1 | #pragma once |
2 | #include <Interpreters/IExternalLoadable.h> |
3 | #include <Columns/IColumn.h> |
4 | #include <Columns/ColumnsNumber.h> |
5 | |
6 | |
7 | namespace DB |
8 | { |
9 | |
10 | /// CatBoost wrapper interface functions. |
11 | struct CatBoostWrapperAPI; |
12 | class CatBoostWrapperAPIProvider |
13 | { |
14 | public: |
15 | virtual ~CatBoostWrapperAPIProvider() = default; |
16 | virtual const CatBoostWrapperAPI & getAPI() const = 0; |
17 | }; |
18 | |
19 | /// CatBoost model interface. |
20 | class ICatBoostModel |
21 | { |
22 | public: |
23 | virtual ~ICatBoostModel() = default; |
24 | /// Evaluate model. Use first `float_features_count` columns as float features, |
25 | /// the others `cat_features_count` as categorical features. |
26 | virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0; |
27 | |
28 | virtual size_t getFloatFeaturesCount() const = 0; |
29 | virtual size_t getCatFeaturesCount() const = 0; |
30 | virtual size_t getTreeCount() const = 0; |
31 | }; |
32 | |
33 | class IDataType; |
34 | using DataTypePtr = std::shared_ptr<const IDataType>; |
35 | |
36 | /// General ML model evaluator interface. |
37 | class IModel : public IExternalLoadable |
38 | { |
39 | public: |
40 | virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0; |
41 | virtual std::string getTypeName() const = 0; |
42 | virtual DataTypePtr getReturnType() const = 0; |
43 | }; |
44 | |
45 | class CatBoostModel : public IModel |
46 | { |
47 | public: |
48 | CatBoostModel(std::string name, std::string model_path, |
49 | std::string lib_path, const ExternalLoadableLifetime & lifetime); |
50 | |
51 | ColumnPtr evaluate(const ColumnRawPtrs & columns) const override; |
52 | std::string getTypeName() const override { return "catboost" ; } |
53 | |
54 | size_t getFloatFeaturesCount() const; |
55 | size_t getCatFeaturesCount() const; |
56 | size_t getTreeCount() const; |
57 | DataTypePtr getReturnType() const override; |
58 | |
59 | /// IExternalLoadable interface. |
60 | |
61 | const ExternalLoadableLifetime & getLifetime() const override; |
62 | |
63 | const std::string & getLoadableName() const override { return name; } |
64 | |
65 | bool supportUpdates() const override { return true; } |
66 | |
67 | bool isModified() const override; |
68 | |
69 | std::shared_ptr<const IExternalLoadable> clone() const override; |
70 | |
71 | private: |
72 | const std::string name; |
73 | std::string model_path; |
74 | std::string lib_path; |
75 | ExternalLoadableLifetime lifetime; |
76 | std::shared_ptr<CatBoostWrapperAPIProvider> api_provider; |
77 | const CatBoostWrapperAPI * api; |
78 | |
79 | std::unique_ptr<ICatBoostModel> model; |
80 | |
81 | size_t float_features_count; |
82 | size_t cat_features_count; |
83 | size_t tree_count; |
84 | |
85 | void init(); |
86 | }; |
87 | |
88 | } |
89 | |