1#include <Columns/ColumnString.h>
2#include <Common/assert_cast.h>
3#include <DataTypes/DataTypeString.h>
4#include <Functions/FunctionFactory.h>
5#include <Functions/FunctionHelpers.h>
6#include <Functions/IFunctionImpl.h>
7#include <ext/range.h>
8
9
10namespace DB
11{
12
13namespace ErrorCodes
14{
15 extern const int ILLEGAL_COLUMN;
16 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
17 extern const int BAD_ARGUMENTS;
18}
19
20
21class FunctionAppendTrailingCharIfAbsent : public IFunction
22{
23public:
24 static constexpr auto name = "appendTrailingCharIfAbsent";
25 static FunctionPtr create(const Context &)
26 {
27 return std::make_shared<FunctionAppendTrailingCharIfAbsent>();
28 }
29
30 String getName() const override
31 {
32 return name;
33 }
34
35
36private:
37 size_t getNumberOfArguments() const override
38 {
39 return 2;
40 }
41
42 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
43 {
44 if (!isString(arguments[0]))
45 throw Exception{"Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
46
47 if (!isString(arguments[1]))
48 throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
49
50 return std::make_shared<DataTypeString>();
51 }
52
53 bool useDefaultImplementationForConstants() const override { return true; }
54 ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
55
56 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
57 {
58 const auto & column = block.getByPosition(arguments[0]).column;
59 const auto & column_char = block.getByPosition(arguments[1]).column;
60
61 if (!checkColumnConst<ColumnString>(column_char.get()))
62 throw Exception{"Second argument of function " + getName() + " must be a constant string", ErrorCodes::ILLEGAL_COLUMN};
63
64 String trailing_char_str = assert_cast<const ColumnConst &>(*column_char).getValue<String>();
65
66 if (trailing_char_str.size() != 1)
67 throw Exception{"Second argument of function " + getName() + " must be a one-character string", ErrorCodes::BAD_ARGUMENTS};
68
69 if (const auto col = checkAndGetColumn<ColumnString>(column.get()))
70 {
71 auto col_res = ColumnString::create();
72
73 const auto & src_data = col->getChars();
74 const auto & src_offsets = col->getOffsets();
75
76 auto & dst_data = col_res->getChars();
77 auto & dst_offsets = col_res->getOffsets();
78
79 const auto size = src_offsets.size();
80 dst_data.resize(src_data.size() + size);
81 dst_offsets.resize(size);
82
83 ColumnString::Offset src_offset{};
84 ColumnString::Offset dst_offset{};
85
86 for (const auto i : ext::range(0, size))
87 {
88 const auto src_length = src_offsets[i] - src_offset;
89 memcpySmallAllowReadWriteOverflow15(&dst_data[dst_offset], &src_data[src_offset], src_length);
90 src_offset = src_offsets[i];
91 dst_offset += src_length;
92
93 if (src_length > 1 && dst_data[dst_offset - 2] != trailing_char_str.front())
94 {
95 dst_data[dst_offset - 1] = trailing_char_str.front();
96 dst_data[dst_offset] = 0;
97 ++dst_offset;
98 }
99
100 dst_offsets[i] = dst_offset;
101 }
102
103 dst_data.resize_assume_reserved(dst_offset);
104 block.getByPosition(result).column = std::move(col_res);
105 }
106 else
107 throw Exception{"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(),
108 ErrorCodes::ILLEGAL_COLUMN};
109 }
110};
111
112void registerFunctionAppendTrailingCharIfAbsent(FunctionFactory & factory)
113{
114 factory.registerFunction<FunctionAppendTrailingCharIfAbsent>();
115}
116
117}
118