1#pragma once
2
3#include <iostream>
4#include <sstream>
5#include <unordered_set>
6#include <Columns/ColumnsNumber.h>
7#include <DataTypes/DataTypeDateTime.h>
8#include <DataTypes/DataTypesNumber.h>
9#include <IO/ReadHelpers.h>
10#include <IO/WriteHelpers.h>
11#include <Common/ArenaAllocator.h>
12#include <Common/assert_cast.h>
13
14#include <AggregateFunctions/IAggregateFunction.h>
15
16namespace DB
17{
18
19namespace ErrorCodes
20{
21 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
22 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23}
24
25struct ComparePairFirst final
26{
27 template <typename T1, typename T2>
28 bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
29 {
30 return lhs.first < rhs.first;
31 }
32};
33
34static constexpr auto max_events = 32;
35template <typename T>
36struct AggregateFunctionWindowFunnelData
37{
38 using TimestampEvent = std::pair<T, UInt8>;
39 using TimestampEvents = PODArray<TimestampEvent, 64>;
40 using Comparator = ComparePairFirst;
41
42 bool sorted = true;
43 TimestampEvents events_list;
44
45 size_t size() const
46 {
47 return events_list.size();
48 }
49
50 void add(T timestamp, UInt8 event)
51 {
52 // Since most events should have already been sorted by timestamp.
53 if (sorted && events_list.size() > 0 && events_list.back().first > timestamp)
54 sorted = false;
55 events_list.emplace_back(timestamp, event);
56 }
57
58 void merge(const AggregateFunctionWindowFunnelData & other)
59 {
60 const auto size = events_list.size();
61
62 events_list.insert(std::begin(other.events_list), std::end(other.events_list));
63
64 /// either sort whole container or do so partially merging ranges afterwards
65 if (!sorted && !other.sorted)
66 std::stable_sort(std::begin(events_list), std::end(events_list), Comparator{});
67 else
68 {
69 const auto begin = std::begin(events_list);
70 const auto middle = std::next(begin, size);
71 const auto end = std::end(events_list);
72
73 if (!sorted)
74 std::stable_sort(begin, middle, Comparator{});
75
76 if (!other.sorted)
77 std::stable_sort(middle, end, Comparator{});
78
79 std::inplace_merge(begin, middle, end, Comparator{});
80 }
81
82 sorted = true;
83 }
84
85 void sort()
86 {
87 if (!sorted)
88 {
89 std::stable_sort(std::begin(events_list), std::end(events_list), Comparator{});
90 sorted = true;
91 }
92 }
93
94 void serialize(WriteBuffer & buf) const
95 {
96 writeBinary(sorted, buf);
97 writeBinary(events_list.size(), buf);
98
99 for (const auto & events : events_list)
100 {
101 writeBinary(events.first, buf);
102 writeBinary(events.second, buf);
103 }
104 }
105
106 void deserialize(ReadBuffer & buf)
107 {
108 readBinary(sorted, buf);
109
110 size_t size;
111 readBinary(size, buf);
112
113 /// TODO Protection against huge size
114
115 events_list.clear();
116 events_list.reserve(size);
117
118 T timestamp;
119 UInt8 event;
120
121 for (size_t i = 0; i < size; ++i)
122 {
123 readBinary(timestamp, buf);
124 readBinary(event, buf);
125 events_list.emplace_back(timestamp, event);
126 }
127 }
128};
129
130/** Calculates the max event level in a sliding window.
131 * The max size of events is 32, that's enough for funnel analytics
132 *
133 * Usage:
134 * - windowFunnel(window)(timestamp, cond1, cond2, cond3, ....)
135 */
136template <typename T, typename Data>
137class AggregateFunctionWindowFunnel final
138 : public IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>
139{
140private:
141 UInt64 window;
142 UInt8 events_size;
143 UInt8 strict;
144
145
146 // Loop through the entire events_list, update the event timestamp value
147 // The level path must be 1---2---3---...---check_events_size, find the max event level that statisfied the path in the sliding window.
148 // If found, returns the max event level, else return 0.
149 // The Algorithm complexity is O(n).
150 UInt8 getEventLevel(const Data & data) const
151 {
152 if (data.size() == 0)
153 return 0;
154 if (events_size == 1)
155 return 1;
156
157 const_cast<Data &>(data).sort();
158
159 /// events_timestamp stores the timestamp that latest i-th level event happen withing time window after previous level event.
160 /// timestamp defaults to -1, which unsigned timestamp value never meet
161 /// there may be some bugs when UInt64 type timstamp overflows Int64, but it works on most cases.
162 std::vector<Int64> events_timestamp(events_size, -1);
163 for (const auto & pair : data.events_list)
164 {
165 const T & timestamp = pair.first;
166 const auto & event_idx = pair.second - 1;
167
168 if (event_idx == 0)
169 events_timestamp[0] = timestamp;
170 else if (strict && events_timestamp[event_idx] >= 0)
171 {
172 return event_idx + 1;
173 }
174 else if (events_timestamp[event_idx - 1] >= 0 && timestamp <= events_timestamp[event_idx - 1] + window)
175 {
176 events_timestamp[event_idx] = events_timestamp[event_idx - 1];
177 if (event_idx + 1 == events_size)
178 return events_size;
179 }
180 }
181 for (size_t event = events_timestamp.size(); event > 0; --event)
182 {
183 if (events_timestamp[event - 1] >= 0)
184 return event;
185 }
186 return 0;
187 }
188
189public:
190 String getName() const override
191 {
192 return "windowFunnel";
193 }
194
195 AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
196 : IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params)
197 {
198 events_size = arguments.size() - 1;
199 window = params.at(0).safeGet<UInt64>();
200
201 strict = 0;
202 for (size_t i = 1; i < params.size(); ++i)
203 {
204 String option = params.at(i).safeGet<String>();
205 if (option.compare("strict") == 0)
206 strict = 1;
207 else
208 throw Exception{"Aggregate function " + getName() + " doesn't support a parameter: " + option, ErrorCodes::BAD_ARGUMENTS};
209 }
210 }
211
212 DataTypePtr getReturnType() const override
213 {
214 return std::make_shared<DataTypeUInt8>();
215 }
216
217 void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
218 {
219 const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
220 // reverse iteration and stable sorting are needed for events that are qualified by more than one condition.
221 for (auto i = events_size; i > 0; --i)
222 {
223 auto event = assert_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
224 if (event)
225 this->data(place).add(timestamp, i);
226 }
227 }
228
229 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
230 {
231 this->data(place).merge(this->data(rhs));
232 }
233
234 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
235 {
236 this->data(place).serialize(buf);
237 }
238
239 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
240 {
241 this->data(place).deserialize(buf);
242 }
243
244 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
245 {
246 assert_cast<ColumnUInt8 &>(to).getData().push_back(getEventLevel(this->data(place)));
247 }
248};
249
250}
251