1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/GatherUtils/GatherUtils.h> |
3 | #include <DataTypes/DataTypeArray.h> |
4 | #include <DataTypes/getLeastSupertype.h> |
5 | #include <Columns/ColumnArray.h> |
6 | #include <Columns/ColumnConst.h> |
7 | #include <Interpreters/castColumn.h> |
8 | #include <Common/typeid_cast.h> |
9 | |
10 | |
11 | namespace DB |
12 | { |
13 | |
14 | namespace ErrorCodes |
15 | { |
16 | extern const int LOGICAL_ERROR; |
17 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
18 | } |
19 | |
20 | |
21 | class FunctionArrayPush : public IFunction |
22 | { |
23 | public: |
24 | FunctionArrayPush(const Context & context_, bool push_front_, const char * name_) |
25 | : context(context_), push_front(push_front_), name(name_) {} |
26 | |
27 | String getName() const override { return name; } |
28 | |
29 | bool isVariadic() const override { return false; } |
30 | size_t getNumberOfArguments() const override { return 2; } |
31 | |
32 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
33 | { |
34 | if (arguments[0]->onlyNull()) |
35 | return arguments[0]; |
36 | |
37 | auto array_type = typeid_cast<const DataTypeArray *>(arguments[0].get()); |
38 | if (!array_type) |
39 | throw Exception("First argument for function " + getName() + " must be an array but it has type " |
40 | + arguments[0]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
41 | |
42 | auto nested_type = array_type->getNestedType(); |
43 | |
44 | DataTypes types = {nested_type, arguments[1]}; |
45 | |
46 | return std::make_shared<DataTypeArray>(getLeastSupertype(types)); |
47 | } |
48 | |
49 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
50 | { |
51 | const auto & return_type = block.getByPosition(result).type; |
52 | |
53 | if (return_type->onlyNull()) |
54 | { |
55 | block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); |
56 | return; |
57 | } |
58 | |
59 | auto result_column = return_type->createColumn(); |
60 | |
61 | auto array_column = block.getByPosition(arguments[0]).column; |
62 | auto appended_column = block.getByPosition(arguments[1]).column; |
63 | |
64 | if (!block.getByPosition(arguments[0]).type->equals(*return_type)) |
65 | array_column = castColumn(block.getByPosition(arguments[0]), return_type, context); |
66 | |
67 | const DataTypePtr & return_nested_type = typeid_cast<const DataTypeArray &>(*return_type).getNestedType(); |
68 | if (!block.getByPosition(arguments[1]).type->equals(*return_nested_type)) |
69 | appended_column = castColumn(block.getByPosition(arguments[1]), return_nested_type, context); |
70 | |
71 | std::unique_ptr<GatherUtils::IArraySource> array_source; |
72 | std::unique_ptr<GatherUtils::IValueSource> value_source; |
73 | |
74 | size_t size = array_column->size(); |
75 | bool is_const = false; |
76 | |
77 | if (auto const_array_column = typeid_cast<const ColumnConst *>(array_column.get())) |
78 | { |
79 | is_const = true; |
80 | array_column = const_array_column->getDataColumnPtr(); |
81 | } |
82 | |
83 | if (auto argument_column_array = typeid_cast<const ColumnArray *>(array_column.get())) |
84 | array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); |
85 | else |
86 | throw Exception{"First arguments for function " + getName() + " must be array." , ErrorCodes::LOGICAL_ERROR}; |
87 | |
88 | |
89 | bool is_appended_const = false; |
90 | if (auto const_appended_column = typeid_cast<const ColumnConst *>(appended_column.get())) |
91 | { |
92 | is_appended_const = true; |
93 | appended_column = const_appended_column->getDataColumnPtr(); |
94 | } |
95 | |
96 | value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); |
97 | |
98 | auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), size); |
99 | |
100 | GatherUtils::push(*array_source, *value_source, *sink, push_front); |
101 | |
102 | block.getByPosition(result).column = std::move(result_column); |
103 | } |
104 | |
105 | bool useDefaultImplementationForConstants() const override { return true; } |
106 | bool useDefaultImplementationForNulls() const override { return false; } |
107 | |
108 | private: |
109 | const Context & context; |
110 | bool push_front; |
111 | const char * name; |
112 | }; |
113 | |
114 | } |
115 | |