1 | #include <Columns/ColumnString.h> |
2 | #include <DataTypes/DataTypeString.h> |
3 | #include <DataTypes/getLeastSupertype.h> |
4 | #include <Functions/FunctionFactory.h> |
5 | #include <Functions/FunctionHelpers.h> |
6 | #include <Functions/GatherUtils/Algorithms.h> |
7 | #include <Functions/GatherUtils/GatherUtils.h> |
8 | #include <Functions/GatherUtils/Sinks.h> |
9 | #include <Functions/GatherUtils/Slices.h> |
10 | #include <Functions/GatherUtils/Sources.h> |
11 | #include <Functions/IFunctionImpl.h> |
12 | #include <IO/WriteHelpers.h> |
13 | #include <ext/map.h> |
14 | #include <ext/range.h> |
15 | |
16 | #include "formatString.h" |
17 | |
18 | namespace DB |
19 | { |
20 | namespace ErrorCodes |
21 | { |
22 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
23 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
24 | extern const int ILLEGAL_COLUMN; |
25 | } |
26 | |
27 | using namespace GatherUtils; |
28 | |
29 | |
30 | template <typename Name, bool is_injective> |
31 | class ConcatImpl : public IFunction |
32 | { |
33 | public: |
34 | static constexpr auto name = Name::name; |
35 | ConcatImpl(const Context & context_) : context(context_) {} |
36 | static FunctionPtr create(const Context & context) { return std::make_shared<ConcatImpl>(context); } |
37 | |
38 | String getName() const override { return name; } |
39 | |
40 | bool isVariadic() const override { return true; } |
41 | |
42 | size_t getNumberOfArguments() const override { return 0; } |
43 | |
44 | bool isInjective(const Block &) override { return is_injective; } |
45 | |
46 | bool useDefaultImplementationForConstants() const override { return true; } |
47 | |
48 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
49 | { |
50 | if (arguments.size() < 2) |
51 | throw Exception( |
52 | "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) |
53 | + ", should be at least 2." , |
54 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
55 | |
56 | if (arguments.size() > FormatImpl::argument_threshold) |
57 | throw Exception( |
58 | "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) |
59 | + ", should be at most " + std::to_string(FormatImpl::argument_threshold), |
60 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
61 | |
62 | for (const auto arg_idx : ext::range(0, arguments.size())) |
63 | { |
64 | const auto arg = arguments[arg_idx].get(); |
65 | if (!isStringOrFixedString(arg)) |
66 | throw Exception{"Illegal type " + arg->getName() + " of argument " + std::to_string(arg_idx + 1) + " of function " |
67 | + getName(), |
68 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
69 | } |
70 | |
71 | return std::make_shared<DataTypeString>(); |
72 | } |
73 | |
74 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
75 | { |
76 | /// Format function is not proven to be faster for two arguments. |
77 | /// Actually there is overhead of 2 to 5 extra instructions for each string for checking empty strings in FormatImpl. |
78 | /// Though, benchmarks are really close, for most examples we saw executeBinary is slightly faster (0-3%). |
79 | /// For 3 and more arguments FormatImpl is much faster (up to 50-60%). |
80 | if (arguments.size() == 2) |
81 | executeBinary(block, arguments, result, input_rows_count); |
82 | else |
83 | executeFormatImpl(block, arguments, result, input_rows_count); |
84 | } |
85 | |
86 | private: |
87 | const Context & context; |
88 | |
89 | void executeBinary(Block & block, const ColumnNumbers & arguments, const size_t result, size_t input_rows_count) |
90 | { |
91 | const IColumn * c0 = block.getByPosition(arguments[0]).column.get(); |
92 | const IColumn * c1 = block.getByPosition(arguments[1]).column.get(); |
93 | |
94 | const ColumnString * c0_string = checkAndGetColumn<ColumnString>(c0); |
95 | const ColumnString * c1_string = checkAndGetColumn<ColumnString>(c1); |
96 | const ColumnConst * c0_const_string = checkAndGetColumnConst<ColumnString>(c0); |
97 | const ColumnConst * c1_const_string = checkAndGetColumnConst<ColumnString>(c1); |
98 | |
99 | auto c_res = ColumnString::create(); |
100 | |
101 | if (c0_string && c1_string) |
102 | concat(StringSource(*c0_string), StringSource(*c1_string), StringSink(*c_res, c0->size())); |
103 | else if (c0_string && c1_const_string) |
104 | concat(StringSource(*c0_string), ConstSource<StringSource>(*c1_const_string), StringSink(*c_res, c0->size())); |
105 | else if (c0_const_string && c1_string) |
106 | concat(ConstSource<StringSource>(*c0_const_string), StringSource(*c1_string), StringSink(*c_res, c0->size())); |
107 | else |
108 | { |
109 | /// Fallback: use generic implementation for not very important cases. |
110 | executeFormatImpl(block, arguments, result, input_rows_count); |
111 | return; |
112 | } |
113 | |
114 | block.getByPosition(result).column = std::move(c_res); |
115 | } |
116 | |
117 | void executeFormatImpl(Block & block, const ColumnNumbers & arguments, const size_t result, size_t input_rows_count) |
118 | { |
119 | auto c_res = ColumnString::create(); |
120 | std::vector<const ColumnString::Chars *> data(arguments.size()); |
121 | std::vector<const ColumnString::Offsets *> offsets(arguments.size()); |
122 | std::vector<size_t> fixed_string_N(arguments.size()); |
123 | std::vector<String> constant_strings(arguments.size()); |
124 | bool has_column_string = false; |
125 | bool has_column_fixed_string = false; |
126 | for (size_t i = 0; i < arguments.size(); ++i) |
127 | { |
128 | const ColumnPtr & column = block.getByPosition(arguments[i]).column; |
129 | if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get())) |
130 | { |
131 | has_column_string = true; |
132 | data[i] = &col->getChars(); |
133 | offsets[i] = &col->getOffsets(); |
134 | } |
135 | else if (const ColumnFixedString * fixed_col = checkAndGetColumn<ColumnFixedString>(column.get())) |
136 | { |
137 | has_column_fixed_string = true; |
138 | data[i] = &fixed_col->getChars(); |
139 | fixed_string_N[i] = fixed_col->getN(); |
140 | } |
141 | else if (const ColumnConst * const_col = checkAndGetColumnConstStringOrFixedString(column.get())) |
142 | { |
143 | constant_strings[i] = const_col->getValue<String>(); |
144 | } |
145 | else |
146 | throw Exception( |
147 | "Illegal column " + column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); |
148 | } |
149 | |
150 | String pattern; |
151 | pattern.reserve(2 * arguments.size()); |
152 | |
153 | for (size_t i = 0; i < arguments.size(); ++i) |
154 | pattern += "{}" ; |
155 | |
156 | FormatImpl::formatExecute( |
157 | has_column_string, |
158 | has_column_fixed_string, |
159 | std::move(pattern), |
160 | data, |
161 | offsets, |
162 | fixed_string_N, |
163 | constant_strings, |
164 | c_res->getChars(), |
165 | c_res->getOffsets(), |
166 | input_rows_count); |
167 | |
168 | block.getByPosition(result).column = std::move(c_res); |
169 | } |
170 | }; |
171 | |
172 | |
173 | struct NameConcat |
174 | { |
175 | static constexpr auto name = "concat" ; |
176 | }; |
177 | struct NameConcatAssumeInjective |
178 | { |
179 | static constexpr auto name = "concatAssumeInjective" ; |
180 | }; |
181 | |
182 | using FunctionConcat = ConcatImpl<NameConcat, false>; |
183 | using FunctionConcatAssumeInjective = ConcatImpl<NameConcatAssumeInjective, true>; |
184 | |
185 | |
186 | /// Also works with arrays. |
187 | class ConcatOverloadResolver : public IFunctionOverloadResolverImpl |
188 | { |
189 | public: |
190 | static constexpr auto name = "concat" ; |
191 | static FunctionOverloadResolverImplPtr create(const Context & context) { return std::make_unique<ConcatOverloadResolver>(context); } |
192 | |
193 | explicit ConcatOverloadResolver(const Context & context_) : context(context_) {} |
194 | |
195 | String getName() const override { return name; } |
196 | size_t getNumberOfArguments() const override { return 0; } |
197 | bool isVariadic() const override { return true; } |
198 | |
199 | FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override |
200 | { |
201 | if (isArray(arguments.at(0).type)) |
202 | { |
203 | return FunctionOverloadResolverAdaptor(FunctionFactory::instance().getImpl("arrayConcat" , context)).buildImpl(arguments); |
204 | } |
205 | else |
206 | return std::make_unique<DefaultFunction>( |
207 | FunctionConcat::create(context), ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), return_type); |
208 | } |
209 | |
210 | DataTypePtr getReturnType(const DataTypes & arguments) const override |
211 | { |
212 | if (arguments.size() < 2) |
213 | throw Exception( |
214 | "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) |
215 | + ", should be at least 2." , |
216 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
217 | |
218 | /// We always return Strings from concat, even if arguments were fixed strings. |
219 | return std::make_shared<DataTypeString>(); |
220 | } |
221 | |
222 | private: |
223 | const Context & context; |
224 | }; |
225 | |
226 | |
227 | void registerFunctionsConcat(FunctionFactory & factory) |
228 | { |
229 | factory.registerFunction<ConcatOverloadResolver>(FunctionFactory::CaseInsensitive); |
230 | factory.registerFunction<FunctionConcatAssumeInjective>(); |
231 | } |
232 | |
233 | } |
234 | |