1 | #include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" |
2 | |
3 | #include "duckdb/execution/expression_executor.hpp" |
4 | #include "duckdb/function/aggregate_function.hpp" |
5 | #include "duckdb/parallel/thread_context.hpp" |
6 | #include "duckdb/planner/expression/bound_reference_expression.hpp" |
7 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | PhysicalStreamingWindow::PhysicalStreamingWindow(vector<LogicalType> types, vector<unique_ptr<Expression>> select_list, |
12 | idx_t estimated_cardinality, PhysicalOperatorType type) |
13 | : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { |
14 | } |
15 | |
16 | class StreamingWindowGlobalState : public GlobalOperatorState { |
17 | public: |
18 | StreamingWindowGlobalState() : row_number(1) { |
19 | } |
20 | |
21 | //! The next row number. |
22 | std::atomic<int64_t> row_number; |
23 | }; |
24 | |
25 | class StreamingWindowState : public OperatorState { |
26 | public: |
27 | using StateBuffer = vector<data_t>; |
28 | |
29 | StreamingWindowState() : initialized(false), statev(LogicalType::POINTER, data_ptr_cast(src: &state_ptr)) { |
30 | } |
31 | |
32 | ~StreamingWindowState() override { |
33 | for (size_t i = 0; i < aggregate_dtors.size(); ++i) { |
34 | auto dtor = aggregate_dtors[i]; |
35 | if (dtor) { |
36 | AggregateInputData aggr_input_data(aggregate_bind_data[i], Allocator::DefaultAllocator()); |
37 | state_ptr = aggregate_states[i].data(); |
38 | dtor(statev, aggr_input_data, 1); |
39 | } |
40 | } |
41 | } |
42 | |
43 | void Initialize(ClientContext &context, DataChunk &input, const vector<unique_ptr<Expression>> &expressions) { |
44 | const_vectors.resize(new_size: expressions.size()); |
45 | aggregate_states.resize(new_size: expressions.size()); |
46 | aggregate_bind_data.resize(new_size: expressions.size(), x: nullptr); |
47 | aggregate_dtors.resize(new_size: expressions.size(), x: nullptr); |
48 | |
49 | for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { |
50 | auto &expr = *expressions[expr_idx]; |
51 | auto &wexpr = expr.Cast<BoundWindowExpression>(); |
52 | switch (expr.GetExpressionType()) { |
53 | case ExpressionType::WINDOW_AGGREGATE: { |
54 | auto &aggregate = *wexpr.aggregate; |
55 | auto &state = aggregate_states[expr_idx]; |
56 | aggregate_bind_data[expr_idx] = wexpr.bind_info.get(); |
57 | aggregate_dtors[expr_idx] = aggregate.destructor; |
58 | state.resize(new_size: aggregate.state_size()); |
59 | aggregate.initialize(state.data()); |
60 | break; |
61 | } |
62 | case ExpressionType::WINDOW_FIRST_VALUE: { |
63 | // Just execute the expression once |
64 | ExpressionExecutor executor(context); |
65 | executor.AddExpression(expr: *wexpr.children[0]); |
66 | DataChunk result; |
67 | result.Initialize(allocator&: Allocator::Get(context), types: {wexpr.children[0]->return_type}); |
68 | executor.Execute(input, result); |
69 | |
70 | const_vectors[expr_idx] = make_uniq<Vector>(args: result.GetValue(col_idx: 0, index: 0)); |
71 | break; |
72 | } |
73 | case ExpressionType::WINDOW_PERCENT_RANK: { |
74 | const_vectors[expr_idx] = make_uniq<Vector>(args: Value((double)0)); |
75 | break; |
76 | } |
77 | case ExpressionType::WINDOW_RANK: |
78 | case ExpressionType::WINDOW_RANK_DENSE: { |
79 | const_vectors[expr_idx] = make_uniq<Vector>(args: Value((int64_t)1)); |
80 | break; |
81 | } |
82 | default: |
83 | break; |
84 | } |
85 | } |
86 | initialized = true; |
87 | } |
88 | |
89 | public: |
90 | bool initialized; |
91 | vector<unique_ptr<Vector>> const_vectors; |
92 | |
93 | // Aggregation |
94 | vector<StateBuffer> aggregate_states; |
95 | vector<FunctionData *> aggregate_bind_data; |
96 | vector<aggregate_destructor_t> aggregate_dtors; |
97 | data_ptr_t state_ptr; |
98 | Vector statev; |
99 | }; |
100 | |
101 | unique_ptr<GlobalOperatorState> PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { |
102 | return make_uniq<StreamingWindowGlobalState>(); |
103 | } |
104 | |
105 | unique_ptr<OperatorState> PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { |
106 | return make_uniq<StreamingWindowState>(); |
107 | } |
108 | |
109 | OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, |
110 | GlobalOperatorState &gstate_p, OperatorState &state_p) const { |
111 | auto &gstate = gstate_p.Cast<StreamingWindowGlobalState>(); |
112 | auto &state = state_p.Cast<StreamingWindowState>(); |
113 | if (!state.initialized) { |
114 | state.Initialize(context&: context.client, input, expressions: select_list); |
115 | } |
116 | // Put payload columns in place |
117 | for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { |
118 | chunk.data[col_idx].Reference(other&: input.data[col_idx]); |
119 | } |
120 | // Compute window function |
121 | const idx_t count = input.size(); |
122 | for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { |
123 | idx_t col_idx = input.data.size() + expr_idx; |
124 | auto &expr = *select_list[expr_idx]; |
125 | auto &result = chunk.data[col_idx]; |
126 | switch (expr.GetExpressionType()) { |
127 | case ExpressionType::WINDOW_AGGREGATE: { |
128 | // Establish the aggregation environment |
129 | auto &wexpr = expr.Cast<BoundWindowExpression>(); |
130 | auto &aggregate = *wexpr.aggregate; |
131 | auto &statev = state.statev; |
132 | state.state_ptr = state.aggregate_states[expr_idx].data(); |
133 | AggregateInputData aggr_input_data(wexpr.bind_info.get(), Allocator::DefaultAllocator()); |
134 | |
135 | // Check for COUNT(*) |
136 | if (wexpr.children.empty()) { |
137 | D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); |
138 | auto data = FlatVector::GetData<int64_t>(vector&: result); |
139 | int64_t start_row = gstate.row_number; |
140 | for (idx_t i = 0; i < input.size(); ++i) { |
141 | data[i] = start_row + i; |
142 | } |
143 | break; |
144 | } |
145 | |
146 | // Compute the arguments |
147 | auto &allocator = Allocator::Get(context&: context.client); |
148 | ExpressionExecutor executor(context.client); |
149 | vector<LogicalType> payload_types; |
150 | for (auto &child : wexpr.children) { |
151 | payload_types.push_back(x: child->return_type); |
152 | executor.AddExpression(expr: *child); |
153 | } |
154 | |
155 | DataChunk payload; |
156 | payload.Initialize(allocator, types: payload_types); |
157 | executor.Execute(input, result&: payload); |
158 | |
159 | // Iterate through them using a single SV |
160 | payload.Flatten(); |
161 | DataChunk row; |
162 | row.Initialize(allocator, types: payload_types); |
163 | sel_t s = 0; |
164 | SelectionVector sel(&s); |
165 | row.Slice(sel_vector: sel, count: 1); |
166 | for (size_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) { |
167 | DictionaryVector::Child(vector&: row.data[col_idx]).Reference(other&: payload.data[col_idx]); |
168 | } |
169 | |
170 | // Update the state and finalize it one row at a time. |
171 | for (idx_t i = 0; i < input.size(); ++i) { |
172 | sel.set_index(idx: 0, loc: i); |
173 | aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1); |
174 | aggregate.finalize(statev, aggr_input_data, result, 1, i); |
175 | } |
176 | break; |
177 | } |
178 | case ExpressionType::WINDOW_FIRST_VALUE: |
179 | case ExpressionType::WINDOW_PERCENT_RANK: |
180 | case ExpressionType::WINDOW_RANK: |
181 | case ExpressionType::WINDOW_RANK_DENSE: { |
182 | // Reference constant vector |
183 | chunk.data[col_idx].Reference(other&: *state.const_vectors[expr_idx]); |
184 | break; |
185 | } |
186 | case ExpressionType::WINDOW_ROW_NUMBER: { |
187 | // Set row numbers |
188 | int64_t start_row = gstate.row_number; |
189 | auto rdata = FlatVector::GetData<int64_t>(vector&: chunk.data[col_idx]); |
190 | for (idx_t i = 0; i < count; i++) { |
191 | rdata[i] = start_row + i; |
192 | } |
193 | break; |
194 | } |
195 | default: |
196 | throw NotImplementedException("%s for StreamingWindow" , ExpressionTypeToString(type: expr.GetExpressionType())); |
197 | } |
198 | } |
199 | gstate.row_number += count; |
200 | chunk.SetCardinality(count); |
201 | return OperatorResultType::NEED_MORE_INPUT; |
202 | } |
203 | |
204 | string PhysicalStreamingWindow::ParamsToString() const { |
205 | string result; |
206 | for (idx_t i = 0; i < select_list.size(); i++) { |
207 | if (i > 0) { |
208 | result += "\n" ; |
209 | } |
210 | result += select_list[i]->GetName(); |
211 | } |
212 | return result; |
213 | } |
214 | |
215 | } // namespace duckdb |
216 | |