| 1 | #pragma once |
| 2 | #include <Interpreters/IExternalLoadable.h> |
| 3 | #include <Columns/IColumn.h> |
| 4 | #include <Columns/ColumnsNumber.h> |
| 5 | |
| 6 | |
| 7 | namespace DB |
| 8 | { |
| 9 | |
| 10 | /// CatBoost wrapper interface functions. |
| 11 | struct CatBoostWrapperAPI; |
| 12 | class CatBoostWrapperAPIProvider |
| 13 | { |
| 14 | public: |
| 15 | virtual ~CatBoostWrapperAPIProvider() = default; |
| 16 | virtual const CatBoostWrapperAPI & getAPI() const = 0; |
| 17 | }; |
| 18 | |
| 19 | /// CatBoost model interface. |
| 20 | class ICatBoostModel |
| 21 | { |
| 22 | public: |
| 23 | virtual ~ICatBoostModel() = default; |
| 24 | /// Evaluate model. Use first `float_features_count` columns as float features, |
| 25 | /// the others `cat_features_count` as categorical features. |
| 26 | virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0; |
| 27 | |
| 28 | virtual size_t getFloatFeaturesCount() const = 0; |
| 29 | virtual size_t getCatFeaturesCount() const = 0; |
| 30 | virtual size_t getTreeCount() const = 0; |
| 31 | }; |
| 32 | |
| 33 | class IDataType; |
| 34 | using DataTypePtr = std::shared_ptr<const IDataType>; |
| 35 | |
| 36 | /// General ML model evaluator interface. |
| 37 | class IModel : public IExternalLoadable |
| 38 | { |
| 39 | public: |
| 40 | virtual ColumnPtr evaluate(const ColumnRawPtrs & columns) const = 0; |
| 41 | virtual std::string getTypeName() const = 0; |
| 42 | virtual DataTypePtr getReturnType() const = 0; |
| 43 | }; |
| 44 | |
| 45 | class CatBoostModel : public IModel |
| 46 | { |
| 47 | public: |
| 48 | CatBoostModel(std::string name, std::string model_path, |
| 49 | std::string lib_path, const ExternalLoadableLifetime & lifetime); |
| 50 | |
| 51 | ColumnPtr evaluate(const ColumnRawPtrs & columns) const override; |
| 52 | std::string getTypeName() const override { return "catboost" ; } |
| 53 | |
| 54 | size_t getFloatFeaturesCount() const; |
| 55 | size_t getCatFeaturesCount() const; |
| 56 | size_t getTreeCount() const; |
| 57 | DataTypePtr getReturnType() const override; |
| 58 | |
| 59 | /// IExternalLoadable interface. |
| 60 | |
| 61 | const ExternalLoadableLifetime & getLifetime() const override; |
| 62 | |
| 63 | const std::string & getLoadableName() const override { return name; } |
| 64 | |
| 65 | bool supportUpdates() const override { return true; } |
| 66 | |
| 67 | bool isModified() const override; |
| 68 | |
| 69 | std::shared_ptr<const IExternalLoadable> clone() const override; |
| 70 | |
| 71 | private: |
| 72 | const std::string name; |
| 73 | std::string model_path; |
| 74 | std::string lib_path; |
| 75 | ExternalLoadableLifetime lifetime; |
| 76 | std::shared_ptr<CatBoostWrapperAPIProvider> api_provider; |
| 77 | const CatBoostWrapperAPI * api; |
| 78 | |
| 79 | std::unique_ptr<ICatBoostModel> model; |
| 80 | |
| 81 | size_t float_features_count; |
| 82 | size_t cat_features_count; |
| 83 | size_t tree_count; |
| 84 | |
| 85 | void init(); |
| 86 | }; |
| 87 | |
| 88 | } |
| 89 | |