1 | #pragma once |
2 | |
3 | #include <Common/HashTable/HashMap.h> |
4 | #include <Common/NaNUtils.h> |
5 | |
6 | #include <AggregateFunctions/IAggregateFunction.h> |
7 | #include <AggregateFunctions/UniqVariadicHash.h> |
8 | #include <DataTypes/DataTypesNumber.h> |
9 | #include <Columns/ColumnVector.h> |
10 | #include <Common/assert_cast.h> |
11 | |
12 | #include <cmath> |
13 | |
14 | |
15 | namespace DB |
16 | { |
17 | |
18 | /** Calculates Shannon Entropy, using HashMap and computing empirical distribution function. |
19 | * Entropy is measured in bits (base-2 logarithm is used). |
20 | */ |
21 | template <typename Value> |
22 | struct EntropyData |
23 | { |
24 | using Weight = UInt64; |
25 | |
26 | using HashingMap = HashMap< |
27 | Value, Weight, |
28 | HashCRC32<Value>, |
29 | HashTableGrower<4>, |
30 | HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>; |
31 | |
32 | /// For the case of pre-hashed values. |
33 | using TrivialMap = HashMap< |
34 | Value, Weight, |
35 | UInt128TrivialHash, |
36 | HashTableGrower<4>, |
37 | HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>; |
38 | |
39 | using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>; |
40 | |
41 | Map map; |
42 | |
43 | void add(const Value & x) |
44 | { |
45 | if (!isNaN(x)) |
46 | ++map[x]; |
47 | } |
48 | |
49 | void add(const Value & x, const Weight & weight) |
50 | { |
51 | if (!isNaN(x)) |
52 | map[x] += weight; |
53 | } |
54 | |
55 | void merge(const EntropyData & rhs) |
56 | { |
57 | for (const auto & pair : rhs.map) |
58 | map[pair.getKey()] += pair.getMapped(); |
59 | } |
60 | |
61 | void serialize(WriteBuffer & buf) const |
62 | { |
63 | map.write(buf); |
64 | } |
65 | |
66 | void deserialize(ReadBuffer & buf) |
67 | { |
68 | typename Map::Reader reader(buf); |
69 | while (reader.next()) |
70 | { |
71 | const auto & pair = reader.get(); |
72 | map[pair.first] = pair.second; |
73 | } |
74 | } |
75 | |
76 | Float64 get() const |
77 | { |
78 | UInt64 total_value = 0; |
79 | for (const auto & pair : map) |
80 | total_value += pair.getMapped(); |
81 | |
82 | Float64 shannon_entropy = 0; |
83 | for (const auto & pair : map) |
84 | { |
85 | Float64 frequency = Float64(pair.getMapped()) / total_value; |
86 | shannon_entropy -= frequency * log2(frequency); |
87 | } |
88 | |
89 | return shannon_entropy; |
90 | } |
91 | }; |
92 | |
93 | |
94 | template <typename Value> |
95 | class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>> |
96 | { |
97 | private: |
98 | size_t num_args; |
99 | |
100 | public: |
101 | AggregateFunctionEntropy(const DataTypes & argument_types_) |
102 | : IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {}) |
103 | , num_args(argument_types_.size()) |
104 | { |
105 | } |
106 | |
107 | String getName() const override { return "entropy" ; } |
108 | |
109 | DataTypePtr getReturnType() const override |
110 | { |
111 | return std::make_shared<DataTypeNumber<Float64>>(); |
112 | } |
113 | |
114 | void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override |
115 | { |
116 | if constexpr (!std::is_same_v<UInt128, Value>) |
117 | { |
118 | /// Here we manage only with numerical types |
119 | const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]); |
120 | this->data(place).add(column.getData()[row_num]); |
121 | } |
122 | else |
123 | { |
124 | this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num)); |
125 | } |
126 | } |
127 | |
128 | void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override |
129 | { |
130 | this->data(place).merge(this->data(rhs)); |
131 | } |
132 | |
133 | void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override |
134 | { |
135 | this->data(const_cast<AggregateDataPtr>(place)).serialize(buf); |
136 | } |
137 | |
138 | void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override |
139 | { |
140 | this->data(place).deserialize(buf); |
141 | } |
142 | |
143 | void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override |
144 | { |
145 | auto & column = assert_cast<ColumnVector<Float64> &>(to); |
146 | column.getData().push_back(this->data(place).get()); |
147 | } |
148 | }; |
149 | |
150 | } |
151 | |