1 | #include "duckdb/function/scalar/string_functions.hpp" |
2 | #include "fmt/format.h" |
3 | #include "fmt/printf.h" |
4 | |
5 | using namespace std; |
6 | |
7 | namespace duckdb { |
8 | |
9 | struct FMTPrintf { |
10 | template <class ctx> |
11 | static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) { |
12 | return fmt::vsprintf(format_str, |
13 | fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size()))); |
14 | } |
15 | }; |
16 | |
17 | struct FMTFormat { |
18 | template <class ctx> |
19 | static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) { |
20 | return fmt::vformat(format_str, |
21 | fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size()))); |
22 | } |
23 | }; |
24 | |
25 | template <class FORMAT_FUN, class ctx> |
26 | static void printf_function(DataChunk &args, ExpressionState &state, Vector &result) { |
27 | auto &format_string = args.data[0]; |
28 | result.vector_type = VectorType::CONSTANT_VECTOR; |
29 | for (idx_t i = 0; i < args.column_count(); i++) { |
30 | switch (args.data[i].vector_type) { |
31 | case VectorType::CONSTANT_VECTOR: |
32 | if (ConstantVector::IsNull(args.data[i])) { |
33 | // constant null! result is always NULL regardless of other input |
34 | result.vector_type = VectorType::CONSTANT_VECTOR; |
35 | ConstantVector::SetNull(result, true); |
36 | return; |
37 | } |
38 | break; |
39 | default: |
40 | // FLAT VECTOR, we can directly OR the nullmask |
41 | args.data[i].Normalify(args.size()); |
42 | result.vector_type = VectorType::FLAT_VECTOR; |
43 | FlatVector::Nullmask(result) |= FlatVector::Nullmask(args.data[i]); |
44 | break; |
45 | } |
46 | } |
47 | idx_t count = result.vector_type == VectorType::CONSTANT_VECTOR ? 1 : args.size(); |
48 | |
49 | auto format_data = FlatVector::GetData<string_t>(format_string); |
50 | auto result_data = FlatVector::GetData<string_t>(result); |
51 | for (idx_t idx = 0; idx < count; idx++) { |
52 | if (result.vector_type == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { |
53 | // this entry is NULL: skip it |
54 | continue; |
55 | } |
56 | |
57 | // first fetch the format string |
58 | auto fmt_idx = format_string.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx; |
59 | auto format_string = format_data[fmt_idx].GetData(); |
60 | |
61 | // now gather all the format arguments |
62 | std::vector<fmt::basic_format_arg<ctx>> format_args; |
63 | for (idx_t col_idx = 1; col_idx < args.column_count(); col_idx++) { |
64 | auto &col = args.data[col_idx]; |
65 | idx_t arg_idx = col.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx; |
66 | switch (col.type) { |
67 | case TypeId::BOOL: { |
68 | auto arg_data = FlatVector::GetData<bool>(col); |
69 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
70 | break; |
71 | } |
72 | case TypeId::INT8: { |
73 | auto arg_data = FlatVector::GetData<int8_t>(col); |
74 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
75 | break; |
76 | } |
77 | case TypeId::INT16: { |
78 | auto arg_data = FlatVector::GetData<int8_t>(col); |
79 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
80 | break; |
81 | } |
82 | case TypeId::INT32: { |
83 | auto arg_data = FlatVector::GetData<int32_t>(col); |
84 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
85 | break; |
86 | } |
87 | case TypeId::INT64: { |
88 | auto arg_data = FlatVector::GetData<int64_t>(col); |
89 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
90 | break; |
91 | } |
92 | case TypeId::FLOAT: { |
93 | auto arg_data = FlatVector::GetData<float>(col); |
94 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
95 | break; |
96 | } |
97 | case TypeId::DOUBLE: { |
98 | auto arg_data = FlatVector::GetData<double>(col); |
99 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
100 | break; |
101 | } |
102 | case TypeId::VARCHAR: { |
103 | auto arg_data = FlatVector::GetData<string_t>(col); |
104 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx].GetData())); |
105 | break; |
106 | } |
107 | default: |
108 | throw Exception("Unsupported type for format!" ); |
109 | } |
110 | } |
111 | // finally actually perform the format |
112 | string dynamic_result = FORMAT_FUN::template OP<ctx>(format_string, format_args); |
113 | result_data[idx] = StringVector::AddString(result, dynamic_result); |
114 | } |
115 | } |
116 | |
117 | void PrintfFun::RegisterFunction(BuiltinFunctions &set) { |
118 | // fmt::printf_context, fmt::vsprintf |
119 | ScalarFunction printf_fun = |
120 | ScalarFunction("printf" , {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTPrintf, fmt::printf_context>); |
121 | printf_fun.varargs = SQLType::ANY; |
122 | set.AddFunction(printf_fun); |
123 | |
124 | // fmt::format_context, fmt::vformat |
125 | ScalarFunction format_fun = |
126 | ScalarFunction("format" , {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTFormat, fmt::format_context>); |
127 | format_fun.varargs = SQLType::ANY; |
128 | set.AddFunction(format_fun); |
129 | } |
130 | |
131 | } // namespace duckdb |
132 | |