1 | #include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" |
2 | #include "duckdb/execution/reservoir_sample.hpp" |
3 | |
4 | namespace duckdb { |
5 | |
6 | //===--------------------------------------------------------------------===// |
7 | // Sink |
8 | //===--------------------------------------------------------------------===// |
9 | class SampleGlobalSinkState : public GlobalSinkState { |
10 | public: |
11 | explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) { |
12 | if (options.is_percentage) { |
13 | auto percentage = options.sample_size.GetValue<double>(); |
14 | if (percentage == 0) { |
15 | return; |
16 | } |
17 | sample = make_uniq<ReservoirSamplePercentage>(args&: allocator, args&: percentage, args&: options.seed); |
18 | } else { |
19 | auto size = options.sample_size.GetValue<int64_t>(); |
20 | if (size == 0) { |
21 | return; |
22 | } |
23 | sample = make_uniq<ReservoirSample>(args&: allocator, args&: size, args&: options.seed); |
24 | } |
25 | } |
26 | |
27 | //! The lock for updating the global aggregate state |
28 | mutex lock; |
29 | //! The reservoir sample |
30 | unique_ptr<BlockingSample> sample; |
31 | }; |
32 | |
33 | unique_ptr<GlobalSinkState> PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const { |
34 | return make_uniq<SampleGlobalSinkState>(args&: Allocator::Get(context), args&: *options); |
35 | } |
36 | |
37 | SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk, |
38 | OperatorSinkInput &input) const { |
39 | auto &gstate = input.global_state.Cast<SampleGlobalSinkState>(); |
40 | if (!gstate.sample) { |
41 | return SinkResultType::FINISHED; |
42 | } |
43 | // we implement reservoir sampling without replacement and exponential jumps here |
44 | // the algorithm is adopted from the paper Weighted random sampling with a reservoir by Pavlos S. Efraimidis et al. |
45 | // note that the original algorithm is about weighted sampling; this is a simplified approach for uniform sampling |
46 | lock_guard<mutex> glock(gstate.lock); |
47 | gstate.sample->AddToReservoir(input&: chunk); |
48 | return SinkResultType::NEED_MORE_INPUT; |
49 | } |
50 | |
51 | //===--------------------------------------------------------------------===// |
52 | // Source |
53 | //===--------------------------------------------------------------------===// |
54 | SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, |
55 | OperatorSourceInput &input) const { |
56 | auto &sink = this->sink_state->Cast<SampleGlobalSinkState>(); |
57 | if (!sink.sample) { |
58 | return SourceResultType::FINISHED; |
59 | } |
60 | auto sample_chunk = sink.sample->GetChunk(); |
61 | if (!sample_chunk) { |
62 | return SourceResultType::FINISHED; |
63 | } |
64 | chunk.Move(chunk&: *sample_chunk); |
65 | |
66 | return SourceResultType::HAVE_MORE_OUTPUT; |
67 | } |
68 | |
69 | string PhysicalReservoirSample::ParamsToString() const { |
70 | return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows" ); |
71 | } |
72 | |
73 | } // namespace duckdb |
74 | |