| 1 | #include "duckdb/function/cast/default_casts.hpp" |
| 2 | #include "duckdb/function/cast/cast_function_set.hpp" |
| 3 | #include "duckdb/function/cast/bound_cast_data.hpp" |
| 4 | |
| 5 | namespace duckdb { |
| 6 | |
| 7 | unique_ptr<BoundCastData> StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source, |
| 8 | const LogicalType &target) { |
| 9 | vector<BoundCastInfo> child_cast_info; |
| 10 | auto &source_child_types = StructType::GetChildTypes(type: source); |
| 11 | auto &result_child_types = StructType::GetChildTypes(type: target); |
| 12 | if (source_child_types.size() != result_child_types.size()) { |
| 13 | throw TypeMismatchException(source, target, "Cannot cast STRUCTs of different size" ); |
| 14 | } |
| 15 | for (idx_t i = 0; i < source_child_types.size(); i++) { |
| 16 | auto child_cast = input.GetCastFunction(source: source_child_types[i].second, target: result_child_types[i].second); |
| 17 | child_cast_info.push_back(x: std::move(child_cast)); |
| 18 | } |
| 19 | return make_uniq<StructBoundCastData>(args: std::move(child_cast_info), args: target); |
| 20 | } |
| 21 | |
| 22 | unique_ptr<FunctionLocalState> StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters ¶meters) { |
| 23 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
| 24 | auto result = make_uniq<StructCastLocalState>(); |
| 25 | |
| 26 | for (auto &entry : cast_data.child_cast_info) { |
| 27 | unique_ptr<FunctionLocalState> child_state; |
| 28 | if (entry.init_local_state) { |
| 29 | CastLocalStateParameters child_params(parameters, entry.cast_data); |
| 30 | child_state = entry.init_local_state(child_params); |
| 31 | } |
| 32 | result->local_states.push_back(x: std::move(child_state)); |
| 33 | } |
| 34 | return std::move(result); |
| 35 | } |
| 36 | |
| 37 | static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 38 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
| 39 | auto &lstate = parameters.local_state->Cast<StructCastLocalState>(); |
| 40 | auto &source_child_types = StructType::GetChildTypes(type: source.GetType()); |
| 41 | auto &source_children = StructVector::GetEntries(vector&: source); |
| 42 | D_ASSERT(source_children.size() == StructType::GetChildTypes(result.GetType()).size()); |
| 43 | |
| 44 | auto &result_children = StructVector::GetEntries(vector&: result); |
| 45 | bool all_converted = true; |
| 46 | for (idx_t c_idx = 0; c_idx < source_child_types.size(); c_idx++) { |
| 47 | auto &result_child_vector = *result_children[c_idx]; |
| 48 | auto &source_child_vector = *source_children[c_idx]; |
| 49 | CastParameters child_parameters(parameters, cast_data.child_cast_info[c_idx].cast_data, |
| 50 | lstate.local_states[c_idx]); |
| 51 | if (!cast_data.child_cast_info[c_idx].function(source_child_vector, result_child_vector, count, |
| 52 | child_parameters)) { |
| 53 | all_converted = false; |
| 54 | } |
| 55 | } |
| 56 | if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
| 57 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 58 | ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source)); |
| 59 | } else { |
| 60 | source.Flatten(count); |
| 61 | FlatVector::Validity(vector&: result) = FlatVector::Validity(vector&: source); |
| 62 | } |
| 63 | return all_converted; |
| 64 | } |
| 65 | |
| 66 | static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 67 | auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; |
| 68 | // first cast all child elements to varchar |
| 69 | auto &cast_data = parameters.cast_data->Cast<StructBoundCastData>(); |
| 70 | Vector varchar_struct(cast_data.target, count); |
| 71 | StructToStructCast(source, result&: varchar_struct, count, parameters); |
| 72 | |
| 73 | // now construct the actual varchar vector |
| 74 | varchar_struct.Flatten(count); |
| 75 | auto &child_types = StructType::GetChildTypes(type: source.GetType()); |
| 76 | auto &children = StructVector::GetEntries(vector&: varchar_struct); |
| 77 | auto &validity = FlatVector::Validity(vector&: varchar_struct); |
| 78 | auto result_data = FlatVector::GetData<string_t>(vector&: result); |
| 79 | static constexpr const idx_t SEP_LENGTH = 2; |
| 80 | static constexpr const idx_t NAME_SEP_LENGTH = 4; |
| 81 | static constexpr const idx_t NULL_LENGTH = 4; |
| 82 | for (idx_t i = 0; i < count; i++) { |
| 83 | if (!validity.RowIsValid(row_idx: i)) { |
| 84 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
| 85 | continue; |
| 86 | } |
| 87 | idx_t string_length = 2; // {} |
| 88 | for (idx_t c = 0; c < children.size(); c++) { |
| 89 | if (c > 0) { |
| 90 | string_length += SEP_LENGTH; |
| 91 | } |
| 92 | children[c]->Flatten(count); |
| 93 | auto &child_validity = FlatVector::Validity(vector&: *children[c]); |
| 94 | auto data = FlatVector::GetData<string_t>(vector&: *children[c]); |
| 95 | auto &name = child_types[c].first; |
| 96 | string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': " |
| 97 | string_length += child_validity.RowIsValid(row_idx: i) ? data[i].GetSize() : NULL_LENGTH; |
| 98 | } |
| 99 | result_data[i] = StringVector::EmptyString(vector&: result, len: string_length); |
| 100 | auto dataptr = result_data[i].GetDataWriteable(); |
| 101 | idx_t offset = 0; |
| 102 | dataptr[offset++] = '{'; |
| 103 | for (idx_t c = 0; c < children.size(); c++) { |
| 104 | if (c > 0) { |
| 105 | memcpy(dest: dataptr + offset, src: ", " , n: SEP_LENGTH); |
| 106 | offset += SEP_LENGTH; |
| 107 | } |
| 108 | auto &child_validity = FlatVector::Validity(vector&: *children[c]); |
| 109 | auto data = FlatVector::GetData<string_t>(vector&: *children[c]); |
| 110 | auto &name = child_types[c].first; |
| 111 | // "'{name}': " |
| 112 | dataptr[offset++] = '\''; |
| 113 | memcpy(dest: dataptr + offset, src: name.c_str(), n: name.size()); |
| 114 | offset += name.size(); |
| 115 | dataptr[offset++] = '\''; |
| 116 | dataptr[offset++] = ':'; |
| 117 | dataptr[offset++] = ' '; |
| 118 | // value |
| 119 | if (child_validity.RowIsValid(row_idx: i)) { |
| 120 | auto len = data[i].GetSize(); |
| 121 | memcpy(dest: dataptr + offset, src: data[i].GetData(), n: len); |
| 122 | offset += len; |
| 123 | } else { |
| 124 | memcpy(dest: dataptr + offset, src: "NULL" , n: NULL_LENGTH); |
| 125 | offset += NULL_LENGTH; |
| 126 | } |
| 127 | } |
| 128 | dataptr[offset++] = '}'; |
| 129 | result_data[i].Finalize(); |
| 130 | } |
| 131 | |
| 132 | if (constant) { |
| 133 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 134 | } |
| 135 | return true; |
| 136 | } |
| 137 | |
| 138 | BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source, |
| 139 | const LogicalType &target) { |
| 140 | switch (target.id()) { |
| 141 | case LogicalTypeId::STRUCT: |
| 142 | return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target), |
| 143 | StructBoundCastData::InitStructCastLocalState); |
| 144 | case LogicalTypeId::VARCHAR: { |
| 145 | // bind a cast in which we convert all child entries to VARCHAR entries |
| 146 | auto &struct_children = StructType::GetChildTypes(type: source); |
| 147 | child_list_t<LogicalType> varchar_children; |
| 148 | for (auto &child_entry : struct_children) { |
| 149 | varchar_children.push_back(x: make_pair(x: child_entry.first, y: LogicalType::VARCHAR)); |
| 150 | } |
| 151 | auto varchar_type = LogicalType::STRUCT(children: varchar_children); |
| 152 | return BoundCastInfo(StructToVarcharCast, |
| 153 | StructBoundCastData::BindStructToStructCast(input, source, target: varchar_type), |
| 154 | StructBoundCastData::InitStructCastLocalState); |
| 155 | } |
| 156 | default: |
| 157 | return TryVectorNullCast; |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | } // namespace duckdb |
| 162 | |