1#include "duckdb/common/exception.hpp"
2#include "duckdb/common/vector_operations/vector_operations.hpp"
3#include "duckdb/function/aggregate/distributive_functions.hpp"
4#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
5
6namespace duckdb {
7
8struct BaseCountFunction {
9 template <class STATE>
10 static void Initialize(STATE &state) {
11 state = 0;
12 }
13
14 template <class STATE, class OP>
15 static void Combine(const STATE &source, STATE &target, AggregateInputData &) {
16 target += source;
17 }
18
19 template <class T, class STATE>
20 static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) {
21 target = state;
22 }
23};
24
25struct CountStarFunction : public BaseCountFunction {
26 template <class STATE, class OP>
27 static void Operation(STATE &state, AggregateInputData &, idx_t idx) {
28 state += 1;
29 }
30
31 template <class STATE, class OP>
32 static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) {
33 state += count;
34 }
35
36 template <typename RESULT_TYPE>
37 static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data,
38 idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev,
39 Vector &result, idx_t rid, idx_t bias) {
40 D_ASSERT(input_count == 0);
41 auto data = FlatVector::GetData<RESULT_TYPE>(result);
42 const auto begin = frame.first;
43 const auto end = frame.second;
44 // Slice to any filtered rows
45 if (!filter_mask.AllValid()) {
46 RESULT_TYPE filtered = 0;
47 for (auto i = begin; i < end; ++i) {
48 filtered += filter_mask.RowIsValid(row_idx: i);
49 }
50 data[rid] = filtered;
51 } else {
52 data[rid] = end - begin;
53 }
54 }
55};
56
57struct CountFunction : public BaseCountFunction {
58 using STATE = int64_t;
59
60 static void Operation(STATE &state) {
61 state += 1;
62 }
63
64 static void ConstantOperation(STATE &state, idx_t count) {
65 state += count;
66 }
67
68 static bool IgnoreNull() {
69 return true;
70 }
71
72 static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) {
73 if (!mask.AllValid()) {
74 idx_t base_idx = 0;
75 auto entry_count = ValidityMask::EntryCount(count);
76 for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) {
77 auto validity_entry = mask.GetValidityEntry(entry_idx);
78 idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count);
79 if (ValidityMask::AllValid(entry: validity_entry)) {
80 // all valid: perform operation
81 for (; base_idx < next; base_idx++) {
82 CountFunction::Operation(state&: *states[base_idx]);
83 }
84 } else if (ValidityMask::NoneValid(entry: validity_entry)) {
85 // nothing valid: skip all
86 base_idx = next;
87 continue;
88 } else {
89 // partially valid: need to check individual elements for validity
90 idx_t start = base_idx;
91 for (; base_idx < next; base_idx++) {
92 if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) {
93 CountFunction::Operation(state&: *states[base_idx]);
94 }
95 }
96 }
97 }
98 } else {
99 for (idx_t i = 0; i < count; i++) {
100 CountFunction::Operation(state&: *states[i]);
101 }
102 }
103 }
104
105 static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel,
106 const SelectionVector &ssel, ValidityMask &mask, idx_t count) {
107 if (!mask.AllValid()) {
108 // potential NULL values
109 for (idx_t i = 0; i < count; i++) {
110 auto idx = isel.get_index(idx: i);
111 auto sidx = ssel.get_index(idx: i);
112 if (mask.RowIsValid(row_idx: idx)) {
113 CountFunction::Operation(state&: *states[sidx]);
114 }
115 }
116 } else {
117 // quick path: no NULL values
118 for (idx_t i = 0; i < count; i++) {
119 auto sidx = ssel.get_index(idx: i);
120 CountFunction::Operation(state&: *states[sidx]);
121 }
122 }
123 }
124
125 static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
126 idx_t count) {
127 auto &input = inputs[0];
128 if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) {
129 auto sdata = FlatVector::GetData<STATE *>(vector&: states);
130 CountFlatLoop(states: sdata, mask&: FlatVector::Validity(vector&: input), count);
131 } else {
132 UnifiedVectorFormat idata, sdata;
133 input.ToUnifiedFormat(count, data&: idata);
134 states.ToUnifiedFormat(count, data&: sdata);
135 CountScatterLoop(states: reinterpret_cast<STATE **>(sdata.data), isel: *idata.sel, ssel: *sdata.sel, mask&: idata.validity, count);
136 }
137 }
138
139 static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) {
140 idx_t base_idx = 0;
141 auto entry_count = ValidityMask::EntryCount(count);
142 for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) {
143 auto validity_entry = mask.GetValidityEntry(entry_idx);
144 idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count);
145 if (ValidityMask::AllValid(entry: validity_entry)) {
146 // all valid
147 result += next - base_idx;
148 base_idx = next;
149 } else if (ValidityMask::NoneValid(entry: validity_entry)) {
150 // nothing valid: skip all
151 base_idx = next;
152 continue;
153 } else {
154 // partially valid: need to check individual elements for validity
155 idx_t start = base_idx;
156 for (; base_idx < next; base_idx++) {
157 if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) {
158 result++;
159 }
160 }
161 }
162 }
163 }
164
165 static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count,
166 const SelectionVector &sel_vector) {
167 if (mask.AllValid()) {
168 // no NULL values
169 result += count;
170 return;
171 }
172 for (idx_t i = 0; i < count; i++) {
173 auto idx = sel_vector.get_index(idx: i);
174 if (mask.RowIsValid(row_idx: idx)) {
175 result++;
176 }
177 }
178 }
179
180 static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) {
181 auto &input = inputs[0];
182 auto &result = *reinterpret_cast<STATE *>(state_p);
183 switch (input.GetVectorType()) {
184 case VectorType::CONSTANT_VECTOR: {
185 if (!ConstantVector::IsNull(vector: input)) {
186 // if the constant is not null increment the state
187 result += count;
188 }
189 break;
190 }
191 case VectorType::FLAT_VECTOR: {
192 CountFlatUpdateLoop(result, mask&: FlatVector::Validity(vector&: input), count);
193 break;
194 }
195 case VectorType::SEQUENCE_VECTOR: {
196 // sequence vectors cannot have NULL values
197 result += count;
198 break;
199 }
200 default: {
201 UnifiedVectorFormat idata;
202 input.ToUnifiedFormat(count, data&: idata);
203 CountUpdateLoop(result, mask&: idata.validity, count, sel_vector: *idata.sel);
204 break;
205 }
206 }
207 }
208};
209
210AggregateFunction CountFun::GetFunction() {
211 AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize<int64_t>,
212 AggregateFunction::StateInitialize<int64_t, CountFunction>, CountFunction::CountScatter,
213 AggregateFunction::StateCombine<int64_t, CountFunction>,
214 AggregateFunction::StateFinalize<int64_t, int64_t, CountFunction>,
215 FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate);
216 fun.name = "count";
217 fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
218 return fun;
219}
220
221static void CountStarSerialize(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function) {
222}
223
224static unique_ptr<FunctionData> CountStarDeserialize(PlanDeserializationState &state, FieldReader &reader,
225 AggregateFunction &function) {
226 return nullptr;
227}
228
229AggregateFunction CountStarFun::GetFunction() {
230 auto fun = AggregateFunction::NullaryAggregate<int64_t, int64_t, CountStarFunction>(return_type: LogicalType::BIGINT);
231 fun.name = "count_star";
232 fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
233 fun.window = CountStarFunction::Window<int64_t>;
234 // TODO is there a better way to set those?
235 fun.serialize = CountStarSerialize;
236 fun.deserialize = CountStarDeserialize;
237 return fun;
238}
239
240unique_ptr<BaseStatistics> CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr,
241 AggregateStatisticsInput &input) {
242 if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) {
243 // count on a column without null values: use count star
244 expr.function = CountStarFun::GetFunction();
245 expr.function.name = "count_star";
246 expr.children.clear();
247 }
248 return nullptr;
249}
250
251void CountFun::RegisterFunction(BuiltinFunctions &set) {
252 AggregateFunction count_function = CountFun::GetFunction();
253 count_function.statistics = CountPropagateStats;
254 AggregateFunctionSet count("count");
255 count.AddFunction(function: count_function);
256 // the count function can also be called without arguments
257 count_function.arguments.clear();
258 count_function.statistics = nullptr;
259 count_function.window = CountStarFunction::Window<int64_t>;
260 count.AddFunction(function: count_function);
261 set.AddFunction(set: count);
262}
263
264void CountStarFun::RegisterFunction(BuiltinFunctions &set) {
265 AggregateFunctionSet count("count_star");
266 count.AddFunction(function: CountStarFun::GetFunction());
267 set.AddFunction(set: count);
268}
269
270} // namespace duckdb
271