# ---------------------------------------------------------------------- # Robust variance for Cox Proportional Hazards model # ---------------------------------------------------------------------- import plpy from utilities.utilities import unique_string from utilities.utilities import _assert from utilities.validate_args import table_exists from utilities.validate_args import columns_exist_in_table # ---------------------------------------------------------------------- def rb_help_message(schema_madlib, message=None, **kwargs): """ Given a help string, provide usage information Args: schema_madlib: String. Schema where MADlib is installed message: String. Which help message to display Returns: String. A detailed help message """ if message is not None and \ message.lower() in ("usage", "help", "?"): return """ This function calculates the Robust statistics for Cox Proportional Hazards Regression. ----------------------------------------------------------------------- USAGE ----------------------------------------------------------------------- SELECT madlib.robust_variance_coxph ( model_table TEXT, -- Model table name, which is generated FROM coxph_train output_table TEXT, -- Output table name ) ----------------------------------------------------------------------- OUTPUT ----------------------------------------------------------------------- The output table contains the following columns: - coef: Coefficients of regression - loglikelihood: Log-likelihood value - std_err: Standard errors - robust_se: Standard errors of the robust variance estimators - robust_z: z-stats of the robust variance estimators - robust_p: p-values of the robust variance estimators - hessian: Hessian matrix """ else: return """ This function calculates the Huber-White robust statistics for Cox Proportional Hazards Regression. For more details on function usage: SELECT madlib.robust_variance_coxph('usage'); """ # ---------------------------------------------------------------------- def _validate_params(schema_madlib, model_table, output_table): """ Validate the input parameters for coxph Args: @param schema_madlib - MADlib schema name @param model_table - A string, the model table name @param output_table - A string, the result table name Throws: "Robust Variance Cox error" if any argument is invalid """ _assert(model_table is not None and table_exists(model_table), "Robust Variance Cox error: Model data table does not exist") _assert(model_table is not None and table_exists(model_table + '_summary'), "Robust Variance Cox error: Model summary table does not exist") _assert(columns_exist_in_table(model_table, ['coef', 'loglikelihood', 'std_err', 'z_stats', 'p_values', 'hessian']), "Robust Variance Cox error: Invalid model data table" " - some required columns missing") _assert(columns_exist_in_table(model_table + '_summary', ['source_table', 'dependent_varname', 'independent_varname', 'right_censoring_status', 'strata']), "Robust Variance Cox error: Invalid model summary table" " - some required columns missing") _assert(output_table is not None and output_table.strip() != '', "Robust Variance Cox error: Invalid output_table is given") _assert(not table_exists(output_table, only_first_schema=True), "Robust Variance Cox error: Output table {0}" " already exists!".format(str(output_table))) not_null_coef = plpy.execute(""" SELECT count(*) AS c FROM {model_table} WHERE coef IS NOT NULL AND hessian IS NOT NULL """.format(model_table=model_table))[0]["c"] _assert(not_null_coef != 0, "Robust Variance Cox error: No not-null coef and hessian found in " "model_table {model_table}".format(model_table=model_table)) # To warn users who have multiple rows in the model table. if not_null_coef > 1: plpy.warning("Robust Variance Cox Warning: multiple rows in " "model_table {model_table}".format(model_table=model_table)) coef_contains_null = plpy.execute(""" SELECT {schema_madlib}.array_contains_null(coef) OR {schema_madlib}.array_contains_null(hessian) AS contains_null FROM {model_table} """.format(schema_madlib=schema_madlib, model_table=model_table))[0]["contains_null"] _assert(not coef_contains_null, "Robust Variance Cox error: coef or hessian array in {0} contains " "NULL values. (If the input table contains at least one row without " " NULLS then rerunning coxph should get correct values)". format(model_table)) # ---------------------------------------------------------------------- def rb_coxph(schema_madlib, model_table, output_table, **kwargs): """ Compute the Huber-White robust sandwich estimator for CoxPH model @brief Huber-White robust sandwich estimator for Cox proportional hazards regression. Args: @param schema_madlib - MADlib schema name @param model_table - A string, the table name of a trained cox model @param output_table - A string, the result table name Returns: None Side effect: A table named by output_table, which contains the following columns: * coef - the input coef * loglikelihood - the input log-likelihood * std_err - the input standard errors * robust_se - standard erros of robust variance estimators * robust_z - z statistics of robust variance estimators * robust_p - p value of robust variance estimators * hessian - the input hessian """ old_msg_level = plpy.execute(""" SELECT setting FROM pg_settings WHERE name='client_min_messages' """)[0]['setting'] plpy.execute("set client_min_messages to error") _validate_params(schema_madlib, model_table, output_table) # info is a dict that contains source_table, ind_var, dep_var, # right_censoring_status, strata info = plpy.execute("SELECT * FROM {model_table}_summary". format(model_table=model_table))[0] # table name of the result of H and S temp_H_S = unique_string() if info['strata'] is None: # Create H and S table, use window function ordered desc plpy.execute( """ CREATE TEMP TABLE {temp_H_S} as SELECT ({independent_varname})::FLOAT8[] AS x, ({dependent_varname})::FLOAT8 AS y, ({right_censoring_status})::BOOLEAN AS status, {schema_madlib}.coxph_h_s( {independent_varname}, (SELECT coef FROM {model_table}) ) OVER (ORDER BY {dependent_varname} desc) AS h_s FROM {source_table} s """.format(temp_H_S=temp_H_S, model_table=model_table, schema_madlib=schema_madlib, **info)) # W is computed in ascending order plpy.execute( """ CREATE TABLE {output_table} AS SELECT u.coef, u.loglikelihood, u.std_err, (v.f).std_err AS robust_se, (v.f).stats AS robust_z, (v.f).p_values AS robust_p, u.hessian FROM {model_table} u, ( SELECT {schema_madlib}.rb_coxph_step( x, y, status, (SELECT coef FROM {model_table}), (SELECT hessian FROM {model_table}), (h_s).h, (h_s).s ORDER BY y) AS f FROM {temp_H_S} ) v """.format(output_table=output_table, model_table=model_table, schema_madlib=schema_madlib, temp_H_S=temp_H_S, **info)) else: # strata can have multiple columns, so cannot use '{strata} AS strata' # To avoid name conflicts, have to use unique strings x = unique_string() y = unique_string() status = unique_string() h_s = unique_string() # Create H and S table, use window function ordered desc # But for each strata, we need to do this plpy.execute( """ CREATE TEMP TABLE {temp_H_S} AS SELECT ({independent_varname})::FLOAT8[] AS {x}, ({dependent_varname})::FLOAT8 AS {y}, ({right_censoring_status})::BOOLEAN AS {status}, {strata}, {schema_madlib}.coxph_h_s( {independent_varname}, (SELECT coef FROM {model_table}) ) OVER (partition by {strata} ORDER BY {dependent_varname} desc) AS {h_s} FROM {source_table} s """.format(temp_H_S=temp_H_S, model_table=model_table, x=x, y=y, status=status, h_s=h_s, schema_madlib=schema_madlib, **info)) # W is computed with ascending order plpy.execute( """ CREATE TABLE {output_table} AS SELECT u.coef, u.loglikelihood, u.std_err, (v.f).std_err AS robust_se, (v.f).stats AS robust_z, (v.f).p_values AS robust_p, u.hessian FROM {model_table} u, ( SELECT {schema_madlib}.rb_sum_strata(in_state) AS f FROM ( SELECT {schema_madlib}.rb_coxph_strata_step( {x}, {y}, {status}, (SELECT coef FROM {model_table}), (SELECT hessian FROM {model_table}), ({h_s}).h, ({h_s}).s ORDER BY {y}) AS in_state FROM {temp_H_S} group by {strata} ) w ) v """.format(output_table=output_table, model_table=model_table, x=x, y=y, status=status, h_s=h_s, schema_madlib=schema_madlib, temp_H_S=temp_H_S, **info)) # plpy.info(temp_H_S) plpy.execute("DROP TABLE IF EXISTS {temp_H_S}".format(temp_H_S=temp_H_S)) plpy.execute("SET client_min_messages TO " + old_msg_level) return None