1#include <Columns/ColumnString.h>
2#include <Columns/ColumnVector.h>
3#include <DataTypes/DataTypeString.h>
4#include <DataTypes/DataTypesNumber.h>
5#include <Functions/FunctionFactory.h>
6#include <Functions/FunctionHelpers.h>
7#include <Functions/IFunctionImpl.h>
8#include <Functions/castTypeToEither.h>
9
10
11namespace DB
12{
13namespace ErrorCodes
14{
15 extern const int ILLEGAL_COLUMN;
16 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
17 extern const int TOO_LARGE_STRING_SIZE;
18}
19
20struct RepeatImpl
21{
22 /// Safety threshold against DoS.
23 static inline void checkRepeatTime(UInt64 repeat_time)
24 {
25 static constexpr UInt64 max_repeat_times = 1000000;
26 if (repeat_time > max_repeat_times)
27 throw Exception("Too many times to repeat (" + std::to_string(repeat_time) + "), maximum is: " + std::to_string(max_repeat_times),
28 ErrorCodes::TOO_LARGE_STRING_SIZE);
29 }
30
31 static void vectorStrConstRepeat(
32 const ColumnString::Chars & data,
33 const ColumnString::Offsets & offsets,
34 ColumnString::Chars & res_data,
35 ColumnString::Offsets & res_offsets,
36 UInt64 repeat_time)
37 {
38 checkRepeatTime(repeat_time);
39
40 UInt64 data_size = 0;
41 res_offsets.assign(offsets);
42 for (UInt64 i = 0; i < offsets.size(); ++i)
43 {
44 data_size += (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1; /// Note that accessing -1th element is valid for PaddedPODArray.
45 res_offsets[i] = data_size;
46 }
47 res_data.resize(data_size);
48 for (UInt64 i = 0; i < res_offsets.size(); ++i)
49 {
50 process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time);
51 }
52 }
53
54 template <typename T>
55 static void vectorStrVectorRepeat(
56 const ColumnString::Chars & data,
57 const ColumnString::Offsets & offsets,
58 ColumnString::Chars & res_data,
59 ColumnString::Offsets & res_offsets,
60 const PaddedPODArray<T> & col_num)
61 {
62 UInt64 data_size = 0;
63 res_offsets.assign(offsets);
64 for (UInt64 i = 0; i < col_num.size(); ++i)
65 {
66 data_size += (offsets[i] - offsets[i - 1] - 1) * col_num[i] + 1;
67 res_offsets[i] = data_size;
68 }
69 res_data.resize(data_size);
70
71 for (UInt64 i = 0; i < col_num.size(); ++i)
72 {
73 T repeat_time = col_num[i];
74 checkRepeatTime(repeat_time);
75 process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time);
76 }
77 }
78
79 template <typename T>
80 static void constStrVectorRepeat(
81 const StringRef & copy_str,
82 ColumnString::Chars & res_data,
83 ColumnString::Offsets & res_offsets,
84 const PaddedPODArray<T> & col_num)
85 {
86 UInt64 data_size = 0;
87 res_offsets.resize(col_num.size());
88 UInt64 str_size = copy_str.size;
89 UInt64 col_size = col_num.size();
90 for (UInt64 i = 0; i < col_size; ++i)
91 {
92 data_size += str_size * col_num[i] + 1;
93 res_offsets[i] = data_size;
94 }
95 res_data.resize(data_size);
96 for (UInt64 i = 0; i < col_size; ++i)
97 {
98 T repeat_time = col_num[i];
99 checkRepeatTime(repeat_time);
100 process(
101 reinterpret_cast<UInt8 *>(const_cast<char *>(copy_str.data)),
102 res_data.data() + res_offsets[i - 1],
103 str_size + 1,
104 repeat_time);
105 }
106 }
107
108private:
109 static void process(const UInt8 * src, UInt8 * dst, UInt64 size, UInt64 repeat_time)
110 {
111 for (UInt64 i = 0; i < repeat_time; ++i)
112 {
113 memcpy(dst, src, size - 1);
114 dst += size - 1;
115 }
116 *dst = 0;
117 }
118};
119
120
121class FunctionRepeat : public IFunction
122{
123 template <typename F>
124 static bool castType(const IDataType * type, F && f)
125 {
126 return castTypeToEither<DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64>(type, std::forward<F>(f));
127 }
128
129public:
130 static constexpr auto name = "repeat";
131 static FunctionPtr create(const Context &) { return std::make_shared<FunctionRepeat>(); }
132
133 String getName() const override { return name; }
134
135 size_t getNumberOfArguments() const override { return 2; }
136
137 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
138 {
139 if (!isString(arguments[0]))
140 throw Exception(
141 "Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
142 if (!isUnsignedInteger(arguments[1]))
143 throw Exception(
144 "Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
145 return arguments[0];
146 }
147
148 bool useDefaultImplementationForConstants() const override { return true; }
149
150 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t) override
151 {
152 const auto & strcolumn = block.getByPosition(arguments[0]).column;
153 const auto & numcolumn = block.getByPosition(arguments[1]).column;
154
155 if (const ColumnString * col = checkAndGetColumn<ColumnString>(strcolumn.get()))
156 {
157 if (const ColumnConst * scale_column_num = checkAndGetColumn<ColumnConst>(numcolumn.get()))
158 {
159 UInt64 repeat_time = scale_column_num->getValue<UInt64>();
160 auto col_res = ColumnString::create();
161 RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time);
162 block.getByPosition(result).column = std::move(col_res);
163 return;
164 }
165 else if (castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type)
166 {
167 using DataType = std::decay_t<decltype(type)>;
168 using T = typename DataType::FieldType;
169 const ColumnVector<T> * colnum = checkAndGetColumn<ColumnVector<T>>(numcolumn.get());
170 auto col_res = ColumnString::create();
171 RepeatImpl::vectorStrVectorRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), colnum->getData());
172 block.getByPosition(result).column = std::move(col_res);
173 return true;
174 }))
175 {
176 return;
177 }
178 }
179 else if (const ColumnConst * col_const = checkAndGetColumn<ColumnConst>(strcolumn.get()))
180 {
181 /// Note that const-const case is handled by useDefaultImplementationForConstants.
182
183 StringRef copy_str = col_const->getDataColumn().getDataAt(0);
184
185 if (castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type)
186 {
187 using DataType = std::decay_t<decltype(type)>;
188 using T = typename DataType::FieldType;
189 const ColumnVector<T> * colnum = checkAndGetColumn<ColumnVector<T>>(numcolumn.get());
190 auto col_res = ColumnString::create();
191 RepeatImpl::constStrVectorRepeat(copy_str, col_res->getChars(), col_res->getOffsets(), colnum->getData());
192 block.getByPosition(result).column = std::move(col_res);
193 return true;
194 }))
195 {
196 return;
197 }
198 }
199
200 throw Exception(
201 "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
202 ErrorCodes::ILLEGAL_COLUMN);
203 }
204};
205
206
207void registerFunctionRepeat(FunctionFactory & factory)
208{
209 factory.registerFunction<FunctionRepeat>();
210}
211
212}
213