1#include <Columns/ColumnArray.h>
2#include <Columns/ColumnConst.h>
3#include <Columns/ColumnTuple.h>
4#include <Columns/ColumnLowCardinality.h>
5
6#include <DataTypes/DataTypeLowCardinality.h>
7#include <DataTypes/DataTypeArray.h>
8#include <DataTypes/DataTypeTuple.h>
9
10#include <Common/assert_cast.h>
11
12
13namespace DB
14{
15
16namespace ErrorCodes
17{
18 extern const int ILLEGAL_COLUMN;
19 extern const int TYPE_MISMATCH;
20}
21
22DataTypePtr recursiveRemoveLowCardinality(const DataTypePtr & type)
23{
24 if (!type)
25 return type;
26
27 if (const auto * array_type = typeid_cast<const DataTypeArray *>(type.get()))
28 return std::make_shared<DataTypeArray>(recursiveRemoveLowCardinality(array_type->getNestedType()));
29
30 if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(type.get()))
31 {
32 DataTypes elements = tuple_type->getElements();
33 for (auto & element : elements)
34 element = recursiveRemoveLowCardinality(element);
35
36 if (tuple_type->haveExplicitNames())
37 return std::make_shared<DataTypeTuple>(elements, tuple_type->getElementNames());
38 else
39 return std::make_shared<DataTypeTuple>(elements);
40 }
41
42 if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(type.get()))
43 return low_cardinality_type->getDictionaryType();
44
45 return type;
46}
47
48ColumnPtr recursiveRemoveLowCardinality(const ColumnPtr & column)
49{
50 if (!column)
51 return column;
52
53 if (const auto * column_array = typeid_cast<const ColumnArray *>(column.get()))
54 {
55 auto & data = column_array->getDataPtr();
56 auto data_no_lc = recursiveRemoveLowCardinality(data);
57 if (data.get() == data_no_lc.get())
58 return column;
59
60 return ColumnArray::create(data_no_lc, column_array->getOffsetsPtr());
61 }
62
63 if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
64 {
65 auto & nested = column_const->getDataColumnPtr();
66 auto nested_no_lc = recursiveRemoveLowCardinality(nested);
67 if (nested.get() == nested_no_lc.get())
68 return column;
69
70 return ColumnConst::create(nested_no_lc, column_const->size());
71 }
72
73 if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get()))
74 {
75 auto columns = column_tuple->getColumns();
76 for (auto & element : columns)
77 element = recursiveRemoveLowCardinality(element);
78 return ColumnTuple::create(columns);
79 }
80
81 if (const auto * column_low_cardinality = typeid_cast<const ColumnLowCardinality *>(column.get()))
82 return column_low_cardinality->convertToFullColumn();
83
84 return column;
85}
86
87ColumnPtr recursiveTypeConversion(const ColumnPtr & column, const DataTypePtr & from_type, const DataTypePtr & to_type)
88{
89 if (!column)
90 return column;
91
92 if (from_type->equals(*to_type))
93 return column;
94
95 /// We can allow insert enum column if it's numeric type is the same as the column's type in table.
96 if (WhichDataType(to_type).isEnum() && from_type->getTypeId() == to_type->getTypeId())
97 return column;
98
99 if (const auto * column_const = typeid_cast<const ColumnConst *>(column.get()))
100 {
101 auto & nested = column_const->getDataColumnPtr();
102 auto nested_no_lc = recursiveTypeConversion(nested, from_type, to_type);
103 if (nested.get() == nested_no_lc.get())
104 return column;
105
106 return ColumnConst::create(nested_no_lc, column_const->size());
107 }
108
109 if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(from_type.get()))
110 {
111 if (to_type->equals(*low_cardinality_type->getDictionaryType()))
112 return column->convertToFullColumnIfLowCardinality();
113 }
114
115 if (const auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(to_type.get()))
116 {
117 if (from_type->equals(*low_cardinality_type->getDictionaryType()))
118 {
119 auto col = low_cardinality_type->createColumn();
120 assert_cast<ColumnLowCardinality &>(*col).insertRangeFromFullColumn(*column, 0, column->size());
121 return col;
122 }
123 }
124
125 if (const auto * from_array_type = typeid_cast<const DataTypeArray *>(from_type.get()))
126 {
127 if (const auto * to_array_type = typeid_cast<const DataTypeArray *>(to_type.get()))
128 {
129 const auto * column_array = typeid_cast<const ColumnArray *>(column.get());
130 if (!column_array)
131 throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(),
132 ErrorCodes::ILLEGAL_COLUMN);
133
134 auto & nested_from = from_array_type->getNestedType();
135 auto & nested_to = to_array_type->getNestedType();
136
137 return ColumnArray::create(
138 recursiveTypeConversion(column_array->getDataPtr(), nested_from, nested_to),
139 column_array->getOffsetsPtr());
140 }
141 }
142
143 if (const auto * from_tuple_type = typeid_cast<const DataTypeTuple *>(from_type.get()))
144 {
145 if (const auto * to_tuple_type = typeid_cast<const DataTypeTuple *>(to_type.get()))
146 {
147 const auto * column_tuple = typeid_cast<const ColumnTuple *>(column.get());
148 if (!column_tuple)
149 throw Exception("Unexpected column " + column->getName() + " for type " + from_type->getName(),
150 ErrorCodes::ILLEGAL_COLUMN);
151
152 auto columns = column_tuple->getColumns();
153 auto & from_elements = from_tuple_type->getElements();
154 auto & to_elements = to_tuple_type->getElements();
155
156 bool has_converted = false;
157
158 for (size_t i = 0; i < columns.size(); ++i)
159 {
160 auto & element = columns[i];
161 auto element_no_lc = recursiveTypeConversion(element, from_elements.at(i), to_elements.at(i));
162 if (element.get() != element_no_lc.get())
163 {
164 element = element_no_lc;
165 has_converted = true;
166 }
167 }
168
169 if (!has_converted)
170 return column;
171
172 return ColumnTuple::create(columns);
173 }
174 }
175
176 throw Exception("Cannot convert: " + from_type->getName() + " to " + to_type->getName(), ErrorCodes::TYPE_MISMATCH);
177}
178
179}
180