1 | #include <memory> |
2 | #include <random> |
3 | |
4 | #include <DataTypes/DataTypesNumber.h> |
5 | #include <Common/thread_local_rng.h> |
6 | #include <IO/ReadBuffer.h> |
7 | #include <IO/WriteBuffer.h> |
8 | #include <AggregateFunctions/IAggregateFunction.h> |
9 | #include <AggregateFunctions/AggregateFunctionFactory.h> |
10 | |
11 | |
12 | namespace DB |
13 | { |
14 | |
15 | namespace ErrorCodes |
16 | { |
17 | extern const int AGGREGATE_FUNCTION_THROW; |
18 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
19 | } |
20 | |
21 | namespace |
22 | { |
23 | |
24 | struct AggregateFunctionThrowData |
25 | { |
26 | bool allocated; |
27 | |
28 | AggregateFunctionThrowData() : allocated(true) {} |
29 | ~AggregateFunctionThrowData() |
30 | { |
31 | volatile bool * allocated_ptr = &allocated; |
32 | |
33 | if (*allocated_ptr) |
34 | *allocated_ptr = false; |
35 | else |
36 | abort(); |
37 | } |
38 | }; |
39 | |
40 | /** Throw on creation with probability specified in parameter. |
41 | * It will check correct destruction of the state. |
42 | * This is intended to check for exception safety. |
43 | */ |
44 | class AggregateFunctionThrow final : public IAggregateFunctionDataHelper<AggregateFunctionThrowData, AggregateFunctionThrow> |
45 | { |
46 | private: |
47 | Float64 throw_probability; |
48 | |
49 | public: |
50 | AggregateFunctionThrow(const DataTypes & argument_types_, const Array & parameters_, Float64 throw_probability_) |
51 | : IAggregateFunctionDataHelper(argument_types_, parameters_), throw_probability(throw_probability_) {} |
52 | |
53 | String getName() const override |
54 | { |
55 | return "aggThrow" ; |
56 | } |
57 | |
58 | DataTypePtr getReturnType() const override |
59 | { |
60 | return std::make_shared<DataTypeUInt8>(); |
61 | } |
62 | |
63 | void create(AggregateDataPtr place) const override |
64 | { |
65 | if (std::uniform_real_distribution<>(0.0, 1.0)(thread_local_rng) <= throw_probability) |
66 | throw Exception("Aggregate function " + getName() + " has thrown exception successfully" , ErrorCodes::AGGREGATE_FUNCTION_THROW); |
67 | |
68 | new (place) Data; |
69 | } |
70 | |
71 | void destroy(AggregateDataPtr place) const noexcept override |
72 | { |
73 | data(place).~Data(); |
74 | } |
75 | |
76 | void add(AggregateDataPtr, const IColumn **, size_t, Arena *) const override |
77 | { |
78 | } |
79 | |
80 | void merge(AggregateDataPtr, ConstAggregateDataPtr, Arena *) const override |
81 | { |
82 | } |
83 | |
84 | void serialize(ConstAggregateDataPtr, WriteBuffer & buf) const override |
85 | { |
86 | char c = 0; |
87 | buf.write(c); |
88 | } |
89 | |
90 | void deserialize(AggregateDataPtr, ReadBuffer & buf, Arena *) const override |
91 | { |
92 | char c = 0; |
93 | buf.read(c); |
94 | } |
95 | |
96 | void insertResultInto(ConstAggregateDataPtr, IColumn & to) const override |
97 | { |
98 | to.insertDefault(); |
99 | } |
100 | }; |
101 | |
102 | } |
103 | |
104 | void registerAggregateFunctionAggThrow(AggregateFunctionFactory & factory) |
105 | { |
106 | factory.registerFunction("aggThrow" , [](const std::string & name, const DataTypes & argument_types, const Array & parameters) |
107 | { |
108 | Float64 throw_probability = 1.0; |
109 | if (parameters.size() == 1) |
110 | throw_probability = parameters[0].safeGet<Float64>(); |
111 | else if (parameters.size() > 1) |
112 | throw Exception("Aggregate function " + name + " cannot have more than one parameter" , ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
113 | |
114 | return std::make_shared<AggregateFunctionThrow>(argument_types, parameters, throw_probability); |
115 | }); |
116 | } |
117 | |
118 | } |
119 | |
120 | |