| 1 | #include "duckdb/execution/operator/projection/physical_pivot.hpp" |
| 2 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 3 | |
| 4 | namespace duckdb { |
| 5 | |
| 6 | PhysicalPivot::PhysicalPivot(vector<LogicalType> types_p, unique_ptr<PhysicalOperator> child, |
| 7 | BoundPivotInfo bound_pivot_p) |
| 8 | : PhysicalOperator(PhysicalOperatorType::PIVOT, std::move(types_p), child->estimated_cardinality), |
| 9 | bound_pivot(std::move(bound_pivot_p)) { |
| 10 | children.push_back(x: std::move(child)); |
| 11 | for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { |
| 12 | auto entry = pivot_map.find(x: bound_pivot.pivot_values[p]); |
| 13 | if (entry != pivot_map.end()) { |
| 14 | continue; |
| 15 | } |
| 16 | pivot_map[bound_pivot.pivot_values[p]] = bound_pivot.group_count + p; |
| 17 | } |
| 18 | // extract the empty aggregate expressions |
| 19 | for (auto &aggr_expr : bound_pivot.aggregates) { |
| 20 | auto &aggr = aggr_expr->Cast<BoundAggregateExpression>(); |
| 21 | // for each aggregate, initialize an empty aggregate state and finalize it immediately |
| 22 | auto state = make_unsafe_uniq_array<data_t>(n: aggr.function.state_size()); |
| 23 | aggr.function.initialize(state.get()); |
| 24 | Vector state_vector(Value::POINTER(value: CastPointerToValue(src: state.get()))); |
| 25 | Vector result_vector(aggr_expr->return_type); |
| 26 | AggregateInputData aggr_input_data(aggr.bind_info.get(), Allocator::DefaultAllocator()); |
| 27 | aggr.function.finalize(state_vector, aggr_input_data, result_vector, 1, 0); |
| 28 | empty_aggregates.push_back(x: result_vector.GetValue(index: 0)); |
| 29 | } |
| 30 | } |
| 31 | |
| 32 | OperatorResultType PhysicalPivot::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
| 33 | GlobalOperatorState &gstate, OperatorState &state) const { |
| 34 | // copy the groups as-is |
| 35 | for (idx_t i = 0; i < bound_pivot.group_count; i++) { |
| 36 | chunk.data[i].Reference(other&: input.data[i]); |
| 37 | } |
| 38 | auto pivot_column_lists = FlatVector::GetData<list_entry_t>(vector&: input.data.back()); |
| 39 | auto &pivot_column_values = ListVector::GetEntry(vector&: input.data.back()); |
| 40 | auto pivot_columns = FlatVector::GetData<string_t>(vector&: pivot_column_values); |
| 41 | |
| 42 | // initialize all aggregate columns with the empty aggregate value |
| 43 | // if there are multiple aggregates the columns are in order of [AGGR1][AGGR2][AGGR1][AGGR2] |
| 44 | // so we need to alternate the empty_aggregate that we use |
| 45 | idx_t aggregate = 0; |
| 46 | for (idx_t c = bound_pivot.group_count; c < chunk.ColumnCount(); c++) { |
| 47 | chunk.data[c].Reference(value: empty_aggregates[aggregate]); |
| 48 | chunk.data[c].Flatten(count: input.size()); |
| 49 | aggregate++; |
| 50 | if (aggregate >= empty_aggregates.size()) { |
| 51 | aggregate = 0; |
| 52 | } |
| 53 | } |
| 54 | |
| 55 | // move the pivots to the given columns |
| 56 | for (idx_t r = 0; r < input.size(); r++) { |
| 57 | auto list = pivot_column_lists[r]; |
| 58 | for (idx_t l = 0; l < list.length; l++) { |
| 59 | // figure out the column value number of this list |
| 60 | auto &column_name = pivot_columns[list.offset + l]; |
| 61 | auto entry = pivot_map.find(x: column_name); |
| 62 | if (entry == pivot_map.end()) { |
| 63 | // column entry not found in map - that means this element is explicitly excluded from the pivot list |
| 64 | continue; |
| 65 | } |
| 66 | auto column_idx = entry->second; |
| 67 | for (idx_t aggr = 0; aggr < empty_aggregates.size(); aggr++) { |
| 68 | auto pivot_value_lists = FlatVector::GetData<list_entry_t>(vector&: input.data[bound_pivot.group_count + aggr]); |
| 69 | auto &pivot_value_child = ListVector::GetEntry(vector&: input.data[bound_pivot.group_count + aggr]); |
| 70 | if (list.offset != pivot_value_lists[r].offset || list.length != pivot_value_lists[r].length) { |
| 71 | throw InternalException("Pivot - unaligned lists between values and columns!?" ); |
| 72 | } |
| 73 | chunk.data[column_idx + aggr].SetValue(index: r, val: pivot_value_child.GetValue(index: list.offset + l)); |
| 74 | } |
| 75 | } |
| 76 | } |
| 77 | chunk.SetCardinality(input.size()); |
| 78 | return OperatorResultType::NEED_MORE_INPUT; |
| 79 | } |
| 80 | |
| 81 | } // namespace duckdb |
| 82 | |