1#pragma once
2
3#include <Common/HashTable/HashMap.h>
4#include <Common/NaNUtils.h>
5
6#include <AggregateFunctions/IAggregateFunction.h>
7#include <AggregateFunctions/UniqVariadicHash.h>
8#include <DataTypes/DataTypesNumber.h>
9#include <Columns/ColumnVector.h>
10#include <Common/assert_cast.h>
11
12#include <cmath>
13
14
15namespace DB
16{
17
18/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
19 * Entropy is measured in bits (base-2 logarithm is used).
20 */
21template <typename Value>
22struct EntropyData
23{
24 using Weight = UInt64;
25
26 using HashingMap = HashMap<
27 Value, Weight,
28 HashCRC32<Value>,
29 HashTableGrower<4>,
30 HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>;
31
32 /// For the case of pre-hashed values.
33 using TrivialMap = HashMap<
34 Value, Weight,
35 UInt128TrivialHash,
36 HashTableGrower<4>,
37 HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>;
38
39 using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
40
41 Map map;
42
43 void add(const Value & x)
44 {
45 if (!isNaN(x))
46 ++map[x];
47 }
48
49 void add(const Value & x, const Weight & weight)
50 {
51 if (!isNaN(x))
52 map[x] += weight;
53 }
54
55 void merge(const EntropyData & rhs)
56 {
57 for (const auto & pair : rhs.map)
58 map[pair.getKey()] += pair.getMapped();
59 }
60
61 void serialize(WriteBuffer & buf) const
62 {
63 map.write(buf);
64 }
65
66 void deserialize(ReadBuffer & buf)
67 {
68 typename Map::Reader reader(buf);
69 while (reader.next())
70 {
71 const auto & pair = reader.get();
72 map[pair.first] = pair.second;
73 }
74 }
75
76 Float64 get() const
77 {
78 UInt64 total_value = 0;
79 for (const auto & pair : map)
80 total_value += pair.getMapped();
81
82 Float64 shannon_entropy = 0;
83 for (const auto & pair : map)
84 {
85 Float64 frequency = Float64(pair.getMapped()) / total_value;
86 shannon_entropy -= frequency * log2(frequency);
87 }
88
89 return shannon_entropy;
90 }
91};
92
93
94template <typename Value>
95class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
96{
97private:
98 size_t num_args;
99
100public:
101 AggregateFunctionEntropy(const DataTypes & argument_types_)
102 : IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>(argument_types_, {})
103 , num_args(argument_types_.size())
104 {
105 }
106
107 String getName() const override { return "entropy"; }
108
109 DataTypePtr getReturnType() const override
110 {
111 return std::make_shared<DataTypeNumber<Float64>>();
112 }
113
114 void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
115 {
116 if constexpr (!std::is_same_v<UInt128, Value>)
117 {
118 /// Here we manage only with numerical types
119 const auto & column = assert_cast<const ColumnVector <Value> &>(*columns[0]);
120 this->data(place).add(column.getData()[row_num]);
121 }
122 else
123 {
124 this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
125 }
126 }
127
128 void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
129 {
130 this->data(place).merge(this->data(rhs));
131 }
132
133 void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
134 {
135 this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
136 }
137
138 void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
139 {
140 this->data(place).deserialize(buf);
141 }
142
143 void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
144 {
145 auto & column = assert_cast<ColumnVector<Float64> &>(to);
146 column.getData().push_back(this->data(place).get());
147 }
148};
149
150}
151