| 1 | #include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" |
| 2 | |
| 3 | #include "duckdb/execution/expression_executor.hpp" |
| 4 | #include "duckdb/function/aggregate_function.hpp" |
| 5 | #include "duckdb/parallel/thread_context.hpp" |
| 6 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
| 7 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
| 8 | |
| 9 | namespace duckdb { |
| 10 | |
| 11 | PhysicalStreamingWindow::PhysicalStreamingWindow(vector<LogicalType> types, vector<unique_ptr<Expression>> select_list, |
| 12 | idx_t estimated_cardinality, PhysicalOperatorType type) |
| 13 | : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { |
| 14 | } |
| 15 | |
| 16 | class StreamingWindowGlobalState : public GlobalOperatorState { |
| 17 | public: |
| 18 | StreamingWindowGlobalState() : row_number(1) { |
| 19 | } |
| 20 | |
| 21 | //! The next row number. |
| 22 | std::atomic<int64_t> row_number; |
| 23 | }; |
| 24 | |
| 25 | class StreamingWindowState : public OperatorState { |
| 26 | public: |
| 27 | using StateBuffer = vector<data_t>; |
| 28 | |
| 29 | StreamingWindowState() : initialized(false), statev(LogicalType::POINTER, data_ptr_cast(src: &state_ptr)) { |
| 30 | } |
| 31 | |
| 32 | ~StreamingWindowState() override { |
| 33 | for (size_t i = 0; i < aggregate_dtors.size(); ++i) { |
| 34 | auto dtor = aggregate_dtors[i]; |
| 35 | if (dtor) { |
| 36 | AggregateInputData aggr_input_data(aggregate_bind_data[i], Allocator::DefaultAllocator()); |
| 37 | state_ptr = aggregate_states[i].data(); |
| 38 | dtor(statev, aggr_input_data, 1); |
| 39 | } |
| 40 | } |
| 41 | } |
| 42 | |
| 43 | void Initialize(ClientContext &context, DataChunk &input, const vector<unique_ptr<Expression>> &expressions) { |
| 44 | const_vectors.resize(new_size: expressions.size()); |
| 45 | aggregate_states.resize(new_size: expressions.size()); |
| 46 | aggregate_bind_data.resize(new_size: expressions.size(), x: nullptr); |
| 47 | aggregate_dtors.resize(new_size: expressions.size(), x: nullptr); |
| 48 | |
| 49 | for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { |
| 50 | auto &expr = *expressions[expr_idx]; |
| 51 | auto &wexpr = expr.Cast<BoundWindowExpression>(); |
| 52 | switch (expr.GetExpressionType()) { |
| 53 | case ExpressionType::WINDOW_AGGREGATE: { |
| 54 | auto &aggregate = *wexpr.aggregate; |
| 55 | auto &state = aggregate_states[expr_idx]; |
| 56 | aggregate_bind_data[expr_idx] = wexpr.bind_info.get(); |
| 57 | aggregate_dtors[expr_idx] = aggregate.destructor; |
| 58 | state.resize(new_size: aggregate.state_size()); |
| 59 | aggregate.initialize(state.data()); |
| 60 | break; |
| 61 | } |
| 62 | case ExpressionType::WINDOW_FIRST_VALUE: { |
| 63 | // Just execute the expression once |
| 64 | ExpressionExecutor executor(context); |
| 65 | executor.AddExpression(expr: *wexpr.children[0]); |
| 66 | DataChunk result; |
| 67 | result.Initialize(allocator&: Allocator::Get(context), types: {wexpr.children[0]->return_type}); |
| 68 | executor.Execute(input, result); |
| 69 | |
| 70 | const_vectors[expr_idx] = make_uniq<Vector>(args: result.GetValue(col_idx: 0, index: 0)); |
| 71 | break; |
| 72 | } |
| 73 | case ExpressionType::WINDOW_PERCENT_RANK: { |
| 74 | const_vectors[expr_idx] = make_uniq<Vector>(args: Value((double)0)); |
| 75 | break; |
| 76 | } |
| 77 | case ExpressionType::WINDOW_RANK: |
| 78 | case ExpressionType::WINDOW_RANK_DENSE: { |
| 79 | const_vectors[expr_idx] = make_uniq<Vector>(args: Value((int64_t)1)); |
| 80 | break; |
| 81 | } |
| 82 | default: |
| 83 | break; |
| 84 | } |
| 85 | } |
| 86 | initialized = true; |
| 87 | } |
| 88 | |
| 89 | public: |
| 90 | bool initialized; |
| 91 | vector<unique_ptr<Vector>> const_vectors; |
| 92 | |
| 93 | // Aggregation |
| 94 | vector<StateBuffer> aggregate_states; |
| 95 | vector<FunctionData *> aggregate_bind_data; |
| 96 | vector<aggregate_destructor_t> aggregate_dtors; |
| 97 | data_ptr_t state_ptr; |
| 98 | Vector statev; |
| 99 | }; |
| 100 | |
| 101 | unique_ptr<GlobalOperatorState> PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { |
| 102 | return make_uniq<StreamingWindowGlobalState>(); |
| 103 | } |
| 104 | |
| 105 | unique_ptr<OperatorState> PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { |
| 106 | return make_uniq<StreamingWindowState>(); |
| 107 | } |
| 108 | |
| 109 | OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
| 110 | GlobalOperatorState &gstate_p, OperatorState &state_p) const { |
| 111 | auto &gstate = gstate_p.Cast<StreamingWindowGlobalState>(); |
| 112 | auto &state = state_p.Cast<StreamingWindowState>(); |
| 113 | if (!state.initialized) { |
| 114 | state.Initialize(context&: context.client, input, expressions: select_list); |
| 115 | } |
| 116 | // Put payload columns in place |
| 117 | for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { |
| 118 | chunk.data[col_idx].Reference(other&: input.data[col_idx]); |
| 119 | } |
| 120 | // Compute window function |
| 121 | const idx_t count = input.size(); |
| 122 | for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { |
| 123 | idx_t col_idx = input.data.size() + expr_idx; |
| 124 | auto &expr = *select_list[expr_idx]; |
| 125 | auto &result = chunk.data[col_idx]; |
| 126 | switch (expr.GetExpressionType()) { |
| 127 | case ExpressionType::WINDOW_AGGREGATE: { |
| 128 | // Establish the aggregation environment |
| 129 | auto &wexpr = expr.Cast<BoundWindowExpression>(); |
| 130 | auto &aggregate = *wexpr.aggregate; |
| 131 | auto &statev = state.statev; |
| 132 | state.state_ptr = state.aggregate_states[expr_idx].data(); |
| 133 | AggregateInputData aggr_input_data(wexpr.bind_info.get(), Allocator::DefaultAllocator()); |
| 134 | |
| 135 | // Check for COUNT(*) |
| 136 | if (wexpr.children.empty()) { |
| 137 | D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); |
| 138 | auto data = FlatVector::GetData<int64_t>(vector&: result); |
| 139 | int64_t start_row = gstate.row_number; |
| 140 | for (idx_t i = 0; i < input.size(); ++i) { |
| 141 | data[i] = start_row + i; |
| 142 | } |
| 143 | break; |
| 144 | } |
| 145 | |
| 146 | // Compute the arguments |
| 147 | auto &allocator = Allocator::Get(context&: context.client); |
| 148 | ExpressionExecutor executor(context.client); |
| 149 | vector<LogicalType> payload_types; |
| 150 | for (auto &child : wexpr.children) { |
| 151 | payload_types.push_back(x: child->return_type); |
| 152 | executor.AddExpression(expr: *child); |
| 153 | } |
| 154 | |
| 155 | DataChunk payload; |
| 156 | payload.Initialize(allocator, types: payload_types); |
| 157 | executor.Execute(input, result&: payload); |
| 158 | |
| 159 | // Iterate through them using a single SV |
| 160 | payload.Flatten(); |
| 161 | DataChunk row; |
| 162 | row.Initialize(allocator, types: payload_types); |
| 163 | sel_t s = 0; |
| 164 | SelectionVector sel(&s); |
| 165 | row.Slice(sel_vector: sel, count: 1); |
| 166 | for (size_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) { |
| 167 | DictionaryVector::Child(vector&: row.data[col_idx]).Reference(other&: payload.data[col_idx]); |
| 168 | } |
| 169 | |
| 170 | // Update the state and finalize it one row at a time. |
| 171 | for (idx_t i = 0; i < input.size(); ++i) { |
| 172 | sel.set_index(idx: 0, loc: i); |
| 173 | aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1); |
| 174 | aggregate.finalize(statev, aggr_input_data, result, 1, i); |
| 175 | } |
| 176 | break; |
| 177 | } |
| 178 | case ExpressionType::WINDOW_FIRST_VALUE: |
| 179 | case ExpressionType::WINDOW_PERCENT_RANK: |
| 180 | case ExpressionType::WINDOW_RANK: |
| 181 | case ExpressionType::WINDOW_RANK_DENSE: { |
| 182 | // Reference constant vector |
| 183 | chunk.data[col_idx].Reference(other&: *state.const_vectors[expr_idx]); |
| 184 | break; |
| 185 | } |
| 186 | case ExpressionType::WINDOW_ROW_NUMBER: { |
| 187 | // Set row numbers |
| 188 | int64_t start_row = gstate.row_number; |
| 189 | auto rdata = FlatVector::GetData<int64_t>(vector&: chunk.data[col_idx]); |
| 190 | for (idx_t i = 0; i < count; i++) { |
| 191 | rdata[i] = start_row + i; |
| 192 | } |
| 193 | break; |
| 194 | } |
| 195 | default: |
| 196 | throw NotImplementedException("%s for StreamingWindow" , ExpressionTypeToString(type: expr.GetExpressionType())); |
| 197 | } |
| 198 | } |
| 199 | gstate.row_number += count; |
| 200 | chunk.SetCardinality(count); |
| 201 | return OperatorResultType::NEED_MORE_INPUT; |
| 202 | } |
| 203 | |
| 204 | string PhysicalStreamingWindow::ParamsToString() const { |
| 205 | string result; |
| 206 | for (idx_t i = 0; i < select_list.size(); i++) { |
| 207 | if (i > 0) { |
| 208 | result += "\n" ; |
| 209 | } |
| 210 | result += select_list[i]->GetName(); |
| 211 | } |
| 212 | return result; |
| 213 | } |
| 214 | |
| 215 | } // namespace duckdb |
| 216 | |