1 | #include <Columns/ColumnArray.h> |
2 | #include <Columns/ColumnConst.h> |
3 | #include <Columns/ColumnTuple.h> |
4 | #include <Columns/ColumnLowCardinality.h> |
5 | |
6 | #include <DataTypes/DataTypeLowCardinality.h> |
7 | #include <DataTypes/DataTypeArray.h> |
8 | #include <DataTypes/DataTypeTuple.h> |
9 | |
10 | #include <Common/assert_cast.h> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace ErrorCodes |
17 | { |
18 | extern const int ILLEGAL_COLUMN; |
19 | extern const int TYPE_MISMATCH; |
20 | } |
21 | |
22 | DataTypePtr recursiveRemoveLowCardinality(const DataTypePtr & type) |
23 | { |
24 | if (!type) |
25 | return type; |
26 | |
27 | if (const auto * array_type = typeid_cast<const DataTypeArray *>(type.get())) |
28 | return std::make_shared<DataTypeArray>(recursiveRemoveLowCardinality(array_type->getNestedType())); |
29 | |
30 | if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(type.get())) |
31 | { |
32 | DataTypes elements = tuple_type->getElements(); |
33 | for (auto & element : elements) |
34 | element = recursiveRemoveLowCardinality(element); |
35 | |
36 | if (tuple_type->haveExplicitNames()) |
37 | return std::make_shared<DataTypeTuple>(elements, tuple_type->getElementNames()); |
38 | else |
39 | return std::make_shared<DataTypeTuple>(elements); |
40 | } |
41 | |
42 | if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(type.get())) |
43 | return low_cardinality_type->getDictionaryType(); |
44 | |
45 | return type; |
46 | } |
47 | |
48 | ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column) |
49 | { |
50 | if (!column) |
51 | return column; |
52 | |
53 | if (const auto * column_array = typeid_cast<const ColumnArray *>(column.get())) |
54 | { |
55 | auto & data = column_array->getDataPtr(); |
56 | auto data_no_lc = recursiveRemoveLowCardinality(data); |
57 | if (data.get() == data_no_lc.get()) |
58 | return column; |
59 | |
60 | return ColumnArray::create(data_no_lc, column_array->getOffsetsPtr()); |
61 | } |
62 | |
63 | if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get())) |
64 | { |
65 | auto & nested = column_const->getDataColumnPtr(); |
66 | auto nested_no_lc = recursiveRemoveLowCardinality(nested); |
67 | if (nested.get() == nested_no_lc.get()) |
68 | return column; |
69 | |
70 | return ColumnConst::create(nested_no_lc, column_const->size()); |
71 | } |
72 | |
73 | if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get())) |
74 | { |
75 | auto columns = column_tuple->getColumns(); |
76 | for (auto & element : columns) |
77 | element = recursiveRemoveLowCardinality(element); |
78 | return ColumnTuple::create(columns); |
79 | } |
80 | |
81 | if (const auto * column_low_cardinality = typeid_cast<const ColumnLowCardinality *>(column.get())) |
82 | return column_low_cardinality->convertToFullColumn(); |
83 | |
84 | return column; |
85 | } |
86 | |
87 | ColumnPtr recursiveTypeConversion(const ColumnPtr & column, const DataTypePtr & from_type, const DataTypePtr & to_type) |
88 | { |
89 | if (!column) |
90 | return column; |
91 | |
92 | if (from_type->equals(*to_type)) |
93 | return column; |
94 | |
95 | /// We can allow insert enum column if it's numeric type is the same as the column's type in table. |
96 | if (WhichDataType(to_type).isEnum() && from_type->getTypeId() == to_type->getTypeId()) |
97 | return column; |
98 | |
99 | if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get())) |
100 | { |
101 | auto & nested = column_const->getDataColumnPtr(); |
102 | auto nested_no_lc = recursiveTypeConversion(nested, from_type, to_type); |
103 | if (nested.get() == nested_no_lc.get()) |
104 | return column; |
105 | |
106 | return ColumnConst::create(nested_no_lc, column_const->size()); |
107 | } |
108 | |
109 | if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(from_type.get())) |
110 | { |
111 | if (to_type->equals(*low_cardinality_type->getDictionaryType())) |
112 | return column->convertToFullColumnIfLowCardinality(); |
113 | } |
114 | |
115 | if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(to_type.get())) |
116 | { |
117 | if (from_type->equals(*low_cardinality_type->getDictionaryType())) |
118 | { |
119 | auto col = low_cardinality_type->createColumn(); |
120 | assert_cast<ColumnLowCardinality &>(*col).insertRangeFromFullColumn(*column, 0, column->size()); |
121 | return col; |
122 | } |
123 | } |
124 | |
125 | if (const auto * from_array_type = typeid_cast<const DataTypeArray *>(from_type.get())) |
126 | { |
127 | if (const auto * to_array_type = typeid_cast<const DataTypeArray *>(to_type.get())) |
128 | { |
129 | const auto * column_array = typeid_cast<const ColumnArray *>(column.get()); |
130 | if (!column_array) |
131 | throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(), |
132 | ErrorCodes::ILLEGAL_COLUMN); |
133 | |
134 | auto & nested_from = from_array_type->getNestedType(); |
135 | auto & nested_to = to_array_type->getNestedType(); |
136 | |
137 | return ColumnArray::create( |
138 | recursiveTypeConversion(column_array->getDataPtr(), nested_from, nested_to), |
139 | column_array->getOffsetsPtr()); |
140 | } |
141 | } |
142 | |
143 | if (const auto * from_tuple_type = typeid_cast<const DataTypeTuple *>(from_type.get())) |
144 | { |
145 | if (const auto * to_tuple_type = typeid_cast<const DataTypeTuple *>(to_type.get())) |
146 | { |
147 | const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get()); |
148 | if (!column_tuple) |
149 | throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(), |
150 | ErrorCodes::ILLEGAL_COLUMN); |
151 | |
152 | auto columns = column_tuple->getColumns(); |
153 | auto & from_elements = from_tuple_type->getElements(); |
154 | auto & to_elements = to_tuple_type->getElements(); |
155 | |
156 | bool has_converted = false; |
157 | |
158 | for (size_t i = 0; i < columns.size(); ++i) |
159 | { |
160 | auto & element = columns[i]; |
161 | auto element_no_lc = recursiveTypeConversion(element, from_elements.at(i), to_elements.at(i)); |
162 | if (element.get() != element_no_lc.get()) |
163 | { |
164 | element = element_no_lc; |
165 | has_converted = true; |
166 | } |
167 | } |
168 | |
169 | if (!has_converted) |
170 | return column; |
171 | |
172 | return ColumnTuple::create(columns); |
173 | } |
174 | } |
175 | |
176 | throw Exception("Cannot convert: " + from_type->getName() + " to " + to_type->getName(), ErrorCodes::TYPE_MISMATCH); |
177 | } |
178 | |
179 | } |
180 | |