1#pragma once
2#include <Interpreters/IExternalLoadable.h>
3#include <Columns/IColumn.h>
4#include <Columns/ColumnsNumber.h>
5
6
7namespace DB
8{
9
10/// CatBoost wrapper interface functions.
11struct CatBoostWrapperAPI;
12class CatBoostWrapperAPIProvider
13{
14public:
15 virtual ~CatBoostWrapperAPIProvider() = default;
16 virtual const CatBoostWrapperAPI & getAPI() const = 0;
17};
18
19/// CatBoost model interface.
20class ICatBoostModel
21{
22public:
23 virtual ~ICatBoostModel() = default;
24 /// Evaluate model. Use first `float_features_count` columns as float features,
25 /// the others `cat_features_count` as categorical features.
26 virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
27
28 virtual size_t getFloatFeaturesCount() const = 0;
29 virtual size_t getCatFeaturesCount() const = 0;
30 virtual size_t getTreeCount() const = 0;
31};
32
33class IDataType;
34using DataTypePtr = std::shared_ptr<const IDataType>;
35
36/// General ML model evaluator interface.
37class IModel : public IExternalLoadable
38{
39public:
40 virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0;
41 virtual std::string getTypeName() const = 0;
42 virtual DataTypePtr getReturnType() const = 0;
43};
44
45class CatBoostModel : public IModel
46{
47public:
48 CatBoostModel(std::string name, std::string model_path,
49 std::string lib_path, const ExternalLoadableLifetime & lifetime);
50
51 ColumnPtr evaluate(const ColumnRawPtrs & columns) const override;
52 std::string getTypeName() const override { return "catboost"; }
53
54 size_t getFloatFeaturesCount() const;
55 size_t getCatFeaturesCount() const;
56 size_t getTreeCount() const;
57 DataTypePtr getReturnType() const override;
58
59 /// IExternalLoadable interface.
60
61 const ExternalLoadableLifetime & getLifetime() const override;
62
63 const std::string & getLoadableName() const override { return name; }
64
65 bool supportUpdates() const override { return true; }
66
67 bool isModified() const override;
68
69 std::shared_ptr<const IExternalLoadable> clone() const override;
70
71private:
72 const std::string name;
73 std::string model_path;
74 std::string lib_path;
75 ExternalLoadableLifetime lifetime;
76 std::shared_ptr<CatBoostWrapperAPIProvider> api_provider;
77 const CatBoostWrapperAPI * api;
78
79 std::unique_ptr<ICatBoostModel> model;
80
81 size_t float_features_count;
82 size_t cat_features_count;
83 size_t tree_count;
84
85 void init();
86};
87
88}
89