1 | #include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" |
2 | #include "duckdb/common/random_engine.hpp" |
3 | #include "duckdb/common/to_string.hpp" |
4 | #include "duckdb/common/enum_util.hpp" |
5 | |
6 | namespace duckdb { |
7 | |
8 | PhysicalStreamingSample::PhysicalStreamingSample(vector<LogicalType> types, SampleMethod method, double percentage, |
9 | int64_t seed, idx_t estimated_cardinality) |
10 | : PhysicalOperator(PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), method(method), |
11 | percentage(percentage / 100), seed(seed) { |
12 | } |
13 | |
14 | //===--------------------------------------------------------------------===// |
15 | // Operator |
16 | //===--------------------------------------------------------------------===// |
17 | class StreamingSampleOperatorState : public OperatorState { |
18 | public: |
19 | explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { |
20 | } |
21 | |
22 | RandomEngine random; |
23 | }; |
24 | |
25 | void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { |
26 | // system sampling: we throw one dice per chunk |
27 | auto &state = state_p.Cast<StreamingSampleOperatorState>(); |
28 | double rand = state.random.NextRandom(); |
29 | if (rand <= percentage) { |
30 | // rand is smaller than sample_size: output chunk |
31 | result.Reference(chunk&: input); |
32 | } |
33 | } |
34 | |
35 | void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { |
36 | // bernoulli sampling: we throw one dice per tuple |
37 | // then slice the result chunk |
38 | auto &state = state_p.Cast<StreamingSampleOperatorState>(); |
39 | idx_t result_count = 0; |
40 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
41 | for (idx_t i = 0; i < input.size(); i++) { |
42 | double rand = state.random.NextRandom(); |
43 | if (rand <= percentage) { |
44 | sel.set_index(idx: result_count++, loc: i); |
45 | } |
46 | } |
47 | if (result_count > 0) { |
48 | result.Slice(other&: input, sel, count: result_count); |
49 | } |
50 | } |
51 | |
52 | unique_ptr<OperatorState> PhysicalStreamingSample::GetOperatorState(ExecutionContext &context) const { |
53 | return make_uniq<StreamingSampleOperatorState>(args: seed); |
54 | } |
55 | |
56 | OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
57 | GlobalOperatorState &gstate, OperatorState &state) const { |
58 | switch (method) { |
59 | case SampleMethod::BERNOULLI_SAMPLE: |
60 | BernoulliSample(input, result&: chunk, state_p&: state); |
61 | break; |
62 | case SampleMethod::SYSTEM_SAMPLE: |
63 | SystemSample(input, result&: chunk, state_p&: state); |
64 | break; |
65 | default: |
66 | throw InternalException("Unsupported sample method for streaming sample" ); |
67 | } |
68 | return OperatorResultType::NEED_MORE_INPUT; |
69 | } |
70 | |
71 | string PhysicalStreamingSample::ParamsToString() const { |
72 | return EnumUtil::ToString(value: method) + ": " + to_string(val: 100 * percentage) + "%" ; |
73 | } |
74 | |
75 | } // namespace duckdb |
76 | |