| 1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
| 2 | #include "duckdb/common/exception.hpp" |
| 3 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 4 | #include "duckdb/planner/expression.hpp" |
| 5 | |
| 6 | namespace duckdb { |
| 7 | |
| 8 | template <class T> |
| 9 | struct FirstState { |
| 10 | T value; |
| 11 | bool is_set; |
| 12 | bool is_null; |
| 13 | }; |
| 14 | |
| 15 | struct FirstFunctionBase { |
| 16 | template <class STATE> |
| 17 | static void Initialize(STATE &state) { |
| 18 | state.is_set = false; |
| 19 | state.is_null = false; |
| 20 | } |
| 21 | |
| 22 | static bool IgnoreNull() { |
| 23 | return false; |
| 24 | } |
| 25 | }; |
| 26 | |
| 27 | template <bool LAST, bool SKIP_NULLS> |
| 28 | struct FirstFunction : public FirstFunctionBase { |
| 29 | template <class INPUT_TYPE, class STATE, class OP> |
| 30 | static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { |
| 31 | if (LAST || !state.is_set) { |
| 32 | if (!unary_input.RowIsValid()) { |
| 33 | if (!SKIP_NULLS) { |
| 34 | state.is_set = true; |
| 35 | } |
| 36 | state.is_null = true; |
| 37 | } else { |
| 38 | state.is_set = true; |
| 39 | state.is_null = false; |
| 40 | state.value = input; |
| 41 | } |
| 42 | } |
| 43 | } |
| 44 | |
| 45 | template <class INPUT_TYPE, class STATE, class OP> |
| 46 | static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, |
| 47 | idx_t count) { |
| 48 | Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input); |
| 49 | } |
| 50 | |
| 51 | template <class STATE, class OP> |
| 52 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
| 53 | if (!target.is_set) { |
| 54 | target = source; |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | template <class T, class STATE> |
| 59 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
| 60 | if (!state.is_set || state.is_null) { |
| 61 | finalize_data.ReturnNull(); |
| 62 | } else { |
| 63 | target = state.value; |
| 64 | } |
| 65 | } |
| 66 | }; |
| 67 | |
| 68 | template <bool LAST, bool SKIP_NULLS> |
| 69 | struct FirstFunctionString : public FirstFunctionBase { |
| 70 | template <class STATE> |
| 71 | static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { |
| 72 | if (LAST && state.is_set) { |
| 73 | Destroy(state, input_data); |
| 74 | } |
| 75 | if (is_null) { |
| 76 | if (!SKIP_NULLS) { |
| 77 | state.is_set = true; |
| 78 | state.is_null = true; |
| 79 | } |
| 80 | } else { |
| 81 | state.is_set = true; |
| 82 | state.is_null = false; |
| 83 | if (value.IsInlined()) { |
| 84 | state.value = value; |
| 85 | } else { |
| 86 | // non-inlined string, need to allocate space for it |
| 87 | auto len = value.GetSize(); |
| 88 | auto ptr = new char[len]; |
| 89 | memcpy(dest: ptr, src: value.GetData(), n: len); |
| 90 | |
| 91 | state.value = string_t(ptr, len); |
| 92 | } |
| 93 | } |
| 94 | } |
| 95 | |
| 96 | template <class INPUT_TYPE, class STATE, class OP> |
| 97 | static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { |
| 98 | if (LAST || !state.is_set) { |
| 99 | SetValue(state, unary_input.input, input, !unary_input.RowIsValid()); |
| 100 | } |
| 101 | } |
| 102 | |
| 103 | template <class INPUT_TYPE, class STATE, class OP> |
| 104 | static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, |
| 105 | idx_t count) { |
| 106 | Operation<INPUT_TYPE, STATE, OP>(state, input, unary_input); |
| 107 | } |
| 108 | |
| 109 | template <class STATE, class OP> |
| 110 | static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { |
| 111 | if (source.is_set && (LAST || !target.is_set)) { |
| 112 | SetValue(target, input_data, source.value, source.is_null); |
| 113 | } |
| 114 | } |
| 115 | |
| 116 | template <class T, class STATE> |
| 117 | static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { |
| 118 | if (!state.is_set || state.is_null) { |
| 119 | finalize_data.ReturnNull(); |
| 120 | } else { |
| 121 | target = StringVector::AddStringOrBlob(finalize_data.result, state.value); |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | template <class STATE> |
| 126 | static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { |
| 127 | if (state.is_set && !state.is_null && !state.value.IsInlined()) { |
| 128 | delete[] state.value.GetData(); |
| 129 | } |
| 130 | } |
| 131 | }; |
| 132 | |
| 133 | struct FirstStateVector { |
| 134 | Vector *value; |
| 135 | }; |
| 136 | |
| 137 | template <bool LAST, bool SKIP_NULLS> |
| 138 | struct FirstVectorFunction { |
| 139 | template <class STATE> |
| 140 | static void Initialize(STATE &state) { |
| 141 | state.value = nullptr; |
| 142 | } |
| 143 | |
| 144 | template <class STATE> |
| 145 | static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { |
| 146 | if (state.value) { |
| 147 | delete state.value; |
| 148 | } |
| 149 | } |
| 150 | static bool IgnoreNull() { |
| 151 | return SKIP_NULLS; |
| 152 | } |
| 153 | |
| 154 | template <class STATE> |
| 155 | static void SetValue(STATE &state, Vector &input, const idx_t idx) { |
| 156 | if (!state.value) { |
| 157 | state.value = new Vector(input.GetType()); |
| 158 | state.value->SetVectorType(VectorType::CONSTANT_VECTOR); |
| 159 | } |
| 160 | sel_t selv = idx; |
| 161 | SelectionVector sel(&selv); |
| 162 | VectorOperations::Copy(input, *state.value, sel, 1, 0, 0); |
| 163 | } |
| 164 | |
| 165 | static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { |
| 166 | auto &input = inputs[0]; |
| 167 | UnifiedVectorFormat idata; |
| 168 | input.ToUnifiedFormat(count, data&: idata); |
| 169 | |
| 170 | UnifiedVectorFormat sdata; |
| 171 | state_vector.ToUnifiedFormat(count, data&: sdata); |
| 172 | |
| 173 | auto states = UnifiedVectorFormat::GetData<FirstStateVector *>(format: sdata); |
| 174 | for (idx_t i = 0; i < count; i++) { |
| 175 | const auto idx = idata.sel->get_index(idx: i); |
| 176 | if (SKIP_NULLS && !idata.validity.RowIsValid(row_idx: idx)) { |
| 177 | continue; |
| 178 | } |
| 179 | auto &state = *states[sdata.sel->get_index(idx: i)]; |
| 180 | if (LAST || !state.value) { |
| 181 | SetValue(state, input, i); |
| 182 | } |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | template <class STATE, class OP> |
| 187 | static void Combine(const STATE &source, STATE &target, AggregateInputData &) { |
| 188 | if (source.value && (LAST || !target.value)) { |
| 189 | SetValue(target, *source.value, 0); |
| 190 | } |
| 191 | } |
| 192 | |
| 193 | template <class STATE> |
| 194 | static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { |
| 195 | if (!state.value) { |
| 196 | finalize_data.ReturnNull(); |
| 197 | } else { |
| 198 | VectorOperations::Copy(*state.value, finalize_data.result, 1, 0, finalize_data.result_idx); |
| 199 | } |
| 200 | } |
| 201 | |
| 202 | static unique_ptr<FunctionData> Bind(ClientContext &context, AggregateFunction &function, |
| 203 | vector<unique_ptr<Expression>> &arguments) { |
| 204 | function.arguments[0] = arguments[0]->return_type; |
| 205 | function.return_type = arguments[0]->return_type; |
| 206 | return nullptr; |
| 207 | } |
| 208 | }; |
| 209 | |
| 210 | template <class T, bool LAST, bool SKIP_NULLS> |
| 211 | static AggregateFunction GetFirstAggregateTemplated(LogicalType type) { |
| 212 | return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction<LAST, SKIP_NULLS>>(type, type); |
| 213 | } |
| 214 | |
| 215 | template <bool LAST, bool SKIP_NULLS> |
| 216 | static AggregateFunction GetFirstFunction(const LogicalType &type); |
| 217 | |
| 218 | template <bool LAST, bool SKIP_NULLS> |
| 219 | AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { |
| 220 | D_ASSERT(type.id() == LogicalTypeId::DECIMAL); |
| 221 | switch (type.InternalType()) { |
| 222 | case PhysicalType::INT16: |
| 223 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::SMALLINT); |
| 224 | case PhysicalType::INT32: |
| 225 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::INTEGER); |
| 226 | case PhysicalType::INT64: |
| 227 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::BIGINT); |
| 228 | default: |
| 229 | return GetFirstFunction<LAST, SKIP_NULLS>(LogicalType::HUGEINT); |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | template <bool LAST, bool SKIP_NULLS> |
| 234 | static AggregateFunction GetFirstFunction(const LogicalType &type) { |
| 235 | switch (type.id()) { |
| 236 | case LogicalTypeId::BOOLEAN: |
| 237 | return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type); |
| 238 | case LogicalTypeId::TINYINT: |
| 239 | return GetFirstAggregateTemplated<int8_t, LAST, SKIP_NULLS>(type); |
| 240 | case LogicalTypeId::SMALLINT: |
| 241 | return GetFirstAggregateTemplated<int16_t, LAST, SKIP_NULLS>(type); |
| 242 | case LogicalTypeId::INTEGER: |
| 243 | case LogicalTypeId::DATE: |
| 244 | return GetFirstAggregateTemplated<int32_t, LAST, SKIP_NULLS>(type); |
| 245 | case LogicalTypeId::BIGINT: |
| 246 | case LogicalTypeId::TIME: |
| 247 | case LogicalTypeId::TIMESTAMP: |
| 248 | case LogicalTypeId::TIME_TZ: |
| 249 | case LogicalTypeId::TIMESTAMP_TZ: |
| 250 | return GetFirstAggregateTemplated<int64_t, LAST, SKIP_NULLS>(type); |
| 251 | case LogicalTypeId::UTINYINT: |
| 252 | return GetFirstAggregateTemplated<uint8_t, LAST, SKIP_NULLS>(type); |
| 253 | case LogicalTypeId::USMALLINT: |
| 254 | return GetFirstAggregateTemplated<uint16_t, LAST, SKIP_NULLS>(type); |
| 255 | case LogicalTypeId::UINTEGER: |
| 256 | return GetFirstAggregateTemplated<uint32_t, LAST, SKIP_NULLS>(type); |
| 257 | case LogicalTypeId::UBIGINT: |
| 258 | return GetFirstAggregateTemplated<uint64_t, LAST, SKIP_NULLS>(type); |
| 259 | case LogicalTypeId::HUGEINT: |
| 260 | return GetFirstAggregateTemplated<hugeint_t, LAST, SKIP_NULLS>(type); |
| 261 | case LogicalTypeId::FLOAT: |
| 262 | return GetFirstAggregateTemplated<float, LAST, SKIP_NULLS>(type); |
| 263 | case LogicalTypeId::DOUBLE: |
| 264 | return GetFirstAggregateTemplated<double, LAST, SKIP_NULLS>(type); |
| 265 | case LogicalTypeId::INTERVAL: |
| 266 | return GetFirstAggregateTemplated<interval_t, LAST, SKIP_NULLS>(type); |
| 267 | case LogicalTypeId::VARCHAR: |
| 268 | case LogicalTypeId::BLOB: |
| 269 | return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t, |
| 270 | FirstFunctionString<LAST, SKIP_NULLS>>(type, type); |
| 271 | case LogicalTypeId::DECIMAL: { |
| 272 | type.Verify(); |
| 273 | AggregateFunction function = GetDecimalFirstFunction<LAST, SKIP_NULLS>(type); |
| 274 | function.arguments[0] = type; |
| 275 | function.return_type = type; |
| 276 | // TODO set_key here? |
| 277 | return function; |
| 278 | } |
| 279 | default: { |
| 280 | using OP = FirstVectorFunction<LAST, SKIP_NULLS>; |
| 281 | return AggregateFunction({type}, type, AggregateFunction::StateSize<FirstStateVector>, |
| 282 | AggregateFunction::StateInitialize<FirstStateVector, OP>, OP::Update, |
| 283 | AggregateFunction::StateCombine<FirstStateVector, OP>, |
| 284 | AggregateFunction::StateVoidFinalize<FirstStateVector, OP>, nullptr, OP::Bind, |
| 285 | AggregateFunction::StateDestroy<FirstStateVector, OP>, nullptr, nullptr); |
| 286 | } |
| 287 | } |
| 288 | } |
| 289 | |
| 290 | AggregateFunction FirstFun::GetFunction(const LogicalType &type) { |
| 291 | auto fun = GetFirstFunction<false, false>(type); |
| 292 | fun.name = "first" ; |
| 293 | return fun; |
| 294 | } |
| 295 | |
| 296 | template <bool LAST, bool SKIP_NULLS> |
| 297 | unique_ptr<FunctionData> BindDecimalFirst(ClientContext &context, AggregateFunction &function, |
| 298 | vector<unique_ptr<Expression>> &arguments) { |
| 299 | auto decimal_type = arguments[0]->return_type; |
| 300 | auto name = std::move(function.name); |
| 301 | function = GetFirstFunction<LAST, SKIP_NULLS>(decimal_type); |
| 302 | function.name = std::move(name); |
| 303 | function.return_type = decimal_type; |
| 304 | return nullptr; |
| 305 | } |
| 306 | |
| 307 | template <bool LAST, bool SKIP_NULLS> |
| 308 | static AggregateFunction GetFirstOperator(const LogicalType &type) { |
| 309 | if (type.id() == LogicalTypeId::DECIMAL) { |
| 310 | throw InternalException("FIXME: this shouldn't happen..." ); |
| 311 | } |
| 312 | return GetFirstFunction<LAST, SKIP_NULLS>(type); |
| 313 | } |
| 314 | |
| 315 | template <bool LAST, bool SKIP_NULLS> |
| 316 | unique_ptr<FunctionData> BindFirst(ClientContext &context, AggregateFunction &function, |
| 317 | vector<unique_ptr<Expression>> &arguments) { |
| 318 | auto input_type = arguments[0]->return_type; |
| 319 | auto name = std::move(function.name); |
| 320 | function = GetFirstOperator<LAST, SKIP_NULLS>(input_type); |
| 321 | function.name = std::move(name); |
| 322 | if (function.bind) { |
| 323 | return function.bind(context, function, arguments); |
| 324 | } else { |
| 325 | return nullptr; |
| 326 | } |
| 327 | } |
| 328 | |
| 329 | template <bool LAST, bool SKIP_NULLS> |
| 330 | static void AddFirstOperator(AggregateFunctionSet &set) { |
| 331 | set.AddFunction(function: AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, |
| 332 | nullptr, nullptr, nullptr, BindDecimalFirst<LAST, SKIP_NULLS>)); |
| 333 | set.AddFunction(function: AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, |
| 334 | nullptr, BindFirst<LAST, SKIP_NULLS>)); |
| 335 | } |
| 336 | |
| 337 | void FirstFun::RegisterFunction(BuiltinFunctions &set) { |
| 338 | AggregateFunctionSet first("first" ); |
| 339 | AggregateFunctionSet last("last" ); |
| 340 | AggregateFunctionSet any_value("any_value" ); |
| 341 | |
| 342 | AddFirstOperator<false, false>(set&: first); |
| 343 | AddFirstOperator<true, false>(set&: last); |
| 344 | AddFirstOperator<false, true>(set&: any_value); |
| 345 | |
| 346 | set.AddFunction(set: first); |
| 347 | first.name = "arbitrary" ; |
| 348 | set.AddFunction(set: first); |
| 349 | |
| 350 | set.AddFunction(set: last); |
| 351 | |
| 352 | set.AddFunction(set: any_value); |
| 353 | } |
| 354 | |
| 355 | } // namespace duckdb |
| 356 | |