import plpy from utilities.validate_args import table_exists from utilities.validate_args import columns_exist_in_table from utilities.utilities import _assert def vcrf_top1_table(schema_madlib, segment_tbl, label_tbl, resulttbl_raw, result_tbl, **kwargs): plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib)); seg_count_tbl = "pg_temp._madlib_segcount_tbl" plpy.execute(""" DROP TABLE IF EXISTS {seg_count_tbl}; CREATE TABLE {seg_count_tbl} AS SELECT doc_id, count(*) -1 max_pos FROM {segment_tbl} GROUP BY doc_id m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)') """.format(seg_count_tbl = seg_count_tbl, segment_tbl = segment_tbl)) rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl); nlabel = rv[0]['total'] query = """create table """ + result_tbl + """ AS select segs.doc_id, start_pos, seg_text, L.label, L.id as id, count_tbl.max_pos, (result.label[count_tbl.max_pos+2]::float/1000000) as prob from """ + segment_tbl + """ segs, """ + label_tbl + """ L, """ + resulttbl_raw + """ result, """ + seg_count_tbl + """ count_tbl where result.label[segs.start_pos+1]=L.id and segs.doc_id=result.doc_id and segs.doc_id = count_tbl.doc_id order by doc_id, start_pos;""" plpy.execute(query) def vcrf_label(schema_madlib, segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl, **kwargs): _validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl) origClientMinMessages = plpy.execute("SELECT setting AS setting FROM pg_settings WHERE name = \'client_min_messages\';") plpy.execute("SET client_min_messages TO warning;") plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib)); m_factors = "pg_temp._madlib_m_factors" r_factors = "pg_temp._madlib_r_factors" segment_tbl_digits = "pg_temp._madlib_segment_tbl_digits" resulttbl_raw = "pg_temp._madlib_resulttbl_raw" plpy.execute("""DROP TABLE IF EXISTS """ + m_factors + """,""" + r_factors + """,""" + segment_tbl_digits + """,""" + resulttbl_raw + """;""") plpy.execute("""CREATE TABLE """ + resulttbl_raw + """(doc_id integer, label integer[]) m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") # replace digits with "DIGIT" keyword plpy.execute("""CREATE TABLE """ + segment_tbl_digits + """ AS SELECT start_pos, doc_id, seg_text FROM """ + segment_tbl + """ WHERE NOT (seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~ E'^[-+]?[0-9]*[.][0-9]+$') m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") plpy.execute("""INSERT INTO """ + segment_tbl_digits + """ SELECT start_pos,doc_id,'DIGIT' FROM """ + segment_tbl + """ WHERE seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~E'^[-+]?[0-9]*[.][0-9]+$';""") query = """ -- for each sentence, store array representation of r_factors m4_ifdef(`__HAS_ORDERED_AGGREGATES__', ` select doc_id, array_agg(score order by start_pos, label) as score ', ` select doc_id, array( select score from """ + factor_rtbl + """ factors, """ + segment_tbl_digits + """ seg where factors.seg_text = seg.seg_text and doc_id = ss.doc_id order by start_pos, label ) as score ') into """ + r_factors + """ from (select doc_id, start_pos, label, score from """ + factor_rtbl + """ factors, """ + segment_tbl_digits + """ seg where factors.seg_text=seg.seg_text) as ss group by doc_id order by doc_id """ plpy.execute(query) plpy.execute("analyze " + r_factors + ";") query = """ -- array representation of m_factor select score into """ + m_factors + """ from (select score from """ + factor_mtbl + """ factors) as ss """ plpy.execute(query) rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl); nlabel = rv[0]['total'] query = (""" INSERT INTO {resulttbl_raw} SELECT doc_id, {schema_madlib}.vcrf_top1_label(mfactors.score, rfactors.score, {nlabel}) FROM {m_factors} mfactors, {r_factors} rfactors """.format(schema_madlib = schema_madlib, resulttbl_raw = resulttbl_raw, m_factors = m_factors, r_factors = r_factors, nlabel = str(nlabel))) plpy.execute(query); query = "SELECT * FROM {schema_madlib}.vcrf_top1_table('{segment_tbl}', '{label_tbl}', '{resulttbl_raw}', '{result_tbl}')" plpy.execute(query.format(schema_madlib = schema_madlib, segment_tbl = segment_tbl, label_tbl = label_tbl, resulttbl_raw = resulttbl_raw, result_tbl = result_tbl)) plpy.execute("""SET client_min_messages TO """ + str(origClientMinMessages[0]['setting']) + """;""") def _validate_columns(cols, table_name, err_msg_tbl): """ @brief Validate if cols exists in the table """ _assert(columns_exist_in_table(table_name, cols), "CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols))) def _validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl): # Check existence of input tables. _assert(table_exists(segment_tbl), "CRF error: Segment table does not exist!") _assert(table_exists(factor_mtbl), "CRF error: M Factor table does not exist!") _assert(table_exists(factor_rtbl), "CRF error: R Factor table does not exist!") _assert(table_exists(label_tbl), "CRF error: Label table does not exist!") # validate required column names for existence _validate_columns(['doc_id', 'start_pos', 'seg_text'], segment_tbl, "segment") _validate_columns(['seg_text', 'label', 'score'], factor_rtbl, "R factor") _validate_columns(['score'], factor_mtbl, "M factor") _validate_columns(['id', 'label'], label_tbl, "label") _assert(result_tbl is not None and result_tbl.lower().strip() not in ('null', ''), "CRF error: Invalid result table name") _assert(not table_exists(result_tbl, only_first_schema=True), "CRF error: Result table name already exist!" " Please provide a different table name.")