| 1 | #include <Functions/IFunctionAdaptors.h> |
| 2 | |
| 3 | #include <Common/config.h> |
| 4 | #include <Common/typeid_cast.h> |
| 5 | #include <Common/assert_cast.h> |
| 6 | #include <Common/LRUCache.h> |
| 7 | #include <Columns/ColumnConst.h> |
| 8 | #include <Columns/ColumnNullable.h> |
| 9 | #include <Columns/ColumnArray.h> |
| 10 | #include <Columns/ColumnTuple.h> |
| 11 | #include <Columns/ColumnLowCardinality.h> |
| 12 | #include <DataTypes/DataTypeNothing.h> |
| 13 | #include <DataTypes/DataTypeNullable.h> |
| 14 | #include <DataTypes/DataTypeTuple.h> |
| 15 | #include <DataTypes/Native.h> |
| 16 | #include <DataTypes/DataTypeLowCardinality.h> |
| 17 | #include <DataTypes/getLeastSupertype.h> |
| 18 | #include <Functions/FunctionHelpers.h> |
| 19 | #include <Interpreters/ExpressionActions.h> |
| 20 | #include <IO/WriteHelpers.h> |
| 21 | #include <ext/range.h> |
| 22 | #include <ext/collection_cast.h> |
| 23 | #include <cstdlib> |
| 24 | #include <memory> |
| 25 | #include <optional> |
| 26 | |
| 27 | #if USE_EMBEDDED_COMPILER |
| 28 | #pragma GCC diagnostic push |
| 29 | #pragma GCC diagnostic ignored "-Wunused-parameter" |
| 30 | #include <llvm/IR/IRBuilder.h> |
| 31 | #pragma GCC diagnostic pop |
| 32 | #endif |
| 33 | |
| 34 | |
| 35 | namespace DB |
| 36 | { |
| 37 | |
| 38 | namespace ErrorCodes |
| 39 | { |
| 40 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 41 | extern const int ILLEGAL_COLUMN; |
| 42 | } |
| 43 | |
| 44 | |
| 45 | /// Cache for functions result if it was executed on low cardinality column. |
| 46 | /// It's LRUCache which stores function result executed on dictionary and index mapping. |
| 47 | /// It's expected that cache_size is a number of reading streams (so, will store single cached value per thread). |
| 48 | class ExecutableFunctionLowCardinalityResultCache |
| 49 | { |
| 50 | public: |
| 51 | /// Will assume that dictionaries with same hash has the same keys. |
| 52 | /// Just in case, check that they have also the same size. |
| 53 | struct DictionaryKey |
| 54 | { |
| 55 | UInt128 hash; |
| 56 | UInt64 size; |
| 57 | |
| 58 | bool operator== (const DictionaryKey & other) const { return hash == other.hash && size == other.size; } |
| 59 | }; |
| 60 | |
| 61 | struct DictionaryKeyHash |
| 62 | { |
| 63 | size_t operator()(const DictionaryKey & key) const |
| 64 | { |
| 65 | SipHash hash; |
| 66 | hash.update(key.hash.low); |
| 67 | hash.update(key.hash.high); |
| 68 | hash.update(key.size); |
| 69 | return hash.get64(); |
| 70 | } |
| 71 | }; |
| 72 | |
| 73 | struct CachedValues |
| 74 | { |
| 75 | /// Store ptr to dictionary to be sure it won't be deleted. |
| 76 | ColumnPtr dictionary_holder; |
| 77 | ColumnUniquePtr function_result; |
| 78 | /// Remap positions. new_pos = index_mapping->index(old_pos); |
| 79 | ColumnPtr index_mapping; |
| 80 | }; |
| 81 | |
| 82 | using CachedValuesPtr = std::shared_ptr<CachedValues>; |
| 83 | |
| 84 | explicit ExecutableFunctionLowCardinalityResultCache(size_t cache_size) : cache(cache_size) {} |
| 85 | |
| 86 | CachedValuesPtr get(const DictionaryKey & key) { return cache.get(key); } |
| 87 | void set(const DictionaryKey & key, const CachedValuesPtr & mapped) { cache.set(key, mapped); } |
| 88 | CachedValuesPtr getOrSet(const DictionaryKey & key, const CachedValuesPtr & mapped) |
| 89 | { |
| 90 | return cache.getOrSet(key, [&]() { return mapped; }).first; |
| 91 | } |
| 92 | |
| 93 | private: |
| 94 | using Cache = LRUCache<DictionaryKey, CachedValues, DictionaryKeyHash>; |
| 95 | Cache cache; |
| 96 | }; |
| 97 | |
| 98 | |
| 99 | void ExecutableFunctionAdaptor::createLowCardinalityResultCache(size_t cache_size) |
| 100 | { |
| 101 | if (!low_cardinality_result_cache) |
| 102 | low_cardinality_result_cache = std::make_shared<ExecutableFunctionLowCardinalityResultCache>(cache_size); |
| 103 | } |
| 104 | |
| 105 | |
| 106 | ColumnPtr wrapInNullable(const ColumnPtr & src, const Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) |
| 107 | { |
| 108 | ColumnPtr result_null_map_column; |
| 109 | |
| 110 | /// If result is already nullable. |
| 111 | ColumnPtr src_not_nullable = src; |
| 112 | |
| 113 | if (src->onlyNull()) |
| 114 | return src; |
| 115 | else if (auto * nullable = checkAndGetColumn<ColumnNullable>(*src)) |
| 116 | { |
| 117 | src_not_nullable = nullable->getNestedColumnPtr(); |
| 118 | result_null_map_column = nullable->getNullMapColumnPtr(); |
| 119 | } |
| 120 | |
| 121 | for (const auto & arg : args) |
| 122 | { |
| 123 | const ColumnWithTypeAndName & elem = block.getByPosition(arg); |
| 124 | if (!elem.type->isNullable()) |
| 125 | continue; |
| 126 | |
| 127 | /// Const Nullable that are NULL. |
| 128 | if (elem.column->onlyNull()) |
| 129 | { |
| 130 | auto result_type = block.getByPosition(result).type; |
| 131 | assert(result_type->isNullable()); |
| 132 | return result_type->createColumnConstWithDefaultValue(input_rows_count); |
| 133 | } |
| 134 | |
| 135 | if (isColumnConst(*elem.column)) |
| 136 | continue; |
| 137 | |
| 138 | if (auto * nullable = checkAndGetColumn<ColumnNullable>(*elem.column)) |
| 139 | { |
| 140 | const ColumnPtr & null_map_column = nullable->getNullMapColumnPtr(); |
| 141 | if (!result_null_map_column) |
| 142 | { |
| 143 | result_null_map_column = null_map_column; |
| 144 | } |
| 145 | else |
| 146 | { |
| 147 | MutableColumnPtr mutable_result_null_map_column = (*std::move(result_null_map_column)).mutate(); |
| 148 | |
| 149 | NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mutable_result_null_map_column).getData(); |
| 150 | const NullMap & src_null_map = assert_cast<const ColumnUInt8 &>(*null_map_column).getData(); |
| 151 | |
| 152 | for (size_t i = 0, size = result_null_map.size(); i < size; ++i) |
| 153 | if (src_null_map[i]) |
| 154 | result_null_map[i] = 1; |
| 155 | |
| 156 | result_null_map_column = std::move(mutable_result_null_map_column); |
| 157 | } |
| 158 | } |
| 159 | } |
| 160 | |
| 161 | if (!result_null_map_column) |
| 162 | return makeNullable(src); |
| 163 | |
| 164 | return ColumnNullable::create(src_not_nullable->convertToFullColumnIfConst(), result_null_map_column); |
| 165 | } |
| 166 | |
| 167 | |
| 168 | namespace |
| 169 | { |
| 170 | |
| 171 | struct NullPresence |
| 172 | { |
| 173 | bool has_nullable = false; |
| 174 | bool has_null_constant = false; |
| 175 | }; |
| 176 | |
| 177 | NullPresence getNullPresense(const Block & block, const ColumnNumbers & args) |
| 178 | { |
| 179 | NullPresence res; |
| 180 | |
| 181 | for (const auto & arg : args) |
| 182 | { |
| 183 | const auto & elem = block.getByPosition(arg); |
| 184 | |
| 185 | if (!res.has_nullable) |
| 186 | res.has_nullable = elem.type->isNullable(); |
| 187 | if (!res.has_null_constant) |
| 188 | res.has_null_constant = elem.type->onlyNull(); |
| 189 | } |
| 190 | |
| 191 | return res; |
| 192 | } |
| 193 | |
| 194 | NullPresence getNullPresense(const ColumnsWithTypeAndName & args) |
| 195 | { |
| 196 | NullPresence res; |
| 197 | |
| 198 | for (const auto & elem : args) |
| 199 | { |
| 200 | if (!res.has_nullable) |
| 201 | res.has_nullable = elem.type->isNullable(); |
| 202 | if (!res.has_null_constant) |
| 203 | res.has_null_constant = elem.type->onlyNull(); |
| 204 | } |
| 205 | |
| 206 | return res; |
| 207 | } |
| 208 | |
| 209 | bool allArgumentsAreConstants(const Block & block, const ColumnNumbers & args) |
| 210 | { |
| 211 | for (auto arg : args) |
| 212 | if (!isColumnConst(*block.getByPosition(arg).column)) |
| 213 | return false; |
| 214 | return true; |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | bool ExecutableFunctionAdaptor::defaultImplementationForConstantArguments( |
| 219 | Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count, bool dry_run) |
| 220 | { |
| 221 | ColumnNumbers arguments_to_remain_constants = impl->getArgumentsThatAreAlwaysConstant(); |
| 222 | |
| 223 | /// Check that these arguments are really constant. |
| 224 | for (auto arg_num : arguments_to_remain_constants) |
| 225 | if (arg_num < args.size() && !isColumnConst(*block.getByPosition(args[arg_num]).column)) |
| 226 | throw Exception("Argument at index " + toString(arg_num) + " for function " + getName() + " must be constant" , ErrorCodes::ILLEGAL_COLUMN); |
| 227 | |
| 228 | if (args.empty() || !impl->useDefaultImplementationForConstants() || !allArgumentsAreConstants(block, args)) |
| 229 | return false; |
| 230 | |
| 231 | Block temporary_block; |
| 232 | bool have_converted_columns = false; |
| 233 | |
| 234 | size_t arguments_size = args.size(); |
| 235 | for (size_t arg_num = 0; arg_num < arguments_size; ++arg_num) |
| 236 | { |
| 237 | const ColumnWithTypeAndName & column = block.getByPosition(args[arg_num]); |
| 238 | |
| 239 | if (arguments_to_remain_constants.end() != std::find(arguments_to_remain_constants.begin(), arguments_to_remain_constants.end(), arg_num)) |
| 240 | { |
| 241 | temporary_block.insert({column.column->cloneResized(1), column.type, column.name}); |
| 242 | } |
| 243 | else |
| 244 | { |
| 245 | have_converted_columns = true; |
| 246 | temporary_block.insert({ assert_cast<const ColumnConst *>(column.column.get())->getDataColumnPtr(), column.type, column.name }); |
| 247 | } |
| 248 | } |
| 249 | |
| 250 | /** When using default implementation for constants, the function requires at least one argument |
| 251 | * not in "arguments_to_remain_constants" set. Otherwise we get infinite recursion. |
| 252 | */ |
| 253 | if (!have_converted_columns) |
| 254 | throw Exception("Number of arguments for function " + getName() + " doesn't match: the function requires more arguments" , |
| 255 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 256 | |
| 257 | temporary_block.insert(block.getByPosition(result)); |
| 258 | |
| 259 | ColumnNumbers temporary_argument_numbers(arguments_size); |
| 260 | for (size_t i = 0; i < arguments_size; ++i) |
| 261 | temporary_argument_numbers[i] = i; |
| 262 | |
| 263 | executeWithoutLowCardinalityColumns(temporary_block, temporary_argument_numbers, arguments_size, temporary_block.rows(), dry_run); |
| 264 | |
| 265 | ColumnPtr result_column; |
| 266 | /// extremely rare case, when we have function with completely const arguments |
| 267 | /// but some of them produced by non isDeterministic function |
| 268 | if (temporary_block.getByPosition(arguments_size).column->size() > 1) |
| 269 | result_column = temporary_block.getByPosition(arguments_size).column->cloneResized(1); |
| 270 | else |
| 271 | result_column = temporary_block.getByPosition(arguments_size).column; |
| 272 | |
| 273 | block.getByPosition(result).column = ColumnConst::create(result_column, input_rows_count); |
| 274 | return true; |
| 275 | } |
| 276 | |
| 277 | |
| 278 | bool ExecutableFunctionAdaptor::defaultImplementationForNulls( |
| 279 | Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count, bool dry_run) |
| 280 | { |
| 281 | if (args.empty() || !impl->useDefaultImplementationForNulls()) |
| 282 | return false; |
| 283 | |
| 284 | NullPresence null_presence = getNullPresense(block, args); |
| 285 | |
| 286 | if (null_presence.has_null_constant) |
| 287 | { |
| 288 | auto & result_column = block.getByPosition(result).column; |
| 289 | auto result_type = block.getByPosition(result).type; |
| 290 | // Default implementation for nulls returns null result for null arguments, |
| 291 | // so the result type must be nullable. |
| 292 | assert(result_type->isNullable()); |
| 293 | |
| 294 | result_column = result_type->createColumnConstWithDefaultValue(input_rows_count); |
| 295 | return true; |
| 296 | } |
| 297 | |
| 298 | if (null_presence.has_nullable) |
| 299 | { |
| 300 | Block temporary_block = createBlockWithNestedColumns(block, args, result); |
| 301 | executeWithoutLowCardinalityColumns(temporary_block, args, result, temporary_block.rows(), dry_run); |
| 302 | block.getByPosition(result).column = wrapInNullable(temporary_block.getByPosition(result).column, block, args, |
| 303 | result, input_rows_count); |
| 304 | return true; |
| 305 | } |
| 306 | |
| 307 | return false; |
| 308 | } |
| 309 | |
| 310 | void ExecutableFunctionAdaptor::executeWithoutLowCardinalityColumns( |
| 311 | Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count, bool dry_run) |
| 312 | { |
| 313 | if (defaultImplementationForConstantArguments(block, args, result, input_rows_count, dry_run)) |
| 314 | return; |
| 315 | |
| 316 | if (defaultImplementationForNulls(block, args, result, input_rows_count, dry_run)) |
| 317 | return; |
| 318 | |
| 319 | if (dry_run) |
| 320 | impl->executeDryRun(block, args, result, input_rows_count); |
| 321 | else |
| 322 | impl->execute(block, args, result, input_rows_count); |
| 323 | } |
| 324 | |
| 325 | static const ColumnLowCardinality * findLowCardinalityArgument(const Block & block, const ColumnNumbers & args) |
| 326 | { |
| 327 | const ColumnLowCardinality * result_column = nullptr; |
| 328 | |
| 329 | for (auto arg : args) |
| 330 | { |
| 331 | const ColumnWithTypeAndName & column = block.getByPosition(arg); |
| 332 | if (auto * low_cardinality_column = checkAndGetColumn<ColumnLowCardinality>(column.column.get())) |
| 333 | { |
| 334 | if (result_column) |
| 335 | throw Exception("Expected single dictionary argument for function." , ErrorCodes::LOGICAL_ERROR); |
| 336 | |
| 337 | result_column = low_cardinality_column; |
| 338 | } |
| 339 | } |
| 340 | |
| 341 | return result_column; |
| 342 | } |
| 343 | |
| 344 | static ColumnPtr replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes( |
| 345 | Block & block, const ColumnNumbers & args, bool can_be_executed_on_default_arguments, size_t input_rows_count) |
| 346 | { |
| 347 | size_t num_rows = input_rows_count; |
| 348 | ColumnPtr indexes; |
| 349 | |
| 350 | /// Find first LowCardinality column and replace it to nested dictionary. |
| 351 | for (auto arg : args) |
| 352 | { |
| 353 | ColumnWithTypeAndName & column = block.getByPosition(arg); |
| 354 | if (auto * low_cardinality_column = checkAndGetColumn<ColumnLowCardinality>(column.column.get())) |
| 355 | { |
| 356 | /// Single LowCardinality column is supported now. |
| 357 | if (indexes) |
| 358 | throw Exception("Expected single dictionary argument for function." , ErrorCodes::LOGICAL_ERROR); |
| 359 | |
| 360 | auto * low_cardinality_type = checkAndGetDataType<DataTypeLowCardinality>(column.type.get()); |
| 361 | |
| 362 | if (!low_cardinality_type) |
| 363 | throw Exception("Incompatible type for low cardinality column: " + column.type->getName(), |
| 364 | ErrorCodes::LOGICAL_ERROR); |
| 365 | |
| 366 | if (can_be_executed_on_default_arguments) |
| 367 | { |
| 368 | /// Normal case, when function can be executed on values's default. |
| 369 | column.column = low_cardinality_column->getDictionary().getNestedColumn(); |
| 370 | indexes = low_cardinality_column->getIndexesPtr(); |
| 371 | } |
| 372 | else |
| 373 | { |
| 374 | /// Special case when default value can't be used. Example: 1 % LowCardinality(Int). |
| 375 | /// LowCardinality always contains default, so 1 % 0 will throw exception in normal case. |
| 376 | auto dict_encoded = low_cardinality_column->getMinimalDictionaryEncodedColumn(0, low_cardinality_column->size()); |
| 377 | column.column = dict_encoded.dictionary; |
| 378 | indexes = dict_encoded.indexes; |
| 379 | } |
| 380 | |
| 381 | num_rows = column.column->size(); |
| 382 | column.type = low_cardinality_type->getDictionaryType(); |
| 383 | } |
| 384 | } |
| 385 | |
| 386 | /// Change size of constants. |
| 387 | for (auto arg : args) |
| 388 | { |
| 389 | ColumnWithTypeAndName & column = block.getByPosition(arg); |
| 390 | if (auto * column_const = checkAndGetColumn<ColumnConst>(column.column.get())) |
| 391 | { |
| 392 | column.column = column_const->removeLowCardinality()->cloneResized(num_rows); |
| 393 | column.type = removeLowCardinality(column.type); |
| 394 | } |
| 395 | } |
| 396 | |
| 397 | #ifndef NDEBUG |
| 398 | block.checkNumberOfRows(true); |
| 399 | #endif |
| 400 | |
| 401 | return indexes; |
| 402 | } |
| 403 | |
| 404 | static void convertLowCardinalityColumnsToFull(Block & block, const ColumnNumbers & args) |
| 405 | { |
| 406 | for (auto arg : args) |
| 407 | { |
| 408 | ColumnWithTypeAndName & column = block.getByPosition(arg); |
| 409 | |
| 410 | column.column = recursiveRemoveLowCardinality(column.column); |
| 411 | column.type = recursiveRemoveLowCardinality(column.type); |
| 412 | } |
| 413 | } |
| 414 | |
| 415 | void ExecutableFunctionAdaptor::execute(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count, bool dry_run) |
| 416 | { |
| 417 | if (impl->useDefaultImplementationForLowCardinalityColumns()) |
| 418 | { |
| 419 | auto & res = block.safeGetByPosition(result); |
| 420 | Block block_without_low_cardinality = block.cloneWithoutColumns(); |
| 421 | |
| 422 | for (auto arg : args) |
| 423 | block_without_low_cardinality.safeGetByPosition(arg).column = block.safeGetByPosition(arg).column; |
| 424 | |
| 425 | if (auto * res_low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(res.type.get())) |
| 426 | { |
| 427 | const auto * low_cardinality_column = findLowCardinalityArgument(block, args); |
| 428 | bool can_be_executed_on_default_arguments = impl->canBeExecutedOnDefaultArguments(); |
| 429 | bool use_cache = low_cardinality_result_cache && can_be_executed_on_default_arguments |
| 430 | && low_cardinality_column && low_cardinality_column->isSharedDictionary(); |
| 431 | ExecutableFunctionLowCardinalityResultCache::DictionaryKey key; |
| 432 | |
| 433 | if (use_cache) |
| 434 | { |
| 435 | const auto & dictionary = low_cardinality_column->getDictionary(); |
| 436 | key = {dictionary.getHash(), dictionary.size()}; |
| 437 | |
| 438 | auto cached_values = low_cardinality_result_cache->get(key); |
| 439 | if (cached_values) |
| 440 | { |
| 441 | auto indexes = cached_values->index_mapping->index(low_cardinality_column->getIndexes(), 0); |
| 442 | res.column = ColumnLowCardinality::create(cached_values->function_result, indexes, true); |
| 443 | return; |
| 444 | } |
| 445 | } |
| 446 | |
| 447 | block_without_low_cardinality.safeGetByPosition(result).type = res_low_cardinality_type->getDictionaryType(); |
| 448 | ColumnPtr indexes = replaceLowCardinalityColumnsByNestedAndGetDictionaryIndexes( |
| 449 | block_without_low_cardinality, args, can_be_executed_on_default_arguments, input_rows_count); |
| 450 | |
| 451 | executeWithoutLowCardinalityColumns(block_without_low_cardinality, args, result, block_without_low_cardinality.rows(), dry_run); |
| 452 | |
| 453 | auto keys = block_without_low_cardinality.safeGetByPosition(result).column->convertToFullColumnIfConst(); |
| 454 | |
| 455 | auto res_mut_dictionary = DataTypeLowCardinality::createColumnUnique(*res_low_cardinality_type->getDictionaryType()); |
| 456 | ColumnPtr res_indexes = res_mut_dictionary->uniqueInsertRangeFrom(*keys, 0, keys->size()); |
| 457 | ColumnUniquePtr res_dictionary = std::move(res_mut_dictionary); |
| 458 | |
| 459 | if (indexes) |
| 460 | { |
| 461 | if (use_cache) |
| 462 | { |
| 463 | auto cache_values = std::make_shared<ExecutableFunctionLowCardinalityResultCache::CachedValues>(); |
| 464 | cache_values->dictionary_holder = low_cardinality_column->getDictionaryPtr(); |
| 465 | cache_values->function_result = res_dictionary; |
| 466 | cache_values->index_mapping = res_indexes; |
| 467 | |
| 468 | cache_values = low_cardinality_result_cache->getOrSet(key, cache_values); |
| 469 | res_dictionary = cache_values->function_result; |
| 470 | res_indexes = cache_values->index_mapping; |
| 471 | } |
| 472 | |
| 473 | res.column = ColumnLowCardinality::create(res_dictionary, res_indexes->index(*indexes, 0), use_cache); |
| 474 | } |
| 475 | else |
| 476 | { |
| 477 | res.column = ColumnLowCardinality::create(res_dictionary, res_indexes); |
| 478 | } |
| 479 | } |
| 480 | else |
| 481 | { |
| 482 | convertLowCardinalityColumnsToFull(block_without_low_cardinality, args); |
| 483 | executeWithoutLowCardinalityColumns(block_without_low_cardinality, args, result, input_rows_count, dry_run); |
| 484 | res.column = block_without_low_cardinality.safeGetByPosition(result).column; |
| 485 | } |
| 486 | } |
| 487 | else |
| 488 | executeWithoutLowCardinalityColumns(block, args, result, input_rows_count, dry_run); |
| 489 | } |
| 490 | |
| 491 | void FunctionOverloadResolverAdaptor::checkNumberOfArguments(size_t number_of_arguments) const |
| 492 | { |
| 493 | if (isVariadic()) |
| 494 | return; |
| 495 | |
| 496 | size_t expected_number_of_arguments = getNumberOfArguments(); |
| 497 | |
| 498 | if (number_of_arguments != expected_number_of_arguments) |
| 499 | throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " |
| 500 | + toString(number_of_arguments) + ", should be " + toString(expected_number_of_arguments), |
| 501 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 502 | } |
| 503 | |
| 504 | DataTypePtr FunctionOverloadResolverAdaptor::getReturnTypeWithoutLowCardinality(const ColumnsWithTypeAndName & arguments) const |
| 505 | { |
| 506 | checkNumberOfArguments(arguments.size()); |
| 507 | |
| 508 | if (!arguments.empty() && impl->useDefaultImplementationForNulls()) |
| 509 | { |
| 510 | NullPresence null_presence = getNullPresense(arguments); |
| 511 | |
| 512 | if (null_presence.has_null_constant) |
| 513 | { |
| 514 | return makeNullable(std::make_shared<DataTypeNothing>()); |
| 515 | } |
| 516 | if (null_presence.has_nullable) |
| 517 | { |
| 518 | Block nested_block = createBlockWithNestedColumns(Block(arguments), ext::collection_cast<ColumnNumbers>(ext::range(0, arguments.size()))); |
| 519 | auto return_type = impl->getReturnType(ColumnsWithTypeAndName(nested_block.begin(), nested_block.end())); |
| 520 | return makeNullable(return_type); |
| 521 | |
| 522 | } |
| 523 | } |
| 524 | |
| 525 | return impl->getReturnType(arguments); |
| 526 | } |
| 527 | |
| 528 | #if USE_EMBEDDED_COMPILER |
| 529 | |
| 530 | static std::optional<DataTypes> removeNullables(const DataTypes & types) |
| 531 | { |
| 532 | for (const auto & type : types) |
| 533 | { |
| 534 | if (!typeid_cast<const DataTypeNullable *>(type.get())) |
| 535 | continue; |
| 536 | DataTypes filtered; |
| 537 | for (const auto & sub_type : types) |
| 538 | filtered.emplace_back(removeNullable(sub_type)); |
| 539 | return filtered; |
| 540 | } |
| 541 | return {}; |
| 542 | } |
| 543 | |
| 544 | bool IFunction::isCompilable(const DataTypes & arguments) const |
| 545 | { |
| 546 | if (useDefaultImplementationForNulls()) |
| 547 | if (auto denulled = removeNullables(arguments)) |
| 548 | return isCompilableImpl(*denulled); |
| 549 | return isCompilableImpl(arguments); |
| 550 | } |
| 551 | |
| 552 | llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes & arguments, ValuePlaceholders values) const |
| 553 | { |
| 554 | if (useDefaultImplementationForNulls()) |
| 555 | { |
| 556 | if (auto denulled = removeNullables(arguments)) |
| 557 | { |
| 558 | /// FIXME: when only one column is nullable, this can actually be slower than the non-jitted version |
| 559 | /// because this involves copying the null map while `wrapInNullable` reuses it. |
| 560 | auto & b = static_cast<llvm::IRBuilder<> &>(builder); |
| 561 | auto * fail = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "" , b.GetInsertBlock()->getParent()); |
| 562 | auto * join = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "" , b.GetInsertBlock()->getParent()); |
| 563 | auto * zero = llvm::Constant::getNullValue(toNativeType(b, makeNullable(getReturnTypeImpl(*denulled)))); |
| 564 | for (size_t i = 0; i < arguments.size(); i++) |
| 565 | { |
| 566 | if (!arguments[i]->isNullable()) |
| 567 | continue; |
| 568 | /// Would be nice to evaluate all this lazily, but that'd change semantics: if only unevaluated |
| 569 | /// arguments happen to contain NULLs, the return value would not be NULL, though it should be. |
| 570 | auto * value = values[i](); |
| 571 | auto * ok = llvm::BasicBlock::Create(b.GetInsertBlock()->getContext(), "" , b.GetInsertBlock()->getParent()); |
| 572 | b.CreateCondBr(b.CreateExtractValue(value, {1}), fail, ok); |
| 573 | b.SetInsertPoint(ok); |
| 574 | values[i] = [value = b.CreateExtractValue(value, {0})]() { return value; }; |
| 575 | } |
| 576 | auto * result = b.CreateInsertValue(zero, compileImpl(builder, *denulled, std::move(values)), {0}); |
| 577 | auto * result_block = b.GetInsertBlock(); |
| 578 | b.CreateBr(join); |
| 579 | b.SetInsertPoint(fail); |
| 580 | auto * null = b.CreateInsertValue(zero, b.getTrue(), {1}); |
| 581 | b.CreateBr(join); |
| 582 | b.SetInsertPoint(join); |
| 583 | auto * phi = b.CreatePHI(result->getType(), 2); |
| 584 | phi->addIncoming(result, result_block); |
| 585 | phi->addIncoming(null, fail); |
| 586 | return phi; |
| 587 | } |
| 588 | } |
| 589 | return compileImpl(builder, arguments, std::move(values)); |
| 590 | } |
| 591 | |
| 592 | #endif |
| 593 | |
| 594 | DataTypePtr FunctionOverloadResolverAdaptor::getReturnType(const ColumnsWithTypeAndName & arguments) const |
| 595 | { |
| 596 | if (impl->useDefaultImplementationForLowCardinalityColumns()) |
| 597 | { |
| 598 | bool has_low_cardinality = false; |
| 599 | size_t num_full_low_cardinality_columns = 0; |
| 600 | size_t num_full_ordinary_columns = 0; |
| 601 | |
| 602 | ColumnsWithTypeAndName args_without_low_cardinality(arguments); |
| 603 | |
| 604 | for (ColumnWithTypeAndName & arg : args_without_low_cardinality) |
| 605 | { |
| 606 | bool is_const = arg.column && isColumnConst(*arg.column); |
| 607 | if (is_const) |
| 608 | arg.column = assert_cast<const ColumnConst &>(*arg.column).removeLowCardinality(); |
| 609 | |
| 610 | if (auto * low_cardinality_type = typeid_cast<const DataTypeLowCardinality *>(arg.type.get())) |
| 611 | { |
| 612 | arg.type = low_cardinality_type->getDictionaryType(); |
| 613 | has_low_cardinality = true; |
| 614 | |
| 615 | if (!is_const) |
| 616 | ++num_full_low_cardinality_columns; |
| 617 | } |
| 618 | else if (!is_const) |
| 619 | ++num_full_ordinary_columns; |
| 620 | } |
| 621 | |
| 622 | for (auto & arg : args_without_low_cardinality) |
| 623 | { |
| 624 | arg.column = recursiveRemoveLowCardinality(arg.column); |
| 625 | arg.type = recursiveRemoveLowCardinality(arg.type); |
| 626 | } |
| 627 | |
| 628 | auto type_without_low_cardinality = getReturnTypeWithoutLowCardinality(args_without_low_cardinality); |
| 629 | |
| 630 | if (impl->canBeExecutedOnLowCardinalityDictionary() && has_low_cardinality |
| 631 | && num_full_low_cardinality_columns <= 1 && num_full_ordinary_columns == 0 |
| 632 | && type_without_low_cardinality->canBeInsideLowCardinality()) |
| 633 | return std::make_shared<DataTypeLowCardinality>(type_without_low_cardinality); |
| 634 | else |
| 635 | return type_without_low_cardinality; |
| 636 | } |
| 637 | |
| 638 | return getReturnTypeWithoutLowCardinality(arguments); |
| 639 | } |
| 640 | } |
| 641 | |