1#include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp"
2
3#include "duckdb/execution/expression_executor.hpp"
4#include "duckdb/function/aggregate_function.hpp"
5#include "duckdb/parallel/thread_context.hpp"
6#include "duckdb/planner/expression/bound_reference_expression.hpp"
7#include "duckdb/planner/expression/bound_window_expression.hpp"
8
9namespace duckdb {
10
11PhysicalStreamingWindow::PhysicalStreamingWindow(vector<LogicalType> types, vector<unique_ptr<Expression>> select_list,
12 idx_t estimated_cardinality, PhysicalOperatorType type)
13 : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) {
14}
15
16class StreamingWindowGlobalState : public GlobalOperatorState {
17public:
18 StreamingWindowGlobalState() : row_number(1) {
19 }
20
21 //! The next row number.
22 std::atomic<int64_t> row_number;
23};
24
25class StreamingWindowState : public OperatorState {
26public:
27 using StateBuffer = vector<data_t>;
28
29 StreamingWindowState() : initialized(false), statev(LogicalType::POINTER, data_ptr_cast(src: &state_ptr)) {
30 }
31
32 ~StreamingWindowState() override {
33 for (size_t i = 0; i < aggregate_dtors.size(); ++i) {
34 auto dtor = aggregate_dtors[i];
35 if (dtor) {
36 AggregateInputData aggr_input_data(aggregate_bind_data[i], Allocator::DefaultAllocator());
37 state_ptr = aggregate_states[i].data();
38 dtor(statev, aggr_input_data, 1);
39 }
40 }
41 }
42
43 void Initialize(ClientContext &context, DataChunk &input, const vector<unique_ptr<Expression>> &expressions) {
44 const_vectors.resize(new_size: expressions.size());
45 aggregate_states.resize(new_size: expressions.size());
46 aggregate_bind_data.resize(new_size: expressions.size(), x: nullptr);
47 aggregate_dtors.resize(new_size: expressions.size(), x: nullptr);
48
49 for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) {
50 auto &expr = *expressions[expr_idx];
51 auto &wexpr = expr.Cast<BoundWindowExpression>();
52 switch (expr.GetExpressionType()) {
53 case ExpressionType::WINDOW_AGGREGATE: {
54 auto &aggregate = *wexpr.aggregate;
55 auto &state = aggregate_states[expr_idx];
56 aggregate_bind_data[expr_idx] = wexpr.bind_info.get();
57 aggregate_dtors[expr_idx] = aggregate.destructor;
58 state.resize(new_size: aggregate.state_size());
59 aggregate.initialize(state.data());
60 break;
61 }
62 case ExpressionType::WINDOW_FIRST_VALUE: {
63 // Just execute the expression once
64 ExpressionExecutor executor(context);
65 executor.AddExpression(expr: *wexpr.children[0]);
66 DataChunk result;
67 result.Initialize(allocator&: Allocator::Get(context), types: {wexpr.children[0]->return_type});
68 executor.Execute(input, result);
69
70 const_vectors[expr_idx] = make_uniq<Vector>(args: result.GetValue(col_idx: 0, index: 0));
71 break;
72 }
73 case ExpressionType::WINDOW_PERCENT_RANK: {
74 const_vectors[expr_idx] = make_uniq<Vector>(args: Value((double)0));
75 break;
76 }
77 case ExpressionType::WINDOW_RANK:
78 case ExpressionType::WINDOW_RANK_DENSE: {
79 const_vectors[expr_idx] = make_uniq<Vector>(args: Value((int64_t)1));
80 break;
81 }
82 default:
83 break;
84 }
85 }
86 initialized = true;
87 }
88
89public:
90 bool initialized;
91 vector<unique_ptr<Vector>> const_vectors;
92
93 // Aggregation
94 vector<StateBuffer> aggregate_states;
95 vector<FunctionData *> aggregate_bind_data;
96 vector<aggregate_destructor_t> aggregate_dtors;
97 data_ptr_t state_ptr;
98 Vector statev;
99};
100
101unique_ptr<GlobalOperatorState> PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const {
102 return make_uniq<StreamingWindowGlobalState>();
103}
104
105unique_ptr<OperatorState> PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const {
106 return make_uniq<StreamingWindowState>();
107}
108
109OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk,
110 GlobalOperatorState &gstate_p, OperatorState &state_p) const {
111 auto &gstate = gstate_p.Cast<StreamingWindowGlobalState>();
112 auto &state = state_p.Cast<StreamingWindowState>();
113 if (!state.initialized) {
114 state.Initialize(context&: context.client, input, expressions: select_list);
115 }
116 // Put payload columns in place
117 for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) {
118 chunk.data[col_idx].Reference(other&: input.data[col_idx]);
119 }
120 // Compute window function
121 const idx_t count = input.size();
122 for (idx_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) {
123 idx_t col_idx = input.data.size() + expr_idx;
124 auto &expr = *select_list[expr_idx];
125 auto &result = chunk.data[col_idx];
126 switch (expr.GetExpressionType()) {
127 case ExpressionType::WINDOW_AGGREGATE: {
128 // Establish the aggregation environment
129 auto &wexpr = expr.Cast<BoundWindowExpression>();
130 auto &aggregate = *wexpr.aggregate;
131 auto &statev = state.statev;
132 state.state_ptr = state.aggregate_states[expr_idx].data();
133 AggregateInputData aggr_input_data(wexpr.bind_info.get(), Allocator::DefaultAllocator());
134
135 // Check for COUNT(*)
136 if (wexpr.children.empty()) {
137 D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t));
138 auto data = FlatVector::GetData<int64_t>(vector&: result);
139 int64_t start_row = gstate.row_number;
140 for (idx_t i = 0; i < input.size(); ++i) {
141 data[i] = start_row + i;
142 }
143 break;
144 }
145
146 // Compute the arguments
147 auto &allocator = Allocator::Get(context&: context.client);
148 ExpressionExecutor executor(context.client);
149 vector<LogicalType> payload_types;
150 for (auto &child : wexpr.children) {
151 payload_types.push_back(x: child->return_type);
152 executor.AddExpression(expr: *child);
153 }
154
155 DataChunk payload;
156 payload.Initialize(allocator, types: payload_types);
157 executor.Execute(input, result&: payload);
158
159 // Iterate through them using a single SV
160 payload.Flatten();
161 DataChunk row;
162 row.Initialize(allocator, types: payload_types);
163 sel_t s = 0;
164 SelectionVector sel(&s);
165 row.Slice(sel_vector: sel, count: 1);
166 for (size_t col_idx = 0; col_idx < payload.ColumnCount(); ++col_idx) {
167 DictionaryVector::Child(vector&: row.data[col_idx]).Reference(other&: payload.data[col_idx]);
168 }
169
170 // Update the state and finalize it one row at a time.
171 for (idx_t i = 0; i < input.size(); ++i) {
172 sel.set_index(idx: 0, loc: i);
173 aggregate.update(row.data.data(), aggr_input_data, row.ColumnCount(), statev, 1);
174 aggregate.finalize(statev, aggr_input_data, result, 1, i);
175 }
176 break;
177 }
178 case ExpressionType::WINDOW_FIRST_VALUE:
179 case ExpressionType::WINDOW_PERCENT_RANK:
180 case ExpressionType::WINDOW_RANK:
181 case ExpressionType::WINDOW_RANK_DENSE: {
182 // Reference constant vector
183 chunk.data[col_idx].Reference(other&: *state.const_vectors[expr_idx]);
184 break;
185 }
186 case ExpressionType::WINDOW_ROW_NUMBER: {
187 // Set row numbers
188 int64_t start_row = gstate.row_number;
189 auto rdata = FlatVector::GetData<int64_t>(vector&: chunk.data[col_idx]);
190 for (idx_t i = 0; i < count; i++) {
191 rdata[i] = start_row + i;
192 }
193 break;
194 }
195 default:
196 throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(type: expr.GetExpressionType()));
197 }
198 }
199 gstate.row_number += count;
200 chunk.SetCardinality(count);
201 return OperatorResultType::NEED_MORE_INPUT;
202}
203
204string PhysicalStreamingWindow::ParamsToString() const {
205 string result;
206 for (idx_t i = 0; i < select_list.size(); i++) {
207 if (i > 0) {
208 result += "\n";
209 }
210 result += select_list[i]->GetName();
211 }
212 return result;
213}
214
215} // namespace duckdb
216