| 1 | #include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" |
| 2 | #include "duckdb/common/random_engine.hpp" |
| 3 | #include "duckdb/common/to_string.hpp" |
| 4 | #include "duckdb/common/enum_util.hpp" |
| 5 | |
| 6 | namespace duckdb { |
| 7 | |
| 8 | PhysicalStreamingSample::PhysicalStreamingSample(vector<LogicalType> types, SampleMethod method, double percentage, |
| 9 | int64_t seed, idx_t estimated_cardinality) |
| 10 | : PhysicalOperator(PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), method(method), |
| 11 | percentage(percentage / 100), seed(seed) { |
| 12 | } |
| 13 | |
| 14 | //===--------------------------------------------------------------------===// |
| 15 | // Operator |
| 16 | //===--------------------------------------------------------------------===// |
| 17 | class StreamingSampleOperatorState : public OperatorState { |
| 18 | public: |
| 19 | explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { |
| 20 | } |
| 21 | |
| 22 | RandomEngine random; |
| 23 | }; |
| 24 | |
| 25 | void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { |
| 26 | // system sampling: we throw one dice per chunk |
| 27 | auto &state = state_p.Cast<StreamingSampleOperatorState>(); |
| 28 | double rand = state.random.NextRandom(); |
| 29 | if (rand <= percentage) { |
| 30 | // rand is smaller than sample_size: output chunk |
| 31 | result.Reference(chunk&: input); |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { |
| 36 | // bernoulli sampling: we throw one dice per tuple |
| 37 | // then slice the result chunk |
| 38 | auto &state = state_p.Cast<StreamingSampleOperatorState>(); |
| 39 | idx_t result_count = 0; |
| 40 | SelectionVector sel(STANDARD_VECTOR_SIZE); |
| 41 | for (idx_t i = 0; i < input.size(); i++) { |
| 42 | double rand = state.random.NextRandom(); |
| 43 | if (rand <= percentage) { |
| 44 | sel.set_index(idx: result_count++, loc: i); |
| 45 | } |
| 46 | } |
| 47 | if (result_count > 0) { |
| 48 | result.Slice(other&: input, sel, count: result_count); |
| 49 | } |
| 50 | } |
| 51 | |
| 52 | unique_ptr<OperatorState> PhysicalStreamingSample::GetOperatorState(ExecutionContext &context) const { |
| 53 | return make_uniq<StreamingSampleOperatorState>(args: seed); |
| 54 | } |
| 55 | |
| 56 | OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
| 57 | GlobalOperatorState &gstate, OperatorState &state) const { |
| 58 | switch (method) { |
| 59 | case SampleMethod::BERNOULLI_SAMPLE: |
| 60 | BernoulliSample(input, result&: chunk, state_p&: state); |
| 61 | break; |
| 62 | case SampleMethod::SYSTEM_SAMPLE: |
| 63 | SystemSample(input, result&: chunk, state_p&: state); |
| 64 | break; |
| 65 | default: |
| 66 | throw InternalException("Unsupported sample method for streaming sample" ); |
| 67 | } |
| 68 | return OperatorResultType::NEED_MORE_INPUT; |
| 69 | } |
| 70 | |
| 71 | string PhysicalStreamingSample::ParamsToString() const { |
| 72 | return EnumUtil::ToString(value: method) + ": " + to_string(val: 100 * percentage) + "%" ; |
| 73 | } |
| 74 | |
| 75 | } // namespace duckdb |
| 76 | |