1#include "config_functions.h"
2#if USE_H3
3# include <vector>
4# include <Columns/ColumnArray.h>
5# include <Columns/ColumnsNumber.h>
6# include <DataTypes/DataTypeArray.h>
7# include <DataTypes/DataTypesNumber.h>
8# include <DataTypes/IDataType.h>
9# include <Functions/FunctionFactory.h>
10# include <Functions/IFunction.h>
11# include <Common/typeid_cast.h>
12# include <ext/range.h>
13
14# include <h3api.h>
15
16
17namespace DB
18{
19class FunctionH3KRing : public IFunction
20{
21public:
22 static constexpr auto name = "h3kRing";
23
24 static FunctionPtr create(const Context &) { return std::make_shared<FunctionH3KRing>(); }
25
26 std::string getName() const override { return name; }
27
28 size_t getNumberOfArguments() const override { return 2; }
29 bool useDefaultImplementationForConstants() const override { return true; }
30
31 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
32 {
33 auto arg = arguments[0].get();
34 if (!WhichDataType(arg).isUInt64())
35 throw Exception(
36 "Illegal type " + arg->getName() + " of argument " + std::to_string(1) + " of function " + getName() + ". Must be UInt64",
37 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
38
39 arg = arguments[1].get();
40 if (!isInteger(arg))
41 throw Exception(
42 "Illegal type " + arg->getName() + " of argument " + std::to_string(2) + " of function " + getName() + ". Must be integer",
43 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
44
45 return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt64>());
46 }
47
48 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
49 {
50 const auto col_hindex = block.getByPosition(arguments[0]).column.get();
51 const auto col_k = block.getByPosition(arguments[1]).column.get();
52
53 auto dst = ColumnArray::create(ColumnUInt64::create());
54 auto & dst_data = dst->getData();
55 auto & dst_offsets = dst->getOffsets();
56 dst_offsets.resize(input_rows_count);
57 auto current_offset = 0;
58
59 std::vector<H3Index> hindex_vec;
60
61 for (const auto row : ext::range(0, input_rows_count))
62 {
63 const H3Index origin_hindex = col_hindex->getUInt(row);
64 const int k = col_k->getInt(row);
65
66 const auto vec_size = maxKringSize(k);
67 hindex_vec.resize(vec_size);
68 kRing(origin_hindex, k, hindex_vec.data());
69
70 dst_data.reserve(dst_data.size() + vec_size);
71 for (auto hindex : hindex_vec)
72 {
73 if (hindex != 0)
74 {
75 ++current_offset;
76 dst_data.insert(hindex);
77 }
78 }
79 dst_offsets[row] = current_offset;
80 }
81
82 block.getByPosition(result).column = std::move(dst);
83 }
84};
85
86
87void registerFunctionH3KRing(FunctionFactory & factory)
88{
89 factory.registerFunction<FunctionH3KRing>();
90}
91
92}
93#endif
94