1#include <DataTypes/DataTypesNumber.h>
2#include <DataTypes/DataTypesDecimal.h>
3#include <Columns/ColumnsNumber.h>
4#include <Columns/ColumnDecimal.h>
5#include "FunctionArrayMapped.h"
6#include <Functions/FunctionFactory.h>
7
8
9namespace DB
10{
11
12namespace ErrorCodes
13{
14 extern const int ILLEGAL_COLUMN;
15}
16
17struct ArrayCumSumImpl
18{
19 static bool needBoolean() { return false; }
20 static bool needExpression() { return false; }
21 static bool needOneArray() { return false; }
22
23 static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
24 {
25 WhichDataType which(expression_return);
26
27 if (which.isNativeUInt())
28 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
29
30 if (which.isNativeInt())
31 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt64>());
32
33 if (which.isFloat())
34 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
35
36 if (which.isDecimal())
37 {
38 UInt32 scale = getDecimalScale(*expression_return);
39 DataTypePtr nested = std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::maxPrecision<Decimal128>(), scale);
40 return std::make_shared<DataTypeArray>(nested);
41 }
42
43 throw Exception("arrayCumSum cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
44 }
45
46
47 template <typename Element, typename Result>
48 static bool executeType(const ColumnPtr & mapped, const ColumnArray & array, ColumnPtr & res_ptr)
49 {
50 using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
51 using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
52
53 const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
54
55 if (!column)
56 {
57 const ColumnConst * column_const = checkAndGetColumnConst<ColVecType>(&*mapped);
58
59 if (!column_const)
60 return false;
61
62 const Element x = column_const->template getValue<Element>();
63 const IColumn::Offsets & offsets = array.getOffsets();
64
65 typename ColVecResult::MutablePtr res_nested;
66 if constexpr (IsDecimalNumber<Element>)
67 {
68 const typename ColVecType::Container & data =
69 checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData();
70 res_nested = ColVecResult::create(0, data.getScale());
71 }
72 else
73 res_nested = ColVecResult::create();
74
75 typename ColVecResult::Container & res_values = res_nested->getData();
76 res_values.resize(column_const->size());
77
78 size_t pos = 0;
79 for (size_t i = 0; i < offsets.size(); ++i)
80 {
81 // skip empty arrays
82 if (pos < offsets[i])
83 {
84 res_values[pos++] = x;
85 for (; pos < offsets[i]; ++pos)
86 {
87 res_values[pos] = res_values[pos - 1] + x;
88 }
89 }
90 }
91
92 res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr());
93 return true;
94 }
95
96 const typename ColVecType::Container & data = column->getData();
97 const IColumn::Offsets & offsets = array.getOffsets();
98
99 typename ColVecResult::MutablePtr res_nested;
100 if constexpr (IsDecimalNumber<Element>)
101 res_nested = ColVecResult::create(0, data.getScale());
102 else
103 res_nested = ColVecResult::create();
104
105 typename ColVecResult::Container & res_values = res_nested->getData();
106 res_values.resize(data.size());
107
108 size_t pos = 0;
109 for (size_t i = 0; i < offsets.size(); ++i)
110 {
111 // skip empty arrays
112 if (pos < offsets[i])
113 {
114 res_values[pos] = data[pos];
115 for (++pos; pos < offsets[i]; ++pos)
116 {
117 res_values[pos] = res_values[pos - 1] + data[pos];
118 }
119 }
120 }
121 res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr());
122 return true;
123
124 }
125
126 static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped)
127 {
128 ColumnPtr res;
129
130 if (executeType< UInt8 , UInt64>(mapped, array, res) ||
131 executeType< UInt16, UInt64>(mapped, array, res) ||
132 executeType< UInt32, UInt64>(mapped, array, res) ||
133 executeType< UInt64, UInt64>(mapped, array, res) ||
134 executeType< Int8 , Int64>(mapped, array, res) ||
135 executeType< Int16, Int64>(mapped, array, res) ||
136 executeType< Int32, Int64>(mapped, array, res) ||
137 executeType< Int64, Int64>(mapped, array, res) ||
138 executeType<Float32,Float64>(mapped, array, res) ||
139 executeType<Float64,Float64>(mapped, array, res) ||
140 executeType<Decimal32, Decimal128>(mapped, array, res) ||
141 executeType<Decimal64, Decimal128>(mapped, array, res) ||
142 executeType<Decimal128, Decimal128>(mapped, array, res))
143 return res;
144 else
145 throw Exception("Unexpected column for arrayCumSum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
146 }
147
148};
149
150struct NameArrayCumSum { static constexpr auto name = "arrayCumSum"; };
151using FunctionArrayCumSum = FunctionArrayMapped<ArrayCumSumImpl, NameArrayCumSum>;
152
153void registerFunctionArrayCumSum(FunctionFactory & factory)
154{
155 factory.registerFunction<FunctionArrayCumSum>();
156}
157
158}
159
160