| 1 | #include <Functions/IFunctionImpl.h> |
| 2 | #include <Functions/FunctionFactory.h> |
| 3 | #include <Functions/GatherUtils/GatherUtils.h> |
| 4 | #include <DataTypes/DataTypeArray.h> |
| 5 | #include <DataTypes/DataTypeNullable.h> |
| 6 | #include <DataTypes/getLeastSupertype.h> |
| 7 | #include <Columns/ColumnArray.h> |
| 8 | #include <Columns/ColumnConst.h> |
| 9 | #include <Interpreters/castColumn.h> |
| 10 | #include <IO/WriteHelpers.h> |
| 11 | #include <Common/typeid_cast.h> |
| 12 | |
| 13 | |
| 14 | namespace DB |
| 15 | { |
| 16 | |
| 17 | namespace ErrorCodes |
| 18 | { |
| 19 | extern const int LOGICAL_ERROR; |
| 20 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 21 | extern const int ILLEGAL_COLUMN; |
| 22 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| 23 | } |
| 24 | |
| 25 | class FunctionArrayResize : public IFunction |
| 26 | { |
| 27 | public: |
| 28 | static constexpr auto name = "arrayResize" ; |
| 29 | static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayResize>(context); } |
| 30 | FunctionArrayResize(const Context & context_) : context(context_) {} |
| 31 | |
| 32 | String getName() const override { return name; } |
| 33 | |
| 34 | bool isVariadic() const override { return true; } |
| 35 | size_t getNumberOfArguments() const override { return 0; } |
| 36 | |
| 37 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
| 38 | { |
| 39 | const size_t number_of_arguments = arguments.size(); |
| 40 | |
| 41 | if (number_of_arguments < 2 || number_of_arguments > 3) |
| 42 | throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " |
| 43 | + toString(number_of_arguments) + ", should be 2 or 3" , |
| 44 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
| 45 | |
| 46 | if (arguments[0]->onlyNull()) |
| 47 | return arguments[0]; |
| 48 | |
| 49 | auto array_type = typeid_cast<const DataTypeArray *>(arguments[0].get()); |
| 50 | if (!array_type) |
| 51 | throw Exception("First argument for function " + getName() + " must be an array but it has type " |
| 52 | + arguments[0]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 53 | |
| 54 | if (WhichDataType(array_type->getNestedType()).isNothing()) |
| 55 | throw Exception("Function " + getName() + " cannot resize " + array_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 56 | |
| 57 | if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull()) |
| 58 | throw Exception( |
| 59 | "Argument " + toString(1) + " for function " + getName() + " must be integer but it has type " |
| 60 | + arguments[1]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 61 | |
| 62 | if (number_of_arguments == 2) |
| 63 | return arguments[0]; |
| 64 | else /* if (number_of_arguments == 3) */ |
| 65 | return std::make_shared<DataTypeArray>(getLeastSupertype({array_type->getNestedType(), arguments[2]})); |
| 66 | } |
| 67 | |
| 68 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
| 69 | { |
| 70 | const auto & return_type = block.getByPosition(result).type; |
| 71 | |
| 72 | if (return_type->onlyNull()) |
| 73 | { |
| 74 | block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); |
| 75 | return; |
| 76 | } |
| 77 | |
| 78 | auto result_column = return_type->createColumn(); |
| 79 | |
| 80 | auto array_column = block.getByPosition(arguments[0]).column; |
| 81 | auto size_column = block.getByPosition(arguments[1]).column; |
| 82 | |
| 83 | if (!block.getByPosition(arguments[0]).type->equals(*return_type)) |
| 84 | array_column = castColumn(block.getByPosition(arguments[0]), return_type, context); |
| 85 | |
| 86 | const DataTypePtr & return_nested_type = typeid_cast<const DataTypeArray &>(*return_type).getNestedType(); |
| 87 | size_t size = array_column->size(); |
| 88 | |
| 89 | ColumnPtr appended_column; |
| 90 | if (arguments.size() == 3) |
| 91 | { |
| 92 | appended_column = block.getByPosition(arguments[2]).column; |
| 93 | if (!block.getByPosition(arguments[2]).type->equals(*return_nested_type)) |
| 94 | appended_column = castColumn(block.getByPosition(arguments[2]), return_nested_type, context); |
| 95 | } |
| 96 | else |
| 97 | appended_column = return_nested_type->createColumnConstWithDefaultValue(size); |
| 98 | |
| 99 | std::unique_ptr<GatherUtils::IArraySource> array_source; |
| 100 | std::unique_ptr<GatherUtils::IValueSource> value_source; |
| 101 | |
| 102 | bool is_const = false; |
| 103 | |
| 104 | if (auto const_array_column = typeid_cast<const ColumnConst *>(array_column.get())) |
| 105 | { |
| 106 | is_const = true; |
| 107 | array_column = const_array_column->getDataColumnPtr(); |
| 108 | } |
| 109 | |
| 110 | if (auto argument_column_array = typeid_cast<const ColumnArray *>(array_column.get())) |
| 111 | array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); |
| 112 | else |
| 113 | throw Exception{"First arguments for function " + getName() + " must be array." , ErrorCodes::LOGICAL_ERROR}; |
| 114 | |
| 115 | |
| 116 | bool is_appended_const = false; |
| 117 | if (auto const_appended_column = typeid_cast<const ColumnConst *>(appended_column.get())) |
| 118 | { |
| 119 | is_appended_const = true; |
| 120 | appended_column = const_appended_column->getDataColumnPtr(); |
| 121 | } |
| 122 | |
| 123 | value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); |
| 124 | |
| 125 | auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), size); |
| 126 | |
| 127 | if (isColumnConst(*size_column)) |
| 128 | GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0)); |
| 129 | else |
| 130 | GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column); |
| 131 | |
| 132 | block.getByPosition(result).column = std::move(result_column); |
| 133 | } |
| 134 | |
| 135 | bool useDefaultImplementationForConstants() const override { return true; } |
| 136 | bool useDefaultImplementationForNulls() const override { return false; } |
| 137 | |
| 138 | private: |
| 139 | const Context & context; |
| 140 | }; |
| 141 | |
| 142 | |
| 143 | void registerFunctionArrayResize(FunctionFactory & factory) |
| 144 | { |
| 145 | factory.registerFunction<FunctionArrayResize>(); |
| 146 | } |
| 147 | |
| 148 | } |
| 149 | |