1#include "arg.h"
2#include "common.h"
3#include "log.h"
4#include "llama.h"
5
6#include <cmath>
7#include <cstdio>
8#include <cstring>
9#include <ctime>
10#include <vector>
11
12#if defined(_MSC_VER)
13#pragma warning(disable: 4244 4267) // possible loss of data
14#endif
15
16int main(int argc, char ** argv) {
17 common_params params;
18 params.escape = false;
19
20 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_FINETUNE)) {
21 return 1;
22 }
23
24 if (params.use_mmap) {
25 LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
26 __func__);
27 params.use_mmap = false;
28 }
29 if (params.cache_type_k != GGML_TYPE_F32) {
30 LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
31 params.cache_type_k = GGML_TYPE_F32;
32 }
33 if (params.cache_type_v != GGML_TYPE_F32) {
34 LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
35 params.cache_type_v = GGML_TYPE_F32;
36 }
37
38 common_init();
39 llama_backend_init();
40 llama_numa_init(numa: params.numa);
41 // load the model and apply lora adapter, if any
42 common_init_result llama_init = common_init_from_params(params);
43 llama_model_ptr & model = llama_init.model;
44 llama_context_ptr & ctx = llama_init.context;
45
46 if (model == NULL) {
47 LOG_ERR("%s: unable to load model\n", __func__);
48 return 1;
49 }
50
51 // print system information
52 {
53 LOG_INF("\n");
54 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
55 }
56
57 std::vector<llama_token> tokens = common_tokenize(ctx: ctx.get(), text: params.prompt, add_special: true);
58 ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx: ctx.get(), tokens, stride: llama_n_ctx(ctx: ctx.get()) / 2);
59
60 struct lr_opt & lr = params.lr;
61 LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
62 ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs,
63 (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
64
65 struct llama_opt_params lopt_params{
66 /*n_ctx_train =*/0,
67 /*param_filter =*/llama_opt_param_filter_all,
68 /*param_filter_ud =*/nullptr,
69 /*get_opt_pars =*/common_opt_lr_pars,
70 /*get_opt_pars_ud =*/&params.lr,
71 /*optimizer_type =*/params.optimizer,
72 };
73 llama_opt_init(lctx: ctx.get(), model: model.get(), lopt_params);
74
75 const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
76
77 ggml_opt_result_t result_train = ggml_opt_result_init();
78 ggml_opt_result_t result_eval = ggml_opt_result_init();
79
80 for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
81 llama_opt_epoch(lctx: ctx.get(), dataset, result_train, result_eval, idata_split,
82 callback_train: ggml_opt_epoch_callback_progress_bar, callback_eval: ggml_opt_epoch_callback_progress_bar);
83 fprintf(stderr, format: "\n");
84
85 ggml_opt_result_reset(result: result_train);
86 ggml_opt_result_reset(result: result_eval);
87 }
88 ggml_opt_result_free(result: result_train);
89 ggml_opt_result_free(result: result_eval);
90
91 llama_model_save_to_file(model: model.get(), path_model: params.out_file.c_str());
92
93 llama_backend_free();
94
95 return 0;
96}
97