1#include "duckdb/function/scalar/string_functions.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/common/vector_operations/unary_executor.hpp"
6#include "utf8proc.hpp"
7
8#include <string.h>
9
10using namespace std;
11
12namespace duckdb {
13
14template <bool IS_UPPER> static string_t strcase_unicode(Vector &result, const char *input_data, idx_t input_length) {
15 // first figure out the output length
16 // optimization: if only ascii then input_length = output_length
17 idx_t output_length = 0;
18 for (idx_t i = 0; i < input_length;) {
19 if (input_data[i] & 0x80) {
20 // unicode
21 int sz = 0;
22 int codepoint = utf8proc_codepoint(input_data + i, sz);
23 int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
24 int new_sz = utf8proc_codepoint_length(converted_codepoint);
25 assert(new_sz >= 0);
26 output_length += new_sz;
27 i += sz;
28 } else {
29 // ascii
30 output_length++;
31 i++;
32 }
33 }
34 auto result_str = StringVector::EmptyString(result, output_length);
35 auto result_data = result_str.GetData();
36
37 for (idx_t i = 0; i < input_length;) {
38 if (input_data[i] & 0x80) {
39 // non-ascii character
40 int sz = 0, new_sz = 0;
41 int codepoint = utf8proc_codepoint(input_data + i, sz);
42 int converted_codepoint = IS_UPPER ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
43 const auto success = utf8proc_codepoint_to_utf8(converted_codepoint, new_sz, result_data);
44 assert(success);
45 result_data += new_sz;
46 i += sz;
47 } else {
48 // ascii
49 *result_data = IS_UPPER ? toupper(input_data[i]) : tolower(input_data[i]);
50 result_data++;
51 i++;
52 }
53 }
54 result_str.Finalize();
55 return result_str;
56}
57
58template <bool IS_UPPER> static void caseconvert_function(Vector &input, Vector &result, idx_t count) {
59 assert(input.type == TypeId::VARCHAR);
60
61 UnaryExecutor::Execute<string_t, string_t, true>(input, result, count, [&](string_t input) {
62 auto input_data = input.GetData();
63 auto input_length = input.GetSize();
64 return strcase_unicode<IS_UPPER>(result, input_data, input_length);
65 });
66}
67
68static void caseconvert_upper_function(DataChunk &args, ExpressionState &state, Vector &result) {
69 assert(args.column_count() == 1);
70 caseconvert_function<true>(args.data[0], result, args.size());
71}
72
73static void caseconvert_lower_function(DataChunk &args, ExpressionState &state, Vector &result) {
74 assert(args.column_count() == 1);
75 caseconvert_function<false>(args.data[0], result, args.size());
76}
77
78ScalarFunction LowerFun::GetFunction() {
79 return ScalarFunction({SQLType::VARCHAR}, SQLType::VARCHAR, caseconvert_lower_function);
80}
81
82void LowerFun::RegisterFunction(BuiltinFunctions &set) {
83 set.AddFunction({"lower", "lcase"}, LowerFun::GetFunction());
84}
85
86void UpperFun::RegisterFunction(BuiltinFunctions &set) {
87 set.AddFunction({"upper", "ucase"},
88 ScalarFunction({SQLType::VARCHAR}, SQLType::VARCHAR, caseconvert_upper_function));
89}
90
91} // namespace duckdb
92