| 1 | #include "duckdb/function/aggregate/distributive_functions.hpp" |
| 2 | #include "duckdb/common/exception.hpp" |
| 3 | #include "duckdb/common/types/null_value.hpp" |
| 4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
| 5 | |
| 6 | using namespace std; |
| 7 | |
| 8 | namespace duckdb { |
| 9 | |
| 10 | template <class T> struct FirstState { |
| 11 | bool is_set; |
| 12 | T value; |
| 13 | }; |
| 14 | |
| 15 | struct FirstFunctionBase { |
| 16 | template <class STATE> static void Initialize(STATE *state) { |
| 17 | state->is_set = false; |
| 18 | } |
| 19 | |
| 20 | template <class STATE, class OP> static void Combine(STATE source, STATE *target) { |
| 21 | if (!target->is_set) { |
| 22 | *target = source; |
| 23 | } |
| 24 | } |
| 25 | |
| 26 | static bool IgnoreNull() { |
| 27 | return false; |
| 28 | } |
| 29 | }; |
| 30 | |
| 31 | struct FirstFunction : public FirstFunctionBase { |
| 32 | template <class INPUT_TYPE, class STATE, class OP> |
| 33 | static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) { |
| 34 | if (!state->is_set) { |
| 35 | state->is_set = true; |
| 36 | if (nullmask[idx]) { |
| 37 | state->value = NullValue<INPUT_TYPE>(); |
| 38 | } else { |
| 39 | state->value = input[idx]; |
| 40 | } |
| 41 | } |
| 42 | } |
| 43 | |
| 44 | template <class INPUT_TYPE, class STATE, class OP> |
| 45 | static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) { |
| 46 | Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0); |
| 47 | } |
| 48 | |
| 49 | template <class T, class STATE> |
| 50 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
| 51 | if (!state->is_set || IsNullValue<T>(state->value)) { |
| 52 | nullmask[idx] = true; |
| 53 | } else { |
| 54 | target[idx] = state->value; |
| 55 | } |
| 56 | } |
| 57 | }; |
| 58 | |
| 59 | struct FirstFunctionString : public FirstFunctionBase { |
| 60 | template <class INPUT_TYPE, class STATE, class OP> |
| 61 | static void Operation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t idx) { |
| 62 | if (!state->is_set) { |
| 63 | state->is_set = true; |
| 64 | if (nullmask[idx]) { |
| 65 | state->value = NullValue<INPUT_TYPE>(); |
| 66 | } else { |
| 67 | if (input[idx].IsInlined()) { |
| 68 | state->value = input[idx]; |
| 69 | } else { |
| 70 | // non-inlined string, need to allocate space for it |
| 71 | auto len = input[idx].GetSize(); |
| 72 | auto ptr = new char[len + 1]; |
| 73 | memcpy(ptr, input[idx].GetData(), len + 1); |
| 74 | |
| 75 | state->value = string_t(ptr, len); |
| 76 | } |
| 77 | } |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | template <class INPUT_TYPE, class STATE, class OP> |
| 82 | static void ConstantOperation(STATE *state, INPUT_TYPE *input, nullmask_t &nullmask, idx_t count) { |
| 83 | Operation<INPUT_TYPE, STATE, OP>(state, input, nullmask, 0); |
| 84 | } |
| 85 | |
| 86 | template <class T, class STATE> |
| 87 | static void Finalize(Vector &result, STATE *state, T *target, nullmask_t &nullmask, idx_t idx) { |
| 88 | if (!state->is_set || IsNullValue<T>(state->value)) { |
| 89 | nullmask[idx] = true; |
| 90 | } else { |
| 91 | target[idx] = StringVector::AddString(result, state->value); |
| 92 | } |
| 93 | } |
| 94 | |
| 95 | template <class STATE> static void Destroy(STATE *state) { |
| 96 | if (state->is_set && !state->value.IsInlined()) { |
| 97 | delete[] state->value.GetData(); |
| 98 | } |
| 99 | } |
| 100 | }; |
| 101 | |
| 102 | template <class T> static AggregateFunction GetFirstAggregateTemplated(SQLType type) { |
| 103 | return AggregateFunction::UnaryAggregate<FirstState<T>, T, T, FirstFunction>(type, type); |
| 104 | } |
| 105 | |
| 106 | AggregateFunction FirstFun::GetFunction(SQLType type) { |
| 107 | switch (type.id) { |
| 108 | case SQLTypeId::BOOLEAN: |
| 109 | return GetFirstAggregateTemplated<int8_t>(type); |
| 110 | case SQLTypeId::TINYINT: |
| 111 | return GetFirstAggregateTemplated<int8_t>(type); |
| 112 | case SQLTypeId::SMALLINT: |
| 113 | return GetFirstAggregateTemplated<int16_t>(type); |
| 114 | case SQLTypeId::INTEGER: |
| 115 | return GetFirstAggregateTemplated<int32_t>(type); |
| 116 | case SQLTypeId::BIGINT: |
| 117 | return GetFirstAggregateTemplated<int64_t>(type); |
| 118 | case SQLTypeId::FLOAT: |
| 119 | return GetFirstAggregateTemplated<float>(type); |
| 120 | case SQLTypeId::DOUBLE: |
| 121 | return GetFirstAggregateTemplated<double>(type); |
| 122 | case SQLTypeId::DECIMAL: |
| 123 | return GetFirstAggregateTemplated<double>(type); |
| 124 | case SQLTypeId::DATE: |
| 125 | return GetFirstAggregateTemplated<date_t>(type); |
| 126 | case SQLTypeId::TIMESTAMP: |
| 127 | return GetFirstAggregateTemplated<timestamp_t>(type); |
| 128 | case SQLTypeId::VARCHAR: |
| 129 | case SQLTypeId::BLOB: |
| 130 | return AggregateFunction::UnaryAggregateDestructor<FirstState<string_t>, string_t, string_t, |
| 131 | FirstFunctionString>(type, type); |
| 132 | default: |
| 133 | throw NotImplementedException("Unimplemented type for FIRST aggregate" ); |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | void FirstFun::RegisterFunction(BuiltinFunctions &set) { |
| 138 | AggregateFunctionSet first("first" ); |
| 139 | for (auto type : SQLType::ALL_TYPES) { |
| 140 | first.AddFunction(FirstFun::GetFunction(type)); |
| 141 | } |
| 142 | set.AddFunction(first); |
| 143 | } |
| 144 | |
| 145 | } // namespace duckdb |
| 146 | |