1#include <Interpreters/join_common.h>
2#include <Columns/ColumnNullable.h>
3#include <DataTypes/DataTypeNullable.h>
4#include <DataTypes/DataTypeLowCardinality.h>
5#include <DataStreams/materializeBlock.h>
6
7namespace DB
8{
9
10namespace ErrorCodes
11{
12 extern const int TYPE_MISMATCH;
13}
14
15
16namespace JoinCommon
17{
18
19void convertColumnToNullable(ColumnWithTypeAndName & column)
20{
21 if (column.type->isNullable() || !column.type->canBeInsideNullable())
22 return;
23
24 column.type = makeNullable(column.type);
25 if (column.column)
26 column.column = makeNullable(column.column);
27}
28
29void convertColumnsToNullable(Block & block, size_t starting_pos)
30{
31 for (size_t i = starting_pos; i < block.columns(); ++i)
32 convertColumnToNullable(block.getByPosition(i));
33}
34
35/// @warning It assumes that every NULL has default value in nested column (or it does not matter)
36void removeColumnNullability(ColumnWithTypeAndName & column)
37{
38 if (!column.type->isNullable())
39 return;
40
41 column.type = static_cast<const DataTypeNullable &>(*column.type).getNestedType();
42 if (column.column)
43 {
44 auto * nullable_column = checkAndGetColumn<ColumnNullable>(*column.column);
45 ColumnPtr nested_column = nullable_column->getNestedColumnPtr();
46 MutableColumnPtr mutable_column = (*std::move(nested_column)).mutate();
47 column.column = std::move(mutable_column);
48 }
49}
50
51ColumnRawPtrs materializeColumnsInplace(Block & block, const Names & names)
52{
53 ColumnRawPtrs ptrs;
54 ptrs.reserve(names.size());
55
56 for (auto & column_name : names)
57 {
58 auto & column = block.getByName(column_name).column;
59 column = recursiveRemoveLowCardinality(column->convertToFullColumnIfConst());
60 ptrs.push_back(column.get());
61 }
62
63 return ptrs;
64}
65
66Columns materializeColumns(const Block & block, const Names & names)
67{
68 Columns materialized;
69 materialized.reserve(names.size());
70
71 for (auto & column_name : names)
72 {
73 const auto & src_column = block.getByName(column_name).column;
74 materialized.emplace_back(recursiveRemoveLowCardinality(src_column->convertToFullColumnIfConst()));
75 }
76
77 return materialized;
78}
79
80ColumnRawPtrs getRawPointers(const Columns & columns)
81{
82 ColumnRawPtrs ptrs;
83 ptrs.reserve(columns.size());
84
85 for (auto & column : columns)
86 ptrs.push_back(column.get());
87
88 return ptrs;
89}
90
91void removeLowCardinalityInplace(Block & block)
92{
93 for (size_t i = 0; i < block.columns(); ++i)
94 {
95 auto & col = block.getByPosition(i);
96 col.column = recursiveRemoveLowCardinality(col.column);
97 col.type = recursiveRemoveLowCardinality(col.type);
98 }
99}
100
101ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
102 Block & sample_block_with_keys, Block & sample_block_with_columns_to_add)
103{
104 size_t keys_size = key_names_right.size();
105 ColumnRawPtrs key_columns(keys_size);
106
107 sample_block_with_columns_to_add = materializeBlock(right_sample_block);
108
109 for (size_t i = 0; i < keys_size; ++i)
110 {
111 const String & column_name = key_names_right[i];
112
113 /// there could be the same key names
114 if (sample_block_with_keys.has(column_name))
115 {
116 key_columns[i] = sample_block_with_keys.getByName(column_name).column.get();
117 continue;
118 }
119
120 auto & col = sample_block_with_columns_to_add.getByName(column_name);
121 col.column = recursiveRemoveLowCardinality(col.column);
122 col.type = recursiveRemoveLowCardinality(col.type);
123
124 /// Extract right keys with correct keys order.
125 sample_block_with_keys.insert(col);
126 sample_block_with_columns_to_add.erase(column_name);
127
128 key_columns[i] = sample_block_with_keys.getColumns().back().get();
129
130 /// We will join only keys, where all components are not NULL.
131 if (auto * nullable = checkAndGetColumn<ColumnNullable>(*key_columns[i]))
132 key_columns[i] = &nullable->getNestedColumn();
133 }
134
135 return key_columns;
136}
137
138void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right)
139{
140 size_t keys_size = key_names_left.size();
141
142 for (size_t i = 0; i < keys_size; ++i)
143 {
144 DataTypePtr left_type = removeNullable(recursiveRemoveLowCardinality(block_left.getByName(key_names_left[i]).type));
145 DataTypePtr right_type = removeNullable(recursiveRemoveLowCardinality(block_right.getByName(key_names_right[i]).type));
146
147 if (!left_type->equals(*right_type))
148 throw Exception("Type mismatch of columns to JOIN by: "
149 + key_names_left[i] + " " + left_type->getName() + " at left, "
150 + key_names_right[i] + " " + right_type->getName() + " at right",
151 ErrorCodes::TYPE_MISMATCH);
152 }
153}
154
155void createMissedColumns(Block & block)
156{
157 for (size_t i = 0; i < block.columns(); ++i)
158 {
159 auto & column = block.getByPosition(i);
160 if (!column.column)
161 column.column = column.type->createColumn();
162 }
163}
164
165void joinTotals(const Block & totals, const Block & columns_to_add, const Names & key_names_right, Block & block)
166{
167 if (Block totals_without_keys = totals)
168 {
169 for (const auto & name : key_names_right)
170 totals_without_keys.erase(totals_without_keys.getPositionByName(name));
171
172 for (size_t i = 0; i < totals_without_keys.columns(); ++i)
173 block.insert(totals_without_keys.safeGetByPosition(i));
174 }
175 else
176 {
177 /// We will join empty `totals` - from one row with the default values.
178
179 for (size_t i = 0; i < columns_to_add.columns(); ++i)
180 {
181 const auto & col = columns_to_add.getByPosition(i);
182 block.insert({
183 col.type->createColumnConstWithDefaultValue(1)->convertToFullColumnIfConst(),
184 col.type,
185 col.name});
186 }
187 }
188}
189
190}
191}
192