1#pragma once
2
3#include <Columns/ColumnConst.h>
4#include <Columns/ColumnsNumber.h>
5#include <DataTypes/DataTypesNumber.h>
6#include <Functions/FunctionHelpers.h>
7#include <Functions/IFunctionImpl.h>
8#include <Common/typeid_cast.h>
9#include <common/likely.h>
10
11
12namespace DB
13{
14namespace ErrorCodes
15{
16 extern const int LOGICAL_ERROR;
17 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
18 extern const int ILLEGAL_COLUMN;
19 extern const int BAD_ARGUMENTS;
20}
21
22
23template <typename Impl>
24class FunctionConsistentHashImpl : public IFunction
25{
26public:
27 static constexpr auto name = Impl::name;
28
29 static FunctionPtr create(const Context &)
30 {
31 return std::make_shared<FunctionConsistentHashImpl<Impl>>();
32 }
33
34 String getName() const override
35 {
36 return name;
37 }
38
39 size_t getNumberOfArguments() const override
40 {
41 return 2;
42 }
43
44 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
45 {
46 if (!isInteger(arguments[0]))
47 throw Exception("Illegal type " + arguments[0]->getName() + " of the first argument of function " + getName(),
48 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
49
50 if (arguments[0]->getSizeOfValueInMemory() > sizeof(HashType))
51 throw Exception("Function " + getName() + " accepts " + std::to_string(sizeof(HashType) * 8) + "-bit integers at most"
52 + ", got " + arguments[0]->getName(),
53 ErrorCodes::BAD_ARGUMENTS);
54
55 if (!isInteger(arguments[1]))
56 throw Exception("Illegal type " + arguments[1]->getName() + " of the second argument of function " + getName(),
57 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
58
59 return std::make_shared<DataTypeNumber<ResultType>>();
60 }
61
62 bool useDefaultImplementationForConstants() const override
63 {
64 return true;
65 }
66 ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
67 {
68 return {1};
69 }
70
71 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
72 {
73 if (isColumnConst(*block.getByPosition(arguments[1]).column))
74 executeConstBuckets(block, arguments, result);
75 else
76 throw Exception(
77 "The second argument of function " + getName() + " (number of buckets) must be constant", ErrorCodes::BAD_ARGUMENTS);
78 }
79
80private:
81 using HashType = typename Impl::HashType;
82 using ResultType = typename Impl::ResultType;
83 using BucketsType = typename Impl::BucketsType;
84
85 template <typename T>
86 inline BucketsType checkBucketsRange(T buckets)
87 {
88 if (unlikely(buckets <= 0))
89 throw Exception(
90 "The second argument of function " + getName() + " (number of buckets) must be positive number", ErrorCodes::BAD_ARGUMENTS);
91
92 if (unlikely(static_cast<UInt64>(buckets) > Impl::max_buckets))
93 throw Exception("The value of the second argument of function " + getName() + " (number of buckets) must not be greater than "
94 + std::to_string(Impl::max_buckets), ErrorCodes::BAD_ARGUMENTS);
95
96 return static_cast<BucketsType>(buckets);
97 }
98
99 void executeConstBuckets(Block & block, const ColumnNumbers & arguments, size_t result)
100 {
101 Field buckets_field = (*block.getByPosition(arguments[1]).column)[0];
102 BucketsType num_buckets;
103
104 if (buckets_field.getType() == Field::Types::Int64)
105 num_buckets = checkBucketsRange(buckets_field.get<Int64>());
106 else if (buckets_field.getType() == Field::Types::UInt64)
107 num_buckets = checkBucketsRange(buckets_field.get<UInt64>());
108 else
109 throw Exception("Illegal type " + String(buckets_field.getTypeName()) + " of the second argument of function " + getName(),
110 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
111
112 const auto & hash_col = block.getByPosition(arguments[0]).column;
113 const IDataType * hash_type = block.getByPosition(arguments[0]).type.get();
114 auto res_col = ColumnVector<ResultType>::create();
115
116 WhichDataType which(hash_type);
117
118 if (which.isUInt8())
119 executeType<UInt8>(hash_col, num_buckets, res_col.get());
120 else if (which.isUInt16())
121 executeType<UInt16>(hash_col, num_buckets, res_col.get());
122 else if (which.isUInt32())
123 executeType<UInt32>(hash_col, num_buckets, res_col.get());
124 else if (which.isUInt64())
125 executeType<UInt64>(hash_col, num_buckets, res_col.get());
126 else if (which.isInt8())
127 executeType<Int8>(hash_col, num_buckets, res_col.get());
128 else if (which.isInt16())
129 executeType<Int16>(hash_col, num_buckets, res_col.get());
130 else if (which.isInt32())
131 executeType<Int32>(hash_col, num_buckets, res_col.get());
132 else if (which.isInt64())
133 executeType<Int64>(hash_col, num_buckets, res_col.get());
134 else
135 throw Exception("Illegal type " + hash_type->getName() + " of the first argument of function " + getName(),
136 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
137
138 block.getByPosition(result).column = std::move(res_col);
139 }
140
141 template <typename CurrentHashType>
142 void executeType(const ColumnPtr & col_hash_ptr, BucketsType num_buckets, ColumnVector<ResultType> * col_result)
143 {
144 auto col_hash = checkAndGetColumn<ColumnVector<CurrentHashType>>(col_hash_ptr.get());
145 if (!col_hash)
146 throw Exception("Illegal type of the first argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
147
148 auto & vec_result = col_result->getData();
149 const auto & vec_hash = col_hash->getData();
150
151 size_t size = vec_hash.size();
152 vec_result.resize(size);
153 for (size_t i = 0; i < size; ++i)
154 vec_result[i] = Impl::apply(static_cast<HashType>(vec_hash[i]), num_buckets);
155 }
156};
157
158}
159