1#include <Columns/ColumnString.h>
2#include <DataTypes/DataTypeString.h>
3#include <DataTypes/getLeastSupertype.h>
4#include <Functions/FunctionFactory.h>
5#include <Functions/FunctionHelpers.h>
6#include <Functions/GatherUtils/Algorithms.h>
7#include <Functions/GatherUtils/GatherUtils.h>
8#include <Functions/GatherUtils/Sinks.h>
9#include <Functions/GatherUtils/Slices.h>
10#include <Functions/GatherUtils/Sources.h>
11#include <Functions/IFunctionImpl.h>
12#include <IO/WriteHelpers.h>
13#include <ext/map.h>
14#include <ext/range.h>
15
16#include "formatString.h"
17
18namespace DB
19{
20namespace ErrorCodes
21{
22 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
24 extern const int ILLEGAL_COLUMN;
25}
26
27using namespace GatherUtils;
28
29
30template <typename Name, bool is_injective>
31class ConcatImpl : public IFunction
32{
33public:
34 static constexpr auto name = Name::name;
35 ConcatImpl(const Context & context_) : context(context_) {}
36 static FunctionPtr create(const Context & context) { return std::make_shared<ConcatImpl>(context); }
37
38 String getName() const override { return name; }
39
40 bool isVariadic() const override { return true; }
41
42 size_t getNumberOfArguments() const override { return 0; }
43
44 bool isInjective(const Block &) override { return is_injective; }
45
46 bool useDefaultImplementationForConstants() const override { return true; }
47
48 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
49 {
50 if (arguments.size() < 2)
51 throw Exception(
52 "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
53 + ", should be at least 2.",
54 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
55
56 if (arguments.size() > FormatImpl::argument_threshold)
57 throw Exception(
58 "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
59 + ", should be at most " + std::to_string(FormatImpl::argument_threshold),
60 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
61
62 for (const auto arg_idx : ext::range(0, arguments.size()))
63 {
64 const auto arg = arguments[arg_idx].get();
65 if (!isStringOrFixedString(arg))
66 throw Exception{"Illegal type " + arg->getName() + " of argument " + std::to_string(arg_idx + 1) + " of function "
67 + getName(),
68 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
69 }
70
71 return std::make_shared<DataTypeString>();
72 }
73
74 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
75 {
76 /// Format function is not proven to be faster for two arguments.
77 /// Actually there is overhead of 2 to 5 extra instructions for each string for checking empty strings in FormatImpl.
78 /// Though, benchmarks are really close, for most examples we saw executeBinary is slightly faster (0-3%).
79 /// For 3 and more arguments FormatImpl is much faster (up to 50-60%).
80 if (arguments.size() == 2)
81 executeBinary(block, arguments, result, input_rows_count);
82 else
83 executeFormatImpl(block, arguments, result, input_rows_count);
84 }
85
86private:
87 const Context & context;
88
89 void executeBinary(Block & block, const ColumnNumbers & arguments, const size_t result, size_t input_rows_count)
90 {
91 const IColumn * c0 = block.getByPosition(arguments[0]).column.get();
92 const IColumn * c1 = block.getByPosition(arguments[1]).column.get();
93
94 const ColumnString * c0_string = checkAndGetColumn<ColumnString>(c0);
95 const ColumnString * c1_string = checkAndGetColumn<ColumnString>(c1);
96 const ColumnConst * c0_const_string = checkAndGetColumnConst<ColumnString>(c0);
97 const ColumnConst * c1_const_string = checkAndGetColumnConst<ColumnString>(c1);
98
99 auto c_res = ColumnString::create();
100
101 if (c0_string && c1_string)
102 concat(StringSource(*c0_string), StringSource(*c1_string), StringSink(*c_res, c0->size()));
103 else if (c0_string && c1_const_string)
104 concat(StringSource(*c0_string), ConstSource<StringSource>(*c1_const_string), StringSink(*c_res, c0->size()));
105 else if (c0_const_string && c1_string)
106 concat(ConstSource<StringSource>(*c0_const_string), StringSource(*c1_string), StringSink(*c_res, c0->size()));
107 else
108 {
109 /// Fallback: use generic implementation for not very important cases.
110 executeFormatImpl(block, arguments, result, input_rows_count);
111 return;
112 }
113
114 block.getByPosition(result).column = std::move(c_res);
115 }
116
117 void executeFormatImpl(Block & block, const ColumnNumbers & arguments, const size_t result, size_t input_rows_count)
118 {
119 auto c_res = ColumnString::create();
120 std::vector<const ColumnString::Chars *> data(arguments.size());
121 std::vector<const ColumnString::Offsets *> offsets(arguments.size());
122 std::vector<size_t> fixed_string_N(arguments.size());
123 std::vector<String> constant_strings(arguments.size());
124 bool has_column_string = false;
125 bool has_column_fixed_string = false;
126 for (size_t i = 0; i < arguments.size(); ++i)
127 {
128 const ColumnPtr & column = block.getByPosition(arguments[i]).column;
129 if (const ColumnString * col = checkAndGetColumn<ColumnString>(column.get()))
130 {
131 has_column_string = true;
132 data[i] = &col->getChars();
133 offsets[i] = &col->getOffsets();
134 }
135 else if (const ColumnFixedString * fixed_col = checkAndGetColumn<ColumnFixedString>(column.get()))
136 {
137 has_column_fixed_string = true;
138 data[i] = &fixed_col->getChars();
139 fixed_string_N[i] = fixed_col->getN();
140 }
141 else if (const ColumnConst * const_col = checkAndGetColumnConstStringOrFixedString(column.get()))
142 {
143 constant_strings[i] = const_col->getValue<String>();
144 }
145 else
146 throw Exception(
147 "Illegal column " + column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
148 }
149
150 String pattern;
151 pattern.reserve(2 * arguments.size());
152
153 for (size_t i = 0; i < arguments.size(); ++i)
154 pattern += "{}";
155
156 FormatImpl::formatExecute(
157 has_column_string,
158 has_column_fixed_string,
159 std::move(pattern),
160 data,
161 offsets,
162 fixed_string_N,
163 constant_strings,
164 c_res->getChars(),
165 c_res->getOffsets(),
166 input_rows_count);
167
168 block.getByPosition(result).column = std::move(c_res);
169 }
170};
171
172
173struct NameConcat
174{
175 static constexpr auto name = "concat";
176};
177struct NameConcatAssumeInjective
178{
179 static constexpr auto name = "concatAssumeInjective";
180};
181
182using FunctionConcat = ConcatImpl<NameConcat, false>;
183using FunctionConcatAssumeInjective = ConcatImpl<NameConcatAssumeInjective, true>;
184
185
186/// Also works with arrays.
187class ConcatOverloadResolver : public IFunctionOverloadResolverImpl
188{
189public:
190 static constexpr auto name = "concat";
191 static FunctionOverloadResolverImplPtr create(const Context & context) { return std::make_unique<ConcatOverloadResolver>(context); }
192
193 explicit ConcatOverloadResolver(const Context & context_) : context(context_) {}
194
195 String getName() const override { return name; }
196 size_t getNumberOfArguments() const override { return 0; }
197 bool isVariadic() const override { return true; }
198
199 FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
200 {
201 if (isArray(arguments.at(0).type))
202 {
203 return FunctionOverloadResolverAdaptor(FunctionFactory::instance().getImpl("arrayConcat", context)).buildImpl(arguments);
204 }
205 else
206 return std::make_unique<DefaultFunction>(
207 FunctionConcat::create(context), ext::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), return_type);
208 }
209
210 DataTypePtr getReturnType(const DataTypes & arguments) const override
211 {
212 if (arguments.size() < 2)
213 throw Exception(
214 "Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size())
215 + ", should be at least 2.",
216 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
217
218 /// We always return Strings from concat, even if arguments were fixed strings.
219 return std::make_shared<DataTypeString>();
220 }
221
222private:
223 const Context & context;
224};
225
226
227void registerFunctionsConcat(FunctionFactory & factory)
228{
229 factory.registerFunction<ConcatOverloadResolver>(FunctionFactory::CaseInsensitive);
230 factory.registerFunction<FunctionConcatAssumeInjective>();
231}
232
233}
234