1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/FunctionFactory.h> |
3 | #include <Functions/GatherUtils/GatherUtils.h> |
4 | #include <DataTypes/DataTypeArray.h> |
5 | #include <DataTypes/DataTypeNullable.h> |
6 | #include <DataTypes/getLeastSupertype.h> |
7 | #include <Columns/ColumnArray.h> |
8 | #include <Columns/ColumnConst.h> |
9 | #include <Interpreters/castColumn.h> |
10 | #include <IO/WriteHelpers.h> |
11 | #include <Common/typeid_cast.h> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | namespace ErrorCodes |
18 | { |
19 | extern const int LOGICAL_ERROR; |
20 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
21 | extern const int ILLEGAL_COLUMN; |
22 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
23 | } |
24 | |
25 | class FunctionArrayResize : public IFunction |
26 | { |
27 | public: |
28 | static constexpr auto name = "arrayResize" ; |
29 | static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayResize>(context); } |
30 | FunctionArrayResize(const Context & context_) : context(context_) {} |
31 | |
32 | String getName() const override { return name; } |
33 | |
34 | bool isVariadic() const override { return true; } |
35 | size_t getNumberOfArguments() const override { return 0; } |
36 | |
37 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
38 | { |
39 | const size_t number_of_arguments = arguments.size(); |
40 | |
41 | if (number_of_arguments < 2 || number_of_arguments > 3) |
42 | throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " |
43 | + toString(number_of_arguments) + ", should be 2 or 3" , |
44 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
45 | |
46 | if (arguments[0]->onlyNull()) |
47 | return arguments[0]; |
48 | |
49 | auto array_type = typeid_cast<const DataTypeArray *>(arguments[0].get()); |
50 | if (!array_type) |
51 | throw Exception("First argument for function " + getName() + " must be an array but it has type " |
52 | + arguments[0]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
53 | |
54 | if (WhichDataType(array_type->getNestedType()).isNothing()) |
55 | throw Exception("Function " + getName() + " cannot resize " + array_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
56 | |
57 | if (!isInteger(removeNullable(arguments[1])) && !arguments[1]->onlyNull()) |
58 | throw Exception( |
59 | "Argument " + toString(1) + " for function " + getName() + " must be integer but it has type " |
60 | + arguments[1]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
61 | |
62 | if (number_of_arguments == 2) |
63 | return arguments[0]; |
64 | else /* if (number_of_arguments == 3) */ |
65 | return std::make_shared<DataTypeArray>(getLeastSupertype({array_type->getNestedType(), arguments[2]})); |
66 | } |
67 | |
68 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
69 | { |
70 | const auto & return_type = block.getByPosition(result).type; |
71 | |
72 | if (return_type->onlyNull()) |
73 | { |
74 | block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); |
75 | return; |
76 | } |
77 | |
78 | auto result_column = return_type->createColumn(); |
79 | |
80 | auto array_column = block.getByPosition(arguments[0]).column; |
81 | auto size_column = block.getByPosition(arguments[1]).column; |
82 | |
83 | if (!block.getByPosition(arguments[0]).type->equals(*return_type)) |
84 | array_column = castColumn(block.getByPosition(arguments[0]), return_type, context); |
85 | |
86 | const DataTypePtr & return_nested_type = typeid_cast<const DataTypeArray &>(*return_type).getNestedType(); |
87 | size_t size = array_column->size(); |
88 | |
89 | ColumnPtr appended_column; |
90 | if (arguments.size() == 3) |
91 | { |
92 | appended_column = block.getByPosition(arguments[2]).column; |
93 | if (!block.getByPosition(arguments[2]).type->equals(*return_nested_type)) |
94 | appended_column = castColumn(block.getByPosition(arguments[2]), return_nested_type, context); |
95 | } |
96 | else |
97 | appended_column = return_nested_type->createColumnConstWithDefaultValue(size); |
98 | |
99 | std::unique_ptr<GatherUtils::IArraySource> array_source; |
100 | std::unique_ptr<GatherUtils::IValueSource> value_source; |
101 | |
102 | bool is_const = false; |
103 | |
104 | if (auto const_array_column = typeid_cast<const ColumnConst *>(array_column.get())) |
105 | { |
106 | is_const = true; |
107 | array_column = const_array_column->getDataColumnPtr(); |
108 | } |
109 | |
110 | if (auto argument_column_array = typeid_cast<const ColumnArray *>(array_column.get())) |
111 | array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); |
112 | else |
113 | throw Exception{"First arguments for function " + getName() + " must be array." , ErrorCodes::LOGICAL_ERROR}; |
114 | |
115 | |
116 | bool is_appended_const = false; |
117 | if (auto const_appended_column = typeid_cast<const ColumnConst *>(appended_column.get())) |
118 | { |
119 | is_appended_const = true; |
120 | appended_column = const_appended_column->getDataColumnPtr(); |
121 | } |
122 | |
123 | value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); |
124 | |
125 | auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), size); |
126 | |
127 | if (isColumnConst(*size_column)) |
128 | GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0)); |
129 | else |
130 | GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column); |
131 | |
132 | block.getByPosition(result).column = std::move(result_column); |
133 | } |
134 | |
135 | bool useDefaultImplementationForConstants() const override { return true; } |
136 | bool useDefaultImplementationForNulls() const override { return false; } |
137 | |
138 | private: |
139 | const Context & context; |
140 | }; |
141 | |
142 | |
143 | void registerFunctionArrayResize(FunctionFactory & factory) |
144 | { |
145 | factory.registerFunction<FunctionArrayResize>(); |
146 | } |
147 | |
148 | } |
149 | |