1 | #include "duckdb/function/scalar/string_functions.hpp" |
2 | |
3 | #include "duckdb/common/exception.hpp" |
4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
5 | #include "duckdb/common/vector_operations/unary_executor.hpp" |
6 | #include "utf8proc.hpp" |
7 | |
8 | #include <string.h> |
9 | |
10 | using namespace std; |
11 | |
12 | namespace duckdb { |
13 | |
14 | template <bool IS_UPPER> static string_t strcase_unicode(Vector &result, const char *input_data, idx_t input_length) { |
15 | // first figure out the output length |
16 | // optimization: if only ascii then input_length = output_length |
17 | idx_t output_length = 0; |
18 | for (idx_t i = 0; i < input_length;) { |
19 | if (input_data[i] & 0x80) { |
20 | // unicode |
21 | int sz = 0; |
22 | int codepoint = utf8proc_codepoint(input_data + i, sz); |
23 | int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); |
24 | int new_sz = utf8proc_codepoint_length(converted_codepoint); |
25 | assert(new_sz >= 0); |
26 | output_length += new_sz; |
27 | i += sz; |
28 | } else { |
29 | // ascii |
30 | output_length++; |
31 | i++; |
32 | } |
33 | } |
34 | auto result_str = StringVector::EmptyString(result, output_length); |
35 | auto result_data = result_str.GetData(); |
36 | |
37 | for (idx_t i = 0; i < input_length;) { |
38 | if (input_data[i] & 0x80) { |
39 | // non-ascii character |
40 | int sz = 0, new_sz = 0; |
41 | int codepoint = utf8proc_codepoint(input_data + i, sz); |
42 | int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint); |
43 | const auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data); |
44 | assert(success); |
45 | result_data += new_sz; |
46 | i += sz; |
47 | } else { |
48 | // ascii |
49 | *result_data = IS_UPPER ? toupper(input_data[i]) : tolower(input_data[i]); |
50 | result_data++; |
51 | i++; |
52 | } |
53 | } |
54 | result_str.Finalize(); |
55 | return result_str; |
56 | } |
57 | |
58 | template <bool IS_UPPER> static void caseconvert_function(Vector &input, Vector &result, idx_t count) { |
59 | assert(input.type == TypeId::VARCHAR); |
60 | |
61 | UnaryExecutor::Execute<string_t, string_t, true>(input, result, count, [&](string_t input) { |
62 | auto input_data = input.GetData(); |
63 | auto input_length = input.GetSize(); |
64 | return strcase_unicode<IS_UPPER>(result, input_data, input_length); |
65 | }); |
66 | } |
67 | |
68 | static void caseconvert_upper_function(DataChunk &args, ExpressionState &state, Vector &result) { |
69 | assert(args.column_count() == 1); |
70 | caseconvert_function<true>(args.data[0], result, args.size()); |
71 | } |
72 | |
73 | static void caseconvert_lower_function(DataChunk &args, ExpressionState &state, Vector &result) { |
74 | assert(args.column_count() == 1); |
75 | caseconvert_function<false>(args.data[0], result, args.size()); |
76 | } |
77 | |
78 | ScalarFunction LowerFun::GetFunction() { |
79 | return ScalarFunction({SQLType::VARCHAR}, SQLType::VARCHAR, caseconvert_lower_function); |
80 | } |
81 | |
82 | void LowerFun::RegisterFunction(BuiltinFunctions &set) { |
83 | set.AddFunction({"lower" , "lcase" }, LowerFun::GetFunction()); |
84 | } |
85 | |
86 | void UpperFun::RegisterFunction(BuiltinFunctions &set) { |
87 | set.AddFunction({"upper" , "ucase" }, |
88 | ScalarFunction({SQLType::VARCHAR}, SQLType::VARCHAR, caseconvert_upper_function)); |
89 | } |
90 | |
91 | } // namespace duckdb |
92 | |