1#include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp"
2#include "duckdb/execution/reservoir_sample.hpp"
3
4namespace duckdb {
5
6//===--------------------------------------------------------------------===//
7// Sink
8//===--------------------------------------------------------------------===//
9class SampleGlobalSinkState : public GlobalSinkState {
10public:
11 explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) {
12 if (options.is_percentage) {
13 auto percentage = options.sample_size.GetValue<double>();
14 if (percentage == 0) {
15 return;
16 }
17 sample = make_uniq<ReservoirSamplePercentage>(args&: allocator, args&: percentage, args&: options.seed);
18 } else {
19 auto size = options.sample_size.GetValue<int64_t>();
20 if (size == 0) {
21 return;
22 }
23 sample = make_uniq<ReservoirSample>(args&: allocator, args&: size, args&: options.seed);
24 }
25 }
26
27 //! The lock for updating the global aggregate state
28 mutex lock;
29 //! The reservoir sample
30 unique_ptr<BlockingSample> sample;
31};
32
33unique_ptr<GlobalSinkState> PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const {
34 return make_uniq<SampleGlobalSinkState>(args&: Allocator::Get(context), args&: *options);
35}
36
37SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk,
38 OperatorSinkInput &input) const {
39 auto &gstate = input.global_state.Cast<SampleGlobalSinkState>();
40 if (!gstate.sample) {
41 return SinkResultType::FINISHED;
42 }
43 // we implement reservoir sampling without replacement and exponential jumps here
44 // the algorithm is adopted from the paper Weighted random sampling with a reservoir by Pavlos S. Efraimidis et al.
45 // note that the original algorithm is about weighted sampling; this is a simplified approach for uniform sampling
46 lock_guard<mutex> glock(gstate.lock);
47 gstate.sample->AddToReservoir(input&: chunk);
48 return SinkResultType::NEED_MORE_INPUT;
49}
50
51//===--------------------------------------------------------------------===//
52// Source
53//===--------------------------------------------------------------------===//
54SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk,
55 OperatorSourceInput &input) const {
56 auto &sink = this->sink_state->Cast<SampleGlobalSinkState>();
57 if (!sink.sample) {
58 return SourceResultType::FINISHED;
59 }
60 auto sample_chunk = sink.sample->GetChunk();
61 if (!sample_chunk) {
62 return SourceResultType::FINISHED;
63 }
64 chunk.Move(chunk&: *sample_chunk);
65
66 return SourceResultType::HAVE_MORE_OUTPUT;
67}
68
69string PhysicalReservoirSample::ParamsToString() const {
70 return options->sample_size.ToString() + (options->is_percentage ? "%" : " rows");
71}
72
73} // namespace duckdb
74