1 | #include "duckdb/planner/expression/bound_window_expression.hpp" |
2 | #include "duckdb/parser/expression/window_expression.hpp" |
3 | |
4 | #include "duckdb/common/string_util.hpp" |
5 | #include "duckdb/function/aggregate_function.hpp" |
6 | #include "duckdb/function/function_serialization.hpp" |
7 | #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" |
8 | |
9 | namespace duckdb { |
10 | |
11 | BoundWindowExpression::BoundWindowExpression(ExpressionType type, LogicalType return_type, |
12 | unique_ptr<AggregateFunction> aggregate, |
13 | unique_ptr<FunctionData> bind_info) |
14 | : Expression(type, ExpressionClass::BOUND_WINDOW, std::move(return_type)), aggregate(std::move(aggregate)), |
15 | bind_info(std::move(bind_info)), ignore_nulls(false) { |
16 | } |
17 | |
18 | string BoundWindowExpression::ToString() const { |
19 | string function_name = aggregate.get() ? aggregate->name : ExpressionTypeToString(type); |
20 | return WindowExpression::ToString<BoundWindowExpression, Expression, BoundOrderByNode>(entry: *this, schema: string(), |
21 | function_name); |
22 | } |
23 | |
24 | bool BoundWindowExpression::Equals(const BaseExpression &other_p) const { |
25 | if (!Expression::Equals(other: other_p)) { |
26 | return false; |
27 | } |
28 | auto &other = other_p.Cast<BoundWindowExpression>(); |
29 | |
30 | if (ignore_nulls != other.ignore_nulls) { |
31 | return false; |
32 | } |
33 | if (start != other.start || end != other.end) { |
34 | return false; |
35 | } |
36 | // check if the child expressions are equivalent |
37 | if (!Expression::ListEquals(left: children, right: other.children)) { |
38 | return false; |
39 | } |
40 | // check if the filter expressions are equivalent |
41 | if (!Expression::Equals(left: filter_expr, right: other.filter_expr)) { |
42 | return false; |
43 | } |
44 | |
45 | // check if the framing expressions are equivalent |
46 | if (!Expression::Equals(left: start_expr, right: other.start_expr) || !Expression::Equals(left: end_expr, right: other.end_expr) || |
47 | !Expression::Equals(left: offset_expr, right: other.offset_expr) || !Expression::Equals(left: default_expr, right: other.default_expr)) { |
48 | return false; |
49 | } |
50 | |
51 | return KeysAreCompatible(other); |
52 | } |
53 | |
54 | bool BoundWindowExpression::KeysAreCompatible(const BoundWindowExpression &other) const { |
55 | // check if the partitions are equivalent |
56 | if (!Expression::ListEquals(left: partitions, right: other.partitions)) { |
57 | return false; |
58 | } |
59 | // check if the orderings are equivalent |
60 | if (orders.size() != other.orders.size()) { |
61 | return false; |
62 | } |
63 | for (idx_t i = 0; i < orders.size(); i++) { |
64 | if (orders[i].type != other.orders[i].type) { |
65 | return false; |
66 | } |
67 | if (!Expression::Equals(left: *orders[i].expression, right: *other.orders[i].expression)) { |
68 | return false; |
69 | } |
70 | } |
71 | return true; |
72 | } |
73 | |
74 | unique_ptr<Expression> BoundWindowExpression::Copy() { |
75 | auto new_window = make_uniq<BoundWindowExpression>(args&: type, args&: return_type, args: nullptr, args: nullptr); |
76 | new_window->CopyProperties(other&: *this); |
77 | |
78 | if (aggregate) { |
79 | new_window->aggregate = make_uniq<AggregateFunction>(args&: *aggregate); |
80 | } |
81 | if (bind_info) { |
82 | new_window->bind_info = bind_info->Copy(); |
83 | } |
84 | for (auto &child : children) { |
85 | new_window->children.push_back(x: child->Copy()); |
86 | } |
87 | for (auto &e : partitions) { |
88 | new_window->partitions.push_back(x: e->Copy()); |
89 | } |
90 | for (auto &ps : partitions_stats) { |
91 | if (ps) { |
92 | new_window->partitions_stats.push_back(x: ps->ToUnique()); |
93 | } else { |
94 | new_window->partitions_stats.push_back(x: nullptr); |
95 | } |
96 | } |
97 | for (auto &o : orders) { |
98 | new_window->orders.emplace_back(args&: o.type, args&: o.null_order, args: o.expression->Copy()); |
99 | } |
100 | |
101 | new_window->filter_expr = filter_expr ? filter_expr->Copy() : nullptr; |
102 | |
103 | new_window->start = start; |
104 | new_window->end = end; |
105 | new_window->start_expr = start_expr ? start_expr->Copy() : nullptr; |
106 | new_window->end_expr = end_expr ? end_expr->Copy() : nullptr; |
107 | new_window->offset_expr = offset_expr ? offset_expr->Copy() : nullptr; |
108 | new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; |
109 | new_window->ignore_nulls = ignore_nulls; |
110 | |
111 | return std::move(new_window); |
112 | } |
113 | |
114 | void BoundWindowExpression::Serialize(FieldWriter &writer) const { |
115 | writer.WriteField<bool>(element: aggregate.get()); |
116 | if (aggregate) { |
117 | D_ASSERT(return_type == aggregate->return_type); |
118 | FunctionSerializer::Serialize<AggregateFunction>(writer, function: *aggregate, return_type, children, bind_info: bind_info.get()); |
119 | } else { |
120 | // children and return_type are written as part of the aggregate function otherwise |
121 | writer.WriteSerializableList(elements: children); |
122 | writer.WriteSerializable(element: return_type); |
123 | } |
124 | writer.WriteSerializableList(elements: partitions); |
125 | writer.WriteRegularSerializableList(elements: orders); |
126 | // FIXME: partitions_stats |
127 | writer.WriteOptional(element: filter_expr); |
128 | writer.WriteField<bool>(element: ignore_nulls); |
129 | writer.WriteField<WindowBoundary>(element: start); |
130 | writer.WriteField<WindowBoundary>(element: end); |
131 | writer.WriteOptional(element: start_expr); |
132 | writer.WriteOptional(element: end_expr); |
133 | writer.WriteOptional(element: offset_expr); |
134 | writer.WriteOptional(element: default_expr); |
135 | } |
136 | |
137 | unique_ptr<Expression> BoundWindowExpression::Deserialize(ExpressionDeserializationState &state, FieldReader &reader) { |
138 | auto has_aggregate = reader.ReadRequired<bool>(); |
139 | unique_ptr<AggregateFunction> aggregate; |
140 | unique_ptr<FunctionData> bind_info; |
141 | vector<unique_ptr<Expression>> children; |
142 | LogicalType return_type; |
143 | if (has_aggregate) { |
144 | auto aggr_function = FunctionSerializer::Deserialize<AggregateFunction, AggregateFunctionCatalogEntry>( |
145 | reader, state, type: CatalogType::AGGREGATE_FUNCTION_ENTRY, children, bind_info); |
146 | aggregate = make_uniq<AggregateFunction>(args: std::move(aggr_function)); |
147 | return_type = aggregate->return_type; |
148 | } else { |
149 | children = reader.ReadRequiredSerializableList<Expression>(args&: state.gstate); |
150 | return_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
151 | } |
152 | auto result = make_uniq<BoundWindowExpression>(args&: state.type, args&: return_type, args: std::move(aggregate), args: std::move(bind_info)); |
153 | |
154 | result->partitions = reader.ReadRequiredSerializableList<Expression>(args&: state.gstate); |
155 | result->orders = reader.ReadRequiredSerializableList<BoundOrderByNode, BoundOrderByNode>(args&: state.gstate); |
156 | result->filter_expr = reader.ReadOptional<Expression>(default_value: nullptr, args&: state.gstate); |
157 | result->ignore_nulls = reader.ReadRequired<bool>(); |
158 | result->start = reader.ReadRequired<WindowBoundary>(); |
159 | result->end = reader.ReadRequired<WindowBoundary>(); |
160 | result->start_expr = reader.ReadOptional<Expression>(default_value: nullptr, args&: state.gstate); |
161 | result->end_expr = reader.ReadOptional<Expression>(default_value: nullptr, args&: state.gstate); |
162 | result->offset_expr = reader.ReadOptional<Expression>(default_value: nullptr, args&: state.gstate); |
163 | result->default_expr = reader.ReadOptional<Expression>(default_value: nullptr, args&: state.gstate); |
164 | result->children = std::move(children); |
165 | return std::move(result); |
166 | } |
167 | |
168 | } // namespace duckdb |
169 | |