| 1 | #include "duckdb/function/cast/default_casts.hpp" |
| 2 | #include "duckdb/function/cast/cast_function_set.hpp" |
| 3 | #include "duckdb/function/cast/bound_cast_data.hpp" |
| 4 | |
| 5 | namespace duckdb { |
| 6 | |
| 7 | unique_ptr<BoundCastData> ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source, |
| 8 | const LogicalType &target) { |
| 9 | vector<BoundCastInfo> child_cast_info; |
| 10 | auto &source_child_type = ListType::GetChildType(type: source); |
| 11 | auto &result_child_type = ListType::GetChildType(type: target); |
| 12 | auto child_cast = input.GetCastFunction(source: source_child_type, target: result_child_type); |
| 13 | return make_uniq<ListBoundCastData>(args: std::move(child_cast)); |
| 14 | } |
| 15 | |
| 16 | unique_ptr<FunctionLocalState> ListBoundCastData::InitListLocalState(CastLocalStateParameters ¶meters) { |
| 17 | auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>(); |
| 18 | if (!cast_data.child_cast_info.init_local_state) { |
| 19 | return nullptr; |
| 20 | } |
| 21 | CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); |
| 22 | return cast_data.child_cast_info.init_local_state(child_parameters); |
| 23 | } |
| 24 | |
| 25 | bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 26 | auto &cast_data = parameters.cast_data->Cast<ListBoundCastData>(); |
| 27 | |
| 28 | // only handle constant and flat vectors here for now |
| 29 | if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { |
| 30 | result.SetVectorType(source.GetVectorType()); |
| 31 | ConstantVector::SetNull(vector&: result, is_null: ConstantVector::IsNull(vector: source)); |
| 32 | |
| 33 | auto ldata = ConstantVector::GetData<list_entry_t>(vector&: source); |
| 34 | auto tdata = ConstantVector::GetData<list_entry_t>(vector&: result); |
| 35 | *tdata = *ldata; |
| 36 | } else { |
| 37 | source.Flatten(count); |
| 38 | result.SetVectorType(VectorType::FLAT_VECTOR); |
| 39 | FlatVector::SetValidity(vector&: result, new_validity&: FlatVector::Validity(vector&: source)); |
| 40 | |
| 41 | auto ldata = FlatVector::GetData<list_entry_t>(vector&: source); |
| 42 | auto tdata = FlatVector::GetData<list_entry_t>(vector&: result); |
| 43 | for (idx_t i = 0; i < count; i++) { |
| 44 | tdata[i] = ldata[i]; |
| 45 | } |
| 46 | } |
| 47 | auto &source_cc = ListVector::GetEntry(vector&: source); |
| 48 | auto source_size = ListVector::GetListSize(vector: source); |
| 49 | |
| 50 | ListVector::Reserve(vec&: result, required_capacity: source_size); |
| 51 | auto &append_vector = ListVector::GetEntry(vector&: result); |
| 52 | |
| 53 | CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); |
| 54 | bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters); |
| 55 | ListVector::SetListSize(vec&: result, size: source_size); |
| 56 | D_ASSERT(ListVector::GetListSize(result) == source_size); |
| 57 | return all_succeeded; |
| 58 | } |
| 59 | |
| 60 | static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { |
| 61 | auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; |
| 62 | // first cast the child vector to varchar |
| 63 | Vector varchar_list(LogicalType::LIST(child: LogicalType::VARCHAR), count); |
| 64 | ListCast::ListToListCast(source, result&: varchar_list, count, parameters); |
| 65 | |
| 66 | // now construct the actual varchar vector |
| 67 | varchar_list.Flatten(count); |
| 68 | auto &child = ListVector::GetEntry(vector&: varchar_list); |
| 69 | auto list_data = FlatVector::GetData<list_entry_t>(vector&: varchar_list); |
| 70 | auto &validity = FlatVector::Validity(vector&: varchar_list); |
| 71 | |
| 72 | child.Flatten(count); |
| 73 | auto child_data = FlatVector::GetData<string_t>(vector&: child); |
| 74 | auto &child_validity = FlatVector::Validity(vector&: child); |
| 75 | |
| 76 | auto result_data = FlatVector::GetData<string_t>(vector&: result); |
| 77 | static constexpr const idx_t SEP_LENGTH = 2; |
| 78 | static constexpr const idx_t NULL_LENGTH = 4; |
| 79 | for (idx_t i = 0; i < count; i++) { |
| 80 | if (!validity.RowIsValid(row_idx: i)) { |
| 81 | FlatVector::SetNull(vector&: result, idx: i, is_null: true); |
| 82 | continue; |
| 83 | } |
| 84 | auto list = list_data[i]; |
| 85 | // figure out how long the result needs to be |
| 86 | idx_t list_length = 2; // "[" and "]" |
| 87 | for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { |
| 88 | auto idx = list.offset + list_idx; |
| 89 | if (list_idx > 0) { |
| 90 | list_length += SEP_LENGTH; // ", " |
| 91 | } |
| 92 | // string length, or "NULL" |
| 93 | list_length += child_validity.RowIsValid(row_idx: idx) ? child_data[idx].GetSize() : NULL_LENGTH; |
| 94 | } |
| 95 | result_data[i] = StringVector::EmptyString(vector&: result, len: list_length); |
| 96 | auto dataptr = result_data[i].GetDataWriteable(); |
| 97 | auto offset = 0; |
| 98 | dataptr[offset++] = '['; |
| 99 | for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { |
| 100 | auto idx = list.offset + list_idx; |
| 101 | if (list_idx > 0) { |
| 102 | memcpy(dest: dataptr + offset, src: ", " , n: SEP_LENGTH); |
| 103 | offset += SEP_LENGTH; |
| 104 | } |
| 105 | if (child_validity.RowIsValid(row_idx: idx)) { |
| 106 | auto len = child_data[idx].GetSize(); |
| 107 | memcpy(dest: dataptr + offset, src: child_data[idx].GetData(), n: len); |
| 108 | offset += len; |
| 109 | } else { |
| 110 | memcpy(dest: dataptr + offset, src: "NULL" , n: NULL_LENGTH); |
| 111 | offset += NULL_LENGTH; |
| 112 | } |
| 113 | } |
| 114 | dataptr[offset] = ']'; |
| 115 | result_data[i].Finalize(); |
| 116 | } |
| 117 | |
| 118 | if (constant) { |
| 119 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
| 120 | } |
| 121 | return true; |
| 122 | } |
| 123 | |
| 124 | BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { |
| 125 | switch (target.id()) { |
| 126 | case LogicalTypeId::LIST: |
| 127 | return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), |
| 128 | ListBoundCastData::InitListLocalState); |
| 129 | case LogicalTypeId::VARCHAR: |
| 130 | return BoundCastInfo( |
| 131 | ListToVarcharCast, |
| 132 | ListBoundCastData::BindListToListCast(input, source, target: LogicalType::LIST(child: LogicalType::VARCHAR)), |
| 133 | ListBoundCastData::InitListLocalState); |
| 134 | default: |
| 135 | return DefaultCasts::TryVectorNullCast; |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | } // namespace duckdb |
| 140 | |