1 | #include "duckdb/common/exception.hpp" |
2 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
3 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
4 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
5 | |
6 | namespace duckdb { |
7 | |
8 | struct BaseCountFunction { |
9 | template <class STATE> |
10 | static void Initialize(STATE &state) { |
11 | state = 0; |
12 | } |
13 | |
14 | template <class STATE, class OP> |
15 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
16 | target += source; |
17 | } |
18 | |
19 | template <class T, class STATE> |
20 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
21 | target = state; |
22 | } |
23 | }; |
24 | |
25 | struct CountStarFunction : public BaseCountFunction { |
26 | template <class STATE, class OP> |
27 | static void Operation(STATE &state, AggregateInputData &, idx_t idx) { |
28 | state += 1; |
29 | } |
30 | |
31 | template <class STATE, class OP> |
32 | static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) { |
33 | state += count; |
34 | } |
35 | |
36 | template <typename RESULT_TYPE> |
37 | static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, |
38 | idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, |
39 | Vector &result, idx_t rid, idx_t bias) { |
40 | D_ASSERT(input_count == 0); |
41 | auto data = FlatVector::GetData<RESULT_TYPE>(result); |
42 | const auto begin = frame.first; |
43 | const auto end = frame.second; |
44 | // Slice to any filtered rows |
45 | if (!filter_mask.AllValid()) { |
46 | RESULT_TYPE filtered = 0; |
47 | for (auto i = begin; i < end; ++i) { |
48 | filtered += filter_mask.RowIsValid(row_idx: i); |
49 | } |
50 | data[rid] = filtered; |
51 | } else { |
52 | data[rid] = end - begin; |
53 | } |
54 | } |
55 | }; |
56 | |
57 | struct CountFunction : public BaseCountFunction { |
58 | using STATE = int64_t; |
59 | |
60 | static void Operation(STATE &state) { |
61 | state += 1; |
62 | } |
63 | |
64 | static void ConstantOperation(STATE &state, idx_t count) { |
65 | state += count; |
66 | } |
67 | |
68 | static bool IgnoreNull() { |
69 | return true; |
70 | } |
71 | |
72 | static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) { |
73 | if (!mask.AllValid()) { |
74 | idx_t base_idx = 0; |
75 | auto entry_count = ValidityMask::EntryCount(count); |
76 | for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { |
77 | auto validity_entry = mask.GetValidityEntry(entry_idx); |
78 | idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count); |
79 | if (ValidityMask::AllValid(entry: validity_entry)) { |
80 | // all valid: perform operation |
81 | for (; base_idx < next; base_idx++) { |
82 | CountFunction::Operation(state&: *states[base_idx]); |
83 | } |
84 | } else if (ValidityMask::NoneValid(entry: validity_entry)) { |
85 | // nothing valid: skip all |
86 | base_idx = next; |
87 | continue; |
88 | } else { |
89 | // partially valid: need to check individual elements for validity |
90 | idx_t start = base_idx; |
91 | for (; base_idx < next; base_idx++) { |
92 | if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) { |
93 | CountFunction::Operation(state&: *states[base_idx]); |
94 | } |
95 | } |
96 | } |
97 | } |
98 | } else { |
99 | for (idx_t i = 0; i < count; i++) { |
100 | CountFunction::Operation(state&: *states[i]); |
101 | } |
102 | } |
103 | } |
104 | |
105 | static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel, |
106 | const SelectionVector &ssel, ValidityMask &mask, idx_t count) { |
107 | if (!mask.AllValid()) { |
108 | // potential NULL values |
109 | for (idx_t i = 0; i < count; i++) { |
110 | auto idx = isel.get_index(idx: i); |
111 | auto sidx = ssel.get_index(idx: i); |
112 | if (mask.RowIsValid(row_idx: idx)) { |
113 | CountFunction::Operation(state&: *states[sidx]); |
114 | } |
115 | } |
116 | } else { |
117 | // quick path: no NULL values |
118 | for (idx_t i = 0; i < count; i++) { |
119 | auto sidx = ssel.get_index(idx: i); |
120 | CountFunction::Operation(state&: *states[sidx]); |
121 | } |
122 | } |
123 | } |
124 | |
125 | static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, |
126 | idx_t count) { |
127 | auto &input = inputs[0]; |
128 | if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { |
129 | auto sdata = FlatVector::GetData<STATE *>(vector&: states); |
130 | CountFlatLoop(states: sdata, mask&: FlatVector::Validity(vector&: input), count); |
131 | } else { |
132 | UnifiedVectorFormat idata, sdata; |
133 | input.ToUnifiedFormat(count, data&: idata); |
134 | states.ToUnifiedFormat(count, data&: sdata); |
135 | CountScatterLoop(states: reinterpret_cast<STATE **>(sdata.data), isel: *idata.sel, ssel: *sdata.sel, mask&: idata.validity, count); |
136 | } |
137 | } |
138 | |
139 | static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) { |
140 | idx_t base_idx = 0; |
141 | auto entry_count = ValidityMask::EntryCount(count); |
142 | for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { |
143 | auto validity_entry = mask.GetValidityEntry(entry_idx); |
144 | idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count); |
145 | if (ValidityMask::AllValid(entry: validity_entry)) { |
146 | // all valid |
147 | result += next - base_idx; |
148 | base_idx = next; |
149 | } else if (ValidityMask::NoneValid(entry: validity_entry)) { |
150 | // nothing valid: skip all |
151 | base_idx = next; |
152 | continue; |
153 | } else { |
154 | // partially valid: need to check individual elements for validity |
155 | idx_t start = base_idx; |
156 | for (; base_idx < next; base_idx++) { |
157 | if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) { |
158 | result++; |
159 | } |
160 | } |
161 | } |
162 | } |
163 | } |
164 | |
165 | static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count, |
166 | const SelectionVector &sel_vector) { |
167 | if (mask.AllValid()) { |
168 | // no NULL values |
169 | result += count; |
170 | return; |
171 | } |
172 | for (idx_t i = 0; i < count; i++) { |
173 | auto idx = sel_vector.get_index(idx: i); |
174 | if (mask.RowIsValid(row_idx: idx)) { |
175 | result++; |
176 | } |
177 | } |
178 | } |
179 | |
180 | static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { |
181 | auto &input = inputs[0]; |
182 | auto &result = *reinterpret_cast<STATE *>(state_p); |
183 | switch (input.GetVectorType()) { |
184 | case VectorType::CONSTANT_VECTOR: { |
185 | if (!ConstantVector::IsNull(vector: input)) { |
186 | // if the constant is not null increment the state |
187 | result += count; |
188 | } |
189 | break; |
190 | } |
191 | case VectorType::FLAT_VECTOR: { |
192 | CountFlatUpdateLoop(result, mask&: FlatVector::Validity(vector&: input), count); |
193 | break; |
194 | } |
195 | case VectorType::SEQUENCE_VECTOR: { |
196 | // sequence vectors cannot have NULL values |
197 | result += count; |
198 | break; |
199 | } |
200 | default: { |
201 | UnifiedVectorFormat idata; |
202 | input.ToUnifiedFormat(count, data&: idata); |
203 | CountUpdateLoop(result, mask&: idata.validity, count, sel_vector: *idata.sel); |
204 | break; |
205 | } |
206 | } |
207 | } |
208 | }; |
209 | |
210 | AggregateFunction CountFun::GetFunction() { |
211 | AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize<int64_t>, |
212 | AggregateFunction::StateInitialize<int64_t, CountFunction>, CountFunction::CountScatter, |
213 | AggregateFunction::StateCombine<int64_t, CountFunction>, |
214 | AggregateFunction::StateFinalize<int64_t, int64_t, CountFunction>, |
215 | FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); |
216 | fun.name = "count" ; |
217 | fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; |
218 | return fun; |
219 | } |
220 | |
221 | static void CountStarSerialize(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function) { |
222 | } |
223 | |
224 | static unique_ptr<FunctionData> CountStarDeserialize(PlanDeserializationState &state, FieldReader &reader, |
225 | AggregateFunction &function) { |
226 | return nullptr; |
227 | } |
228 | |
229 | AggregateFunction CountStarFun::GetFunction() { |
230 | auto fun = AggregateFunction::NullaryAggregate<int64_t, int64_t, CountStarFunction>(return_type: LogicalType::BIGINT); |
231 | fun.name = "count_star" ; |
232 | fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
233 | fun.window = CountStarFunction::Window<int64_t>; |
234 | // TODO is there a better way to set those? |
235 | fun.serialize = CountStarSerialize; |
236 | fun.deserialize = CountStarDeserialize; |
237 | return fun; |
238 | } |
239 | |
240 | unique_ptr<BaseStatistics> CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, |
241 | AggregateStatisticsInput &input) { |
242 | if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { |
243 | // count on a column without null values: use count star |
244 | expr.function = CountStarFun::GetFunction(); |
245 | expr.function.name = "count_star" ; |
246 | expr.children.clear(); |
247 | } |
248 | return nullptr; |
249 | } |
250 | |
251 | void CountFun::RegisterFunction(BuiltinFunctions &set) { |
252 | AggregateFunction count_function = CountFun::GetFunction(); |
253 | count_function.statistics = CountPropagateStats; |
254 | AggregateFunctionSet count("count" ); |
255 | count.AddFunction(function: count_function); |
256 | // the count function can also be called without arguments |
257 | count_function.arguments.clear(); |
258 | count_function.statistics = nullptr; |
259 | count_function.window = CountStarFunction::Window<int64_t>; |
260 | count.AddFunction(function: count_function); |
261 | set.AddFunction(set: count); |
262 | } |
263 | |
264 | void CountStarFun::RegisterFunction(BuiltinFunctions &set) { |
265 | AggregateFunctionSet count("count_star" ); |
266 | count.AddFunction(function: CountStarFun::GetFunction()); |
267 | set.AddFunction(set: count); |
268 | } |
269 | |
270 | } // namespace duckdb |
271 | |