""" @file correlation.py_in @brief Cross-correlation function for multiple columns of a relation @namespace correlation """ from time import time import plpy from utilities.control import MinWarning from utilities.utilities import unique_string, add_postfix from utilities.validate_args import get_cols_and_types from utilities.validate_args import input_tbl_valid, output_tbl_valid, cols_in_tbl_valid from utilities.utilities import py_list_to_sql_string def correlation(schema_madlib, source_table, output_table, target_cols, get_cov=False, verbose=False, **kwargs): """ Populates an output table with the coefficients of correlation between the columns in a source table Args: @param schema_madlib MADlib schema namespace @param source_table Name of input table @param output_table Name of output table @param target_cols Name of specific columns targetted for correlation @param get_cov If False return the correlation matrix else return the covariance matrix @param verbose Flag to determine whether to output debug info Returns: Tuple (output table name, number of columns, time for computation) """ with MinWarning("info" if verbose else "error"): _validate_corr_arg(source_table, output_table) _numeric_column_names, _nonnumeric_column_names = _get_numeric_columns(source_table) _target_cols = _analyze_target_cols(source_table, target_cols) if _target_cols: # prune all non-numeric column types from target columns _existing_target_cols = [] # we create a copy using a set to efficiently check membership (see below) _existing_target_cols_check = set() _nonexisting_target_cols = [] _nonnumeric_target_cols = [] _numeric_column_names = set(_numeric_column_names) _nonnumeric_column_names = set(_nonnumeric_column_names) for col in _target_cols: if col in _numeric_column_names: # efficiently check membership using the set # ensure all column names are unique since they're output column names if col not in _existing_target_cols_check: _existing_target_cols.append(col) _existing_target_cols_check.add(col) elif col in _nonnumeric_column_names: _nonnumeric_target_cols.append(col) else: _nonexisting_target_cols.append(col) else: # if target_cols not provided then all numeric columns are to be targeted _existing_target_cols = _numeric_column_names[:] _nonnumeric_target_cols = _nonnumeric_column_names[:] _nonexisting_target_cols = [] if not _existing_target_cols: plpy.error("Correlation error: No numeric column found in the target list.") if len(_existing_target_cols) == 1: plpy.error("Correlation error: Only one numeric column found in the target list.") run_time = _populate_output_table(schema_madlib, source_table, output_table, _existing_target_cols, get_cov, verbose) # ---- Output message ---- output_text_list = ["Summary for 'correlation' function"] output_text_list.append("Output table = " + str(output_table)) if _nonnumeric_target_cols: output_text_list.append(("Non-numeric columns ignored: {0}". format(str(_nonnumeric_target_cols)))) if _nonexisting_target_cols: output_text_list.append(("Columns that don't exist in '{0}' ignored: {1}". format(source_table, str(_nonexisting_target_cols)))) output_text_list.append(("Producing correlation for columns: {0}". format(str(_existing_target_cols)))) output_text_list.append("Total run time = " + str(run_time)) # ---- Output message ---- return '\n'.join(output_text_list) # ------------------------------------------------------------------------------ def _validate_corr_arg(source_table, output_table): """ Validates all arguments and raises an error if there is an invalid argument Args: @param source_table Name of input table (string) @param output_table Name of output table (string) @param target_cols Comma separated list of output columns (string) Returns: True if all arguments are valid """ input_tbl_valid(source_table, "Correlation") output_tbl_valid(output_table, "Correlation") output_tbl_valid(add_postfix(output_table, "_summary"), "Correlation") # ------------------------------------------------------------------------------ def _get_numeric_columns(source_table): """ Returns all column names for numeric type columns in a relation Args: @param source_table Returns: List of column names in table """ # retrieve the numeric columns numeric_types = ('smallint', 'integer', 'bigint', 'real', 'numeric', 'double precision') numeric_col_names = [] nonnum_col_names = [] for col_name, col_type in get_cols_and_types(source_table): if col_type in numeric_types: numeric_col_names.append(col_name) else: nonnum_col_names.append(col_name) return (numeric_col_names, nonnum_col_names) # ------------------------------------------------------------------------------ def _analyze_target_cols(source_table, target_cols): """ Analyzes target_cols string input and converts it to a list """ if not target_cols or target_cols.strip() in ('', '*'): target_cols = None else: target_cols = [i.strip() for i in target_cols.split(',')] cols_in_tbl_valid(source_table, target_cols, "Correlation") return target_cols # ------------------------------------------------------------------------------ def _populate_output_table(schema_madlib, source_table, output_table, col_names, get_cov=False, verbose=False): """ Creates a relation with the appropriate number of columns given a list of column names and populates with the correlation coefficients. If the table already exists, then it is dropped before creating. Args: @param schema_madlib Schema of MADlib @param source_table Name of source table @param output_table Name of output table @param col_names Name of all columns to place in output table @param get_cov If False return the correlation matrix else return covariance matrix Returns: Tuple (output table name, number of columns, time for computation) """ with MinWarning("info" if verbose else "error"): start = time() col_len = len(col_names) col_names_as_text_array = py_list_to_sql_string(col_names, "varchar") temp_table = unique_string() if get_cov: function_name = "Covariance" agg_str = """ (CASE WHEN count(*) > 0 THEN {0}.array_scalar_mult({0}.covariance_agg(x, mean), 1.0 / count(*)::double precision) ELSE NULL END) """.format(schema_madlib) else: function_name = "Correlation" agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib) cols = ','.join(["coalesce({0}, {1})".format(col, add_postfix(col, "_avg")) for col in col_names]) avgs = ','.join(["avg({0}) AS {1}".format(col, add_postfix(col, "_avg")) for col in col_names]) avg_array = ','.join([str(add_postfix(col, "_avg")) for col in col_names]) # actual computation sql1 = """ CREATE TEMP TABLE {temp_table} AS SELECT count(*) AS tot_cnt, mean, {agg_str} as cor_mat FROM ( SELECT ARRAY[ {cols} ] AS x, ARRAY [ {avg_array} ] AS mean FROM {source_table}, ( SELECT {avgs} FROM {source_table} )sub1 ) sub2 GROUP BY mean """.format(**locals()) plpy.execute(sql1) # create summary table summary_table = add_postfix(output_table, "_summary") q_summary = """ CREATE TABLE {summary_table} AS SELECT '{function_name}'::varchar AS method, '{source_table}'::varchar AS source, '{output_table}'::varchar AS output_table, {col_names_as_text_array} AS column_names, mean AS mean_vector, tot_cnt AS total_rows_processed FROM {temp_table} """.format(**locals()) plpy.execute(q_summary) # create output table variable_list = [] for k, c in enumerate(col_names): if k % 10 == 0: variable_list.append("\n ") variable_list.append(str(c) + " float8") variable_list.append(",") variable_list_str = ''.join(variable_list[:-1]) # remove the last comma plpy.execute(""" CREATE TABLE {output_table} AS SELECT * FROM ( SELECT generate_series(1, {num_cols}) AS column_position, unnest({col_names_as_text_array}) AS variable ) variable_subq JOIN ( SELECT * FROM {schema_madlib}.__deconstruct_lower_triangle( (SELECT cor_mat FROM {temp_table}) ) AS deconstructed(column_position integer, {variable_list_str}) ) matrix_subq USING (column_position) """.format(num_cols=len(col_names), **locals())) # clean up and return plpy.execute("DROP TABLE {temp_table}".format(**locals())) end = time() return (output_table, len(col_names), end - start) # ------------------------------------------------------------------------------ def correlation_help_message(schema_madlib, message, cov=False, **kwargs): """ Given a help string, provide usage information """ func = "covariance" if cov else "correlation" if message is not None and \ message.lower() in ("usage", "help", "?"): return """ Usage: ----------------------------------------------------------------------- SELECT {schema_madlib}.{func} ( source_table TEXT, -- Source table name (Required) output_table TEXT, -- Output table name (Required) target_cols TEXT, -- Comma separated columns for which summary is desired -- (Default: '*' - produces result for all columns) verbose BOOLEAN -- Verbosity ) ----------------------------------------------------------------------- Output will be a table with N+2 columns and N rows, where N is the number of numeric columns in 'target_cols'. The columns of the table are described as follows: - column_position : Position of the variable in the 'source_table'. - variable : Provides the row-header for each variable - Rest of the table is the NxN {func} matrix for all numeric columns in 'source_table'. The output table is arranged as a lower-traingular matrix with the upper triangle set to NULL. To obtain the result from the output_table in this matrix format ensure to order the elements using the 'column_position' column. """.format(schema_madlib=schema_madlib, func=func) elif message is not None and message.lower() in ('example', 'examples'): return """ DROP TABLE IF EXISTS example_data; CREATE TABLE example_data( id SERIAL, outlook text, temperature float8, humidity float8, windy text, class text) ; INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('sunny', 85, 85, 'false', E'Dont Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('sunny', 80, 90, 'true', E'Dont Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('overcast', 83, 78, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('rain', 70, 96, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('rain', 68, 80, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('rain', 65, 70, 'true', E'Dont Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('overcast', 64, 65, 'true', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('sunny', 72, 95, 'false', E'Dont Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('sunny', 69, 70, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('rain', 75, 80, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('sunny', 75, 70, 'true', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('overcast', 72, 90, 'true', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('overcast', 81, 75, 'false', 'Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES('rain', 71, 80, 'true', E'Dont Play'); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES(NULL, 100, 100, 'true', NULL); INSERT INTO example_data(outlook, temperature, humidity, windy, class) VALUES(NULL, 110, 100, 'true', NULL); SELECT madlib.{func}('example_data', 'example_data_output'); SELECT madlib.{func}('example_data', 'example_data_output', '*'); SELECT madlib.{func}('example_data', 'example_data_output', 'temperature, humidity'); -- To get the {func} matrix from output table: SELECT * from example_data_output order by column_position; """.format(func=func) else: if cov: return """ Covariance is a measure of how much two random variables change together. If the greater values of one variable mainly correspond with the greater values of the other variable, and the same holds for the smaller values, i.e., the variables tend to show similar behavior, the covariance is positive. In the opposite case, when the greater values of one variable mainly correspond to the smaller values of the other, i.e., the variables tend to show opposite behavior, the covariance is negative. The sign of the covariance therefore shows the tendency ------- For an overview on usage, run: SELECT {schema_madlib}.covariance('usage'); ------- For examples: SELECT {schema_madlib}.covariance('example'); """.format(schema_madlib=schema_madlib) else: return """ A correlation function is the degree and direction of association of two variables; how well can one random variable be predicted from the other. The coefficient of correlation varies from -1 to 1: 1 implies perfect correlation, 0 means no correlation, and -1 means perfectly anti-correlated. ------- For an overview on usage, run: SELECT {schema_madlib}.correlation('usage'); ------- For examples: SELECT {schema_madlib}.correlation('example'); """.format(schema_madlib=schema_madlib) # ------------------------------------------------------------------------------