1 | #include <DataTypes/DataTypesNumber.h> |
2 | #include <DataTypes/DataTypesDecimal.h> |
3 | #include <Columns/ColumnsNumber.h> |
4 | #include <Columns/ColumnDecimal.h> |
5 | #include "FunctionArrayMapped.h" |
6 | #include <Functions/FunctionFactory.h> |
7 | |
8 | |
9 | namespace DB |
10 | { |
11 | |
12 | namespace ErrorCodes |
13 | { |
14 | extern const int ILLEGAL_COLUMN; |
15 | } |
16 | |
17 | struct ArrayCumSumImpl |
18 | { |
19 | static bool needBoolean() { return false; } |
20 | static bool needExpression() { return false; } |
21 | static bool needOneArray() { return false; } |
22 | |
23 | static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/) |
24 | { |
25 | WhichDataType which(expression_return); |
26 | |
27 | if (which.isNativeUInt()) |
28 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>()); |
29 | |
30 | if (which.isNativeInt()) |
31 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt64>()); |
32 | |
33 | if (which.isFloat()) |
34 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>()); |
35 | |
36 | if (which.isDecimal()) |
37 | { |
38 | UInt32 scale = getDecimalScale(*expression_return); |
39 | DataTypePtr nested = std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::maxPrecision<Decimal128>(), scale); |
40 | return std::make_shared<DataTypeArray>(nested); |
41 | } |
42 | |
43 | throw Exception("arrayCumSum cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
44 | } |
45 | |
46 | |
47 | template <typename Element, typename Result> |
48 | static bool executeType(const ColumnPtr & mapped, const ColumnArray & array, ColumnPtr & res_ptr) |
49 | { |
50 | using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>; |
51 | using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>; |
52 | |
53 | const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped); |
54 | |
55 | if (!column) |
56 | { |
57 | const ColumnConst * column_const = checkAndGetColumnConst<ColVecType>(&*mapped); |
58 | |
59 | if (!column_const) |
60 | return false; |
61 | |
62 | const Element x = column_const->template getValue<Element>(); |
63 | const IColumn::Offsets & offsets = array.getOffsets(); |
64 | |
65 | typename ColVecResult::MutablePtr res_nested; |
66 | if constexpr (IsDecimalNumber<Element>) |
67 | { |
68 | const typename ColVecType::Container & data = |
69 | checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData(); |
70 | res_nested = ColVecResult::create(0, data.getScale()); |
71 | } |
72 | else |
73 | res_nested = ColVecResult::create(); |
74 | |
75 | typename ColVecResult::Container & res_values = res_nested->getData(); |
76 | res_values.resize(column_const->size()); |
77 | |
78 | size_t pos = 0; |
79 | for (size_t i = 0; i < offsets.size(); ++i) |
80 | { |
81 | // skip empty arrays |
82 | if (pos < offsets[i]) |
83 | { |
84 | res_values[pos++] = x; |
85 | for (; pos < offsets[i]; ++pos) |
86 | { |
87 | res_values[pos] = res_values[pos - 1] + x; |
88 | } |
89 | } |
90 | } |
91 | |
92 | res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr()); |
93 | return true; |
94 | } |
95 | |
96 | const typename ColVecType::Container & data = column->getData(); |
97 | const IColumn::Offsets & offsets = array.getOffsets(); |
98 | |
99 | typename ColVecResult::MutablePtr res_nested; |
100 | if constexpr (IsDecimalNumber<Element>) |
101 | res_nested = ColVecResult::create(0, data.getScale()); |
102 | else |
103 | res_nested = ColVecResult::create(); |
104 | |
105 | typename ColVecResult::Container & res_values = res_nested->getData(); |
106 | res_values.resize(data.size()); |
107 | |
108 | size_t pos = 0; |
109 | for (size_t i = 0; i < offsets.size(); ++i) |
110 | { |
111 | // skip empty arrays |
112 | if (pos < offsets[i]) |
113 | { |
114 | res_values[pos] = data[pos]; |
115 | for (++pos; pos < offsets[i]; ++pos) |
116 | { |
117 | res_values[pos] = res_values[pos - 1] + data[pos]; |
118 | } |
119 | } |
120 | } |
121 | res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr()); |
122 | return true; |
123 | |
124 | } |
125 | |
126 | static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped) |
127 | { |
128 | ColumnPtr res; |
129 | |
130 | if (executeType< UInt8 , UInt64>(mapped, array, res) || |
131 | executeType< UInt16, UInt64>(mapped, array, res) || |
132 | executeType< UInt32, UInt64>(mapped, array, res) || |
133 | executeType< UInt64, UInt64>(mapped, array, res) || |
134 | executeType< Int8 , Int64>(mapped, array, res) || |
135 | executeType< Int16, Int64>(mapped, array, res) || |
136 | executeType< Int32, Int64>(mapped, array, res) || |
137 | executeType< Int64, Int64>(mapped, array, res) || |
138 | executeType<Float32,Float64>(mapped, array, res) || |
139 | executeType<Float64,Float64>(mapped, array, res) || |
140 | executeType<Decimal32, Decimal128>(mapped, array, res) || |
141 | executeType<Decimal64, Decimal128>(mapped, array, res) || |
142 | executeType<Decimal128, Decimal128>(mapped, array, res)) |
143 | return res; |
144 | else |
145 | throw Exception("Unexpected column for arrayCumSum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN); |
146 | } |
147 | |
148 | }; |
149 | |
150 | struct NameArrayCumSum { static constexpr auto name = "arrayCumSum" ; }; |
151 | using FunctionArrayCumSum = FunctionArrayMapped<ArrayCumSumImpl, NameArrayCumSum>; |
152 | |
153 | void registerFunctionArrayCumSum(FunctionFactory & factory) |
154 | { |
155 | factory.registerFunction<FunctionArrayCumSum>(); |
156 | } |
157 | |
158 | } |
159 | |
160 | |