1#include <DataTypes/DataTypesNumber.h>
2#include <DataTypes/DataTypesDecimal.h>
3#include <Columns/ColumnsNumber.h>
4#include <Columns/ColumnDecimal.h>
5#include "FunctionArrayMapped.h"
6#include <Functions/FunctionFactory.h>
7
8
9namespace DB
10{
11
12namespace ErrorCodes
13{
14 extern const int ILLEGAL_COLUMN;
15}
16
17struct ArraySumImpl
18{
19 static bool needBoolean() { return false; }
20 static bool needExpression() { return false; }
21 static bool needOneArray() { return false; }
22
23 static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
24 {
25 WhichDataType which(expression_return);
26
27 if (which.isNativeUInt())
28 return std::make_shared<DataTypeUInt64>();
29
30 if (which.isNativeInt())
31 return std::make_shared<DataTypeInt64>();
32
33 if (which.isFloat())
34 return std::make_shared<DataTypeFloat64>();
35
36 if (which.isDecimal())
37 {
38 UInt32 scale = getDecimalScale(*expression_return);
39 return std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::maxPrecision<Decimal128>(), scale);
40 }
41
42 throw Exception("arraySum cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
43 }
44
45 template <typename Element, typename Result>
46 static bool executeType(const ColumnPtr & mapped, const ColumnArray::Offsets & offsets, ColumnPtr & res_ptr)
47 {
48 using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
49 using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
50
51 const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
52
53 if (!column)
54 {
55 const ColumnConst * column_const = checkAndGetColumnConst<ColVecType>(&*mapped);
56
57 if (!column_const)
58 return false;
59
60 const Element x = column_const->template getValue<Element>();
61
62 typename ColVecResult::MutablePtr res_column;
63 if constexpr (IsDecimalNumber<Element>)
64 {
65 const typename ColVecType::Container & data =
66 checkAndGetColumn<ColVecType>(&column_const->getDataColumn())->getData();
67 res_column = ColVecResult::create(offsets.size(), data.getScale());
68 }
69 else
70 res_column = ColVecResult::create(offsets.size());
71
72 typename ColVecResult::Container & res = res_column->getData();
73
74 size_t pos = 0;
75 for (size_t i = 0; i < offsets.size(); ++i)
76 {
77 res[i] = x * (offsets[i] - pos);
78 pos = offsets[i];
79 }
80
81 res_ptr = std::move(res_column);
82 return true;
83 }
84
85 const typename ColVecType::Container & data = column->getData();
86
87 typename ColVecResult::MutablePtr res_column;
88 if constexpr (IsDecimalNumber<Element>)
89 res_column = ColVecResult::create(offsets.size(), data.getScale());
90 else
91 res_column = ColVecResult::create(offsets.size());
92
93 typename ColVecResult::Container & res = res_column->getData();
94
95 size_t pos = 0;
96 for (size_t i = 0; i < offsets.size(); ++i)
97 {
98 Result s = 0;
99 for (; pos < offsets[i]; ++pos)
100 {
101 s += data[pos];
102 }
103 res[i] = s;
104 }
105
106 res_ptr = std::move(res_column);
107 return true;
108 }
109
110 static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped)
111 {
112 const IColumn::Offsets & offsets = array.getOffsets();
113 ColumnPtr res;
114
115 if (executeType< UInt8 , UInt64>(mapped, offsets, res) ||
116 executeType< UInt16, UInt64>(mapped, offsets, res) ||
117 executeType< UInt32, UInt64>(mapped, offsets, res) ||
118 executeType< UInt64, UInt64>(mapped, offsets, res) ||
119 executeType< Int8 , Int64>(mapped, offsets, res) ||
120 executeType< Int16, Int64>(mapped, offsets, res) ||
121 executeType< Int32, Int64>(mapped, offsets, res) ||
122 executeType< Int64, Int64>(mapped, offsets, res) ||
123 executeType<Float32,Float64>(mapped, offsets, res) ||
124 executeType<Float64,Float64>(mapped, offsets, res) ||
125 executeType<Decimal32, Decimal128>(mapped, offsets, res) ||
126 executeType<Decimal64, Decimal128>(mapped, offsets, res) ||
127 executeType<Decimal128, Decimal128>(mapped, offsets, res))
128 return res;
129 else
130 throw Exception("Unexpected column for arraySum: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
131 }
132};
133
134struct NameArraySum { static constexpr auto name = "arraySum"; };
135using FunctionArraySum = FunctionArrayMapped<ArraySumImpl, NameArraySum>;
136
137void registerFunctionArraySum(FunctionFactory & factory)
138{
139 factory.registerFunction<FunctionArraySum>();
140}
141
142}
143
144