| 1 | #include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" |
| 2 | #include "duckdb/execution/reservoir_sample.hpp" |
| 3 | |
| 4 | namespace duckdb { |
| 5 | |
| 6 | //===--------------------------------------------------------------------===// |
| 7 | // Sink |
| 8 | //===--------------------------------------------------------------------===// |
| 9 | class SampleGlobalSinkState : public GlobalSinkState { |
| 10 | public: |
| 11 | explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) { |
| 12 | if (options.is_percentage) { |
| 13 | auto percentage = options.sample_size.GetValue<double>(); |
| 14 | if (percentage == 0) { |
| 15 | return; |
| 16 | } |
| 17 | sample = make_uniq<ReservoirSamplePercentage>(args&: allocator, args&: percentage, args&: options.seed); |
| 18 | } else { |
| 19 | auto size = options.sample_size.GetValue<int64_t>(); |
| 20 | if (size == 0) { |
| 21 | return; |
| 22 | } |
| 23 | sample = make_uniq<ReservoirSample>(args&: allocator, args&: size, args&: options.seed); |
| 24 | } |
| 25 | } |
| 26 | |
| 27 | //! The lock for updating the global aggregate state |
| 28 | mutex lock; |
| 29 | //! The reservoir sample |
| 30 | unique_ptr<BlockingSample> sample; |
| 31 | }; |
| 32 | |
| 33 | unique_ptr<GlobalSinkState> PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const { |
| 34 | return make_uniq<SampleGlobalSinkState>(args&: Allocator::Get(context), args&: *options); |
| 35 | } |
| 36 | |
| 37 | SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk, |
| 38 | OperatorSinkInput &input) const { |
| 39 | auto &gstate = input.global_state.Cast<SampleGlobalSinkState>(); |
| 40 | if (!gstate.sample) { |
| 41 | return SinkResultType::FINISHED; |
| 42 | } |
| 43 | // we implement reservoir sampling without replacement and exponential jumps here |
| 44 | // the algorithm is adopted from the paper Weighted random sampling with a reservoir by Pavlos S. Efraimidis et al. |
| 45 | // note that the original algorithm is about weighted sampling; this is a simplified approach for uniform sampling |
| 46 | lock_guard<mutex> glock(gstate.lock); |
| 47 | gstate.sample->AddToReservoir(input&: chunk); |
| 48 | return SinkResultType::NEED_MORE_INPUT; |
| 49 | } |
| 50 | |
| 51 | //===--------------------------------------------------------------------===// |
| 52 | // Source |
| 53 | //===--------------------------------------------------------------------===// |
| 54 | SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, |
| 55 | OperatorSourceInput &input) const { |
| 56 | auto &sink = this->sink_state->Cast<SampleGlobalSinkState>(); |
| 57 | if (!sink.sample) { |
| 58 | return SourceResultType::FINISHED; |
| 59 | } |
| 60 | auto sample_chunk = sink.sample->GetChunk(); |
| 61 | if (!sample_chunk) { |
| 62 | return SourceResultType::FINISHED; |
| 63 | } |
| 64 | chunk.Move(chunk&: *sample_chunk); |
| 65 | |
| 66 | return SourceResultType::HAVE_MORE_OUTPUT; |
| 67 | } |
| 68 | |
| 69 | string PhysicalReservoirSample::ParamsToString() const { |
| 70 | return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows" ); |
| 71 | } |
| 72 | |
| 73 | } // namespace duckdb |
| 74 | |