# coding=utf-8
m4_changequote(`')
"""
@file control.py_in
@brief controller classes (e.g. iteration controller)
@namespace utilities
@brief driver functions shared by modules
"""
import plpy
from utilities import __mad_version
version_wrapper = __mad_version()
from utilities import unique_string
_unique_string = unique_string
from control import MinWarning
STATE_IN_MEM = m4_ifdef(, , )
HAS_FUNCTION_PROPERTIES = m4_ifdef(, , )
UDF_ON_SEGMENT_NOT_ALLOWED = m4_ifdef(, , )
class IterationControllerComposite:
"""
@brief Abstraction for implementing driver functions in PL/Python
This class encapsulates handling of the inter-iteration state. The design
goal is to avoid any conversion between backend-native types and those of
procedureal languages like PL/Python. Therefore, the expectation is that
***all only "template" parameters are passes as PL/Python arguments***,
whereas non-template arguments are provided in an argument table. Here,
"template" arguments are those parameters that cannot be SQL parameters,
***such as table and column names***.
The inter-state iteration table contains two columns:
- _iteration INTEGER - The 0-based iteration number
- _state self.kwargs.stateType - The state (after
iteration \c _interation)
"""
def __init__(self, rel_args, rel_state, stateType,
temporaryTables=True,
truncAfterIteration=False,
schema_madlib="MADLIB_SCHEMA_MISSING",
verbose=False,
**kwargs):
self.kwargs = kwargs
self.kwargs.update(
unqualified_rel_state=rel_state,
rel_args=('pg_temp.' if temporaryTables else '') + rel_args,
rel_state=('pg_temp.' if temporaryTables else '') + rel_state,
stateType=stateType.format(schema_madlib=schema_madlib),
schema_madlib=schema_madlib)
self.temporaryTables = temporaryTables
self.truncAfterIteration = truncAfterIteration
self.verbose = verbose
self.inWith = False
self.iteration = -1
if STATE_IN_MEM:
self.new_state = None
self.old_state = None
def __enter__(self):
with MinWarning('warning'):
self.runSQL("""
DROP TABLE IF EXISTS {rel_state};
CREATE {temp} TABLE {unqualified_rel_state} (
_iteration INTEGER m4_ifdef(, , ),
_state {stateType}
)m4_ifdef(, , );
""".format(
temp='TEMPORARY' if self.temporaryTables else '',
**self.kwargs))
self.inWith = True
return self
def __exit__(self, type, value, tb):
self.inWith = False
if STATE_IN_MEM:
insert_plan=plpy.prepare("""
INSERT INTO {rel_state}
SELECT $1, CAST( ({schema_madlib}.array_to_2d($2), $3, $4, $5, $6) AS {schema_madlib}.kmeans_state)
""".format(**self.kwargs), ["INTEGER", "DOUBLE PRECISION[]",
"INTEGER[]", "DOUBLE PRECISION[]",
"DOUBLE PRECISION", "DOUBLE PRECISION"])
plpy.execute(insert_plan, [self.new_state['_iteration'],
self.new_state['centroids'],
self.new_state['old_centroid_ids'],
self.new_state['cluster_variance'],
self.new_state['objective_fn'],
self.new_state['frac_reassigned']])
def runSQL(self, sql):
if self.verbose:
plpy.notice(sql)
return plpy.execute(sql)
def evaluate(self, expression):
"""
Evaluate the given expression. The expression may depend on
the current inter-iteration state and all arguments
@param expression SQL expression. The
following names are defined and can be used in the condition:
- \c _args - The (single-row) argument table
- \c _state - The row of the state table containing the latest
inter-iteration state
@return None if \c expression evaluates to NULL, otherwise the value of
\c expression
"""
if STATE_IN_MEM:
cast_str = "CAST (({schema_madlib}.array_to_2d($1), $2, $3, $4, $5) AS {schema_madlib}.kmeans_state)".format(**self.kwargs)
cast_type = ["DOUBLE PRECISION[]", "INTEGER[]", "DOUBLE PRECISION[]",
"DOUBLE PRECISION", "DOUBLE PRECISION"]
cast_para = [None if self.new_state is None else self.new_state[i]
for i in ('centroids', 'old_centroid_ids', 'cluster_variance',
'objective_fn', 'frac_reassigned')]
eval_plan = plpy.prepare("""
SELECT ({expression}) AS expression
FROM {{rel_args}} AS _args
LEFT OUTER JOIN (
SELECT {{schema_madlib}}.array_to_2d($1) AS _state
) AS _state ON True
""".format(expression=expression).
format(iteration=self.iteration,
curr_state=cast_str, **self.kwargs), cast_type)
resultObject = plpy.execute(eval_plan, cast_para)
else:
resultObject = self.runSQL("""
SELECT ({expression}) AS expression
FROM {{rel_args}} AS _args
LEFT OUTER JOIN (
SELECT *
FROM {{rel_state}} AS _state
WHERE _state._iteration = {{iteration}}
) AS _state ON True
""".format(expression=expression).
format(iteration=self.iteration,
**self.kwargs))
if resultObject.nrows() == 0:
return None
else:
return resultObject[0]['expression']
def test(self, condition):
"""
Test if the given condition is satisfied. The condition may depend on
the current inter-iteration state and all arguments
@param condition Boolean SQL expression. The
following names are defined and can be used in the condition:
- \c _args - The (single-row) argument table
- \c _state - The row of the state table containing the latest
inter-iteration state
@return None if \c condition evaluates to NULL, otherwise the Boolean
value of \c condition
"""
return self.evaluate("CAST(({0}) AS BOOLEAN)".format(condition))
def update(self, newState, **updateKwargs):
"""
Update the inter-iteration state
@param newState SQL expression of type
stateType.kwargs.stateType. The
following names are defined and can be used in the condition:
- \c _args - The (single-row) argument table
- \c _state - The row of the state table containing the latest
inter-iteration state
.
Note that {iteration} will still be the current iteration.
For instance, it could be used in the expression as a WHERE
condition: [...] WHERE _state._iteration = {iteration}
This updates the current inter-iteration state to the result of
evaluating \c newState. If self.truncAfterIteration is true,
this will replace the old state, otherwise the history of all old states
is kept.
"""
updateKwargs.update(**self.kwargs)
newState = newState.format(iteration=self.iteration, **updateKwargs)
self.iteration = self.iteration + 1
if STATE_IN_MEM:
cast_str = """CAST (({schema_madlib}.array_to_2d($1), $2, $3, $4, $5)
AS {schema_madlib}.kmeans_state)""".format(**self.kwargs)
cast_str_old = """CAST (({schema_madlib}.array_to_2d($6), $7, $8, $9, $10)
AS {schema_madlib}.kmeans_state)""".format(**self.kwargs)
cast_type = [
"DOUBLE PRECISION[]", "INTEGER[]", "DOUBLE PRECISION[]",
"DOUBLE PRECISION", "DOUBLE PRECISION",
"DOUBLE PRECISION[]", "INTEGER[]", "DOUBLE PRECISION[]",
"DOUBLE PRECISION", "DOUBLE PRECISION"
]
cast_para = [None if self.new_state is None else self.new_state[i]
for i in ('centroids', 'old_centroid_ids', 'cluster_variance',
'objective_fn', 'frac_reassigned')]
cast_para.extend([None if self.old_state is None else self.old_state[i]
for i in ('centroids', 'old_centroid_ids', 'cluster_variance',
'objective_fn', 'frac_reassigned')])
updateKwargs.update(curr_state=cast_str, old_state=cast_str_old)
self.old_state = self.new_state
update_plan = plpy.prepare("""
SELECT
{iteration} AS _iteration,
{schema_madlib}.array_to_1d((_state).centroids) AS centroids,
(_state).old_centroid_ids,
(_state).cluster_variance,
(_state).objective_fn,
(_state).frac_reassigned
FROM
(
SELECT ({newState}) AS _state
) q
""".format(iteration=self.iteration,
newState=newState, **self.kwargs).format(**updateKwargs),
cast_type)
self.new_state = plpy.execute(update_plan, cast_para)[0]
else:
self.runSQL("""
INSERT INTO {rel_state}
SELECT
{iteration},
({newState})
""".format(iteration=self.iteration, newState=newState,
**self.kwargs))
if self.truncAfterIteration:
self.runSQL("""
DELETE FROM {rel_state} AS _state
WHERE _state._iteration < {iteration}
""".format(iteration=self.iteration, **self.kwargs))
m4_changequote(, )