1#include <DataTypes/DataTypesNumber.h>
2#include <DataTypes/DataTypesDecimal.h>
3#include <Columns/ColumnsNumber.h>
4#include <Columns/ColumnDecimal.h>
5#include "FunctionArrayMapped.h"
6#include <Functions/FunctionFactory.h>
7
8
9namespace DB
10{
11
12namespace ErrorCodes
13{
14 extern const int ILLEGAL_COLUMN;
15}
16
17/** arrayCumSumNonNegative() - returns an array with cumulative sums of the original. (If value < 0 -> 0).
18 */
19struct ArrayCumSumNonNegativeImpl
20{
21 static bool useDefaultImplementationForConstants() { return true; }
22 static bool needBoolean() { return false; }
23 static bool needExpression() { return false; }
24 static bool needOneArray() { return false; }
25
26 static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & /*array_element*/)
27 {
28 WhichDataType which(expression_return);
29
30 if (which.isNativeUInt())
31 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
32
33 if (which.isNativeInt())
34 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeInt64>());
35
36 if (which.isFloat())
37 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
38
39 if (which.isDecimal())
40 {
41 UInt32 scale = getDecimalScale(*expression_return);
42 DataTypePtr nested = std::make_shared<DataTypeDecimal<Decimal128>>(DecimalUtils::maxPrecision<Decimal128>(), scale);
43 return std::make_shared<DataTypeArray>(nested);
44 }
45
46 throw Exception("arrayCumSumNonNegativeImpl cannot add values of type " + expression_return->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
47 }
48
49
50 template <typename Element, typename Result>
51 static bool executeType(const ColumnPtr & mapped, const ColumnArray & array, ColumnPtr & res_ptr)
52 {
53 using ColVecType = std::conditional_t<IsDecimalNumber<Element>, ColumnDecimal<Element>, ColumnVector<Element>>;
54 using ColVecResult = std::conditional_t<IsDecimalNumber<Result>, ColumnDecimal<Result>, ColumnVector<Result>>;
55
56 const ColVecType * column = checkAndGetColumn<ColVecType>(&*mapped);
57
58 if (!column)
59 return false;
60
61 const IColumn::Offsets & offsets = array.getOffsets();
62 const typename ColVecType::Container & data = column->getData();
63
64 typename ColVecResult::MutablePtr res_nested;
65 if constexpr (IsDecimalNumber<Element>)
66 res_nested = ColVecResult::create(0, data.getScale());
67 else
68 res_nested = ColVecResult::create();
69
70 typename ColVecResult::Container & res_values = res_nested->getData();
71 res_values.resize(data.size());
72
73 size_t pos = 0;
74 Result accum_sum = 0;
75 for (size_t i = 0; i < offsets.size(); ++i)
76 {
77 // skip empty arrays
78 if (pos < offsets[i])
79 {
80 accum_sum = data[pos] > 0 ? data[pos] : Element(0);
81 res_values[pos] = accum_sum;
82 for (++pos; pos < offsets[i]; ++pos)
83 {
84 accum_sum = accum_sum + data[pos];
85 if (accum_sum < 0)
86 accum_sum = 0;
87
88 res_values[pos] = accum_sum;
89 }
90 }
91 }
92 res_ptr = ColumnArray::create(std::move(res_nested), array.getOffsetsPtr());
93 return true;
94
95 }
96
97 static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped)
98 {
99 ColumnPtr res;
100
101 if (executeType< UInt8 , UInt64>(mapped, array, res) ||
102 executeType< UInt16, UInt64>(mapped, array, res) ||
103 executeType< UInt32, UInt64>(mapped, array, res) ||
104 executeType< UInt64, UInt64>(mapped, array, res) ||
105 executeType< Int8 , Int64>(mapped, array, res) ||
106 executeType< Int16, Int64>(mapped, array, res) ||
107 executeType< Int32, Int64>(mapped, array, res) ||
108 executeType< Int64, Int64>(mapped, array, res) ||
109 executeType<Float32,Float64>(mapped, array, res) ||
110 executeType<Float64,Float64>(mapped, array, res) ||
111 executeType<Decimal32, Decimal128>(mapped, array, res) ||
112 executeType<Decimal64, Decimal128>(mapped, array, res) ||
113 executeType<Decimal128, Decimal128>(mapped, array, res))
114 return res;
115 else
116 throw Exception("Unexpected column for arrayCumSumNonNegativeImpl: " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
117 }
118
119};
120
121struct NameArrayCumSumNonNegative { static constexpr auto name = "arrayCumSumNonNegative"; };
122using FunctionArrayCumSumNonNegative = FunctionArrayMapped<ArrayCumSumNonNegativeImpl, NameArrayCumSumNonNegative>;
123
124void registerFunctionArrayCumSumNonNegative(FunctionFactory & factory)
125{
126 factory.registerFunction<FunctionArrayCumSumNonNegative>();
127}
128
129}
130
131