# coding=utf-8 """ @file crf.py_in @brief Conditional Random Field: Driver functions @namespace crf Conditional Random Field: Driver functions """ import plpy from utilities.validate_args import table_exists from utilities.validate_args import columns_exist_in_table from utilities.utilities import _assert def __runIterativeAlg(stateType, initialState, source, updateExpr, terminateExpr, maxNumIterations, cyclesPerIteration = 1): """ Driver for an iterative algorithm A general driver function for most iterative algorithms: The state between iterations is kept in a variable of type stateType, which is initialized with initialState. During each iteration, the SQL statement updateSQL is executed in the database. Afterwards, the SQL query updateSQL decides whether the algorithm terminates. @param stateType SQL type of the state between iterations @param initialState The initial value of the SQL state variable @param source The source relation @param updateExpr SQL expression that returns the new state of type stateType. The expression may use the replacement fields "{state}", "{iteration}", and "{sourceAlias}". Source alias is an alias for the source relation source. @param terminateExpr SQL expression that returns whether the algorithm should terminate. The expression may use the replacement fields "{newState}" and "{iteration}". It must return a BOOLEAN value. @param maxNumIterations Maximum number of iterations. Algorithm will then terminate even when terminateExpr does not evaluate to \c true @param cyclesPerIteration Number of aggregate function calls per iteration. """ updateSQL = """ INSERT INTO _madlib_iterative_alg SELECT {{iteration}}, {updateExpr} FROM _madlib_iterative_alg AS st, {{source}} AS src WHERE st._madlib_iteration = {{iteration}} - 1 """.format(updateExpr = updateExpr) terminateSQL = """ SELECT {terminateExpr} AS should_terminate FROM ( SELECT _madlib_state FROM _madlib_iterative_alg WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}} ) AS older, ( SELECT _madlib_state FROM _madlib_iterative_alg WHERE _madlib_iteration = {{iteration}} ) AS newer """.format(terminateExpr = terminateExpr) checkForNullStateSQL = """ SELECT _madlib_state IS NULL AS should_terminate FROM _madlib_iterative_alg WHERE _madlib_iteration = {iteration} """ oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings WHERE name='client_min_messages'" )[0]['setting'] plpy.execute(""" SET client_min_messages = error; DROP TABLE IF EXISTS _madlib_iterative_alg; CREATE TEMPORARY TABLE _madlib_iterative_alg ( _madlib_iteration INTEGER m4_ifdef(`__HAWQ__', `', ` PRIMARY KEY'), _madlib_state {stateType} ) m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (_madlib_iteration)'); SET client_min_messages = {oldMsgLevel}; """.format(stateType = stateType, oldMsgLevel = oldMsgLevel)) iteration = 0 plpy.execute(""" INSERT INTO _madlib_iterative_alg VALUES ({iteration}, {initialState}) """.format(iteration = iteration, initialState = initialState)) while True: iteration = iteration + 1 plpy.execute(updateSQL.format( source = source, state = "(st._madlib_state)", iteration = iteration, sourceAlias = "src")) if plpy.execute(checkForNullStateSQL.format( iteration = iteration))[0]['should_terminate'] or ( iteration > cyclesPerIteration and ( iteration >= cyclesPerIteration * maxNumIterations or plpy.execute(terminateSQL.format( iteration = iteration, cyclesPerIteration = cyclesPerIteration, oldState = "(older._madlib_state)", newState = "(newer._madlib_state)"))[0]['should_terminate'])): break # Note: We do not drop the temporary table return iteration def compute_lincrf(schema_madlib, source, sparse_R, dense_M, sparse_M, featureSize, tagSize, maxNumIterations, **kwargs): """ Compute conditional random field coefficients This method serves as an interface to L-BFGS optimization algorithms. @param schema_madlib Name of the MADlib schema, properly escaped/quoted @param source Name of relation containing the training data @param sparse_R Name of the sparse single state feature column (of type DOUBLE PRECISION[]) @param dense_M Name of the dense two state feature column (of type DOUBLE PRECISION[]) @param sparse_M Name of the sparse two state feature column (of type DOUBLE PRECISION[]) @param featureSize Name of feature size column in training data (of type DOUBLE PRECISION) @param tagSize The size of the tag set @param maxNumIterations Maximum number of iterations @param kwargs We allow the caller to specify additional arguments (all of which will be ignored though). The purpose of this is to allow the caller to unpack a dictionary whose element set is a superset of the required arguments by this function. @return array with coefficients in case of convergence, otherwise None """ if maxNumIterations < 1: plpy.error("Number of iterations must be positive") return __runIterativeAlg( stateType = "FLOAT8[]", initialState = "NULL", source = source, updateExpr = """ {schema_madlib}.lincrf_lbfgs_step( ({sparse_R})::FLOAT8[], ({dense_M})::FLOAT8[], ({sparse_M})::FLOAT8[], ({featureSize})::FLOAT8, ({tagSize})::FLOAT8, {{state}} ) """.format( schema_madlib = schema_madlib, sparse_R = sparse_R, dense_M = dense_M, sparse_M = sparse_M, featureSize = featureSize, tagSize = tagSize), terminateExpr = """ {schema_madlib}.internal_lincrf_lbfgs_converge( {{newState}}) = 0 """.format( schema_madlib = schema_madlib), maxNumIterations = maxNumIterations) def lincrf_train(schema_madlib, train_feature_tbl, train_featureset_tbl, label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations, **kwargs): tag_size = _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations) rv = plpy.execute(""" SELECT {schema_madlib}.compute_lincrf('{train_feature_tbl}', 'sparse_r', 'dense_m', 'sparse_m', 'f_size', {tag_size}, {max_iterations}) iterations """.format(schema_madlib = schema_madlib, train_feature_tbl = train_feature_tbl, tag_size = tag_size, max_iterations = max_iterations)) iterations = rv[0]['iterations'] plpy.execute("""CREATE TABLE {crf_stats_tbl} AS SELECT (result).coef, (result).log_likelihood, (result).num_iterations FROM (SELECT {schema_madlib}.internal_lincrf_lbfgs_result(_madlib_state) AS result FROM _madlib_iterative_alg WHERE _madlib_iteration = {iteration} ) Q m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (num_iterations)') """.format(crf_stats_tbl = crf_stats_tbl, iteration = iterations, schema_madlib = schema_madlib)) plpy.execute("""CREATE TABLE {crf_weights_tbl} AS SELECT f_index id, f_name as name, feature[1] prev_label_id, feature[2] label_id, coef[f_index + 1] weight FROM {crf_stats_tbl}, {train_featureset_tbl} m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)') """.format(crf_weights_tbl = crf_weights_tbl, crf_stats_tbl = crf_stats_tbl, train_featureset_tbl = train_featureset_tbl, schema_madlib = schema_madlib)) # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans plpy.execute(""" ANALYZE {crf_weights_tbl} """.format(crf_weights_tbl = crf_weights_tbl)) return "CRF Train successful. Results stored in the specified CRF stats and weights table" def _validate_columns(cols, table_name, err_msg_tbl): """ @brief Validate if cols exists in the table """ _assert(columns_exist_in_table(table_name, cols), "CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols))) def _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations): _assert(table_exists(train_feature_tbl), "CRF error: Train feature table does not exist!") _assert(table_exists(train_featureset_tbl), "CRF error: Train featureset table does not exist!") _assert(table_exists(label_tbl), "CRF error: Label table does not exist!") _assert(max_iterations > 0, "CRF error: max iterations cannot be zero or negative") # Validate required columns _validate_columns(['doc_id', 'f_size', 'sparse_r', 'dense_m', 'sparse_m'], train_feature_tbl, "feature") _validate_columns(['f_index', 'f_name', 'feature'], train_featureset_tbl, "featureset") _validate_columns(['id', 'label'], label_tbl, "label") rv = plpy.execute(""" SELECT count(*) FROM {label_tbl} """.format(label_tbl = label_tbl)) tag_size = rv[0]['count'] _assert(tag_size > 0, "CRF error: Label table is empty") _assert(crf_stats_tbl is not None and crf_stats_tbl.lower().strip() not in ('null', ''), "CRF error: Invalid CRF stats table name") _assert(crf_weights_tbl is not None and crf_weights_tbl.lower().strip() not in ('null', ''), "CRF error: Invalid CRF weights table name") _assert(not table_exists(crf_stats_tbl, only_first_schema=True), "CRF error: CRF stats table already exist!" " Please provide a different table name.") _assert(not table_exists(crf_weights_tbl, only_first_schema=True), "CRF error: CRF weights table already exist!" " Please provide a different table name.") return tag_size