1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/FunctionHelpers.h> |
3 | #include <Functions/GatherUtils/GatherUtils.h> |
4 | #include <Functions/GatherUtils/Sources.h> |
5 | #include <DataTypes/DataTypeString.h> |
6 | #include <DataTypes/DataTypesNumber.h> |
7 | #include <Columns/ColumnString.h> |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | |
13 | using namespace GatherUtils; |
14 | |
15 | namespace ErrorCodes |
16 | { |
17 | extern const int ILLEGAL_COLUMN; |
18 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
19 | } |
20 | |
21 | struct NameStartsWith |
22 | { |
23 | static constexpr auto name = "startsWith" ; |
24 | }; |
25 | struct NameEndsWith |
26 | { |
27 | static constexpr auto name = "endsWith" ; |
28 | }; |
29 | |
30 | template <typename Name> |
31 | class FunctionStartsEndsWith : public IFunction |
32 | { |
33 | public: |
34 | static constexpr auto name = Name::name; |
35 | static FunctionPtr create(const Context &) |
36 | { |
37 | return std::make_shared<FunctionStartsEndsWith>(); |
38 | } |
39 | |
40 | String getName() const override |
41 | { |
42 | return name; |
43 | } |
44 | |
45 | size_t getNumberOfArguments() const override |
46 | { |
47 | return 2; |
48 | } |
49 | |
50 | bool useDefaultImplementationForConstants() const override |
51 | { |
52 | return true; |
53 | } |
54 | |
55 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
56 | { |
57 | if (!isStringOrFixedString(arguments[0])) |
58 | throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
59 | |
60 | if (!isStringOrFixedString(arguments[1])) |
61 | throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
62 | |
63 | return std::make_shared<DataTypeUInt8>(); |
64 | } |
65 | |
66 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
67 | { |
68 | const IColumn * haystack_column = block.getByPosition(arguments[0]).column.get(); |
69 | const IColumn * needle_column = block.getByPosition(arguments[1]).column.get(); |
70 | |
71 | auto col_res = ColumnVector<UInt8>::create(); |
72 | typename ColumnVector<UInt8>::Container & vec_res = col_res->getData(); |
73 | |
74 | vec_res.resize(input_rows_count); |
75 | |
76 | if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column)) |
77 | dispatch<StringSource>(StringSource(*haystack), needle_column, vec_res); |
78 | else if (const ColumnFixedString * haystack_fixed = checkAndGetColumn<ColumnFixedString>(haystack_column)) |
79 | dispatch<FixedStringSource>(FixedStringSource(*haystack_fixed), needle_column, vec_res); |
80 | else if (const ColumnConst * haystack_const = checkAndGetColumnConst<ColumnString>(haystack_column)) |
81 | dispatch<ConstSource<StringSource>>(ConstSource<StringSource>(*haystack_const), needle_column, vec_res); |
82 | else if (const ColumnConst * haystack_const_fixed = checkAndGetColumnConst<ColumnFixedString>(haystack_column)) |
83 | dispatch<ConstSource<FixedStringSource>>(ConstSource<FixedStringSource>(*haystack_const_fixed), needle_column, vec_res); |
84 | else |
85 | throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); |
86 | |
87 | block.getByPosition(result).column = std::move(col_res); |
88 | } |
89 | |
90 | private: |
91 | template <typename HaystackSource> |
92 | void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const |
93 | { |
94 | if (const ColumnString * needle = checkAndGetColumn<ColumnString>(needle_column)) |
95 | execute<HaystackSource, StringSource>(haystack_source, StringSource(*needle), res_data); |
96 | else if (const ColumnFixedString * needle_fixed = checkAndGetColumn<ColumnFixedString>(needle_column)) |
97 | execute<HaystackSource, FixedStringSource>(haystack_source, FixedStringSource(*needle_fixed), res_data); |
98 | else if (const ColumnConst * needle_const = checkAndGetColumnConst<ColumnString>(needle_column)) |
99 | execute<HaystackSource, ConstSource<StringSource>>(haystack_source, ConstSource<StringSource>(*needle_const), res_data); |
100 | else if (const ColumnConst * needle_const_fixed = checkAndGetColumnConst<ColumnFixedString>(needle_column)) |
101 | execute<HaystackSource, ConstSource<FixedStringSource>>(haystack_source, ConstSource<FixedStringSource>(*needle_const_fixed), res_data); |
102 | else |
103 | throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); |
104 | } |
105 | |
106 | template <typename HaystackSource, typename NeedleSource> |
107 | static void execute(HaystackSource haystack_source, NeedleSource needle_source, PaddedPODArray<UInt8> & res_data) |
108 | { |
109 | size_t row_num = 0; |
110 | |
111 | while (!haystack_source.isEnd()) |
112 | { |
113 | auto haystack = haystack_source.getWhole(); |
114 | auto needle = needle_source.getWhole(); |
115 | |
116 | if (needle.size > haystack.size) |
117 | { |
118 | res_data[row_num] = false; |
119 | } |
120 | else |
121 | { |
122 | if constexpr (std::is_same_v<Name, NameStartsWith>) |
123 | { |
124 | res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size); |
125 | } |
126 | else /// endsWith |
127 | { |
128 | res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size); |
129 | } |
130 | } |
131 | |
132 | haystack_source.next(); |
133 | needle_source.next(); |
134 | ++row_num; |
135 | } |
136 | } |
137 | }; |
138 | |
139 | } |
140 | |