1#pragma once
2
3#include <AggregateFunctions/IAggregateFunction.h>
4#include <Columns/ColumnArray.h>
5#include <Columns/ColumnsNumber.h>
6#include <DataTypes/DataTypeArray.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <IO/ReadHelpers.h>
9#include <IO/WriteHelpers.h>
10
11#include <ext/range.h>
12
13
14namespace DB
15{
16
17template <typename T = UInt64>
18class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>>
19{
20private:
21 size_t category_count;
22
23public:
24 AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
25 IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> {arguments_, params_},
26 category_count {arguments_.size() - 1}
27 {
28 // notice: argument types has been checked before
29 }
30
31 String getName() const override
32 {
33 return "categoricalInformationValue";
34 }
35
36 void create(AggregateDataPtr place) const override
37 {
38 memset(place, 0, sizeOfData());
39 }
40
41 void destroy(AggregateDataPtr) const noexcept override
42 {
43 // nothing
44 }
45
46 bool hasTrivialDestructor() const override
47 {
48 return true;
49 }
50
51 size_t sizeOfData() const override
52 {
53 return sizeof(T) * (category_count + 1) * 2;
54 }
55
56 size_t alignOfData() const override
57 {
58 return alignof(T);
59 }
60
61 void add(
62 AggregateDataPtr place,
63 const IColumn ** columns,
64 size_t row_num,
65 Arena *
66 ) const override
67 {
68 auto y_col = static_cast<const ColumnUInt8 *>(columns[category_count]);
69 bool y = y_col->getData()[row_num];
70
71 for (size_t i : ext::range(0, category_count))
72 {
73 auto x_col = static_cast<const ColumnUInt8 *>(columns[i]);
74 bool x = x_col->getData()[row_num];
75
76 if (x)
77 reinterpret_cast<T *>(place)[i * 2 + size_t(y)] += 1;
78 }
79
80 reinterpret_cast<T *>(place)[category_count * 2 + size_t(y)] += 1;
81 }
82
83 void merge(
84 AggregateDataPtr place,
85 ConstAggregateDataPtr rhs,
86 Arena *
87 ) const override
88 {
89 for (size_t i : ext::range(0, category_count + 1))
90 {
91 reinterpret_cast<T *>(place)[i * 2] += reinterpret_cast<const T *>(rhs)[i * 2];
92 reinterpret_cast<T *>(place)[i * 2 + 1] += reinterpret_cast<const T *>(rhs)[i * 2 + 1];
93 }
94 }
95
96 void serialize(
97 ConstAggregateDataPtr place,
98 WriteBuffer & buf
99 ) const override
100 {
101 buf.write(place, sizeOfData());
102 }
103
104 void deserialize(
105 AggregateDataPtr place,
106 ReadBuffer & buf,
107 Arena *
108 ) const override
109 {
110 buf.read(place, sizeOfData());
111 }
112
113 DataTypePtr getReturnType() const override
114 {
115 return std::make_shared<DataTypeArray>(
116 std::make_shared<DataTypeNumber<Float64>>()
117 );
118 }
119
120 void insertResultInto(
121 ConstAggregateDataPtr place,
122 IColumn & to
123 ) const override
124 {
125 auto & col = static_cast<ColumnArray &>(to);
126 auto & data_col = static_cast<ColumnFloat64 &>(col.getData());
127 auto & offset_col = static_cast<ColumnArray::ColumnOffsets &>(
128 col.getOffsetsColumn()
129 );
130
131 data_col.reserve(data_col.size() + category_count);
132
133 T sum_no = reinterpret_cast<const T *>(place)[category_count * 2];
134 T sum_yes = reinterpret_cast<const T *>(place)[category_count * 2 + 1];
135
136 Float64 rev_no = 1. / sum_no;
137 Float64 rev_yes = 1. / sum_yes;
138
139 for (size_t i : ext::range(0, category_count))
140 {
141 T no = reinterpret_cast<const T *>(place)[i * 2];
142 T yes = reinterpret_cast<const T *>(place)[i * 2 + 1];
143
144 data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes)));
145 }
146
147 offset_col.insertValue(data_col.size());
148 }
149};
150
151}
152