1#include "duckdb/function/aggregate/distributive_functions.hpp"
2#include "duckdb/common/exception.hpp"
3#include "duckdb/common/types/null_value.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5
6using namespace std;
7
8namespace duckdb {
9
10struct BaseCountFunction {
11 template <class STATE> static void Initialize(STATE *state) {
12 *state = 0;
13 }
14
15 template <class INPUT_TYPE, class STATE, class OP>
16 static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) {
17 *state += 1;
18 }
19
20 template <class INPUT_TYPE, class STATE, class OP>
21 static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) {
22 *state += count;
23 }
24
25 template <class STATE, class OP> static void Combine(STATE source, STATE *target) {
26 *target += source;
27 }
28
29 template <class T, class STATE>
30 static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) {
31 target[idx] = *state;
32 }
33};
34
35struct CountStarFunction : public BaseCountFunction {
36 static bool IgnoreNull() {
37 return false;
38 }
39};
40
41struct CountFunction : public BaseCountFunction {
42 static bool IgnoreNull() {
43 return true;
44 }
45};
46
47AggregateFunction CountFun::GetFunction() {
48 return AggregateFunction::UnaryAggregate<int64_t, int64_t, int64_t, CountFunction>(SQLType(SQLTypeId::ANY),
49 SQLType::BIGINT);
50}
51
52AggregateFunction CountStarFun::GetFunction() {
53 return AggregateFunction::UnaryAggregate<int64_t, int64_t, int64_t, CountStarFunction>(SQLType(SQLTypeId::ANY),
54 SQLType::BIGINT);
55}
56
57void CountFun::RegisterFunction(BuiltinFunctions &set) {
58 AggregateFunction count_function = CountFun::GetFunction();
59 AggregateFunctionSet count("count");
60 count.AddFunction(count_function);
61 // the count function can also be called without arguments
62 count_function.arguments.clear();
63 count.AddFunction(count_function);
64 set.AddFunction(count);
65}
66
67void CountStarFun::RegisterFunction(BuiltinFunctions &set) {
68 AggregateFunctionSet count("count_star");
69 count.AddFunction(CountStarFun::GetFunction());
70 set.AddFunction(count);
71}
72
73} // namespace duckdb
74