1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionHelpers.h>
3#include <Functions/GatherUtils/GatherUtils.h>
4#include <Functions/GatherUtils/Sources.h>
5#include <DataTypes/DataTypeString.h>
6#include <DataTypes/DataTypesNumber.h>
7#include <Columns/ColumnString.h>
8
9
10namespace DB
11{
12
13using namespace GatherUtils;
14
15namespace ErrorCodes
16{
17 extern const int ILLEGAL_COLUMN;
18 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
19}
20
21struct NameStartsWith
22{
23 static constexpr auto name = "startsWith";
24};
25struct NameEndsWith
26{
27 static constexpr auto name = "endsWith";
28};
29
30template <typename Name>
31class FunctionStartsEndsWith : public IFunction
32{
33public:
34 static constexpr auto name = Name::name;
35 static FunctionPtr create(const Context &)
36 {
37 return std::make_shared<FunctionStartsEndsWith>();
38 }
39
40 String getName() const override
41 {
42 return name;
43 }
44
45 size_t getNumberOfArguments() const override
46 {
47 return 2;
48 }
49
50 bool useDefaultImplementationForConstants() const override
51 {
52 return true;
53 }
54
55 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
56 {
57 if (!isStringOrFixedString(arguments[0]))
58 throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
59
60 if (!isStringOrFixedString(arguments[1]))
61 throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
62
63 return std::make_shared<DataTypeUInt8>();
64 }
65
66 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
67 {
68 const IColumn * haystack_column = block.getByPosition(arguments[0]).column.get();
69 const IColumn * needle_column = block.getByPosition(arguments[1]).column.get();
70
71 auto col_res = ColumnVector<UInt8>::create();
72 typename ColumnVector<UInt8>::Container & vec_res = col_res->getData();
73
74 vec_res.resize(input_rows_count);
75
76 if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column))
77 dispatch<StringSource>(StringSource(*haystack), needle_column, vec_res);
78 else if (const ColumnFixedString * haystack_fixed = checkAndGetColumn<ColumnFixedString>(haystack_column))
79 dispatch<FixedStringSource>(FixedStringSource(*haystack_fixed), needle_column, vec_res);
80 else if (const ColumnConst * haystack_const = checkAndGetColumnConst<ColumnString>(haystack_column))
81 dispatch<ConstSource<StringSource>>(ConstSource<StringSource>(*haystack_const), needle_column, vec_res);
82 else if (const ColumnConst * haystack_const_fixed = checkAndGetColumnConst<ColumnFixedString>(haystack_column))
83 dispatch<ConstSource<FixedStringSource>>(ConstSource<FixedStringSource>(*haystack_const_fixed), needle_column, vec_res);
84 else
85 throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
86
87 block.getByPosition(result).column = std::move(col_res);
88 }
89
90private:
91 template <typename HaystackSource>
92 void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const
93 {
94 if (const ColumnString * needle = checkAndGetColumn<ColumnString>(needle_column))
95 execute<HaystackSource, StringSource>(haystack_source, StringSource(*needle), res_data);
96 else if (const ColumnFixedString * needle_fixed = checkAndGetColumn<ColumnFixedString>(needle_column))
97 execute<HaystackSource, FixedStringSource>(haystack_source, FixedStringSource(*needle_fixed), res_data);
98 else if (const ColumnConst * needle_const = checkAndGetColumnConst<ColumnString>(needle_column))
99 execute<HaystackSource, ConstSource<StringSource>>(haystack_source, ConstSource<StringSource>(*needle_const), res_data);
100 else if (const ColumnConst * needle_const_fixed = checkAndGetColumnConst<ColumnFixedString>(needle_column))
101 execute<HaystackSource, ConstSource<FixedStringSource>>(haystack_source, ConstSource<FixedStringSource>(*needle_const_fixed), res_data);
102 else
103 throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
104 }
105
106 template <typename HaystackSource, typename NeedleSource>
107 static void execute(HaystackSource haystack_source, NeedleSource needle_source, PaddedPODArray<UInt8> & res_data)
108 {
109 size_t row_num = 0;
110
111 while (!haystack_source.isEnd())
112 {
113 auto haystack = haystack_source.getWhole();
114 auto needle = needle_source.getWhole();
115
116 if (needle.size > haystack.size)
117 {
118 res_data[row_num] = false;
119 }
120 else
121 {
122 if constexpr (std::is_same_v<Name, NameStartsWith>)
123 {
124 res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size);
125 }
126 else /// endsWith
127 {
128 res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size);
129 }
130 }
131
132 haystack_source.next();
133 needle_source.next();
134 ++row_num;
135 }
136 }
137};
138
139}
140