1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <DataTypes/DataTypeArray.h>
5#include <DataTypes/DataTypeNothing.h>
6#include <DataTypes/DataTypesNumber.h>
7#include <DataTypes/DataTypesDecimal.h>
8#include <DataTypes/DataTypeDate.h>
9#include <DataTypes/DataTypeDateTime.h>
10#include <DataTypes/DataTypeDateTime64.h>
11#include <DataTypes/DataTypeNullable.h>
12#include <DataTypes/DataTypeTuple.h>
13#include <DataTypes/getMostSubtype.h>
14#include <Columns/ColumnArray.h>
15#include <Columns/ColumnString.h>
16#include <Columns/ColumnFixedString.h>
17#include <Columns/ColumnDecimal.h>
18#include <Columns/ColumnNullable.h>
19#include <Columns/ColumnTuple.h>
20#include <Common/HashTable/ClearableHashMap.h>
21#include <Common/assert_cast.h>
22#include <Core/TypeListNumber.h>
23#include <Interpreters/castColumn.h>
24#include <ext/range.h>
25
26
27namespace DB
28{
29
30namespace ErrorCodes
31{
32 extern const int LOGICAL_ERROR;
33 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
34 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
35}
36
37class FunctionArrayIntersect : public IFunction
38{
39public:
40 static constexpr auto name = "arrayIntersect";
41 static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayIntersect>(context); }
42 FunctionArrayIntersect(const Context & context_) : context(context_) {}
43
44 String getName() const override { return name; }
45
46 bool isVariadic() const override { return true; }
47 size_t getNumberOfArguments() const override { return 0; }
48
49 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override;
50
51 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
52
53 bool useDefaultImplementationForConstants() const override { return true; }
54
55private:
56 const Context & context;
57
58 /// Initially allocate a piece of memory for 512 elements. NOTE: This is just a guess.
59 static constexpr size_t INITIAL_SIZE_DEGREE = 9;
60
61 struct UnpackedArrays
62 {
63 size_t base_rows = 0;
64
65 struct UnpackedArray
66 {
67 bool is_const = false;
68 const NullMap * null_map = nullptr;
69 const NullMap * overflow_mask = nullptr;
70 const ColumnArray::ColumnOffsets::Container * offsets = nullptr;
71 const IColumn * nested_column = nullptr;
72
73 };
74
75 std::vector<UnpackedArray> args;
76 Columns column_holders;
77
78 UnpackedArrays() = default;
79 };
80
81 /// Cast column to data_type removing nullable if data_type hasn't.
82 /// It's expected that column can represent data_type after removing some NullMap's.
83 ColumnPtr castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const;
84
85 struct CastArgumentsResult
86 {
87 ColumnsWithTypeAndName initial;
88 ColumnsWithTypeAndName casted;
89 };
90
91 CastArgumentsResult castColumns(Block & block, const ColumnNumbers & arguments,
92 const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls) const;
93 UnpackedArrays prepareArrays(const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const;
94
95 template <typename Map, typename ColumnType, bool is_numeric_column>
96 static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data);
97
98 struct NumberExecutor
99 {
100 const UnpackedArrays & arrays;
101 const DataTypePtr & data_type;
102 ColumnPtr & result;
103
104 NumberExecutor(const UnpackedArrays & arrays_, const DataTypePtr & data_type_, ColumnPtr & result_)
105 : arrays(arrays_), data_type(data_type_), result(result_) {}
106
107 template <typename T, size_t>
108 void operator()();
109 };
110
111 struct DecimalExecutor
112 {
113 const UnpackedArrays & arrays;
114 const DataTypePtr & data_type;
115 ColumnPtr & result;
116
117 DecimalExecutor(const UnpackedArrays & arrays_, const DataTypePtr & data_type_, ColumnPtr & result_)
118 : arrays(arrays_), data_type(data_type_), result(result_) {}
119
120 template <typename T, size_t>
121 void operator()();
122 };
123};
124
125
126DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const
127{
128 DataTypes nested_types;
129 nested_types.reserve(arguments.size());
130
131 bool has_nothing = false;
132
133 if (arguments.empty())
134 throw Exception{"Function " + getName() + " requires at least one argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
135
136 for (auto i : ext::range(0, arguments.size()))
137 {
138 auto array_type = typeid_cast<const DataTypeArray *>(arguments[i].get());
139 if (!array_type)
140 throw Exception("Argument " + std::to_string(i) + " for function " + getName() + " must be an array but it has type "
141 + arguments[i]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
142
143 const auto & nested_type = array_type->getNestedType();
144
145 if (typeid_cast<const DataTypeNothing *>(nested_type.get()))
146 has_nothing = true;
147 else
148 nested_types.push_back(nested_type);
149 }
150
151 DataTypePtr result_type;
152
153 if (!nested_types.empty())
154 result_type = getMostSubtype(nested_types, true);
155
156 if (has_nothing)
157 result_type = std::make_shared<DataTypeNothing>();
158
159 return std::make_shared<DataTypeArray>(result_type);
160}
161
162ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
163{
164 if (auto column_nullable = checkAndGetColumn<ColumnNullable>(column.get()))
165 {
166 auto nullable_type = checkAndGetDataType<DataTypeNullable>(data_type.get());
167 const auto & nested = column_nullable->getNestedColumnPtr();
168 if (nullable_type)
169 {
170 auto casted_column = castRemoveNullable(nested, nullable_type->getNestedType());
171 return ColumnNullable::create(casted_column, column_nullable->getNullMapColumnPtr());
172 }
173 return castRemoveNullable(nested, data_type);
174 }
175 else if (auto column_array = checkAndGetColumn<ColumnArray>(column.get()))
176 {
177 auto array_type = checkAndGetDataType<DataTypeArray>(data_type.get());
178 if (!array_type)
179 throw Exception{"Cannot cast array column to column with type "
180 + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR};
181
182 auto casted_column = castRemoveNullable(column_array->getDataPtr(), array_type->getNestedType());
183 return ColumnArray::create(casted_column, column_array->getOffsetsPtr());
184 }
185 else if (auto column_tuple = checkAndGetColumn<ColumnTuple>(column.get()))
186 {
187 auto tuple_type = checkAndGetDataType<DataTypeTuple>(data_type.get());
188
189 if (!tuple_type)
190 throw Exception{"Cannot cast tuple column to type "
191 + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR};
192
193 auto columns_number = column_tuple->tupleSize();
194 Columns columns(columns_number);
195
196 const auto & types = tuple_type->getElements();
197
198 for (auto i : ext::range(0, columns_number))
199 {
200 columns[i] = castRemoveNullable(column_tuple->getColumnPtr(i), types[i]);
201 }
202 return ColumnTuple::create(columns);
203 }
204
205 return column;
206}
207
208FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns(
209 Block & block, const ColumnNumbers & arguments, const DataTypePtr & return_type,
210 const DataTypePtr & return_type_with_nulls) const
211{
212 size_t num_args = arguments.size();
213 ColumnsWithTypeAndName initial_columns(num_args);
214 ColumnsWithTypeAndName columns(num_args);
215
216 auto type_array = checkAndGetDataType<DataTypeArray>(return_type.get());
217 auto & type_nested = type_array->getNestedType();
218 auto type_not_nullable_nested = removeNullable(type_nested);
219
220 const bool is_numeric_or_string = isNativeNumber(type_not_nullable_nested)
221 || isDateOrDateTime(type_not_nullable_nested)
222 || isStringOrFixedString(type_not_nullable_nested);
223
224 DataTypePtr nullable_return_type;
225
226 if (is_numeric_or_string)
227 {
228 auto type_nullable_nested = makeNullable(type_nested);
229 nullable_return_type = std::make_shared<DataTypeArray>(type_nullable_nested);
230 }
231
232 const bool nested_is_nullable = type_nested->isNullable();
233
234 for (size_t i = 0; i < num_args; ++i)
235 {
236 const ColumnWithTypeAndName & arg = block.getByPosition(arguments[i]);
237 initial_columns[i] = arg;
238 columns[i] = arg;
239 auto & column = columns[i];
240
241 if (is_numeric_or_string)
242 {
243 /// Cast to Array(T) or Array(Nullable(T)).
244 if (nested_is_nullable)
245 {
246 if (!arg.type->equals(*return_type))
247 {
248 column.column = castColumn(arg, return_type, context);
249 column.type = return_type;
250 }
251 }
252 else
253 {
254
255 if (!arg.type->equals(*return_type) && !arg.type->equals(*nullable_return_type))
256 {
257 /// If result has array type Array(T) still cast Array(Nullable(U)) to Array(Nullable(T))
258 /// because cannot cast Nullable(T) to T.
259 if (static_cast<const DataTypeArray &>(*arg.type).getNestedType()->isNullable())
260 {
261 column.column = castColumn(arg, nullable_return_type, context);
262 column.type = nullable_return_type;
263 }
264 else
265 {
266 column.column = castColumn(arg, return_type, context);
267 column.type = return_type;
268 }
269 }
270 }
271 }
272 else
273 {
274 /// return_type_with_nulls is the most common subtype with possible nullable parts.
275 if (!arg.type->equals(*return_type_with_nulls))
276 {
277 column.column = castColumn(arg, return_type_with_nulls, context);
278 column.type = return_type_with_nulls;
279 }
280 }
281 }
282
283 return {.initial = initial_columns, .casted = columns};
284}
285
286static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTypeAndName second, const Context & context)
287{
288 ColumnsWithTypeAndName args;
289 args.reserve(2);
290 args.emplace_back(std::move(first));
291 args.emplace_back(std::move(second));
292
293 auto eq_func = FunctionFactory::instance().get("notEquals", context)->build(args);
294
295 Block block = args;
296 block.insert({nullptr, eq_func->getReturnType(), ""});
297
298 eq_func->execute(block, {0, 1}, 2, args.front().column->size());
299
300 return block.getByPosition(2).column;
301}
302
303FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays(
304 const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const
305{
306 UnpackedArrays arrays;
307
308 size_t columns_number = columns.size();
309 arrays.args.resize(columns_number);
310
311 bool all_const = true;
312
313 for (auto i : ext::range(0, columns_number))
314 {
315 auto & arg = arrays.args[i];
316 auto argument_column = columns[i].column.get();
317 auto initial_column = initial_columns[i].column.get();
318
319 if (auto argument_column_const = typeid_cast<const ColumnConst *>(argument_column))
320 {
321 arg.is_const = true;
322 argument_column = argument_column_const->getDataColumnPtr().get();
323 initial_column = typeid_cast<const ColumnConst *>(initial_column)->getDataColumnPtr().get();
324 }
325
326 if (auto argument_column_array = typeid_cast<const ColumnArray *>(argument_column))
327 {
328 if (!arg.is_const)
329 all_const = false;
330
331 arg.offsets = &argument_column_array->getOffsets();
332 arg.nested_column = &argument_column_array->getData();
333
334 initial_column = &typeid_cast<const ColumnArray *>(initial_column)->getData();
335
336 if (auto column_nullable = typeid_cast<const ColumnNullable *>(arg.nested_column))
337 {
338 arg.null_map = &column_nullable->getNullMapData();
339 arg.nested_column = &column_nullable->getNestedColumn();
340 initial_column = &typeid_cast<const ColumnNullable *>(initial_column)->getNestedColumn();
341 }
342
343 /// In case column was casted need to create overflow mask for integer types.
344 if (arg.nested_column != initial_column)
345 {
346 auto & nested_init_type = typeid_cast<const DataTypeArray *>(removeNullable(initial_columns[i].type).get())->getNestedType();
347 auto & nested_cast_type = typeid_cast<const DataTypeArray *>(removeNullable(columns[i].type).get())->getNestedType();
348
349 if (isInteger(nested_init_type) || isDateOrDateTime(nested_init_type))
350 {
351 /// Compare original and casted columns. It seem to be the easiest way.
352 auto overflow_mask = callFunctionNotEquals(
353 {arg.nested_column->getPtr(), nested_init_type, ""},
354 {initial_column->getPtr(), nested_cast_type, ""},
355 context);
356
357 arg.overflow_mask = &typeid_cast<const ColumnUInt8 *>(overflow_mask.get())->getData();
358 arrays.column_holders.emplace_back(std::move(overflow_mask));
359 }
360 }
361 }
362 else
363 throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR};
364 }
365
366 if (all_const)
367 {
368 arrays.base_rows = arrays.args.front().offsets->size();
369 }
370 else
371 {
372 for (auto i : ext::range(0, columns_number))
373 {
374 if (arrays.args[i].is_const)
375 continue;
376
377 size_t rows = arrays.args[i].offsets->size();
378 if (arrays.base_rows == 0 && rows > 0)
379 arrays.base_rows = rows;
380 else if (arrays.base_rows != rows)
381 throw Exception("Non-const array columns in function " + getName() + "should have same rows", ErrorCodes::LOGICAL_ERROR);
382 }
383 }
384
385 return arrays;
386}
387
388void FunctionArrayIntersect::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
389{
390 const auto & return_type = block.getByPosition(result).type;
391 auto return_type_array = checkAndGetDataType<DataTypeArray>(return_type.get());
392
393 if (!return_type_array)
394 throw Exception{"Return type for function " + getName() + " must be array.", ErrorCodes::LOGICAL_ERROR};
395
396 const auto & nested_return_type = return_type_array->getNestedType();
397
398 if (typeid_cast<const DataTypeNothing *>(nested_return_type.get()))
399 {
400 block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count);
401 return;
402 }
403
404 auto num_args = arguments.size();
405 DataTypes data_types;
406 data_types.reserve(num_args);
407 for (size_t i = 0; i < num_args; ++i)
408 data_types.push_back(block.getByPosition(arguments[i]).type);
409
410 auto return_type_with_nulls = getMostSubtype(data_types, true, true);
411
412 auto columns = castColumns(block, arguments, return_type, return_type_with_nulls);
413
414 UnpackedArrays arrays = prepareArrays(columns.casted, columns.initial);
415
416 ColumnPtr result_column;
417 auto not_nullable_nested_return_type = removeNullable(nested_return_type);
418 TypeListNativeNumbers::forEach(NumberExecutor(arrays, not_nullable_nested_return_type, result_column));
419 TypeListDecimalNumbers::forEach(DecimalExecutor(arrays, not_nullable_nested_return_type, result_column));
420
421 using DateMap = ClearableHashMap<DataTypeDate::FieldType, size_t, DefaultHash<DataTypeDate::FieldType>,
422 HashTableGrower<INITIAL_SIZE_DEGREE>,
423 HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(DataTypeDate::FieldType)>>;
424
425 using DateTimeMap = ClearableHashMap<DataTypeDateTime::FieldType, size_t, DefaultHash<DataTypeDateTime::FieldType>,
426 HashTableGrower<INITIAL_SIZE_DEGREE>,
427 HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(DataTypeDateTime::FieldType)>>;
428
429 using StringMap = ClearableHashMap<StringRef, size_t, StringRefHash, HashTableGrower<INITIAL_SIZE_DEGREE>,
430 HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(StringRef)>>;
431
432 if (!result_column)
433 {
434 auto column = not_nullable_nested_return_type->createColumn();
435 WhichDataType which(not_nullable_nested_return_type);
436
437 if (which.isDate())
438 result_column = execute<DateMap, ColumnVector<DataTypeDate::FieldType>, true>(arrays, std::move(column));
439 else if (which.isDateTime())
440 result_column = execute<DateTimeMap, ColumnVector<DataTypeDateTime::FieldType>, true>(arrays, std::move(column));
441 else if (which.isString())
442 result_column = execute<StringMap, ColumnString, false>(arrays, std::move(column));
443 else if (which.isFixedString())
444 result_column = execute<StringMap, ColumnFixedString, false>(arrays, std::move(column));
445 else
446 {
447 column = assert_cast<const DataTypeArray &>(*return_type_with_nulls).getNestedType()->createColumn();
448 result_column = castRemoveNullable(execute<StringMap, IColumn, false>(arrays, std::move(column)), return_type);
449 }
450 }
451
452 block.getByPosition(result).column = std::move(result_column);
453}
454
455template <typename T, size_t>
456void FunctionArrayIntersect::NumberExecutor::operator()()
457{
458 using Map = ClearableHashMap<T, size_t, DefaultHash<T>, HashTableGrower<INITIAL_SIZE_DEGREE>,
459 HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>;
460
461 if (!result && typeid_cast<const DataTypeNumber<T> *>(data_type.get()))
462 result = execute<Map, ColumnVector<T>, true>(arrays, ColumnVector<T>::create());
463}
464
465template <typename T, size_t>
466void FunctionArrayIntersect::DecimalExecutor::operator()()
467{
468 using Map = ClearableHashMap<T, size_t, DefaultHash<T>, HashTableGrower<INITIAL_SIZE_DEGREE>,
469 HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>;
470
471 if (!result)
472 if (auto * decimal = typeid_cast<const DataTypeDecimal<T> *>(data_type.get()))
473 result = execute<Map, ColumnDecimal<T>, true>(arrays, ColumnDecimal<T>::create(0, decimal->getScale()));
474}
475
476template <typename Map, typename ColumnType, bool is_numeric_column>
477ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr)
478{
479 auto args = arrays.args.size();
480 auto rows = arrays.base_rows;
481
482 bool all_nullable = true;
483
484 std::vector<const ColumnType *> columns;
485 columns.reserve(args);
486 for (auto & arg : arrays.args)
487 {
488 if constexpr (std::is_same<ColumnType, IColumn>::value)
489 columns.push_back(arg.nested_column);
490 else
491 columns.push_back(checkAndGetColumn<ColumnType>(arg.nested_column));
492
493 if (!columns.back())
494 throw Exception("Unexpected array type for function arrayIntersect", ErrorCodes::LOGICAL_ERROR);
495
496 if (!arg.null_map)
497 all_nullable = false;
498 }
499
500 auto & result_data = static_cast<ColumnType &>(*result_data_ptr);
501 auto result_offsets_ptr = ColumnArray::ColumnOffsets::create(rows);
502 auto & result_offsets = assert_cast<ColumnArray::ColumnOffsets &>(*result_offsets_ptr);
503 auto null_map_column = ColumnUInt8::create();
504 NullMap & null_map = assert_cast<ColumnUInt8 &>(*null_map_column).getData();
505
506 Arena arena;
507
508 Map map;
509 std::vector<size_t> prev_off(args, 0);
510 size_t result_offset = 0;
511 for (auto row : ext::range(0, rows))
512 {
513 map.clear();
514
515 bool all_has_nullable = all_nullable;
516
517 for (auto arg_num : ext::range(0, args))
518 {
519 auto & arg = arrays.args[arg_num];
520 bool current_has_nullable = false;
521
522 size_t off;
523 // const array has only one row
524 if (arg.is_const)
525 off = (*arg.offsets)[0];
526 else
527 off = (*arg.offsets)[row];
528
529 for (auto i : ext::range(prev_off[arg_num], off))
530 {
531 if (arg.null_map && (*arg.null_map)[i])
532 current_has_nullable = true;
533 else if (!arg.overflow_mask || (*arg.overflow_mask)[i] == 0)
534 {
535 typename Map::mapped_type * value = nullptr;
536
537 if constexpr (is_numeric_column)
538 value = &map[columns[arg_num]->getElement(i)];
539 else if constexpr (std::is_same<ColumnType, ColumnString>::value || std::is_same<ColumnType, ColumnFixedString>::value)
540 value = &map[columns[arg_num]->getDataAt(i)];
541 else
542 {
543 const char * data = nullptr;
544 value = &map[columns[arg_num]->serializeValueIntoArena(i, arena, data)];
545 }
546
547 /// Here we count the number of element appearances, but no more than once per array.
548 if (*value == arg_num)
549 ++(*value);
550 }
551 }
552
553 prev_off[arg_num] = off;
554 if (arg.is_const)
555 prev_off[arg_num] = 0;
556
557 if (!current_has_nullable)
558 all_has_nullable = false;
559 }
560
561 if (all_has_nullable)
562 {
563 ++result_offset;
564 result_data.insertDefault();
565 null_map.push_back(1);
566 }
567
568 for (const auto & pair : map)
569 {
570 if (pair.getMapped() == args)
571 {
572 ++result_offset;
573 if constexpr (is_numeric_column)
574 result_data.insertValue(pair.getKey());
575 else if constexpr (std::is_same<ColumnType, ColumnString>::value || std::is_same<ColumnType, ColumnFixedString>::value)
576 result_data.insertData(pair.getKey().data, pair.getKey().size);
577 else
578 result_data.deserializeAndInsertFromArena(pair.getKey().data);
579
580 if (all_nullable)
581 null_map.push_back(0);
582 }
583 }
584 result_offsets.getElement(row) = result_offset;
585 }
586
587 ColumnPtr result_column = std::move(result_data_ptr);
588 if (all_nullable)
589 result_column = ColumnNullable::create(result_column, std::move(null_map_column));
590 return ColumnArray::create(result_column, std::move(result_offsets_ptr));
591}
592
593
594void registerFunctionArrayIntersect(FunctionFactory & factory)
595{
596 factory.registerFunction<FunctionArrayIntersect>();
597}
598
599}
600