#include "dbconnector/dbconnector.hpp" #include "elastic_net_binomial_fista.hpp" #include "state/fista.hpp" #include "elastic_net_optimizer_fista.hpp" #include "share/shared_utils.hpp" namespace madlib { namespace modules { namespace elastic_net { /* This class contains specific methods needed by Gaussian model using FISTA */ class BinomialFista { public: static void initialize(FistaState >& state); static void get_y(double& y, AnyType& args); static void normal_transition(FistaState >& state, MappedColumnVector& x, double y); static void active_transition(FistaState >& state, MappedColumnVector& x, double y); static void update_b_intercept(FistaState >& state); static void update_loglikelihood(FistaState >& state, MappedColumnVector& x, double y); static void update_y_intercept(FistaState >& state, double old_tk); static void update_y_intercept_final(FistaState >& state); static void merge_intercept(FistaState >& state1, FistaState >& state2); private: static void backtracking_transition(FistaState >& state, MappedColumnVector& x, double y); }; // ------------------------------------------------------------------------ inline void BinomialFista::update_y_intercept_final(FistaState >& state) { state.gradient_intercept = state.gradient_intercept / static_cast(state.totalRows); } // ----------------------------------------------------------------------------- /** @brief Compute log-likelihood for one data point in binomial models */ inline void BinomialFista::update_loglikelihood( FistaState >& state, MappedColumnVector& x, double y) { double r = state.intercept + sparse_dot(state.coef, x); if (y > 0) state.loglikelihood += std::log(1 + std::exp(-r)); else state.loglikelihood += std::log(1 + std::exp(r)); } // ------------------------------------------------------------------------ inline void BinomialFista::merge_intercept( FistaState >& state1, FistaState >& state2) { state1.gradient_intercept += state2.gradient_intercept; } // ------------------------------------------------------------------------ inline void BinomialFista::initialize(FistaState >& state) { state.coef.setZero(); state.coef_y.setZero(); state.intercept = 0; state.intercept_y = 0; } // ------------------------------------------------------------------------ // extract dependent variable from args inline void BinomialFista::get_y(double& y, AnyType& args) { y = args[2].getAs() ? 1. : -1.; } // ------------------------------------------------------------------------ inline void BinomialFista::normal_transition(FistaState >& state, MappedColumnVector& x, double y) { if (state.backtracking == 0) { double r = state.intercept_y + sparse_dot(state.coef_y, x); double u; if (y > 0) u = - 1. / (1. + std::exp(r)); else u = 1. / (1. + std::exp(-r)); for (uint32_t i = 0; i < state.dimension; i++) state.gradient(i) += x(i) * u; // update gradient state.gradient_intercept += u; } else backtracking_transition(state, x, y); } // ------------------------------------------------------------------------ inline void BinomialFista::active_transition(FistaState >& state, MappedColumnVector& x, double y) { if (state.backtracking == 0) // Compute gradient for active set { double r = state.intercept_y + sparse_dot(state.coef_y, x); double u; if (y > 0) u = - 1. / (1. + std::exp(r)); else u = 1. / (1. + std::exp(-r)); for (uint32_t i = 0; i < state.dimension; i++) if (state.coef_y(i) != 0) state.gradient(i) += x(i) * u; // always update intercept state.gradient_intercept += u; } else backtracking_transition(state, x, y); } // ------------------------------------------------------------------------ inline void BinomialFista::backtracking_transition(FistaState >& state, MappedColumnVector& x, double y) { // during backtracking, always use b_coef and b_intercept double r = state.b_intercept + sparse_dot(state.b_coef, x); if (y > 0) state.fn += std::log(1 + std::exp(-r)); else state.fn += std::log(1 + std::exp(r)); // Qfn only need to be calculated once in each backtracking if (state.backtracking == 1) { r = state.intercept_y + sparse_dot(state.coef_y, x); if (y > 0) state.Qfn += std::log(1 + std::exp(-r)); else state.Qfn += std::log(1 + std::exp(r)); } } // ------------------------------------------------------------------------ inline void BinomialFista::update_b_intercept (FistaState >& state) { state.b_intercept = state.intercept_y - state.stepsize * state.gradient_intercept; } // ------------------------------------------------------------------------ inline void BinomialFista::update_y_intercept (FistaState >& state, double old_tk) { state.intercept_y = state.b_intercept + (old_tk - 1) * (state.b_intercept - state.intercept) / state.tk; } // ------------------------------------------------------------------------ // ------------------------------------------------------------------------ // ------------------------------------------------------------------------ /* The following are the functions that are actually called by SQL */ /** @brief Perform FISTA transition step It is called for each tuple of (x, y) */ AnyType binomial_fista_transition::run (AnyType& args) { return Fista::fista_transition(args, *this); } // ------------------------------------------------------------------------ /** @brief Perform Merge transition steps */ AnyType binomial_fista_merge::run (AnyType& args) { return Fista::fista_merge(args); } // ------------------------------------------------------------------------ /** @brief Perform the final computation */ AnyType binomial_fista_final::run (AnyType& args) { return Fista::fista_final(args); } // ------------------------------------------------------------------------ /** * @brief Return the difference in RMSE between two states */ AnyType __binomial_fista_state_diff::run (AnyType& args) { return Fista::fista_state_diff(args); } // ------------------------------------------------------------------------ /** * @brief Return the coefficients and diagnostic statistics of the state */ AnyType __binomial_fista_result::run (AnyType& args) { return Fista::fista_result(args); } } } }