| 1 | #include "duckdb/execution/operator/aggregate/aggregate_object.hpp" |
| 2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 3 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
| 4 | |
| 5 | namespace duckdb { |
| 6 | |
| 7 | AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, |
| 8 | idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, |
| 9 | Expression *filter) |
| 10 | : function(std::move(function)), |
| 11 | bind_data_wrapper(bind_data ? make_shared<FunctionDataWrapper>(args: bind_data->Copy()) : nullptr), |
| 12 | child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), |
| 13 | filter(filter) { |
| 14 | } |
| 15 | |
| 16 | AggregateObject::AggregateObject(BoundAggregateExpression *aggr) |
| 17 | : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), |
| 18 | AlignValue(n: aggr->function.state_size()), aggr->aggr_type, aggr->return_type.InternalType(), |
| 19 | aggr->filter.get()) { |
| 20 | } |
| 21 | |
| 22 | AggregateObject::AggregateObject(BoundWindowExpression &window) |
| 23 | : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), |
| 24 | AlignValue(n: window.aggregate->state_size()), AggregateType::NON_DISTINCT, |
| 25 | window.return_type.InternalType(), window.filter_expr.get()) { |
| 26 | } |
| 27 | |
| 28 | vector<AggregateObject> AggregateObject::CreateAggregateObjects(const vector<BoundAggregateExpression *> &bindings) { |
| 29 | vector<AggregateObject> aggregates; |
| 30 | aggregates.reserve(n: aggregates.size()); |
| 31 | for (auto &binding : bindings) { |
| 32 | aggregates.emplace_back(args: binding); |
| 33 | } |
| 34 | return aggregates; |
| 35 | } |
| 36 | |
| 37 | AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr, |
| 38 | const vector<LogicalType> &payload_types) |
| 39 | : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { |
| 40 | if (payload_types.empty()) { |
| 41 | return; |
| 42 | } |
| 43 | filtered_payload.Initialize(allocator&: Allocator::Get(context), types: payload_types); |
| 44 | } |
| 45 | |
| 46 | idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) { |
| 47 | filtered_payload.Reset(); |
| 48 | |
| 49 | auto count = filter_executor.SelectExpression(input&: payload, sel&: true_sel); |
| 50 | filtered_payload.Slice(other&: payload, sel: true_sel, count); |
| 51 | return count; |
| 52 | } |
| 53 | |
| 54 | AggregateFilterDataSet::AggregateFilterDataSet() { |
| 55 | } |
| 56 | |
| 57 | void AggregateFilterDataSet::Initialize(ClientContext &context, const vector<AggregateObject> &aggregates, |
| 58 | const vector<LogicalType> &payload_types) { |
| 59 | bool has_filters = false; |
| 60 | for (auto &aggregate : aggregates) { |
| 61 | if (aggregate.filter) { |
| 62 | has_filters = true; |
| 63 | break; |
| 64 | } |
| 65 | } |
| 66 | if (!has_filters) { |
| 67 | // no filters: nothing to do |
| 68 | return; |
| 69 | } |
| 70 | filter_data.resize(new_size: aggregates.size()); |
| 71 | for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { |
| 72 | auto &aggr = aggregates[aggr_idx]; |
| 73 | if (aggr.filter) { |
| 74 | filter_data[aggr_idx] = make_uniq<AggregateFilterData>(args&: context, args&: *aggr.filter, args: payload_types); |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) { |
| 80 | D_ASSERT(aggr_idx < filter_data.size()); |
| 81 | D_ASSERT(filter_data[aggr_idx]); |
| 82 | return *filter_data[aggr_idx]; |
| 83 | } |
| 84 | } // namespace duckdb |
| 85 | |