m4_changequote(`') # coding=utf-8 """ @file cox_prop_hazards.py_in @brief Cox prop Hazards: Driver functions @namespace coxprophazards Cox prop Hazards: Driver functions //""" import plpy import re from utilities.control import MinWarning from utilities.validate_args import table_exists from utilities.validate_args import table_is_empty from utilities.validate_args import get_cols from utilities.validate_args import is_var_valid from utilities.utilities import unique_string from utilities.utilities import preprocess_keyvalue_params from utilities.utilities import _assert from utilities.utilities import get_seg_number from utilities.utilities import add_postfix from utilities.utilities import py_list_to_sql_string from utilities.validate_args import columns_exist_in_table from utilities.utilities import __mad_version from utilities.control import IterationController2S from convex.utils_regularization import utils_ind_var_scales import random # ---------------------------------------------------------------------- version_wrapper = __mad_version() madvec = version_wrapper.select_vecfunc() def coxph_help_message(schema_madlib, message, **kwargs): """ Help message for Cox Proportional Hazards @brief Args: @param schema_madlib string, Name of the schema madlib @param message string, Help message indicator Returns: String. Contain the help message string """ if not message: help_string = """ ----------------------------------------------------------------------- SUMMARY ----------------------------------------------------------------------- Functionality: Cox proprtional hazards regression (Breslow method) Proportional-Hazard models enable the comparison of various survival models. These survival models are functions describing the probability of a one-item event (prototypically, this event is death) with respect to time. The interval of time before death occurs is the survival time. Let T be a random variable representing the survival time, with a cumulative probability function P(t). Informally, P(t) is the probability that death has happened before time t. For more details on function usage: SELECT {schema_madlib}.coxph_train('usage') For an example on using the function: SELECT {schema_madlib}.coxph_train('example') """ elif message in ['usage', 'help', '?']: help_string = """ ----------------------------------------------------------------------- USAGE ----------------------------------------------------------------------- SELECT {schema_madlib}.coxph_train( 'source_table', -- Name of data table 'output_table', -- Name of result table (overwrites if exists) 'dependent_variable', -- Name of column for dependent variables 'independent_variable', -- Name of column for independent variables (can be any SQL expression Eg: '*') 'right_censoring_status', -- Name of the column containing censoring status 0/false : If the observation is censored 1/true : otherwise Can also be an SQL expression: 'dependent_variable < 10') (Optional, DEFAULT = TRUE) 'strata', -- The stratification column names. (Optional, DEFAULT = NULL) 'optimizer_params' -- The optimizer parameters as a comma-separated string ); ----------------------------------------------------------------------- OUTUPT ----------------------------------------------------------------------- The output table ('output_table' above) has the following columns 'coef' DOUBLE PRECISION[], -- Coefficients of regression 'loglikelihood' DOUBLE PRECISION, -- Log-likelihood value 'std_err' DOUBLE PRECISION[], -- Standard errors 'z_stats' DOUBLE PRECISION[], -- z-stats of the standard errors 'p_values' DOUBLE PRECISION[], -- p-values of the standard errors 'num_iterations' INTEGER -- Number of iterations performed by the optimizer The output summary table is named as _summary has the following columns 'source_table' VARCHAR, Source table name 'dep_var' VARCHAR, Dependent variable name 'ind_var' VARCHAR, Independent variable name 'right_censoring_status' VARCHAR, Right censoring status 'strata' VARCHAR, Stratification columns num_rows_processed INTEGER, Number of rows processed during training num_missing_rows_skipped INTEGER, Number of rows skipped during training due to missing values """ elif message in ['example', 'examples']: help_string = """ DROP TABLE IF EXISTS sample_data; CREATE TABLE sample_data ( id INTEGER NOT NULL, grp DOUBLE PRECISION, wbc DOUBLE PRECISION, timedeath INTEGER, status BOOLEAN ); COPY sample_data FROM STDIN DELIMITER '|'; 0 | 0 | 1.45 | 35 | t 1 | 0 | 1.47 | 34 | t 3 | 0 | 2.2 | 32 | t 4 | 0 | 1.78 | 25 | t 5 | 0 | 2.57 | 23 | t 6 | 0 | 2.32 | 22 | t 7 | 0 | 2.01 | 20 | t 8 | 0 | 2.05 | 19 | t 9 | 0 | 2.16 | 17 | t 10 | 0 | 3.6 | 16 | t 11 | 1 | 2.3 | 15 | t 12 | 0 | 2.88 | 13 | t 13 | 1 | 1.5 | 12 | t 14 | 0 | 2.6 | 11 | t 15 | 0 | 2.7 | 10 | t 16 | 0 | 2.8 | 9 | t 17 | 1 | 2.32 | 8 | t 18 | 0 | 4.43 | 7 | t 19 | 0 | 2.31 | 6 | t 20 | 1 | 3.49 | 5 | t 21 | 1 | 2.42 | 4 | t 22 | 1 | 4.01 | 3 | t 23 | 1 | 4.91 | 2 | t 24 | 1 | 5 | 1 | t \. SELECT {schema_madlib}.coxph_train( 'sample_data', 'sample_cox', 'timedeath', 'ARRAY[grp,wbc]', 'status'); SELECT * FROM sample_cox; """ else: help_string = "No such option. Use {schema_madlib}.coxph_train()" return help_string.format(schema_madlib=schema_madlib) # ------------------------------------------------------------ def coxph( schema_madlib, source_table, output_table, dependent_varname, independent_varname, right_censoring_status, strata, optimizer_params, *args, **kwargs): """ Cox proportional hazards regression training function @brief Cox proportional hazards regression, with stratification support. Args: @param schema_madlib - MADlib schema name @param source_table - A string, the data table name @param output_table - A string, the result table name @param dependent_varname - A string, the survival time column name or a valid expression @param independent_varname - A string, the covariates in array formats. It is a valid expression. @param right_censoring_status - A string, a column name or a valid expression that has boolean values. Whether the row of data is censored. Default is 'TRUE'. @param strata - A string, column names seprated by commas. The columns used for stratification. Default is None. @param optimizer_params - A string, which contains key=value pairs separated by commas. Default values: max_iter=20, optimizer='newton', tolerance=1e-4. Returns: A table named by output_table, which contains the following columns: * coef - An array of double precision values, fitting coefs * std_err - An array of double precision values, standard erros of coef * z - An array of double precision values, z statistics * p - An array of double precision values, p value """ old_msg_level = plpy.execute( """ select setting from pg_settings where name='client_min_messages' """)[0]['setting'] plpy.execute("set client_min_messages to warning") all_arguments = { 'schema_madlib': schema_madlib, 'source_table': source_table, 'output_table': output_table, 'dependent_varname': dependent_varname, 'independent_varname': independent_varname, 'right_censoring_status': right_censoring_status, 'strata': strata } _validate_params(**all_arguments) indepColumn = __check_args(schema_madlib, source_table, independent_varname, dependent_varname, right_censoring_status) (max_iter, optimizer, tolerance, array_agg_size, sample_size) = _extract_params(schema_madlib, optimizer_params) # Number of features n_features = plpy.execute( """ select array_upper({independent_varname}, 1) as n_features from {source_table} limit 1 """.format( independent_varname=independent_varname, source_table=source_table))[0]['n_features'] # Use the quick split to re-distribute the data # construct a list of other columns other_cols = [right_censoring_status] new_source_table, index, dep, indep, status, n_processed, n_skipped, real_distid, std_str = quick_split( schema_madlib, source_table, dependent_varname, n_features, independent_varname, right_censoring_status, strata, array_agg_size, sample_size, True) if n_processed > 0: compute_coxph(schema_madlib, new_source_table, output_table, index, dep, indep, n_features, status, strata, optimizer, max_iter, tolerance, real_distid, std_str) plpy.execute('drop table if exists ' + new_source_table) else: plpy.execute( """ create table {output_table} ( coef DOUBLE PRECISION[], loglikelihood DOUBLE PRECISION, std_err DOUBLE PRECISION[], z_stats DOUBLE PRECISION[], p_values DOUBLE PRECISION[], hessian DOUBLE PRECISION[], num_iterations INTEGER ); """.format(output_table=output_table)) plpy.execute( """ insert into {output_table} values (NULL, NULL, NULL, NULL, NULL, NULL, 1); """.format(output_table=output_table)) # the summary table output_table_summary = add_postfix(output_table, "_summary") plpy.execute( """ create table {output_table_summary} as select 'coxph'::varchar as method, '{source_table}'::varchar as source_table, '{output_table}'::varchar as out_table, '{dependent_varname}'::varchar as dependent_varname, '{independent_varname}'::varchar as independent_varname, {right_censoring_status}::varchar as right_censoring_status, {strata}::varchar as strata, {n_processed}::integer as num_processed, {n_skipped}::integer as num_missing_rows_skipped; """.format( source_table=source_table, output_table_summary=output_table_summary, output_table=output_table, dependent_varname=dependent_varname, independent_varname=independent_varname, right_censoring_status=("NULL::text" if right_censoring_status is None else "'" + right_censoring_status + "'"), strata=("NULL::text" if strata is None else "'" + strata + "'"), n_processed=n_processed, n_skipped=n_skipped)) plpy.execute("set client_min_messages to " + old_msg_level) return None # ------------------------------------------------------------ def compute_coxph(schema_madlib, source_table, output_table, index, dep, indep, n_features, status, strata, optimizer, max_iter, precision, real_distid, std_str): """ Use the old sequential algorithm to solve coxph @brief Run ordered aggregate on the re-distributed data. Each row of the table contains many original rows of the original data table in inverse order. """ m4_ifdef(, , ) if max_iter < 1: plpy.error("Number of iterations must be positive") if optimizer not in ['newton']: plpy.error("Unknown optimizer requested. Must be 'newton'") # FIXME starting from random values or 0 ? coef = [0] * n_features L = float('-inf') # $1 - previous coef # $2 - coef limit, when it an array of all 0 (during 1-th iteration), # we compute the limit, and return it to Python. And then in the following # iterations, we use the limit. # the data is already sorted desc if strata is None: sql = """ select (f).* from ( select {schema_madlib}.coxph_improved_step( {indep}, {dep}, {status}, $1, $2 order by {index} ) as f from {source_table} group by {real_distid} -- to avoid gathering data before aggregating ) s """.format(**locals()) else: sql = """ select (f).* from ( select {schema_madlib}.coxph_improved_strata_step_outer(inner_state) as f from ( select {schema_madlib}.coxph_improved_strata_step_inner( {indep}, {dep}, {status}, $1, $2 order by {index} ) as inner_state from {source_table} group by {strata} ) t1 ) t2 """.format(**locals()) old_coef = coef n_iter = 0 max_coef = [-1] * n_features # limit of coefficients sql_plan = plpy.prepare(sql, ["double precision[]", "double precision[]"]) while True: n_iter += 1 # The result contains coef, L, d2L result = plpy.execute(sql_plan, [coef, max_coef])[0] prev_L, L = L, result['l'] new_coef = result['coef'] if n_iter == 1: max_coef = result['max_coef'] if L < prev_L: # Newton Raphson step coef = [(i + j)/2 for i, j in zip(coef, old_coef)] else: old_coef, coef = coef, new_coef if n_iter > max_iter or (L > prev_L and abs(1 - L / prev_L) <= precision): # exiting since max_iter iterations explored or # very small increase in log_likelihood # Compute std_err, t_stats, p_value if hawq_platform: coef_l_d2l = unique_string() + "_coef_l_d2l" plpy.execute(""" CREATE TEMP TABLE {coef_l_d2l} as SELECT {coef}::float8[] as coef, {l}::float8 as l, {d2l}::float8[] as d2l, {n_iter}::integer as n_iter """.format( coef_l_d2l=coef_l_d2l, coef=py_list_to_sql_string(coef), l=L, d2l=py_list_to_sql_string(result["d2l"]), n_iter=n_iter)) plpy.execute(""" create table {output_table} as SELECT (f).* FROM ( SELECT {schema_madlib}.compute_coxph_result( coef, l, d2l, n_iter, {std_str} ) as f FROM {coef_l_d2l} ) s; """.format(coef_l_d2l=coef_l_d2l, schema_madlib=schema_madlib, output_table=output_table, std_str=std_str)) plpy.execute("drop table if exists " + coef_l_d2l) else: plpy.execute( """ CREATE TABLE {output_table} as SELECT (f).* FROM ( SELECT {schema_madlib}.compute_coxph_result( {coef}, {L}, {d2L}, {n_iter}, {std_str} ) as f ) s; """.format( output_table=output_table, schema_madlib=schema_madlib, coef=py_list_to_sql_string(coef), L=L, d2L=py_list_to_sql_string(result["d2l"]), std_str=std_str, n_iter=n_iter)) return None # ------------------------------------------------------------ # The new quick method. Used by PLANET algorithm for # decision tree. Since it has been successfully applied # onto an algorithm, no reason that we could not use it. def quick_split( schema_madlib, source_table, split_col, num_features, indep, status, strata, array_agg_size, sample_size, reverse=True, split_col_alias=''): """ Quickly find the splits of the split_col so that we can evenly cut the data without sorting @brief We want to cut the table to n even pieces, and at the same time keep split_col sorted. For example, in ARIMA, we need to partition the data into multiple chunks while each chunk contains consecutive time series. In CoxPH, we need to do the same thing. This step proves to be time-consuming becuase sorting (row_number over (order by ...)) is really slow. The solution is to work on a random sample of the original data instead of the whole data set. @param source_table The data table @param split_col Find the cuts of this column @param num_features The number of features @param indep The indepent variables """ # Need the total number of records n_rows = plpy.execute( "select count(*) from {source_table}".format( source_table=source_table))[0]['count'] # It might be cubersome to deal with NULL values in C++, # since a matrix of independent variables is passed into C++. filter_null = """ {split_col} is not NULL and {schema_madlib}.array_contains_null({indep}) is False and {indep} is not NULL and ({status}) is not NULL """.format(schema_madlib=schema_madlib, split_col=split_col, indep=indep, status=status) # Number of rows to be processed n_processed = plpy.execute( "select count(*) from {source_table} where {filter_null}".format( source_table=source_table, filter_null=filter_null))[0]['count'] n_skipped = n_rows - n_processed if n_processed == 0: return ('', 1, '', '', '', n_processed, n_skipped, '', '') # Rewrite the num_splits n_rows_in_chunk = array_agg_size / num_features num_splits = (1 if n_rows < n_rows_in_chunk else n_rows / n_rows_in_chunk) # So that we could compute the percentage of the sample # We sample a few more values to make sure we can get enough # samples, otherwise the number of samples might be smaller # than sample_size. percentage = sample_size / n_rows + 0.01 n_per_seg = int(sample_size / get_seg_number()) + 1 output_table = unique_string() + '_redist' # nomalize the data to avoid possible overflow x_mean = plpy.execute( """ select {schema_madlib}.array_avg({indep}, false) as xmean from {source_table} where {filter_null} """.format( schema_madlib=schema_madlib, indep=indep, source_table=source_table, filter_null=filter_null))[0]['xmean'] mean_str = "array[" + ",".join(str(v) for v in x_mean) + "]" scales = plpy.execute( """ select {schema_madlib}.array_avg( {schema_madlib}.array_sub({indep}::float8[], {mean}::float8[]), true) as scale from {source_table} where {filter_null} """.format( schema_madlib=schema_madlib, indep=indep, mean=mean_str, source_table=source_table, filter_null=filter_null))[0]['scale'] scale_str = "array[" + ",".join(str(v) if v != 0 else "1" for v in scales) + "]" # compute the break points # Use a fixed number of rows of sample to approximate the # break points. # Each bin will have approximate the same number of records. splits = plpy.execute( """ select {schema_madlib}._compute_splits( {split_col}, {n_per_seg}, {num_splits}) as splits from {source_table} where random() <= {percentage} and {filter_null} """.format( schema_madlib=schema_madlib, split_col=split_col, n_per_seg=n_per_seg, num_splits=num_splits, source_table=source_table, filter_null=filter_null, percentage=percentage))[0]['splits'] # Since the array of break points is not big # we load it into memory. splits_str = ('NULL::DOUBLE PRECISION[]' if splits is None else "array[" + ",".join(str(split) for split in splits) + "]") # Use the same convention as _redistribute_data # ie. hard code dep, and indep as column names # NOTE: CoxPH will process the data in a reversed order. Thus # the latest data points will be grouped into the group 0. indep_name = unique_string() + '_indep' # avoid name conflicts distid = unique_string() + '_distid' if split_col_alias == "": split_col_alias = unique_string() + '_split_alias' status_name = unique_string() + '_status' order_str = 'desc' if reverse else '' # Each big rows may contain different number of original small rows # since the quick split is approximate. But this is perfectly fine. if strata is None: real_distid = unique_string() + '_real_distid' strata_sql = """ CREATE TEMP TABLE {output_table} AS select 0 as {real_distid}, -- to ensure that all chunks are stored in the same segment {distid}, array_agg({split_col_alias} ORDER BY {split_col_alias} {order_str}, {status_name}) as {split_col_alias}, array_agg({status_name} ORDER BY {split_col_alias} {order_str}, {status_name}) as {status_name}, {schema_madlib}.matrix_agg( {schema_madlib}.utils_normalize_data( {indep_name}, {mean_str}, {scale_str} ) ORDER BY {split_col_alias} {order_str}, {status_name}) as {indep_name} from ( select {schema_madlib}._compute_grpid( {splits}, {split_col}, {reverse}) as {distid}, {split_col} as {split_col_alias}, {indep} as {indep_name}, ({status})::INTEGER as {status_name} from {source_table} where {filter_null} ) table_split group by {distid} m4_ifdef(, , ); """.format( output_table=output_table, schema_madlib=schema_madlib, distid=distid, splits=splits_str, reverse='true' if reverse else 'false', split_col_alias=split_col_alias, split_col=split_col, indep=indep, indep_name=indep_name, status_name=status_name, status=status, filter_null=filter_null, source_table=source_table, order_str=order_str, mean_str=mean_str, scale_str=scale_str, real_distid=real_distid) plpy.execute(strata_sql) return (output_table, distid, split_col_alias, indep_name, status_name, n_processed, n_skipped, real_distid, scale_str) else: plpy.execute(""" create temp table {output_table} as select {strata}, {distid}, array_agg({split_col_alias} order by {split_col_alias} {order_str}, {status_name}) as {split_col_alias}, array_agg({status_name} order by {split_col_alias} {order_str}, {status_name}) as {status_name}, {schema_madlib}.matrix_agg( {schema_madlib}.utils_normalize_data( {indep_name}, {mean_str}, {scale_str} ) order by {split_col_alias} {order_str}, {status_name}) as {indep_name} from ( select {strata}, {schema_madlib}._compute_grpid( {splits}, {split_col}, {reverse}) as {distid}, {split_col} as {split_col_alias}, {indep} as {indep_name}, ({status})::INTEGER as {status_name} from {source_table} where {filter_null} ) table_split group by {distid}, {strata} -- ensure that all chunks belonging to the same strata go to the same segment m4_ifdef(, , ) """.format( output_table=output_table, schema_madlib=schema_madlib, distid=distid, splits=splits_str, reverse='true' if reverse else 'false', split_col_alias=split_col_alias, split_col=split_col, indep=indep, indep_name=indep_name, status_name=status_name, status=status, strata=strata, filter_null=filter_null, source_table=source_table, mean_str=mean_str, scale_str=scale_str, order_str=order_str)) return (output_table, distid, split_col_alias, indep_name, status_name, n_processed, n_skipped, '', scale_str) # ---------------------------------------------------------------------- def _validate_params(schema_madlib, source_table, output_table, dependent_varname, independent_varname, right_censoring_status, strata, *args, **kwargs): """ Validate the input parameters for coxph Args: @param schema_madlib - MADlib schema name @param source_table - A string, the data table name @param output_table - A string, the result table name @param dependent_varname - A string, the survival time column name or a valid expression @param independent_varname - A string, the covariates in array formats. It is a valid expression. @param right_censoring_status - A string, a column name or a valid expression that has boolean values. Whether the row of data is censored. Default is 'TRUE'. @param strata - A string, column name seprated by commas. The columns used for stratification. Default is None. Throws: "Cox error" if any argument is invalid """ # _assert(source_table is not None and table_exists(source_table), # "Cox error: Source data table does not exist!") _assert(not table_exists(output_table, only_first_schema=True), "Cox error: Output table {0}" " already exists!".format(str(output_table))) output_table_summary = add_postfix(output_table, "_summary") _assert(not table_exists(output_table_summary, only_first_schema=True), "Cox error: Output table {0}" " already exists!".format(str(output_table_summary))) if strata is not None: strata_cols = [a.strip() for a in strata.split(",")] _assert(columns_exist_in_table(source_table, strata_cols, schema_madlib), "ARIMA error: {1} columns do not exist in {0}!" .format(source_table, strata_cols)) return None # ---------------------------------------------------------------------- def _extract_params(schema_madlib, optimizer_params): """ Extract optimizer control parameter or set the default values @brief optimizer_params is a string with the format of 'max_iter=..., optimizer=..., tolerance=...'. The order does not matter. If a parameter is missing, then the default value for it is used. If optimizer_params is None or '', then all default values are used. If the parameter specified is none of 'max_iter', 'optimizer', or 'tolerance' then an error is raised. This function also validates the values of these parameters. Throws: "Cox error" - If the parameter is unsupported or the value is not valid. """ allowed_params = set(["max_iter", "optimizer", "tolerance", "array_agg_size", "sample_size"]) name_value = dict(max_iter=100, optimizer="newton", tolerance=1e-8, array_agg_size=10000000, sample_size=1000000) if optimizer_params is None or len(optimizer_params) == 0: return (name_value['max_iter'], name_value['optimizer'], name_value['tolerance'], name_value['array_agg_size'], name_value['sample_size']) for s in preprocess_keyvalue_params(optimizer_params): items = s.split("=") if (len(items) != 2): plpy.error("Cox error: Optimizer parameter list has incorrect format!") param_name = items[0].strip(" \"").lower() param_value = items[1].strip(" \"").lower() if param_name not in allowed_params: plpy.error( """ Cox error: {param_name} is not a valid parameter name. Run: SELECT {schema_madlib}.coxph('usage'); to see the allowed parameters. """.format(param_name=param_name, schema_madlib=schema_madlib)) if param_name == "array_agg_size": try: name_value["array_agg_size"] = int(param_value) except: plpy.error("Cox error: array_agg_size must be an integer value!") if param_name == "sample_size": try: name_value["sample_size"] = int(param_value) except: plpy.error("Cox error: sample_size must be an integer value!") if param_name == "max_iter": try: name_value["max_iter"] = int(param_value) except: plpy.error("Cox error: max_iter must be an integer number!") if param_name == "optimizer": name_value["optimizer"] = param_value if param_name == "tolerance": try: name_value["tolerance"] = float(param_value) except: plpy.error("Cox error: tolerance must be a double precision value!") if name_value["max_iter"] <= 0: plpy.error("Cox error: max_iter must be positive!") if name_value["optimizer"] != "newton": plpy.error("Cox error: this optimization method is not supported yet!") if name_value["tolerance"] < 0: plpy.error("Cox error: tolerance cannot be smaller than 0!") if name_value["array_agg_size"] <= 0: plpy.error("Cox error: array_agg_size must be positive!") if name_value["sample_size"] <= 0: plpy.error("Cox error: sample_size must be positive!") return (name_value['max_iter'], name_value['optimizer'], name_value['tolerance'], name_value['array_agg_size'], name_value['sample_size']) # ---------------------------------------------------------------------- def __check_args(schema_madlib, tbl_source, col_ind_var, col_dep_var, col_status): _assert(tbl_source is not None, "Cox Proportional Hazards Error: Source table should not be NULL!") _assert(col_ind_var is not None, "Cox Proportional Hazards Error: Independent variable should not be NULL!") _assert(col_dep_var is not None, "Cox Proportional Hazards Error: Dependent variable should not be NULL!") _assert(table_exists(tbl_source), "Cox Proportional Hazards Error: Source table " + tbl_source + " does not exist!") _assert(not table_is_empty(tbl_source), "Cox Proportional Hazards Error: Source table " + tbl_source + " is empty!") _assert(columns_exist_in_table(tbl_source, [col_dep_var]), "Cox Proportional Hazards Error: Dependent variable does not exist!") _assert(is_var_valid(tbl_source, col_ind_var), "Cox Proportional Hazards Error: The independent variable does not exist!") _assert(is_var_valid(tbl_source, col_status), "Cox Proportional Hazards Error: Not a valid boolean expression for status!") col_ind_var_new = col_ind_var cols = get_cols(tbl_source) # Select al columns except status and dependent variable if col_ind_var == "*": cols = get_cols(tbl_source) outstr_array = [] for each_col in cols: if each_col != col_dep_var.lower() and each_col not in col_status.lower(): outstr_array.append(each_col) col_ind_var_new = 'array[%s]' % (','.join(outstr_array)) return col_ind_var_new # ----------------------------------------------------------------------- # ZPH functionality # ----------------------------------------------------------------------- def zph_help_message(schema_madlib, message, **kwargs): """ Help message for function to test the proportional hazards assumption for a Cox regression model fit @brief Args: @param schema_madlib string, Name of the schema madlib @param message string, Help message indicator Returns: String. Contain the help message string """ if not message: help_string = """ ----------------------------------------------------------------------- SUMMARY ----------------------------------------------------------------------- Functionality: Test of proportional hazards assumption Proportional-Hazard models enable the comparison of various survival models. See {schema_madlib}.coxph_train() for details to create a Cox PH model. These PH models, however, assume that the hazard for a given individual is a fixed proportion of the hazard for any other individual, and the ratio of the hazards is constant across time. The cox_zph() function is used to test this assumption by computing the correlation of the residual of the Cox PH model with time. For more details on function usage: SELECT {schema_madlib}.cox_zph('usage') For an example on using the function: SELECT {schema_madlib}.cox_zph('example') """ elif message in ['usage', 'help', '?']: help_string = """ ----------------------------------------------------------------------- USAGE ----------------------------------------------------------------------- SELECT {schema_madlib}.cox_zph( 'cox_model_table', -- TEXT. The name of the table containing the Cox Proportional-Hazards model 'output_table', -- TEXT. The name of the table where the test statistics are saved ); ----------------------------------------------------------------------- OUTUPT ----------------------------------------------------------------------- The ('output_table' above) has the following columns - covariate TEXT. The names of independent variables - rho FLOAT8[]. Vector of the correlation coefficients between survival time and the scaled Schoenfeld residuals - chi_square FLOAT8[]. Chi-square test statistic for the correlation analysis - p_value FLOAT8[]. Two-side p-value for the chi-square statistic The output residual table is named as _residual has the following columns - FLOAT8. Time values (dependent variable) present in the original source table. - residual FLOAT8[]. Difference between the original covariate value and the expectation of the covariate obtained from the coxph model. - scaled_reisdual FLOAT8[]. Residual values scaled by the variance of the coefficients """ elif message in ['example', 'examples']: help_string = """ DROP TABLE IF EXISTS sample_data; CREATE TABLE sample_data ( id INTEGER NOT NULL, grp DOUBLE PRECISION, wbc DOUBLE PRECISION, timedeath INTEGER, status BOOLEAN ); -- Insert sample data COPY sample_data FROM STDIN DELIMITER '|'; 0 | 0 | 1.45 | 35 | t 1 | 0 | 1.47 | 34 | t 3 | 0 | 2.2 | 32 | t 4 | 0 | 1.78 | 25 | t 5 | 0 | 2.57 | 23 | t 6 | 0 | 2.32 | 22 | t 7 | 0 | 2.01 | 20 | t 8 | 0 | 2.05 | 19 | t 9 | 0 | 2.16 | 17 | t 10 | 0 | 3.6 | 16 | t 11 | 1 | 2.3 | 15 | t 12 | 0 | 2.88 | 13 | t 13 | 1 | 1.5 | 12 | t 14 | 0 | 2.6 | 11 | t 15 | 0 | 2.7 | 10 | t 16 | 0 | 2.8 | 9 | t 17 | 1 | 2.32 | 8 | t 18 | 0 | 4.43 | 7 | t 19 | 0 | 2.31 | 6 | t 20 | 1 | 3.49 | 5 | t 21 | 1 | 2.42 | 4 | t 22 | 1 | 4.01 | 3 | t 23 | 1 | 4.91 | 2 | t 24 | 1 | 5 | 1 | t \. -- Run coxph function SELECT {schema_madlib}.coxph_train( 'sample_data', 'sample_cox', 'timedeath', 'ARRAY[grp,wbc]', 'status'); -- Get the Cox PH model SELECT * FROM sample_cox; -- Run the PH assumption test and obtain the results SELECT {schema_madlib}.cox_zph('sample_cox', 'sample_zph_output'); SELECT * FROM sample_zph_output; """ else: help_string = "No such option. Use {schema_madlib}.cox_zph()" return help_string.format(schema_madlib=schema_madlib) def zph(schema_madlib, cox_output_table, output_table): """ Compute the Schoenfeld residuals for a Hazards data table @brief Compute the Schoenfeld residuals for a Hazards data table by using an aggregate-defined window function Args: @param schema_madlib: string, Name of the MADlib schema @param cox_output_table: string, Name of the coxph output_table Returns: None """ _validate_zph_params(schema_madlib, cox_output_table, output_table) cox_output_table_summary = add_postfix(cox_output_table, "_summary") rv = plpy.execute(""" SELECT source_table, dependent_varname, independent_varname, right_censoring_status, strata FROM {cox_output_table_summary} """.format(cox_output_table_summary=cox_output_table_summary)) source_table = rv[0]['source_table'] dependent_variable = rv[0]['dependent_varname'] independent_variable = rv[0]['independent_varname'] right_censoring_status = rv[0]['right_censoring_status'] strata = rv[0]['strata'] _compute_residual(schema_madlib, source_table, output_table, dependent_variable, independent_variable, cox_output_table, right_censoring_status, strata) # ---------------------------------------------------------------------- def _validate_zph_params(schema_madlib, cox_model_table, output_table): """ Args: @param schema_madlib: string, Name of the MADlib schema @param cox_model_table: string, Table name for Cox Prop Hazards model @param output_table: string, Output data table name Returns: None Throws: Error on any invalid parameter """ if cox_model_table is None or cox_model_table.strip() == '': plpy.error("Cox error: NULL/Empty model table is given!") cox_model_table_summary = add_postfix(cox_model_table, "_summary") _assert(cox_model_table is not None and table_exists(cox_model_table) and table_exists(cox_model_table_summary), "Cox error: Model table {0} or summary table {1} " "does not exist!".format(cox_model_table, cox_model_table_summary)) output_table_residual= add_postfix(output_table, "_residual") _assert((not table_exists(output_table, only_first_schema=True)) and (not table_exists(output_table_residual, only_first_schema=True)), "Cox error: Output table {0} or residual table {1} " "already exists!".format(output_table, output_table_residual)) summary_columns = ["source_table", "dependent_varname", "independent_varname", "right_censoring_status", "strata"] _assert(columns_exist_in_table(cox_model_table + "_summary", summary_columns), "Cox error: At least one column from {0} missing in " "model table {1}". format(str(summary_columns), cox_model_table)) return None # ---------------------------------------------------------------------- def _compute_residual(schema_madlib, source_table, output_table, dependent_variable, independent_variable, cox_output_table, right_censoring_status=None, strata=None, **kwargs): """ Compute the Schoenfeld residuals for a Hazards model @brief Computes the Schoenfeld residuals for a Hazards data table by using an aggregate-defined window function and outputs to a table Args: @param schema_madlib: string, Name of the MADlib schema @param source_table: string, Input data table name @param output_table: string, Output data table name @param dependent_variable: string, Dependent variable name @param independent_variable: string, Independent variable name (could also be an expression) @param right_censoring_status: string, Column name with right censoring status @param cox_output_table: string, Output table of coxph @param strata: string, Comma-separated list of columns to stratify with Returns: None """ if not right_censoring_status: right_censoring_status = 'TRUE' if strata: partition_str = "PARTITION BY {0}".format(strata) else: partition_str = '' coef = madvec(plpy.execute("SELECT coef FROM {table} ". format(table=cox_output_table))[0]["coef"], text=False) coef_str = "ARRAY" + str(coef) # We don't extract a copy of the Hessian 2D array, since Postgres/GPDB still # don't support getting a 2d array into plpython residual_table = unique_string() format_args = {'schema_madlib': schema_madlib, 'output': output_table, 'indep_column': independent_variable, 'dep_column': dependent_variable, 'status': right_censoring_status, 'cox_output_table': cox_output_table, 'source_table': source_table, 'residual_table': residual_table, 'coef_str': coef_str, 'partition_str': partition_str} # plpy.info("--------- Computing residuals --------- ") plpy.execute(""" CREATE TEMP TABLE {residual_table} AS SELECT {dep_column}, {schema_madlib}.array_sub( x::DOUBLE PRECISION[], expectation_x::DOUBLE PRECISION[] ) AS residual FROM ( SELECT {dep_column}, ({indep_column})::DOUBLE PRECISION[] AS x, ({status})::BOOLEAN as status, {schema_madlib}.zph_agg( ({indep_column})::DOUBLE PRECISION[], {coef_str} ) OVER ({partition_str} ORDER BY {dep_column} DESC) AS expectation_x FROM {source_table} WHERE {dep_column} IS NOT NULL AND NOT {schema_madlib}.array_contains_null( {indep_column}::DOUBLE PRECISION[]) ) AS q1 WHERE status is TRUE ORDER BY {dep_column} ASC m4_ifdef(, , ) """.format(**format_args)) n_uncensored = plpy.execute("""SELECT count(*)::INTEGER as n_uncensored FROM {table} """.format(table=residual_table))[0]["n_uncensored"] format_args['n_uncensored'] = n_uncensored # plpy.info("--------- Computing scaled residuals ---------") output_residual = add_postfix(output_table, "_residual") plpy.execute(""" CREATE TABLE {output_residual} AS SELECT {dep_column}, residual as residual, {schema_madlib}.__coxph_scale_resid( {n_uncensored}::INTEGER, (SELECT hessian FROM {cox_output_table}), residual ) AS scaled_residual FROM {residual_table} m4_ifdef(, , ) """.format(output_residual=output_residual, **format_args)) # plpy.info("--------- Computing metrics ---------") mean = plpy.execute(""" SELECT avg({dep_column}) AS w FROM {residual_table} """.format(**format_args))[0]['w'] plpy.execute(""" CREATE TABLE {output} AS SELECT ('{indep_column}')::TEXT as covariate, rho, (f).chi_square_stat as chi_square, (f).p_value as p_value FROM ( SELECT {schema_madlib}.array_elem_corr_agg( scaled_residual, ({dep_column} - {mean})::DOUBLE PRECISION) AS rho, {schema_madlib}.__coxph_resid_stat_agg( ({dep_column} - {mean})::DOUBLE PRECISION, residual, (SELECT hessian FROM {cox_output_table}), {n_uncensored}::INTEGER) AS f FROM {output_residual} ) AS q1 m4_ifdef(, , ) """.format(mean=mean, output_residual=output_residual, **format_args)) # Cleanup plpy.execute('DROP TABLE IF EXISTS ' + residual_table) def cox_prop_hazards(schema_madlib, usage_string, **kwargs): plpy.warning("This function has been deprecated. Please use 'coxph_train' instead.") if usage_string == '': message = """ Summary ------------------------------------------------------------------------------------ Functionality: Cox proprtional hazards regression (Breslow method) SELECT {schema_madlib}.cox_prop_hazards( 'source_table', 'output_table', 'dependent_variable', 'independent_variable', 'right_censoring_status' ); For more details on function usage: SELECT {schema_madlib}.cox_prop_hazards('usage'); """ elif usage_string in ('usage', 'help', '?'): message = """ Usage ------------------------------------------------------------------------------------ SELECT {schema_madlib}.cox_prop_hazards( 'source_table', -- Name of data table 'output_table', -- Name of result table (overwrites if exists) 'dependent_variable', -- Name of column for dependent variables 'independent_variable', -- Name of column for independent variables (can be any SQL expression Eg: ''*'') ['right_censoring_status', -- Name of the column containing censoring status -- 0/false : If the observation is censored -- 1/true : otherwise -- Default is 1/true for all observations -- Can also be an SQL expression: 'dependent_variable < 10' ); Output: ------------------------------------------------------------------------------------ The output table (''output_table'' above) has the following columns: 'coef' DOUBLE PRECISION[], -- Coefficients of regression 'std_err' DOUBLE PRECISION[], -- Standard errors 'z_stats' DOUBLE PRECISION[], -- z-stats of the standard errors 'p_values' DOUBLE PRECISION[], -- p-values of the standard errors """ else: message = "No such option. Run SELECT {schema_madlib}.cox_prop_hazards()" return message.format(schema_madlib=schema_madlib) def _validate_predict(schema_madlib, model_table, source_table, id_col_name, output_table, pred_type, reference): """ Validate the arguments """ _assert(pred_type in ('linear_predictors', 'risk', 'terms'), "Cox predict error: Invalid prediction type.") _assert(reference in ('overall', 'strata'), "Cox predict error: Invalid reference type.") _assert(model_table and model_table.strip().lower() not in ('null', ''), "Cox predict error: Invalid model table.") _assert(table_exists(model_table), "Cox predict error: Model table is missing.") _assert(source_table and source_table.strip().lower() not in ('null', ''), "Cox predict error: Invalid source table.") _assert(table_exists(source_table), "Cox predict error: source table is missing.") _assert(not table_exists(output_table, only_first_schema=True), "Cox predict error: output table already exists.") _assert(columns_exist_in_table(model_table, ["coef"], schema_madlib), "Cox predict error: Invalid model table ({0})".format(model_table)) model_summary = model_table + "_summary" _assert(table_exists(model_summary), "Cox predict error: Model summary table is missing.") _assert(columns_exist_in_table(model_summary, ["independent_varname", "strata"], schema_madlib), "Cox predict error: Invalid summary table ({0})".format(model_summary)) # ------------------------------------------------------------------------------ def coxph_predict(schema_madlib, model_table, source_table, id_col_name, output_table, pred_type, reference, **kwargs): """ Cox prediction function """ if pred_type is None: pred_type = "linear_predictors" if reference is None: reference = "strata" _validate_predict(schema_madlib, model_table, source_table, id_col_name, output_table, pred_type, reference) model_table_summary = model_table + "_summary" summary_elements = plpy.execute("SELECT strata, independent_varname FROM {0}" .format(model_table_summary))[0] strata = summary_elements['strata'] independent_varname = summary_elements['independent_varname'] term_avg = unique_string() select_strata = "" group_by_strata = "" where_strata = "" if pred_type == "terms": strata = None if strata is not None and reference == "strata": select_strata = "{strata}".format(**locals()) + "," group_by_strata = "group by {strata}".format(**locals()) strat_cols_condition = "" els = [] for s in strata.split(","): els.append("t." + s + " = s." + s) strat_cols_condition = " and ".join(els) where_strata = "where {strat_cols_condition}".format(**locals()) if columns_exist_in_table(source_table, [id_col_name], schema_madlib): coxph_predict_id = id_col_name else: coxph_predict_id = 'coxph_predict_id' # resolve name conflicts in output table output_name = pred_type + "_output" if id_col_name == pred_type else pred_type if pred_type != "terms": sql_predict = """ CREATE TABLE {output_table} AS SELECT {id_col_name} AS {coxph_predict_id}, {schema_madlib}._coxph_predict_resp( coef, {independent_varname}, {term_avg}, '{pred_type}'::TEXT) AS {output_name} FROM ( SELECT {select_strata} {schema_madlib}.avg({independent_varname}) as {term_avg} FROM {source_table} {group_by_strata} ) t, {model_table} m, {source_table} s {where_strata} """.format(**locals()) else: sql_predict = """ CREATE TABLE {output_table} AS SELECT {id_col_name} AS {coxph_predict_id}, {schema_madlib}._coxph_predict_terms( coef, {independent_varname}, {term_avg} ) AS {output_name} FROM ( SELECT {select_strata} {schema_madlib}.avg({independent_varname}) as {term_avg} FROM {source_table} {group_by_strata} ) t, {model_table} m, {source_table} s {where_strata} """.format(**locals()) with MinWarning('warning'): plpy.notice("sql_predict:\n" + sql_predict) plpy.execute(sql_predict) return None def coxph_predict_help_message(schema_madlib, message, **kwargs): """ Help message for prediction using a CoxPH model @brief Args: @param schema_madlib string, Name of the schema madlib @param message string, Help message indicator Returns: String. Help message string """ if not message: help_string = """ ----------------------------------------------------------------------- SUMMARY ----------------------------------------------------------------------- Functionality: Prediction using a CoxPH model The prediction function is provided to calculate the linear predictors, risk or the linear terms for the given prediction data. For more details on function usage: SELECT {schema_madlib}.coxph_predict('usage') For an example on using the function: SELECT {schema_madlib}.coxph_predict('example') """ elif message in ['usage', 'help', '?']: help_string = """ ----------------------------------------------------------------------- USAGE ----------------------------------------------------------------------- SELECT {schema_madlib}.coxph_predict( 'model_table', -- TEXT. Name of the table containing the cox model. 'source_table', -- TEXT. Name of the table containing the prediction data. 'id_col_name', -- TEXT. Name of the id column in the source table. 'output_table', -- TEXT. Name of the table to store the prediction results in. 'pred_type', -- TEXT. Type of prediction. Can be one of 'linear_predictors', 'risk' or 'terms'. Default = 'linear_predictors'. 'linear_predictors' calculates the dot product of the independent variables and the coefficients. 'risk' is the exponentiated value of the linear prediction. 'terms' correspond to the linear terms obtained by multiplying the independent variables with their corresponding coefficients values (without further calculating the sum of these terms) The resulting predictions, in all of the above cases, are then centered around a reference level. 'reference' -- TEXT. Reference level to use for centering the predictions. Can be one of 'strata' or 'overall'. Default = 'strata'. Cox model is a relative risk model wherein the predictions are relative to the sample that they are taken from. Therefore, all predictions are centered around the mean of the covariates within each stratum by default. If it is instead desired to use the mean over all samples, reference='overall' can be specified. If there was no stratification involved, the reference parameter does not have any effect. Note 1: For pred_type = 'terms', the predictions are always centered around the overall mean values of the covariates independent of stratification. Note 2: R uses 'sample' instead of 'overall' when referring to the overall mean value of the covariates as being the reference level. ) ----------------------------------------------------------------------- OUTUPT ----------------------------------------------------------------------- The ('output_table' above) has the following columns - id TEXT. The id column name from the source table - predicted_result DOUBLE PRECISION. Result of prediction based of the value of the pred_type parameter """ elif message in ['example', 'examples']: help_string = """ DROP TABLE IF EXISTS sample_data; CREATE TABLE sample_data ( id INTEGER NOT NULL, grp DOUBLE PRECISION, wbc DOUBLE PRECISION, timedeath INTEGER, status BOOLEAN ); -- Insert sample data COPY sample_data FROM STDIN DELIMITER '|'; 0 | 0 | 1.45 | 35 | t 1 | 0 | 1.47 | 34 | t 3 | 0 | 2.2 | 32 | t 4 | 0 | 1.78 | 25 | t 5 | 0 | 2.57 | 23 | t 6 | 0 | 2.32 | 22 | t 7 | 0 | 2.01 | 20 | t 8 | 0 | 2.05 | 19 | t 9 | 0 | 2.16 | 17 | t 10 | 0 | 3.6 | 16 | t 11 | 1 | 2.3 | 15 | t 12 | 0 | 2.88 | 13 | t 13 | 1 | 1.5 | 12 | t 14 | 0 | 2.6 | 11 | t 15 | 0 | 2.7 | 10 | t 16 | 0 | 2.8 | 9 | t 17 | 1 | 2.32 | 8 | t 18 | 0 | 4.43 | 7 | t 19 | 0 | 2.31 | 6 | t 20 | 1 | 3.49 | 5 | t 21 | 1 | 2.42 | 4 | t 22 | 1 | 4.01 | 3 | t 23 | 1 | 4.91 | 2 | t 24 | 1 | 5 | 1 | t \. -- Run coxph function SELECT {schema_madlib}.coxph_train( 'sample_data', 'sample_cox', 'timedeath', 'ARRAY[grp,wbc]', 'status'); -- View the Cox PH model SELECT * FROM sample_cox; -- Predict back on the original dataset SELECT madlib.coxph_predict('sample_cox', 'sample_data', 'id', 'sample_pred', 'risk'); """ else: help_string = "No such option. Use {schema_madlib}.coxph_predict()" return help_string.format(schema_madlib=schema_madlib)