1#include "arg.h"
2#include "chat.h"
3#include "common.h"
4#include "llama.h"
5#include "log.h"
6
7#include <limits.h>
8
9#include <algorithm>
10#include <cmath>
11#include <cstring>
12#include <limits>
13#include <random>
14#include <string>
15#include <vector>
16
17enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
18
19// Unified transfer scheduling methods
20enum transfer_schedule {
21 TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
22 BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
23};
24
25typedef bool (*diffusion_step_callback_t)(int32_t step,
26 int32_t total_steps,
27 const llama_token * tokens,
28 int32_t n_tokens,
29 void * user_data);
30
31struct diffusion_params {
32 int32_t steps = 0;
33 float temperature = 0;
34 llama_token mask_token_id = LLAMA_TOKEN_NULL;
35 diffusion_step_callback_t step_callback = nullptr;
36 void * step_callback_user_data = nullptr;
37 int32_t seed = 0;
38 bool visual_mode = false;
39 bool shift_logits = false; // Shift logits by -1 after decode
40
41 float top_p = 0.;
42 int32_t top_k = 0.;
43
44 diffusion_algorithm algorithm = CONFIDENCE_BASED;
45 transfer_schedule schedule = TIMESTEP_BASED;
46
47 float cfg_scale = 0.; // Config scale for classifier-free guidance
48 float eps = 0.; // Timestep scheduling
49 int32_t block_length = 0; // Block size (for block scheduling)
50 float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
51 bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
52
53 int32_t max_length = 0; // Maximum sequence length
54};
55
56struct callback_data {
57 diffusion_params * diff_params;
58 const llama_vocab * vocab;
59 int32_t n_input;
60};
61
62static float calculate_confidence(const llama_token_data_array & cur_p,
63 diffusion_algorithm algorithm,
64 std::mt19937 & rng) {
65 switch (algorithm) {
66 case CONFIDENCE_BASED:
67 return cur_p.data[cur_p.selected].p; // Selected token probability
68
69 case ENTROPY_BASED:
70 {
71 float entropy = 0.0f;
72 const float epsilon = 1e-10f;
73 for (size_t i = 0; i < cur_p.size; i++) {
74 float prob = cur_p.data[i].p;
75 entropy += prob * logf(x: prob + epsilon);
76 }
77 return -entropy; // Higher entropy = lower confidence
78 }
79
80 case MARGIN_BASED:
81 return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
82
83 case RANDOM:
84 {
85 std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
86 return uniform(rng); // Random confidence
87 }
88
89 case ORIGIN:
90 return cur_p.data[cur_p.selected].p;
91
92 default:
93 return 0.0f;
94 }
95}
96
97// Unified transfer count calculation function
98static int32_t calculate_transfer_count(int32_t step,
99 int32_t total_steps,
100 int32_t remaining_masked,
101 transfer_schedule schedule,
102 float eps,
103 const std::vector<int32_t> & num_transfer_tokens = {}) {
104 switch (schedule) {
105 case TIMESTEP_BASED:
106 {
107 float t = 1.0f - (float) step / total_steps * (1.0f - eps);
108 float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
109 float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
110 return (int32_t) (remaining_masked * p_transfer);
111 }
112
113 case BLOCK_BASED:
114 if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
115 return num_transfer_tokens[step];
116 }
117 return remaining_masked / (total_steps - step); // Fallback
118
119 default:
120 return remaining_masked / (total_steps - step);
121 }
122}
123
124static bool diffusion_step_callback(int32_t step,
125 int32_t total_steps,
126 const llama_token * tokens,
127 int32_t n_tokens,
128 void * user_data) {
129 (void) user_data;
130
131 callback_data * data = static_cast<callback_data *>(user_data);
132
133 auto print_progress_bar = [](int32_t step, int32_t total_steps) {
134 int progress_percent = (step * 100) / total_steps;
135 int progress_bars = (step * 50) / total_steps;
136 LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
137 step,
138 total_steps,
139 std::string(progress_bars, '=').c_str(),
140 std::string(50 - progress_bars, ' ').c_str(),
141 progress_percent);
142 };
143
144 if (data->diff_params->visual_mode) {
145 // Visual mode: clear
146 LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
147
148 print_progress_bar(step, total_steps);
149
150 LOG_INF("\n");
151
152 std::string current_text = " ";
153
154 for (int32_t i = data->n_input; i < n_tokens; i++) {
155 std::string token_str;
156 if (tokens[i] != llama_vocab_mask(vocab: data->vocab)) {
157 char piece[256];
158 int n_chars = llama_token_to_piece(vocab: data->vocab, token: tokens[i], buf: piece, length: sizeof(piece), lstrip: 0, special: false);
159 if (n_chars > 0) {
160 piece[n_chars] = '\0';
161 token_str = piece;
162 }
163 } else {
164 token_str = " ";
165 }
166
167 current_text += token_str;
168 }
169
170 LOG_INF("%s\n", current_text.c_str());
171 } else {
172 print_progress_bar(step, total_steps);
173 }
174
175 return true;
176}
177
178static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
179 if (temperature == 0.0f) {
180 return;
181 }
182
183 std::uniform_real_distribution<double> uniform(0.0, 1.0);
184 for (int32_t i = 0; i < n_vocab; i++) {
185 double noise = uniform(rng);
186 // Prevent log(0)
187 noise = std::max(a: noise, b: 1e-20);
188 double gumbel_noise = std::pow(x: -std::log(x: noise), y: temperature);
189 logits[i] = std::exp(x: logits[i]) / gumbel_noise;
190 }
191}
192
193static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
194 std::vector<int32_t> num_transfer_tokens(steps);
195
196 int32_t base = mask_count / steps;
197 int32_t remainder = mask_count % steps;
198
199 for (int32_t i = 0; i < steps; i++) {
200 num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
201 }
202
203 return num_transfer_tokens;
204}
205
206static void diffusion_generate(llama_context * ctx,
207 const llama_token * input_tokens,
208 llama_token * output_tokens,
209 int32_t n_input,
210 const diffusion_params & params,
211 int32_t & n_generated) {
212 n_generated = 0;
213 if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
214 return;
215 }
216
217 const llama_model * model = llama_get_model(ctx);
218
219 // Initialize with input and pad with mask tokens
220 std::copy(first: input_tokens, last: input_tokens + n_input, result: output_tokens);
221 std::fill(first: output_tokens + n_input, last: output_tokens + params.max_length, value: params.mask_token_id);
222
223 std::mt19937 rng(params.seed);
224
225 llama_set_causal_attn(ctx, causal_attn: false);
226
227 int32_t n_vocab = llama_vocab_n_tokens(vocab: llama_model_get_vocab(model));
228
229 std::vector<llama_token_data> candidates(n_vocab);
230 std::vector<llama_token_data> conf_candidates;
231 conf_candidates.reserve(n: params.max_length);
232 std::vector<int32_t> mask_positions;
233 mask_positions.reserve(n: params.max_length);
234
235 // Setup sampler chain
236 struct llama_sampler * sampler = llama_sampler_chain_init(params: llama_sampler_chain_default_params());
237 if (params.top_k > 0) {
238 llama_sampler_chain_add(chain: sampler, smpl: llama_sampler_init_top_k(k: params.top_k));
239 }
240 if (params.top_p < 1.0f) {
241 llama_sampler_chain_add(chain: sampler, smpl: llama_sampler_init_top_p(p: params.top_p, min_keep: 1));
242 }
243 if (params.temperature > 0.0f) {
244 llama_sampler_chain_add(chain: sampler, smpl: llama_sampler_init_temp(t: params.temperature));
245 }
246 llama_sampler_chain_add(chain: sampler, smpl: llama_sampler_init_dist(seed: params.seed));
247
248 struct llama_sampler * dist_sampler = llama_sampler_init_dist(seed: params.seed);
249
250 llama_batch batch = llama_batch_init(n_tokens: params.max_length, embd: 0, n_seq_max: 1);
251 batch.n_tokens = params.max_length;
252
253 // Pre-allocate buffers for CFG if needed
254 int32_t logits_size = n_vocab * params.max_length;
255 std::vector<float> cond_logits_buffer;
256 std::vector<llama_token> un_x_buffer;
257 if (params.cfg_scale > 0.0f) {
258 cond_logits_buffer.resize(new_size: logits_size);
259 un_x_buffer.resize(new_size: params.max_length);
260 }
261
262 // For block-based processing
263 std::vector<int32_t> num_transfer_tokens;
264 int32_t num_blocks = 1;
265 int32_t steps_per_block = params.steps;
266
267 if (params.schedule == BLOCK_BASED) {
268 GGML_ASSERT(params.max_length % params.block_length == 0);
269 num_blocks = params.max_length / params.block_length;
270 GGML_ASSERT(params.steps % num_blocks == 0);
271 steps_per_block = params.steps / num_blocks;
272 }
273
274 std::vector<float> confidence(params.max_length);
275
276 int64_t total_sampling_time = 0;
277 int64_t total_time = 0;
278 int64_t time_start = ggml_time_us();
279
280 for (int block_num = 0; block_num < num_blocks; block_num++) {
281 int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
282 int32_t block_end = (params.schedule == BLOCK_BASED) ?
283 std::min(a: n_input + (block_num + 1) * params.block_length, b: params.max_length) :
284 params.max_length;
285
286 // Count masked tokens in current block for block-based processing
287 if (params.schedule == BLOCK_BASED) {
288 int32_t block_mask_count = 0;
289 for (int i = block_start; i < block_end; i++) {
290 if (output_tokens[i] == params.mask_token_id) {
291 block_mask_count++;
292 }
293 }
294 num_transfer_tokens = get_num_transfer_tokens(mask_count: block_mask_count, steps: steps_per_block);
295 }
296
297 for (int32_t step = 0; step < steps_per_block; step++) {
298 int32_t global_step = block_num * steps_per_block + step;
299
300 if (params.step_callback) {
301 if (!params.step_callback(
302 global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
303 break;
304 }
305 }
306
307 // Setup batch
308 for (int32_t i = 0; i < params.max_length; i++) {
309 batch.token[i] = output_tokens[i];
310 batch.pos[i] = i;
311 batch.n_seq_id[i] = 1;
312 batch.seq_id[i][0] = 0;
313 batch.logits[i] = 1;
314 }
315
316 float * logits = nullptr;
317
318 if (params.cfg_scale > 0.0f) {
319 int ret = llama_decode(ctx, batch);
320 if (ret != 0) {
321 LOG_ERR("Failed to generate conditional");
322 break;
323 }
324 float * cond_logits_ptr = llama_get_logits(ctx);
325 std::memcpy(dest: cond_logits_buffer.data(), src: cond_logits_ptr, n: logits_size * sizeof(float));
326
327 // Unconditional generation (mask input)
328 std::copy(first: output_tokens, last: output_tokens + params.max_length, result: un_x_buffer.begin());
329 for (int32_t i = 0; i < n_input; i++) {
330 un_x_buffer[i] = params.mask_token_id;
331 }
332
333 for (int32_t i = 0; i < params.max_length; i++) {
334 batch.token[i] = un_x_buffer[i];
335 }
336 ret = llama_decode(ctx, batch);
337 if (ret != 0) {
338 LOG_ERR("Failed to generate unconditional");
339 break;
340 }
341 float * uncond_logits = llama_get_logits(ctx);
342
343 // Apply CFG
344 for (int32_t i = 0; i < logits_size; i++) {
345 cond_logits_buffer[i] =
346 uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
347 }
348 logits = cond_logits_buffer.data();
349 } else {
350 int ret = llama_decode(ctx, batch);
351 if (ret != 0) {
352 LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
353 break;
354 }
355 logits = llama_get_logits(ctx);
356 }
357
358 if (!logits) {
359 LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
360 break;
361 }
362
363 auto get_logits_for_pos = [&](int32_t pos) -> const float * {
364 if (params.shift_logits) {
365 return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
366 }
367 return logits + (pos) *n_vocab;
368 };
369
370 int64_t time_start_sampling = ggml_time_us();
371
372 mask_positions.clear();
373 for (int32_t i = 0; i < params.max_length; i++) {
374 if (output_tokens[i] == params.mask_token_id) {
375 // For block-based, only consider current block
376 if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
377 mask_positions.push_back(x: i);
378 }
379 }
380 }
381
382 if (mask_positions.empty()) {
383 break;
384 }
385
386 if (params.add_gumbel_noise && params.temperature > 0.0f) {
387 add_gumbel_noise(logits, n_vocab, temperature: params.temperature, rng);
388 }
389
390 if (params.algorithm == ORIGIN) {
391 int32_t transfer_count = calculate_transfer_count(
392 step, total_steps: steps_per_block, remaining_masked: mask_positions.size(), schedule: params.schedule, eps: params.eps, num_transfer_tokens);
393 float p_transfer = (float) transfer_count / mask_positions.size();
394
395 for (int32_t pos : mask_positions) {
396 if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
397 const float * pos_logits = get_logits_for_pos(pos);
398 for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
399 candidates[token_id].id = token_id;
400 candidates[token_id].logit = pos_logits[token_id];
401 candidates[token_id].p = 0.0f;
402 }
403
404 llama_token_data_array cur_p = {
405 .data: candidates.data(),
406 .size: (size_t) n_vocab,
407 .selected: -1,
408 .sorted: false,
409 };
410
411 llama_sampler_apply(smpl: sampler, cur_p: &cur_p);
412 output_tokens[pos] = cur_p.data[cur_p.selected].id;
413 }
414 }
415 } else {
416 std::vector<std::pair<float, int32_t>> confidences;
417 std::vector<llama_token> sampled_tokens(mask_positions.size());
418
419 for (size_t i = 0; i < mask_positions.size(); i++) {
420 int32_t pos = mask_positions[i];
421 const float * pos_logits = get_logits_for_pos(pos);
422
423 for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
424 candidates[token_id].logit = pos_logits[token_id];
425 candidates[token_id].p = 0.0f;
426 candidates[token_id].id = token_id;
427 }
428
429 llama_token_data_array cur_p = {
430 .data: candidates.data(),
431 .size: candidates.size(),
432 .selected: -1,
433 .sorted: false,
434 };
435
436 llama_sampler_apply(smpl: sampler, cur_p: &cur_p);
437 llama_token sampled_token = cur_p.data[cur_p.selected].id;
438
439 float conf = calculate_confidence(cur_p, algorithm: params.algorithm, rng);
440
441 sampled_tokens[i] = sampled_token;
442 confidences.emplace_back(args&: conf, args&: i);
443 }
444
445 int32_t transfer_count = calculate_transfer_count(
446 step, total_steps: steps_per_block, remaining_masked: mask_positions.size(), schedule: params.schedule, eps: params.eps, num_transfer_tokens);
447
448 if (transfer_count > 0) {
449 if (params.alg_temp == 0.0f) {
450 std::partial_sort(first: confidences.begin(),
451 middle: confidences.begin() + std::min(a: transfer_count, b: (int32_t) confidences.size()),
452 last: confidences.end(),
453 comp: [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
454 if (a.first != b.first) {
455 return a.first > b.first;
456 }
457 return a.second < b.second;
458 });
459
460 for (int32_t i = 0; i < std::min(a: transfer_count, b: (int32_t) confidences.size()); i++) {
461 int32_t mask_idx = confidences[i].second;
462 int32_t pos = mask_positions[mask_idx];
463 output_tokens[pos] = sampled_tokens[mask_idx];
464 }
465 } else {
466 conf_candidates.clear();
467 for (size_t i = 0; i < confidences.size(); i++) {
468 float conf_logit = confidences[i].first / params.alg_temp;
469 conf_candidates.emplace_back(args: llama_token_data{ .id: (int32_t) i, .logit: conf_logit, .p: 0.0f });
470 }
471
472 llama_token_data_array conf_array = {
473 .data: conf_candidates.data(),
474 .size: conf_candidates.size(),
475 .selected: -1,
476 .sorted: false,
477 };
478
479 for (int32_t i = 0; i < std::min(a: transfer_count, b: (int32_t) confidences.size()); i++) {
480 llama_sampler_apply(smpl: dist_sampler, cur_p: &conf_array);
481 int32_t selected_idx = conf_array.selected;
482 int32_t mask_idx = selected_idx;
483 int32_t pos = mask_positions[mask_idx];
484 output_tokens[pos] = sampled_tokens[mask_idx];
485
486 conf_candidates[selected_idx].p = 0.0f;
487 conf_array.selected = -1;
488 }
489 }
490 }
491 }
492
493 int64_t time_end_sampling = ggml_time_us();
494 total_sampling_time += time_end_sampling - time_start_sampling;
495 }
496 }
497
498 int64_t time_end = ggml_time_us();
499 total_time += time_end - time_start;
500
501 LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
502 total_time / 1000.0,
503 total_time / 1000.0 / params.steps,
504 total_sampling_time / 1000.0 / params.steps);
505
506 llama_batch_free(batch);
507 llama_sampler_free(smpl: sampler);
508 llama_sampler_free(smpl: dist_sampler);
509
510 n_generated = params.max_length;
511}
512
513static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
514 if (!use_chat_template) {
515 return prompt;
516 }
517
518 auto chat_templates = common_chat_templates_init(model, chat_template_override: "");
519 common_chat_templates_inputs inputs;
520 common_chat_msg system_msg;
521
522 if (!system_prompt.empty()) {
523 system_msg.role = "system";
524 system_msg.content = system_prompt;
525 inputs.messages.push_back(x: system_msg);
526 }
527
528 common_chat_msg user_msg;
529 user_msg.role = "user";
530 user_msg.content = prompt;
531
532 inputs.messages.push_back(x: user_msg);
533 inputs.add_generation_prompt = true;
534
535 auto result = common_chat_templates_apply(tmpls: chat_templates.get(), inputs);
536
537 return result.prompt;
538}
539
540int main(int argc, char ** argv) {
541 ggml_time_init();
542
543 common_params params;
544
545 if (!common_params_parse(argc, argv, params, ex: LLAMA_EXAMPLE_DIFFUSION)) {
546 return 1;
547 }
548
549 common_init();
550 llama_backend_init();
551
552 llama_model_params model_params = llama_model_default_params();
553 model_params.n_gpu_layers = params.n_gpu_layers;
554 model_params.devices = params.devices.data();
555 model_params.use_mmap = params.use_mmap;
556 model_params.use_mlock = params.use_mlock;
557 model_params.check_tensors = params.check_tensors;
558
559 llama_model * model = llama_model_load_from_file(path_model: params.model.path.c_str(), params: model_params);
560 if (!model) {
561 LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
562 return 1;
563 }
564
565 if (!llama_model_is_diffusion(model)) {
566 LOG_ERR("error: unsupported model for diffusion");
567 llama_model_free(model);
568 return 1;
569 }
570
571 llama_context_params ctx_params = llama_context_default_params();
572 ctx_params.n_ctx = params.n_ctx;
573 ctx_params.n_batch = params.n_batch;
574 ctx_params.n_ubatch = params.n_ubatch;
575 ctx_params.flash_attn_type = params.flash_attn_type;
576 ctx_params.no_perf = params.no_perf;
577 ctx_params.type_k = params.cache_type_k;
578 ctx_params.type_v = params.cache_type_v;
579
580 llama_context * ctx = llama_init_from_model(model, params: ctx_params);
581 if (!ctx) {
582 LOG_ERR("error: failed to create context\n");
583 llama_model_free(model);
584 return 1;
585 }
586
587 llama_set_n_threads(ctx, n_threads: params.cpuparams.n_threads, n_threads_batch: params.cpuparams_batch.n_threads);
588
589 const llama_vocab * vocab = llama_model_get_vocab(model);
590
591 std::string formatted_prompt = format_input_text(prompt: params.prompt, system_prompt: params.system_prompt, use_chat_template: params.enable_chat_template, model);
592
593 std::vector<llama_token> input_tokens = common_tokenize(vocab,
594 text: formatted_prompt,
595 /*add special tokens*/ add_special: true,
596 /*parse special*/ parse_special: true);
597
598 int n_input = input_tokens.size();
599
600 if (n_input >= params.n_ctx) {
601 LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
602 llama_free(ctx);
603 llama_model_free(model);
604 return 1;
605 }
606
607 llama_token mask_token_id = llama_vocab_mask(vocab);
608
609 GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
610
611 bool visual_mode = params.diffusion.visual_mode;
612
613 int32_t n_generated = 0;
614 std::vector<llama_token> output_tokens(params.n_ubatch);
615
616 struct diffusion_params diff_params;
617
618 char shift_logits_str[8];
619 if (llama_model_meta_val_str(model, key: "diffusion.shift_logits", buf: shift_logits_str, buf_size: sizeof(shift_logits_str)) >= 0) {
620 diff_params.shift_logits = (strcmp(s1: shift_logits_str, s2: "true") == 0);
621 } else {
622 diff_params.shift_logits = true;
623 }
624
625 //Use either eps or block length, but not both
626 GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
627
628 if (params.diffusion.eps) {
629 diff_params.schedule = TIMESTEP_BASED;
630 diff_params.eps = params.diffusion.eps;
631 } else if (params.diffusion.block_length) {
632 diff_params.schedule = BLOCK_BASED;
633 diff_params.block_length = params.diffusion.block_length;
634 }
635
636 diff_params.mask_token_id = mask_token_id;
637 diff_params.seed = params.sampling.seed;
638 diff_params.temperature = params.sampling.temp;
639 diff_params.steps = params.diffusion.steps;
640 diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
641 diff_params.max_length = params.n_ubatch;
642 diff_params.top_p = params.sampling.top_p;
643 diff_params.top_k = params.sampling.top_k;
644 diff_params.visual_mode = params.diffusion.visual_mode;
645 diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
646
647 diff_params.step_callback = diffusion_step_callback;
648 callback_data cb_data = { .diff_params: &diff_params, .vocab: vocab, .n_input: n_input };
649 diff_params.step_callback_user_data = &cb_data;
650
651 const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
652 const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
653 const char * alg_name =
654 (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
655 const char * sched_name =
656 (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
657
658 LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
659 LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
660 LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
661 LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
662 LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
663 LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
664 if (diff_params.schedule == TIMESTEP_BASED) {
665 LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
666 LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
667 }
668 if (diff_params.schedule == BLOCK_BASED) {
669 LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
670 LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
671 }
672
673 diffusion_generate(ctx, input_tokens: input_tokens.data(), output_tokens: output_tokens.data(), n_input, params: diff_params, n_generated);
674
675 if (n_generated > 0) {
676 if (visual_mode) {
677 //clear screen and move cursor to top-left
678 LOG_INF("\033[2J\033[H");
679 }
680
681 output_tokens.erase(first: output_tokens.begin(), last: output_tokens.begin() + n_input);
682 std::string output_data = common_detokenize(vocab, tokens: output_tokens, special: false);
683 LOG_INF("\n%s\n", output_data.c_str());
684 } else {
685 LOG_INF("Error: diffusion generation failed\n");
686 }
687
688 llama_free(ctx);
689 llama_model_free(model);
690 llama_backend_free();
691
692 return 0;
693}
694