1 | #pragma once |
2 | |
3 | #include <AggregateFunctions/IAggregateFunction.h> |
4 | #include <Columns/ColumnArray.h> |
5 | #include <Columns/ColumnsNumber.h> |
6 | #include <DataTypes/DataTypeArray.h> |
7 | #include <DataTypes/DataTypesNumber.h> |
8 | #include <IO/ReadHelpers.h> |
9 | #include <IO/WriteHelpers.h> |
10 | |
11 | #include <ext/range.h> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | template <typename T = UInt64> |
18 | class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> |
19 | { |
20 | private: |
21 | size_t category_count; |
22 | |
23 | public: |
24 | AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : |
25 | IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> {arguments_, params_}, |
26 | category_count {arguments_.size() - 1} |
27 | { |
28 | // notice: argument types has been checked before |
29 | } |
30 | |
31 | String getName() const override |
32 | { |
33 | return "categoricalInformationValue" ; |
34 | } |
35 | |
36 | void create(AggregateDataPtr place) const override |
37 | { |
38 | memset(place, 0, sizeOfData()); |
39 | } |
40 | |
41 | void destroy(AggregateDataPtr) const noexcept override |
42 | { |
43 | // nothing |
44 | } |
45 | |
46 | bool hasTrivialDestructor() const override |
47 | { |
48 | return true; |
49 | } |
50 | |
51 | size_t sizeOfData() const override |
52 | { |
53 | return sizeof(T) * (category_count + 1) * 2; |
54 | } |
55 | |
56 | size_t alignOfData() const override |
57 | { |
58 | return alignof(T); |
59 | } |
60 | |
61 | void add( |
62 | AggregateDataPtr place, |
63 | const IColumn ** columns, |
64 | size_t row_num, |
65 | Arena * |
66 | ) const override |
67 | { |
68 | auto y_col = static_cast<const ColumnUInt8 *>(columns[category_count]); |
69 | bool y = y_col->getData()[row_num]; |
70 | |
71 | for (size_t i : ext::range(0, category_count)) |
72 | { |
73 | auto x_col = static_cast<const ColumnUInt8 *>(columns[i]); |
74 | bool x = x_col->getData()[row_num]; |
75 | |
76 | if (x) |
77 | reinterpret_cast<T *>(place)[i * 2 + size_t(y)] += 1; |
78 | } |
79 | |
80 | reinterpret_cast<T *>(place)[category_count * 2 + size_t(y)] += 1; |
81 | } |
82 | |
83 | void merge( |
84 | AggregateDataPtr place, |
85 | ConstAggregateDataPtr rhs, |
86 | Arena * |
87 | ) const override |
88 | { |
89 | for (size_t i : ext::range(0, category_count + 1)) |
90 | { |
91 | reinterpret_cast<T *>(place)[i * 2] += reinterpret_cast<const T *>(rhs)[i * 2]; |
92 | reinterpret_cast<T *>(place)[i * 2 + 1] += reinterpret_cast<const T *>(rhs)[i * 2 + 1]; |
93 | } |
94 | } |
95 | |
96 | void serialize( |
97 | ConstAggregateDataPtr place, |
98 | WriteBuffer & buf |
99 | ) const override |
100 | { |
101 | buf.write(place, sizeOfData()); |
102 | } |
103 | |
104 | void deserialize( |
105 | AggregateDataPtr place, |
106 | ReadBuffer & buf, |
107 | Arena * |
108 | ) const override |
109 | { |
110 | buf.read(place, sizeOfData()); |
111 | } |
112 | |
113 | DataTypePtr getReturnType() const override |
114 | { |
115 | return std::make_shared<DataTypeArray>( |
116 | std::make_shared<DataTypeNumber<Float64>>() |
117 | ); |
118 | } |
119 | |
120 | void insertResultInto( |
121 | ConstAggregateDataPtr place, |
122 | IColumn & to |
123 | ) const override |
124 | { |
125 | auto & col = static_cast<ColumnArray &>(to); |
126 | auto & data_col = static_cast<ColumnFloat64 &>(col.getData()); |
127 | auto & offset_col = static_cast<ColumnArray::ColumnOffsets &>( |
128 | col.getOffsetsColumn() |
129 | ); |
130 | |
131 | data_col.reserve(data_col.size() + category_count); |
132 | |
133 | T sum_no = reinterpret_cast<const T *>(place)[category_count * 2]; |
134 | T sum_yes = reinterpret_cast<const T *>(place)[category_count * 2 + 1]; |
135 | |
136 | Float64 rev_no = 1. / sum_no; |
137 | Float64 rev_yes = 1. / sum_yes; |
138 | |
139 | for (size_t i : ext::range(0, category_count)) |
140 | { |
141 | T no = reinterpret_cast<const T *>(place)[i * 2]; |
142 | T yes = reinterpret_cast<const T *>(place)[i * 2 + 1]; |
143 | |
144 | data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes))); |
145 | } |
146 | |
147 | offset_col.insertValue(data_col.size()); |
148 | } |
149 | }; |
150 | |
151 | } |
152 | |