1 | #pragma once |
2 | |
3 | #include <iostream> |
4 | #include <sstream> |
5 | #include <unordered_set> |
6 | #include <Columns/ColumnsNumber.h> |
7 | #include <Columns/ColumnArray.h> |
8 | #include <Common/assert_cast.h> |
9 | #include <DataTypes/DataTypesNumber.h> |
10 | #include <DataTypes/DataTypeArray.h> |
11 | #include <IO/ReadHelpers.h> |
12 | #include <IO/WriteHelpers.h> |
13 | #include <Common/ArenaAllocator.h> |
14 | #include <ext/range.h> |
15 | #include <bitset> |
16 | |
17 | #include <AggregateFunctions/IAggregateFunction.h> |
18 | |
19 | |
20 | namespace DB |
21 | { |
22 | namespace ErrorCodes |
23 | { |
24 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
25 | extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION; |
26 | } |
27 | |
28 | struct AggregateFunctionRetentionData |
29 | { |
30 | static constexpr auto max_events = 32; |
31 | |
32 | using Events = std::bitset<max_events>; |
33 | |
34 | Events events; |
35 | |
36 | void add(UInt8 event) |
37 | { |
38 | events.set(event); |
39 | } |
40 | |
41 | void merge(const AggregateFunctionRetentionData & other) |
42 | { |
43 | events |= other.events; |
44 | } |
45 | |
46 | void serialize(WriteBuffer & buf) const |
47 | { |
48 | UInt32 event_value = events.to_ulong(); |
49 | writeBinary(event_value, buf); |
50 | } |
51 | |
52 | void deserialize(ReadBuffer & buf) |
53 | { |
54 | UInt32 event_value; |
55 | readBinary(event_value, buf); |
56 | events = event_value; |
57 | } |
58 | }; |
59 | |
60 | /** |
61 | * The max size of events is 32, that's enough for retention analytics |
62 | * |
63 | * Usage: |
64 | * - retention(cond1, cond2, cond3, ....) |
65 | * - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...] |
66 | */ |
67 | class AggregateFunctionRetention final |
68 | : public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention> |
69 | { |
70 | private: |
71 | UInt8 events_size; |
72 | |
73 | public: |
74 | String getName() const override |
75 | { |
76 | return "retention" ; |
77 | } |
78 | |
79 | AggregateFunctionRetention(const DataTypes & arguments) |
80 | : IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>(arguments, {}) |
81 | { |
82 | for (const auto i : ext::range(0, arguments.size())) |
83 | { |
84 | auto cond_arg = arguments[i].get(); |
85 | if (!isUInt8(cond_arg)) |
86 | throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i) + " of aggregate function " |
87 | + getName() + ", must be UInt8" , |
88 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
89 | } |
90 | |
91 | events_size = static_cast<UInt8>(arguments.size()); |
92 | } |
93 | |
94 | |
95 | DataTypePtr getReturnType() const override |
96 | { |
97 | return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>()); |
98 | } |
99 | |
100 | void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override |
101 | { |
102 | for (const auto i : ext::range(0, events_size)) |
103 | { |
104 | auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num]; |
105 | if (event) |
106 | { |
107 | this->data(place).add(i); |
108 | } |
109 | } |
110 | } |
111 | |
112 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
113 | { |
114 | this->data(place).merge(this->data(rhs)); |
115 | } |
116 | |
117 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
118 | { |
119 | this->data(place).serialize(buf); |
120 | } |
121 | |
122 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
123 | { |
124 | this->data(place).deserialize(buf); |
125 | } |
126 | |
127 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
128 | { |
129 | auto & data_to = assert_cast<ColumnUInt8 &>(assert_cast<ColumnArray &>(to).getData()).getData(); |
130 | auto & offsets_to = assert_cast<ColumnArray &>(to).getOffsets(); |
131 | |
132 | ColumnArray::Offset current_offset = data_to.size(); |
133 | data_to.resize(current_offset + events_size); |
134 | |
135 | const bool first_flag = this->data(place).events.test(0); |
136 | data_to[current_offset] = first_flag; |
137 | ++current_offset; |
138 | |
139 | for (size_t i = 1; i < events_size; ++i) |
140 | { |
141 | data_to[current_offset] = (first_flag && this->data(place).events.test(i)); |
142 | ++current_offset; |
143 | } |
144 | |
145 | offsets_to.push_back(current_offset); |
146 | } |
147 | }; |
148 | |
149 | } |
150 | |