1#include "duckdb/function/scalar/string_functions.hpp"
2#include "fmt/format.h"
3#include "fmt/printf.h"
4
5using namespace std;
6
7namespace duckdb {
8
9struct FMTPrintf {
10 template <class ctx>
11 static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) {
12 return fmt::vsprintf(format_str,
13 fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size())));
14 }
15};
16
17struct FMTFormat {
18 template <class ctx>
19 static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) {
20 return fmt::vformat(format_str,
21 fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size())));
22 }
23};
24
25template <class FORMAT_FUN, class ctx>
26static void printf_function(DataChunk &args, ExpressionState &state, Vector &result) {
27 auto &format_string = args.data[0];
28 result.vector_type = VectorType::CONSTANT_VECTOR;
29 for (idx_t i = 0; i < args.column_count(); i++) {
30 switch (args.data[i].vector_type) {
31 case VectorType::CONSTANT_VECTOR:
32 if (ConstantVector::IsNull(args.data[i])) {
33 // constant null! result is always NULL regardless of other input
34 result.vector_type = VectorType::CONSTANT_VECTOR;
35 ConstantVector::SetNull(result, true);
36 return;
37 }
38 break;
39 default:
40 // FLAT VECTOR, we can directly OR the nullmask
41 args.data[i].Normalify(args.size());
42 result.vector_type = VectorType::FLAT_VECTOR;
43 FlatVector::Nullmask(result) |= FlatVector::Nullmask(args.data[i]);
44 break;
45 }
46 }
47 idx_t count = result.vector_type == VectorType::CONSTANT_VECTOR ? 1 : args.size();
48
49 auto format_data = FlatVector::GetData<string_t>(format_string);
50 auto result_data = FlatVector::GetData<string_t>(result);
51 for (idx_t idx = 0; idx < count; idx++) {
52 if (result.vector_type == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) {
53 // this entry is NULL: skip it
54 continue;
55 }
56
57 // first fetch the format string
58 auto fmt_idx = format_string.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx;
59 auto format_string = format_data[fmt_idx].GetData();
60
61 // now gather all the format arguments
62 std::vector<fmt::basic_format_arg<ctx>> format_args;
63 for (idx_t col_idx = 1; col_idx < args.column_count(); col_idx++) {
64 auto &col = args.data[col_idx];
65 idx_t arg_idx = col.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx;
66 switch (col.type) {
67 case TypeId::BOOL: {
68 auto arg_data = FlatVector::GetData<bool>(col);
69 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
70 break;
71 }
72 case TypeId::INT8: {
73 auto arg_data = FlatVector::GetData<int8_t>(col);
74 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
75 break;
76 }
77 case TypeId::INT16: {
78 auto arg_data = FlatVector::GetData<int8_t>(col);
79 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
80 break;
81 }
82 case TypeId::INT32: {
83 auto arg_data = FlatVector::GetData<int32_t>(col);
84 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
85 break;
86 }
87 case TypeId::INT64: {
88 auto arg_data = FlatVector::GetData<int64_t>(col);
89 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
90 break;
91 }
92 case TypeId::FLOAT: {
93 auto arg_data = FlatVector::GetData<float>(col);
94 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
95 break;
96 }
97 case TypeId::DOUBLE: {
98 auto arg_data = FlatVector::GetData<double>(col);
99 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx]));
100 break;
101 }
102 case TypeId::VARCHAR: {
103 auto arg_data = FlatVector::GetData<string_t>(col);
104 format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx].GetData()));
105 break;
106 }
107 default:
108 throw Exception("Unsupported type for format!");
109 }
110 }
111 // finally actually perform the format
112 string dynamic_result = FORMAT_FUN::template OP<ctx>(format_string, format_args);
113 result_data[idx] = StringVector::AddString(result, dynamic_result);
114 }
115}
116
117void PrintfFun::RegisterFunction(BuiltinFunctions &set) {
118 // fmt::printf_context, fmt::vsprintf
119 ScalarFunction printf_fun =
120 ScalarFunction("printf", {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTPrintf, fmt::printf_context>);
121 printf_fun.varargs = SQLType::ANY;
122 set.AddFunction(printf_fun);
123
124 // fmt::format_context, fmt::vformat
125 ScalarFunction format_fun =
126 ScalarFunction("format", {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTFormat, fmt::format_context>);
127 format_fun.varargs = SQLType::ANY;
128 set.AddFunction(format_fun);
129}
130
131} // namespace duckdb
132