/* ----------------------------------------------------------------------- *//** * * @file weighted_sample.cpp * * @brief Generate a single weighted random sample * *//* ----------------------------------------------------------------------- */ #include #include #include "WeightedSample_proto.hpp" #include "WeightedSample_impl.hpp" #include "weighted_sample.hpp" namespace madlib { namespace modules { namespace sample { typedef WeightedSampleAccumulator WeightedSampleInt64State; typedef WeightedSampleAccumulator MutableWeightedSampleInt64State; typedef WeightedSampleAccumulator WeightedSampleColVecState; typedef WeightedSampleAccumulator MutableWeightedSampleColVecState; /** * @brief Perform the weighted-sample transition step */ AnyType weighted_sample_transition_int64::run(AnyType& args) { MutableWeightedSampleInt64State state = args[0].getAs(); int64_t x = args[1].getAs(); double weight = args[2].getAs(); state << WeightedSampleInt64State::tuple_type(x, weight); return state.storage(); } AnyType weighted_sample_transition_vector::run(AnyType& args) { MutableWeightedSampleColVecState state = args[0].getAs(); MappedColumnVector x = args[1].getAs(); double weight = args[2].getAs(); state << WeightedSampleColVecState::tuple_type(x, weight); return state.storage(); } /** * @brief Perform the merging of two transition states */ AnyType weighted_sample_merge_int64::run(AnyType &args) { MutableWeightedSampleInt64State stateLeft = args[0].getAs(); WeightedSampleInt64State stateRight = args[1].getAs(); stateLeft << stateRight; return stateLeft.storage(); } AnyType weighted_sample_merge_vector::run(AnyType &args) { MutableWeightedSampleColVecState stateLeft = args[0].getAs(); WeightedSampleColVecState stateRight = args[1].getAs(); stateLeft << stateRight; return stateLeft.storage(); } /** * @brief Perform the weighted-sample final step */ AnyType weighted_sample_final_int64::run(AnyType &args) { WeightedSampleInt64State state = args[0].getAs(); return static_cast(state.sample); } AnyType weighted_sample_final_vector::run(AnyType &args) { WeightedSampleColVecState state = args[0].getAs(); return state.sample; } /** * @brief In-memory weighted sample, returning index */ AnyType index_weighted_sample::run(AnyType &args) { MappedColumnVector distribution; try { MappedColumnVector xx = args[0].getAs(); distribution.rebind(xx.memoryHandle(), xx.size()); } catch (const ArrayWithNullException &e) { return Null(); } boost::random::discrete_distribution<> dist(distribution.data(), distribution.data() + distribution.size()); NativeRandomNumberGenerator gen; return dist(gen); } } // namespace sample } // namespace modules } // namespace madlib