1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/FunctionFactory.h> |
3 | #include <Functions/FunctionHelpers.h> |
4 | #include <DataTypes/IDataType.h> |
5 | #include <DataTypes/DataTypeTuple.h> |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <Columns/ColumnTuple.h> |
8 | #include <Columns/ColumnArray.h> |
9 | #include <Columns/ColumnString.h> |
10 | #include <Columns/ColumnsNumber.h> |
11 | #include <Common/assert_cast.h> |
12 | #include <memory> |
13 | |
14 | |
15 | namespace DB |
16 | { |
17 | |
18 | namespace ErrorCodes |
19 | { |
20 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
21 | extern const int ILLEGAL_INDEX; |
22 | } |
23 | |
24 | |
25 | /** Extract element of tuple by constant index or name. The operation is essentially free. |
26 | * Also the function looks through Arrays: you can get Array of tuple elements from Array of Tuples. |
27 | */ |
28 | class FunctionTupleElement : public IFunction |
29 | { |
30 | public: |
31 | static constexpr auto name = "tupleElement" ; |
32 | static FunctionPtr create(const Context &) |
33 | { |
34 | return std::make_shared<FunctionTupleElement>(); |
35 | } |
36 | |
37 | String getName() const override |
38 | { |
39 | return name; |
40 | } |
41 | |
42 | size_t getNumberOfArguments() const override |
43 | { |
44 | return 2; |
45 | } |
46 | |
47 | bool useDefaultImplementationForConstants() const override |
48 | { |
49 | return true; |
50 | } |
51 | |
52 | ColumnNumbers getArgumentsThatAreAlwaysConstant() const override |
53 | { |
54 | return {1}; |
55 | } |
56 | |
57 | DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override |
58 | { |
59 | size_t count_arrays = 0; |
60 | |
61 | const IDataType * tuple_col = arguments[0].type.get(); |
62 | while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(tuple_col)) |
63 | { |
64 | tuple_col = array->getNestedType().get(); |
65 | ++count_arrays; |
66 | } |
67 | |
68 | const DataTypeTuple * tuple = checkAndGetDataType<DataTypeTuple>(tuple_col); |
69 | if (!tuple) |
70 | throw Exception("First argument for function " + getName() + " must be tuple or array of tuple." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
71 | |
72 | size_t index = getElementNum(arguments[1].column, *tuple); |
73 | DataTypePtr out_return_type = tuple->getElements()[index]; |
74 | |
75 | for (; count_arrays; --count_arrays) |
76 | out_return_type = std::make_shared<DataTypeArray>(out_return_type); |
77 | |
78 | return out_return_type; |
79 | } |
80 | |
81 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override |
82 | { |
83 | Columns array_offsets; |
84 | |
85 | const auto & first_arg = block.getByPosition(arguments[0]); |
86 | |
87 | const IDataType * tuple_type = first_arg.type.get(); |
88 | const IColumn * tuple_col = first_arg.column.get(); |
89 | while (const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(tuple_type)) |
90 | { |
91 | const ColumnArray * array_col = assert_cast<const ColumnArray *>(tuple_col); |
92 | |
93 | tuple_type = array_type->getNestedType().get(); |
94 | tuple_col = &array_col->getData(); |
95 | array_offsets.push_back(array_col->getOffsetsPtr()); |
96 | } |
97 | |
98 | const DataTypeTuple * tuple_type_concrete = checkAndGetDataType<DataTypeTuple>(tuple_type); |
99 | const ColumnTuple * tuple_col_concrete = checkAndGetColumn<ColumnTuple>(tuple_col); |
100 | if (!tuple_type_concrete || !tuple_col_concrete) |
101 | throw Exception("First argument for function " + getName() + " must be tuple or array of tuple." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
102 | |
103 | size_t index = getElementNum(block.getByPosition(arguments[1]).column, *tuple_type_concrete); |
104 | ColumnPtr res = tuple_col_concrete->getColumns()[index]; |
105 | |
106 | /// Wrap into Arrays |
107 | for (auto it = array_offsets.rbegin(); it != array_offsets.rend(); ++it) |
108 | res = ColumnArray::create(res, *it); |
109 | |
110 | block.getByPosition(result).column = res; |
111 | } |
112 | |
113 | private: |
114 | size_t getElementNum(const ColumnPtr & index_column, const DataTypeTuple & tuple) const |
115 | { |
116 | if (auto index_col = checkAndGetColumnConst<ColumnUInt8>(index_column.get())) |
117 | { |
118 | size_t index = index_col->getValue<UInt8>(); |
119 | |
120 | if (index == 0) |
121 | throw Exception("Indices in tuples are 1-based." , ErrorCodes::ILLEGAL_INDEX); |
122 | |
123 | if (index > tuple.getElements().size()) |
124 | throw Exception("Index for tuple element is out of range." , ErrorCodes::ILLEGAL_INDEX); |
125 | |
126 | return index - 1; |
127 | } |
128 | else if (auto name_col = checkAndGetColumnConst<ColumnString>(index_column.get())) |
129 | { |
130 | return tuple.getPositionByName(name_col->getValue<String>()); |
131 | } |
132 | else |
133 | throw Exception("Second argument to " + getName() + " must be a constant UInt8 or String" , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
134 | } |
135 | }; |
136 | |
137 | |
138 | void registerFunctionTupleElement(FunctionFactory & factory) |
139 | { |
140 | factory.registerFunction<FunctionTupleElement>(); |
141 | } |
142 | |
143 | } |
144 | |