1#include <DataTypes/DataTypeString.h>
2#include <Columns/ColumnString.h>
3#include <Columns/ColumnFixedString.h>
4#include <Columns/ColumnConst.h>
5#include <Functions/FunctionFactory.h>
6#include <Functions/FunctionHelpers.h>
7#include <Functions/IFunctionImpl.h>
8#include <Functions/GatherUtils/GatherUtils.h>
9#include <Functions/GatherUtils/Sources.h>
10#include <Functions/GatherUtils/Sinks.h>
11#include <Functions/GatherUtils/Slices.h>
12#include <Functions/GatherUtils/Algorithms.h>
13#include <IO/WriteHelpers.h>
14
15
16namespace DB
17{
18
19using namespace GatherUtils;
20
21namespace ErrorCodes
22{
23 extern const int ILLEGAL_COLUMN;
24 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
25 extern const int ARGUMENT_OUT_OF_BOUND;
26 extern const int ZERO_ARRAY_OR_TUPLE_INDEX;
27 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
28}
29
30
31/// If 'is_utf8' - measure offset and length in code points instead of bytes.
32/// UTF8 variant is not available for FixedString arguments.
33template <bool is_utf8>
34class FunctionSubstring : public IFunction
35{
36public:
37 static constexpr auto name = is_utf8 ? "substringUTF8" : "substring";
38 static FunctionPtr create(const Context &)
39 {
40 return std::make_shared<FunctionSubstring>();
41 }
42
43 String getName() const override
44 {
45 return name;
46 }
47
48 bool isVariadic() const override { return true; }
49 size_t getNumberOfArguments() const override { return 0; }
50
51 bool useDefaultImplementationForConstants() const override { return true; }
52
53 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
54 {
55 size_t number_of_arguments = arguments.size();
56
57 if (number_of_arguments < 2 || number_of_arguments > 3)
58 throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
59 + toString(number_of_arguments) + ", should be 2 or 3",
60 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
61
62 if ((is_utf8 && !isString(arguments[0])) || !isStringOrFixedString(arguments[0]))
63 throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
64
65 if (!isNativeNumber(arguments[1]))
66 throw Exception("Illegal type " + arguments[1]->getName()
67 + " of second argument of function "
68 + getName(),
69 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
70
71 if (number_of_arguments == 3 && !isNativeNumber(arguments[2]))
72 throw Exception("Illegal type " + arguments[2]->getName()
73 + " of second argument of function "
74 + getName(),
75 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
76
77 return std::make_shared<DataTypeString>();
78 }
79
80 template <typename Source>
81 void executeForSource(const ColumnPtr & column_start, const ColumnPtr & column_length,
82 const ColumnConst * column_start_const, const ColumnConst * column_length_const,
83 Int64 start_value, Int64 length_value, Block & block, size_t result, Source && source,
84 size_t input_rows_count)
85 {
86 auto col_res = ColumnString::create();
87
88 if (!column_length)
89 {
90 if (column_start_const)
91 {
92 if (start_value > 0)
93 sliceFromLeftConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), start_value - 1);
94 else if (start_value < 0)
95 sliceFromRightConstantOffsetUnbounded(source, StringSink(*col_res, input_rows_count), -start_value);
96 else
97 throw Exception("Indices in strings are 1-based", ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX);
98 }
99 else
100 sliceDynamicOffsetUnbounded(source, StringSink(*col_res, input_rows_count), *column_start);
101 }
102 else
103 {
104 if (column_start_const && column_length_const)
105 {
106 if (start_value > 0)
107 sliceFromLeftConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), start_value - 1, length_value);
108 else if (start_value < 0)
109 sliceFromRightConstantOffsetBounded(source, StringSink(*col_res, input_rows_count), -start_value, length_value);
110 else
111 throw Exception("Indices in strings are 1-based", ErrorCodes::ZERO_ARRAY_OR_TUPLE_INDEX);
112 }
113 else
114 sliceDynamicOffsetBounded(source, StringSink(*col_res, input_rows_count), *column_start, *column_length);
115 }
116
117 block.getByPosition(result).column = std::move(col_res);
118 }
119
120 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
121 {
122 size_t number_of_arguments = arguments.size();
123
124 ColumnPtr column_string = block.getByPosition(arguments[0]).column;
125 ColumnPtr column_start = block.getByPosition(arguments[1]).column;
126 ColumnPtr column_length;
127
128 if (number_of_arguments == 3)
129 column_length = block.getByPosition(arguments[2]).column;
130
131 const ColumnConst * column_start_const = checkAndGetColumn<ColumnConst>(column_start.get());
132 const ColumnConst * column_length_const = nullptr;
133
134 if (number_of_arguments == 3)
135 column_length_const = checkAndGetColumn<ColumnConst>(column_length.get());
136
137 Int64 start_value = 0;
138 Int64 length_value = 0;
139
140 if (column_start_const)
141 start_value = column_start_const->getInt(0);
142 if (column_length_const)
143 length_value = column_length_const->getInt(0);
144
145 if constexpr (is_utf8)
146 {
147 if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get()))
148 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
149 length_value, block, result, UTF8StringSource(*col), input_rows_count);
150 else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get()))
151 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
152 length_value, block, result, ConstSource<UTF8StringSource>(*col_const), input_rows_count);
153 else
154 throw Exception(
155 "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(),
156 ErrorCodes::ILLEGAL_COLUMN);
157 }
158 else
159 {
160 if (const ColumnString * col = checkAndGetColumn<ColumnString>(column_string.get()))
161 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
162 length_value, block, result, StringSource(*col), input_rows_count);
163 else if (const ColumnFixedString * col_fixed = checkAndGetColumn<ColumnFixedString>(column_string.get()))
164 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
165 length_value, block, result, FixedStringSource(*col_fixed), input_rows_count);
166 else if (const ColumnConst * col_const = checkAndGetColumnConst<ColumnString>(column_string.get()))
167 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
168 length_value, block, result, ConstSource<StringSource>(*col_const), input_rows_count);
169 else if (const ColumnConst * col_const_fixed = checkAndGetColumnConst<ColumnFixedString>(column_string.get()))
170 executeForSource(column_start, column_length, column_start_const, column_length_const, start_value,
171 length_value, block, result, ConstSource<FixedStringSource>(*col_const_fixed), input_rows_count);
172 else
173 throw Exception(
174 "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(),
175 ErrorCodes::ILLEGAL_COLUMN);
176 }
177 }
178};
179
180void registerFunctionSubstring(FunctionFactory & factory)
181{
182 factory.registerFunction<FunctionSubstring<false>>(FunctionFactory::CaseInsensitive);
183 factory.registerAlias("substr", "substring", FunctionFactory::CaseInsensitive);
184 factory.registerAlias("mid", "substring", FunctionFactory::CaseInsensitive); /// from MySQL dialect
185
186 factory.registerFunction<FunctionSubstring<true>>(FunctionFactory::CaseSensitive);
187}
188
189}
190