1#include "config_functions.h"
2#if USE_BASE64
3#include <Columns/ColumnConst.h>
4#include <Columns/ColumnString.h>
5#include <DataTypes/DataTypeString.h>
6#include <Functions/FunctionFactory.h>
7#include <Functions/FunctionHelpers.h>
8#include <Functions/GatherUtils/Algorithms.h>
9#include <IO/WriteHelpers.h>
10#include <libbase64.h>
11
12
13namespace DB
14{
15using namespace GatherUtils;
16
17namespace ErrorCodes
18{
19 extern const int ILLEGAL_COLUMN;
20 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
21 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
22 extern const int INCORRECT_DATA;
23}
24
25struct Base64Encode
26{
27 static constexpr auto name = "base64Encode";
28 static size_t getBufferSize(size_t string_length, size_t string_count)
29 {
30 return ((string_length - string_count) / 3 + string_count) * 4 + string_count;
31 }
32};
33
34struct Base64Decode
35{
36 static constexpr auto name = "base64Decode";
37
38 static size_t getBufferSize(size_t string_length, size_t string_count)
39 {
40 return ((string_length - string_count) / 4 + string_count) * 3 + string_count;
41 }
42};
43
44struct TryBase64Decode
45{
46 static constexpr auto name = "tryBase64Decode";
47
48 static size_t getBufferSize(size_t string_length, size_t string_count)
49 {
50 return Base64Decode::getBufferSize(string_length, string_count);
51 }
52};
53
54template <typename Func>
55class FunctionBase64Conversion : public IFunction
56{
57public:
58 static constexpr auto name = Func::name;
59
60 static FunctionPtr create(const Context &)
61 {
62 return std::make_shared<FunctionBase64Conversion>();
63 }
64
65 String getName() const override
66 {
67 return Func::name;
68 }
69
70 size_t getNumberOfArguments() const override
71 {
72 return 1;
73 }
74
75 bool useDefaultImplementationForConstants() const override
76 {
77 return true;
78 }
79
80 DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
81 {
82 if (!WhichDataType(arguments[0].type).isString())
83 throw Exception(
84 "Illegal type " + arguments[0].type->getName() + " of 1 argument of function " + getName() + ". Must be String.",
85 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
86
87 return std::make_shared<DataTypeString>();
88 }
89
90 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
91 {
92 const ColumnPtr column_string = block.getByPosition(arguments[0]).column;
93 const ColumnString * input = checkAndGetColumn<ColumnString>(column_string.get());
94
95 if (!input)
96 throw Exception(
97 "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(),
98 ErrorCodes::ILLEGAL_COLUMN);
99
100 auto dst_column = ColumnString::create();
101 auto & dst_data = dst_column->getChars();
102 auto & dst_offsets = dst_column->getOffsets();
103
104 size_t reserve = Func::getBufferSize(input->getChars().size(), input->size());
105 dst_data.resize(reserve);
106 dst_offsets.resize(input_rows_count);
107
108 const ColumnString::Offsets & src_offsets = input->getOffsets();
109
110 auto source = reinterpret_cast<const char *>(input->getChars().data());
111 auto dst = reinterpret_cast<char *>(dst_data.data());
112 auto dst_pos = dst;
113
114 size_t src_offset_prev = 0;
115
116 int codec = getCodec();
117 for (size_t row = 0; row < input_rows_count; ++row)
118 {
119 size_t srclen = src_offsets[row] - src_offset_prev - 1;
120 size_t outlen = 0;
121
122 if constexpr (std::is_same_v<Func, Base64Encode>)
123 {
124 base64_encode(source, srclen, dst_pos, &outlen, codec);
125 }
126 else if constexpr (std::is_same_v<Func, Base64Decode>)
127 {
128 if (!base64_decode(source, srclen, dst_pos, &outlen, codec))
129 {
130 throw Exception("Failed to " + getName() + " input '" + String(source, srclen) + "'", ErrorCodes::INCORRECT_DATA);
131 }
132 }
133 else
134 {
135 // during decoding character array can be partially polluted
136 // if fail, revert back and clean
137 auto savepoint = dst_pos;
138 if (!base64_decode(source, srclen, dst_pos, &outlen, codec))
139 {
140 outlen = 0;
141 dst_pos = savepoint;
142 // clean the symbol
143 dst_pos[0] = 0;
144 }
145 }
146
147 source += srclen + 1;
148 dst_pos += outlen + 1;
149
150 dst_offsets[row] = dst_pos - dst;
151 src_offset_prev = src_offsets[row];
152 }
153
154 dst_data.resize(dst_pos - dst);
155
156 block.getByPosition(result).column = std::move(dst_column);
157 }
158
159private:
160 static int getCodec()
161 {
162 /// You can provide different value if you want to test specific codecs.
163 /// Due to poor implementation of "base64" library (it will write to a global variable),
164 /// it doesn't scale for multiple threads. Never use non-zero values in production.
165 return 0;
166 }
167};
168}
169
170/** We must call it in advance from a single thread
171 * to avoid thread sanitizer report about data race in "codec_choose" function.
172 */
173inline void initializeBase64()
174{
175 size_t outlen = 0;
176 base64_encode(nullptr, 0, nullptr, &outlen, 0);
177}
178
179#endif
180