1 | #include <DataTypes/DataTypeString.h> |
2 | #include <Columns/ColumnString.h> |
3 | #include <Columns/ColumnFixedString.h> |
4 | #include <Columns/ColumnConst.h> |
5 | #include <Functions/FunctionFactory.h> |
6 | #include <Functions/FunctionHelpers.h> |
7 | #include <Functions/IFunctionImpl.h> |
8 | #include <Functions/GatherUtils/GatherUtils.h> |
9 | #include <Functions/GatherUtils/Sources.h> |
10 | #include <Functions/GatherUtils/Sinks.h> |
11 | #include <Functions/GatherUtils/Slices.h> |
12 | #include <Functions/GatherUtils/Algorithms.h> |
13 | #include <IO/WriteHelpers.h> |
14 | |
15 | |
16 | namespace DB |
17 | { |
18 | |
19 | using namespace GatherUtils; |
20 | |
21 | namespace ErrorCodes |
22 | { |
23 | extern const int ILLEGAL_COLUMN; |
24 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
25 | extern const int ARGUMENT_OUT_OF_BOUND; |
26 | extern const int ZERO_ARRAY_OR_TUPLE_INDEX; |
27 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
28 | } |
29 | |
30 | |
31 | /// If 'is_utf8' - measure offset and length in code points instead of bytes. |
32 | /// UTF8 variant is not available for FixedString arguments. |
33 | template <bool is_utf8> |
34 | class FunctionSubstring : public IFunction |
35 | { |
36 | public: |
37 | static constexpr auto name = is_utf8 ? "substringUTF8" : "substring" ; |
38 | static FunctionPtr create(const Context &) |
39 | { |
40 | return std::make_shared<FunctionSubstring>(); |
41 | } |
42 | |
43 | String getName() const override |
44 | { |
45 | return name; |
46 | } |
47 | |
48 | bool isVariadic() const override { return true; } |
49 | size_t getNumberOfArguments() const override { return 0; } |
50 | |
51 | bool useDefaultImplementationForConstants() const override { return true; } |
52 | |
53 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
54 | { |
55 | size_t number_of_arguments = arguments.size(); |
56 | |
57 | if (number_of_arguments < 2 || number_of_arguments > 3) |
58 | throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " |
59 | + toString(number_of_arguments) + ", should be 2 or 3" , |
60 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
61 | |
62 | if ((is_utf8 && !isString(arguments[0])) || !isStringOrFixedString(arguments[0])) |
63 | throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
64 | |
65 | if (!isNativeNumber(arguments[1])) |
66 | throw Exception("Illegal type " + arguments[1]->getName() |
67 | + " of second argument of function " |
68 | + getName(), |
69 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
70 | |
71 | if (number_of_arguments == 3 && !isNativeNumber(arguments[2])) |
72 | throw Exception("Illegal type " + arguments[2]->getName() |
73 | + " of second argument of function " |
74 | + getName(), |
75 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
76 | |
77 | return std::make_shared<DataTypeString>(); |
78 | } |
79 | |
80 | template <typename Source> |
81 | void executeForSource(const ColumnPtr & column_start, const ColumnPtr & column_length, |
82 | const ColumnConst * column_start_const, const ColumnConst * column_length_const, |
83 | Int64 start_value, Int64 length_value, Block & block, size_t result, Source && source, |
84 | size_t input_rows_count) |
85 | { |
86 | auto col_res = ColumnString::create(); |
87 | |
88 | if (!column_length) |
89 | { |
90 | if (column_start_const) |
91 | { |
92 | if (start_value > 0) |
93 | sliceFromLeftConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), start_value - 1); |
94 | else if (start_value < 0) |
95 | sliceFromRightConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), -start_value); |
96 | else |
97 | throw Exception("Indices in strings are 1-based" , ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX); |
98 | } |
99 | else |
100 | sliceDynamicOffsetUnbounded(source, StringSink(*col_res, input_rows_count), *column_start); |
101 | } |
102 | else |
103 | { |
104 | if (column_start_const && column_length_const) |
105 | { |
106 | if (start_value > 0) |
107 | sliceFromLeftConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), start_value - 1, length_value); |
108 | else if (start_value < 0) |
109 | sliceFromRightConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), -start_value, length_value); |
110 | else |
111 | throw Exception("Indices in strings are 1-based" , ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX); |
112 | } |
113 | else |
114 | sliceDynamicOffsetBounded(source, StringSink(*col_res, input_rows_count), *column_start, *column_length); |
115 | } |
116 | |
117 | block.getByPosition(result).column = std::move(col_res); |
118 | } |
119 | |
120 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
121 | { |
122 | size_t number_of_arguments = arguments.size(); |
123 | |
124 | ColumnPtr column_string = block.getByPosition(arguments[0]).column; |
125 | ColumnPtr column_start = block.getByPosition(arguments[1]).column; |
126 | ColumnPtr column_length; |
127 | |
128 | if (number_of_arguments == 3) |
129 | column_length = block.getByPosition(arguments[2]).column; |
130 | |
131 | const ColumnConst * column_start_const = checkAndGetColumn<ColumnConst>(column_start.get()); |
132 | const ColumnConst * column_length_const = nullptr; |
133 | |
134 | if (number_of_arguments == 3) |
135 | column_length_const = checkAndGetColumn<ColumnConst>(column_length.get()); |
136 | |
137 | Int64 start_value = 0; |
138 | Int64 length_value = 0; |
139 | |
140 | if (column_start_const) |
141 | start_value = column_start_const->getInt(0); |
142 | if (column_length_const) |
143 | length_value = column_length_const->getInt(0); |
144 | |
145 | if constexpr (is_utf8) |
146 | { |
147 | if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get())) |
148 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
149 | length_value, block, result, UTF8StringSource(*col), input_rows_count); |
150 | else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get())) |
151 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
152 | length_value, block, result, ConstSource<UTF8StringSource>(*col_const), input_rows_count); |
153 | else |
154 | throw Exception( |
155 | "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), |
156 | ErrorCodes::ILLEGAL_COLUMN); |
157 | } |
158 | else |
159 | { |
160 | if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get())) |
161 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
162 | length_value, block, result, StringSource(*col), input_rows_count); |
163 | else if (const ColumnFixedString * col_fixed = checkAndGetColumn<ColumnFixedString>(column_string.get())) |
164 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
165 | length_value, block, result, FixedStringSource(*col_fixed), input_rows_count); |
166 | else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get())) |
167 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
168 | length_value, block, result, ConstSource<StringSource>(*col_const), input_rows_count); |
169 | else if (const ColumnConst * col_const_fixed = checkAndGetColumnConst<ColumnFixedString>(column_string.get())) |
170 | executeForSource(column_start, column_length, column_start_const, column_length_const, start_value, |
171 | length_value, block, result, ConstSource<FixedStringSource>(*col_const_fixed), input_rows_count); |
172 | else |
173 | throw Exception( |
174 | "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), |
175 | ErrorCodes::ILLEGAL_COLUMN); |
176 | } |
177 | } |
178 | }; |
179 | |
180 | void registerFunctionSubstring(FunctionFactory & factory) |
181 | { |
182 | factory.registerFunction<FunctionSubstring<false>>(FunctionFactory::CaseInsensitive); |
183 | factory.registerAlias("substr" , "substring" , FunctionFactory::CaseInsensitive); |
184 | factory.registerAlias("mid" , "substring" , FunctionFactory::CaseInsensitive); /// from MySQL dialect |
185 | |
186 | factory.registerFunction<FunctionSubstring<true>>(FunctionFactory::CaseSensitive); |
187 | } |
188 | |
189 | } |
190 | |