| 1 | #include "arg.h" |
| 2 | #include "common.h" |
| 3 | #include "log.h" |
| 4 | #include "llama.h" |
| 5 | |
| 6 | #include <cmath> |
| 7 | #include <cstdio> |
| 8 | #include <cstring> |
| 9 | #include <ctime> |
| 10 | #include <vector> |
| 11 | |
| 12 | #if defined(_MSC_VER) |
| 13 | #pragma warning(disable: 4244 4267) // possible loss of data |
| 14 | #endif |
| 15 | |
| 16 | int main(int argc, char ** argv) { |
| 17 | common_params params; |
| 18 | params.escape = false; |
| 19 | |
| 20 | if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_FINETUNE)) { |
| 21 | return 1; |
| 22 | } |
| 23 | |
| 24 | if (params.use_mmap) { |
| 25 | LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n" , |
| 26 | __func__); |
| 27 | params.use_mmap = false; |
| 28 | } |
| 29 | if (params.cache_type_k != GGML_TYPE_F32) { |
| 30 | LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n" , __func__); |
| 31 | params.cache_type_k = GGML_TYPE_F32; |
| 32 | } |
| 33 | if (params.cache_type_v != GGML_TYPE_F32) { |
| 34 | LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n" , __func__); |
| 35 | params.cache_type_v = GGML_TYPE_F32; |
| 36 | } |
| 37 | |
| 38 | common_init(); |
| 39 | llama_backend_init(); |
| 40 | llama_numa_init(numa: params.numa); |
| 41 | // load the model and apply lora adapter, if any |
| 42 | common_init_result llama_init = common_init_from_params(params); |
| 43 | llama_model_ptr & model = llama_init.model; |
| 44 | llama_context_ptr & ctx = llama_init.context; |
| 45 | |
| 46 | if (model == NULL) { |
| 47 | LOG_ERR("%s: unable to load model\n" , __func__); |
| 48 | return 1; |
| 49 | } |
| 50 | |
| 51 | // print system information |
| 52 | { |
| 53 | LOG_INF("\n" ); |
| 54 | LOG_INF("%s\n" , common_params_get_system_info(params).c_str()); |
| 55 | } |
| 56 | |
| 57 | std::vector<llama_token> tokens = common_tokenize(ctx: ctx.get(), text: params.prompt, add_special: true); |
| 58 | ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx: ctx.get(), tokens, stride: llama_n_ctx(ctx: ctx.get()) / 2); |
| 59 | |
| 60 | struct lr_opt & lr = params.lr; |
| 61 | LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n" , |
| 62 | ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs, |
| 63 | (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split); |
| 64 | |
| 65 | struct llama_opt_params lopt_params{ |
| 66 | /*n_ctx_train =*/0, |
| 67 | /*param_filter =*/llama_opt_param_filter_all, |
| 68 | /*param_filter_ud =*/nullptr, |
| 69 | /*get_opt_pars =*/common_opt_lr_pars, |
| 70 | /*get_opt_pars_ud =*/¶ms.lr, |
| 71 | /*optimizer_type =*/params.optimizer, |
| 72 | }; |
| 73 | llama_opt_init(lctx: ctx.get(), model: model.get(), lopt_params); |
| 74 | |
| 75 | const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split); |
| 76 | |
| 77 | ggml_opt_result_t result_train = ggml_opt_result_init(); |
| 78 | ggml_opt_result_t result_eval = ggml_opt_result_init(); |
| 79 | |
| 80 | for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) { |
| 81 | llama_opt_epoch(lctx: ctx.get(), dataset, result_train, result_eval, idata_split, |
| 82 | callback_train: ggml_opt_epoch_callback_progress_bar, callback_eval: ggml_opt_epoch_callback_progress_bar); |
| 83 | fprintf(stderr, format: "\n" ); |
| 84 | |
| 85 | ggml_opt_result_reset(result: result_train); |
| 86 | ggml_opt_result_reset(result: result_eval); |
| 87 | } |
| 88 | ggml_opt_result_free(result: result_train); |
| 89 | ggml_opt_result_free(result: result_eval); |
| 90 | |
| 91 | llama_model_save_to_file(model: model.get(), path_model: params.out_file.c_str()); |
| 92 | |
| 93 | llama_backend_free(); |
| 94 | |
| 95 | return 0; |
| 96 | } |
| 97 | |