/* ----------------------------------------------------------------------- *//** * * @file WeightedSample_impl.hpp * *//* ----------------------------------------------------------------------- */ #ifndef MADLIB_MODULES_SAMPLE_WEIGHTED_SAMPLE_IMPL_HPP #define MADLIB_MODULES_SAMPLE_WEIGHTED_SAMPLE_IMPL_HPP #include // Import TR1 names (currently used from boost). This can go away once we make // the switch to C++11. namespace std { using tr1::bernoulli_distribution; } namespace madlib { namespace modules { namespace sample { template inline WeightedSampleAccumulator::WeightedSampleAccumulator( Init_type& inInitialization) : Base(inInitialization) { this->initialize(); } template inline void bindWeightedSampleAcc( WeightedSampleAccumulator& ioAccumulator, typename WeightedSampleAccumulator::ByteStream_type& inStream) { inStream >> ioAccumulator.weight_sum >> ioAccumulator.sample; } template inline void bindWeightedSampleAcc( WeightedSampleAccumulator& ioAccumulator, typename WeightedSampleAccumulator ::ByteStream_type& inStream) { inStream >> ioAccumulator.weight_sum >> ioAccumulator.header.width; uint32_t actualWidth = ioAccumulator.header.width.isNull() ? 0 : static_cast(ioAccumulator.header.width); inStream >> ioAccumulator.sample.rebind(actualWidth); } /** * @brief Bind all elements of the state to the data in the stream * * The bind() is special in that even after running operator>>() on an element, * there is no guarantee yet that the element can indeed be accessed. It is * cruicial to first check this. * * Provided that this methods correctly lists all member variables, all other * methods can, however, rely on that fact that all variables are correctly * initialized and accessible. */ template inline void WeightedSampleAccumulator::bind(ByteStream_type& inStream) { bindWeightedSampleAcc(*this, inStream); } template inline void prepareSample(WeightedSampleAccumulator&, const T&) { } template inline void prepareSample( WeightedSampleAccumulator& ioAccumulator, const MappedColumnVector& inX) { uint32_t width = static_cast(inX.size()); if (width > ioAccumulator.header.width) { ioAccumulator.header.width = width; ioAccumulator.resize(); } } /** * @brief Update the accumulation state */ template inline WeightedSampleAccumulator& WeightedSampleAccumulator::operator<<( const tuple_type& inTuple) { const T& x = std::get<0>(inTuple); const double& weight = std::get<1>(inTuple); // Instead of throwing an error, we will just ignore rows with a negative // weight if (weight > 0.) { weight_sum += weight; std::bernoulli_distribution success(weight / weight_sum); // Note that a NativeRandomNumberGenerator object is stateless, so it // is not a problem to instantiate an object for each RN generation... NativeRandomNumberGenerator generator; if (success(generator)) { prepareSample(*this, x); sample = x; } } return *this; } /** * @brief Merge with another accumulation state */ template template inline WeightedSampleAccumulator& WeightedSampleAccumulator::operator<<( const WeightedSampleAccumulator& inOther) { // Initialize if necessary if (weight_sum == 0) { *this = inOther; return *this; } *this << tuple_type(inOther.sample, inOther.weight_sum); return *this; } template template inline WeightedSampleAccumulator& WeightedSampleAccumulator::operator=( const WeightedSampleAccumulator& inOther) { this->copy(inOther); return *this; } } // namespace sample } // namespace modules } // namespace madlib #endif // defined(MADLIB_MODULES_REGRESS_LINEAR_REGRESSION_IMPL_HPP)