| 1 | #include <Functions/IFunctionImpl.h> |
| 2 | #include <Functions/FunctionHelpers.h> |
| 3 | #include <Functions/GatherUtils/GatherUtils.h> |
| 4 | #include <Functions/GatherUtils/Sources.h> |
| 5 | #include <DataTypes/DataTypeString.h> |
| 6 | #include <DataTypes/DataTypesNumber.h> |
| 7 | #include <Columns/ColumnString.h> |
| 8 | |
| 9 | |
| 10 | namespace DB |
| 11 | { |
| 12 | |
| 13 | using namespace GatherUtils; |
| 14 | |
| 15 | namespace ErrorCodes |
| 16 | { |
| 17 | extern const int ILLEGAL_COLUMN; |
| 18 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| 19 | } |
| 20 | |
| 21 | struct NameStartsWith |
| 22 | { |
| 23 | static constexpr auto name = "startsWith" ; |
| 24 | }; |
| 25 | struct NameEndsWith |
| 26 | { |
| 27 | static constexpr auto name = "endsWith" ; |
| 28 | }; |
| 29 | |
| 30 | template <typename Name> |
| 31 | class FunctionStartsEndsWith : public IFunction |
| 32 | { |
| 33 | public: |
| 34 | static constexpr auto name = Name::name; |
| 35 | static FunctionPtr create(const Context &) |
| 36 | { |
| 37 | return std::make_shared<FunctionStartsEndsWith>(); |
| 38 | } |
| 39 | |
| 40 | String getName() const override |
| 41 | { |
| 42 | return name; |
| 43 | } |
| 44 | |
| 45 | size_t getNumberOfArguments() const override |
| 46 | { |
| 47 | return 2; |
| 48 | } |
| 49 | |
| 50 | bool useDefaultImplementationForConstants() const override |
| 51 | { |
| 52 | return true; |
| 53 | } |
| 54 | |
| 55 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
| 56 | { |
| 57 | if (!isStringOrFixedString(arguments[0])) |
| 58 | throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 59 | |
| 60 | if (!isStringOrFixedString(arguments[1])) |
| 61 | throw Exception("Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 62 | |
| 63 | return std::make_shared<DataTypeUInt8>(); |
| 64 | } |
| 65 | |
| 66 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
| 67 | { |
| 68 | const IColumn * haystack_column = block.getByPosition(arguments[0]).column.get(); |
| 69 | const IColumn * needle_column = block.getByPosition(arguments[1]).column.get(); |
| 70 | |
| 71 | auto col_res = ColumnVector<UInt8>::create(); |
| 72 | typename ColumnVector<UInt8>::Container & vec_res = col_res->getData(); |
| 73 | |
| 74 | vec_res.resize(input_rows_count); |
| 75 | |
| 76 | if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column)) |
| 77 | dispatch<StringSource>(StringSource(*haystack), needle_column, vec_res); |
| 78 | else if (const ColumnFixedString * haystack_fixed = checkAndGetColumn<ColumnFixedString>(haystack_column)) |
| 79 | dispatch<FixedStringSource>(FixedStringSource(*haystack_fixed), needle_column, vec_res); |
| 80 | else if (const ColumnConst * haystack_const = checkAndGetColumnConst<ColumnString>(haystack_column)) |
| 81 | dispatch<ConstSource<StringSource>>(ConstSource<StringSource>(*haystack_const), needle_column, vec_res); |
| 82 | else if (const ColumnConst * haystack_const_fixed = checkAndGetColumnConst<ColumnFixedString>(haystack_column)) |
| 83 | dispatch<ConstSource<FixedStringSource>>(ConstSource<FixedStringSource>(*haystack_const_fixed), needle_column, vec_res); |
| 84 | else |
| 85 | throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); |
| 86 | |
| 87 | block.getByPosition(result).column = std::move(col_res); |
| 88 | } |
| 89 | |
| 90 | private: |
| 91 | template <typename HaystackSource> |
| 92 | void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const |
| 93 | { |
| 94 | if (const ColumnString * needle = checkAndGetColumn<ColumnString>(needle_column)) |
| 95 | execute<HaystackSource, StringSource>(haystack_source, StringSource(*needle), res_data); |
| 96 | else if (const ColumnFixedString * needle_fixed = checkAndGetColumn<ColumnFixedString>(needle_column)) |
| 97 | execute<HaystackSource, FixedStringSource>(haystack_source, FixedStringSource(*needle_fixed), res_data); |
| 98 | else if (const ColumnConst * needle_const = checkAndGetColumnConst<ColumnString>(needle_column)) |
| 99 | execute<HaystackSource, ConstSource<StringSource>>(haystack_source, ConstSource<StringSource>(*needle_const), res_data); |
| 100 | else if (const ColumnConst * needle_const_fixed = checkAndGetColumnConst<ColumnFixedString>(needle_column)) |
| 101 | execute<HaystackSource, ConstSource<FixedStringSource>>(haystack_source, ConstSource<FixedStringSource>(*needle_const_fixed), res_data); |
| 102 | else |
| 103 | throw Exception("Illegal combination of columns as arguments of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); |
| 104 | } |
| 105 | |
| 106 | template <typename HaystackSource, typename NeedleSource> |
| 107 | static void execute(HaystackSource haystack_source, NeedleSource needle_source, PaddedPODArray<UInt8> & res_data) |
| 108 | { |
| 109 | size_t row_num = 0; |
| 110 | |
| 111 | while (!haystack_source.isEnd()) |
| 112 | { |
| 113 | auto haystack = haystack_source.getWhole(); |
| 114 | auto needle = needle_source.getWhole(); |
| 115 | |
| 116 | if (needle.size > haystack.size) |
| 117 | { |
| 118 | res_data[row_num] = false; |
| 119 | } |
| 120 | else |
| 121 | { |
| 122 | if constexpr (std::is_same_v<Name, NameStartsWith>) |
| 123 | { |
| 124 | res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size); |
| 125 | } |
| 126 | else /// endsWith |
| 127 | { |
| 128 | res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size); |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | haystack_source.next(); |
| 133 | needle_source.next(); |
| 134 | ++row_num; |
| 135 | } |
| 136 | } |
| 137 | }; |
| 138 | |
| 139 | } |
| 140 | |