| 1 | #include "duckdb/function/scalar/string_functions.hpp" |
| 2 | #include "fmt/format.h" |
| 3 | #include "fmt/printf.h" |
| 4 | |
| 5 | using namespace std; |
| 6 | |
| 7 | namespace duckdb { |
| 8 | |
| 9 | struct FMTPrintf { |
| 10 | template <class ctx> |
| 11 | static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) { |
| 12 | return fmt::vsprintf(format_str, |
| 13 | fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size()))); |
| 14 | } |
| 15 | }; |
| 16 | |
| 17 | struct FMTFormat { |
| 18 | template <class ctx> |
| 19 | static string OP(const char *format_str, std::vector<fmt::basic_format_arg<ctx>> &format_args) { |
| 20 | return fmt::vformat(format_str, |
| 21 | fmt::basic_format_args<ctx>(format_args.data(), static_cast<int>(format_args.size()))); |
| 22 | } |
| 23 | }; |
| 24 | |
| 25 | template <class FORMAT_FUN, class ctx> |
| 26 | static void printf_function(DataChunk &args, ExpressionState &state, Vector &result) { |
| 27 | auto &format_string = args.data[0]; |
| 28 | result.vector_type = VectorType::CONSTANT_VECTOR; |
| 29 | for (idx_t i = 0; i < args.column_count(); i++) { |
| 30 | switch (args.data[i].vector_type) { |
| 31 | case VectorType::CONSTANT_VECTOR: |
| 32 | if (ConstantVector::IsNull(args.data[i])) { |
| 33 | // constant null! result is always NULL regardless of other input |
| 34 | result.vector_type = VectorType::CONSTANT_VECTOR; |
| 35 | ConstantVector::SetNull(result, true); |
| 36 | return; |
| 37 | } |
| 38 | break; |
| 39 | default: |
| 40 | // FLAT VECTOR, we can directly OR the nullmask |
| 41 | args.data[i].Normalify(args.size()); |
| 42 | result.vector_type = VectorType::FLAT_VECTOR; |
| 43 | FlatVector::Nullmask(result) |= FlatVector::Nullmask(args.data[i]); |
| 44 | break; |
| 45 | } |
| 46 | } |
| 47 | idx_t count = result.vector_type == VectorType::CONSTANT_VECTOR ? 1 : args.size(); |
| 48 | |
| 49 | auto format_data = FlatVector::GetData<string_t>(format_string); |
| 50 | auto result_data = FlatVector::GetData<string_t>(result); |
| 51 | for (idx_t idx = 0; idx < count; idx++) { |
| 52 | if (result.vector_type == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { |
| 53 | // this entry is NULL: skip it |
| 54 | continue; |
| 55 | } |
| 56 | |
| 57 | // first fetch the format string |
| 58 | auto fmt_idx = format_string.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx; |
| 59 | auto format_string = format_data[fmt_idx].GetData(); |
| 60 | |
| 61 | // now gather all the format arguments |
| 62 | std::vector<fmt::basic_format_arg<ctx>> format_args; |
| 63 | for (idx_t col_idx = 1; col_idx < args.column_count(); col_idx++) { |
| 64 | auto &col = args.data[col_idx]; |
| 65 | idx_t arg_idx = col.vector_type == VectorType::CONSTANT_VECTOR ? 0 : idx; |
| 66 | switch (col.type) { |
| 67 | case TypeId::BOOL: { |
| 68 | auto arg_data = FlatVector::GetData<bool>(col); |
| 69 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 70 | break; |
| 71 | } |
| 72 | case TypeId::INT8: { |
| 73 | auto arg_data = FlatVector::GetData<int8_t>(col); |
| 74 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 75 | break; |
| 76 | } |
| 77 | case TypeId::INT16: { |
| 78 | auto arg_data = FlatVector::GetData<int8_t>(col); |
| 79 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 80 | break; |
| 81 | } |
| 82 | case TypeId::INT32: { |
| 83 | auto arg_data = FlatVector::GetData<int32_t>(col); |
| 84 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 85 | break; |
| 86 | } |
| 87 | case TypeId::INT64: { |
| 88 | auto arg_data = FlatVector::GetData<int64_t>(col); |
| 89 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 90 | break; |
| 91 | } |
| 92 | case TypeId::FLOAT: { |
| 93 | auto arg_data = FlatVector::GetData<float>(col); |
| 94 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 95 | break; |
| 96 | } |
| 97 | case TypeId::DOUBLE: { |
| 98 | auto arg_data = FlatVector::GetData<double>(col); |
| 99 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx])); |
| 100 | break; |
| 101 | } |
| 102 | case TypeId::VARCHAR: { |
| 103 | auto arg_data = FlatVector::GetData<string_t>(col); |
| 104 | format_args.emplace_back(fmt::internal::make_arg<ctx>(arg_data[arg_idx].GetData())); |
| 105 | break; |
| 106 | } |
| 107 | default: |
| 108 | throw Exception("Unsupported type for format!" ); |
| 109 | } |
| 110 | } |
| 111 | // finally actually perform the format |
| 112 | string dynamic_result = FORMAT_FUN::template OP<ctx>(format_string, format_args); |
| 113 | result_data[idx] = StringVector::AddString(result, dynamic_result); |
| 114 | } |
| 115 | } |
| 116 | |
| 117 | void PrintfFun::RegisterFunction(BuiltinFunctions &set) { |
| 118 | // fmt::printf_context, fmt::vsprintf |
| 119 | ScalarFunction printf_fun = |
| 120 | ScalarFunction("printf" , {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTPrintf, fmt::printf_context>); |
| 121 | printf_fun.varargs = SQLType::ANY; |
| 122 | set.AddFunction(printf_fun); |
| 123 | |
| 124 | // fmt::format_context, fmt::vformat |
| 125 | ScalarFunction format_fun = |
| 126 | ScalarFunction("format" , {SQLType::VARCHAR}, SQLType::VARCHAR, printf_function<FMTFormat, fmt::format_context>); |
| 127 | format_fun.varargs = SQLType::ANY; |
| 128 | set.AddFunction(format_fun); |
| 129 | } |
| 130 | |
| 131 | } // namespace duckdb |
| 132 | |