/* ----------------------------------------------------------------------- */ /** * * @file cox_proportional_hazards.cpp * * @brief Cox proportional hazards * * ----------------------------------------------------------------------- */ #include #include #include #include #include "cox_prop_hazards.hpp" #include "CoxPHState.hpp" namespace madlib { namespace modules { namespace stats { // Use Eigen using namespace dbal::eigen_integration; // Import names from other MADlib modules using dbal::NoSolutionFoundException; using namespace std; // ---------------------------------------------------------------------- AnyType coxph_step_outer_transition::run(AnyType &args) { if (args[0].isNull()) return args[1]; if (args[1].isNull()) return args[0]; CoxPHState > stateLeft = args[0]; CoxPHState > stateRight = args[1]; stateLeft += stateRight; return stateLeft; } // ---------------------------------------------------------------------- /** * @brief Newton method final step for Cox Proportional Hazards * */ AnyType coxph_step_inner_final::run(AnyType &args) { CoxPHState > state = args[0]; // If we haven't seen any data, just return Null. if (state.numRows == 0) return Null(); if (!state.hessian.is_finite() || !state.grad.is_finite()) throw NoSolutionFoundException("Over- or underflow in intermediate " "calculation. Input data is likely of poor numerical condition."); // First merge all tied times of death for the last row state.grad -= state.multiplier*state.H/state.S; triangularView(state.hessian) -= ((state.H*trans(state.H))/(state.S*state.S) - state.V/state.S)*state.multiplier; state.logLikelihood -= state.multiplier*std::log(state.S); // Return all coefficients etc. in a tuple return state; } // ----------------------------------------------------------------------- // Schoenfeld Residual Aggregate // ----------------------------------------------------------------------- AnyType zph_transition::run(AnyType &args){ if (args[1].isNull()) { return args[0]; } MappedColumnVector x; try { // an exception is raised in the backend if args[2] contains nulls MappedColumnVector xx = args[1].getAs(); // x is a const reference, we can only rebind to change its pointer x.rebind(xx.memoryHandle(), xx.size()); } catch (const ArrayWithNullException &e) { return args[0]; } if (!dbal::eigen_integration::isfinite(x)) throw std::domain_error("Design matrix is not finite."); int data_dim = static_cast(x.size()); if (data_dim > std::numeric_limits::max()) throw std::domain_error( "Number of independent variables cannot be larger than 65535."); MutableArrayHandle state(NULL); if (args[0].isNull()){ // state[0:data_dim-1] - x * exp(coeff . x) // state[data_dim] - exp(coeff . x) state = allocateArray(data_dim + 1); } else { state = args[0].getAs >(); } MutableNativeColumnVector coef(allocateArray(data_dim)); if (args[2].isNull()){ for (int i=0; i < data_dim ; i++) coef(i) = 0; } else { coef = args[2].getAs(); } double exp_coef_x = std::exp(trans(coef) * x); MutableNativeColumnVector x_exp_coef_x(allocateArray(data_dim)); x_exp_coef_x = exp_coef_x * x; for (int i =0; i < data_dim; i++) state[i] += x_exp_coef_x[i]; state[data_dim] += exp_coef_x; return state; } // ------------------------------------------------------------------------- AnyType zph_merge::run(AnyType &/* args */){ throw std::logic_error("The aggregate is used as an aggregate over window. " "The merge function should not be used in this scenario."); } // ------------------------------------------------------------------------- AnyType zph_final::run(AnyType &args){ if (args[0].isNull()) return Null(); MutableArrayHandle state(NULL); state = args[0].getAs >(); size_t data_dim = state.size()-1; MutableArrayHandle result = allocateArray(data_dim); for (size_t i = 0; i < data_dim; i++) result[i] = state[i]/state[data_dim]; return result; } // ----------------------------------------------------------------------- // Correlation aggregate between an array and a scalar // ----------------------------------------------------------------------- /** * @brief Transition state for the Cox Proportional Hazards * * TransitionState encapsulates the transition state during the * aggregate functions. To the database, the state is exposed * as a single DOUBLE PRECISION array, to the C++ code it is a proper object * containing scalars, a vector, and a matrix. * * Note: We assume that the DOUBLE PRECISION array is initialized by the * database with length at least 5, and all elements are 0. */ template class ArrayElemCorrState { template friend class ArrayElemCorrState; public: ArrayElemCorrState(const AnyType &inArray) : mStorage(inArray.getAs()) { rebind(static_cast(mStorage[1])); } /** * @brief Convert to backend representation * * We define this function so that we can use TransitionState in the argument * list and as a return type. */ inline operator AnyType() const { return mStorage; } /** * @brief Initialize the transition state. Only called for first row. * * @param inAllocator Allocator for the memory transition state. Must fill * the memory block with zeros. * @param inWidthOfX Number of independent variables. The first row of data * determines the size of the transition state. This size is a quadratic * function of inWidthOfX. */ inline void initialize(const Allocator &inAllocator, uint16_t inWidthOfX) { mStorage = inAllocator.allocateArray(arraySize(inWidthOfX)); rebind(inWidthOfX); widthOfX = inWidthOfX; this->reset(); } /** * @brief We need to support assigning the previous state */ template ArrayElemCorrState &operator=( const ArrayElemCorrState &inOtherState) { for (size_t i = 0; i < mStorage.size(); i++) mStorage[i] = inOtherState.mStorage[i]; return *this; } /** * @brief Merge with another State object by copying the intra-iteration * fields */ template ArrayElemCorrState &operator+=( const ArrayElemCorrState &inOtherState) { if (mStorage.size() != inOtherState.mStorage.size() || widthOfX != inOtherState.widthOfX) throw std::logic_error( "Internal error: Incompatible transition states"); numRows += inOtherState.numRows; sum_y += inOtherState.sum_y; sum_yy += inOtherState.sum_yy; sum_x += inOtherState.sum_x; sum_xx += inOtherState.sum_xx; sum_xy += inOtherState.sum_xy; return *this; } /** * @brief Reset the inter-iteration fields. */ inline void reset() { numRows = 0; sum_y = 0; sum_yy = 0; sum_xy.fill(0); sum_x.fill(0); sum_xx.fill(0); } private: static inline size_t arraySize(const uint16_t inWidthOfX) { return 4 + 3 * inWidthOfX; } /** * @brief Rebind to a new storage array * * @param inWidthOfX The number of independent variables. * * Array layout: * Inter iteration components (updated in the final step) * - 0: numRows (number of rows seen so far) * - 1: widthOfX (number of features) * - 2: coef (multipliers for each of the features) * Intra interation components (updated in the current interation) * - 2 + widthofX: S (see design document for details) * - 3 + widthofX: Hi[j] (see design document for details) * - 3 + 2*widthofX: gradCoef (coefficients of the gradient) * - 3 + 3*widthofX: logLikelihood * - 4 + 3*widthofX: V (Precomputations for the hessian) * - 4 + 3*widthofX + widthofX^2: hessian * */ void rebind(uint16_t inWidthOfX) { // Inter iteration components numRows.rebind(&mStorage[0]); widthOfX.rebind(&mStorage[1]); sum_y.rebind(&mStorage[2]); sum_yy.rebind(&mStorage[3]); // Intra iteration components sum_xy.rebind(&mStorage[4], inWidthOfX); sum_x.rebind(&mStorage[4+inWidthOfX], inWidthOfX); sum_xx.rebind(&mStorage[4+2*inWidthOfX], inWidthOfX); } Handle mStorage; public: typename HandleTraits::ReferenceToUInt64 numRows; typename HandleTraits::ReferenceToUInt16 widthOfX; typename HandleTraits::ReferenceToDouble sum_y; typename HandleTraits::ReferenceToDouble sum_yy; typename HandleTraits::ColumnVectorTransparentHandleMap sum_xy; typename HandleTraits::ColumnVectorTransparentHandleMap sum_x; typename HandleTraits::ColumnVectorTransparentHandleMap sum_xx; }; AnyType array_elem_corr_transition::run(AnyType &args){ ArrayElemCorrState > state = args[0]; if (args[1].isNull() || args[2].isNull()) { return args[0]; } MappedColumnVector x; try { // an exception is raised in the backend if input data contains nulls MappedColumnVector xx = args[1].getAs(); // x is a const reference, we can only rebind to change its pointer x.rebind(xx.memoryHandle(), xx.size()); } catch (const ArrayWithNullException &e) { return args[0]; } double y = args[2].getAs(); if (!dbal::eigen_integration::isfinite(x)) throw std::domain_error("Design matrix is not finite."); if (x.size() > std::numeric_limits::max()) throw std::domain_error( "Number of variables cannot be larger than 65535."); if (state.numRows == 0) { state.initialize(*this, static_cast(x.size())); } state.numRows += 1; state.sum_y += y; state.sum_yy += y*y; state.sum_x += x; state.sum_xx += x.cwiseProduct(x); // perform element-wise multiplication state.sum_xy += x * y; // perform scalar multiplication return state; } AnyType array_elem_corr_merge::run(AnyType &args) { ArrayElemCorrState > stateLeft = args[0]; ArrayElemCorrState > stateRight = args[1]; // We first handle the trivial case where this function is called with one // of the states being the initial state if (stateLeft.numRows == 0) return stateRight; else if (stateRight.numRows == 0) return stateLeft; // Merge states together and return stateLeft += stateRight; return stateLeft; } AnyType array_elem_corr_final::run(AnyType &args) { ArrayElemCorrState > state = args[0]; // If we haven't seen any data, just return Null. if (state.numRows == 0) return Null(); ColumnVector S_xy = static_cast(state.numRows) * state.sum_xy - state.sum_x * state.sum_y; ColumnVector S_xx = static_cast(state.numRows) * state.sum_xx - state.sum_x.cwiseProduct(state.sum_x); double S_yy = static_cast(state.numRows) * state.sum_yy - state.sum_y * state.sum_y; ColumnVector correlation = S_xy.cwiseQuotient(S_xx.cwiseSqrt() * sqrt(S_yy)); return correlation; } AnyType coxph_resid_stat_transition::run(AnyType &args) { double w = args[1].getAs(); ArrayHandle residual = args[2].getAs >(); ArrayHandle hessian = args[3].getAs >(); int m = args[4].getAs(); MutableArrayHandle state(NULL); if(args[0].isNull()){ int n = static_cast(residual.size()); // state[0]: m // state[1]: n // state[2]: w_t * w // state[3:2 + n]: w_t * residual (size = n) // state[n + 3:n + n * n + 2]: hessian (size = n*n) state = allocateArray(n * n + n + 3); state[0] = m; state[1] = n; for(size_t i = 0; i < hessian.size(); i++) state[n + 3 + i] = hessian[i]; }else{ state = args[0].getAs >(); } state[2] += w * w; for(size_t i = 0; i < residual.size(); i++) state[3 + i] += residual[i] * w; return state; } AnyType coxph_resid_stat_merge::run(AnyType &args) { if(args[0].isNull()) return args[1]; if(args[1].isNull()) return args[0]; MutableArrayHandle state1 = args[0].getAs >(); ArrayHandle state2 = args[1].getAs >(); int n = static_cast(state1[1]); for(int i = 2; i <= n + 2; i++) state1[i] += state2[i]; return state1; } AnyType coxph_resid_stat_final::run(AnyType &args) { if(args[0].isNull()) return Null(); using boost::math::complement; MutableArrayHandle state = args[0].getAs >(); int m = static_cast(state[0]); int n = static_cast(state[1]); double w_trans_w = state[2]; Eigen::Map w_trans_residual(state.ptr() + 3, 1, n); Eigen::Map hessian(state.ptr() + 3 + n, n, n); SymmetricPositiveDefiniteEigenDecomposition decomposition( hessian, EigenvaluesOnly, ComputePseudoInverse); Matrix inverse_of_hessian = decomposition.pseudoInverse(); ColumnVector v = m * (w_trans_residual * inverse_of_hessian).transpose(); ColumnVector v_v = v.cwiseProduct(v); ColumnVector covar_diagonal = inverse_of_hessian.diagonal(); ColumnVector z = v_v.cwiseQuotient(covar_diagonal * m * w_trans_w); ColumnVector p(n); for(int i = 0; i < n; i++) p(i) = prob::cdf(complement(prob::chi_squared(static_cast(1)), z(i))); AnyType tuple; tuple << z << p; return tuple; } AnyType coxph_scale_resid::run(AnyType &args) { int m = args[0].getAs(); MappedMatrix hessian = args[1].getAs(); MappedColumnVector residual = args[2].getAs(); SymmetricPositiveDefiniteEigenDecomposition decomposition( hessian, EigenvaluesOnly, ComputePseudoInverse); Matrix inverse_of_hessian = decomposition.pseudoInverse(); ColumnVector scaled_residual = (inverse_of_hessian * Matrix(residual)) * m; return scaled_residual; } AnyType coxph_predict_resp::run(AnyType &args) { try { args[0].getAs(); } catch (const ArrayWithNullException &e) { throw std::runtime_error( "coxph error: the coefficients contain NULL values"); } // returns NULL if args[1] (features) contains NULL values try { args[1].getAs(); } catch (const ArrayWithNullException &e) { return Null(); } MappedColumnVector coefs = args[0].getAs(); MappedColumnVector indep = args[1].getAs(); MappedColumnVector mean_indep = args[2].getAs(); std::string predtype = args[3].getAs(); if (coefs.size() != indep.size()) throw std::runtime_error( "Coefficients and independent variables are of incompatible length"); if (coefs.size() != mean_indep.size()) throw std::runtime_error( "Coefficients and mean vector of independent variables are of incompatible length"); double dot = coefs.dot(indep); double meandot = coefs.dot(mean_indep); double diff = dot - meandot; if (predtype.compare("linear_predictors")==0) return(diff); if (predtype.compare("risk")==0) return(exp(diff)); throw std::runtime_error("Invalid prediction type!"); } AnyType coxph_predict_terms::run(AnyType &args) { try { args[0].getAs(); } catch (const ArrayWithNullException &e) { throw std::runtime_error( "coxph error: the coefficients contain NULL values"); } // returns NULL if args[1] (features) contains NULL values try { args[1].getAs(); } catch (const ArrayWithNullException &e) { return Null(); } MappedColumnVector coefs = args[0].getAs(); MappedColumnVector indep = args[1].getAs(); MappedColumnVector mean_indep = args[2].getAs(); if (coefs.size() != indep.size()) throw std::runtime_error( "Coefficients and independent variables are of incompatible length"); if (coefs.size() != mean_indep.size()) throw std::runtime_error( "Coefficients and mean vector of independent variables are of incompatible length"); return ColumnVector(coefs.cwiseProduct(indep - mean_indep)); } } // namespace stats } // namespace modules } // namespace madlib