1#pragma once
2
3#include <common/logger_useful.h>
4
5#include <DataTypes/DataTypesNumber.h>
6#include <Columns/ColumnsNumber.h>
7
8#include <IO/ReadHelpers.h>
9#include <IO/WriteHelpers.h>
10
11#include <Common/ArenaAllocator.h>
12#include <Common/NaNUtils.h>
13#include <Common/assert_cast.h>
14
15#include <AggregateFunctions/IAggregateFunction.h>
16
17#define AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE 0xFFFFFF
18
19
20namespace DB
21{
22
23namespace ErrorCodes
24{
25 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
26 extern const int TOO_LARGE_ARRAY_SIZE;
27}
28
29
30/** maxIntersections: returns maximum count of the intersected intervals defined by start_column and end_column values,
31 * maxIntersectionsPosition: returns leftmost position of maximum intersection of intervals.
32 */
33
34/// Similar to GroupArrayNumericData.
35template <typename T>
36struct MaxIntersectionsData
37{
38 /// Left or right end of the interval and signed weight; with positive sign for begin of interval and negative sign for end of interval.
39 using Value = std::pair<T, Int64>;
40
41 // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
42 using Allocator = MixedAlignedArenaAllocator<alignof(Value), 4096>;
43 using Array = PODArray<Value, 32, Allocator>;
44
45 Array value;
46};
47
48enum class AggregateFunctionIntersectionsKind
49{
50 Count,
51 Position
52};
53
54template <typename PointType>
55class AggregateFunctionIntersectionsMax final
56 : public IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>
57{
58private:
59 AggregateFunctionIntersectionsKind kind;
60
61public:
62 AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
63 : IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
64 {
65 if (!isNativeNumber(arguments[0]))
66 throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
67
68 if (!isNativeNumber(arguments[1]))
69 throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
70
71 if (!arguments[0]->equals(*arguments[1]))
72 throw Exception{getName() + ": arguments must have the same type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
73 }
74
75 String getName() const override
76 {
77 return kind == AggregateFunctionIntersectionsKind::Count
78 ? "maxIntersections"
79 : "maxIntersectionsPosition";
80 }
81
82 DataTypePtr getReturnType() const override
83 {
84 if (kind == AggregateFunctionIntersectionsKind::Count)
85 return std::make_shared<DataTypeUInt64>();
86 else
87 return std::make_shared<DataTypeNumber<PointType>>();
88 }
89
90 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
91 {
92 PointType left = assert_cast<const ColumnVector<PointType> &>(*columns[0]).getData()[row_num];
93 PointType right = assert_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];
94
95 if (!isNaN(left))
96 this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
97
98 if (!isNaN(right))
99 this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
100 }
101
102 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
103 {
104 auto & cur_elems = this->data(place);
105 auto & rhs_elems = this->data(rhs);
106
107 cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
108 }
109
110 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
111 {
112 const auto & value = this->data(place).value;
113 size_t size = value.size();
114 writeVarUInt(size, buf);
115 buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
116 }
117
118 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
119 {
120 size_t size = 0;
121 readVarUInt(size, buf);
122
123 if (unlikely(size > AGGREGATE_FUNCTION_MAX_INTERSECTIONS_MAX_ARRAY_SIZE))
124 throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
125
126 auto & value = this->data(place).value;
127
128 value.resize(size, arena);
129 buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
130 }
131
132 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
133 {
134 Int64 current_intersections = 0;
135 Int64 max_intersections = 0;
136 PointType position_of_max_intersections = 0;
137
138 /// const_cast because we will sort the array
139 auto & array = const_cast<typename MaxIntersectionsData<PointType>::Array &>(this->data(place).value);
140
141 /// Sort by position; for equal position, sort by weight to get deterministic result.
142 std::sort(array.begin(), array.end());
143
144 for (const auto & point_weight : array)
145 {
146 current_intersections += point_weight.second;
147 if (current_intersections > max_intersections)
148 {
149 max_intersections = current_intersections;
150 position_of_max_intersections = point_weight.first;
151 }
152 }
153
154 if (kind == AggregateFunctionIntersectionsKind::Count)
155 {
156 auto & result_column = assert_cast<ColumnUInt64 &>(to).getData();
157 result_column.push_back(max_intersections);
158 }
159 else
160 {
161 auto & result_column = assert_cast<ColumnVector<PointType> &>(to).getData();
162 result_column.push_back(position_of_max_intersections);
163 }
164 }
165};
166
167}
168