# coding=utf-8
"""
@file crf.py_in
@brief Conditional Random Field: Driver functions
@namespace crf
Conditional Random Field: Driver functions
"""
import plpy
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import _assert
def __runIterativeAlg(stateType, initialState, source, updateExpr,
terminateExpr, maxNumIterations, cyclesPerIteration = 1):
"""
Driver for an iterative algorithm
A general driver function for most iterative algorithms: The state between
iterations is kept in a variable of type stateType, which is
initialized with initialState. During each iteration, the
SQL statement updateSQL is executed in the database. Afterwards,
the SQL query updateSQL decides whether the algorithm terminates.
@param stateType SQL type of the state between iterations
@param initialState The initial value of the SQL state variable
@param source The source relation
@param updateExpr SQL expression that returns the new state of type
stateType. The expression may use the replacement fields
"{state}", "{iteration}", and
"{sourceAlias}". Source alias is an alias for the source
relation source.
@param terminateExpr SQL expression that returns whether the algorithm should
terminate. The expression may use the replacement fields
"{newState}" and "{iteration}". It must return a BOOLEAN value.
@param maxNumIterations Maximum number of iterations. Algorithm will then
terminate even when terminateExpr does not evaluate to \c true
@param cyclesPerIteration Number of aggregate function calls per iteration.
"""
updateSQL = """
INSERT INTO _madlib_iterative_alg
SELECT
{{iteration}},
{updateExpr}
FROM
_madlib_iterative_alg AS st,
{{source}} AS src
WHERE
st._madlib_iteration = {{iteration}} - 1
""".format(updateExpr = updateExpr)
terminateSQL = """
SELECT
{terminateExpr} AS should_terminate
FROM
(
SELECT _madlib_state
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}}
) AS older,
(
SELECT _madlib_state
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {{iteration}}
) AS newer
""".format(terminateExpr = terminateExpr)
checkForNullStateSQL = """
SELECT _madlib_state IS NULL AS should_terminate
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {iteration}
"""
oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings WHERE name='client_min_messages'"
)[0]['setting']
plpy.execute("""
SET client_min_messages = error;
DROP TABLE IF EXISTS _madlib_iterative_alg;
CREATE TEMPORARY TABLE _madlib_iterative_alg (
_madlib_iteration INTEGER
m4_ifdef(`__HAWQ__', `', ` PRIMARY KEY'),
_madlib_state {stateType}
)
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (_madlib_iteration)');
SET client_min_messages = {oldMsgLevel};
""".format(stateType = stateType, oldMsgLevel = oldMsgLevel))
iteration = 0
plpy.execute("""
INSERT INTO _madlib_iterative_alg VALUES ({iteration}, {initialState})
""".format(iteration = iteration, initialState = initialState))
while True:
iteration = iteration + 1
plpy.execute(updateSQL.format(
source = source,
state = "(st._madlib_state)",
iteration = iteration,
sourceAlias = "src"))
if plpy.execute(checkForNullStateSQL.format(
iteration = iteration))[0]['should_terminate'] or (
iteration > cyclesPerIteration and (
iteration >= cyclesPerIteration * maxNumIterations or
plpy.execute(terminateSQL.format(
iteration = iteration,
cyclesPerIteration = cyclesPerIteration,
oldState = "(older._madlib_state)",
newState = "(newer._madlib_state)"))[0]['should_terminate'])):
break
# Note: We do not drop the temporary table
return iteration
def compute_lincrf(schema_madlib, source, sparse_R, dense_M, sparse_M, featureSize, tagSize, maxNumIterations, **kwargs):
"""
Compute conditional random field coefficients
This method serves as an interface to L-BFGS optimization algorithms.
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param source Name of relation containing the training data
@param sparse_R Name of the sparse single state feature column (of type DOUBLE PRECISION[])
@param dense_M Name of the dense two state feature column (of type DOUBLE PRECISION[])
@param sparse_M Name of the sparse two state feature column (of type DOUBLE PRECISION[])
@param featureSize Name of feature size column in training data (of type
DOUBLE PRECISION)
@param tagSize The size of the tag set
@param maxNumIterations Maximum number of iterations
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return array with coefficients in case of convergence, otherwise None
"""
if maxNumIterations < 1:
plpy.error("Number of iterations must be positive")
return __runIterativeAlg(
stateType = "FLOAT8[]",
initialState = "NULL",
source = source,
updateExpr = """
{schema_madlib}.lincrf_lbfgs_step(
({sparse_R})::FLOAT8[],
({dense_M})::FLOAT8[],
({sparse_M})::FLOAT8[],
({featureSize})::FLOAT8,
({tagSize})::FLOAT8,
{{state}}
)
""".format(
schema_madlib = schema_madlib,
sparse_R = sparse_R,
dense_M = dense_M,
sparse_M = sparse_M,
featureSize = featureSize,
tagSize = tagSize),
terminateExpr = """
{schema_madlib}.internal_lincrf_lbfgs_converge(
{{newState}}) = 0
""".format(
schema_madlib = schema_madlib),
maxNumIterations = maxNumIterations)
def lincrf_train(schema_madlib, train_feature_tbl, train_featureset_tbl,
label_tbl, crf_stats_tbl, crf_weights_tbl, max_iterations, **kwargs):
tag_size = _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl,
crf_stats_tbl, crf_weights_tbl, max_iterations)
rv = plpy.execute("""
SELECT {schema_madlib}.compute_lincrf('{train_feature_tbl}',
'sparse_r', 'dense_m', 'sparse_m', 'f_size',
{tag_size}, {max_iterations}) iterations
""".format(schema_madlib = schema_madlib,
train_feature_tbl = train_feature_tbl,
tag_size = tag_size,
max_iterations = max_iterations))
iterations = rv[0]['iterations']
plpy.execute("""CREATE TABLE {crf_stats_tbl} AS
SELECT (result).coef, (result).log_likelihood, (result).num_iterations
FROM
(SELECT {schema_madlib}.internal_lincrf_lbfgs_result(_madlib_state) AS result
FROM _madlib_iterative_alg
WHERE _madlib_iteration = {iteration}
) Q
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (num_iterations)')
""".format(crf_stats_tbl = crf_stats_tbl,
iteration = iterations,
schema_madlib = schema_madlib))
plpy.execute("""CREATE TABLE {crf_weights_tbl} AS
SELECT f_index id, f_name as name,
feature[1] prev_label_id,
feature[2] label_id,
coef[f_index + 1] weight
FROM
{crf_stats_tbl},
{train_featureset_tbl}
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)')
""".format(crf_weights_tbl = crf_weights_tbl,
crf_stats_tbl = crf_stats_tbl,
train_featureset_tbl = train_featureset_tbl,
schema_madlib = schema_madlib))
# Enforce ANALYZE to gather proper table statistics required to generate optimized query plans
plpy.execute(""" ANALYZE {crf_weights_tbl} """.format(crf_weights_tbl = crf_weights_tbl))
return "CRF Train successful. Results stored in the specified CRF stats and weights table"
def _validate_columns(cols, table_name, err_msg_tbl):
"""
@brief Validate if cols exists in the table
"""
_assert(columns_exist_in_table(table_name, cols),
"CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols)))
def _validate_args(train_feature_tbl, train_featureset_tbl, label_tbl,
crf_stats_tbl, crf_weights_tbl, max_iterations):
_assert(table_exists(train_feature_tbl),
"CRF error: Train feature table does not exist!")
_assert(table_exists(train_featureset_tbl),
"CRF error: Train featureset table does not exist!")
_assert(table_exists(label_tbl),
"CRF error: Label table does not exist!")
_assert(max_iterations > 0,
"CRF error: max iterations cannot be zero or negative")
# Validate required columns
_validate_columns(['doc_id', 'f_size', 'sparse_r', 'dense_m', 'sparse_m'],
train_feature_tbl, "feature")
_validate_columns(['f_index', 'f_name', 'feature'], train_featureset_tbl, "featureset")
_validate_columns(['id', 'label'], label_tbl, "label")
rv = plpy.execute(""" SELECT count(*) FROM {label_tbl}
""".format(label_tbl = label_tbl))
tag_size = rv[0]['count']
_assert(tag_size > 0,
"CRF error: Label table is empty")
_assert(crf_stats_tbl is not None and
crf_stats_tbl.lower().strip() not in ('null', ''),
"CRF error: Invalid CRF stats table name")
_assert(crf_weights_tbl is not None and
crf_weights_tbl.lower().strip() not in ('null', ''),
"CRF error: Invalid CRF weights table name")
_assert(not table_exists(crf_stats_tbl, only_first_schema=True),
"CRF error: CRF stats table already exist!"
" Please provide a different table name.")
_assert(not table_exists(crf_weights_tbl, only_first_schema=True),
"CRF error: CRF weights table already exist!"
" Please provide a different table name.")
return tag_size