| 1 | //===----------------------------------------------------------------------===// |
| 2 | // DuckDB |
| 3 | // |
| 4 | // duckdb/function/function_serialization.hpp |
| 5 | // |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #pragma once |
| 10 | |
| 11 | #include "duckdb/common/field_writer.hpp" |
| 12 | #include "duckdb/main/client_context.hpp" |
| 13 | #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" |
| 14 | |
| 15 | namespace duckdb { |
| 16 | |
| 17 | class FunctionSerializer { |
| 18 | public: |
| 19 | template <class FUNC> |
| 20 | static void SerializeBase(FieldWriter &writer, const FUNC &function, FunctionData *bind_info) { |
| 21 | D_ASSERT(!function.name.empty()); |
| 22 | writer.WriteString(val: function.name); |
| 23 | writer.WriteRegularSerializableList(function.arguments); |
| 24 | writer.WriteRegularSerializableList(function.original_arguments); |
| 25 | bool serialize = function.serialize; |
| 26 | writer.WriteField(element: serialize); |
| 27 | if (serialize) { |
| 28 | function.serialize(writer, bind_info, function); |
| 29 | // First check if serialize throws a NotImplementedException, in which case it doesn't require a deserialize |
| 30 | // function |
| 31 | D_ASSERT(function.deserialize); |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | template <class FUNC> |
| 36 | static void Serialize(FieldWriter &writer, const FUNC &function, const LogicalType &return_type, |
| 37 | const vector<unique_ptr<Expression>> &children, FunctionData *bind_info) { |
| 38 | SerializeBase(writer, function, bind_info); |
| 39 | writer.WriteSerializable(element: return_type); |
| 40 | writer.WriteSerializableList(elements: children); |
| 41 | } |
| 42 | |
| 43 | template <class FUNC, class CATALOG_ENTRY> |
| 44 | static FUNC DeserializeBaseInternal(FieldReader &reader, PlanDeserializationState &state, CatalogType type, |
| 45 | unique_ptr<FunctionData> &bind_info, bool &has_deserialize) { |
| 46 | auto &context = state.context; |
| 47 | auto name = reader.ReadRequired<string>(); |
| 48 | auto arguments = reader.ReadRequiredSerializableList<LogicalType, LogicalType>(); |
| 49 | // note: original_arguments are optional (can be list of size 0) |
| 50 | auto original_arguments = reader.ReadRequiredSerializableList<LogicalType, LogicalType>(); |
| 51 | |
| 52 | auto &func_catalog = Catalog::GetEntry(context, type, SYSTEM_CATALOG, DEFAULT_SCHEMA, name); |
| 53 | if (func_catalog.type != type) { |
| 54 | throw InternalException("Cant find catalog entry for function %s" , name); |
| 55 | } |
| 56 | |
| 57 | auto &functions = func_catalog.Cast<CATALOG_ENTRY>(); |
| 58 | auto function = functions.functions.GetFunctionByArguments( |
| 59 | state.context, original_arguments.empty() ? arguments : original_arguments); |
| 60 | function.arguments = std::move(arguments); |
| 61 | function.original_arguments = std::move(original_arguments); |
| 62 | |
| 63 | has_deserialize = reader.ReadRequired<bool>(); |
| 64 | if (has_deserialize) { |
| 65 | if (!function.deserialize) { |
| 66 | throw SerializationException("Function requires deserialization but no deserialization function for %s" , |
| 67 | function.name); |
| 68 | } |
| 69 | bind_info = function.deserialize(state, reader, function); |
| 70 | } else { |
| 71 | D_ASSERT(!function.serialize); |
| 72 | D_ASSERT(!function.deserialize); |
| 73 | } |
| 74 | return function; |
| 75 | } |
| 76 | template <class FUNC, class CATALOG_ENTRY> |
| 77 | static FUNC DeserializeBase(FieldReader &reader, PlanDeserializationState &state, CatalogType type, |
| 78 | unique_ptr<FunctionData> &bind_info) { |
| 79 | bool has_deserialize; |
| 80 | return DeserializeBaseInternal<FUNC, CATALOG_ENTRY>(reader, state, type, bind_info, has_deserialize); |
| 81 | } |
| 82 | |
| 83 | template <class FUNC, class CATALOG_ENTRY> |
| 84 | static FUNC Deserialize(FieldReader &reader, ExpressionDeserializationState &state, CatalogType type, |
| 85 | vector<unique_ptr<Expression>> &children, unique_ptr<FunctionData> &bind_info) { |
| 86 | bool has_deserialize; |
| 87 | auto function = |
| 88 | DeserializeBaseInternal<FUNC, CATALOG_ENTRY>(reader, state.gstate, type, bind_info, has_deserialize); |
| 89 | auto return_type = reader.ReadRequiredSerializable<LogicalType, LogicalType>(); |
| 90 | children = reader.ReadRequiredSerializableList<Expression>(args&: state.gstate); |
| 91 | |
| 92 | // we re-bind the function only if the function did not have an explicit deserialize method |
| 93 | auto &context = state.gstate.context; |
| 94 | if (!has_deserialize && function.bind) { |
| 95 | bind_info = function.bind(context, function, children); |
| 96 | } |
| 97 | function.return_type = return_type; |
| 98 | return function; |
| 99 | } |
| 100 | }; |
| 101 | |
| 102 | } // namespace duckdb |
| 103 | |