1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/FunctionFactory.h> |
3 | #include <Functions/GatherUtils/GatherUtils.h> |
4 | #include <DataTypes/DataTypeArray.h> |
5 | #include <DataTypes/getLeastSupertype.h> |
6 | #include <Interpreters/castColumn.h> |
7 | #include <Columns/ColumnArray.h> |
8 | #include <Columns/ColumnConst.h> |
9 | #include <Common/typeid_cast.h> |
10 | #include <ext/range.h> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace ErrorCodes |
17 | { |
18 | extern const int LOGICAL_ERROR; |
19 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
20 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
21 | } |
22 | |
23 | |
24 | /// arrayConcat(arr1, ...) - concatenate arrays. |
25 | class FunctionArrayConcat : public IFunction |
26 | { |
27 | public: |
28 | static constexpr auto name = "arrayConcat" ; |
29 | static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayConcat>(context); } |
30 | FunctionArrayConcat(const Context & context_) : context(context_) {} |
31 | |
32 | String getName() const override { return name; } |
33 | |
34 | bool isVariadic() const override { return true; } |
35 | size_t getNumberOfArguments() const override { return 0; } |
36 | |
37 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
38 | { |
39 | if (arguments.empty()) |
40 | throw Exception{"Function " + getName() + " requires at least one argument." , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; |
41 | |
42 | for (auto i : ext::range(0, arguments.size())) |
43 | { |
44 | auto array_type = typeid_cast<const DataTypeArray *>(arguments[i].get()); |
45 | if (!array_type) |
46 | throw Exception("Argument " + std::to_string(i) + " for function " + getName() + " must be an array but it has type " |
47 | + arguments[i]->getName() + "." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
48 | } |
49 | |
50 | return getLeastSupertype(arguments); |
51 | } |
52 | |
53 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
54 | { |
55 | const DataTypePtr & return_type = block.getByPosition(result).type; |
56 | |
57 | if (return_type->onlyNull()) |
58 | { |
59 | block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); |
60 | return; |
61 | } |
62 | |
63 | auto result_column = return_type->createColumn(); |
64 | |
65 | size_t rows = input_rows_count; |
66 | size_t num_args = arguments.size(); |
67 | |
68 | Columns preprocessed_columns(num_args); |
69 | |
70 | for (size_t i = 0; i < num_args; ++i) |
71 | { |
72 | const ColumnWithTypeAndName & arg = block.getByPosition(arguments[i]); |
73 | ColumnPtr preprocessed_column = arg.column; |
74 | |
75 | if (!arg.type->equals(*return_type)) |
76 | preprocessed_column = castColumn(arg, return_type, context); |
77 | |
78 | preprocessed_columns[i] = std::move(preprocessed_column); |
79 | } |
80 | |
81 | std::vector<std::unique_ptr<GatherUtils::IArraySource>> sources; |
82 | |
83 | for (auto & argument_column : preprocessed_columns) |
84 | { |
85 | bool is_const = false; |
86 | |
87 | if (auto argument_column_const = typeid_cast<const ColumnConst *>(argument_column.get())) |
88 | { |
89 | is_const = true; |
90 | argument_column = argument_column_const->getDataColumnPtr(); |
91 | } |
92 | |
93 | if (auto argument_column_array = typeid_cast<const ColumnArray *>(argument_column.get())) |
94 | sources.emplace_back(GatherUtils::createArraySource(*argument_column_array, is_const, rows)); |
95 | else |
96 | throw Exception{"Arguments for function " + getName() + " must be arrays." , ErrorCodes::LOGICAL_ERROR}; |
97 | } |
98 | |
99 | auto sink = GatherUtils::createArraySink(typeid_cast<ColumnArray &>(*result_column), rows); |
100 | GatherUtils::concat(sources, *sink); |
101 | |
102 | block.getByPosition(result).column = std::move(result_column); |
103 | } |
104 | |
105 | bool useDefaultImplementationForConstants() const override { return true; } |
106 | |
107 | private: |
108 | const Context & context; |
109 | }; |
110 | |
111 | |
112 | void registerFunctionArrayConcat(FunctionFactory & factory) |
113 | { |
114 | factory.registerFunction<FunctionArrayConcat>(); |
115 | } |
116 | |
117 | } |
118 | |