""" @file in_mem_group_control.py_in @brief in-memory grouping controller classes """ import plpy import math from control import MinWarning from utilities import unique_string from collections import namedtuple from collections import Iterable class BaseState(object): """@brief Abstraction for intermediate iteration state""" def __init__(self, **kwargs): self._state = {} self._is_none = None self.initialize(**kwargs) def __len__(self): return len(self._state) def __del__(self): del self._state def __getitem__(self, k): return self._state[k] def __setitem__(self, k, v): self._state[k] = v @property def keys(self): return self._state.keys() @property def values(self): if self.is_none(): return [] return [s for x in self._state.values() for s in x] def delete(self, keys_to_remove): for k in keys_to_remove: try: del self._state[k] except KeyError: pass self._is_none = None def initialize(self, col_grp_key='', col_grp_state='', ret_states=None, **kwargs): self.update(col_grp_key, col_grp_state, ret_states) def isclose_to_zero(self, a, b=0.0, rel_tol=1e-09, abs_tol=0.0): ''' A variation of of Python 3.5 math.isclose() https://hg.python.org/cpython/file/tip/Modules/mathmodule.c#l1993 ''' # sanity check on the inputs if rel_tol < 0 or abs_tol < 0: raise ValueError("tolerances must be non-negative") # short circuit exact equality -- needed to catch two infinities of # the same sign. And perhaps speeds things up a bit sometimes. if a == b: return True # This catches the case of two infinities of opposite sign, or # one infinity and one finite number. Two infinities of opposite # sign would otherwise have an infinite relative tolerance. # Two infinities of the same sign are caught by the equality check # above. if math.isinf(a) or math.isinf(b): return False # now do the regular computation # this is essentially the "weak" test from the Boost library diff = math.fabs(b - a) result = (((diff <= math.fabs(rel_tol * b)) or (diff <= math.fabs(rel_tol * a))) or (diff <= abs_tol)) return result def are_last_state_value_zero(self): # This function returns a boolean value, after inspecting the last # element of the state array for each group. # This returns True, only if the last element of the state array # of every group is (almost equal to) 0. If the value is non-zero # even for a single group, it returns False. return all([self.isclose_to_zero(val[-1]) for val in self._state.itervalues()]) def update(self, col_grp_key, col_grp_state, ret_states): failed_grp_keys = [] if not ret_states: return failed_grp_keys t0 = ret_states[0] # no key column in table ret_states if col_grp_key not in t0: return failed_grp_keys # initialize state to None if col_grp_state == '': self._is_none = True for s in ret_states: self._state[s[col_grp_key]] = None return failed_grp_keys for t in ret_states: _grp_key, _grp_state = t[col_grp_key], t[col_grp_state] if _grp_state is None: failed_grp_keys.append(_grp_key) else: self._state[_grp_key] = _grp_state # no need to update if all failed if len(failed_grp_keys) < len(self): self._is_none = False return failed_grp_keys # entries in self not in other are kept def update_from_state(self, other, keys=None): if not isinstance(other, BaseState): return if keys is None: self._state.update(other._state) else: for k in keys: self[k] = other[k] # reset cache self._is_none = None def sync_from(self, other): self._state = {} self.update_from_state(other) self._is_none = other.is_none() def is_none(self): if self._is_none is None: self._is_none = True for k, v in self._state.iteritems(): if v is not None: self._is_none = False break return self._is_none def interpret(self, schema_madlib, state_type, keys=None): if keys is None: keys = self.keys elif isinstance(keys, str) or not isinstance(keys, Iterable): keys = [keys] s = dict.fromkeys(keys) plan = plpy.prepare( """ SELECT (result).loss AS loss, (result).norm_of_gradient AS norm_of_gradient FROM ( SELECT {schema_madlib}.internal_linear_svm_igd_result($1) AS result ) subq """.format(schema_madlib=schema_madlib), [state_type]) for k in keys: s[k] = plpy.execute(plan, [self._state[k]])[0] return s class Bytea8State(BaseState): """@brief bytea8 type state""" def __init__(self, **kwargs): super(Bytea8State, self).__init__(**kwargs) @property def values(self): return self._state.values() def state_factory(is_bytea8, **kwargs): if is_bytea8: return Bytea8State(**kwargs) else: return BaseState(**kwargs) class GroupIterationController: """ @brief Abstraction for implementing in-memory iteration controller for SQL aggregate with STYPE=madlib.bytea8 (e.g. DynamicStruct) in PL/Python """ def __init__(self, arg_dict): """ arg_dict: Dictionary containing arguments to be defined by the calling function: Necessary: state_type: Type of the transition state (can be double precision[] or .bytea8) col_dep_var: Name of the dependent column name col_ind_var: Name of the independent column name Optional: col_n_tuples: Name of the column containing count of tuples. Set to a unique string if not defined """ self.schema_madlib = arg_dict['schema_madlib'] self.in_with = False self.iteration = -1 self.is_group_null = True if arg_dict["grouping_col"] is None else False self.verbose = arg_dict.get('verbose', False) self.kwargs = dict(arg_dict) self.kwargs.update( as_rel_source=arg_dict.get('as_rel_source', '_src'), state_type=arg_dict.get('state_type', 'double precision[]').format(**arg_dict), col_grp_null=unique_string(desp='col_grp_null'), col_n_tuples=self.kwargs.get('col_n_tuples', unique_string(desp='col_n_tuples')), col_grp_key=self.kwargs.get('col_grp_key', unique_string(desp='col_grp_key')), grouping_col=("NULL" if arg_dict["grouping_col"] is None else arg_dict["grouping_col"]), ) self.grp_to_n_tuples = {} self.failed_grp_keys = [] self.is_state_type_bytea8 = False if self.kwargs['state_type'] == "{0}.bytea8".format(self.schema_madlib): self.is_state_type_bytea8 = True elif (self.kwargs['state_type'].lower() == "double precision[]" or self.kwargs['state_type'].lower() == "float8[]"): self.is_state_type_bytea8 = False else: plpy.error("Internal error: unexpected state type!") self.new_states = state_factory(self.is_state_type_bytea8) self.old_states = state_factory(self.is_state_type_bytea8) self.finished_states = state_factory(self.is_state_type_bytea8) self.group_param = self._init_group_param() def _init_group_param(self): _grp_key = ("array_to_string(ARRAY[{grouping_str}], ',')" .format(grouping_str=self.kwargs['grouping_str'])) _select_rel_state = ("SELECT " "grp_key AS {col_grp_key}," "state AS {col_grp_state} " "FROM {schema_madlib}._gen_state($1, NULL, $2)" .format(**self.kwargs)) _select_n_tuples = ("SELECT " "unnest($3) AS {col_grp_key}, " "unnest($4) AS {col_n_tuples}" .format(**self.kwargs)) _using_str = "ON TRUE" _grouped_state_type = "float8[]" _groupby_str = "" if not self.is_group_null: _groupby_str = "GROUP BY {grouping_col}, {col_grp_key}".format( **self.kwargs) _using_str = "USING ({col_grp_key})".format(**self.kwargs) _grp_key = self.kwargs['col_grp_key'] if self.is_state_type_bytea8: _select_rel_state = ("SELECT " "unnest($1) AS {col_grp_key}, " "unnest($2) AS {col_grp_state}" "".format(**self.kwargs)) _grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs) GroupParam = namedtuple('GroupParam', 'groupby_str, using_str,' 'select_rel_state,grouped_state_type,' 'grp_key, select_n_tuples') return GroupParam(groupby_str=_groupby_str, using_str=_using_str, select_rel_state=_select_rel_state, grp_key=_grp_key, select_n_tuples=_select_n_tuples, grouped_state_type=_grouped_state_type) def __enter__(self): verbosity_level = "info" if self.verbose else "error" with MinWarning(verbosity_level): ############################ # create state table # currently assuming that groups is passed as a valid array group_col = ("NULL::integer as {col_grp_null}" if self.is_group_null else "{grouping_col}").format(**self.kwargs) groupby_str = ("{col_grp_null}" if self.is_group_null else "{grouping_col}").format(**self.kwargs) plpy.execute(""" DROP TABLE IF EXISTS {rel_state}; CREATE TEMPORARY TABLE {rel_state} AS ( SELECT array_to_string(ARRAY[{grouping_str}], ',') AS {col_grp_key}, 0::integer AS {col_grp_iteration}, NULL::{state_type} AS {col_grp_state}, count(*) AS {col_n_tuples}, {group_col} FROM {rel_source} WHERE ({col_dep_var}) IS NOT NULL AND NOT {schema_madlib}.array_contains_null({col_ind_var}) GROUP BY {groupby_str} ); """.format(group_col=group_col, groupby_str=groupby_str, **self.kwargs)) ############################ # checking null in group values # We cannot allow NULL due to array_to_string cannot handle it well. if not self.is_group_null: null_test = (" OR ".join([g.strip() + " is NULL" for g in self.kwargs['grouping_col'].split(",")])) null_count = plpy.execute(""" SELECT count(*) FROM {rel_state} WHERE {null_test} """.format(null_test=null_test, rel_state=self.kwargs['rel_state']))[0]['count'] if null_count != 0: plpy.error("Grouping error: at least one of the grouping " "columns contains NULL values! Please filter " "out those NULL values.") ############################ # initialize states rel_state_str = self.kwargs['rel_state'] col_grp_key = self.kwargs['col_grp_key'] col_n_tuples = self.kwargs['col_n_tuples'] ret_states = plpy.execute("SELECT * FROM " + rel_state_str) self.new_states.initialize(col_grp_key=col_grp_key, col_grp_state='', ret_states=ret_states) for s in ret_states: self.grp_to_n_tuples[s[col_grp_key]] = long(s[col_n_tuples]) self.in_with = True return self def are_last_state_value_zero(self): return self.new_states.are_last_state_value_zero() def info(self): """ Logging intermediate state information """ if not self.verbose: return group_param = self.group_param schema_madlib = self.kwargs['schema_madlib'] res = self.new_states.interpret(schema_madlib, group_param.grouped_state_type) for grp, t in res.iteritems(): loss, normg = t['loss'], t['norm_of_gradient'] iteration = self.iteration output_str = "DEBUG: \ grp = {grp:10s}, \ iter = {iteration:5d}, \ loss = {loss:.5e}, \ |gradient| = {normg:.5e}, \ stepsize = {stepsize:.5e}" plpy.notice(output_str.format( grp=grp, iteration=iteration, loss=loss, normg=normg, **self.kwargs)) def final(self): """ Store the final converged state to a table for output """ group_param = self.group_param insert_sql = """ INSERT INTO {rel_state} SELECT {col_grp_key}, {iteration}::int, {col_grp_state}, {col_n_tuples}::bigint, {grouping_col} FROM ( SELECT {grouping_col}, {col_grp_key} FROM {rel_state} ) AS _src JOIN ( {select_rel_state} ) AS _rel_state USING ({col_grp_key}) JOIN ( {select_n_tuples} ) AS _rel_n_tuples USING ({col_grp_key}) """.format( iteration=self.iteration, select_rel_state=group_param.select_rel_state, select_n_tuples=group_param.select_n_tuples, **self.kwargs) insert_plan = plpy.prepare(insert_sql, ["text[]", group_param.grouped_state_type, "text[]", "bigint[]"]) plpy.execute(insert_plan, [self.finished_states.keys, self.finished_states.values, self.grp_to_n_tuples.keys(), self.grp_to_n_tuples.values()]) if self.failed_grp_keys: plpy.execute(insert_plan, [self.failed_grp_keys, [], self.grp_to_n_tuples.keys(), self.grp_to_n_tuples.values()]) def __exit__(self, type, value, tb): self.in_with = False def get_state_size(self): """ Return the size of the state. Greenplum does not have the array_length() method yet, hence this work-around. """ if self.is_state_type_bytea8: unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs) grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs) else: unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs) grouped_state_type = "float8[]" eval_plan = plpy.prepare(""" SELECT _state_current FROM ( SELECT {unnest_str_current} ) subq1 """.format(unnest_str_current=unnest_str_current, **self.kwargs), ["text[]", grouped_state_type] * 2) state_size = len(plpy.execute(eval_plan, [self.old_states.keys, self.old_states.values, self.new_states.keys, self.new_states.values])[0]['_state_current']) return state_size def get_param_value_per_group(self, select_col): """ Return desired values from the current state. Example: to return the value in the last index of the state, the 'select_col' can be set to '_state_current[array_length(_state_current,1)] AS last' This is especially useful when there is grouping and we want to find some specific state value (such as loss) for each group. """ # Check convex/mlp_igd.py_in for an example of how this is useful. if self.is_state_type_bytea8: unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs) grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs) else: unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs) grouped_state_type = "float8[]" eval_plan = plpy.prepare(""" SELECT {select_col}, {col_grp_key} FROM ( SELECT {unnest_str_current} ) subq1 """.format(select_col=select_col, unnest_str_current=unnest_str_current, **self.kwargs), ["text[]", grouped_state_type] * 2) res = plpy.execute(eval_plan, [self.old_states.keys, self.old_states.values, self.new_states.keys, self.new_states.values]) return res def test(self, condition): """ Return True if the given expression is TRUE for all in-progress groups. Also returns True if there are no groups currently in-progress. Evaluate the given expression for all in-progress groups. Move groups with expression as True from self.new_states to self.finished_states. The expression may depend on the current inter-iteration state and all arguments @param condition SQL boolean expression. The following names are defined and can be used in the condition: - \c _args - The (single-row) argument table - \c _state - In memory inter-iteration state @return True if \c expression in all non-failed groups is True, otherwise False """ if len(self.new_states) == 0: # self.new_states can become empty if the last of the groups failed # in the previous update return True if self.is_state_type_bytea8: unnest_str_previous = "unnest($1) AS {col_grp_key}, unnest($2) AS _state_previous".format(**self.kwargs) unnest_str_current = "unnest($3) AS {col_grp_key}, unnest($4) AS _state_current".format(**self.kwargs) grouped_state_type = "{schema_madlib}.bytea8[]".format(**self.kwargs) else: unnest_str_previous = "grp_key AS {col_grp_key}, state AS _state_previous FROM {schema_madlib}._gen_state($1, NULL, $2)".format(**self.kwargs) unnest_str_current = "grp_key AS {col_grp_key}, state AS _state_current FROM {schema_madlib}._gen_state($3, NULL, $4)".format(**self.kwargs) grouped_state_type = "float8[]" condition = condition.format(iteration=self.iteration, **self.kwargs) eval_plan = plpy.prepare(""" SELECT CAST(({condition}) AS BOOLEAN) AS _expression, {col_grp_key} FROM ( ( SELECT {unnest_str_previous} ) sub1 JOIN ( SELECT {unnest_str_current} ) sub2 USING ({col_grp_key}) ) subq1 """.format(condition=condition, unnest_str_current=unnest_str_current, unnest_str_previous=unnest_str_previous, **self.kwargs), ["text[]", grouped_state_type] * 2) res = plpy.execute(eval_plan, [self.old_states.keys, self.old_states.values, self.new_states.keys, self.new_states.values]) col_grp_key = self.kwargs['col_grp_key'] finished_keys = [t[col_grp_key] for t in res if t['_expression']] self.finished_states.update_from_state(self.new_states, finished_keys) self.new_states.delete(finished_keys) return len(self.new_states) == 0 def update(self, newState, **updateKwargs): """ Update the inter-iteration state @param newState SQL expression of (or returning) type state_type.kwargs.state_type. The following names are defined and can be used in the condition: - \c _args - The (single-row) argument table - \c _state - In memory inter-iteration state . Note that {iteration} will still be the current iteration. For instance, it could be used in the expression as a WHERE condition: [...] WHERE _state._iteration = {iteration} This updates the current inter-iteration state to the result of evaluating \c newState. """ newState = newState.format(**self.kwargs) self.iteration = self.iteration + 1 group_param = self.group_param run_sql = """ SELECT {_grp_key} AS {col_grp_key}, {grouping_col}, {iteration} AS {col_grp_iteration}, ({newState}) AS {col_grp_state} FROM ( SELECT *, array_to_string(ARRAY[{grouping_str}], ',') AS {col_grp_key} FROM {rel_source} ) AS {as_rel_source} JOIN ( {select_rel_state} ) AS {rel_state} {using_str} JOIN ( {select_n_tuples} ) AS _rel_n_tuples {using_str} {groupby_str} """.format( newState=newState, iteration=self.iteration, using_str=group_param.using_str, groupby_str=group_param.groupby_str, _grp_key=group_param.grp_key, select_rel_state=group_param.select_rel_state, select_n_tuples=group_param.select_n_tuples, **self.kwargs) update_plan = plpy.prepare(run_sql, ["text[]", group_param.grouped_state_type, "text[]", "integer[]"]) res_tuples = plpy.execute(update_plan, [self.new_states.keys, self.new_states.values, self.grp_to_n_tuples.keys(), self.grp_to_n_tuples.values()]) col_grp_state = self.kwargs['col_grp_state'] col_grp_key = self.kwargs['col_grp_key'] self.old_states.sync_from(self.new_states) self.failed_grp_keys.extend(self.new_states.update( col_grp_key, col_grp_state, res_tuples))