1#include <AggregateFunctions/Helpers.h>
2#include <AggregateFunctions/AggregateFunctionFactory.h>
3#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
4
5#include <DataTypes/DataTypeDate.h>
6#include <DataTypes/DataTypeDateTime.h>
7
8#include <ext/range.h>
9#include "registerAggregateFunctions.h"
10
11namespace DB
12{
13
14namespace ErrorCodes
15{
16 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
17}
18
19namespace
20{
21
22template <template <typename, typename> class AggregateFunction, template <typename> class Data>
23AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
24{
25 if (params.size() != 1)
26 throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
27 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
28
29 const auto arg_count = argument_types.size();
30
31 if (arg_count < 3)
32 throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
33 ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
34
35 if (arg_count - 1 > max_events)
36 throw Exception{"Aggregate function " + name + " supports up to "
37 + toString(max_events) + " event arguments.",
38 ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
39
40 const auto time_arg = argument_types.front().get();
41
42 for (const auto i : ext::range(1, arg_count))
43 {
44 const auto cond_arg = argument_types[i].get();
45 if (!isUInt8(cond_arg))
46 throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
47 + " of aggregate function " + name + ", must be UInt8",
48 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
49 }
50
51 String pattern = params.front().safeGet<std::string>();
52
53 AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
54 if (res)
55 return res;
56
57 WhichDataType which(argument_types.front().get());
58 if (which.isDateTime())
59 return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
60 else if (which.isDate())
61 return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
62
63 throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
64 + name + ", must be DateTime",
65 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
66}
67
68}
69
70void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
71{
72 factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
73 factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
74}
75
76}
77