1#include "CatBoostModel.h"
2
3#include <Common/FieldVisitors.h>
4#include <mutex>
5#include <Columns/ColumnString.h>
6#include <Columns/ColumnFixedString.h>
7#include <Columns/ColumnVector.h>
8#include <Columns/ColumnTuple.h>
9#include <Common/typeid_cast.h>
10#include <IO/WriteBufferFromString.h>
11#include <IO/Operators.h>
12#include <Common/PODArray.h>
13#include <Common/SharedLibrary.h>
14#include <DataTypes/DataTypesNumber.h>
15#include <DataTypes/DataTypeTuple.h>
16
17namespace DB
18{
19
20namespace ErrorCodes
21{
22extern const int LOGICAL_ERROR;
23extern const int BAD_ARGUMENTS;
24extern const int CANNOT_LOAD_CATBOOST_MODEL;
25extern const int CANNOT_APPLY_CATBOOST_MODEL;
26}
27
28
29/// CatBoost wrapper interface functions.
30struct CatBoostWrapperAPI
31{
32 typedef void ModelCalcerHandle;
33
34 ModelCalcerHandle * (* ModelCalcerCreate)();
35
36 void (* ModelCalcerDelete)(ModelCalcerHandle * calcer);
37
38 const char * (* GetErrorString)();
39
40 bool (* LoadFullModelFromFile)(ModelCalcerHandle * calcer, const char * filename);
41
42 bool (* CalcModelPredictionFlat)(ModelCalcerHandle * calcer, size_t docCount,
43 const float ** floatFeatures, size_t floatFeaturesSize,
44 double * result, size_t resultSize);
45
46 bool (* CalcModelPrediction)(ModelCalcerHandle * calcer, size_t docCount,
47 const float ** floatFeatures, size_t floatFeaturesSize,
48 const char *** catFeatures, size_t catFeaturesSize,
49 double * result, size_t resultSize);
50
51 bool (* CalcModelPredictionWithHashedCatFeatures)(ModelCalcerHandle * calcer, size_t docCount,
52 const float ** floatFeatures, size_t floatFeaturesSize,
53 const int ** catFeatures, size_t catFeaturesSize,
54 double * result, size_t resultSize);
55
56 int (* GetStringCatFeatureHash)(const char * data, size_t size);
57 int (* GetIntegerCatFeatureHash)(long long val);
58
59 size_t (* GetFloatFeaturesCount)(ModelCalcerHandle* calcer);
60 size_t (* GetCatFeaturesCount)(ModelCalcerHandle* calcer);
61 size_t (* GetTreeCount)(ModelCalcerHandle* modelHandle);
62 size_t (* GetDimensionsCount)(ModelCalcerHandle* modelHandle);
63
64 bool (* CheckModelMetadataHasKey)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
65 size_t (*GetModelInfoValueSize)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
66 const char* (*GetModelInfoValue)(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
67};
68
69
70namespace
71{
72
73class CatBoostModelHolder
74{
75private:
76 CatBoostWrapperAPI::ModelCalcerHandle * handle;
77 const CatBoostWrapperAPI * api;
78public:
79 explicit CatBoostModelHolder(const CatBoostWrapperAPI * api_) : api(api_) { handle = api->ModelCalcerCreate(); }
80 ~CatBoostModelHolder() { api->ModelCalcerDelete(handle); }
81
82 CatBoostWrapperAPI::ModelCalcerHandle * get() { return handle; }
83};
84
85
86class CatBoostModelImpl : public ICatBoostModel
87{
88public:
89 CatBoostModelImpl(const CatBoostWrapperAPI * api_, const std::string & model_path) : api(api_)
90 {
91 auto handle_ = std::make_unique<CatBoostModelHolder>(api);
92 if (!handle_)
93 {
94 std::string msg = "Cannot create CatBoost model: ";
95 throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
96 }
97 if (!api->LoadFullModelFromFile(handle_->get(), model_path.c_str()))
98 {
99 std::string msg = "Cannot load CatBoost model: ";
100 throw Exception(msg + api->GetErrorString(), ErrorCodes::CANNOT_LOAD_CATBOOST_MODEL);
101 }
102
103 float_features_count = api->GetFloatFeaturesCount(handle_->get());
104 cat_features_count = api->GetCatFeaturesCount(handle_->get());
105 tree_count = 1;
106 if (api->GetDimensionsCount)
107 tree_count = api->GetDimensionsCount(handle_->get());
108
109 handle = std::move(handle_);
110 }
111
112 ColumnPtr evaluate(const ColumnRawPtrs & columns) const override
113 {
114 if (columns.empty())
115 throw Exception("Got empty columns list for CatBoost model.", ErrorCodes::BAD_ARGUMENTS);
116
117 if (columns.size() != float_features_count + cat_features_count)
118 {
119 std::string msg;
120 {
121 WriteBufferFromString buffer(msg);
122 buffer << "Number of columns is different with number of features: ";
123 buffer << columns.size() << " vs " << float_features_count << " + " << cat_features_count;
124 }
125 throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
126 }
127
128 for (size_t i = 0; i < float_features_count; ++i)
129 {
130 if (!columns[i]->isNumeric())
131 {
132 std::string msg;
133 {
134 WriteBufferFromString buffer(msg);
135 buffer << "Column " << i << " should be numeric to make float feature.";
136 }
137 throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
138 }
139 }
140
141 bool cat_features_are_strings = true;
142 for (size_t i = float_features_count; i < float_features_count + cat_features_count; ++i)
143 {
144 auto column = columns[i];
145 if (column->isNumeric())
146 cat_features_are_strings = false;
147 else if (!(typeid_cast<const ColumnString *>(column)
148 || typeid_cast<const ColumnFixedString *>(column)))
149 {
150 std::string msg;
151 {
152 WriteBufferFromString buffer(msg);
153 buffer << "Column " << i << " should be numeric or string.";
154 }
155 throw Exception(msg, ErrorCodes::BAD_ARGUMENTS);
156 }
157 }
158
159 auto result = evalImpl(columns, cat_features_are_strings);
160
161 if (tree_count == 1)
162 return result;
163
164 size_t column_size = columns.front()->size();
165 auto result_buf = result->getData().data();
166
167 /// Multiple trees case. Copy data to several columns.
168 MutableColumns mutable_columns(tree_count);
169 std::vector<Float64 *> column_ptrs(tree_count);
170 for (size_t i = 0; i < tree_count; ++i)
171 {
172 auto col = ColumnFloat64::create(column_size);
173 column_ptrs[i] = col->getData().data();
174 mutable_columns[i] = std::move(col);
175 }
176
177 Float64 * data = result_buf;
178 for (size_t row = 0; row < column_size; ++row)
179 {
180 for (size_t i = 0; i < tree_count; ++i)
181 {
182 *column_ptrs[i] = *data;
183 ++column_ptrs[i];
184 ++data;
185 }
186 }
187
188 return ColumnTuple::create(std::move(mutable_columns));
189 }
190
191 size_t getFloatFeaturesCount() const override { return float_features_count; }
192 size_t getCatFeaturesCount() const override { return cat_features_count; }
193 size_t getTreeCount() const override { return tree_count; }
194
195private:
196 std::unique_ptr<CatBoostModelHolder> handle;
197 const CatBoostWrapperAPI * api;
198 size_t float_features_count;
199 size_t cat_features_count;
200 size_t tree_count;
201
202 /// Buffer should be allocated with features_count * column->size() elements.
203 /// Place column elements in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
204 template <typename T>
205 void placeColumnAsNumber(const IColumn * column, T * buffer, size_t features_count) const
206 {
207 size_t size = column->size();
208 FieldVisitorConvertToNumber<T> visitor;
209 for (size_t i = 0; i < size; ++i)
210 {
211 /// TODO: Replace with column visitor.
212 Field field;
213 column->get(i, field);
214 *buffer = applyVisitor(visitor, field);
215 buffer += features_count;
216 }
217 }
218
219 /// Buffer should be allocated with features_count * column->size() elements.
220 /// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
221 void placeStringColumn(const ColumnString & column, const char ** buffer, size_t features_count) const
222 {
223 size_t size = column.size();
224 for (size_t i = 0; i < size; ++i)
225 {
226 *buffer = const_cast<char *>(column.getDataAtWithTerminatingZero(i).data);
227 buffer += features_count;
228 }
229 }
230
231 /// Buffer should be allocated with features_count * column->size() elements.
232 /// Place string pointers in positions buffer[0], buffer[features_count], ... , buffer[size * features_count]
233 /// Returns PODArray which holds data (because ColumnFixedString doesn't store terminating zero).
234 PODArray<char> placeFixedStringColumn(
235 const ColumnFixedString & column, const char ** buffer, size_t features_count) const
236 {
237 size_t size = column.size();
238 size_t str_size = column.getN();
239 PODArray<char> data(size * (str_size + 1));
240 char * data_ptr = data.data();
241
242 for (size_t i = 0; i < size; ++i)
243 {
244 auto ref = column.getDataAt(i);
245 memcpy(data_ptr, ref.data, ref.size);
246 data_ptr[ref.size] = 0;
247 *buffer = data_ptr;
248 data_ptr += ref.size + 1;
249 buffer += features_count;
250 }
251
252 return data;
253 }
254
255 /// Place columns into buffer, returns column which holds placed data. Buffer should contains column->size() values.
256 template <typename T>
257 ColumnPtr placeNumericColumns(const ColumnRawPtrs & columns,
258 size_t offset, size_t size, const T** buffer) const
259 {
260 if (size == 0)
261 return nullptr;
262 size_t column_size = columns[offset]->size();
263 auto data_column = ColumnVector<T>::create(size * column_size);
264 T * data = data_column->getData().data();
265 for (size_t i = 0; i < size; ++i)
266 {
267 auto column = columns[offset + i];
268 if (column->isNumeric())
269 placeColumnAsNumber(column, data + i, size);
270 }
271
272 for (size_t i = 0; i < column_size; ++i)
273 {
274 *buffer = data;
275 ++buffer;
276 data += size;
277 }
278
279 return data_column;
280 }
281
282 /// Place columns into buffer, returns data which was used for fixed string columns.
283 /// Buffer should contains column->size() values, each value contains size strings.
284 std::vector<PODArray<char>> placeStringColumns(
285 const ColumnRawPtrs & columns, size_t offset, size_t size, const char ** buffer) const
286 {
287 if (size == 0)
288 return {};
289
290 std::vector<PODArray<char>> data;
291 for (size_t i = 0; i < size; ++i)
292 {
293 auto column = columns[offset + i];
294 if (auto column_string = typeid_cast<const ColumnString *>(column))
295 placeStringColumn(*column_string, buffer + i, size);
296 else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
297 data.push_back(placeFixedStringColumn(*column_fixed_string, buffer + i, size));
298 else
299 throw Exception("Cannot place string column.", ErrorCodes::LOGICAL_ERROR);
300 }
301
302 return data;
303 }
304
305 /// Calc hash for string cat feature at ps positions.
306 template <typename Column>
307 void calcStringHashes(const Column * column, size_t ps, const int ** buffer) const
308 {
309 size_t column_size = column->size();
310 for (size_t j = 0; j < column_size; ++j)
311 {
312 auto ref = column->getDataAt(j);
313 const_cast<int *>(*buffer)[ps] = api->GetStringCatFeatureHash(ref.data, ref.size);
314 ++buffer;
315 }
316 }
317
318 /// Calc hash for int cat feature at ps position. Buffer at positions ps should contains unhashed values.
319 void calcIntHashes(size_t column_size, size_t ps, const int ** buffer) const
320 {
321 for (size_t j = 0; j < column_size; ++j)
322 {
323 const_cast<int *>(*buffer)[ps] = api->GetIntegerCatFeatureHash((*buffer)[ps]);
324 ++buffer;
325 }
326 }
327
328 /// buffer contains column->size() rows and size columns.
329 /// For int cat features calc hash inplace.
330 /// For string cat features calc hash from column rows.
331 void calcHashes(const ColumnRawPtrs & columns, size_t offset, size_t size, const int ** buffer) const
332 {
333 if (size == 0)
334 return;
335 size_t column_size = columns[offset]->size();
336
337 std::vector<PODArray<char>> data;
338 for (size_t i = 0; i < size; ++i)
339 {
340 auto column = columns[offset + i];
341 if (auto column_string = typeid_cast<const ColumnString *>(column))
342 calcStringHashes(column_string, i, buffer);
343 else if (auto column_fixed_string = typeid_cast<const ColumnFixedString *>(column))
344 calcStringHashes(column_fixed_string, i, buffer);
345 else
346 calcIntHashes(column_size, i, buffer);
347 }
348 }
349
350 /// buffer[column_size * cat_features_count] -> char * => cat_features[column_size][cat_features_count] -> char *
351 void fillCatFeaturesBuffer(const char *** cat_features, const char ** buffer,
352 size_t column_size) const
353 {
354 for (size_t i = 0; i < column_size; ++i)
355 {
356 *cat_features = buffer;
357 ++cat_features;
358 buffer += cat_features_count;
359 }
360 }
361
362 /// Convert values to row-oriented format and call evaluation function from CatBoost wrapper api.
363 /// * CalcModelPredictionFlat if no cat features
364 /// * CalcModelPrediction if all cat features are strings
365 /// * CalcModelPredictionWithHashedCatFeatures if has int cat features.
366 ColumnFloat64::MutablePtr evalImpl(
367 const ColumnRawPtrs & columns,
368 bool cat_features_are_strings) const
369 {
370 std::string error_msg = "Error occurred while applying CatBoost model: ";
371 size_t column_size = columns.front()->size();
372
373 auto result = ColumnFloat64::create(column_size * tree_count);
374 auto result_buf = result->getData().data();
375
376 if (!column_size)
377 return result;
378
379 /// Prepare float features.
380 PODArray<const float *> float_features(column_size);
381 auto float_features_buf = float_features.data();
382 /// Store all float data into single column. float_features is a list of pointers to it.
383 auto float_features_col = placeNumericColumns<float>(columns, 0, float_features_count, float_features_buf);
384
385 if (cat_features_count == 0)
386 {
387 if (!api->CalcModelPredictionFlat(handle->get(), column_size,
388 float_features_buf, float_features_count,
389 result_buf, column_size * tree_count))
390 {
391
392 throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
393 }
394 return result;
395 }
396
397 /// Prepare cat features.
398 if (cat_features_are_strings)
399 {
400 /// cat_features_holder stores pointers to ColumnString data or fixed_strings_data.
401 PODArray<const char *> cat_features_holder(cat_features_count * column_size);
402 PODArray<const char **> cat_features(column_size);
403 auto cat_features_buf = cat_features.data();
404
405 fillCatFeaturesBuffer(cat_features_buf, cat_features_holder.data(), column_size);
406 /// Fixed strings are stored without termination zero, so have to copy data into fixed_strings_data.
407 auto fixed_strings_data = placeStringColumns(columns, float_features_count,
408 cat_features_count, cat_features_holder.data());
409
410 if (!api->CalcModelPrediction(handle->get(), column_size,
411 float_features_buf, float_features_count,
412 cat_features_buf, cat_features_count,
413 result_buf, column_size * tree_count))
414 {
415 throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
416 }
417 }
418 else
419 {
420 PODArray<const int *> cat_features(column_size);
421 auto cat_features_buf = cat_features.data();
422 auto cat_features_col = placeNumericColumns<int>(columns, float_features_count,
423 cat_features_count, cat_features_buf);
424 calcHashes(columns, float_features_count, cat_features_count, cat_features_buf);
425 if (!api->CalcModelPredictionWithHashedCatFeatures(
426 handle->get(), column_size,
427 float_features_buf, float_features_count,
428 cat_features_buf, cat_features_count,
429 result_buf, column_size * tree_count))
430 {
431 throw Exception(error_msg + api->GetErrorString(), ErrorCodes::CANNOT_APPLY_CATBOOST_MODEL);
432 }
433 }
434
435 return result;
436 }
437};
438
439
440/// Holds CatBoost wrapper library and provides wrapper interface.
441class CatBoostLibHolder: public CatBoostWrapperAPIProvider
442{
443public:
444 explicit CatBoostLibHolder(std::string lib_path_) : lib_path(std::move(lib_path_)), lib(lib_path) { initAPI(); }
445
446 const CatBoostWrapperAPI & getAPI() const override { return api; }
447 const std::string & getCurrentPath() const { return lib_path; }
448
449private:
450 CatBoostWrapperAPI api;
451 std::string lib_path;
452 SharedLibrary lib;
453
454 void initAPI();
455
456 template <typename T>
457 void load(T& func, const std::string & name) { func = lib.get<T>(name); }
458
459 template <typename T>
460 void tryLoad(T& func, const std::string & name) { func = lib.tryGet<T>(name); }
461};
462
463void CatBoostLibHolder::initAPI()
464{
465 load(api.ModelCalcerCreate, "ModelCalcerCreate");
466 load(api.ModelCalcerDelete, "ModelCalcerDelete");
467 load(api.GetErrorString, "GetErrorString");
468 load(api.LoadFullModelFromFile, "LoadFullModelFromFile");
469 load(api.CalcModelPredictionFlat, "CalcModelPredictionFlat");
470 load(api.CalcModelPrediction, "CalcModelPrediction");
471 load(api.CalcModelPredictionWithHashedCatFeatures, "CalcModelPredictionWithHashedCatFeatures");
472 load(api.GetStringCatFeatureHash, "GetStringCatFeatureHash");
473 load(api.GetIntegerCatFeatureHash, "GetIntegerCatFeatureHash");
474 load(api.GetFloatFeaturesCount, "GetFloatFeaturesCount");
475 load(api.GetCatFeaturesCount, "GetCatFeaturesCount");
476 tryLoad(api.CheckModelMetadataHasKey, "CheckModelMetadataHasKey");
477 tryLoad(api.GetModelInfoValueSize, "GetModelInfoValueSize");
478 tryLoad(api.GetModelInfoValue, "GetModelInfoValue");
479 tryLoad(api.GetTreeCount, "GetTreeCount");
480 tryLoad(api.GetDimensionsCount, "GetDimensionsCount");
481}
482
483std::shared_ptr<CatBoostLibHolder> getCatBoostWrapperHolder(const std::string & lib_path)
484{
485 static std::weak_ptr<CatBoostLibHolder> ptr;
486 static std::mutex mutex;
487
488 std::lock_guard lock(mutex);
489 auto result = ptr.lock();
490
491 if (!result || result->getCurrentPath() != lib_path)
492 {
493 result = std::make_shared<CatBoostLibHolder>(lib_path);
494 /// This assignment is not atomic, which prevents from creating lock only inside 'if'.
495 ptr = result;
496 }
497
498 return result;
499}
500
501}
502
503
504CatBoostModel::CatBoostModel(std::string name_, std::string model_path_, std::string lib_path_,
505 const ExternalLoadableLifetime & lifetime_)
506 : name(std::move(name_)), model_path(std::move(model_path_)), lib_path(std::move(lib_path_)), lifetime(lifetime_)
507{
508 api_provider = getCatBoostWrapperHolder(lib_path);
509 api = &api_provider->getAPI();
510 model = std::make_unique<CatBoostModelImpl>(api, model_path);
511 float_features_count = model->getFloatFeaturesCount();
512 cat_features_count = model->getCatFeaturesCount();
513 tree_count = model->getTreeCount();
514}
515
516const ExternalLoadableLifetime & CatBoostModel::getLifetime() const
517{
518 return lifetime;
519}
520
521bool CatBoostModel::isModified() const
522{
523 return true;
524}
525
526std::shared_ptr<const IExternalLoadable> CatBoostModel::clone() const
527{
528 return std::make_shared<CatBoostModel>(name, model_path, lib_path, lifetime);
529}
530
531size_t CatBoostModel::getFloatFeaturesCount() const
532{
533 return float_features_count;
534}
535
536size_t CatBoostModel::getCatFeaturesCount() const
537{
538 return cat_features_count;
539}
540
541size_t CatBoostModel::getTreeCount() const
542{
543 return tree_count;
544}
545
546DataTypePtr CatBoostModel::getReturnType() const
547{
548 auto type = std::make_shared<DataTypeFloat64>();
549 if (tree_count == 1)
550 return type;
551
552 DataTypes types(tree_count, type);
553
554 return std::make_shared<DataTypeTuple>(types);
555}
556
557ColumnPtr CatBoostModel::evaluate(const ColumnRawPtrs & columns) const
558{
559 if (!model)
560 throw Exception("CatBoost model was not loaded.", ErrorCodes::LOGICAL_ERROR);
561 return model->evaluate(columns);
562}
563
564}
565