| 1 | #include "duckdb/common/exception.hpp" |
| 2 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 3 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
| 4 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 5 | |
| 6 | namespace duckdb { |
| 7 | |
| 8 | struct BaseCountFunction { |
| 9 | template <class STATE> |
| 10 | static void Initialize(STATE &state) { |
| 11 | state = 0; |
| 12 | } |
| 13 | |
| 14 | template <class STATE, class OP> |
| 15 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
| 16 | target += source; |
| 17 | } |
| 18 | |
| 19 | template <class T, class STATE> |
| 20 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
| 21 | target = state; |
| 22 | } |
| 23 | }; |
| 24 | |
| 25 | struct CountStarFunction : public BaseCountFunction { |
| 26 | template <class STATE, class OP> |
| 27 | static void Operation(STATE &state, AggregateInputData &, idx_t idx) { |
| 28 | state += 1; |
| 29 | } |
| 30 | |
| 31 | template <class STATE, class OP> |
| 32 | static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) { |
| 33 | state += count; |
| 34 | } |
| 35 | |
| 36 | template <typename RESULT_TYPE> |
| 37 | static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, |
| 38 | idx_t input_count, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, |
| 39 | Vector &result, idx_t rid, idx_t bias) { |
| 40 | D_ASSERT(input_count == 0); |
| 41 | auto data = FlatVector::GetData<RESULT_TYPE>(result); |
| 42 | const auto begin = frame.first; |
| 43 | const auto end = frame.second; |
| 44 | // Slice to any filtered rows |
| 45 | if (!filter_mask.AllValid()) { |
| 46 | RESULT_TYPE filtered = 0; |
| 47 | for (auto i = begin; i < end; ++i) { |
| 48 | filtered += filter_mask.RowIsValid(row_idx: i); |
| 49 | } |
| 50 | data[rid] = filtered; |
| 51 | } else { |
| 52 | data[rid] = end - begin; |
| 53 | } |
| 54 | } |
| 55 | }; |
| 56 | |
| 57 | struct CountFunction : public BaseCountFunction { |
| 58 | using STATE = int64_t; |
| 59 | |
| 60 | static void Operation(STATE &state) { |
| 61 | state += 1; |
| 62 | } |
| 63 | |
| 64 | static void ConstantOperation(STATE &state, idx_t count) { |
| 65 | state += count; |
| 66 | } |
| 67 | |
| 68 | static bool IgnoreNull() { |
| 69 | return true; |
| 70 | } |
| 71 | |
| 72 | static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) { |
| 73 | if (!mask.AllValid()) { |
| 74 | idx_t base_idx = 0; |
| 75 | auto entry_count = ValidityMask::EntryCount(count); |
| 76 | for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { |
| 77 | auto validity_entry = mask.GetValidityEntry(entry_idx); |
| 78 | idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count); |
| 79 | if (ValidityMask::AllValid(entry: validity_entry)) { |
| 80 | // all valid: perform operation |
| 81 | for (; base_idx < next; base_idx++) { |
| 82 | CountFunction::Operation(state&: *states[base_idx]); |
| 83 | } |
| 84 | } else if (ValidityMask::NoneValid(entry: validity_entry)) { |
| 85 | // nothing valid: skip all |
| 86 | base_idx = next; |
| 87 | continue; |
| 88 | } else { |
| 89 | // partially valid: need to check individual elements for validity |
| 90 | idx_t start = base_idx; |
| 91 | for (; base_idx < next; base_idx++) { |
| 92 | if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) { |
| 93 | CountFunction::Operation(state&: *states[base_idx]); |
| 94 | } |
| 95 | } |
| 96 | } |
| 97 | } |
| 98 | } else { |
| 99 | for (idx_t i = 0; i < count; i++) { |
| 100 | CountFunction::Operation(state&: *states[i]); |
| 101 | } |
| 102 | } |
| 103 | } |
| 104 | |
| 105 | static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel, |
| 106 | const SelectionVector &ssel, ValidityMask &mask, idx_t count) { |
| 107 | if (!mask.AllValid()) { |
| 108 | // potential NULL values |
| 109 | for (idx_t i = 0; i < count; i++) { |
| 110 | auto idx = isel.get_index(idx: i); |
| 111 | auto sidx = ssel.get_index(idx: i); |
| 112 | if (mask.RowIsValid(row_idx: idx)) { |
| 113 | CountFunction::Operation(state&: *states[sidx]); |
| 114 | } |
| 115 | } |
| 116 | } else { |
| 117 | // quick path: no NULL values |
| 118 | for (idx_t i = 0; i < count; i++) { |
| 119 | auto sidx = ssel.get_index(idx: i); |
| 120 | CountFunction::Operation(state&: *states[sidx]); |
| 121 | } |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, |
| 126 | idx_t count) { |
| 127 | auto &input = inputs[0]; |
| 128 | if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { |
| 129 | auto sdata = FlatVector::GetData<STATE *>(vector&: states); |
| 130 | CountFlatLoop(states: sdata, mask&: FlatVector::Validity(vector&: input), count); |
| 131 | } else { |
| 132 | UnifiedVectorFormat idata, sdata; |
| 133 | input.ToUnifiedFormat(count, data&: idata); |
| 134 | states.ToUnifiedFormat(count, data&: sdata); |
| 135 | CountScatterLoop(states: reinterpret_cast<STATE **>(sdata.data), isel: *idata.sel, ssel: *sdata.sel, mask&: idata.validity, count); |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) { |
| 140 | idx_t base_idx = 0; |
| 141 | auto entry_count = ValidityMask::EntryCount(count); |
| 142 | for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { |
| 143 | auto validity_entry = mask.GetValidityEntry(entry_idx); |
| 144 | idx_t next = MinValue<idx_t>(a: base_idx + ValidityMask::BITS_PER_VALUE, b: count); |
| 145 | if (ValidityMask::AllValid(entry: validity_entry)) { |
| 146 | // all valid |
| 147 | result += next - base_idx; |
| 148 | base_idx = next; |
| 149 | } else if (ValidityMask::NoneValid(entry: validity_entry)) { |
| 150 | // nothing valid: skip all |
| 151 | base_idx = next; |
| 152 | continue; |
| 153 | } else { |
| 154 | // partially valid: need to check individual elements for validity |
| 155 | idx_t start = base_idx; |
| 156 | for (; base_idx < next; base_idx++) { |
| 157 | if (ValidityMask::RowIsValid(entry: validity_entry, idx_in_entry: base_idx - start)) { |
| 158 | result++; |
| 159 | } |
| 160 | } |
| 161 | } |
| 162 | } |
| 163 | } |
| 164 | |
| 165 | static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count, |
| 166 | const SelectionVector &sel_vector) { |
| 167 | if (mask.AllValid()) { |
| 168 | // no NULL values |
| 169 | result += count; |
| 170 | return; |
| 171 | } |
| 172 | for (idx_t i = 0; i < count; i++) { |
| 173 | auto idx = sel_vector.get_index(idx: i); |
| 174 | if (mask.RowIsValid(row_idx: idx)) { |
| 175 | result++; |
| 176 | } |
| 177 | } |
| 178 | } |
| 179 | |
| 180 | static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { |
| 181 | auto &input = inputs[0]; |
| 182 | auto &result = *reinterpret_cast<STATE *>(state_p); |
| 183 | switch (input.GetVectorType()) { |
| 184 | case VectorType::CONSTANT_VECTOR: { |
| 185 | if (!ConstantVector::IsNull(vector: input)) { |
| 186 | // if the constant is not null increment the state |
| 187 | result += count; |
| 188 | } |
| 189 | break; |
| 190 | } |
| 191 | case VectorType::FLAT_VECTOR: { |
| 192 | CountFlatUpdateLoop(result, mask&: FlatVector::Validity(vector&: input), count); |
| 193 | break; |
| 194 | } |
| 195 | case VectorType::SEQUENCE_VECTOR: { |
| 196 | // sequence vectors cannot have NULL values |
| 197 | result += count; |
| 198 | break; |
| 199 | } |
| 200 | default: { |
| 201 | UnifiedVectorFormat idata; |
| 202 | input.ToUnifiedFormat(count, data&: idata); |
| 203 | CountUpdateLoop(result, mask&: idata.validity, count, sel_vector: *idata.sel); |
| 204 | break; |
| 205 | } |
| 206 | } |
| 207 | } |
| 208 | }; |
| 209 | |
| 210 | AggregateFunction CountFun::GetFunction() { |
| 211 | AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize<int64_t>, |
| 212 | AggregateFunction::StateInitialize<int64_t, CountFunction>, CountFunction::CountScatter, |
| 213 | AggregateFunction::StateCombine<int64_t, CountFunction>, |
| 214 | AggregateFunction::StateFinalize<int64_t, int64_t, CountFunction>, |
| 215 | FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); |
| 216 | fun.name = "count" ; |
| 217 | fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; |
| 218 | return fun; |
| 219 | } |
| 220 | |
| 221 | static void CountStarSerialize(FieldWriter &writer, const FunctionData *bind_data, const AggregateFunction &function) { |
| 222 | } |
| 223 | |
| 224 | static unique_ptr<FunctionData> CountStarDeserialize(PlanDeserializationState &state, FieldReader &reader, |
| 225 | AggregateFunction &function) { |
| 226 | return nullptr; |
| 227 | } |
| 228 | |
| 229 | AggregateFunction CountStarFun::GetFunction() { |
| 230 | auto fun = AggregateFunction::NullaryAggregate<int64_t, int64_t, CountStarFunction>(return_type: LogicalType::BIGINT); |
| 231 | fun.name = "count_star" ; |
| 232 | fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; |
| 233 | fun.window = CountStarFunction::Window<int64_t>; |
| 234 | // TODO is there a better way to set those? |
| 235 | fun.serialize = CountStarSerialize; |
| 236 | fun.deserialize = CountStarDeserialize; |
| 237 | return fun; |
| 238 | } |
| 239 | |
| 240 | unique_ptr<BaseStatistics> CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, |
| 241 | AggregateStatisticsInput &input) { |
| 242 | if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { |
| 243 | // count on a column without null values: use count star |
| 244 | expr.function = CountStarFun::GetFunction(); |
| 245 | expr.function.name = "count_star" ; |
| 246 | expr.children.clear(); |
| 247 | } |
| 248 | return nullptr; |
| 249 | } |
| 250 | |
| 251 | void CountFun::RegisterFunction(BuiltinFunctions &set) { |
| 252 | AggregateFunction count_function = CountFun::GetFunction(); |
| 253 | count_function.statistics = CountPropagateStats; |
| 254 | AggregateFunctionSet count("count" ); |
| 255 | count.AddFunction(function: count_function); |
| 256 | // the count function can also be called without arguments |
| 257 | count_function.arguments.clear(); |
| 258 | count_function.statistics = nullptr; |
| 259 | count_function.window = CountStarFunction::Window<int64_t>; |
| 260 | count.AddFunction(function: count_function); |
| 261 | set.AddFunction(set: count); |
| 262 | } |
| 263 | |
| 264 | void CountStarFun::RegisterFunction(BuiltinFunctions &set) { |
| 265 | AggregateFunctionSet count("count_star" ); |
| 266 | count.AddFunction(function: CountStarFun::GetFunction()); |
| 267 | set.AddFunction(set: count); |
| 268 | } |
| 269 | |
| 270 | } // namespace duckdb |
| 271 | |