1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <DataTypes/DataTypeArray.h>
5#include <Columns/ColumnArray.h>
6
7
8namespace DB
9{
10
11namespace ErrorCodes
12{
13 extern const int ILLEGAL_COLUMN;
14}
15
16/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array.
17class ArrayFlatten : public IFunction
18{
19public:
20 static constexpr auto name = "arrayFlatten";
21
22 static FunctionPtr create(const Context &) { return std::make_shared<ArrayFlatten>(); }
23
24 size_t getNumberOfArguments() const override { return 1; }
25 bool useDefaultImplementationForConstants() const override { return true; }
26
27 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
28 {
29 if (!isArray(arguments[0]))
30 throw Exception("Illegal type " + arguments[0]->getName() +
31 " of argument of function " + getName() +
32 ", expected Array", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
33
34 DataTypePtr nested_type = arguments[0];
35 while (isArray(nested_type))
36 nested_type = checkAndGetDataType<DataTypeArray>(nested_type.get())->getNestedType();
37
38 return std::make_shared<DataTypeArray>(nested_type);
39 }
40
41 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
42 {
43 /** We create an array column with array elements as the most deep elements of nested arrays,
44 * and construct offsets by selecting elements of most deep offsets by values of ancestor offsets.
45 *
46Example 1:
47
48Source column: Array(Array(UInt8)):
49Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]]
50data: [1, 2, 3], [4, 5], [6], [7, 8]
51offsets: 2, 4
52data.data: 1 2 3 4 5 6 7 8
53data.offsets: 3 5 6 8
54
55Result column: Array(UInt8):
56Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8]
57data: 1 2 3 4 5 6 7 8
58offsets: 5 8
59
60Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one):
613 5 6 8
62 ^ ^
63
64Example 2:
65
66Source column: Array(Array(Array(UInt8))):
67Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]]
68
69most deep data: 1 2 3 4
70
71offsets1: 2 3
72offsets2: 0 3 4
73- ^ ^ - select by prev offsets
74offsets3: 1 1 3 4
75- ^ ^ - select by prev offsets
76
77result offsets: 3, 4
78result: Row 1: [1, 2, 3], Row2: [4]
79 */
80
81 const ColumnArray * src_col = checkAndGetColumn<ColumnArray>(block.getByPosition(arguments[0]).column.get());
82
83 if (!src_col)
84 throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName() + " in argument of function 'arrayFlatten'",
85 ErrorCodes::ILLEGAL_COLUMN);
86
87 const IColumn::Offsets & src_offsets = src_col->getOffsets();
88
89 ColumnArray::ColumnOffsets::MutablePtr result_offsets_column;
90 const IColumn::Offsets * prev_offsets = &src_offsets;
91 const IColumn * prev_data = &src_col->getData();
92
93 while (const ColumnArray * next_col = checkAndGetColumn<ColumnArray>(prev_data))
94 {
95 if (!result_offsets_column)
96 result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count);
97
98 IColumn::Offsets & result_offsets = result_offsets_column->getData();
99
100 const IColumn::Offsets * next_offsets = &next_col->getOffsets();
101
102 for (size_t i = 0; i < input_rows_count; ++i)
103 result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray
104
105 prev_offsets = &result_offsets;
106 prev_data = &next_col->getData();
107 }
108
109 block.getByPosition(result).column = ColumnArray::create(
110 prev_data->getPtr(),
111 result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr());
112 }
113
114private:
115 String getName() const override
116 {
117 return name;
118 }
119};
120
121
122void registerFunctionArrayFlatten(FunctionFactory & factory)
123{
124 factory.registerFunction<ArrayFlatten>();
125 factory.registerAlias("flatten", "arrayFlatten", FunctionFactory::CaseInsensitive);
126}
127
128}
129