1#pragma once
2
3#include <AggregateFunctions/IAggregateFunction.h>
4#include <DataTypes/DataTypeDateTime.h>
5#include <DataTypes/DataTypesNumber.h>
6#include <Columns/ColumnsNumber.h>
7#include <Common/assert_cast.h>
8#include <ext/range.h>
9#include <Common/PODArray.h>
10#include <IO/ReadHelpers.h>
11#include <IO/WriteHelpers.h>
12#include <bitset>
13#include <stack>
14
15
16namespace DB
17{
18
19namespace ErrorCodes
20{
21 extern const int TOO_SLOW;
22 extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
23 extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
24 extern const int SYNTAX_ERROR;
25 extern const int BAD_ARGUMENTS;
26 extern const int LOGICAL_ERROR;
27}
28
29/// helper type for comparing `std::pair`s using solely the .first member
30template <template <typename> class Comparator>
31struct ComparePairFirst final
32{
33 template <typename T1, typename T2>
34 bool operator()(const std::pair<T1, T2> & lhs, const std::pair<T1, T2> & rhs) const
35 {
36 return Comparator<T1>{}(lhs.first, rhs.first);
37 }
38};
39
40static constexpr auto max_events = 32;
41
42template <typename T>
43struct AggregateFunctionSequenceMatchData final
44{
45 using Timestamp = T;
46 using Events = std::bitset<max_events>;
47 using TimestampEvents = std::pair<Timestamp, Events>;
48 using Comparator = ComparePairFirst<std::less>;
49
50 bool sorted = true;
51 PODArrayWithStackMemory<TimestampEvents, 64> events_list;
52
53 void add(const Timestamp timestamp, const Events & events)
54 {
55 /// store information exclusively for rows with at least one event
56 if (events.any())
57 {
58 events_list.emplace_back(timestamp, events);
59 sorted = false;
60 }
61 }
62
63 void merge(const AggregateFunctionSequenceMatchData & other)
64 {
65 if (other.events_list.empty())
66 return;
67
68 const auto size = events_list.size();
69
70 events_list.insert(std::begin(other.events_list), std::end(other.events_list));
71
72 /// either sort whole container or do so partially merging ranges afterwards
73 if (!sorted && !other.sorted)
74 std::sort(std::begin(events_list), std::end(events_list), Comparator{});
75 else
76 {
77 const auto begin = std::begin(events_list);
78 const auto middle = std::next(begin, size);
79 const auto end = std::end(events_list);
80
81 if (!sorted)
82 std::sort(begin, middle, Comparator{});
83
84 if (!other.sorted)
85 std::sort(middle, end, Comparator{});
86
87 std::inplace_merge(begin, middle, end, Comparator{});
88 }
89
90 sorted = true;
91 }
92
93 void sort()
94 {
95 if (!sorted)
96 {
97 std::sort(std::begin(events_list), std::end(events_list), Comparator{});
98 sorted = true;
99 }
100 }
101
102 void serialize(WriteBuffer & buf) const
103 {
104 writeBinary(sorted, buf);
105 writeBinary(events_list.size(), buf);
106
107 for (const auto & events : events_list)
108 {
109 writeBinary(events.first, buf);
110 writeBinary(events.second.to_ulong(), buf);
111 }
112 }
113
114 void deserialize(ReadBuffer & buf)
115 {
116 readBinary(sorted, buf);
117
118 size_t size;
119 readBinary(size, buf);
120
121 events_list.clear();
122 events_list.reserve(size);
123
124 for (size_t i = 0; i < size; ++i)
125 {
126 Timestamp timestamp;
127 readBinary(timestamp, buf);
128
129 UInt64 events;
130 readBinary(events, buf);
131
132 events_list.emplace_back(timestamp, Events{events});
133 }
134 }
135};
136
137
138/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
139constexpr auto sequence_match_max_iterations = 1000000;
140
141
142template <typename T, typename Data, typename Derived>
143class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
144{
145public:
146 AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern_)
147 : IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
148 , pattern(pattern_)
149 {
150 arg_count = arguments.size();
151 parsePattern();
152 }
153
154 void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
155 {
156 const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
157
158 typename Data::Events events;
159 for (const auto i : ext::range(1, arg_count))
160 {
161 const auto event = assert_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
162 events.set(i - 1, event);
163 }
164
165 this->data(place).add(timestamp, events);
166 }
167
168 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
169 {
170 this->data(place).merge(this->data(rhs));
171 }
172
173 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
174 {
175 this->data(place).serialize(buf);
176 }
177
178 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
179 {
180 this->data(place).deserialize(buf);
181 }
182
183private:
184 enum class PatternActionType
185 {
186 SpecificEvent,
187 AnyEvent,
188 KleeneStar,
189 TimeLessOrEqual,
190 TimeLess,
191 TimeGreaterOrEqual,
192 TimeGreater
193 };
194
195 struct PatternAction final
196 {
197 PatternActionType type;
198 std::uint64_t extra;
199
200 PatternAction() = default;
201 PatternAction(const PatternActionType type_, const std::uint64_t extra_ = 0) : type{type_}, extra{extra_} {}
202 };
203
204 using PatternActions = PODArrayWithStackMemory<PatternAction, 64>;
205
206 Derived & derived() { return static_cast<Derived &>(*this); }
207
208 void parsePattern()
209 {
210 actions.clear();
211 actions.emplace_back(PatternActionType::KleeneStar);
212
213 dfa_states.clear();
214 dfa_states.emplace_back(true);
215
216 pattern_has_time = false;
217
218 const char * pos = pattern.data();
219 const char * begin = pos;
220 const char * end = pos + pattern.size();
221
222 auto throw_exception = [&](const std::string & msg)
223 {
224 throw Exception{msg + " '" + std::string(pos, end) + "' at position " + toString(pos - begin), ErrorCodes::SYNTAX_ERROR};
225 };
226
227 auto match = [&pos, end](const char * str) mutable
228 {
229 size_t length = strlen(str);
230 if (pos + length <= end && 0 == memcmp(pos, str, length))
231 {
232 pos += length;
233 return true;
234 }
235 return false;
236 };
237
238 while (pos < end)
239 {
240 if (match("(?"))
241 {
242 if (match("t"))
243 {
244 PatternActionType type;
245
246 if (match("<="))
247 type = PatternActionType::TimeLessOrEqual;
248 else if (match("<"))
249 type = PatternActionType::TimeLess;
250 else if (match(">="))
251 type = PatternActionType::TimeGreaterOrEqual;
252 else if (match(">"))
253 type = PatternActionType::TimeGreater;
254 else
255 throw_exception("Unknown time condition");
256
257 UInt64 duration = 0;
258 auto prev_pos = pos;
259 pos = tryReadIntText(duration, pos, end);
260 if (pos == prev_pos)
261 throw_exception("Could not parse number");
262
263 if (actions.back().type != PatternActionType::SpecificEvent &&
264 actions.back().type != PatternActionType::AnyEvent &&
265 actions.back().type != PatternActionType::KleeneStar)
266 throw Exception{"Temporal condition should be preceeded by an event condition", ErrorCodes::BAD_ARGUMENTS};
267
268 pattern_has_time = true;
269 actions.emplace_back(type, duration);
270 }
271 else
272 {
273 UInt64 event_number = 0;
274 auto prev_pos = pos;
275 pos = tryReadIntText(event_number, pos, end);
276 if (pos == prev_pos)
277 throw_exception("Could not parse number");
278
279 if (event_number > arg_count - 1)
280 throw Exception{"Event number " + toString(event_number) + " is out of range", ErrorCodes::BAD_ARGUMENTS};
281
282 actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
283 dfa_states.back().transition = DFATransition::SpecificEvent;
284 dfa_states.back().event = event_number - 1;
285 dfa_states.emplace_back();
286 }
287
288 if (!match(")"))
289 throw_exception("Expected closing parenthesis, found");
290
291 }
292 else if (match(".*"))
293 {
294 actions.emplace_back(PatternActionType::KleeneStar);
295 dfa_states.back().has_kleene = true;
296 }
297 else if (match("."))
298 {
299 actions.emplace_back(PatternActionType::AnyEvent);
300 dfa_states.back().transition = DFATransition::AnyEvent;
301 dfa_states.emplace_back();
302 }
303 else
304 throw_exception("Could not parse pattern, unexpected starting symbol");
305 }
306 }
307
308protected:
309 /// Uses a DFA based approach in order to better handle patterns without
310 /// time assertions.
311 ///
312 /// NOTE: This implementation relies on the assumption that the pattern are *small*.
313 ///
314 /// This algorithm performs in O(mn) (with m the number of DFA states and N the number
315 /// of events) with a memory consumption and memory allocations in O(m). It means that
316 /// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
317 template <typename EventEntry>
318 bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
319 {
320 using ActiveStates = std::vector<bool>;
321
322 /// Those two vectors keep track of which states should be considered for the current
323 /// event as well as the states which should be considered for the next event.
324 ActiveStates active_states(dfa_states.size(), false);
325 ActiveStates next_active_states(dfa_states.size(), false);
326 active_states[0] = true;
327
328 /// Keeps track of dead-ends in order not to iterate over all the events to realize that
329 /// the match failed.
330 size_t n_active = 1;
331
332 for (/* empty */; events_it != events_end && n_active > 0 && !active_states.back(); ++events_it)
333 {
334 n_active = 0;
335 next_active_states.assign(dfa_states.size(), false);
336
337 for (size_t state = 0; state < dfa_states.size(); ++state)
338 {
339 if (!active_states[state])
340 {
341 continue;
342 }
343
344 switch (dfa_states[state].transition)
345 {
346 case DFATransition::None:
347 break;
348 case DFATransition::AnyEvent:
349 next_active_states[state + 1] = true;
350 ++n_active;
351 break;
352 case DFATransition::SpecificEvent:
353 if (events_it->second.test(dfa_states[state].event))
354 {
355 next_active_states[state + 1] = true;
356 ++n_active;
357 }
358 break;
359 }
360
361 if (dfa_states[state].has_kleene)
362 {
363 next_active_states[state] = true;
364 ++n_active;
365 }
366 }
367 swap(active_states, next_active_states);
368 }
369
370 return active_states.back();
371 }
372
373 template <typename EventEntry>
374 bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
375 {
376 const auto action_begin = std::begin(actions);
377 const auto action_end = std::end(actions);
378 auto action_it = action_begin;
379
380 const auto events_begin = events_it;
381 auto base_it = events_it;
382
383 /// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
384 using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
385 std::stack<backtrack_info> back_stack;
386
387 /// backtrack if possible
388 const auto do_backtrack = [&]
389 {
390 while (!back_stack.empty())
391 {
392 auto & top = back_stack.top();
393
394 action_it = std::get<0>(top);
395 events_it = std::next(std::get<1>(top));
396 base_it = std::get<2>(top);
397
398 back_stack.pop();
399
400 if (events_it != events_end)
401 return true;
402 }
403
404 return false;
405 };
406
407 size_t i = 0;
408 while (action_it != action_end && events_it != events_end)
409 {
410 if (action_it->type == PatternActionType::SpecificEvent)
411 {
412 if (events_it->second.test(action_it->extra))
413 {
414 /// move to the next action and events
415 base_it = events_it;
416 ++action_it, ++events_it;
417 }
418 else if (!do_backtrack())
419 /// backtracking failed, bail out
420 break;
421 }
422 else if (action_it->type == PatternActionType::AnyEvent)
423 {
424 base_it = events_it;
425 ++action_it, ++events_it;
426 }
427 else if (action_it->type == PatternActionType::KleeneStar)
428 {
429 back_stack.emplace(action_it, events_it, base_it);
430 base_it = events_it;
431 ++action_it;
432 }
433 else if (action_it->type == PatternActionType::TimeLessOrEqual)
434 {
435 if (events_it->first <= base_it->first + action_it->extra)
436 {
437 /// condition satisfied, move onto next action
438 back_stack.emplace(action_it, events_it, base_it);
439 base_it = events_it;
440 ++action_it;
441 }
442 else if (!do_backtrack())
443 break;
444 }
445 else if (action_it->type == PatternActionType::TimeLess)
446 {
447 if (events_it->first < base_it->first + action_it->extra)
448 {
449 back_stack.emplace(action_it, events_it, base_it);
450 base_it = events_it;
451 ++action_it;
452 }
453 else if (!do_backtrack())
454 break;
455 }
456 else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
457 {
458 if (events_it->first >= base_it->first + action_it->extra)
459 {
460 back_stack.emplace(action_it, events_it, base_it);
461 base_it = events_it;
462 ++action_it;
463 }
464 else if (++events_it == events_end && !do_backtrack())
465 break;
466 }
467 else if (action_it->type == PatternActionType::TimeGreater)
468 {
469 if (events_it->first > base_it->first + action_it->extra)
470 {
471 back_stack.emplace(action_it, events_it, base_it);
472 base_it = events_it;
473 ++action_it;
474 }
475 else if (++events_it == events_end && !do_backtrack())
476 break;
477 }
478 else
479 throw Exception{"Unknown PatternActionType", ErrorCodes::LOGICAL_ERROR};
480
481 if (++i > sequence_match_max_iterations)
482 throw Exception{"Pattern application proves too difficult, exceeding max iterations (" + toString(sequence_match_max_iterations) + ")",
483 ErrorCodes::TOO_SLOW};
484 }
485
486 /// if there are some actions remaining
487 if (action_it != action_end)
488 {
489 /// match multiple empty strings at end
490 while (action_it->type == PatternActionType::KleeneStar ||
491 action_it->type == PatternActionType::TimeLessOrEqual ||
492 action_it->type == PatternActionType::TimeLess ||
493 (action_it->type == PatternActionType::TimeGreaterOrEqual && action_it->extra == 0))
494 ++action_it;
495 }
496
497 if (events_it == events_begin)
498 ++events_it;
499
500 return action_it == action_end;
501 }
502
503private:
504 enum class DFATransition : char
505 {
506 /// .-------.
507 /// | |
508 /// `-------'
509 None,
510 /// .-------. (?[0-9])
511 /// | | ----------
512 /// `-------'
513 SpecificEvent,
514 /// .-------. .
515 /// | | ----------
516 /// `-------'
517 AnyEvent,
518 };
519
520 struct DFAState
521 {
522 DFAState(bool has_kleene_ = false)
523 : has_kleene{has_kleene_}, event{0}, transition{DFATransition::None}
524 {}
525
526 /// .-------.
527 /// | | - - -
528 /// `-------'
529 /// |_^
530 bool has_kleene;
531 /// In the case of a state transitions with a `SpecificEvent`,
532 /// `event` contains the value of the event.
533 uint32_t event;
534 /// The kind of transition out of this state.
535 DFATransition transition;
536 };
537
538 using DFAStates = std::vector<DFAState>;
539
540protected:
541 /// `True` if the parsed pattern contains time assertions (?t...), `false` otherwise.
542 bool pattern_has_time;
543
544private:
545 std::string pattern;
546 size_t arg_count;
547 PatternActions actions;
548
549 DFAStates dfa_states;
550};
551
552template <typename T, typename Data>
553class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
554{
555public:
556 AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern_)
557 : AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern_) {}
558
559 using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
560
561 String getName() const override { return "sequenceMatch"; }
562
563 DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }
564
565 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
566 {
567 const_cast<Data &>(this->data(place)).sort();
568
569 const auto & data_ref = this->data(place);
570
571 const auto events_begin = std::begin(data_ref.events_list);
572 const auto events_end = std::end(data_ref.events_list);
573 auto events_it = events_begin;
574
575 bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
576 assert_cast<ColumnUInt8 &>(to).getData().push_back(match);
577 }
578};
579
580template <typename T, typename Data>
581class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
582{
583public:
584 AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern_)
585 : AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern_) {}
586
587 using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
588
589 String getName() const override { return "sequenceCount"; }
590
591 DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
592
593 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
594 {
595 const_cast<Data &>(this->data(place)).sort();
596 assert_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
597 }
598
599private:
600 UInt64 count(const ConstAggregateDataPtr & place) const
601 {
602 const auto & data_ref = this->data(place);
603
604 const auto events_begin = std::begin(data_ref.events_list);
605 const auto events_end = std::end(data_ref.events_list);
606 auto events_it = events_begin;
607
608 size_t count = 0;
609 while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
610 ++count;
611
612 return count;
613 }
614};
615
616}
617