1#include "ops.h"
2
3#include "ggml-cpu.h"
4#include "ggml-impl.h"
5#include "binary-ops.h"
6#include "ggml.h"
7#include "unary-ops.h"
8#include "vec.h"
9
10#include <float.h>
11#include <algorithm>
12
13// ggml_compute_forward_dup
14
15static void ggml_compute_forward_dup_same_cont(
16 const ggml_compute_params * params,
17 ggml_tensor * dst) {
18
19 const ggml_tensor * src0 = dst->src[0];
20
21 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
22 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
23 GGML_ASSERT(src0->type == dst->type);
24
25 const size_t nb0 = ggml_type_size(type: src0->type);
26
27 const int ith = params->ith; // thread index
28 const int nth = params->nth; // number of threads
29
30 // parallelize by blocks
31 const int nk = ggml_nelements(tensor: src0)/ggml_blck_size(type: src0->type);
32 const int dr = (nk + nth - 1) / nth;
33 const int k0 = dr * ith;
34 const int k1 = MIN(k0 + dr, nk);
35
36 if (k0 < k1) {
37 memcpy(
38 dest: ((char *) dst->data + k0*nb0),
39 src: ((char *) src0->data + k0*nb0),
40 n: (k1 - k0) * nb0);
41 }
42}
43
44template<typename src_t, typename dst_t>
45static void ggml_compute_forward_dup_flt(
46 const ggml_compute_params * params,
47 ggml_tensor * dst) {
48
49 const ggml_tensor * src0 = dst->src[0];
50
51 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
52 GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
53
54 GGML_TENSOR_UNARY_OP_LOCALS
55
56 const int ith = params->ith; // thread index
57 const int nth = params->nth; // number of threads
58
59 // parallelize by rows
60 const int nr = ne01;
61 // number of rows per thread
62 const int dr = (nr + nth - 1) / nth;
63 // row range for this thread
64 const int ir0 = dr * ith;
65 const int ir1 = MIN(ir0 + dr, nr);
66
67 // case: type & row size equal
68 if (src0->type == dst->type &&
69 ne00 == ne0 &&
70 nb00 == ggml_type_size(type: src0->type) && nb0 == ggml_type_size(type: dst->type)) {
71 // copy by rows
72 const size_t rs = ne00*nb00;
73 for (int64_t i03 = 0; i03 < ne03; i03++) {
74 for (int64_t i02 = 0; i02 < ne02; i02++) {
75 for (int64_t i01 = ir0; i01 < ir1; i01++) {
76 memcpy(
77 dest: ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
78 src: ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
79 n: rs);
80 }
81 }
82 }
83 return;
84 }
85
86 // case: dst tensor is contiguous
87 if (ggml_is_contiguous(tensor: dst)) {
88 if (nb00 == sizeof(src_t)) {
89 if constexpr (std::is_same_v<dst_t, src_t>) {
90 // same type
91 size_t id = 0;
92 const size_t rs = ne00 * nb00;
93 char * dst_ptr = (char *) dst->data;
94
95 for (int i03 = 0; i03 < ne03; i03++) {
96 for (int i02 = 0; i02 < ne02; i02++) {
97 id += rs * ir0;
98 for (int i01 = ir0; i01 < ir1; i01++) {
99 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
100 memcpy(dest: dst_ptr + id, src: src0_ptr, n: rs);
101 id += rs;
102 }
103 id += rs * (ne01 - ir1);
104 }
105 }
106 } else {
107 // casting between non-quantized types
108 size_t id = 0;
109 dst_t * dst_ptr = (dst_t *) dst->data;
110
111 for (int i03 = 0; i03 < ne03; i03++) {
112 for (int i02 = 0; i02 < ne02; i02++) {
113 id += ne00 * ir0;
114 for (int i01 = ir0; i01 < ir1; i01++) {
115 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
116 for (int i00 = 0; i00 < ne00; i00++) {
117 float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
118 dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
119 id++;
120 }
121 }
122 id += ne00 * (ne01 - ir1);
123 }
124 }
125 }
126 } else {
127 //printf("%s: this is not optimal - fix me\n", __func__);
128
129 size_t id = 0;
130 dst_t * dst_ptr = (dst_t *) dst->data;
131
132 for (int i03 = 0; i03 < ne03; i03++) {
133 for (int i02 = 0; i02 < ne02; i02++) {
134 id += ne00 * ir0;
135 for (int i01 = ir0; i01 < ir1; i01++) {
136 for (int i00 = 0; i00 < ne00; i00++) {
137 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
138
139 float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
140 dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
141 id++;
142 }
143 }
144 id += ne00 * (ne01 - ir1);
145 }
146 }
147 }
148 return;
149 }
150
151 // dst counters
152 int64_t i10 = 0;
153 int64_t i11 = 0;
154 int64_t i12 = 0;
155 int64_t i13 = 0;
156
157 if constexpr (std::is_same_v<dst_t, src_t>) {
158 for (int64_t i03 = 0; i03 < ne03; i03++) {
159 for (int64_t i02 = 0; i02 < ne02; i02++) {
160 i10 += ne00 * ir0;
161 while (i10 >= ne0) {
162 i10 -= ne0;
163 if (++i11 == ne1) {
164 i11 = 0;
165 if (++i12 == ne2) {
166 i12 = 0;
167 if (++i13 == ne3) {
168 i13 = 0;
169 }
170 }
171 }
172 }
173 for (int64_t i01 = ir0; i01 < ir1; i01++) {
174 for (int64_t i00 = 0; i00 < ne00; i00++) {
175 const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
176 char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
177
178 memcpy(dest: dst_ptr, src: src0_ptr, n: sizeof(dst_t));
179
180 if (++i10 == ne00) {
181 i10 = 0;
182 if (++i11 == ne01) {
183 i11 = 0;
184 if (++i12 == ne02) {
185 i12 = 0;
186 if (++i13 == ne03) {
187 i13 = 0;
188 }
189 }
190 }
191 }
192 }
193 }
194 i10 += ne00 * (ne01 - ir1);
195 while (i10 >= ne0) {
196 i10 -= ne0;
197 if (++i11 == ne1) {
198 i11 = 0;
199 if (++i12 == ne2) {
200 i12 = 0;
201 if (++i13 == ne3) {
202 i13 = 0;
203 }
204 }
205 }
206 }
207 }
208 }
209
210 } else {
211 for (int64_t i03 = 0; i03 < ne03; i03++) {
212 for (int64_t i02 = 0; i02 < ne02; i02++) {
213 i10 += ne00 * ir0;
214 while (i10 >= ne0) {
215 i10 -= ne0;
216 if (++i11 == ne1) {
217 i11 = 0;
218 if (++i12 == ne2) {
219 i12 = 0;
220 if (++i13 == ne3) {
221 i13 = 0;
222 }
223 }
224 }
225 }
226 for (int64_t i01 = ir0; i01 < ir1; i01++) {
227 for (int64_t i00 = 0; i00 < ne00; i00++) {
228 const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
229 char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
230
231 float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
232 *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
233
234 if (++i10 == ne0) {
235 i10 = 0;
236 if (++i11 == ne1) {
237 i11 = 0;
238 if (++i12 == ne2) {
239 i12 = 0;
240 if (++i13 == ne3) {
241 i13 = 0;
242 }
243 }
244 }
245 }
246 }
247 }
248 i10 += ne00 * (ne01 - ir1);
249 while (i10 >= ne0) {
250 i10 -= ne0;
251 if (++i11 == ne1) {
252 i11 = 0;
253 if (++i12 == ne2) {
254 i12 = 0;
255 if (++i13 == ne3) {
256 i13 = 0;
257 }
258 }
259 }
260 }
261 }
262 }
263 }
264}
265
266
267template<typename src_t>
268static void ggml_compute_forward_dup_to_q(
269 const ggml_compute_params * params,
270 ggml_tensor * dst) {
271
272 const ggml_tensor * src0 = dst->src[0];
273
274 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
275 GGML_ASSERT(!ggml_is_quantized(src0->type));
276
277 GGML_TENSOR_UNARY_OP_LOCALS
278
279 const int ith = params->ith; // thread index
280 const int nth = params->nth; // number of threads
281
282 // parallelize by rows
283 const int nr = ne01;
284 // number of rows per thread
285 const int dr = (nr + nth - 1) / nth;
286 // row range for this thread
287 const int ir0 = dr * ith;
288 const int ir1 = MIN(ir0 + dr, nr);
289
290 if (ggml_is_contiguous(tensor: dst) &&
291 nb00 == sizeof(src_t) &&
292 ggml_get_type_traits_cpu(type: dst->type)->from_float) {
293 // casting non-quantized types --> intermediate f32 --> quantized
294 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type: dst->type)->from_float;
295 float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
296
297 size_t id = 0;
298 size_t rs = nb0 * (ne00 / ggml_blck_size(type: dst->type));
299 char * dst_ptr = (char *) dst->data;
300
301 for (int i03 = 0; i03 < ne03; i03++) {
302 for (int i02 = 0; i02 < ne02; i02++) {
303 id += rs * ir0;
304 for (int i01 = ir0; i01 < ir1; i01++) {
305 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
306
307 for (int i00 = 0; i00 < ne00; i00++) {
308 src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
309 }
310
311 quantize_row_q(src0_f32, dst_ptr + id, ne00);
312 id += rs;
313 }
314 id += rs * (ne01 - ir1);
315 }
316 }
317 } else {
318 // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
319 GGML_ABORT("not implemented");
320 }
321}
322
323// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
324static void ggml_compute_forward_dup_bytes(
325 const ggml_compute_params * params,
326 ggml_tensor * dst) {
327 const ggml_tensor * src0 = dst->src[0];
328
329 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
330 GGML_ASSERT(src0->type == dst->type);
331
332 GGML_TENSOR_UNARY_OP_LOCALS;
333
334 if (ggml_is_contiguous(tensor: src0) && ggml_is_contiguous(tensor: dst)) {
335 ggml_compute_forward_dup_same_cont(params, dst);
336 return;
337 }
338
339 const size_t type_size = ggml_type_size(type: src0->type);
340
341 const int ith = params->ith; // thread index
342 const int nth = params->nth; // number of threads
343
344 // parallelize by rows
345 const int nr = ne01;
346 // number of rows per thread
347 const int dr = (nr + nth - 1) / nth;
348 // row range for this thread
349 const int ir0 = dr * ith;
350 const int ir1 = MIN(ir0 + dr, nr);
351
352 if (src0->type == dst->type &&
353 ggml_are_same_shape(t0: src0, t1: dst) &&
354 nb00 == type_size && nb0 == type_size) {
355 // copy by rows
356 const size_t rs = ggml_row_size(type: src0->type, ne: ne00);
357 for (int64_t i03 = 0; i03 < ne03; i03++) {
358 for (int64_t i02 = 0; i02 < ne02; i02++) {
359 for (int64_t i01 = ir0; i01 < ir1; i01++) {
360 memcpy(
361 dest: ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
362 src: ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
363 n: rs);
364 }
365 }
366 }
367 return;
368 }
369
370 if (ggml_is_contiguous(tensor: dst)) {
371 size_t id = 0;
372 char * dst_ptr = (char *) dst->data;
373 const size_t rs = ne00 * type_size;
374
375 if (nb00 == type_size) {
376 // src0 is contigous on first dimension, copy by rows
377 for (int64_t i03 = 0; i03 < ne03; i03++) {
378 for (int64_t i02 = 0; i02 < ne02; i02++) {
379 id += rs * ir0;
380 for (int64_t i01 = ir0; i01 < ir1; i01++) {
381 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
382 memcpy(dest: dst_ptr + id, src: src0_ptr, n: rs);
383 id += rs;
384 }
385 id += rs * (ne01 - ir1);
386 }
387 }
388 } else {
389 //printf("%s: this is not optimal - fix me\n", __func__);
390
391 for (int64_t i03 = 0; i03 < ne03; i03++) {
392 for (int64_t i02 = 0; i02 < ne02; i02++) {
393 id += rs * ir0;
394 for (int64_t i01 = ir0; i01 < ir1; i01++) {
395 for (int64_t i00 = 0; i00 < ne00; i00++) {
396 const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
397 memcpy(dest: dst_ptr + id, src: src0_ptr, n: type_size);
398
399 id += type_size;
400 }
401 }
402 id += rs * (ne01 - ir1);
403 }
404 }
405 }
406
407 return;
408 }
409
410 // dst counters
411 int64_t k10 = 0;
412 int64_t i11 = 0;
413 int64_t i12 = 0;
414 int64_t i13 = 0;
415
416 // number of blocks in a row
417 const int64_t nk00 = ne00 / ggml_blck_size(type: src0->type);
418 const int64_t nk0 = ne0 / ggml_blck_size(type: dst->type);
419
420 for (int64_t i03 = 0; i03 < ne03; i03++) {
421 for (int64_t i02 = 0; i02 < ne02; i02++) {
422 k10 += nk00 * ir0;
423 while (k10 >= nk0) {
424 k10 -= nk0;
425 if (++i11 == ne1) {
426 i11 = 0;
427 if (++i12 == ne2) {
428 i12 = 0;
429 if (++i13 == ne3) {
430 i13 = 0;
431 }
432 }
433 }
434 }
435 for (int64_t i01 = ir0; i01 < ir1; i01++) {
436 for (int64_t k00 = 0; k00 < nk00; k00++) {
437 const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
438 char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
439
440 memcpy(dest: dst_ptr, src: src0_ptr, n: type_size);
441
442 if (++k10 == nk0) {
443 k10 = 0;
444 if (++i11 == ne1) {
445 i11 = 0;
446 if (++i12 == ne2) {
447 i12 = 0;
448 if (++i13 == ne3) {
449 i13 = 0;
450 }
451 }
452 }
453 }
454 }
455 }
456 k10 += nk00 * (ne01 - ir1);
457 while (k10 >= nk0) {
458 k10 -= nk0;
459 if (++i11 == ne1) {
460 i11 = 0;
461 if (++i12 == ne2) {
462 i12 = 0;
463 if (++i13 == ne3) {
464 i13 = 0;
465 }
466 }
467 }
468 }
469 }
470 }
471}
472
473static void ggml_compute_forward_dup_from_q(
474 const ggml_compute_params * params,
475 ggml_tensor * dst) {
476
477 const ggml_tensor * src0 = dst->src[0];
478 const ggml_tensor * src1 = dst->src[1];
479
480 GGML_TENSOR_BINARY_OP_LOCALS
481
482 const ggml_type type = src0->type;
483 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
484
485 size_t qk = ggml_blck_size(type);
486 const int64_t nr = ggml_nelements(tensor: src1) / qk;
487
488 // destination must be contiguous in the first dimension
489 GGML_ASSERT(nb10 == ggml_type_size(dst->type));
490 // must either have first dimension large enough to hold a row, or fully contiguous
491 GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
492
493 const int ith = params->ith;
494 const int nth = params->nth;
495
496 const int dr = (nr + nth - 1)/nth;
497
498 // row range for this thread
499 const int ir0 = dr*ith;
500 const int ir1 = MIN(ir0 + dr, nr);
501
502 for (int64_t ir = ir0; ir < ir1; ++ir) {
503
504 uint32_t i = ir * qk;
505
506 const int64_t i03 = i/(ne00 * ne01 * ne02);
507 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
508 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
509 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
510 const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
511
512 const int64_t i13 = i/(ne10 * ne11 * ne12);
513 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
514 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
515 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
516 const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
517
518 dequantize_row_q(
519 (const void *) ((char *) src0->data + x_offset),
520 (float *) ((char *) dst->data + dst_offset), qk);
521 }
522}
523
524void ggml_compute_forward_dup(
525 const ggml_compute_params * params,
526 ggml_tensor * dst) {
527
528 const ggml_tensor * src0 = dst->src[0];
529
530 if (src0->type == dst->type) {
531 ggml_compute_forward_dup_bytes(params, dst);
532 return;
533 }
534
535 switch (src0->type) {
536 case GGML_TYPE_F16:
537 {
538 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
539 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
540 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
541 else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
542 } break;
543 case GGML_TYPE_BF16:
544 {
545 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
546 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
547 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
548 else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
549 } break;
550 case GGML_TYPE_F32:
551 {
552 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
553 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
554 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
555 else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
556 else ggml_compute_forward_dup_to_q<float>(params, dst);
557 } break;
558 case GGML_TYPE_I32:
559 {
560 if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
561 else GGML_ABORT("not implemented");
562 } break;
563 default:
564 {
565 if (ggml_is_quantized(type: src0->type) && dst->type == GGML_TYPE_F32) {
566 ggml_compute_forward_dup_from_q(params, dst);
567 break;
568 }
569 GGML_ABORT("fatal error");
570 }
571 }
572}
573
574// ggml_compute_forward_add
575
576static void ggml_compute_forward_add_q_f32(
577 const ggml_compute_params * params,
578 ggml_tensor * dst) {
579
580 const ggml_tensor * src0 = dst->src[0];
581 const ggml_tensor * src1 = dst->src[1];
582
583 GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
584
585 const int nr = ggml_nrows(tensor: src0);
586
587 GGML_TENSOR_BINARY_OP_LOCALS
588
589 const int ith = params->ith;
590 const int nth = params->nth;
591
592 const ggml_type type = src0->type;
593 const ggml_type dtype = dst->type;
594 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
595 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type: dtype)->from_float;
596
597 // we don't support permuted src0 or src1
598 GGML_ASSERT(nb00 == ggml_type_size(type));
599 GGML_ASSERT(nb10 == sizeof(float));
600
601 // dst cannot be transposed or permuted
602 GGML_ASSERT(nb0 <= nb1);
603 GGML_ASSERT(nb1 <= nb2);
604 GGML_ASSERT(nb2 <= nb3);
605
606 GGML_ASSERT(ggml_is_quantized(src0->type));
607 GGML_ASSERT(src1->type == GGML_TYPE_F32);
608
609 // rows per thread
610 const int dr = (nr + nth - 1)/nth;
611
612 // row range for this thread
613 const int ir0 = dr*ith;
614 const int ir1 = MIN(ir0 + dr, nr);
615
616 float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
617
618 for (int ir = ir0; ir < ir1; ++ir) {
619 // src0 indices
620 const int i03 = ir/(ne02*ne01);
621 const int i02 = (ir - i03*ne02*ne01)/ne01;
622 const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
623
624 // src1 and dst are same shape as src0 => same indices
625 const int i13 = i03;
626 const int i12 = i02;
627 const int i11 = i01;
628
629 const int i3 = i03;
630 const int i2 = i02;
631 const int i1 = i01;
632
633 void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
634 float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
635 void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
636
637 assert(ne00 % 32 == 0);
638
639 // unquantize row from src0 to temp buffer
640 dequantize_row_q(src0_row, wdata, ne00);
641 // add src1
642 ggml_vec_acc_f32(n: ne00, y: wdata, x: src1_row);
643 // quantize row to dst
644 if (quantize_row_q != NULL) {
645 quantize_row_q(wdata, dst_row, ne00);
646 } else {
647 memcpy(dest: dst_row, src: wdata, n: ne0*nb0);
648 }
649 }
650}
651
652void ggml_compute_forward_add(
653 const ggml_compute_params * params,
654 ggml_tensor * dst) {
655
656 const ggml_tensor * src0 = dst->src[0];
657
658 switch (src0->type) {
659 case GGML_TYPE_F32:
660 case GGML_TYPE_F16:
661 case GGML_TYPE_BF16:
662 {
663 ggml_compute_forward_add_non_quantized(params, dst);
664 } break;
665 case GGML_TYPE_Q4_0:
666 case GGML_TYPE_Q4_1:
667 case GGML_TYPE_Q5_0:
668 case GGML_TYPE_Q5_1:
669 case GGML_TYPE_Q8_0:
670 case GGML_TYPE_MXFP4:
671 case GGML_TYPE_Q2_K:
672 case GGML_TYPE_Q3_K:
673 case GGML_TYPE_Q4_K:
674 case GGML_TYPE_Q5_K:
675 case GGML_TYPE_Q6_K:
676 case GGML_TYPE_TQ1_0:
677 case GGML_TYPE_TQ2_0:
678 case GGML_TYPE_IQ2_XXS:
679 case GGML_TYPE_IQ2_XS:
680 case GGML_TYPE_IQ3_XXS:
681 case GGML_TYPE_IQ1_S:
682 case GGML_TYPE_IQ1_M:
683 case GGML_TYPE_IQ4_NL:
684 case GGML_TYPE_IQ4_XS:
685 case GGML_TYPE_IQ3_S:
686 case GGML_TYPE_IQ2_S:
687 {
688 ggml_compute_forward_add_q_f32(params, dst);
689 } break;
690 default:
691 {
692 GGML_ABORT("fatal error");
693 }
694 }
695}
696
697// ggml_compute_forward_add_id
698
699static void ggml_compute_forward_add_id_f32(
700 const ggml_compute_params * params,
701 ggml_tensor * dst) {
702
703 const ggml_tensor * src0 = dst->src[0];
704 const ggml_tensor * src1 = dst->src[1];
705 const ggml_tensor * src2 = dst->src[2];
706
707 GGML_ASSERT(dst->type == GGML_TYPE_F32);
708 GGML_ASSERT(src0->type == GGML_TYPE_F32);
709 GGML_ASSERT(src1->type == GGML_TYPE_F32);
710 GGML_ASSERT(src2->type == GGML_TYPE_I32);
711
712 GGML_ASSERT(src0->nb[0] == sizeof(float));
713 GGML_ASSERT(src1->nb[0] == sizeof(float));
714
715 const int ith = params->ith;
716 const int nth = params->nth;
717
718 const int nr = ggml_nrows(tensor: src0);
719
720 GGML_TENSOR_TERNARY_OP_LOCALS
721
722 GGML_ASSERT( nb0 == sizeof(float));
723 GGML_ASSERT(nb10 == sizeof(float));
724
725 // rows per thread
726 const int dr = (nr + nth - 1)/nth;
727
728 // row range for this thread
729 const int ir0 = dr*ith;
730 const int ir1 = MIN(ir0 + dr, nr);
731
732 for (int ir = ir0; ir < ir1; ++ir) {
733 // src0 indices
734 const int i3 = ir/(ne2*ne1);
735 const int i2 = (ir - i3*ne2*ne1)/ne1;
736 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
737
738 // src1 indices
739 const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
740
741 GGML_ASSERT(i11 >= 0 && i11 < ne11);
742
743 ggml_vec_add_f32(n: ne0,
744 z: (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
745 x: (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
746 y: (float *) ((char *) src1->data + i11*nb11));
747 }
748}
749
750void ggml_compute_forward_add_id(
751 const ggml_compute_params * params,
752 ggml_tensor * dst) {
753
754 const ggml_tensor * src0 = dst->src[0];
755
756 switch (src0->type) {
757 case GGML_TYPE_F32:
758 {
759 ggml_compute_forward_add_id_f32(params, dst);
760 } break;
761 default:
762 {
763 GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
764 }
765 }
766}
767
768// ggml_compute_forward_add1
769
770static void ggml_compute_forward_add1_f32(
771 const ggml_compute_params * params,
772 ggml_tensor * dst) {
773
774 const ggml_tensor * src0 = dst->src[0];
775 const ggml_tensor * src1 = dst->src[1];
776
777 GGML_ASSERT(ggml_are_same_shape(src0, dst));
778 GGML_ASSERT(ggml_is_scalar(src1));
779
780 const int ith = params->ith;
781 const int nth = params->nth;
782
783 const int nr = ggml_nrows(tensor: src0);
784
785 GGML_TENSOR_UNARY_OP_LOCALS
786
787 GGML_ASSERT( nb0 == sizeof(float));
788 GGML_ASSERT(nb00 == sizeof(float));
789
790 // rows per thread
791 const int dr = (nr + nth - 1)/nth;
792
793 // row range for this thread
794 const int ir0 = dr*ith;
795 const int ir1 = MIN(ir0 + dr, nr);
796
797 for (int ir = ir0; ir < ir1; ++ir) {
798 // src0 and dst are same shape => same indices
799 const int i3 = ir/(ne2*ne1);
800 const int i2 = (ir - i3*ne2*ne1)/ne1;
801 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
802
803#ifdef GGML_USE_ACCELERATE
804 GGML_UNUSED(ggml_vec_add1_f32);
805
806 vDSP_vadd(
807 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
808 (float *) ((char *) src1->data), 0,
809 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
810 ne0);
811#else
812 ggml_vec_add1_f32(n: ne0,
813 z: (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
814 x: (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
815 v: *(float *) src1->data);
816#endif
817 }
818}
819
820static void ggml_compute_forward_add1_f16_f32(
821 const ggml_compute_params * params,
822 ggml_tensor * dst) {
823
824 const ggml_tensor * src0 = dst->src[0];
825 const ggml_tensor * src1 = dst->src[1];
826
827 GGML_ASSERT(ggml_are_same_shape(src0, dst));
828 GGML_ASSERT(ggml_is_scalar(src1));
829
830 // scalar to add
831 const float v = *(float *) src1->data;
832
833 const int ith = params->ith;
834 const int nth = params->nth;
835
836 const int nr = ggml_nrows(tensor: src0);
837
838 GGML_TENSOR_UNARY_OP_LOCALS
839
840 GGML_ASSERT(src0->type == GGML_TYPE_F16);
841 GGML_ASSERT(src1->type == GGML_TYPE_F32);
842 GGML_ASSERT(dst->type == GGML_TYPE_F16);
843
844 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
845 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
846
847 // rows per thread
848 const int dr = (nr + nth - 1)/nth;
849
850 // row range for this thread
851 const int ir0 = dr*ith;
852 const int ir1 = MIN(ir0 + dr, nr);
853
854 for (int ir = ir0; ir < ir1; ++ir) {
855 // src0 and dst are same shape => same indices
856 const int i3 = ir/(ne2*ne1);
857 const int i2 = (ir - i3*ne2*ne1)/ne1;
858 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
859
860 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
861 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
862 for (int i = 0; i < ne0; i++) {
863 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
864 }
865 }
866}
867
868static void ggml_compute_forward_add1_f16_f16(
869 const ggml_compute_params * params,
870 ggml_tensor * dst) {
871
872 const ggml_tensor * src0 = dst->src[0];
873 const ggml_tensor * src1 = dst->src[1];
874
875 GGML_ASSERT(ggml_are_same_shape(src0, dst));
876 GGML_ASSERT(ggml_is_scalar(src1));
877
878 // scalar to add
879 const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
880
881 const int ith = params->ith;
882 const int nth = params->nth;
883
884 const int nr = ggml_nrows(tensor: src0);
885
886 GGML_TENSOR_UNARY_OP_LOCALS
887
888 GGML_ASSERT(src0->type == GGML_TYPE_F16);
889 GGML_ASSERT(src1->type == GGML_TYPE_F16);
890 GGML_ASSERT(dst->type == GGML_TYPE_F16);
891
892 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
893 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
894
895 // rows per thread
896 const int dr = (nr + nth - 1)/nth;
897
898 // row range for this thread
899 const int ir0 = dr*ith;
900 const int ir1 = MIN(ir0 + dr, nr);
901
902 for (int ir = ir0; ir < ir1; ++ir) {
903 // src0 and dst are same shape => same indices
904 const int i3 = ir/(ne2*ne1);
905 const int i2 = (ir - i3*ne2*ne1)/ne1;
906 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
907
908 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
909 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
910 for (int i = 0; i < ne0; i++) {
911 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
912 }
913 }
914}
915
916static void ggml_compute_forward_add1_q_f32(
917 const ggml_compute_params * params,
918 ggml_tensor * dst) {
919
920 const ggml_tensor * src0 = dst->src[0];
921 const ggml_tensor * src1 = dst->src[1];
922
923 GGML_ASSERT(ggml_are_same_shape(src0, dst));
924 GGML_ASSERT(ggml_is_scalar(src1));
925
926 // scalar to add
927 const float v = *(float *) src1->data;
928
929 const int ith = params->ith;
930 const int nth = params->nth;
931
932 const int nr = ggml_nrows(tensor: src0);
933
934 GGML_TENSOR_UNARY_OP_LOCALS
935
936 const ggml_type type = src0->type;
937 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
938 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
939
940 // we don't support permuted src0
941 GGML_ASSERT(nb00 == ggml_type_size(type));
942
943 // dst cannot be transposed or permuted
944 GGML_ASSERT(nb0 <= nb1);
945 GGML_ASSERT(nb1 <= nb2);
946 GGML_ASSERT(nb2 <= nb3);
947
948 GGML_ASSERT(ggml_is_quantized(src0->type));
949 GGML_ASSERT(dst->type == src0->type);
950 GGML_ASSERT(src1->type == GGML_TYPE_F32);
951
952 // rows per thread
953 const int dr = (nr + nth - 1)/nth;
954
955 // row range for this thread
956 const int ir0 = dr*ith;
957 const int ir1 = MIN(ir0 + dr, nr);
958
959 float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
960
961 for (int ir = ir0; ir < ir1; ++ir) {
962 // src0 and dst are same shape => same indices
963 const int i3 = ir/(ne2*ne1);
964 const int i2 = (ir - i3*ne2*ne1)/ne1;
965 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
966
967 void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
968 void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
969
970 assert(ne0 % 32 == 0);
971
972 // unquantize row from src0 to temp buffer
973 dequantize_row_q(src0_row, wdata, ne0);
974 // add src1
975 ggml_vec_acc1_f32(n: ne0, y: wdata, v);
976 // quantize row to dst
977 quantize_row_q(wdata, dst_row, ne0);
978 }
979}
980
981static void ggml_compute_forward_add1_bf16_f32(
982 const ggml_compute_params * params,
983 ggml_tensor * dst) {
984
985 const ggml_tensor * src0 = dst->src[0];
986 const ggml_tensor * src1 = dst->src[1];
987
988 GGML_ASSERT(ggml_are_same_shape(src0, dst));
989 GGML_ASSERT(ggml_is_scalar(src1));
990
991 // scalar to add
992 const float v = *(float *) src1->data;
993
994 const int ith = params->ith;
995 const int nth = params->nth;
996
997 const int nr = ggml_nrows(tensor: src0);
998
999 GGML_TENSOR_UNARY_OP_LOCALS
1000
1001 GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1002 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1003 GGML_ASSERT(dst->type == GGML_TYPE_BF16);
1004
1005 GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1006 GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1007
1008 // rows per thread
1009 const int dr = (nr + nth - 1)/nth;
1010
1011 // row range for this thread
1012 const int ir0 = dr*ith;
1013 const int ir1 = MIN(ir0 + dr, nr);
1014
1015 for (int ir = ir0; ir < ir1; ++ir) {
1016 // src0 and dst are same shape => same indices
1017 const int i3 = ir/(ne2*ne1);
1018 const int i2 = (ir - i3*ne2*ne1)/ne1;
1019 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1020
1021 ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1022 ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1023 for (int i = 0; i < ne0; i++) {
1024 dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1025 }
1026 }
1027}
1028
1029static void ggml_compute_forward_add1_bf16_bf16(
1030 const ggml_compute_params * params,
1031 ggml_tensor * dst) {
1032
1033 const ggml_tensor * src0 = dst->src[0];
1034 const ggml_tensor * src1 = dst->src[1];
1035
1036 GGML_ASSERT(ggml_are_same_shape(src0, dst));
1037 GGML_ASSERT(ggml_is_scalar(src1));
1038
1039 // scalar to add
1040 const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
1041
1042 const int ith = params->ith;
1043 const int nth = params->nth;
1044
1045 const int nr = ggml_nrows(tensor: src0);
1046
1047 GGML_TENSOR_UNARY_OP_LOCALS
1048
1049 GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1050 GGML_ASSERT(src1->type == GGML_TYPE_BF16);
1051 GGML_ASSERT(dst->type == GGML_TYPE_BF16);
1052
1053 GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1054 GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1055
1056 // rows per thread
1057 const int dr = (nr + nth - 1)/nth;
1058
1059 // row range for this thread
1060 const int ir0 = dr*ith;
1061 const int ir1 = MIN(ir0 + dr, nr);
1062
1063 for (int ir = ir0; ir < ir1; ++ir) {
1064 // src0 and dst are same shape => same indices
1065 const int i3 = ir/(ne2*ne1);
1066 const int i2 = (ir - i3*ne2*ne1)/ne1;
1067 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1068
1069 ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1070 ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1071 for (int i = 0; i < ne0; i++) {
1072 dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1073 }
1074 }
1075}
1076
1077void ggml_compute_forward_add1(
1078 const ggml_compute_params * params,
1079 ggml_tensor * dst) {
1080
1081 const ggml_tensor * src0 = dst->src[0];
1082 const ggml_tensor * src1 = dst->src[1];
1083
1084 switch (src0->type) {
1085 case GGML_TYPE_F32:
1086 {
1087 ggml_compute_forward_add1_f32(params, dst);
1088 } break;
1089 case GGML_TYPE_F16:
1090 {
1091 if (src1->type == GGML_TYPE_F16) {
1092 ggml_compute_forward_add1_f16_f16(params, dst);
1093 }
1094 else if (src1->type == GGML_TYPE_F32) {
1095 ggml_compute_forward_add1_f16_f32(params, dst);
1096 }
1097 else {
1098 GGML_ABORT("fatal error");
1099 }
1100 } break;
1101 case GGML_TYPE_BF16:
1102 {
1103 if (src1->type == GGML_TYPE_BF16) {
1104 ggml_compute_forward_add1_bf16_bf16(params, dst);
1105 }
1106 else if (src1->type == GGML_TYPE_F32) {
1107 ggml_compute_forward_add1_bf16_f32(params, dst);
1108 }
1109 else {
1110 GGML_ABORT("fatal error");
1111 }
1112 } break;
1113 case GGML_TYPE_Q4_0:
1114 case GGML_TYPE_Q4_1:
1115 case GGML_TYPE_Q5_0:
1116 case GGML_TYPE_Q5_1:
1117 case GGML_TYPE_Q8_0:
1118 case GGML_TYPE_Q8_1:
1119 case GGML_TYPE_MXFP4:
1120 case GGML_TYPE_Q2_K:
1121 case GGML_TYPE_Q3_K:
1122 case GGML_TYPE_Q4_K:
1123 case GGML_TYPE_Q5_K:
1124 case GGML_TYPE_Q6_K:
1125 case GGML_TYPE_TQ1_0:
1126 case GGML_TYPE_TQ2_0:
1127 case GGML_TYPE_IQ2_XXS:
1128 case GGML_TYPE_IQ2_XS:
1129 case GGML_TYPE_IQ3_XXS:
1130 case GGML_TYPE_IQ1_S:
1131 case GGML_TYPE_IQ1_M:
1132 case GGML_TYPE_IQ4_NL:
1133 case GGML_TYPE_IQ4_XS:
1134 case GGML_TYPE_IQ3_S:
1135 case GGML_TYPE_IQ2_S:
1136 {
1137 ggml_compute_forward_add1_q_f32(params, dst);
1138 } break;
1139 default:
1140 {
1141 GGML_ABORT("fatal error");
1142 }
1143 }
1144}
1145
1146// ggml_compute_forward_acc
1147
1148static void ggml_compute_forward_acc_f32(
1149 const ggml_compute_params * params,
1150 ggml_tensor * dst) {
1151
1152 const ggml_tensor * src0 = dst->src[0];
1153 const ggml_tensor * src1 = dst->src[1];
1154
1155 GGML_ASSERT(ggml_are_same_shape(src0, dst));
1156 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
1157
1158 // view src0 and dst with these strides and data offset inbytes during acc
1159 // nb0 is implicitly element_size because src0 and dst are contiguous
1160 size_t nb1 = ((int32_t *) dst->op_params)[0];
1161 size_t nb2 = ((int32_t *) dst->op_params)[1];
1162 size_t nb3 = ((int32_t *) dst->op_params)[2];
1163 size_t offset = ((int32_t *) dst->op_params)[3];
1164 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1165
1166 if (!inplace) {
1167 if (params->ith == 0) {
1168 // memcpy needs to be synchronized across threads to avoid race conditions.
1169 // => do it in INIT phase
1170 memcpy(
1171 dest: ((char *) dst->data),
1172 src: ((char *) src0->data),
1173 n: ggml_nbytes(tensor: dst));
1174 }
1175 ggml_barrier(tp: params->threadpool);
1176 }
1177
1178 const int ith = params->ith;
1179 const int nth = params->nth;
1180
1181 const int nr = ggml_nrows(tensor: src1);
1182 const int nc = src1->ne[0];
1183
1184 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
1185 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
1186
1187 // src0 and dst as viewed during acc
1188 const size_t nb0 = ggml_element_size(tensor: src0);
1189
1190 const size_t nb00 = nb0;
1191 const size_t nb01 = nb1;
1192 const size_t nb02 = nb2;
1193 const size_t nb03 = nb3;
1194
1195 GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
1196 GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
1197
1198 GGML_ASSERT(nb10 == sizeof(float));
1199
1200 // rows per thread
1201 const int dr = (nr + nth - 1)/nth;
1202
1203 // row range for this thread
1204 const int ir0 = dr*ith;
1205 const int ir1 = MIN(ir0 + dr, nr);
1206
1207 for (int ir = ir0; ir < ir1; ++ir) {
1208 // src0 and dst are viewed with shape of src1 and offset
1209 // => same indices
1210 const int i3 = ir/(ne12*ne11);
1211 const int i2 = (ir - i3*ne12*ne11)/ne11;
1212 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
1213
1214#ifdef GGML_USE_ACCELERATE
1215 vDSP_vadd(
1216 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
1217 (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
1218 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
1219#else
1220 ggml_vec_add_f32(n: nc,
1221 z: (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
1222 x: (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
1223 y: (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
1224#endif
1225 }
1226}
1227
1228void ggml_compute_forward_acc(
1229 const ggml_compute_params * params,
1230 ggml_tensor * dst) {
1231
1232 const ggml_tensor * src0 = dst->src[0];
1233
1234 switch (src0->type) {
1235 case GGML_TYPE_F32:
1236 {
1237 ggml_compute_forward_acc_f32(params, dst);
1238 } break;
1239 case GGML_TYPE_F16:
1240 case GGML_TYPE_BF16:
1241 case GGML_TYPE_Q4_0:
1242 case GGML_TYPE_Q4_1:
1243 case GGML_TYPE_Q5_0:
1244 case GGML_TYPE_Q5_1:
1245 case GGML_TYPE_Q8_0:
1246 case GGML_TYPE_Q8_1:
1247 case GGML_TYPE_MXFP4:
1248 case GGML_TYPE_Q2_K:
1249 case GGML_TYPE_Q3_K:
1250 case GGML_TYPE_Q4_K:
1251 case GGML_TYPE_Q5_K:
1252 case GGML_TYPE_Q6_K:
1253 case GGML_TYPE_TQ1_0:
1254 case GGML_TYPE_TQ2_0:
1255 case GGML_TYPE_IQ2_XXS:
1256 case GGML_TYPE_IQ2_XS:
1257 case GGML_TYPE_IQ3_XXS:
1258 case GGML_TYPE_IQ1_S:
1259 case GGML_TYPE_IQ1_M:
1260 case GGML_TYPE_IQ4_NL:
1261 case GGML_TYPE_IQ4_XS:
1262 case GGML_TYPE_IQ3_S:
1263 case GGML_TYPE_IQ2_S:
1264 default:
1265 {
1266 GGML_ABORT("fatal error");
1267 }
1268 }
1269}
1270
1271// ggml_compute_forward_sum
1272
1273static void ggml_compute_forward_sum_f32(
1274 const ggml_compute_params * params,
1275 ggml_tensor * dst) {
1276
1277 const ggml_tensor * src0 = dst->src[0];
1278
1279 if (params->ith != 0) {
1280 return;
1281 }
1282
1283 assert(ggml_is_scalar(dst));
1284 assert(src0->nb[0] == sizeof(float));
1285
1286 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1287 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1288
1289 ggml_float sum = 0;
1290 ggml_float row_sum = 0;
1291
1292 for (int64_t i03 = 0; i03 < ne03; i03++) {
1293 for (int64_t i02 = 0; i02 < ne02; i02++) {
1294 for (int64_t i01 = 0; i01 < ne01; i01++) {
1295 ggml_vec_sum_f32_ggf(n: ne00,
1296 s: &row_sum,
1297 x: (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1298 sum += row_sum;
1299 }
1300 }
1301 }
1302 ((float *) dst->data)[0] = sum;
1303}
1304
1305static void ggml_compute_forward_sum_f16(
1306 const ggml_compute_params * params,
1307 ggml_tensor * dst) {
1308
1309 const ggml_tensor * src0 = dst->src[0];
1310
1311 if (params->ith != 0) {
1312 return;
1313 }
1314
1315 assert(ggml_is_scalar(dst));
1316
1317 assert(src0->nb[0] == sizeof(ggml_fp16_t));
1318
1319 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1320 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1321
1322 float sum = 0;
1323 float row_sum = 0;
1324
1325 for (int64_t i03 = 0; i03 < ne03; i03++) {
1326 for (int64_t i02 = 0; i02 < ne02; i02++) {
1327 for (int64_t i01 = 0; i01 < ne01; i01++) {
1328 ggml_vec_sum_f16_ggf(n: ne00,
1329 s: &row_sum,
1330 x: (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1331 sum += row_sum;
1332 }
1333 }
1334 }
1335 ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1336}
1337
1338static void ggml_compute_forward_sum_bf16(
1339 const ggml_compute_params * params,
1340 ggml_tensor * dst) {
1341
1342 const ggml_tensor * src0 = dst->src[0];
1343
1344 if (params->ith != 0) {
1345 return;
1346 }
1347
1348 assert(ggml_is_scalar(dst));
1349
1350 assert(src0->nb[0] == sizeof(ggml_bf16_t));
1351
1352 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1353 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1354
1355 float sum = 0;
1356 float row_sum = 0;
1357
1358 for (int64_t i03 = 0; i03 < ne03; i03++) {
1359 for (int64_t i02 = 0; i02 < ne02; i02++) {
1360 for (int64_t i01 = 0; i01 < ne01; i01++) {
1361 ggml_vec_sum_bf16_ggf(n: ne00,
1362 s: &row_sum,
1363 x: (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1364 sum += row_sum;
1365 }
1366 }
1367 }
1368 ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
1369}
1370
1371void ggml_compute_forward_sum(
1372 const ggml_compute_params * params,
1373 ggml_tensor * dst) {
1374
1375 const ggml_tensor * src0 = dst->src[0];
1376
1377 switch (src0->type) {
1378 case GGML_TYPE_F32:
1379 {
1380 ggml_compute_forward_sum_f32(params, dst);
1381 } break;
1382 case GGML_TYPE_F16:
1383 {
1384 ggml_compute_forward_sum_f16(params, dst);
1385 } break;
1386 case GGML_TYPE_BF16:
1387 {
1388 ggml_compute_forward_sum_bf16(params, dst);
1389 } break;
1390 default:
1391 {
1392 GGML_ABORT("fatal error");
1393 }
1394 }
1395}
1396
1397// ggml_compute_forward_sum_rows
1398
1399static void ggml_compute_forward_sum_rows_f32(
1400 const ggml_compute_params * params,
1401 ggml_tensor * dst) {
1402
1403 const ggml_tensor * src0 = dst->src[0];
1404
1405 if (params->ith != 0) {
1406 return;
1407 }
1408
1409 GGML_ASSERT(src0->nb[0] == sizeof(float));
1410 GGML_ASSERT(dst->nb[0] == sizeof(float));
1411
1412 GGML_TENSOR_UNARY_OP_LOCALS
1413
1414 GGML_ASSERT(ne0 == 1);
1415 GGML_ASSERT(ne1 == ne01);
1416 GGML_ASSERT(ne2 == ne02);
1417 GGML_ASSERT(ne3 == ne03);
1418
1419 for (int64_t i3 = 0; i3 < ne03; i3++) {
1420 for (int64_t i2 = 0; i2 < ne02; i2++) {
1421 for (int64_t i1 = 0; i1 < ne01; i1++) {
1422 float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1423 float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
1424 float row_sum = 0;
1425 ggml_vec_sum_f32(n: ne00, s: &row_sum, x: src_row);
1426 dst_row[0] = row_sum;
1427 }
1428 }
1429 }
1430}
1431
1432void ggml_compute_forward_sum_rows(
1433 const ggml_compute_params * params,
1434 ggml_tensor * dst) {
1435
1436 const ggml_tensor * src0 = dst->src[0];
1437
1438 switch (src0->type) {
1439 case GGML_TYPE_F32:
1440 {
1441 ggml_compute_forward_sum_rows_f32(params, dst);
1442 } break;
1443 default:
1444 {
1445 GGML_ABORT("fatal error");
1446 }
1447 }
1448}
1449
1450// ggml_compute_forward_mean
1451
1452static void ggml_compute_forward_mean_f32(
1453 const ggml_compute_params * params,
1454 ggml_tensor * dst) {
1455
1456 const ggml_tensor * src0 = dst->src[0];
1457
1458 if (params->ith != 0) {
1459 return;
1460 }
1461
1462 assert(src0->nb[0] == sizeof(float));
1463
1464 GGML_TENSOR_UNARY_OP_LOCALS
1465
1466 assert(ne0 == 1);
1467 assert(ne1 == ne01);
1468 assert(ne2 == ne02);
1469 assert(ne3 == ne03);
1470
1471 GGML_UNUSED(ne0);
1472 GGML_UNUSED(ne1);
1473 GGML_UNUSED(ne2);
1474 GGML_UNUSED(ne3);
1475
1476 for (int64_t i03 = 0; i03 < ne03; i03++) {
1477 for (int64_t i02 = 0; i02 < ne02; i02++) {
1478 for (int64_t i01 = 0; i01 < ne01; i01++) {
1479 ggml_vec_sum_f32(n: ne00,
1480 s: (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
1481 x: (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1482
1483 *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
1484 }
1485 }
1486 }
1487}
1488
1489void ggml_compute_forward_mean(
1490 const ggml_compute_params * params,
1491 ggml_tensor * dst) {
1492
1493 const ggml_tensor * src0 = dst->src[0];
1494
1495 switch (src0->type) {
1496 case GGML_TYPE_F32:
1497 {
1498 ggml_compute_forward_mean_f32(params, dst);
1499 } break;
1500 default:
1501 {
1502 GGML_ABORT("fatal error");
1503 }
1504 }
1505}
1506
1507// ggml_compute_forward_argmax
1508
1509static void ggml_compute_forward_argmax_f32(
1510 const ggml_compute_params * params,
1511 ggml_tensor * dst) {
1512
1513 const ggml_tensor * src0 = dst->src[0];
1514
1515 if (params->ith != 0) {
1516 return;
1517 }
1518
1519 assert(src0->nb[0] == sizeof(float));
1520 assert(dst->nb[0] == sizeof(float));
1521
1522 const int64_t ne00 = src0->ne[0];
1523 const int64_t ne01 = src0->ne[1];
1524
1525 const size_t nb01 = src0->nb[1];
1526 const size_t nb0 = dst->nb[0];
1527
1528 for (int64_t i1 = 0; i1 < ne01; i1++) {
1529 float * src = (float *) ((char *) src0->data + i1*nb01);
1530 int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
1531 int v = 0;
1532 ggml_vec_argmax_f32(n: ne00, s: &v, x: src);
1533 dst_[0] = v;
1534 }
1535}
1536
1537void ggml_compute_forward_argmax(
1538 const ggml_compute_params * params,
1539 ggml_tensor * dst) {
1540
1541 const ggml_tensor * src0 = dst->src[0];
1542
1543 switch (src0->type) {
1544 case GGML_TYPE_F32:
1545 {
1546 ggml_compute_forward_argmax_f32(params, dst);
1547 } break;
1548 default:
1549 {
1550 GGML_ABORT("fatal error");
1551 }
1552 }
1553}
1554
1555// ggml_compute_forward_count_equal
1556
1557static void ggml_compute_forward_count_equal_i32(
1558 const ggml_compute_params * params,
1559 ggml_tensor * dst) {
1560
1561 const ggml_tensor * src0 = dst->src[0];
1562 const ggml_tensor * src1 = dst->src[1];
1563
1564 GGML_TENSOR_BINARY_OP_LOCALS;
1565
1566 GGML_ASSERT(src0->type == GGML_TYPE_I32);
1567 GGML_ASSERT(src1->type == GGML_TYPE_I32);
1568 GGML_ASSERT(ggml_are_same_shape(src0, src1));
1569 GGML_ASSERT(ggml_is_scalar(dst));
1570 GGML_ASSERT(dst->type == GGML_TYPE_I64);
1571
1572 const int64_t nr = ggml_nrows(tensor: src0);
1573
1574 const int ith = params->ith;
1575 const int nth = params->nth;
1576
1577 int64_t * sums = (int64_t *) params->wdata;
1578 int64_t sum_thread = 0;
1579
1580 // rows per thread
1581 const int64_t dr = (nr + nth - 1)/nth;
1582
1583 // row range for this thread
1584 const int64_t ir0 = dr*ith;
1585 const int64_t ir1 = MIN(ir0 + dr, nr);
1586
1587 for (int64_t ir = ir0; ir < ir1; ++ir) {
1588 const int64_t i03 = ir / (ne02*ne01);
1589 const int64_t i02 = (ir - i03*ne03) / ne01;
1590 const int64_t i01 = ir - i03*ne03 - i02*ne02;
1591
1592 const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
1593 const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
1594
1595 for (int64_t i00 = 0; i00 < ne00; ++i00) {
1596 const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
1597 const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
1598
1599 sum_thread += val0 == val1;
1600 }
1601 }
1602 if (ith != 0) {
1603 sums[ith] = sum_thread;
1604 }
1605 ggml_barrier(tp: params->threadpool);
1606
1607 if (ith != 0) {
1608 return;
1609 }
1610
1611 for (int ith_other = 1; ith_other < nth; ++ith_other) {
1612 sum_thread += sums[ith_other];
1613 }
1614 *((int64_t *) dst->data) = sum_thread;
1615}
1616
1617void ggml_compute_forward_count_equal(
1618 const ggml_compute_params * params,
1619 ggml_tensor * dst) {
1620
1621 const ggml_tensor * src0 = dst->src[0];
1622
1623 switch (src0->type) {
1624 case GGML_TYPE_I32:
1625 {
1626 ggml_compute_forward_count_equal_i32(params, dst);
1627 } break;
1628 default:
1629 {
1630 GGML_ABORT("fatal error");
1631 }
1632 }
1633}
1634
1635// ggml_compute_forward_repeat
1636
1637static void ggml_compute_forward_repeat_f32(
1638 const ggml_compute_params * params,
1639 ggml_tensor * dst) {
1640
1641 const ggml_tensor * src0 = dst->src[0];
1642
1643 if (params->ith != 0) {
1644 return;
1645 }
1646
1647 GGML_ASSERT(ggml_can_repeat(src0, dst));
1648
1649 GGML_TENSOR_UNARY_OP_LOCALS
1650
1651 // guaranteed to be an integer due to the check in ggml_can_repeat
1652 const int nr0 = (int)(ne0/ne00);
1653 const int nr1 = (int)(ne1/ne01);
1654 const int nr2 = (int)(ne2/ne02);
1655 const int nr3 = (int)(ne3/ne03);
1656
1657 // TODO: support for transposed / permuted tensors
1658 GGML_ASSERT(nb0 == sizeof(float));
1659 GGML_ASSERT(nb00 == sizeof(float));
1660
1661 // TODO: maybe this is not optimal?
1662 for (int i3 = 0; i3 < nr3; i3++) {
1663 for (int k3 = 0; k3 < ne03; k3++) {
1664 for (int i2 = 0; i2 < nr2; i2++) {
1665 for (int k2 = 0; k2 < ne02; k2++) {
1666 for (int i1 = 0; i1 < nr1; i1++) {
1667 for (int k1 = 0; k1 < ne01; k1++) {
1668 for (int i0 = 0; i0 < nr0; i0++) {
1669 ggml_vec_cpy_f32(n: ne00,
1670 y: (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
1671 x: (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
1672 }
1673 }
1674 }
1675 }
1676 }
1677 }
1678 }
1679}
1680
1681static void ggml_compute_forward_repeat_f16(
1682 const ggml_compute_params * params,
1683 ggml_tensor * dst) {
1684
1685 const ggml_tensor * src0 = dst->src[0];
1686
1687 if (params->ith != 0) {
1688 return;
1689 }
1690
1691 GGML_ASSERT(ggml_can_repeat(src0, dst));
1692
1693 GGML_TENSOR_UNARY_OP_LOCALS
1694
1695 // guaranteed to be an integer due to the check in ggml_can_repeat
1696 const int nr0 = (int)(ne0/ne00);
1697 const int nr1 = (int)(ne1/ne01);
1698 const int nr2 = (int)(ne2/ne02);
1699 const int nr3 = (int)(ne3/ne03);
1700
1701 // TODO: support for transposed / permuted tensors
1702 GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
1703 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1704
1705 // TODO: maybe this is not optimal?
1706 for (int i3 = 0; i3 < nr3; i3++) {
1707 for (int k3 = 0; k3 < ne03; k3++) {
1708 for (int i2 = 0; i2 < nr2; i2++) {
1709 for (int k2 = 0; k2 < ne02; k2++) {
1710 for (int i1 = 0; i1 < nr1; i1++) {
1711 for (int k1 = 0; k1 < ne01; k1++) {
1712 for (int i0 = 0; i0 < nr0; i0++) {
1713 ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
1714 ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
1715 // ggml_vec_cpy_f16(ne00, y, x)
1716 for (int i = 0; i < ne00; ++i) {
1717 y[i] = x[i];
1718 }
1719 }
1720 }
1721 }
1722 }
1723 }
1724 }
1725 }
1726}
1727
1728void ggml_compute_forward_repeat(
1729 const ggml_compute_params * params,
1730 ggml_tensor * dst) {
1731
1732 const ggml_tensor * src0 = dst->src[0];
1733
1734 switch (src0->type) {
1735 case GGML_TYPE_F16:
1736 case GGML_TYPE_BF16:
1737 case GGML_TYPE_I16:
1738 {
1739 ggml_compute_forward_repeat_f16(params, dst);
1740 } break;
1741 case GGML_TYPE_F32:
1742 case GGML_TYPE_I32:
1743 {
1744 ggml_compute_forward_repeat_f32(params, dst);
1745 } break;
1746 // TODO: templateify the implemenation and support for I64
1747 // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1748 //case GGML_TYPE_I64:
1749 // {
1750 // ggml_compute_forward_repeat_i64(params, dst);
1751 // } break;
1752 default:
1753 {
1754 GGML_ABORT("fatal error");
1755 }
1756 }
1757}
1758
1759// ggml_compute_forward_repeat_back
1760
1761static void ggml_compute_forward_repeat_back_f32(
1762 const ggml_compute_params * params,
1763 ggml_tensor * dst) {
1764
1765 const ggml_tensor * src0 = dst->src[0];
1766
1767 if (params->ith != 0) {
1768 return;
1769 }
1770
1771 GGML_ASSERT(ggml_can_repeat(dst, src0));
1772
1773 GGML_TENSOR_UNARY_OP_LOCALS
1774
1775 // guaranteed to be an integer due to the check in ggml_can_repeat
1776 const int nr0 = (int)(ne00/ne0);
1777 const int nr1 = (int)(ne01/ne1);
1778 const int nr2 = (int)(ne02/ne2);
1779 const int nr3 = (int)(ne03/ne3);
1780
1781 // TODO: support for transposed / permuted tensors
1782 GGML_ASSERT(nb0 == sizeof(float));
1783 GGML_ASSERT(nb00 == sizeof(float));
1784
1785 if (ggml_is_contiguous(tensor: dst)) {
1786 ggml_vec_set_f32(n: ne0*ne1*ne2*ne3, x: (float *)dst->data, v: 0);
1787 } else {
1788 for (int k3 = 0; k3 < ne3; k3++) {
1789 for (int k2 = 0; k2 < ne2; k2++) {
1790 for (int k1 = 0; k1 < ne1; k1++) {
1791 ggml_vec_set_f32(n: ne0,
1792 x: (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
1793 v: 0);
1794 }
1795 }
1796 }
1797 }
1798
1799 // TODO: maybe this is not optimal?
1800 for (int i3 = 0; i3 < nr3; i3++) {
1801 for (int k3 = 0; k3 < ne3; k3++) {
1802 for (int i2 = 0; i2 < nr2; i2++) {
1803 for (int k2 = 0; k2 < ne2; k2++) {
1804 for (int i1 = 0; i1 < nr1; i1++) {
1805 for (int k1 = 0; k1 < ne1; k1++) {
1806 for (int i0 = 0; i0 < nr0; i0++) {
1807 ggml_vec_acc_f32(n: ne0,
1808 y: (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
1809 x: (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
1810 }
1811 }
1812 }
1813 }
1814 }
1815 }
1816 }
1817}
1818
1819void ggml_compute_forward_repeat_back(
1820 const ggml_compute_params * params,
1821 ggml_tensor * dst) {
1822
1823 const ggml_tensor * src0 = dst->src[0];
1824
1825 switch (src0->type) {
1826 case GGML_TYPE_F32:
1827 {
1828 ggml_compute_forward_repeat_back_f32(params, dst);
1829 } break;
1830 default:
1831 {
1832 GGML_ABORT("fatal error");
1833 }
1834 }
1835}
1836
1837// ggml_compute_forward_concat
1838
1839static void ggml_compute_forward_concat_any(
1840 const ggml_compute_params * params,
1841 ggml_tensor * dst) {
1842
1843 const ggml_tensor * src0 = dst->src[0];
1844 const ggml_tensor * src1 = dst->src[1];
1845
1846 const size_t len = ggml_type_size(type: src0->type);
1847
1848 const int ith = params->ith;
1849 const int nth = params->nth;
1850
1851 GGML_TENSOR_BINARY_OP_LOCALS
1852
1853 const int32_t dim = ggml_get_op_params_i32(tensor: dst, i: 0);
1854
1855 GGML_ASSERT(dim >= 0 && dim < 4);
1856
1857 int64_t o[4] = {0, 0, 0, 0};
1858 o[dim] = src0->ne[dim];
1859
1860 const char * x;
1861
1862 // TODO: smarter multi-theading
1863 for (int i3 = 0; i3 < ne3; i3++) {
1864 for (int i2 = ith; i2 < ne2; i2 += nth) {
1865 for (int i1 = 0; i1 < ne1; i1++) {
1866 for (int i0 = 0; i0 < ne0; i0++) {
1867 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1868 x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
1869 } else {
1870 x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
1871 }
1872
1873 char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
1874
1875 memcpy(dest: y, src: x, n: len);
1876 }
1877 }
1878 }
1879 }
1880}
1881
1882static void ggml_compute_forward_concat_i8(
1883 const ggml_compute_params * params,
1884 ggml_tensor * dst) {
1885
1886 const ggml_tensor * src0 = dst->src[0];
1887 const ggml_tensor * src1 = dst->src[1];
1888
1889 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
1890
1891 const int ith = params->ith;
1892 const int nth = params->nth;
1893
1894 GGML_TENSOR_BINARY_OP_LOCALS
1895
1896 const int32_t dim = ggml_get_op_params_i32(tensor: dst, i: 0);
1897
1898 GGML_ASSERT(dim >= 0 && dim < 4);
1899
1900 int64_t o[4] = {0, 0, 0, 0};
1901 o[dim] = src0->ne[dim];
1902
1903 const int8_t * x;
1904
1905 // TODO: smarter multi-theading
1906 for (int i3 = 0; i3 < ne3; i3++) {
1907 for (int i2 = ith; i2 < ne2; i2 += nth) {
1908 for (int i1 = 0; i1 < ne1; i1++) {
1909 for (int i0 = 0; i0 < ne0; i0++) {
1910 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1911 x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
1912 } else {
1913 x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1914 }
1915
1916 int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1917
1918 *y = *x;
1919 }
1920 }
1921 }
1922 }
1923}
1924
1925static void ggml_compute_forward_concat_f16(
1926 const ggml_compute_params * params,
1927 ggml_tensor * dst) {
1928
1929 const ggml_tensor * src0 = dst->src[0];
1930 const ggml_tensor * src1 = dst->src[1];
1931
1932 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
1933
1934 const int ith = params->ith;
1935 const int nth = params->nth;
1936
1937 GGML_TENSOR_BINARY_OP_LOCALS
1938
1939 const int32_t dim = ggml_get_op_params_i32(tensor: dst, i: 0);
1940
1941 GGML_ASSERT(dim >= 0 && dim < 4);
1942
1943 int64_t o[4] = {0, 0, 0, 0};
1944 o[dim] = src0->ne[dim];
1945
1946 const ggml_fp16_t * x;
1947
1948 // TODO: smarter multi-theading
1949 for (int i3 = 0; i3 < ne3; i3++) {
1950 for (int i2 = ith; i2 < ne2; i2 += nth) {
1951 for (int i1 = 0; i1 < ne1; i1++) {
1952 for (int i0 = 0; i0 < ne0; i0++) {
1953 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1954 x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
1955 } else {
1956 x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1957 }
1958
1959 ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1960
1961 *y = *x;
1962 }
1963 }
1964 }
1965 }
1966}
1967
1968static void ggml_compute_forward_concat_f32(
1969 const ggml_compute_params * params,
1970 ggml_tensor * dst) {
1971
1972 const ggml_tensor * src0 = dst->src[0];
1973 const ggml_tensor * src1 = dst->src[1];
1974
1975 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
1976
1977 const int ith = params->ith;
1978 const int nth = params->nth;
1979
1980 GGML_TENSOR_BINARY_OP_LOCALS
1981
1982 const int32_t dim = ggml_get_op_params_i32(tensor: dst, i: 0);
1983
1984 GGML_ASSERT(dim >= 0 && dim < 4);
1985
1986 int64_t o[4] = {0, 0, 0, 0};
1987 o[dim] = src0->ne[dim];
1988
1989 const float * x;
1990
1991 // TODO: smarter multi-theading
1992 for (int i3 = 0; i3 < ne3; i3++) {
1993 for (int i2 = ith; i2 < ne2; i2 += nth) {
1994 for (int i1 = 0; i1 < ne1; i1++) {
1995 for (int i0 = 0; i0 < ne0; i0++) {
1996 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1997 x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
1998 } else {
1999 x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2000 }
2001
2002 float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2003
2004 *y = *x;
2005 }
2006 }
2007 }
2008 }
2009}
2010
2011void ggml_compute_forward_concat(
2012 const ggml_compute_params * params,
2013 ggml_tensor * dst) {
2014
2015 const ggml_tensor * src0 = dst->src[0];
2016
2017 switch (src0->type) {
2018 case GGML_TYPE_F16:
2019 case GGML_TYPE_BF16:
2020 case GGML_TYPE_I16:
2021 {
2022 ggml_compute_forward_concat_f16(params, dst);
2023 } break;
2024 case GGML_TYPE_I8:
2025 {
2026 ggml_compute_forward_concat_i8(params, dst);
2027 } break;
2028 case GGML_TYPE_F32:
2029 case GGML_TYPE_I32:
2030 {
2031 ggml_compute_forward_concat_f32(params, dst);
2032 } break;
2033 default:
2034 {
2035 ggml_compute_forward_concat_any(params, dst);
2036 }
2037 }
2038}
2039
2040// ggml_compute_forward_gelu
2041
2042static void ggml_compute_forward_gelu_f32(
2043 const ggml_compute_params * params,
2044 ggml_tensor * dst) {
2045
2046 const ggml_tensor * src0 = dst->src[0];
2047
2048 assert(ggml_is_contiguous_1(src0));
2049 assert(ggml_is_contiguous_1(dst));
2050 assert(ggml_are_same_shape(src0, dst));
2051
2052 const int ith = params->ith;
2053 const int nth = params->nth;
2054
2055 const int nc = src0->ne[0];
2056 const int nr = ggml_nrows(tensor: src0);
2057
2058 // rows per thread
2059 const int dr = (nr + nth - 1)/nth;
2060
2061 // row range for this thread
2062 const int ir0 = dr*ith;
2063 const int ir1 = MIN(ir0 + dr, nr);
2064
2065 for (int i1 = ir0; i1 < ir1; i1++) {
2066 ggml_vec_gelu_f32(n: nc,
2067 y: (float *) ((char *) dst->data + i1*( dst->nb[1])),
2068 x: (float *) ((char *) src0->data + i1*(src0->nb[1])));
2069
2070#ifndef NDEBUG
2071 for (int k = 0; k < nc; k++) {
2072 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2073 GGML_UNUSED(x);
2074 assert(!isnan(x));
2075 assert(!isinf(x));
2076 }
2077#endif
2078 }
2079}
2080
2081static void ggml_compute_forward_gelu_f16(
2082 const ggml_compute_params * params,
2083 ggml_tensor * dst) {
2084
2085 const ggml_tensor * src0 = dst->src[0];
2086
2087 assert(ggml_is_contiguous_1(src0));
2088 assert(ggml_is_contiguous_1(dst));
2089 assert(ggml_are_same_shape(src0, dst));
2090
2091 const int ith = params->ith;
2092 const int nth = params->nth;
2093
2094 const int nc = src0->ne[0];
2095 const int nr = ggml_nrows(tensor: src0);
2096
2097 // rows per thread
2098 const int dr = (nr + nth - 1)/nth;
2099
2100 // row range for this thread
2101 const int ir0 = dr*ith;
2102 const int ir1 = MIN(ir0 + dr, nr);
2103
2104 for (int i1 = ir0; i1 < ir1; i1++) {
2105 ggml_vec_gelu_f16(n: nc,
2106 y: (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2107 x: (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2108
2109#ifndef NDEBUG
2110 for (int k = 0; k < nc; k++) {
2111 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2112 const float v = GGML_CPU_FP16_TO_FP32(x);
2113 GGML_UNUSED(v);
2114 assert(!isnan(v));
2115 assert(!isinf(v));
2116 }
2117#endif
2118 }
2119}
2120
2121static void ggml_compute_forward_gelu(
2122 const ggml_compute_params * params,
2123 ggml_tensor * dst) {
2124
2125 const ggml_tensor * src0 = dst->src[0];
2126
2127 switch (src0->type) {
2128 case GGML_TYPE_F32:
2129 {
2130 ggml_compute_forward_gelu_f32(params, dst);
2131 } break;
2132 case GGML_TYPE_F16:
2133 {
2134 ggml_compute_forward_gelu_f16(params, dst);
2135 } break;
2136 default:
2137 {
2138 GGML_ABORT("fatal error");
2139 }
2140 }
2141}
2142
2143// ggml_compute_forward_gelu_erf
2144
2145static void ggml_compute_forward_gelu_erf_f32(
2146 const ggml_compute_params * params,
2147 ggml_tensor * dst) {
2148
2149 const ggml_tensor * src0 = dst->src[0];
2150
2151 assert(ggml_is_contiguous_1(src0));
2152 assert(ggml_is_contiguous_1(dst));
2153 assert(ggml_are_same_shape(src0, dst));
2154
2155 const int ith = params->ith;
2156 const int nth = params->nth;
2157
2158 const int nc = src0->ne[0];
2159 const int nr = ggml_nrows(tensor: src0);
2160
2161 // rows per thread
2162 const int dr = (nr + nth - 1)/nth;
2163
2164 // row range for this thread
2165 const int ir0 = dr*ith;
2166 const int ir1 = MIN(ir0 + dr, nr);
2167
2168 for (int i1 = ir0; i1 < ir1; i1++) {
2169 ggml_vec_gelu_erf_f32(n: nc,
2170 y: (float *) ((char *) dst->data + i1*( dst->nb[1])),
2171 x: (float *) ((char *) src0->data + i1*(src0->nb[1])));
2172
2173#ifndef NDEBUG
2174 for (int k = 0; k < nc; k++) {
2175 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2176 GGML_UNUSED(x);
2177 assert(!isnan(x));
2178 assert(!isinf(x));
2179 }
2180#endif
2181 }
2182}
2183
2184static void ggml_compute_forward_gelu_erf_f16(
2185 const ggml_compute_params * params,
2186 ggml_tensor * dst) {
2187
2188 const ggml_tensor * src0 = dst->src[0];
2189
2190 assert(ggml_is_contiguous_1(src0));
2191 assert(ggml_is_contiguous_1(dst));
2192 assert(ggml_are_same_shape(src0, dst));
2193
2194 const int ith = params->ith;
2195 const int nth = params->nth;
2196
2197 const int nc = src0->ne[0];
2198 const int nr = ggml_nrows(tensor: src0);
2199
2200 // rows per thread
2201 const int dr = (nr + nth - 1)/nth;
2202
2203 // row range for this thread
2204 const int ir0 = dr*ith;
2205 const int ir1 = MIN(ir0 + dr, nr);
2206
2207 for (int i1 = ir0; i1 < ir1; i1++) {
2208 ggml_vec_gelu_erf_f16(n: nc,
2209 y: (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2210 x: (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2211
2212#ifndef NDEBUG
2213 for (int k = 0; k < nc; k++) {
2214 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2215 const float v = GGML_CPU_FP16_TO_FP32(x);
2216 GGML_UNUSED(v);
2217 assert(!isnan(v));
2218 assert(!isinf(v));
2219 }
2220#endif
2221 }
2222}
2223
2224static void ggml_compute_forward_gelu_erf(
2225 const ggml_compute_params * params,
2226 ggml_tensor * dst) {
2227
2228 const ggml_tensor * src0 = dst->src[0];
2229
2230 switch (src0->type) {
2231 case GGML_TYPE_F32:
2232 {
2233 ggml_compute_forward_gelu_erf_f32(params, dst);
2234 } break;
2235 case GGML_TYPE_F16:
2236 {
2237 ggml_compute_forward_gelu_erf_f16(params, dst);
2238 } break;
2239 default:
2240 {
2241 GGML_ABORT("fatal error");
2242 }
2243 }
2244}
2245
2246// ggml_compute_forward_gelu_quick
2247
2248static void ggml_compute_forward_gelu_quick_f32(
2249 const ggml_compute_params * params,
2250 ggml_tensor * dst) {
2251
2252 const ggml_tensor * src0 = dst->src[0];
2253
2254 assert(ggml_is_contiguous_1(src0));
2255 assert(ggml_is_contiguous_1(dst));
2256 assert(ggml_are_same_shape(src0, dst));
2257
2258 const int ith = params->ith;
2259 const int nth = params->nth;
2260
2261 const int nc = src0->ne[0];
2262 const int nr = ggml_nrows(tensor: src0);
2263
2264 // rows per thread
2265 const int dr = (nr + nth - 1)/nth;
2266
2267 // row range for this thread
2268 const int ir0 = dr*ith;
2269 const int ir1 = MIN(ir0 + dr, nr);
2270
2271 for (int i1 = ir0; i1 < ir1; i1++) {
2272 ggml_vec_gelu_quick_f32(n: nc,
2273 y: (float *) ((char *) dst->data + i1*( dst->nb[1])),
2274 x: (float *) ((char *) src0->data + i1*(src0->nb[1])));
2275
2276#ifndef NDEBUG
2277 for (int k = 0; k < nc; k++) {
2278 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2279 GGML_UNUSED(x);
2280 assert(!isnan(x));
2281 assert(!isinf(x));
2282 }
2283#endif
2284 }
2285}
2286
2287static void ggml_compute_forward_gelu_quick_f16(
2288 const ggml_compute_params * params,
2289 ggml_tensor * dst) {
2290
2291 const ggml_tensor * src0 = dst->src[0];
2292
2293 assert(ggml_is_contiguous_1(src0));
2294 assert(ggml_is_contiguous_1(dst));
2295 assert(ggml_are_same_shape(src0, dst));
2296
2297 const int ith = params->ith;
2298 const int nth = params->nth;
2299
2300 const int nc = src0->ne[0];
2301 const int nr = ggml_nrows(tensor: src0);
2302
2303 // rows per thread
2304 const int dr = (nr + nth - 1)/nth;
2305
2306 // row range for this thread
2307 const int ir0 = dr*ith;
2308 const int ir1 = MIN(ir0 + dr, nr);
2309
2310 for (int i1 = ir0; i1 < ir1; i1++) {
2311 ggml_vec_gelu_quick_f16(n: nc,
2312 y: (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2313 x: (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2314
2315#ifndef NDEBUG
2316 for (int k = 0; k < nc; k++) {
2317 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2318 const float v = GGML_CPU_FP16_TO_FP32(x);
2319 GGML_UNUSED(v);
2320 assert(!isnan(v));
2321 assert(!isinf(v));
2322 }
2323#endif
2324 }
2325}
2326
2327static void ggml_compute_forward_gelu_quick(
2328 const ggml_compute_params * params,
2329 ggml_tensor * dst) {
2330
2331 const ggml_tensor * src0 = dst->src[0];
2332
2333 switch (src0->type) {
2334 case GGML_TYPE_F32:
2335 {
2336 ggml_compute_forward_gelu_quick_f32(params, dst);
2337 } break;
2338 case GGML_TYPE_F16:
2339 {
2340 ggml_compute_forward_gelu_quick_f16(params, dst);
2341 } break;
2342 default:
2343 {
2344 GGML_ABORT("fatal error");
2345 }
2346 }
2347}
2348
2349// ggml_compute_forward_silu
2350
2351static void ggml_compute_forward_silu_f32(
2352 const ggml_compute_params * params,
2353 ggml_tensor * dst) {
2354
2355 const ggml_tensor * src0 = dst->src[0];
2356
2357 assert(ggml_is_contiguous_1(src0));
2358 assert(ggml_is_contiguous_1(dst));
2359 assert(ggml_are_same_shape(src0, dst));
2360
2361 const int ith = params->ith;
2362 const int nth = params->nth;
2363
2364 const int nc = src0->ne[0];
2365 const int nr = ggml_nrows(tensor: src0);
2366
2367 // rows per thread
2368 const int dr = (nr + nth - 1)/nth;
2369
2370 // row range for this thread
2371 const int ir0 = dr*ith;
2372 const int ir1 = MIN(ir0 + dr, nr);
2373
2374 for (int i1 = ir0; i1 < ir1; i1++) {
2375 ggml_vec_silu_f32(n: nc,
2376 y: (float *) ((char *) dst->data + i1*( dst->nb[1])),
2377 x: (float *) ((char *) src0->data + i1*(src0->nb[1])));
2378
2379#ifndef NDEBUG
2380 for (int k = 0; k < nc; k++) {
2381 const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2382 GGML_UNUSED(x);
2383 assert(!isnan(x));
2384 assert(!isinf(x));
2385 }
2386#endif
2387 }
2388}
2389
2390static void ggml_compute_forward_silu_f16(
2391 const ggml_compute_params * params,
2392 ggml_tensor * dst) {
2393
2394 const ggml_tensor * src0 = dst->src[0];
2395
2396 assert(ggml_is_contiguous_1(src0));
2397 assert(ggml_is_contiguous_1(dst));
2398 assert(ggml_are_same_shape(src0, dst));
2399
2400 const int ith = params->ith;
2401 const int nth = params->nth;
2402
2403 const int nc = src0->ne[0];
2404 const int nr = ggml_nrows(tensor: src0);
2405
2406 // rows per thread
2407 const int dr = (nr + nth - 1)/nth;
2408
2409 // row range for this thread
2410 const int ir0 = dr*ith;
2411 const int ir1 = MIN(ir0 + dr, nr);
2412
2413 for (int i1 = ir0; i1 < ir1; i1++) {
2414 ggml_vec_silu_f16(n: nc,
2415 y: (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2416 x: (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2417
2418#ifndef NDEBUG
2419 for (int k = 0; k < nc; k++) {
2420 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2421 const float v = GGML_CPU_FP16_TO_FP32(x);
2422 GGML_UNUSED(v);
2423 assert(!isnan(v));
2424 assert(!isinf(v));
2425 }
2426#endif
2427 }
2428}
2429
2430static void ggml_compute_forward_silu(
2431 const ggml_compute_params * params,
2432 ggml_tensor * dst) {
2433
2434 const ggml_tensor * src0 = dst->src[0];
2435
2436 switch (src0->type) {
2437 case GGML_TYPE_F32:
2438 {
2439 ggml_compute_forward_silu_f32(params, dst);
2440 } break;
2441 case GGML_TYPE_F16:
2442 {
2443 ggml_compute_forward_silu_f16(params, dst);
2444 } break;
2445 default:
2446 {
2447 GGML_ABORT("fatal error");
2448 }
2449 }
2450}
2451// ggml_compute_forward_leaky_relu
2452
2453static void ggml_compute_forward_leaky_relu_f32(
2454 const ggml_compute_params * params,
2455 ggml_tensor * dst) {
2456
2457 const ggml_tensor * src0 = dst->src[0];
2458
2459 if (params->ith != 0) {
2460 return;
2461 }
2462
2463 assert(ggml_is_contiguous_1(src0));
2464 assert(ggml_is_contiguous_1(dst));
2465 assert(ggml_are_same_shape(src0, dst));
2466
2467 const int n = ggml_nrows(tensor: src0);
2468 const int nc = src0->ne[0];
2469
2470 float negative_slope;
2471 memcpy(dest: &negative_slope, src: dst->op_params, n: sizeof(float));
2472
2473 assert(dst->nb[0] == sizeof(float));
2474 assert(src0->nb[0] == sizeof(float));
2475
2476 for (int i = 0; i < n; i++) {
2477 ggml_vec_leaky_relu_f32(n: nc,
2478 y: (float *) ((char *) dst->data + i*( dst->nb[1])),
2479 x: (float *) ((char *) src0->data + i*(src0->nb[1])), ns: negative_slope);
2480 }
2481}
2482
2483static void ggml_compute_forward_leaky_relu_f16(
2484 const ggml_compute_params * params,
2485 ggml_tensor * dst) {
2486
2487 const ggml_tensor * src0 = dst->src[0];
2488
2489 if (params->ith != 0) {
2490 return;
2491 }
2492
2493 assert(ggml_is_contiguous_1(src0));
2494 assert(ggml_is_contiguous_1(dst));
2495 assert(ggml_are_same_shape(src0, dst));
2496
2497 const int n = ggml_nrows(tensor: src0);
2498 const int nc = src0->ne[0];
2499
2500 float negative_slope;
2501 memcpy(dest: &negative_slope, src: dst->op_params, n: sizeof(float));
2502
2503 assert(dst->nb[0] == sizeof(ggml_fp16_t));
2504 assert(src0->nb[0] == sizeof(ggml_fp16_t));
2505
2506 for (int i = 0; i < n; i++) {
2507 ggml_vec_leaky_relu_f16(n: nc,
2508 y: (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
2509 x: (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), ns: negative_slope);
2510 }
2511}
2512
2513void ggml_compute_forward_leaky_relu(
2514 const ggml_compute_params * params,
2515 ggml_tensor * dst) {
2516
2517 const ggml_tensor * src0 = dst->src[0];
2518
2519 switch (src0->type) {
2520 case GGML_TYPE_F32:
2521 {
2522 ggml_compute_forward_leaky_relu_f32(params, dst);
2523 } break;
2524 case GGML_TYPE_F16:
2525 {
2526 ggml_compute_forward_leaky_relu_f16(params, dst);
2527 } break;
2528 default:
2529 {
2530 GGML_ABORT("fatal error");
2531 }
2532 }
2533}
2534
2535// ggml_compute_forward_silu_back
2536
2537static void ggml_compute_forward_silu_back_f32(
2538 const ggml_compute_params * params,
2539 ggml_tensor * dst) {
2540
2541 const ggml_tensor * grad = dst->src[0];
2542 const ggml_tensor * src1 = dst->src[1];
2543
2544 assert(ggml_is_contiguous_1(grad));
2545 assert(ggml_is_contiguous_1(src1));
2546 assert(ggml_is_contiguous_1(dst));
2547 assert(ggml_are_same_shape(src1, dst));
2548 assert(ggml_are_same_shape(src1, grad));
2549
2550 const int ith = params->ith;
2551 const int nth = params->nth;
2552
2553 const int nc = src1->ne[0];
2554 const int nr = ggml_nrows(tensor: src1);
2555
2556 // rows per thread
2557 const int dr = (nr + nth - 1)/nth;
2558
2559 // row range for this thread
2560 const int ir0 = dr*ith;
2561 const int ir1 = MIN(ir0 + dr, nr);
2562
2563 for (int i1 = ir0; i1 < ir1; i1++) {
2564 ggml_vec_silu_backward_f32(n: nc,
2565 dx: (float *) ((char *) dst->data + i1*( dst->nb[1])),
2566 x: (float *) ((char *) src1->data + i1*(src1->nb[1])),
2567 dy: (float *) ((char *) grad->data + i1*(grad->nb[1])));
2568
2569#ifndef NDEBUG
2570 for (int k = 0; k < nc; k++) {
2571 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2572 GGML_UNUSED(x);
2573 assert(!isnan(x));
2574 assert(!isinf(x));
2575 }
2576#endif
2577 }
2578}
2579
2580static void ggml_compute_forward_silu_back_f16(
2581 const ggml_compute_params * params,
2582 ggml_tensor * dst) {
2583
2584 const ggml_tensor * grad = dst->src[0];
2585 const ggml_tensor * src1 = dst->src[1];
2586
2587 assert(ggml_is_contiguous_1(grad));
2588 assert(ggml_is_contiguous_1(src1));
2589 assert(ggml_is_contiguous_1(dst));
2590 assert(ggml_are_same_shape(src1, dst));
2591 assert(ggml_are_same_shape(src1, grad));
2592
2593 const int ith = params->ith;
2594 const int nth = params->nth;
2595
2596 const int nc = src1->ne[0];
2597 const int nr = ggml_nrows(tensor: src1);
2598
2599 // rows per thread
2600 const int dr = (nr + nth - 1)/nth;
2601
2602 // row range for this thread
2603 const int ir0 = dr*ith;
2604 const int ir1 = MIN(ir0 + dr, nr);
2605
2606 for (int i1 = ir0; i1 < ir1; i1++) {
2607 ggml_vec_silu_backward_f16(n: nc,
2608 dx: (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2609 x: (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2610 dy: (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2611
2612 #ifndef NDEBUG
2613 for (int k = 0; k < nc; k++) {
2614 const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2615 const float v = GGML_CPU_FP16_TO_FP32(x);
2616 GGML_UNUSED(v);
2617 assert(!isnan(v));
2618 assert(!isinf(v));
2619 }
2620 #endif
2621 }
2622}
2623
2624void ggml_compute_forward_silu_back(
2625 const ggml_compute_params * params,
2626 ggml_tensor * dst) {
2627
2628 const ggml_tensor * src0 = dst->src[0];
2629
2630 switch (src0->type) {
2631 case GGML_TYPE_F32:
2632 {
2633 ggml_compute_forward_silu_back_f32(params, dst);
2634 } break;
2635 case GGML_TYPE_F16:
2636 {
2637 ggml_compute_forward_silu_back_f16(params, dst);
2638 } break;
2639 default:
2640 {
2641 GGML_ABORT("fatal error");
2642 }
2643 }
2644}
2645
2646// ggml_compute_forward_reglu
2647
2648static void ggml_compute_forward_reglu_f32(
2649 const ggml_compute_params * params,
2650 ggml_tensor * dst) {
2651
2652 const ggml_tensor * src0 = dst->src[0];
2653 const ggml_tensor * src1 = dst->src[1];
2654 char * src0_d = (char *) src0->data;
2655 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2656 const size_t src0_o = src0->nb[1];
2657 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2658
2659 GGML_ASSERT(ggml_is_contiguous_1(src0));
2660 GGML_ASSERT(ggml_is_contiguous_1(dst));
2661
2662 if (src1) {
2663 GGML_ASSERT(ggml_is_contiguous_1(src1));
2664 GGML_ASSERT(src0->type == src1->type);
2665 }
2666
2667 const int ith = params->ith;
2668 const int nth = params->nth;
2669
2670 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2671 const int nr = ggml_nrows(tensor: src0);
2672
2673 GGML_ASSERT(dst->ne[0] == nc);
2674 GGML_ASSERT(ggml_nrows(dst) == nr);
2675
2676 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
2677
2678 // rows per thread
2679 const int dr = (nr + nth - 1)/nth;
2680
2681 // row range for this thread
2682 const int ir0 = dr*ith;
2683 const int ir1 = MIN(ir0 + dr, nr);
2684
2685 for (int i1 = ir0; i1 < ir1; i1++) {
2686 float * src0_p = (float *) (src0_d + i1*src0_o);
2687 float * src1_p = (float *) (src1_d + i1*src1_o);
2688
2689 if (!src1) {
2690 src0_p += swapped ? nc : 0;
2691 src1_p += swapped ? 0 : nc;
2692 }
2693
2694 ggml_vec_reglu_f32(n: nc, y: (float *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
2695
2696#ifndef NDEBUG
2697 for (int k = 0; k < nc; k++) {
2698 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2699 GGML_UNUSED(x);
2700 assert(!isnan(x));
2701 assert(!isinf(x));
2702 }
2703#endif
2704 }
2705}
2706
2707static void ggml_compute_forward_reglu_f16(
2708 const ggml_compute_params * params,
2709 ggml_tensor * dst) {
2710
2711 const ggml_tensor * src0 = dst->src[0];
2712 const ggml_tensor * src1 = dst->src[1];
2713 char * src0_d = (char *) src0->data;
2714 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2715 const size_t src0_o = src0->nb[1];
2716 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2717
2718 GGML_ASSERT(ggml_is_contiguous_1(src0));
2719 GGML_ASSERT(ggml_is_contiguous_1(dst));
2720
2721 if (src1) {
2722 GGML_ASSERT(ggml_is_contiguous_1(src1));
2723 GGML_ASSERT(src0->type == src1->type);
2724 }
2725
2726 const int ith = params->ith;
2727 const int nth = params->nth;
2728
2729 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2730 const int nr = ggml_nrows(tensor: src0);
2731
2732 GGML_ASSERT(dst->ne[0] == nc);
2733 GGML_ASSERT(ggml_nrows(dst) == nr);
2734
2735 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
2736
2737 // rows per thread
2738 const int dr = (nr + nth - 1)/nth;
2739
2740 // row range for this thread
2741 const int ir0 = dr*ith;
2742 const int ir1 = MIN(ir0 + dr, nr);
2743
2744 for (int i1 = ir0; i1 < ir1; i1++) {
2745 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2746 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2747
2748 if (!src1) {
2749 src0_p += swapped ? nc : 0;
2750 src1_p += swapped ? 0 : nc;
2751 }
2752
2753 ggml_vec_reglu_f16(n: nc, y: (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
2754
2755#ifndef NDEBUG
2756 for (int k = 0; k < nc; k++) {
2757 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2758 const float v = GGML_FP16_TO_FP32(x);
2759 GGML_UNUSED(v);
2760 assert(!isnan(v));
2761 assert(!isinf(v));
2762 }
2763#endif
2764 }
2765}
2766
2767static void ggml_compute_forward_reglu(
2768 const ggml_compute_params * params,
2769 ggml_tensor * dst) {
2770
2771 const ggml_tensor * src0 = dst->src[0];
2772
2773 switch (src0->type) {
2774 case GGML_TYPE_F32:
2775 {
2776 ggml_compute_forward_reglu_f32(params, dst);
2777 } break;
2778 case GGML_TYPE_F16:
2779 {
2780 ggml_compute_forward_reglu_f16(params, dst);
2781 } break;
2782 default:
2783 {
2784 GGML_ABORT("fatal error");
2785 }
2786 }
2787}
2788
2789// ggml_compute_forward_geglu
2790
2791static void ggml_compute_forward_geglu_f32(
2792 const ggml_compute_params * params,
2793 ggml_tensor * dst) {
2794
2795 const ggml_tensor * src0 = dst->src[0];
2796 const ggml_tensor * src1 = dst->src[1];
2797 char * src0_d = (char *) src0->data;
2798 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2799 const size_t src0_o = src0->nb[1];
2800 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2801
2802 GGML_ASSERT(ggml_is_contiguous_1(src0));
2803 GGML_ASSERT(ggml_is_contiguous_1(dst));
2804
2805 if (src1) {
2806 GGML_ASSERT(ggml_is_contiguous_1(src1));
2807 GGML_ASSERT(src0->type == src1->type);
2808 }
2809
2810 const int ith = params->ith;
2811 const int nth = params->nth;
2812
2813 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2814 const int nr = ggml_nrows(tensor: src0);
2815
2816 GGML_ASSERT(dst->ne[0] == nc);
2817 GGML_ASSERT(ggml_nrows(dst) == nr);
2818
2819 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
2820
2821 // rows per thread
2822 const int dr = (nr + nth - 1)/nth;
2823
2824 // row range for this thread
2825 const int ir0 = dr*ith;
2826 const int ir1 = MIN(ir0 + dr, nr);
2827
2828 for (int i1 = ir0; i1 < ir1; i1++) {
2829 float * src0_p = (float *) (src0_d + i1*src0_o);
2830 float * src1_p = (float *) (src1_d + i1*src1_o);
2831
2832 if (!src1) {
2833 src0_p += swapped ? nc : 0;
2834 src1_p += swapped ? 0 : nc;
2835 }
2836
2837 ggml_vec_geglu_f32(n: nc, y: (float *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
2838
2839#ifndef NDEBUG
2840 for (int k = 0; k < nc; k++) {
2841 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2842 GGML_UNUSED(x);
2843 assert(!isnan(x));
2844 assert(!isinf(x));
2845 }
2846#endif
2847 }
2848}
2849
2850static void ggml_compute_forward_geglu_f16(
2851 const ggml_compute_params * params,
2852 ggml_tensor * dst) {
2853
2854 const ggml_tensor * src0 = dst->src[0];
2855 const ggml_tensor * src1 = dst->src[1];
2856 char * src0_d = (char *) src0->data;
2857 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2858 const size_t src0_o = src0->nb[1];
2859 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2860
2861 GGML_ASSERT(ggml_is_contiguous_1(src0));
2862 GGML_ASSERT(ggml_is_contiguous_1(dst));
2863
2864 if (src1) {
2865 GGML_ASSERT(ggml_is_contiguous_1(src1));
2866 GGML_ASSERT(src0->type == src1->type);
2867 }
2868
2869 const int ith = params->ith;
2870 const int nth = params->nth;
2871
2872 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2873 const int nr = ggml_nrows(tensor: src0);
2874
2875 GGML_ASSERT(dst->ne[0] == nc);
2876 GGML_ASSERT(ggml_nrows(dst) == nr);
2877
2878 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
2879
2880 // rows per thread
2881 const int dr = (nr + nth - 1)/nth;
2882
2883 // row range for this thread
2884 const int ir0 = dr*ith;
2885 const int ir1 = MIN(ir0 + dr, nr);
2886
2887 for (int i1 = ir0; i1 < ir1; i1++) {
2888 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2889 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2890
2891 if (!src1) {
2892 src0_p += swapped ? nc : 0;
2893 src1_p += swapped ? 0 : nc;
2894 }
2895
2896 ggml_vec_geglu_f16(n: nc, y: (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
2897
2898#ifndef NDEBUG
2899 for (int k = 0; k < nc; k++) {
2900 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2901 const float v = GGML_FP16_TO_FP32(x);
2902 GGML_UNUSED(v);
2903 assert(!isnan(v));
2904 assert(!isinf(v));
2905 }
2906#endif
2907 }
2908}
2909
2910static void ggml_compute_forward_geglu(
2911 const ggml_compute_params * params,
2912 ggml_tensor * dst) {
2913
2914 const ggml_tensor * src0 = dst->src[0];
2915
2916 switch (src0->type) {
2917 case GGML_TYPE_F32:
2918 {
2919 ggml_compute_forward_geglu_f32(params, dst);
2920 } break;
2921 case GGML_TYPE_F16:
2922 {
2923 ggml_compute_forward_geglu_f16(params, dst);
2924 } break;
2925 default:
2926 {
2927 GGML_ABORT("fatal error");
2928 }
2929 }
2930}
2931
2932// ggml_compute_forward_swiglu
2933
2934static void ggml_compute_forward_swiglu_f32(
2935 const ggml_compute_params * params,
2936 ggml_tensor * dst) {
2937
2938 const ggml_tensor * src0 = dst->src[0];
2939 const ggml_tensor * src1 = dst->src[1];
2940 char * src0_d = (char *) src0->data;
2941 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2942 const size_t src0_o = src0->nb[1];
2943 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2944
2945 GGML_ASSERT(ggml_is_contiguous_1(src0));
2946 GGML_ASSERT(ggml_is_contiguous_1(dst));
2947
2948 if (src1) {
2949 GGML_ASSERT(ggml_is_contiguous_1(src1));
2950 GGML_ASSERT(src0->type == src1->type);
2951 }
2952
2953 const int ith = params->ith;
2954 const int nth = params->nth;
2955
2956 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2957 const int nr = ggml_nrows(tensor: src0);
2958
2959 GGML_ASSERT(dst->ne[0] == nc);
2960 GGML_ASSERT(ggml_nrows(dst) == nr);
2961
2962 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
2963
2964 // rows per thread
2965 const int dr = (nr + nth - 1)/nth;
2966
2967 // row range for this thread
2968 const int ir0 = dr*ith;
2969 const int ir1 = MIN(ir0 + dr, nr);
2970
2971 for (int i1 = ir0; i1 < ir1; i1++) {
2972 float * src0_p = (float *) (src0_d + i1*src0_o);
2973 float * src1_p = (float *) (src1_d + i1*src1_o);
2974
2975 if (!src1) {
2976 src0_p += swapped ? nc : 0;
2977 src1_p += swapped ? 0 : nc;
2978 }
2979
2980 ggml_vec_swiglu_f32(n: nc, y: (float *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
2981
2982#ifndef NDEBUG
2983 for (int k = 0; k < nc; k++) {
2984 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2985 GGML_UNUSED(x);
2986 assert(!isnan(x));
2987 assert(!isinf(x));
2988 }
2989#endif
2990 }
2991}
2992
2993static void ggml_compute_forward_swiglu_f16(
2994 const ggml_compute_params * params,
2995 ggml_tensor * dst) {
2996
2997 const ggml_tensor * src0 = dst->src[0];
2998 const ggml_tensor * src1 = dst->src[1];
2999 char * src0_d = (char *) src0->data;
3000 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3001 const size_t src0_o = src0->nb[1];
3002 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3003
3004 GGML_ASSERT(ggml_is_contiguous_1(src0));
3005 GGML_ASSERT(ggml_is_contiguous_1(dst));
3006
3007 if (src1) {
3008 GGML_ASSERT(ggml_is_contiguous_1(src1));
3009 GGML_ASSERT(src0->type == src1->type);
3010 }
3011
3012 const int ith = params->ith;
3013 const int nth = params->nth;
3014
3015 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3016 const int nr = ggml_nrows(tensor: src0);
3017
3018 GGML_ASSERT(dst->ne[0] == nc);
3019 GGML_ASSERT(ggml_nrows(dst) == nr);
3020
3021 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3022
3023 // rows per thread
3024 const int dr = (nr + nth - 1)/nth;
3025
3026 // row range for this thread
3027 const int ir0 = dr*ith;
3028 const int ir1 = MIN(ir0 + dr, nr);
3029
3030 for (int i1 = ir0; i1 < ir1; i1++) {
3031 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3032 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3033
3034 if (!src1) {
3035 src0_p += swapped ? nc : 0;
3036 src1_p += swapped ? 0 : nc;
3037 }
3038
3039 ggml_vec_swiglu_f16(n: nc, y: (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
3040
3041#ifndef NDEBUG
3042 for (int k = 0; k < nc; k++) {
3043 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3044 const float v = GGML_FP16_TO_FP32(x);
3045 GGML_UNUSED(v);
3046 assert(!isnan(v));
3047 assert(!isinf(v));
3048 }
3049#endif
3050 }
3051}
3052
3053static void ggml_compute_forward_swiglu(
3054 const ggml_compute_params * params,
3055 ggml_tensor * dst) {
3056
3057 const ggml_tensor * src0 = dst->src[0];
3058
3059 switch (src0->type) {
3060 case GGML_TYPE_F32:
3061 {
3062 ggml_compute_forward_swiglu_f32(params, dst);
3063 } break;
3064 case GGML_TYPE_F16:
3065 {
3066 ggml_compute_forward_swiglu_f16(params, dst);
3067 } break;
3068 default:
3069 {
3070 GGML_ABORT("fatal error");
3071 }
3072 }
3073}
3074
3075// ggml_compute_forward_swiglu_oai
3076
3077static void ggml_compute_forward_swiglu_oai_f32(
3078 const ggml_compute_params * params,
3079 ggml_tensor * dst) {
3080
3081 const ggml_tensor * src0 = dst->src[0];
3082 const ggml_tensor * src1 = dst->src[1];
3083 char * src0_d = (char *) src0->data;
3084 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3085 const size_t src0_o = src0->nb[1];
3086 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3087
3088 GGML_ASSERT(ggml_is_contiguous_1(src0));
3089 GGML_ASSERT(ggml_is_contiguous_1(dst));
3090
3091 if (src1) {
3092 GGML_ASSERT(ggml_is_contiguous_1(src1));
3093 GGML_ASSERT(src0->type == src1->type);
3094 }
3095
3096 const int ith = params->ith;
3097 const int nth = params->nth;
3098
3099 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3100 const int nr = ggml_nrows(tensor: src0);
3101
3102 GGML_ASSERT(dst->ne[0] == nc);
3103 GGML_ASSERT(ggml_nrows(dst) == nr);
3104
3105 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3106 const float alpha = ggml_get_op_params_f32(tensor: dst, i: 2);
3107 const float limit = ggml_get_op_params_f32(tensor: dst, i: 3);
3108
3109 // rows per thread
3110 const int dr = (nr + nth - 1)/nth;
3111
3112 // row range for this thread
3113 const int ir0 = dr*ith;
3114 const int ir1 = MIN(ir0 + dr, nr);
3115
3116 for (int i1 = ir0; i1 < ir1; i1++) {
3117 float * src0_p = (float *) (src0_d + i1*src0_o);
3118 float * src1_p = (float *) (src1_d + i1*src1_o);
3119 float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3120
3121 if (!src1) {
3122 src0_p += swapped ? nc : 0;
3123 src1_p += swapped ? 0 : nc;
3124 }
3125
3126 for (int k = 0; k < nc; k++) {
3127 const float x = std::min(a: src0_p[k], b: limit);
3128 const float y = std::clamp(val: src1_p[k], lo: -limit, hi: limit);
3129 const float out_glu = x / (1.f + expf(x: alpha * (-x)));
3130 dst_p[k] = out_glu * (y + 1.f);
3131 }
3132
3133#ifndef NDEBUG
3134 for (int k = 0; k < nc; k++) {
3135 const float x = dst_p[k];
3136 GGML_UNUSED(x);
3137 assert(!isnan(x));
3138 assert(!isinf(x));
3139 }
3140#endif
3141 }
3142}
3143
3144static void ggml_compute_forward_swiglu_oai(
3145 const ggml_compute_params * params,
3146 ggml_tensor * dst) {
3147
3148 const ggml_tensor * src0 = dst->src[0];
3149
3150 switch (src0->type) {
3151 case GGML_TYPE_F32:
3152 {
3153 ggml_compute_forward_swiglu_oai_f32(params, dst);
3154 } break;
3155 default:
3156 {
3157 GGML_ABORT("fatal error");
3158 }
3159 }
3160}
3161
3162// ggml_compute_forward_geglu_erf
3163
3164static void ggml_compute_forward_geglu_erf_f32(
3165 const ggml_compute_params * params,
3166 ggml_tensor * dst) {
3167
3168 const ggml_tensor * src0 = dst->src[0];
3169 const ggml_tensor * src1 = dst->src[1];
3170 char * src0_d = (char *) src0->data;
3171 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3172 const size_t src0_o = src0->nb[1];
3173 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3174
3175 GGML_ASSERT(ggml_is_contiguous_1(src0));
3176 GGML_ASSERT(ggml_is_contiguous_1(dst));
3177
3178 if (src1) {
3179 GGML_ASSERT(ggml_is_contiguous_1(src1));
3180 GGML_ASSERT(src0->type == src1->type);
3181 }
3182
3183 const int ith = params->ith;
3184 const int nth = params->nth;
3185
3186 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3187 const int nr = ggml_nrows(tensor: src0);
3188
3189 GGML_ASSERT(dst->ne[0] == nc);
3190 GGML_ASSERT(ggml_nrows(dst) == nr);
3191
3192 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3193
3194 // rows per thread
3195 const int dr = (nr + nth - 1)/nth;
3196
3197 // row range for this thread
3198 const int ir0 = dr*ith;
3199 const int ir1 = MIN(ir0 + dr, nr);
3200
3201 for (int i1 = ir0; i1 < ir1; i1++) {
3202 float * src0_p = (float *) (src0_d + i1*src0_o);
3203 float * src1_p = (float *) (src1_d + i1*src1_o);
3204
3205 if (!src1) {
3206 src0_p += swapped ? nc : 0;
3207 src1_p += swapped ? 0 : nc;
3208 }
3209
3210 ggml_vec_geglu_erf_f32(n: nc, y: (float *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
3211
3212#ifndef NDEBUG
3213 for (int k = 0; k < nc; k++) {
3214 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3215 GGML_UNUSED(x);
3216 assert(!isnan(x));
3217 assert(!isinf(x));
3218 }
3219#endif
3220 }
3221}
3222
3223static void ggml_compute_forward_geglu_erf_f16(
3224 const ggml_compute_params * params,
3225 ggml_tensor * dst) {
3226
3227 const ggml_tensor * src0 = dst->src[0];
3228 const ggml_tensor * src1 = dst->src[1];
3229 char * src0_d = (char *) src0->data;
3230 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3231 const size_t src0_o = src0->nb[1];
3232 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3233
3234 GGML_ASSERT(ggml_is_contiguous_1(src0));
3235 GGML_ASSERT(ggml_is_contiguous_1(dst));
3236
3237 if (src1) {
3238 GGML_ASSERT(ggml_is_contiguous_1(src1));
3239 GGML_ASSERT(src0->type == src1->type);
3240 }
3241
3242 const int ith = params->ith;
3243 const int nth = params->nth;
3244
3245 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3246 const int nr = ggml_nrows(tensor: src0);
3247
3248 GGML_ASSERT(dst->ne[0] == nc);
3249 GGML_ASSERT(ggml_nrows(dst) == nr);
3250
3251 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3252
3253 // rows per thread
3254 const int dr = (nr + nth - 1)/nth;
3255
3256 // row range for this thread
3257 const int ir0 = dr*ith;
3258 const int ir1 = MIN(ir0 + dr, nr);
3259
3260 for (int i1 = ir0; i1 < ir1; i1++) {
3261 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3262 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3263
3264 if (!src1) {
3265 src0_p += swapped ? nc : 0;
3266 src1_p += swapped ? 0 : nc;
3267 }
3268
3269 ggml_vec_geglu_erf_f16(n: nc, y: (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
3270
3271#ifndef NDEBUG
3272 for (int k = 0; k < nc; k++) {
3273 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3274 const float v = GGML_FP16_TO_FP32(x);
3275 GGML_UNUSED(v);
3276 assert(!isnan(v));
3277 assert(!isinf(v));
3278 }
3279#endif
3280 }
3281}
3282
3283static void ggml_compute_forward_geglu_erf(
3284 const ggml_compute_params * params,
3285 ggml_tensor * dst) {
3286
3287 const ggml_tensor * src0 = dst->src[0];
3288
3289 switch (src0->type) {
3290 case GGML_TYPE_F32:
3291 {
3292 ggml_compute_forward_geglu_erf_f32(params, dst);
3293 } break;
3294 case GGML_TYPE_F16:
3295 {
3296 ggml_compute_forward_geglu_erf_f16(params, dst);
3297 } break;
3298 default:
3299 {
3300 GGML_ABORT("fatal error");
3301 }
3302 }
3303}
3304
3305// ggml_compute_forward_geglu_quick
3306
3307static void ggml_compute_forward_geglu_quick_f32(
3308 const ggml_compute_params * params,
3309 ggml_tensor * dst) {
3310
3311 const ggml_tensor * src0 = dst->src[0];
3312 const ggml_tensor * src1 = dst->src[1];
3313 char * src0_d = (char *) src0->data;
3314 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3315 const size_t src0_o = src0->nb[1];
3316 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3317
3318 GGML_ASSERT(ggml_is_contiguous_1(src0));
3319 GGML_ASSERT(ggml_is_contiguous_1(dst));
3320
3321 if (src1) {
3322 GGML_ASSERT(ggml_is_contiguous_1(src1));
3323 GGML_ASSERT(src0->type == src1->type);
3324 }
3325
3326 const int ith = params->ith;
3327 const int nth = params->nth;
3328
3329 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3330 const int nr = ggml_nrows(tensor: src0);
3331
3332 GGML_ASSERT(dst->ne[0] == nc);
3333 GGML_ASSERT(ggml_nrows(dst) == nr);
3334
3335 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3336
3337 // rows per thread
3338 const int dr = (nr + nth - 1)/nth;
3339
3340 // row range for this thread
3341 const int ir0 = dr*ith;
3342 const int ir1 = MIN(ir0 + dr, nr);
3343
3344 for (int i1 = ir0; i1 < ir1; i1++) {
3345 float * src0_p = (float *) (src0_d + i1*src0_o);
3346 float * src1_p = (float *) (src1_d + i1*src1_o);
3347
3348 if (!src1) {
3349 src0_p += swapped ? nc : 0;
3350 src1_p += swapped ? 0 : nc;
3351 }
3352
3353 ggml_vec_geglu_quick_f32(n: nc, y: (float *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
3354
3355#ifndef NDEBUG
3356 for (int k = 0; k < nc; k++) {
3357 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3358 GGML_UNUSED(x);
3359 assert(!isnan(x));
3360 assert(!isinf(x));
3361 }
3362#endif
3363 }
3364}
3365
3366static void ggml_compute_forward_geglu_quick_f16(
3367 const ggml_compute_params * params,
3368 ggml_tensor * dst) {
3369
3370 const ggml_tensor * src0 = dst->src[0];
3371 const ggml_tensor * src1 = dst->src[1];
3372 char * src0_d = (char *) src0->data;
3373 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3374 const size_t src0_o = src0->nb[1];
3375 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3376
3377 GGML_ASSERT(ggml_is_contiguous_1(src0));
3378 GGML_ASSERT(ggml_is_contiguous_1(dst));
3379
3380 if (src1) {
3381 GGML_ASSERT(ggml_is_contiguous_1(src1));
3382 GGML_ASSERT(src0->type == src1->type);
3383 }
3384
3385 const int ith = params->ith;
3386 const int nth = params->nth;
3387
3388 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3389 const int nr = ggml_nrows(tensor: src0);
3390
3391 GGML_ASSERT(dst->ne[0] == nc);
3392 GGML_ASSERT(ggml_nrows(dst) == nr);
3393
3394 const int32_t swapped = ggml_get_op_params_i32(tensor: dst, i: 1);
3395
3396 // rows per thread
3397 const int dr = (nr + nth - 1)/nth;
3398
3399 // row range for this thread
3400 const int ir0 = dr*ith;
3401 const int ir1 = MIN(ir0 + dr, nr);
3402
3403 for (int i1 = ir0; i1 < ir1; i1++) {
3404 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3405 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3406
3407 if (!src1) {
3408 src0_p += swapped ? nc : 0;
3409 src1_p += swapped ? 0 : nc;
3410 }
3411
3412 ggml_vec_geglu_quick_f16(n: nc, y: (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), x: src0_p, g: src1_p);
3413
3414#ifndef NDEBUG
3415 for (int k = 0; k < nc; k++) {
3416 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3417 const float v = GGML_FP16_TO_FP32(x);
3418 GGML_UNUSED(v);
3419 assert(!isnan(v));
3420 assert(!isinf(v));
3421 }
3422#endif
3423 }
3424}
3425
3426static void ggml_compute_forward_geglu_quick(
3427 const ggml_compute_params * params,
3428 ggml_tensor * dst) {
3429
3430 const ggml_tensor * src0 = dst->src[0];
3431
3432 switch (src0->type) {
3433 case GGML_TYPE_F32:
3434 {
3435 ggml_compute_forward_geglu_quick_f32(params, dst);
3436 } break;
3437 case GGML_TYPE_F16:
3438 {
3439 ggml_compute_forward_geglu_quick_f16(params, dst);
3440 } break;
3441 default:
3442 {
3443 GGML_ABORT("fatal error");
3444 }
3445 }
3446}
3447
3448// ggml_compute_forward_norm
3449
3450static void ggml_compute_forward_norm_f32(
3451 const ggml_compute_params * params,
3452 ggml_tensor * dst) {
3453
3454 const ggml_tensor * src0 = dst->src[0];
3455
3456 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3457
3458 GGML_ASSERT(src0->nb[0] == sizeof(float));
3459
3460 const int ith = params->ith;
3461 const int nth = params->nth;
3462
3463 GGML_TENSOR_UNARY_OP_LOCALS
3464
3465 float eps;
3466 memcpy(dest: &eps, src: dst->op_params, n: sizeof(float));
3467
3468 GGML_ASSERT(eps >= 0.0f);
3469
3470 for (int64_t i03 = 0; i03 < ne03; i03++) {
3471 for (int64_t i02 = 0; i02 < ne02; i02++) {
3472 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3473 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3474
3475 float sum = 0.0;
3476 ggml_vec_sum_f32(n: ne00, s: &sum, x);
3477 float mean = sum/ne00;
3478
3479 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3480 float variance = 0;
3481
3482#ifdef GGML_USE_ACCELERATE
3483 mean = -mean;
3484 vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3485 vDSP_measqv(y, 1, &variance, ne00);
3486#else
3487 variance = ggml_vec_cvar_f32(n: ne00, y, x, mean);
3488#endif //GGML_USE_ACCELERATE
3489
3490 const float scale = 1.0f/sqrtf(x: variance + eps);
3491 ggml_vec_scale_f32(n: ne00, y, v: scale);
3492 }
3493 }
3494 }
3495}
3496
3497void ggml_compute_forward_norm(
3498 const ggml_compute_params * params,
3499 ggml_tensor * dst) {
3500
3501 const ggml_tensor * src0 = dst->src[0];
3502
3503 switch (src0->type) {
3504 case GGML_TYPE_F32:
3505 {
3506 ggml_compute_forward_norm_f32(params, dst);
3507 } break;
3508 default:
3509 {
3510 GGML_ABORT("fatal error");
3511 }
3512 }
3513}
3514
3515// ggml_compute_forward_group_rms_norm
3516
3517static void ggml_compute_forward_rms_norm_f32(
3518 const ggml_compute_params * params,
3519 ggml_tensor * dst) {
3520
3521 const ggml_tensor * src0 = dst->src[0];
3522
3523 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3524
3525 GGML_ASSERT(src0->nb[0] == sizeof(float));
3526
3527 const int ith = params->ith;
3528 const int nth = params->nth;
3529
3530 GGML_TENSOR_UNARY_OP_LOCALS
3531
3532 float eps;
3533 memcpy(dest: &eps, src: dst->op_params, n: sizeof(float));
3534
3535 GGML_ASSERT(eps >= 0.0f);
3536
3537 // TODO: optimize
3538 for (int64_t i03 = 0; i03 < ne03; i03++) {
3539 for (int64_t i02 = 0; i02 < ne02; i02++) {
3540 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3541 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3542
3543 ggml_float sum = 0.0;
3544 for (int64_t i00 = 0; i00 < ne00; i00++) {
3545 sum += (ggml_float)(x[i00] * x[i00]);
3546 }
3547
3548 const float mean = sum/ne00;
3549
3550 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3551
3552 memcpy(dest: y, src: x, n: ne00 * sizeof(float));
3553 // for (int i00 = 0; i00 < ne00; i00++) {
3554 // y[i00] = x[i00];
3555 // }
3556
3557 const float scale = 1.0f/sqrtf(x: mean + eps);
3558
3559 // if you hit this, likely you got an inf somewhere earlier
3560 assert(scale > 0.0f);
3561
3562 ggml_vec_scale_f32(n: ne00, y, v: scale);
3563 }
3564 }
3565 }
3566}
3567
3568void ggml_compute_forward_rms_norm(
3569 const ggml_compute_params * params,
3570 ggml_tensor * dst) {
3571
3572 const ggml_tensor * src0 = dst->src[0];
3573
3574 switch (src0->type) {
3575 case GGML_TYPE_F32:
3576 {
3577 ggml_compute_forward_rms_norm_f32(params, dst);
3578 } break;
3579 default:
3580 {
3581 GGML_ABORT("fatal error");
3582 }
3583 }
3584}
3585
3586static void ggml_compute_forward_rms_norm_back_f32(
3587 const ggml_compute_params * params,
3588 ggml_tensor * dst) {
3589
3590 const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
3591 const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
3592
3593 GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
3594
3595 GGML_ASSERT(src0->nb[0] == sizeof(float));
3596 GGML_ASSERT(src1->nb[0] == sizeof(float));
3597
3598 const int ith = params->ith;
3599 const int nth = params->nth;
3600
3601 GGML_TENSOR_BINARY_OP_LOCALS
3602
3603 float eps;
3604 memcpy(dest: &eps, src: dst->op_params, n: sizeof(float));
3605
3606 // TODO: optimize
3607 for (int64_t i03 = 0; i03 < ne03; i03++) {
3608 for (int64_t i02 = 0; i02 < ne02; i02++) {
3609 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3610 // src1 is same shape as src0 => same indices
3611 const int64_t i11 = i01;
3612 const int64_t i12 = i02;
3613 const int64_t i13 = i03;
3614
3615 const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3616 const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3617
3618 ggml_float sum_xx = 0.0;
3619 ggml_float sum_xdz = 0.0;
3620
3621 for (int64_t i00 = 0; i00 < ne00; i00++) {
3622 sum_xx += (ggml_float)(x[i00] * x[i00]);
3623 sum_xdz += (ggml_float)(x[i00] * dz[i00]);
3624 }
3625
3626 //const float mean = (float)(sum_xx)/ne00;
3627 const float mean_eps = (float)(sum_xx)/ne00 + eps;
3628 const float sum_eps = (float)(sum_xx) + eps*ne00;
3629 //const float mean_xdz = (float)(sum_xdz)/ne00;
3630 // we could cache rms from forward pass to improve performance.
3631 // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
3632 //const float rms = sqrtf(mean_eps);
3633 const float rrms = 1.0f / sqrtf(x: mean_eps);
3634 //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
3635
3636 {
3637 // z = rms_norm(x)
3638 //
3639 // rms_norm(src1) =
3640 // scale(
3641 // src1,
3642 // div(
3643 // 1,
3644 // sqrt(
3645 // add(
3646 // scale(
3647 // sum(
3648 // sqr(
3649 // src1)),
3650 // (1.0/N)),
3651 // eps))));
3652
3653 // postorder:
3654 // ## op args grad
3655 // 00 param src1 grad[#00]
3656 // 01 const 1
3657 // 02 sqr (#00) grad[#02]
3658 // 03 sum (#02) grad[#03]
3659 // 04 const 1/N
3660 // 05 scale (#03, #04) grad[#05]
3661 // 06 const eps
3662 // 07 add (#05, #06) grad[#07]
3663 // 08 sqrt (#07) grad[#08]
3664 // 09 div (#01,#08) grad[#09]
3665 // 10 scale (#00,#09) grad[#10]
3666 //
3667 // backward pass, given grad[#10]
3668 // #10: scale
3669 // grad[#00] += scale(grad[#10],#09)
3670 // grad[#09] += sum(mul(grad[#10],#00))
3671 // #09: div
3672 // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
3673 // #08: sqrt
3674 // grad[#07] += mul(grad[#08], div(0.5, #08))
3675 // #07: add
3676 // grad[#05] += grad[#07]
3677 // #05: scale
3678 // grad[#03] += scale(grad[#05],#04)
3679 // #03: sum
3680 // grad[#02] += repeat(grad[#03], #02)
3681 // #02:
3682 // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
3683 //
3684 // substitute and simplify:
3685 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3686 // grad[#02] = repeat(grad[#03], #02)
3687 // grad[#02] = repeat(scale(grad[#05],#04), #02)
3688 // grad[#02] = repeat(scale(grad[#07],#04), #02)
3689 // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
3690 // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
3691 // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
3692 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
3693 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
3694 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
3695 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
3696 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3697 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
3698 // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
3699 // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
3700 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3701 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3702 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
3703 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
3704 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
3705 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
3706 // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
3707 // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
3708 // a = b*c + d*e
3709 // a = b*c*f/f + d*e*f/f
3710 // a = (b*c*f + d*e*f)*(1/f)
3711 // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
3712 // a = (b + d*e/c)*c
3713 // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
3714 // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
3715 // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
3716 // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
3717 // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
3718 // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
3719 // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
3720 // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
3721 // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3722 // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3723 }
3724 // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3725 // post-order:
3726 // dx := x
3727 // dx := scale(dx,-mean_xdz/mean_eps)
3728 // dx := add(dx, dz)
3729 // dx := scale(dx, rrms)
3730 float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3731
3732 // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
3733 ggml_vec_cpy_f32 (n: ne00, y: dx, x);
3734 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
3735 ggml_vec_scale_f32(n: ne00, y: dx, v: (float)(-sum_xdz)/sum_eps);
3736 ggml_vec_acc_f32 (n: ne00, y: dx, x: dz);
3737 ggml_vec_scale_f32(n: ne00, y: dx, v: rrms);
3738 }
3739 }
3740 }
3741}
3742
3743void ggml_compute_forward_rms_norm_back(
3744 const ggml_compute_params * params,
3745 ggml_tensor * dst) {
3746
3747 const ggml_tensor * src0 = dst->src[0];
3748
3749 switch (src0->type) {
3750 case GGML_TYPE_F32:
3751 {
3752 ggml_compute_forward_rms_norm_back_f32(params, dst);
3753 } break;
3754 default:
3755 {
3756 GGML_ABORT("fatal error");
3757 }
3758 }
3759}
3760
3761// ggml_compute_forward_group_norm
3762
3763static void ggml_compute_forward_group_norm_f32(
3764 const ggml_compute_params * params,
3765 ggml_tensor * dst) {
3766
3767 const ggml_tensor * src0 = dst->src[0];
3768
3769 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3770
3771 GGML_ASSERT(src0->nb[0] == sizeof(float));
3772
3773 const int ith = params->ith;
3774 const int nth = params->nth;
3775
3776 GGML_TENSOR_UNARY_OP_LOCALS
3777
3778 // TODO: optimize
3779
3780 float eps;
3781 memcpy(dest: &eps, src: dst->op_params + 1, n: sizeof(float));
3782
3783 int n_channels = src0->ne[2];
3784 int n_groups = dst->op_params[0];
3785 int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
3786 for (int i = ith; i < n_groups; i += nth) {
3787 int start = i * n_channels_per_group;
3788 int end = start + n_channels_per_group;
3789 if (end > n_channels) {
3790 end = n_channels;
3791 }
3792 int step = end - start;
3793
3794 for (int64_t i03 = 0; i03 < ne03; i03++) {
3795 ggml_float sum = 0.0;
3796 for (int64_t i02 = start; i02 < end; i02++) {
3797 for (int64_t i01 = 0; i01 < ne01; i01++) {
3798 const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3799
3800 ggml_float sumr = 0.0;
3801 for (int64_t i00 = 0; i00 < ne00; i00++) {
3802 sumr += (ggml_float)x[i00];
3803 }
3804 sum += sumr;
3805 }
3806 }
3807 const float mean = sum / (ne00 * ne01 * step);
3808
3809 ggml_float sum2 = 0.0;
3810 for (int64_t i02 = start; i02 < end; i02++) {
3811 for (int64_t i01 = 0; i01 < ne01; i01++) {
3812 const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3813
3814 float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3815
3816 ggml_float sumr = 0.0;
3817 for (int64_t i00 = 0; i00 < ne00; i00++) {
3818 float v = x[i00] - mean;
3819 y[i00] = v;
3820 sumr += (ggml_float)(v * v);
3821 }
3822 sum2 += sumr;
3823 }
3824 }
3825 const float variance = sum2 / (ne00 * ne01 * step);
3826 const float scale = 1.0f / sqrtf(x: variance + eps);
3827
3828 for (int64_t i02 = start; i02 < end; i02++) {
3829 for (int64_t i01 = 0; i01 < ne01; i01++) {
3830 float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
3831 ggml_vec_scale_f32(n: ne00, y, v: scale);
3832 }
3833 }
3834 }
3835 }
3836}
3837
3838void ggml_compute_forward_group_norm(
3839 const ggml_compute_params * params,
3840 ggml_tensor * dst) {
3841
3842 const ggml_tensor * src0 = dst->src[0];
3843
3844 switch (src0->type) {
3845 case GGML_TYPE_F32:
3846 {
3847 ggml_compute_forward_group_norm_f32(params, dst);
3848 } break;
3849 default:
3850 {
3851 GGML_ABORT("fatal error");
3852 }
3853 }
3854}
3855
3856// ggml_compute_forward_l2_norm
3857
3858static void ggml_compute_forward_l2_norm_f32(
3859 const ggml_compute_params * params,
3860 ggml_tensor * dst) {
3861
3862 const ggml_tensor * src0 = dst->src[0];
3863
3864 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3865
3866 GGML_ASSERT(src0->nb[0] == sizeof(float));
3867
3868 const int ith = params->ith;
3869 const int nth = params->nth;
3870
3871 GGML_TENSOR_UNARY_OP_LOCALS
3872
3873 float eps;
3874 memcpy(dest: &eps, src: dst->op_params, n: sizeof(float));
3875
3876 GGML_ASSERT(eps >= 0.0f);
3877
3878 // TODO: optimize
3879 for (int64_t i03 = 0; i03 < ne03; i03++) {
3880 for (int64_t i02 = 0; i02 < ne02; i02++) {
3881 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3882 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3883
3884 ggml_float sum = 0.0;
3885 for (int64_t i00 = 0; i00 < ne00; i00++) {
3886 sum += (ggml_float)(x[i00] * x[i00]);
3887 }
3888
3889 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3890
3891 memcpy(dest: y, src: x, n: ne00 * sizeof(float));
3892
3893 const float scale = 1.0f/fmaxf(x: sqrtf(x: sum), y: eps);
3894
3895 ggml_vec_scale_f32(n: ne00, y, v: scale);
3896 }
3897 }
3898 }
3899}
3900
3901void ggml_compute_forward_l2_norm(
3902 const ggml_compute_params * params,
3903 ggml_tensor * dst) {
3904
3905 const ggml_tensor * src0 = dst->src[0];
3906
3907 switch (src0->type) {
3908 case GGML_TYPE_F32:
3909 {
3910 ggml_compute_forward_l2_norm_f32(params, dst);
3911 } break;
3912 default:
3913 {
3914 GGML_ABORT("fatal error");
3915 }
3916 }
3917}
3918
3919// ggml_compute_forward_out_prod
3920
3921static void ggml_compute_forward_out_prod_f32(
3922 const ggml_compute_params * params,
3923 ggml_tensor * dst) {
3924
3925 const ggml_tensor * src0 = dst->src[0];
3926 const ggml_tensor * src1 = dst->src[1];
3927
3928 GGML_TENSOR_BINARY_OP_LOCALS
3929
3930 GGML_ASSERT(dst->type == GGML_TYPE_F32);
3931 GGML_ASSERT(src0->type == GGML_TYPE_F32);
3932 GGML_ASSERT(src1->type == GGML_TYPE_F32);
3933
3934 const int ith = params->ith;
3935 const int nth = params->nth;
3936
3937 GGML_ASSERT(ne0 == ne00);
3938 GGML_ASSERT(ne1 == ne10);
3939 GGML_ASSERT(ne2 == ne12);
3940 GGML_ASSERT(ne3 == ne13);
3941
3942 GGML_ASSERT(ne2 % ne02 == 0);
3943 GGML_ASSERT(ne3 % ne03 == 0);
3944
3945 // we don't support permuted src0 or src1
3946 GGML_ASSERT(nb00 == sizeof(float));
3947
3948 // dst cannot be transposed or permuted
3949 GGML_ASSERT(nb0 == sizeof(float));
3950 // GGML_ASSERT(nb0 <= nb1);
3951 // GGML_ASSERT(nb1 <= nb2);
3952 // GGML_ASSERT(nb2 <= nb3);
3953
3954 // nb01 >= nb00 - src0 is not transposed
3955 // compute by src0 rows
3956
3957 if (ith == 0) {
3958 ggml_vec_set_f32(n: ne0*ne1*ne2*ne3, x: (float *)dst->data, v: 0);
3959 }
3960 ggml_barrier(tp: params->threadpool);
3961
3962 // dst[:,:,:,:] = 0
3963 // for i2,i3:
3964 // for i1:
3965 // for i01:
3966 // for i0:
3967 // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
3968
3969 // parallelize by last three dimensions
3970
3971 // total rows in dst
3972 const int64_t nr = ne1*ne2*ne3;
3973
3974 // rows per thread
3975 const int64_t dr = (nr + nth - 1)/nth;
3976
3977 // row range for this thread
3978 const int64_t ir0 = dr*ith;
3979 const int64_t ir1 = MIN(ir0 + dr, nr);
3980
3981 // block-tiling attempt
3982 const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
3983 const int64_t blck_1 = 16;
3984
3985 // dps == dst per src0, used for group query attention
3986 const int64_t dps2 = ne2 / ne02;
3987 const int64_t dps3 = ne3 / ne03;
3988
3989 for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
3990 const int64_t bir1 = MIN(bir + blck_1, ir1);
3991 for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
3992 const int64_t bne01 = MIN(bi01 + blck_0, ne01);
3993 for (int64_t ir = bir; ir < bir1; ++ir) {
3994 // dst indices
3995 const int64_t i3 = ir/(ne2*ne1);
3996 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
3997 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
3998
3999 const int64_t i02 = i2 / dps2;
4000 const int64_t i03 = i3 / dps3;
4001
4002 //const int64_t i10 = i1;
4003 const int64_t i12 = i2;
4004 const int64_t i13 = i3;
4005
4006#if GGML_VEC_MAD_UNROLL > 2
4007 const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
4008 for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
4009 const int64_t i11 = i01;
4010
4011 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4012 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4013 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4014
4015 ggml_vec_mad_f32_unroll(n: ne0, xs: nb01, vs: nb11, y: d, xv: s0, vv: s1);
4016 }
4017 for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
4018 const int64_t i11 = i01;
4019
4020 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4021 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4022 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4023
4024 ggml_vec_mad_f32(n: ne0, y: d, x: s0, v: *s1);
4025 }
4026#else
4027 for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4028 const int64_t i11 = i01;
4029
4030 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4031 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4032 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4033
4034 ggml_vec_mad_f32(ne0, d, s0, *s1);
4035 }
4036#endif
4037 }
4038 }
4039 }
4040}
4041
4042static void ggml_compute_forward_out_prod_q_f32(
4043 const ggml_compute_params * params,
4044 ggml_tensor * dst) {
4045
4046 const ggml_tensor * src0 = dst->src[0];
4047 const ggml_tensor * src1 = dst->src[1];
4048
4049 GGML_TENSOR_BINARY_OP_LOCALS;
4050
4051 const int ith = params->ith;
4052 const int nth = params->nth;
4053
4054 const ggml_type type = src0->type;
4055 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4056
4057 GGML_ASSERT(ne02 == ne12);
4058 GGML_ASSERT(ne03 == ne13);
4059 GGML_ASSERT(ne2 == ne12);
4060 GGML_ASSERT(ne3 == ne13);
4061
4062 // we don't support permuted src0 dim0
4063 GGML_ASSERT(nb00 == ggml_type_size(type));
4064
4065 // dst dim0 cannot be transposed or permuted
4066 GGML_ASSERT(nb0 == sizeof(float));
4067 // GGML_ASSERT(nb0 <= nb1);
4068 // GGML_ASSERT(nb1 <= nb2);
4069 // GGML_ASSERT(nb2 <= nb3);
4070
4071 GGML_ASSERT(ne0 == ne00);
4072 GGML_ASSERT(ne1 == ne10);
4073 GGML_ASSERT(ne2 == ne02);
4074 GGML_ASSERT(ne3 == ne03);
4075
4076 // nb01 >= nb00 - src0 is not transposed
4077 // compute by src0 rows
4078
4079 if (ith == 0) {
4080 ggml_vec_set_f32(n: ne0*ne1*ne2*ne3, x: (float *)dst->data, v: 0);
4081 }
4082 ggml_barrier(tp: params->threadpool);
4083
4084 // parallelize by last three dimensions
4085
4086 // total rows in dst
4087 const int64_t nr = ne1*ne2*ne3;
4088
4089 // rows per thread
4090 const int64_t dr = (nr + nth - 1)/nth;
4091
4092 // row range for this thread
4093 const int64_t ir0 = dr*ith;
4094 const int64_t ir1 = MIN(ir0 + dr, nr);
4095
4096 // dst[:,:,:,:] = 0
4097 // for i2,i3:
4098 // for i1:
4099 // for i01:
4100 // for i0:
4101 // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4102
4103 float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
4104
4105 for (int64_t ir = ir0; ir < ir1; ++ir) {
4106 // dst indices
4107 const int64_t i3 = ir/(ne2*ne1);
4108 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4109 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4110
4111 const int64_t i02 = i2;
4112 const int64_t i03 = i3;
4113
4114 //const int64_t i10 = i1;
4115 const int64_t i12 = i2;
4116 const int64_t i13 = i3;
4117
4118 for (int64_t i01 = 0; i01 < ne01; ++i01) {
4119 const int64_t i11 = i01;
4120
4121 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4122 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4123 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4124
4125 dequantize_row_q(s0, wdata, ne0);
4126 ggml_vec_mad_f32(n: ne0, y: d, x: wdata, v: *s1);
4127 }
4128 }
4129}
4130
4131void ggml_compute_forward_out_prod(
4132 const ggml_compute_params * params,
4133 ggml_tensor * dst) {
4134
4135 const ggml_tensor * src0 = dst->src[0];
4136
4137 switch (src0->type) {
4138 case GGML_TYPE_Q4_0:
4139 case GGML_TYPE_Q4_1:
4140 case GGML_TYPE_Q5_0:
4141 case GGML_TYPE_Q5_1:
4142 case GGML_TYPE_Q8_0:
4143 case GGML_TYPE_MXFP4:
4144 case GGML_TYPE_Q2_K:
4145 case GGML_TYPE_Q3_K:
4146 case GGML_TYPE_Q4_K:
4147 case GGML_TYPE_Q5_K:
4148 case GGML_TYPE_Q6_K:
4149 case GGML_TYPE_TQ1_0:
4150 case GGML_TYPE_TQ2_0:
4151 case GGML_TYPE_IQ2_XXS:
4152 case GGML_TYPE_IQ2_XS:
4153 case GGML_TYPE_IQ3_XXS:
4154 case GGML_TYPE_IQ1_S:
4155 case GGML_TYPE_IQ1_M:
4156 case GGML_TYPE_IQ4_NL:
4157 case GGML_TYPE_IQ4_XS:
4158 case GGML_TYPE_IQ3_S:
4159 case GGML_TYPE_IQ2_S:
4160 {
4161 ggml_compute_forward_out_prod_q_f32(params, dst);
4162 } break;
4163 case GGML_TYPE_F16:
4164 {
4165 GGML_ABORT("fatal error"); // todo
4166 // ggml_compute_forward_out_prod_f16_f32(params, dst);
4167 }
4168 case GGML_TYPE_F32:
4169 {
4170 ggml_compute_forward_out_prod_f32(params, dst);
4171 } break;
4172 default:
4173 {
4174 GGML_ABORT("fatal error");
4175 }
4176 }
4177}
4178
4179// ggml_compute_forward_scale
4180
4181static void ggml_compute_forward_scale_f32(
4182 const ggml_compute_params * params,
4183 ggml_tensor * dst) {
4184
4185 const ggml_tensor * src0 = dst->src[0];
4186
4187 GGML_ASSERT(ggml_is_contiguous(src0));
4188 GGML_ASSERT(ggml_is_contiguous(dst));
4189 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4190
4191 float s; // scale factor
4192 float b; // bias
4193
4194 memcpy(dest: &s, src: (float *) dst->op_params + 0, n: sizeof(float));
4195 memcpy(dest: &b, src: (float *) dst->op_params + 1, n: sizeof(float));
4196
4197 const int ith = params->ith;
4198 const int nth = params->nth;
4199
4200 const int nc = src0->ne[0];
4201 const int nr = ggml_nrows(tensor: src0);
4202
4203 // rows per thread
4204 const int dr = (nr + nth - 1)/nth;
4205
4206 // row range for this thread
4207 const int ir0 = dr*ith;
4208 const int ir1 = MIN(ir0 + dr, nr);
4209
4210 const size_t nb01 = src0->nb[1];
4211
4212 const size_t nb1 = dst->nb[1];
4213
4214 if (b == 0.0f) {
4215 for (int i1 = ir0; i1 < ir1; i1++) {
4216 if (dst->data != src0->data) {
4217 // src0 is same shape as dst => same indices
4218 // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4219 memcpy(dest: (char *)dst->data + i1*nb1, src: (char *)src0->data + i1*nb01, n: nc * sizeof(float));
4220 }
4221 ggml_vec_scale_f32(n: nc, y: (float *) ((char *) dst->data + i1*nb1), v: s);
4222 }
4223 } else {
4224 for (int i1 = ir0; i1 < ir1; i1++) {
4225 ggml_vec_mad1_f32(n: nc,
4226 y: (float *) ((char *) dst->data + i1*nb1),
4227 x: (float *) ((char *) src0->data + i1*nb1),
4228 s, b);
4229 }
4230 }
4231}
4232
4233void ggml_compute_forward_scale(
4234 const ggml_compute_params * params,
4235 ggml_tensor * dst) {
4236
4237 const ggml_tensor * src0 = dst->src[0];
4238
4239 switch (src0->type) {
4240 case GGML_TYPE_F32:
4241 {
4242 ggml_compute_forward_scale_f32(params, dst);
4243 } break;
4244 default:
4245 {
4246 GGML_ABORT("fatal error");
4247 }
4248 }
4249}
4250
4251// ggml_compute_forward_set
4252
4253static void ggml_compute_forward_set_f32(
4254 const ggml_compute_params * params,
4255 ggml_tensor * dst) {
4256
4257 const ggml_tensor * src0 = dst->src[0];
4258 const ggml_tensor * src1 = dst->src[1];
4259
4260 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4261 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4262
4263 // view src0 and dst with these strides and data offset inbytes during set
4264 // nb0 is implicitly element_size because src0 and dst are contiguous
4265 size_t nb1 = ((int32_t *) dst->op_params)[0];
4266 size_t nb2 = ((int32_t *) dst->op_params)[1];
4267 size_t nb3 = ((int32_t *) dst->op_params)[2];
4268 size_t offset = ((int32_t *) dst->op_params)[3];
4269 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
4270
4271 if (!inplace) {
4272 if (params->ith == 0) {
4273 // memcpy needs to be synchronized across threads to avoid race conditions.
4274 // => do it in INIT phase
4275 memcpy(
4276 dest: ((char *) dst->data),
4277 src: ((char *) src0->data),
4278 n: ggml_nbytes(tensor: dst));
4279 }
4280 ggml_barrier(tp: params->threadpool);
4281 }
4282
4283 const int ith = params->ith;
4284 const int nth = params->nth;
4285
4286 const int nr = ggml_nrows(tensor: src1);
4287 const int nc = src1->ne[0];
4288
4289 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4290 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
4291
4292 // src0 and dst as viewed during set
4293 const size_t nb0 = ggml_element_size(tensor: src0);
4294
4295 const int im0 = (ne10 == 0 ? 0 : ne10-1);
4296 const int im1 = (ne11 == 0 ? 0 : ne11-1);
4297 const int im2 = (ne12 == 0 ? 0 : ne12-1);
4298 const int im3 = (ne13 == 0 ? 0 : ne13-1);
4299
4300 GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
4301
4302 GGML_ASSERT(nb10 == sizeof(float));
4303
4304 // rows per thread
4305 const int dr = (nr + nth - 1)/nth;
4306
4307 // row range for this thread
4308 const int ir0 = dr*ith;
4309 const int ir1 = MIN(ir0 + dr, nr);
4310
4311 for (int ir = ir0; ir < ir1; ++ir) {
4312 // src0 and dst are viewed with shape of src1 and offset
4313 // => same indices
4314 const int i3 = ir/(ne12*ne11);
4315 const int i2 = (ir - i3*ne12*ne11)/ne11;
4316 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4317
4318 ggml_vec_cpy_f32(n: nc,
4319 y: (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
4320 x: (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4321 }
4322}
4323
4324static void ggml_compute_forward_set_i32(
4325 const ggml_compute_params * params,
4326 ggml_tensor * dst) {
4327
4328 const ggml_tensor * src0 = dst->src[0];
4329 const ggml_tensor * src1 = dst->src[1];
4330
4331 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4332 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4333
4334 // view src0 and dst with these strides and data offset inbytes during set
4335 // nb0 is implicitly element_size because src0 and dst are contiguous
4336 size_t nb1 = ((int32_t *) dst->op_params)[0];
4337 size_t nb2 = ((int32_t *) dst->op_params)[1];
4338 size_t nb3 = ((int32_t *) dst->op_params)[2];
4339 size_t offset = ((int32_t *) dst->op_params)[3];
4340 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
4341
4342 if (!inplace) {
4343 if (params->ith == 0) {
4344 // memcpy needs to be synchronized across threads to avoid race conditions.
4345 // => do it in INIT phase
4346 memcpy(
4347 dest: ((char *) dst->data),
4348 src: ((char *) src0->data),
4349 n: ggml_nbytes(tensor: dst));
4350 }
4351 ggml_barrier(tp: params->threadpool);
4352 }
4353
4354 const int ith = params->ith;
4355 const int nth = params->nth;
4356
4357 const int nr = ggml_nrows(tensor: src1);
4358 const int nc = src1->ne[0];
4359
4360 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4361 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
4362
4363 // src0 and dst as viewed during set
4364 const size_t nb0 = ggml_element_size(tensor: src0);
4365
4366 const int im0 = (ne10 == 0 ? 0 : ne10-1);
4367 const int im1 = (ne11 == 0 ? 0 : ne11-1);
4368 const int im2 = (ne12 == 0 ? 0 : ne12-1);
4369 const int im3 = (ne13 == 0 ? 0 : ne13-1);
4370
4371 GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
4372
4373 GGML_ASSERT(nb10 == sizeof(int32_t));
4374
4375 // rows per thread
4376 const int dr = (nr + nth - 1)/nth;
4377
4378 // row range for this thread
4379 const int ir0 = dr*ith;
4380 const int ir1 = MIN(ir0 + dr, nr);
4381
4382 for (int ir = ir0; ir < ir1; ++ir) {
4383 // src0 and dst are viewed with shape of src1 and offset
4384 // => same indices
4385 const int i3 = ir/(ne12*ne11);
4386 const int i2 = (ir - i3*ne12*ne11)/ne11;
4387 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4388
4389 ggml_vec_cpy_i32(n: nc,
4390 y: (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
4391 x: (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4392 }
4393}
4394
4395void ggml_compute_forward_set(
4396 const ggml_compute_params * params,
4397 ggml_tensor * dst) {
4398
4399 const ggml_tensor * src0 = dst->src[0];
4400
4401 switch (src0->type) {
4402 case GGML_TYPE_F32:
4403 {
4404 ggml_compute_forward_set_f32(params, dst);
4405 } break;
4406 case GGML_TYPE_I32:
4407 {
4408 ggml_compute_forward_set_i32(params, dst);
4409 } break;
4410 case GGML_TYPE_F16:
4411 case GGML_TYPE_BF16:
4412 case GGML_TYPE_Q4_0:
4413 case GGML_TYPE_Q4_1:
4414 case GGML_TYPE_Q5_0:
4415 case GGML_TYPE_Q5_1:
4416 case GGML_TYPE_Q8_0:
4417 case GGML_TYPE_Q8_1:
4418 case GGML_TYPE_MXFP4:
4419 case GGML_TYPE_Q2_K:
4420 case GGML_TYPE_Q3_K:
4421 case GGML_TYPE_Q4_K:
4422 case GGML_TYPE_Q5_K:
4423 case GGML_TYPE_Q6_K:
4424 case GGML_TYPE_TQ1_0:
4425 case GGML_TYPE_TQ2_0:
4426 case GGML_TYPE_IQ2_XXS:
4427 case GGML_TYPE_IQ2_XS:
4428 case GGML_TYPE_IQ3_XXS:
4429 case GGML_TYPE_IQ1_S:
4430 case GGML_TYPE_IQ1_M:
4431 case GGML_TYPE_IQ4_NL:
4432 case GGML_TYPE_IQ4_XS:
4433 case GGML_TYPE_IQ3_S:
4434 case GGML_TYPE_IQ2_S:
4435 default:
4436 {
4437 GGML_ABORT("fatal error");
4438 }
4439 }
4440}
4441
4442// ggml_compute_forward_cpy
4443
4444void ggml_compute_forward_cpy(
4445 const ggml_compute_params * params,
4446 ggml_tensor * dst) {
4447 ggml_compute_forward_dup(params, dst);
4448}
4449
4450// ggml_compute_forward_cont
4451
4452void ggml_compute_forward_cont(
4453 const ggml_compute_params * params,
4454 ggml_tensor * dst) {
4455 ggml_compute_forward_dup(params, dst);
4456}
4457
4458// ggml_compute_forward_reshape
4459
4460void ggml_compute_forward_reshape(
4461 const ggml_compute_params * params,
4462 ggml_tensor * dst) {
4463 // NOP
4464 GGML_UNUSED(params);
4465 GGML_UNUSED(dst);
4466}
4467
4468// ggml_compute_forward_view
4469
4470void ggml_compute_forward_view(
4471 const ggml_compute_params * params,
4472 ggml_tensor * dst) {
4473 // NOP
4474 GGML_UNUSED(params);
4475 GGML_UNUSED(dst);
4476}
4477
4478// ggml_compute_forward_permute
4479
4480void ggml_compute_forward_permute(
4481 const ggml_compute_params * params,
4482 ggml_tensor * dst) {
4483 // NOP
4484 GGML_UNUSED(params);
4485 GGML_UNUSED(dst);
4486}
4487
4488// ggml_compute_forward_transpose
4489
4490void ggml_compute_forward_transpose(
4491 const ggml_compute_params * params,
4492 ggml_tensor * dst) {
4493 // NOP
4494 GGML_UNUSED(params);
4495 GGML_UNUSED(dst);
4496}
4497
4498// ggml_compute_forward_get_rows
4499
4500static void ggml_compute_forward_get_rows_q(
4501 const ggml_compute_params * params,
4502 ggml_tensor * dst) {
4503
4504 const ggml_tensor * src0 = dst->src[0];
4505 const ggml_tensor * src1 = dst->src[1];
4506
4507 GGML_TENSOR_BINARY_OP_LOCALS
4508
4509 const int64_t nc = ne00;
4510 const int64_t nr = ggml_nelements(tensor: src1);
4511
4512 const ggml_type type = src0->type;
4513 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4514
4515 assert(ne0 == nc);
4516 assert(ne02 == ne11);
4517 assert(nb00 == ggml_type_size(type));
4518 assert(ggml_nrows(dst) == nr);
4519
4520 const int ith = params->ith;
4521 const int nth = params->nth;
4522
4523 // rows per thread
4524 const int dr = (nr + nth - 1)/nth;
4525
4526 // row range for this thread
4527 const int ir0 = dr*ith;
4528 const int ir1 = MIN(ir0 + dr, nr);
4529
4530 for (int64_t i = ir0; i < ir1; ++i) {
4531 const int64_t i12 = i/(ne11*ne10);
4532 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4533 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4534 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4535
4536 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4537
4538 dequantize_row_q(
4539 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4540 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4541 }
4542}
4543
4544static void ggml_compute_forward_get_rows_f16(
4545 const ggml_compute_params * params,
4546 ggml_tensor * dst) {
4547
4548 const ggml_tensor * src0 = dst->src[0];
4549 const ggml_tensor * src1 = dst->src[1];
4550
4551 GGML_TENSOR_BINARY_OP_LOCALS
4552
4553 const int64_t nc = ne00;
4554 const int64_t nr = ggml_nelements(tensor: src1);
4555
4556 assert(ne0 == nc);
4557 assert(ne02 == ne11);
4558 assert(nb00 == sizeof(ggml_fp16_t));
4559 assert(ggml_nrows(dst) == nr);
4560
4561 const int ith = params->ith;
4562 const int nth = params->nth;
4563
4564 // rows per thread
4565 const int dr = (nr + nth - 1)/nth;
4566
4567 // row range for this thread
4568 const int ir0 = dr*ith;
4569 const int ir1 = MIN(ir0 + dr, nr);
4570
4571 for (int64_t i = ir0; i < ir1; ++i) {
4572 const int64_t i12 = i/(ne11*ne10);
4573 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4574 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4575 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4576
4577 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4578
4579 ggml_cpu_fp16_to_fp32(
4580 (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4581 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4582 }
4583}
4584
4585static void ggml_compute_forward_get_rows_bf16(
4586 const ggml_compute_params * params,
4587 ggml_tensor * dst) {
4588
4589 const ggml_tensor * src0 = dst->src[0];
4590 const ggml_tensor * src1 = dst->src[1];
4591
4592 GGML_TENSOR_BINARY_OP_LOCALS
4593
4594 const int64_t nc = ne00;
4595 const int64_t nr = ggml_nelements(tensor: src1);
4596
4597 assert(ne0 == nc);
4598 assert(ne02 == ne11);
4599 assert(nb00 == sizeof(ggml_bf16_t));
4600 assert(ggml_nrows(dst) == nr);
4601
4602 const int ith = params->ith;
4603 const int nth = params->nth;
4604
4605 // rows per thread
4606 const int dr = (nr + nth - 1)/nth;
4607
4608 // row range for this thread
4609 const int ir0 = dr*ith;
4610 const int ir1 = MIN(ir0 + dr, nr);
4611
4612 for (int64_t i = ir0; i < ir1; ++i) {
4613 const int64_t i12 = i/(ne11*ne10);
4614 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4615 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4616 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4617
4618 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4619
4620 ggml_cpu_bf16_to_fp32(
4621 (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4622 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4623 }
4624}
4625
4626static void ggml_compute_forward_get_rows_f32(
4627 const ggml_compute_params * params,
4628 ggml_tensor * dst) {
4629
4630 const ggml_tensor * src0 = dst->src[0];
4631 const ggml_tensor * src1 = dst->src[1];
4632
4633 GGML_TENSOR_BINARY_OP_LOCALS
4634
4635 const int64_t nc = ne00;
4636 const int64_t nr = ggml_nelements(tensor: src1);
4637
4638 assert(ne0 == nc);
4639 assert(ne02 == ne11);
4640 assert(nb00 == sizeof(float));
4641 assert(ggml_nrows(dst) == nr);
4642
4643 const int ith = params->ith;
4644 const int nth = params->nth;
4645
4646 // rows per thread
4647 const int dr = (nr + nth - 1)/nth;
4648
4649 // row range for this thread
4650 const int ir0 = dr*ith;
4651 const int ir1 = MIN(ir0 + dr, nr);
4652
4653 for (int64_t i = ir0; i < ir1; ++i) {
4654 const int64_t i12 = i/(ne11*ne10);
4655 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4656 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4657 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4658
4659 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4660
4661 ggml_vec_cpy_f32(n: nc,
4662 y: (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
4663 x: (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4664 }
4665}
4666
4667void ggml_compute_forward_get_rows(
4668 const ggml_compute_params * params,
4669 ggml_tensor * dst) {
4670
4671 const ggml_tensor * src0 = dst->src[0];
4672
4673 switch (src0->type) {
4674 case GGML_TYPE_Q4_0:
4675 case GGML_TYPE_Q4_1:
4676 case GGML_TYPE_Q5_0:
4677 case GGML_TYPE_Q5_1:
4678 case GGML_TYPE_Q8_0:
4679 case GGML_TYPE_Q8_1:
4680 case GGML_TYPE_MXFP4:
4681 case GGML_TYPE_Q2_K:
4682 case GGML_TYPE_Q3_K:
4683 case GGML_TYPE_Q4_K:
4684 case GGML_TYPE_Q5_K:
4685 case GGML_TYPE_Q6_K:
4686 case GGML_TYPE_TQ1_0:
4687 case GGML_TYPE_TQ2_0:
4688 case GGML_TYPE_IQ2_XXS:
4689 case GGML_TYPE_IQ2_XS:
4690 case GGML_TYPE_IQ3_XXS:
4691 case GGML_TYPE_IQ1_S:
4692 case GGML_TYPE_IQ1_M:
4693 case GGML_TYPE_IQ4_NL:
4694 case GGML_TYPE_IQ4_XS:
4695 case GGML_TYPE_IQ3_S:
4696 case GGML_TYPE_IQ2_S:
4697 {
4698 ggml_compute_forward_get_rows_q(params, dst);
4699 } break;
4700 case GGML_TYPE_F16:
4701 {
4702 ggml_compute_forward_get_rows_f16(params, dst);
4703 } break;
4704 case GGML_TYPE_BF16:
4705 {
4706 ggml_compute_forward_get_rows_bf16(params, dst);
4707 } break;
4708 case GGML_TYPE_F32:
4709 case GGML_TYPE_I32:
4710 {
4711 ggml_compute_forward_get_rows_f32(params, dst);
4712 } break;
4713 default:
4714 {
4715 GGML_ABORT("fatal error");
4716 }
4717 }
4718
4719 //static bool first = true;
4720 //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4721 //if (first) {
4722 // first = false;
4723 //} else {
4724 // for (int k = 0; k < dst->ne[1]; ++k) {
4725 // for (int j = 0; j < dst->ne[0]/16; ++j) {
4726 // for (int i = 0; i < 16; ++i) {
4727 // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4728 // }
4729 // printf("\n");
4730 // }
4731 // printf("\n");
4732 // }
4733 // printf("\n");
4734 // exit(0);
4735 //}
4736}
4737
4738template<typename idx_t>
4739static void ggml_compute_forward_set_rows_f32(
4740 const ggml_compute_params * params,
4741 ggml_tensor * dst) {
4742
4743 const ggml_tensor * src0 = dst->src[0];
4744 const ggml_tensor * src1 = dst->src[1];
4745
4746 GGML_TENSOR_BINARY_OP_LOCALS
4747
4748 const int64_t nc = ne00;
4749 const int64_t nr = ne01;
4750
4751 assert(ne0 == nc);
4752 assert(ne2 == ne02);
4753 assert(ne3 == ne03);
4754 assert(src0->type == GGML_TYPE_F32);
4755 assert(ne02 % ne11 == 0);
4756 assert(ne03 % ne12 == 0);
4757
4758 const int ith = params->ith;
4759 const int nth = params->nth;
4760
4761 // rows per thread
4762 const int64_t dr = (nr + nth - 1)/nth;
4763
4764 // row range for this thread
4765 const int64_t ir0 = dr*ith;
4766 const int64_t ir1 = std::min(a: ir0 + dr, b: nr);
4767
4768 ggml_from_float_t const from_float = ggml_get_type_traits_cpu(type: dst->type)->from_float;
4769
4770 for (int64_t i03 = 0; i03 < ne03; ++i03) {
4771 for (int64_t i02 = 0; i02 < ne02; ++i02) {
4772 for (int64_t i = ir0; i < ir1; ++i) {
4773 const int64_t i12 = i03%ne12;
4774 const int64_t i11 = i02%ne11;
4775 const int64_t i10 = i;
4776
4777 const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4778
4779 GGML_ASSERT(i1 >= 0 && i1 < ne1);
4780
4781 from_float(
4782 (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4783 ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4784 }
4785 }
4786 }
4787}
4788
4789void ggml_compute_forward_set_rows(
4790 const ggml_compute_params * params,
4791 ggml_tensor * dst) {
4792
4793 const ggml_tensor * src0 = dst->src[0];
4794 const ggml_tensor * src1 = dst->src[1];
4795
4796 switch (src0->type) {
4797 case GGML_TYPE_F32:
4798 {
4799 if (src1->type == GGML_TYPE_I64) {
4800 ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4801 } else if (src1->type == GGML_TYPE_I32) {
4802 ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4803 } else {
4804 GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4805 }
4806 } break;
4807 default:
4808 {
4809 GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4810 }
4811 }
4812}
4813
4814// ggml_compute_forward_get_rows_back
4815
4816static void ggml_compute_forward_get_rows_back_f32_f16(
4817 const ggml_compute_params * params,
4818 ggml_tensor * dst) {
4819
4820 const ggml_tensor * src0 = dst->src[0];
4821 const ggml_tensor * src1 = dst->src[1];
4822
4823 if (params->ith != 0) {
4824 return;
4825 }
4826
4827 GGML_ASSERT(ggml_is_contiguous(dst));
4828
4829 // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4830
4831 memset(s: dst->data, c: 0, n: ggml_nbytes(tensor: dst));
4832
4833 const int nc = src0->ne[0];
4834 const int nr = ggml_nelements(tensor: src1);
4835
4836 GGML_ASSERT( dst->ne[0] == nc);
4837 GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
4838
4839 for (int i = 0; i < nr; ++i) {
4840 const int r = ((int32_t *) src1->data)[i];
4841
4842 for (int j = 0; j < nc; ++j) {
4843 ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4844 ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
4845 }
4846 }
4847}
4848
4849static void ggml_compute_forward_get_rows_back_f32(
4850 const ggml_compute_params * params,
4851 ggml_tensor * dst) {
4852
4853 const ggml_tensor * src0 = dst->src[0];
4854 const ggml_tensor * src1 = dst->src[1];
4855
4856 if (params->ith != 0) {
4857 return;
4858 }
4859
4860 GGML_ASSERT(ggml_is_contiguous(dst));
4861
4862 // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4863
4864 memset(s: dst->data, c: 0, n: ggml_nbytes(tensor: dst));
4865
4866 const int nc = src0->ne[0];
4867 const int nr = ggml_nelements(tensor: src1);
4868
4869 GGML_ASSERT( dst->ne[0] == nc);
4870 GGML_ASSERT(src0->nb[0] == sizeof(float));
4871
4872 for (int i = 0; i < nr; ++i) {
4873 const int r = ((int32_t *) src1->data)[i];
4874
4875 ggml_vec_add_f32(n: nc,
4876 z: (float *) ((char *) dst->data + r*dst->nb[1]),
4877 x: (float *) ((char *) dst->data + r*dst->nb[1]),
4878 y: (float *) ((char *) src0->data + i*src0->nb[1]));
4879 }
4880}
4881
4882void ggml_compute_forward_get_rows_back(
4883 const ggml_compute_params * params,
4884 ggml_tensor * dst) {
4885
4886 const ggml_tensor * src0 = dst->src[0];
4887
4888 switch (src0->type) {
4889 case GGML_TYPE_F16:
4890 {
4891 ggml_compute_forward_get_rows_back_f32_f16(params, dst);
4892 } break;
4893 case GGML_TYPE_F32:
4894 {
4895 ggml_compute_forward_get_rows_back_f32(params, dst);
4896 } break;
4897 default:
4898 {
4899 GGML_ABORT("fatal error");
4900 }
4901 }
4902
4903 //static bool first = true;
4904 //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4905 //if (first) {
4906 // first = false;
4907 //} else {
4908 // for (int k = 0; k < dst->ne[1]; ++k) {
4909 // for (int j = 0; j < dst->ne[0]/16; ++j) {
4910 // for (int i = 0; i < 16; ++i) {
4911 // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4912 // }
4913 // printf("\n");
4914 // }
4915 // printf("\n");
4916 // }
4917 // printf("\n");
4918 // exit(0);
4919 //}
4920}
4921
4922// ggml_compute_forward_diag
4923
4924static void ggml_compute_forward_diag_f32(
4925 const ggml_compute_params * params,
4926 ggml_tensor * dst) {
4927
4928 const ggml_tensor * src0 = dst->src[0];
4929
4930 if (params->ith != 0) {
4931 return;
4932 }
4933
4934 // TODO: handle transposed/permuted matrices
4935
4936 GGML_TENSOR_UNARY_OP_LOCALS
4937
4938 GGML_ASSERT(ne00 == ne0);
4939 GGML_ASSERT(ne00 == ne1);
4940 GGML_ASSERT(ne01 == 1);
4941 GGML_ASSERT(ne02 == ne2);
4942 GGML_ASSERT(ne03 == ne3);
4943
4944 GGML_ASSERT(nb00 == sizeof(float));
4945 GGML_ASSERT(nb0 == sizeof(float));
4946
4947 for (int i3 = 0; i3 < ne3; i3++) {
4948 for (int i2 = 0; i2 < ne2; i2++) {
4949 for (int i1 = 0; i1 < ne1; i1++) {
4950 float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
4951 float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
4952 for (int i0 = 0; i0 < i1; i0++) {
4953 d[i0] = 0;
4954 }
4955 d[i1] = s[i1];
4956 for (int i0 = i1+1; i0 < ne0; i0++) {
4957 d[i0] = 0;
4958 }
4959 }
4960 }
4961 }
4962}
4963
4964void ggml_compute_forward_diag(
4965 const ggml_compute_params * params,
4966 ggml_tensor * dst) {
4967
4968 const ggml_tensor * src0 = dst->src[0];
4969
4970 switch (src0->type) {
4971 case GGML_TYPE_F32:
4972 {
4973 ggml_compute_forward_diag_f32(params, dst);
4974 } break;
4975 default:
4976 {
4977 GGML_ABORT("fatal error");
4978 }
4979 }
4980}
4981
4982// ggml_compute_forward_diag_mask_inf
4983
4984static void ggml_compute_forward_diag_mask_f32(
4985 const ggml_compute_params * params,
4986 ggml_tensor * dst,
4987 const float value) {
4988
4989 const ggml_tensor * src0 = dst->src[0];
4990
4991 const int ith = params->ith;
4992 const int nth = params->nth;
4993
4994 const int n_past = ((int32_t *) dst->op_params)[0];
4995 const bool inplace = src0->data == dst->data;
4996
4997 GGML_ASSERT(n_past >= 0);
4998
4999 if (!inplace) {
5000 if (ith == 0) {
5001 // memcpy needs to be synchronized across threads to avoid race conditions.
5002 // => do it in INIT phase
5003 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5004 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
5005 memcpy(
5006 dest: ((char *) dst->data),
5007 src: ((char *) src0->data),
5008 n: ggml_nbytes(tensor: dst));
5009 }
5010 ggml_barrier(tp: params->threadpool);
5011 }
5012
5013 // TODO: handle transposed/permuted matrices
5014
5015 const int n = ggml_nrows(tensor: src0);
5016 const int nc = src0->ne[0];
5017 const int nr = src0->ne[1];
5018 const int nz = n/nr;
5019
5020 GGML_ASSERT( dst->nb[0] == sizeof(float));
5021 GGML_ASSERT(src0->nb[0] == sizeof(float));
5022
5023 for (int k = 0; k < nz; k++) {
5024 for (int j = ith; j < nr; j += nth) {
5025 for (int i = n_past; i < nc; i++) {
5026 if (i > n_past + j) {
5027 *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
5028 }
5029 }
5030 }
5031 }
5032}
5033
5034void ggml_compute_forward_diag_mask_inf(
5035 const ggml_compute_params * params,
5036 ggml_tensor * dst) {
5037
5038 const ggml_tensor * src0 = dst->src[0];
5039
5040 switch (src0->type) {
5041 case GGML_TYPE_F32:
5042 {
5043 ggml_compute_forward_diag_mask_f32(params, dst, value: -INFINITY);
5044 } break;
5045 default:
5046 {
5047 GGML_ABORT("fatal error");
5048 }
5049 }
5050}
5051
5052void ggml_compute_forward_diag_mask_zero(
5053 const ggml_compute_params * params,
5054 ggml_tensor * dst) {
5055
5056 const ggml_tensor * src0 = dst->src[0];
5057
5058 switch (src0->type) {
5059 case GGML_TYPE_F32:
5060 {
5061 ggml_compute_forward_diag_mask_f32(params, dst, value: 0);
5062 } break;
5063 default:
5064 {
5065 GGML_ABORT("fatal error");
5066 }
5067 }
5068}
5069
5070// ggml_compute_forward_soft_max
5071
5072static void ggml_compute_forward_soft_max_f32(
5073 const ggml_compute_params * params,
5074 ggml_tensor * dst) {
5075
5076 const ggml_tensor * src0 = dst->src[0];
5077 const ggml_tensor * src1 = dst->src[1];
5078 const ggml_tensor * src2 = dst->src[2];
5079
5080 assert(ggml_is_contiguous(dst));
5081 assert(ggml_are_same_shape(src0, dst));
5082
5083 float scale = 1.0f;
5084 float max_bias = 0.0f;
5085
5086 memcpy(dest: &scale, src: (float *) dst->op_params + 0, n: sizeof(float));
5087 memcpy(dest: &max_bias, src: (float *) dst->op_params + 1, n: sizeof(float));
5088
5089 const int ith = params->ith;
5090 const int nth = params->nth;
5091
5092 GGML_TENSOR_UNARY_OP_LOCALS
5093
5094 const int64_t nb11 = src1 ? src1->nb[1] : 1;
5095 const int64_t nb12 = src1 ? src1->nb[2] : 1;
5096 const int64_t nb13 = src1 ? src1->nb[3] : 1;
5097
5098 const int64_t ne12 = src1 ? src1->ne[2] : 1;
5099 const int64_t ne13 = src1 ? src1->ne[3] : 1;
5100
5101 // TODO: is this supposed to be ceil instead of floor?
5102 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
5103 const uint32_t n_head = ne02;
5104 const uint32_t n_head_log2 = 1u << (uint32_t) floor(x: log2(x: n_head));
5105
5106 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
5107 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
5108
5109 float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5110
5111 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5112
5113 // sinks
5114 const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5115
5116 for (int64_t i03 = 0; i03 < ne03; i03++) {
5117 for (int64_t i02 = 0; i02 < ne02; i02++) {
5118 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5119 const int64_t i11 = i01;
5120 const int64_t i12 = i02%ne12;
5121 const int64_t i13 = i03%ne13;
5122
5123 // ALiBi
5124 const uint32_t h = i02; // head
5125 const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(x: m0, y: h + 1) : powf(x: m1, y: 2*(h - n_head_log2) + 1) : 1.0f;
5126
5127 float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5128 float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5129
5130 // broadcast the mask across rows
5131 ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5132 float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5133
5134 ggml_vec_cpy_f32 (n: ne00, y: wp, x: sp);
5135 ggml_vec_scale_f32(n: ne00, y: wp, v: scale);
5136 if (mp_f32) {
5137 if (use_f16) {
5138 for (int i = 0; i < ne00; ++i) {
5139 wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5140 }
5141 } else {
5142 for (int i = 0; i < ne00; ++i) {
5143 wp[i] += slope*mp_f32[i];
5144 }
5145 }
5146 }
5147
5148#ifndef NDEBUG
5149 for (int i = 0; i < ne00; ++i) {
5150 //printf("p[%d] = %f\n", i, p[i]);
5151 assert(!isnan(wp[i]));
5152 }
5153#endif
5154
5155 float max = -INFINITY;
5156 ggml_vec_max_f32(n: ne00, s: &max, x: wp);
5157
5158 // if we have sinks, make a correction as if they were included in the softmax
5159 if (sk) {
5160 max = MAX(max, sk[i02]);
5161 }
5162
5163 ggml_float sum = ggml_vec_soft_max_f32(n: ne00, y: dp, x: wp, max);
5164 assert(sum > 0.0);
5165
5166 if (sk) {
5167 sum += (ggml_float) expf(x: sk[i02] - max);
5168 }
5169
5170 sum = 1.0/sum;
5171 ggml_vec_scale_f32(n: ne00, y: dp, v: sum);
5172
5173#ifndef NDEBUG
5174 for (int i = 0; i < ne00; ++i) {
5175 assert(!isnan(dp[i]));
5176 assert(!isinf(dp[i]));
5177 }
5178#endif
5179 }
5180 }
5181 }
5182}
5183
5184void ggml_compute_forward_soft_max(
5185 const ggml_compute_params * params,
5186 ggml_tensor * dst) {
5187
5188 const ggml_tensor * src0 = dst->src[0];
5189
5190 switch (src0->type) {
5191 case GGML_TYPE_F32:
5192 {
5193 ggml_compute_forward_soft_max_f32(params, dst);
5194 } break;
5195 default:
5196 {
5197 GGML_ABORT("fatal error");
5198 }
5199 }
5200}
5201
5202
5203// ggml_compute_forward_soft_max_ext_back
5204
5205static void ggml_compute_forward_soft_max_ext_back_f32(
5206 const ggml_compute_params * params,
5207 ggml_tensor * dst) {
5208
5209 const ggml_tensor * src0 = dst->src[0];
5210 const ggml_tensor * src1 = dst->src[1];
5211
5212 GGML_ASSERT(ggml_is_contiguous(src0));
5213 GGML_ASSERT(ggml_is_contiguous(src1));
5214 GGML_ASSERT(ggml_is_contiguous(dst));
5215 GGML_ASSERT(ggml_are_same_shape(src0, dst));
5216 GGML_ASSERT(ggml_are_same_shape(src1, dst));
5217
5218 float scale = 1.0f;
5219 float max_bias = 0.0f;
5220
5221 memcpy(dest: &scale, src: (const float *) dst->op_params + 0, n: sizeof(float));
5222 memcpy(dest: &max_bias, src: (const float *) dst->op_params + 1, n: sizeof(float));
5223
5224 GGML_ASSERT(max_bias == 0.0f);
5225
5226 // TODO: handle transposed/permuted matrices
5227
5228 const int ith = params->ith;
5229 const int nth = params->nth;
5230
5231 const int nc = src0->ne[0];
5232 const int nr = ggml_nrows(tensor: src0);
5233
5234 // rows per thread
5235 const int dr = (nr + nth - 1)/nth;
5236
5237 // row range for this thread
5238 const int ir0 = dr*ith;
5239 const int ir1 = MIN(ir0 + dr, nr);
5240
5241 for (int i1 = ir0; i1 < ir1; i1++) {
5242 float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
5243 float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
5244 float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
5245
5246#ifndef NDEBUG
5247 for (int i = 0; i < nc; ++i) {
5248 //printf("p[%d] = %f\n", i, p[i]);
5249 assert(!isnan(dy[i]));
5250 assert(!isnan(y[i]));
5251 }
5252#endif
5253 // Jii = yi - yi*yi
5254 // Jij = -yi*yj
5255 // J = diag(y)-y.T*y
5256 // dx = J * dy
5257 // dxk = sum_i(Jki * dyi)
5258 // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
5259 // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
5260 // dxk = sum_i(-yk*yi * dyi) + yk*dyk
5261 // dxk = -yk * sum_i(yi * dyi) + yk*dyk
5262 // dxk = -yk * dot(y, dy) + yk*dyk
5263 // dxk = yk * (- dot(y, dy) + dyk)
5264 // dxk = yk * (dyk - dot(y, dy))
5265 //
5266 // post-order:
5267 // dot_y_dy := dot(y, dy)
5268 // dx := dy
5269 // dx := dx - dot_y_dy
5270 // dx := dx * y
5271
5272 // linear runtime, no additional memory
5273 float dot_y_dy = 0;
5274 ggml_vec_dot_f32 (n: nc, s: &dot_y_dy, bs: 0, x: y, bx: 0, y: dy, by: 0, nrc: 1);
5275 ggml_vec_cpy_f32 (n: nc, y: dx, x: dy);
5276 ggml_vec_acc1_f32 (n: nc, y: dx, v: -dot_y_dy);
5277 ggml_vec_mul_f32 (n: nc, z: dx, x: dx, y);
5278 ggml_vec_scale_f32(n: nc, y: dx, v: scale);
5279
5280#ifndef NDEBUG
5281 for (int i = 0; i < nc; ++i) {
5282 assert(!isnan(dx[i]));
5283 assert(!isinf(dx[i]));
5284 }
5285#endif
5286 }
5287}
5288
5289void ggml_compute_forward_soft_max_ext_back(
5290 const ggml_compute_params * params,
5291 ggml_tensor * dst) {
5292
5293 const ggml_tensor * src0 = dst->src[0];
5294
5295 switch (src0->type) {
5296 case GGML_TYPE_F32:
5297 {
5298 ggml_compute_forward_soft_max_ext_back_f32(params, dst);
5299 } break;
5300 default:
5301 {
5302 GGML_ABORT("fatal error");
5303 }
5304 }
5305}
5306
5307// ggml_compute_forward_clamp
5308
5309static void ggml_compute_forward_clamp_f32(
5310 const ggml_compute_params * params,
5311 ggml_tensor * dst) {
5312
5313 const ggml_tensor * src0 = dst->src[0];
5314
5315 float min;
5316 float max;
5317 memcpy(dest: &min, src: (float *) dst->op_params + 0, n: sizeof(float));
5318 memcpy(dest: &max, src: (float *) dst->op_params + 1, n: sizeof(float));
5319
5320 const int ith = params->ith;
5321 const int nth = params->nth;
5322
5323 const int n = ggml_nrows(tensor: src0);
5324 const int nc = src0->ne[0];
5325
5326 const size_t nb00 = src0->nb[0];
5327 const size_t nb01 = src0->nb[1];
5328
5329 const size_t nb0 = dst->nb[0];
5330 const size_t nb1 = dst->nb[1];
5331
5332 GGML_ASSERT( nb0 == sizeof(float));
5333 GGML_ASSERT(nb00 == sizeof(float));
5334
5335 for (int j = ith; j < n; j += nth) {
5336 float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
5337 float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5338
5339 for (int i = 0; i < nc; i++) {
5340 dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
5341 }
5342 }
5343}
5344
5345static void ggml_compute_forward_clamp_f16(
5346 const ggml_compute_params * params,
5347 ggml_tensor * dst) {
5348
5349 const ggml_tensor * src0 = dst->src[0];
5350
5351 float min;
5352 float max;
5353 memcpy(dest: &min, src: (float *) dst->op_params + 0, n: sizeof(float));
5354 memcpy(dest: &max, src: (float *) dst->op_params + 1, n: sizeof(float));
5355
5356 const int ith = params->ith;
5357 const int nth = params->nth;
5358
5359 const int n = ggml_nrows(tensor: src0);
5360 const int nc = src0->ne[0];
5361
5362 const size_t nb00 = src0->nb[0];
5363 const size_t nb01 = src0->nb[1];
5364
5365 const size_t nb0 = dst->nb[0];
5366 const size_t nb1 = dst->nb[1];
5367
5368 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5369 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5370
5371 for (int j = ith; j < n; j += nth) {
5372 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5373 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5374
5375 for (int i = 0; i < nc; i++) {
5376 float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5377 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5378 }
5379 }
5380}
5381
5382void ggml_compute_forward_clamp(
5383 const ggml_compute_params * params,
5384 ggml_tensor * dst) {
5385
5386 const ggml_tensor * src0 = dst->src[0];
5387
5388 switch (src0->type) {
5389 case GGML_TYPE_F32:
5390 {
5391 ggml_compute_forward_clamp_f32(params, dst);
5392 } break;
5393 case GGML_TYPE_F16:
5394 {
5395 ggml_compute_forward_clamp_f16(params, dst);
5396 } break;
5397 case GGML_TYPE_BF16:
5398 case GGML_TYPE_Q4_0:
5399 case GGML_TYPE_Q4_1:
5400 case GGML_TYPE_Q5_0:
5401 case GGML_TYPE_Q5_1:
5402 case GGML_TYPE_Q8_0:
5403 case GGML_TYPE_Q8_1:
5404 case GGML_TYPE_MXFP4:
5405 case GGML_TYPE_Q2_K:
5406 case GGML_TYPE_Q3_K:
5407 case GGML_TYPE_Q4_K:
5408 case GGML_TYPE_Q5_K:
5409 case GGML_TYPE_Q6_K:
5410 case GGML_TYPE_TQ1_0:
5411 case GGML_TYPE_TQ2_0:
5412 case GGML_TYPE_IQ2_XXS:
5413 case GGML_TYPE_IQ2_XS:
5414 case GGML_TYPE_IQ3_XXS:
5415 case GGML_TYPE_IQ1_S:
5416 case GGML_TYPE_IQ1_M:
5417 case GGML_TYPE_IQ4_NL:
5418 case GGML_TYPE_IQ4_XS:
5419 case GGML_TYPE_IQ3_S:
5420 case GGML_TYPE_IQ2_S:
5421 case GGML_TYPE_Q8_K:
5422 case GGML_TYPE_I8:
5423 case GGML_TYPE_I16:
5424 case GGML_TYPE_I32:
5425 case GGML_TYPE_I64:
5426 case GGML_TYPE_F64:
5427 case GGML_TYPE_COUNT:
5428 {
5429 GGML_ABORT("fatal error");
5430 }
5431 }
5432}
5433
5434// ggml_compute_forward_rope
5435
5436static float rope_yarn_ramp(const float low, const float high, const int i0) {
5437 const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
5438 return 1 - MIN(1, MAX(0, y));
5439}
5440
5441// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
5442// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
5443static void rope_yarn(
5444 float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
5445 float * cos_theta, float * sin_theta) {
5446 // Get n-d rotational scaling corrected for extrapolation
5447 float theta_interp = freq_scale * theta_extrap;
5448 float theta = theta_interp;
5449 if (ext_factor != 0.0f) {
5450 float ramp_mix = rope_yarn_ramp(low: corr_dims[0], high: corr_dims[1], i0) * ext_factor;
5451 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5452
5453 // Get n-d magnitude scaling corrected for interpolation
5454 mscale *= 1.0f + 0.1f * logf(x: 1.0f / freq_scale);
5455 }
5456 *cos_theta = cosf(x: theta) * mscale;
5457 *sin_theta = sinf(x: theta) * mscale;
5458}
5459
5460static void ggml_rope_cache_init(
5461 float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5462 float * cache, float sin_sign, float theta_scale) {
5463 // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5464 float theta = theta_base;
5465 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5466 const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5467 rope_yarn(
5468 theta_extrap: theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, cos_theta: &cache[i0 + 0], sin_theta: &cache[i0 + 1]
5469 );
5470 cache[i0 + 1] *= sin_sign;
5471
5472 theta *= theta_scale;
5473 }
5474}
5475
5476static void ggml_mrope_cache_init(
5477 float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5478 float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5479 float * cache, float sin_sign, float theta_scale) {
5480 // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5481 float theta_t = theta_base_t;
5482 float theta_h = theta_base_h;
5483 float theta_w = theta_base_w;
5484 float theta_e = theta_base_e; // extra position id for vision encoder
5485 int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5486 int sec_w = sections[1] + sections[0];
5487 int sec_e = sections[2] + sec_w;
5488 GGML_ASSERT(sect_dims <= ne0);
5489
5490 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5491 const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5492
5493 int sector = (i0 / 2) % sect_dims;
5494 if (indep_sects) {
5495 // compute theta independently for each dim sections
5496 // (i.e. reset corresponding theta when `i0` go from one section to another)
5497 if (sector == 0) {
5498 theta_t = theta_base_t;
5499 }
5500 else if (sector == sections[0]) {
5501 theta_h = theta_base_h;;
5502 }
5503 else if (sector == sec_w) {
5504 theta_w = theta_base_w;
5505 }
5506 else if (sector == sec_e) {
5507 theta_e = theta_base_e;
5508 }
5509 }
5510
5511 float theta = theta_t;
5512 if (is_imrope) { // qwen3vl apply interleaved mrope
5513 if (sector % 3 == 1 && sector < 3 * sections[1]) {
5514 theta = theta_h;
5515 } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5516 theta = theta_w;
5517 } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5518 theta = theta_t;
5519 } else {
5520 theta = theta_e;
5521 }
5522 } else {
5523 if (sector >= sections[0] && sector < sec_w) {
5524 theta = theta_h;
5525 }
5526 else if (sector >= sec_w && sector < sec_w + sections[2]) {
5527 theta = theta_w;
5528 }
5529 else if (sector >= sec_w + sections[2]) {
5530 theta = theta_e;
5531 }
5532 }
5533
5534 rope_yarn(
5535 theta_extrap: theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, cos_theta: &cache[i0 + 0], sin_theta: &cache[i0 + 1]
5536 );
5537 cache[i0 + 1] *= sin_sign;
5538
5539 theta_t *= theta_scale;
5540 theta_w *= theta_scale;
5541 theta_h *= theta_scale;
5542 theta_e *= theta_scale;
5543 }
5544}
5545
5546static void ggml_compute_forward_rope_f32(
5547 const ggml_compute_params * params,
5548 ggml_tensor * dst,
5549 const bool forward) {
5550
5551 const ggml_tensor * src0 = dst->src[0];
5552 const ggml_tensor * src1 = dst->src[1];
5553 const ggml_tensor * src2 = dst->src[2];
5554
5555 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5556 int sections[4];
5557
5558 //const int n_past = ((int32_t *) dst->op_params)[0];
5559 const int n_dims = ((int32_t *) dst->op_params)[1];
5560 const int mode = ((int32_t *) dst->op_params)[2];
5561 //const int n_ctx = ((int32_t *) dst->op_params)[3];
5562 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5563
5564 memcpy(dest: &freq_base, src: (int32_t *) dst->op_params + 5, n: sizeof(float));
5565 memcpy(dest: &freq_scale, src: (int32_t *) dst->op_params + 6, n: sizeof(float));
5566 memcpy(dest: &ext_factor, src: (int32_t *) dst->op_params + 7, n: sizeof(float));
5567 memcpy(dest: &attn_factor, src: (int32_t *) dst->op_params + 8, n: sizeof(float));
5568 memcpy(dest: &beta_fast, src: (int32_t *) dst->op_params + 9, n: sizeof(float));
5569 memcpy(dest: &beta_slow, src: (int32_t *) dst->op_params + 10, n: sizeof(float));
5570 memcpy(dest: &sections, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4);
5571
5572 GGML_TENSOR_UNARY_OP_LOCALS
5573
5574 //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5575 //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5576
5577 GGML_ASSERT(nb00 == sizeof(float));
5578
5579 const int ith = params->ith;
5580 const int nth = params->nth;
5581
5582 const int nr = ggml_nrows(tensor: dst);
5583
5584 GGML_ASSERT(n_dims <= ne0);
5585 GGML_ASSERT(n_dims % 2 == 0);
5586
5587 // rows per thread
5588 const int dr = (nr + nth - 1)/nth;
5589
5590 // row range for this thread
5591 const int ir0 = dr*ith;
5592 const int ir1 = MIN(ir0 + dr, nr);
5593
5594 // row index used to determine which thread to use
5595 int ir = 0;
5596
5597 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
5598
5599 float corr_dims[2];
5600 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims);
5601
5602 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5603 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5604 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5605 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5606
5607 if (is_mrope) {
5608 GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5609 }
5610
5611 if (is_vision) {
5612 GGML_ASSERT(n_dims == ne0/2);
5613 }
5614
5615 const float * freq_factors = NULL;
5616 if (src2 != NULL) {
5617 GGML_ASSERT(src2->type == GGML_TYPE_F32);
5618 GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5619 freq_factors = (const float *) src2->data;
5620 }
5621
5622 // backward process uses inverse rotation by cos and sin.
5623 // cos and sin build a rotation matrix, where the inverse is the transpose.
5624 // this essentially just switches the sign of sin.
5625 const float sin_sign = forward ? 1.0f : -1.0f;
5626
5627 const int32_t * pos = (const int32_t *) src1->data;
5628
5629 for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5630 for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5631
5632 float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5633 if (!is_mrope) {
5634 const int64_t p = pos[i2];
5635 ggml_rope_cache_init(theta_base: p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, mscale: attn_factor, cache, sin_sign, theta_scale);
5636 }
5637 else {
5638 const int64_t p_t = pos[i2];
5639 const int64_t p_h = pos[i2 + ne2];
5640 const int64_t p_w = pos[i2 + ne2 * 2];
5641 const int64_t p_e = pos[i2 + ne2 * 3];
5642 ggml_mrope_cache_init(
5643 theta_base_t: p_t, theta_base_h: p_h, theta_base_w: p_w, theta_base_e: p_e, sections, is_imrope, indep_sects: is_vision,
5644 freq_scale, freq_factors, corr_dims, ne0, ext_factor, mscale: attn_factor, cache, sin_sign, theta_scale);
5645 }
5646
5647 for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5648 if (ir++ < ir0) continue;
5649 if (ir > ir1) break;
5650
5651 if (is_neox || is_mrope) {
5652 if (is_vision){
5653 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5654 const int64_t ic = i0/2;
5655
5656 const float cos_theta = cache[i0 + 0];
5657 const float sin_theta = cache[i0 + 1];
5658
5659 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5660 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5661
5662 const float x0 = src[0];
5663 const float x1 = src[n_dims];
5664
5665 dst_data[0] = x0*cos_theta - x1*sin_theta;
5666 dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5667 }
5668 } else {
5669 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5670 const int64_t ic = i0/2;
5671
5672 const float cos_theta = cache[i0 + 0];
5673 const float sin_theta = cache[i0 + 1];
5674
5675 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5676 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5677
5678 const float x0 = src[0];
5679 const float x1 = src[n_dims/2];
5680
5681 dst_data[0] = x0*cos_theta - x1*sin_theta;
5682 dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5683 }
5684 }
5685 } else {
5686 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5687 const float cos_theta = cache[i0 + 0];
5688 const float sin_theta = cache[i0 + 1];
5689
5690 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5691 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5692
5693 const float x0 = src[0];
5694 const float x1 = src[1];
5695
5696 dst_data[0] = x0*cos_theta - x1*sin_theta;
5697 dst_data[1] = x0*sin_theta + x1*cos_theta;
5698 }
5699 }
5700
5701 if (is_vision) {
5702 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5703 const int64_t ic = i0/2;
5704
5705 const float cos_theta = cache[i0 + 0];
5706 const float sin_theta = cache[i0 + 1];
5707
5708 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5709 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5710
5711 const float x0 = src[0];
5712 const float x1 = src[n_dims];
5713
5714 dst_data[0] = x0*cos_theta - x1*sin_theta;
5715 dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5716 }
5717 } else {
5718 // fill the remain channels with data from src tensor
5719 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5720 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5721 float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5722
5723 dst_data[0] = src[0];
5724 dst_data[1] = src[1];
5725 }
5726 }
5727 }
5728 }
5729 }
5730}
5731
5732// TODO: deduplicate f16/f32 code
5733static void ggml_compute_forward_rope_f16(
5734 const ggml_compute_params * params,
5735 ggml_tensor * dst,
5736 const bool forward) {
5737
5738 const ggml_tensor * src0 = dst->src[0];
5739 const ggml_tensor * src1 = dst->src[1];
5740 const ggml_tensor * src2 = dst->src[2];
5741
5742 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5743 int sections[4];
5744
5745 //const int n_past = ((int32_t *) dst->op_params)[0];
5746 const int n_dims = ((int32_t *) dst->op_params)[1];
5747 const int mode = ((int32_t *) dst->op_params)[2];
5748 //const int n_ctx = ((int32_t *) dst->op_params)[3];
5749 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5750 memcpy(dest: &freq_base, src: (int32_t *) dst->op_params + 5, n: sizeof(float));
5751 memcpy(dest: &freq_scale, src: (int32_t *) dst->op_params + 6, n: sizeof(float));
5752 memcpy(dest: &ext_factor, src: (int32_t *) dst->op_params + 7, n: sizeof(float));
5753 memcpy(dest: &attn_factor, src: (int32_t *) dst->op_params + 8, n: sizeof(float));
5754 memcpy(dest: &beta_fast, src: (int32_t *) dst->op_params + 9, n: sizeof(float));
5755 memcpy(dest: &beta_slow, src: (int32_t *) dst->op_params + 10, n: sizeof(float));
5756 memcpy(dest: &sections, src: (int32_t *) dst->op_params + 11, n: sizeof(int)*4);
5757
5758
5759 GGML_TENSOR_UNARY_OP_LOCALS
5760
5761 //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5762 //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5763
5764 GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
5765
5766 const int ith = params->ith;
5767 const int nth = params->nth;
5768
5769 const int nr = ggml_nrows(tensor: dst);
5770
5771 GGML_ASSERT(n_dims <= ne0);
5772 GGML_ASSERT(n_dims % 2 == 0);
5773
5774 // rows per thread
5775 const int dr = (nr + nth - 1)/nth;
5776
5777 // row range for this thread
5778 const int ir0 = dr*ith;
5779 const int ir1 = MIN(ir0 + dr, nr);
5780
5781 // row index used to determine which thread to use
5782 int ir = 0;
5783
5784 const float theta_scale = powf(x: freq_base, y: -2.0f/n_dims);
5785
5786 float corr_dims[2];
5787 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, dims: corr_dims);
5788
5789 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5790 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
5791 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
5792 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5793
5794 if (is_mrope) {
5795 GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5796 }
5797
5798 if (is_vision) {
5799 GGML_ASSERT(n_dims == ne0/2);
5800 }
5801
5802 const float * freq_factors = NULL;
5803 if (src2 != NULL) {
5804 GGML_ASSERT(src2->type == GGML_TYPE_F32);
5805 GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5806 freq_factors = (const float *) src2->data;
5807 }
5808
5809 // backward process uses inverse rotation by cos and sin.
5810 // cos and sin build a rotation matrix, where the inverse is the transpose.
5811 // this essentially just switches the sign of sin.
5812 const float sin_sign = forward ? 1.0f : -1.0f;
5813
5814 const int32_t * pos = (const int32_t *) src1->data;
5815
5816 for (int64_t i3 = 0; i3 < ne3; i3++) {
5817 for (int64_t i2 = 0; i2 < ne2; i2++) {
5818
5819 float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5820 if (!is_mrope) {
5821 const int64_t p = pos[i2];
5822 ggml_rope_cache_init(theta_base: p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, mscale: attn_factor, cache, sin_sign, theta_scale);
5823 }
5824 else {
5825 const int64_t p_t = pos[i2];
5826 const int64_t p_h = pos[i2 + ne2];
5827 const int64_t p_w = pos[i2 + ne2 * 2];
5828 const int64_t p_e = pos[i2 + ne2 * 3];
5829 ggml_mrope_cache_init(
5830 theta_base_t: p_t, theta_base_h: p_h, theta_base_w: p_w, theta_base_e: p_e, sections, is_imrope, indep_sects: is_vision,
5831 freq_scale, freq_factors, corr_dims, ne0, ext_factor, mscale: attn_factor, cache, sin_sign, theta_scale);
5832 }
5833
5834 for (int64_t i1 = 0; i1 < ne1; i1++) {
5835 if (ir++ < ir0) continue;
5836 if (ir > ir1) break;
5837
5838 if (is_neox || is_mrope) {
5839 if (is_vision) {
5840 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5841 const int64_t ic = i0/2;
5842
5843 const float cos_theta = cache[i0 + 0];
5844 const float sin_theta = cache[i0 + 1];
5845
5846 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5847 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5848
5849 const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5850 const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5851
5852 dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5853 dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5854 }
5855 } else {
5856 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5857 const int64_t ic = i0/2;
5858
5859 const float cos_theta = cache[i0 + 0];
5860 const float sin_theta = cache[i0 + 1];
5861
5862 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5863 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5864
5865 const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5866 const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
5867
5868 dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5869 dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5870 }
5871 }
5872 } else {
5873 for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5874 const float cos_theta = cache[i0 + 0];
5875 const float sin_theta = cache[i0 + 1];
5876
5877 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5878 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5879
5880 const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5881 const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
5882
5883 dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5884 dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5885 }
5886 }
5887
5888 if (is_vision) {
5889 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5890 const int64_t ic = i0/2;
5891
5892 const float cos_theta = cache[i0 + 0];
5893 const float sin_theta = cache[i0 + 1];
5894
5895 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5896 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5897
5898 const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
5899 const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
5900
5901 dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
5902 dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
5903 }
5904 } else {
5905 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5906 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5907 ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5908
5909 dst_data[0] = src[0];
5910 dst_data[1] = src[1];
5911 }
5912 }
5913 }
5914 }
5915 }
5916}
5917
5918void ggml_compute_forward_rope(
5919 const ggml_compute_params * params,
5920 ggml_tensor * dst) {
5921
5922 const ggml_tensor * src0 = dst->src[0];
5923
5924 switch (src0->type) {
5925 case GGML_TYPE_F16:
5926 {
5927 ggml_compute_forward_rope_f16(params, dst, forward: true);
5928 } break;
5929 case GGML_TYPE_F32:
5930 {
5931 ggml_compute_forward_rope_f32(params, dst, forward: true);
5932 } break;
5933 default:
5934 {
5935 GGML_ABORT("fatal error");
5936 }
5937 }
5938}
5939
5940// ggml_compute_forward_rope_back
5941
5942void ggml_compute_forward_rope_back(
5943 const ggml_compute_params * params,
5944 ggml_tensor * dst) {
5945
5946 const ggml_tensor * src0 = dst->src[0];
5947
5948 switch (src0->type) {
5949 case GGML_TYPE_F16:
5950 {
5951 ggml_compute_forward_rope_f16(params, dst, forward: false);
5952 } break;
5953 case GGML_TYPE_F32:
5954 {
5955 ggml_compute_forward_rope_f32(params, dst, forward: false);
5956 } break;
5957 default:
5958 {
5959 GGML_ABORT("fatal error");
5960 }
5961 }
5962}
5963
5964// ggml_compute_forward_conv_transpose_1d
5965
5966static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5967 const ggml_compute_params * params,
5968 ggml_tensor * dst) {
5969
5970 const ggml_tensor * src0 = dst->src[0];
5971 const ggml_tensor * src1 = dst->src[1];
5972
5973 GGML_ASSERT(src0->type == GGML_TYPE_F16);
5974 GGML_ASSERT(src1->type == GGML_TYPE_F32);
5975 GGML_ASSERT( dst->type == GGML_TYPE_F32);
5976
5977 GGML_TENSOR_BINARY_OP_LOCALS
5978
5979 const int ith = params->ith;
5980 const int nth = params->nth;
5981
5982 const int nk = ne00*ne01*ne02;
5983
5984 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5985 GGML_ASSERT(nb10 == sizeof(float));
5986
5987 if (ith == 0) {
5988 memset(s: params->wdata, c: 0, n: params->wsize);
5989
5990 // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5991 {
5992 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5993
5994 for (int64_t i02 = 0; i02 < ne02; i02++) {
5995 for (int64_t i01 = 0; i01 < ne01; i01++) {
5996 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
5997 ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
5998 for (int64_t i00 = 0; i00 < ne00; i00++) {
5999 dst_data[i00*ne02 + i02] = src[i00];
6000 }
6001 }
6002 }
6003 }
6004
6005 // permute source data (src1) from (L x Cin) to (Cin x L)
6006 {
6007 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
6008 ggml_fp16_t * dst_data = wdata;
6009
6010 for (int64_t i11 = 0; i11 < ne11; i11++) {
6011 const float * const src = (float *)((char *) src1->data + i11*nb11);
6012 for (int64_t i10 = 0; i10 < ne10; i10++) {
6013 dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
6014 }
6015 }
6016 }
6017
6018 // need to zero dst since we are accumulating into it
6019 memset(s: dst->data, c: 0, n: ggml_nbytes(tensor: dst));
6020 }
6021 ggml_barrier(tp: params->threadpool);
6022
6023 const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6024
6025 // total rows in dst
6026 const int nr = ne1;
6027
6028 // rows per thread
6029 const int dr = (nr + nth - 1)/nth;
6030
6031 // row range for this thread
6032 const int ir0 = dr*ith;
6033 const int ir1 = MIN(ir0 + dr, nr);
6034
6035 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6036 ggml_fp16_t * const wdata_src = wdata + nk;
6037
6038 for (int i1 = ir0; i1 < ir1; i1++) {
6039 float * dst_data = (float *)((char *) dst->data + i1*nb1);
6040 ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
6041 for (int i10 = 0; i10 < ne10; i10++) {
6042 const int i1n = i10*ne11;
6043 for (int i00 = 0; i00 < ne00; i00++) {
6044 float v = 0;
6045 ggml_vec_dot_f16(n: ne02, s: &v, bs: 0,
6046 x: (ggml_fp16_t *) wdata_src + i1n, bx: 0,
6047 y: (ggml_fp16_t *) wdata_kernel + i00*ne02, by: 0, nrc: 1);
6048 dst_data[i10*s0 + i00] += v;
6049 }
6050 }
6051 }
6052}
6053
6054static void ggml_compute_forward_conv_transpose_1d_f32(
6055 const ggml_compute_params * params,
6056 ggml_tensor * dst) {
6057
6058 const ggml_tensor * src0 = dst->src[0];
6059 const ggml_tensor * src1 = dst->src[1];
6060
6061 GGML_ASSERT(src0->type == GGML_TYPE_F32);
6062 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6063 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6064
6065 GGML_TENSOR_BINARY_OP_LOCALS
6066
6067 const int ith = params->ith;
6068 const int nth = params->nth;
6069
6070 const int nk = ne00*ne01*ne02;
6071
6072 GGML_ASSERT(nb00 == sizeof(float));
6073 GGML_ASSERT(nb10 == sizeof(float));
6074
6075 if (ith == 0) {
6076 memset(s: params->wdata, c: 0, n: params->wsize);
6077
6078 // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
6079 {
6080 float * const wdata = (float *) params->wdata + 0;
6081
6082 for (int64_t i02 = 0; i02 < ne02; i02++) {
6083 for (int64_t i01 = 0; i01 < ne01; i01++) {
6084 const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
6085 float * dst_data = wdata + i01*ne00*ne02;
6086 for (int64_t i00 = 0; i00 < ne00; i00++) {
6087 dst_data[i00*ne02 + i02] = src[i00];
6088 }
6089 }
6090 }
6091 }
6092
6093 // prepare source data (src1)
6094 {
6095 float * const wdata = (float *) params->wdata + nk;
6096 float * dst_data = wdata;
6097
6098 for (int64_t i11 = 0; i11 < ne11; i11++) {
6099 const float * const src = (float *)((char *) src1->data + i11*nb11);
6100 for (int64_t i10 = 0; i10 < ne10; i10++) {
6101 dst_data[i10*ne11 + i11] = src[i10];
6102 }
6103 }
6104 }
6105
6106 // need to zero dst since we are accumulating into it
6107 memset(s: dst->data, c: 0, n: ggml_nbytes(tensor: dst));
6108 }
6109 ggml_barrier(tp: params->threadpool);
6110
6111 const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6112
6113 // total rows in dst
6114 const int nr = ne1;
6115
6116 // rows per thread
6117 const int dr = (nr + nth - 1)/nth;
6118
6119 // row range for this thread
6120 const int ir0 = dr*ith;
6121 const int ir1 = MIN(ir0 + dr, nr);
6122
6123 float * const wdata = (float *) params->wdata + 0;
6124 float * const wdata_src = wdata + nk;
6125
6126 for (int i1 = ir0; i1 < ir1; i1++) {
6127 float * dst_data = (float *)((char *) dst->data + i1*nb1);
6128 float * wdata_kernel = wdata + i1*ne02*ne00;
6129 for (int i10 = 0; i10 < ne10; i10++) {
6130 const int i1n = i10*ne11;
6131 for (int i00 = 0; i00 < ne00; i00++) {
6132 float v = 0;
6133 ggml_vec_dot_f32(n: ne02, s: &v, bs: 0,
6134 x: wdata_src + i1n, bx: 0,
6135 y: wdata_kernel + i00*ne02, by: 0, nrc: 1);
6136 dst_data[i10*s0 + i00] += v;
6137 }
6138 }
6139 }
6140}
6141
6142void ggml_compute_forward_conv_transpose_1d(
6143 const ggml_compute_params * params,
6144 ggml_tensor * dst) {
6145
6146 const ggml_tensor * src0 = dst->src[0];
6147
6148 switch (src0->type) {
6149 case GGML_TYPE_F16:
6150 {
6151 ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
6152 } break;
6153 case GGML_TYPE_F32:
6154 {
6155 ggml_compute_forward_conv_transpose_1d_f32(params, dst);
6156 } break;
6157 default:
6158 {
6159 GGML_ABORT("fatal error");
6160 }
6161 }
6162}
6163
6164// ggml_compute_forward_im2col_f32
6165// src0: kernel [OC, IC, KH, KW]
6166// src1: image [N, IC, IH, IW]
6167// dst: result [N, OH, OW, IC*KH*KW]
6168static void ggml_compute_forward_im2col_f32(
6169 const ggml_compute_params * params,
6170 ggml_tensor * dst) {
6171
6172 const ggml_tensor * src0 = dst->src[0];
6173 const ggml_tensor * src1 = dst->src[1];
6174
6175 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6176 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6177
6178 GGML_TENSOR_BINARY_OP_LOCALS;
6179
6180 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6181 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6182 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6183 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6184 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6185 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6186 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6187
6188 const int ith = params->ith;
6189 const int nth = params->nth;
6190
6191 const int64_t N = is_2D ? ne13 : ne12;
6192 const int64_t IC = is_2D ? ne12 : ne11;
6193 const int64_t IH = is_2D ? ne11 : 1;
6194 const int64_t IW = ne10;
6195
6196 const int64_t KH = is_2D ? ne01 : 1;
6197 const int64_t KW = ne00;
6198
6199 const int64_t OH = is_2D ? ne2 : 1;
6200 const int64_t OW = ne1;
6201
6202 int ofs0 = is_2D ? nb13 : nb12;
6203 int ofs1 = is_2D ? nb12 : nb11;
6204
6205 GGML_ASSERT(nb10 == sizeof(float));
6206
6207 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6208 {
6209 float * const wdata = (float *) dst->data;
6210
6211 for (int64_t in = 0; in < N; in++) {
6212 for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6213 for (int64_t iow = 0; iow < OW; iow++) {
6214 for (int64_t iic = ith; iic < IC; iic += nth) {
6215
6216 // micro kernel
6217 float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6218 const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6219
6220 for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6221 for (int64_t ikw = 0; ikw < KW; ikw++) {
6222 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6223 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6224
6225 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6226 dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6227 } else {
6228 dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
6229 }
6230 }
6231 }
6232 }
6233 }
6234 }
6235 }
6236 }
6237}
6238
6239
6240// ggml_compute_forward_im2col_f16
6241// src0: kernel [OC, IC, KH, KW]
6242// src1: image [N, IC, IH, IW]
6243// dst: result [N, OH, OW, IC*KH*KW]
6244static void ggml_compute_forward_im2col_f16(
6245 const ggml_compute_params * params,
6246 ggml_tensor * dst) {
6247
6248 const ggml_tensor * src0 = dst->src[0];
6249 const ggml_tensor * src1 = dst->src[1];
6250
6251 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6252 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6253 GGML_ASSERT( dst->type == GGML_TYPE_F16);
6254
6255 GGML_TENSOR_BINARY_OP_LOCALS;
6256
6257 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6258 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6259 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6260 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6261 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6262 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6263 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6264
6265 const int ith = params->ith;
6266 const int nth = params->nth;
6267
6268 const int64_t N = is_2D ? ne13 : ne12;
6269 const int64_t IC = is_2D ? ne12 : ne11;
6270 const int64_t IH = is_2D ? ne11 : 1;
6271 const int64_t IW = ne10;
6272
6273 const int64_t KH = is_2D ? ne01 : 1;
6274 const int64_t KW = ne00;
6275
6276 const int64_t OH = is_2D ? ne2 : 1;
6277 const int64_t OW = ne1;
6278
6279 int ofs0 = is_2D ? nb13 : nb12;
6280 int ofs1 = is_2D ? nb12 : nb11;
6281
6282 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6283 GGML_ASSERT(nb10 == sizeof(float));
6284
6285 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6286 {
6287 ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6288
6289 for (int64_t in = 0; in < N; in++) {
6290 for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6291 for (int64_t iow = 0; iow < OW; iow++) {
6292 for (int64_t iic = ith; iic < IC; iic += nth) {
6293
6294 // micro kernel
6295 ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6296 const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6297
6298 for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6299 for (int64_t ikw = 0; ikw < KW; ikw++) {
6300 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6301 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6302
6303 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6304 dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6305 } else {
6306 dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6307 }
6308 }
6309 }
6310 }
6311 }
6312 }
6313 }
6314 }
6315}
6316
6317void ggml_compute_forward_im2col(
6318 const ggml_compute_params * params,
6319 ggml_tensor * dst) {
6320 switch (dst->type) {
6321 case GGML_TYPE_F16:
6322 {
6323 ggml_compute_forward_im2col_f16(params, dst);
6324 } break;
6325 case GGML_TYPE_F32:
6326 {
6327 ggml_compute_forward_im2col_f32(params, dst);
6328 } break;
6329 default:
6330 {
6331 GGML_ABORT("fatal error");
6332 }
6333 }
6334}
6335
6336// ggml_compute_forward_im2col_back_f32
6337
6338void ggml_compute_forward_im2col_back_f32(
6339 const ggml_compute_params * params,
6340 ggml_tensor * dst) {
6341
6342 const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6343 const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6344
6345 GGML_ASSERT(src0->type == GGML_TYPE_F32);
6346 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6347 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6348
6349 GGML_TENSOR_BINARY_OP_LOCALS;
6350
6351 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6352 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6353 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6354 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6355 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6356 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6357 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6358
6359 const int ith = params->ith;
6360 const int nth = params->nth;
6361
6362 const int64_t N = is_2D ? ne3 : ne2;
6363 const int64_t IC = is_2D ? ne2 : ne1;
6364 const int64_t IH = is_2D ? ne1 : 1;
6365 const int64_t IW = ne0;
6366
6367 const int64_t KH = is_2D ? ne11 : 1;
6368 const int64_t KW = ne10;
6369
6370 const int64_t OH = is_2D ? ne02 : 1;
6371 const int64_t OW = ne01;
6372
6373 int ofs0 = is_2D ? nb3 : nb2;
6374 int ofs1 = is_2D ? nb2 : nb1;
6375
6376 GGML_ASSERT(nb0 == sizeof(float));
6377
6378 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6379 {
6380 float * const wdata = (float *) dst->data;
6381
6382 for (int64_t in = 0; in < N; in++) {
6383 for (int64_t iic = ith; iic < IC; iic += nth) {
6384 for (int64_t iih = 0; iih < IH; iih++) {
6385 for (int64_t iiw = 0; iiw < IW; iiw++) {
6386
6387 // micro kernel
6388 float grad = 0.0f;
6389 for (int64_t ikh = 0; ikh < KH; ikh++) {
6390 for (int64_t ikw = 0; ikw < KW; ikw++) {
6391 // For s0 > 1 some values were skipped over in the forward pass.
6392 // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6393 const int64_t tmpw = (iiw + p0 - ikw*d0);
6394 if (tmpw % s0 != 0) {
6395 continue;
6396 }
6397 const int64_t iow = tmpw / s0;
6398
6399 // Equivalent logic as above except for s1.
6400 int64_t ioh;
6401 if (is_2D) {
6402 const int64_t tmph = iih + p1 - ikh*d1;
6403
6404 if (tmph % s1 != 0) {
6405 continue;
6406 }
6407
6408 ioh = tmph / s1;
6409 } else {
6410 ioh = 0;
6411 }
6412
6413 if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6414 continue;
6415 }
6416
6417 const float * const grad_in = (const float *) src0->data
6418 + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6419 grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6420 }
6421 }
6422 float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6423 dst_data[iih*IW + iiw] = grad;
6424 }
6425 }
6426 }
6427 }
6428 }
6429}
6430
6431
6432// ggml_compute_forward_im2col_3d_f16
6433// src0: kernel [OC*IC, KD, KH, KW]
6434// src1: image [N*IC, ID, IH, IW]
6435// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6436static void ggml_compute_forward_im2col_3d_f16(
6437 const ggml_compute_params * params,
6438 ggml_tensor * dst) {
6439
6440 const ggml_tensor * src0 = dst->src[0];
6441 const ggml_tensor * src1 = dst->src[1];
6442
6443 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6444 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6445 GGML_ASSERT( dst->type == GGML_TYPE_F16);
6446
6447 GGML_TENSOR_BINARY_OP_LOCALS;
6448
6449 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6450 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6451 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6452 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6453 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6454 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6455 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6456 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6457 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6458 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6459
6460
6461 const int ith = params->ith;
6462 const int nth = params->nth;
6463
6464 const int64_t N = ne13 / IC;
6465 const int64_t ID = ne12;
6466 const int64_t IH = ne11;
6467 const int64_t IW = ne10;
6468
6469 const int64_t OC = ne03 / IC;
6470 GGML_UNUSED(OC);
6471 const int64_t KD = ne02;
6472 const int64_t KH = ne01;
6473 const int64_t KW = ne00;
6474
6475 const int64_t OD = ne3 / N;
6476 const int64_t OH = ne2;
6477 const int64_t OW = ne1;
6478 const int64_t OH_OW = OH*OW;
6479 const int64_t KD_KH_KW = KD*KH*KW;
6480 const int64_t KH_KW = KH*KW;
6481 const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6482
6483 GGML_ASSERT(nb10 == sizeof(float));
6484
6485 // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6486 {
6487 ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6488
6489 for (int64_t in = 0; in < N; in++) {
6490 for (int64_t iod = 0; iod < OD; iod++) {
6491 for (int64_t ioh = 0; ioh < OH; ioh++) {
6492 for (int64_t iow = 0; iow < OW; iow++) {
6493 for (int64_t iic = ith; iic < IC; iic += nth) {
6494
6495 // micro kernel
6496 ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6497 const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6498
6499 for (int64_t ikd = 0; ikd < KD; ikd++) {
6500 for (int64_t ikh = 0; ikh < KH; ikh++) {
6501 for (int64_t ikw = 0; ikw < KW; ikw++) {
6502 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6503 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6504 const int64_t iid = iod*s2 + ikd*d2 - p2;
6505
6506 if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6507 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6508 } else {
6509 const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6510 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6511 }
6512 }
6513 }
6514 }
6515 }
6516 }
6517 }
6518 }
6519 }
6520 }
6521}
6522
6523// ggml_compute_forward_im2col_3d_f32
6524// src0: kernel [OC*IC, KD, KH, KW]
6525// src1: image [N*IC, ID, IH, IW]
6526// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6527static void ggml_compute_forward_im2col_3d_f32(
6528 const ggml_compute_params * params,
6529 ggml_tensor * dst) {
6530
6531 const ggml_tensor * src0 = dst->src[0];
6532 const ggml_tensor * src1 = dst->src[1];
6533
6534 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6535 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6536
6537 GGML_TENSOR_BINARY_OP_LOCALS;
6538
6539 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6540 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6541 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6542 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6543 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6544 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6545 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6546 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6547 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6548 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6549
6550
6551 const int ith = params->ith;
6552 const int nth = params->nth;
6553
6554 const int64_t N = ne13 / IC;
6555 const int64_t ID = ne12;
6556 const int64_t IH = ne11;
6557 const int64_t IW = ne10;
6558
6559 const int64_t OC = ne03 / IC;
6560 GGML_UNUSED(OC);
6561 const int64_t KD = ne02;
6562 const int64_t KH = ne01;
6563 const int64_t KW = ne00;
6564
6565 const int64_t OD = ne3 / N;
6566 const int64_t OH = ne2;
6567 const int64_t OW = ne1;
6568
6569 const int64_t OH_OW = OH*OW;
6570 const int64_t KD_KH_KW = KD*KH*KW;
6571 const int64_t KH_KW = KH*KW;
6572 const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6573
6574 GGML_ASSERT(nb10 == sizeof(float));
6575
6576 // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6577 {
6578 float * const wdata = (float *) dst->data;
6579
6580 for (int64_t in = 0; in < N; in++) {
6581 for (int64_t iod = 0; iod < OD; iod++) {
6582 for (int64_t ioh = 0; ioh < OH; ioh++) {
6583 for (int64_t iow = 0; iow < OW; iow++) {
6584 for (int64_t iic = ith; iic < IC; iic += nth) {
6585
6586 // micro kernel
6587 float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6588 const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6589
6590 for (int64_t ikd = 0; ikd < KD; ikd++) {
6591 for (int64_t ikh = 0; ikh < KH; ikh++) {
6592 for (int64_t ikw = 0; ikw < KW; ikw++) {
6593 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6594 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6595 const int64_t iid = iod*s2 + ikd*d2 - p2;
6596
6597 if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6598 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6599 } else {
6600 const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6601 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6602 }
6603 }
6604 }
6605 }
6606 }
6607 }
6608 }
6609 }
6610 }
6611 }
6612}
6613
6614
6615void ggml_compute_forward_im2col_3d(
6616 const ggml_compute_params * params,
6617 ggml_tensor * dst) {
6618 switch (dst->type) {
6619 case GGML_TYPE_F16:
6620 {
6621 ggml_compute_forward_im2col_3d_f16(params, dst);
6622 } break;
6623 case GGML_TYPE_F32:
6624 {
6625 ggml_compute_forward_im2col_3d_f32(params, dst);
6626 } break;
6627 default:
6628 {
6629 GGML_ABORT("fatal error");
6630 }
6631 }
6632}
6633
6634static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6635 void * a, void * b, float * c) {
6636 const ggml_type_traits * traits = ggml_get_type_traits(type);
6637 struct ggml_tensor src1 = {};
6638 src1.type = type;
6639 src1.ne[0] = k;
6640 src1.ne[1] = m;
6641 src1.ne[2] = 1;
6642 src1.ne[3] = 1;
6643 src1.nb[0] = traits->type_size;
6644 src1.nb[1] = k * traits->type_size;
6645 src1.nb[2] = src1.nb[1];
6646 src1.nb[3] = src1.nb[2];
6647 src1.data = a;
6648
6649 struct ggml_tensor src0 = {};
6650 src0.type = type;
6651 src0.ne[0] = k;
6652 src0.ne[1] = n;
6653 src0.ne[2] = 1;
6654 src0.ne[3] = 1;
6655 src0.nb[0] = traits->type_size;
6656 src0.nb[1] = k * traits->type_size;
6657 src0.nb[2] = src0.nb[1];
6658 src0.nb[3] = src0.nb[2];
6659 src0.data = b;
6660
6661 struct ggml_tensor dst = {};
6662 dst.ne[0] = n;
6663 dst.ne[1] = m;
6664 dst.ne[2] = 1;
6665 dst.ne[3] = 1;
6666 dst.nb[0] = sizeof(float);
6667 dst.nb[1] = n * sizeof(float);
6668 dst.nb[2] = dst.nb[1];
6669 dst.nb[3] = dst.nb[2];
6670 dst.data = c;
6671 dst.src[0] = &src0;
6672 dst.src[1] = &src1;
6673
6674 ggml_compute_forward_mul_mat(params, dst: &dst);
6675}
6676
6677// ggml_compute_forward_conv_2d
6678
6679static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6680 const ggml_tensor * kernel, // [KW, KH, IC, OC]
6681 const ggml_tensor * src, // [W, H, C, N]
6682 ggml_tensor * dst, // [OW, OH, OC, N]
6683 ggml_type kernel_type) {
6684
6685 GGML_ASSERT(ggml_is_contiguous(kernel));
6686 GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6687 GGML_ASSERT(kernel->type == kernel_type);
6688
6689 const ggml_type_traits * traits = ggml_get_type_traits(type: kernel_type);
6690
6691 const int32_t stride_x = dst->op_params[0];
6692 const int32_t stride_y = dst->op_params[1];
6693 const int32_t pad_x = dst->op_params[2];
6694 const int32_t pad_y = dst->op_params[3];
6695 const int32_t dilation_x = dst->op_params[4];
6696 const int32_t dilation_y = dst->op_params[5];
6697
6698 const int64_t c_in = src->ne[2];
6699 const int64_t c_out = kernel->ne[3];
6700 GGML_ASSERT(c_in == kernel->ne[2]);
6701
6702 const int64_t src_w = src->ne[0];
6703 const int64_t src_h = src->ne[1];
6704 const int64_t knl_w = kernel->ne[0];
6705 const int64_t knl_h = kernel->ne[1];
6706 const int64_t dst_w = dst->ne[0];
6707 const int64_t dst_h = dst->ne[1];
6708
6709 const float * src_data = (float *) src->data;
6710 void * knl_data = kernel->data;
6711 float * dst_data = (float *) dst->data;
6712
6713 const int64_t knl_n = knl_w * knl_h * c_in;
6714 const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6715
6716 const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6717 const int64_t batch_size = params->wsize / space_per_patch;
6718 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6719 const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6720
6721 GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6722
6723 void * tmp = params->wdata;
6724
6725 for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6726
6727 const int64_t patch_start_batch = batch_i * patches_per_batch;
6728 const int64_t patch_end_batch = std::min(a: patch_start_batch + patches_per_batch,
6729 b: patch_total);
6730 const int64_t patch_n = patch_end_batch - patch_start_batch;
6731
6732 const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6733 const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6734 const int64_t patch_end = std::min(a: patch_start + patch_per_thread, b: patch_end_batch);
6735
6736 //im2col for a patch
6737 for (int64_t p = patch_start; p < patch_end; ++p) {
6738 const int64_t batch_n = p / (dst_w * dst_h);
6739 const int64_t src_x = (p / dst_w) % dst_h;
6740 const int64_t src_y = p % dst_w;
6741
6742 const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6743 char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6744
6745 for (int64_t ic = 0; ic < c_in; ++ic) {
6746 for (int64_t ky = 0; ky < knl_h; ++ky) {
6747 for (int64_t kx = 0; kx < knl_w; ++kx) {
6748 const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6749 const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6750
6751 int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6752
6753 float src_val;
6754 if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6755 src_val = 0.0f;
6756 } else {
6757 const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6758 src_val = *src_ptr;
6759 }
6760
6761 char * element_ptr = dst_row + dst_idx * traits->type_size;
6762 if (kernel_type == GGML_TYPE_F32) {
6763 *(float *) element_ptr = src_val;
6764 } else if (kernel_type == GGML_TYPE_F16) {
6765 *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6766 }
6767 }
6768 }
6769 }
6770 } // patches handled by this thread
6771
6772 ggml_barrier(tp: params->threadpool);
6773
6774 float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6775
6776 GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6777
6778 // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6779 ggml_call_mul_mat(type: kernel_type, params, m: patch_n, n: c_out, k: knl_n, a: tmp, b: knl_data, c: gemm_output);
6780
6781 ggml_barrier(tp: params->threadpool);
6782
6783
6784 //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6785 const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6786 const int64_t permute_start = params->ith * permute_per_thread;
6787 const int64_t permute_end = std::min(a: permute_start + permute_per_thread, b: patch_n);
6788
6789 for (int64_t i = permute_start; i < permute_end; ++i) {
6790 const int64_t p = patch_start_batch + i;
6791 const int64_t batch_n = p / (dst_w * dst_h);
6792 const int64_t dst_y = (p / dst_w) % dst_h;
6793 const int64_t dst_x = p % dst_w;
6794
6795 for (int64_t oc = 0; oc < c_out; ++oc) {
6796 const float value = gemm_output[i * c_out + oc];
6797 float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6798 *dst_ptr = value;
6799 }
6800 }
6801 }
6802}
6803
6804void ggml_compute_forward_conv_2d(
6805 const ggml_compute_params * params,
6806 ggml_tensor * dst) {
6807
6808 const ggml_tensor * src0 = dst->src[0];
6809 const ggml_tensor * src1 = dst->src[1];
6810
6811 ggml_compute_forward_conv_2d_impl(params, kernel: src0, src: src1, dst, kernel_type: src0->type);
6812}
6813
6814// ggml_compute_forward_conv_3d
6815
6816static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6817 const ggml_tensor * kernel,
6818 const ggml_tensor * src,
6819 ggml_tensor * dst,
6820 ggml_type kernel_type) {
6821
6822 GGML_ASSERT(ggml_is_contiguous(kernel));
6823 GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6824 GGML_ASSERT(kernel->type == kernel_type);
6825
6826 const ggml_type_traits * traits = ggml_get_type_traits(type: kernel_type);
6827
6828 const int32_t s0 = dst->op_params[0];
6829 const int32_t s1 = dst->op_params[1];
6830 const int32_t s2 = dst->op_params[2];
6831 const int32_t p0 = dst->op_params[3];
6832 const int32_t p1 = dst->op_params[4];
6833 const int32_t p2 = dst->op_params[5];
6834 const int32_t d0 = dst->op_params[6];
6835 const int32_t d1 = dst->op_params[7];
6836 const int32_t d2 = dst->op_params[8];
6837 const int32_t c = dst->op_params[9];
6838 const int32_t n = dst->op_params[10];
6839 const int32_t oc = dst->op_params[11];
6840
6841 const int64_t src_w = src->ne[0];
6842 const int64_t src_h = src->ne[1];
6843 const int64_t src_d = src->ne[2];
6844 const int64_t knl_w = kernel->ne[0];
6845 const int64_t knl_h = kernel->ne[1];
6846 const int64_t knl_d = kernel->ne[2];
6847 const int64_t dst_w = dst->ne[0];
6848 const int64_t dst_h = dst->ne[1];
6849 const int64_t dst_d = dst->ne[2];
6850
6851 const float * src_data = (float *) src->data;
6852 void * knl_data = kernel->data;
6853 float * dst_data = (float *) dst->data;
6854
6855 const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6856 const int64_t knl_n_total = knl_n_per_channel * c;
6857 const int64_t patch_total = n * dst_w * dst_h * dst_d;
6858
6859 const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6860 const int64_t batch_size = params->wsize / space_per_patch;
6861 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6862 const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6863
6864 GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6865
6866 void * tmp = params->wdata;
6867
6868 for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6869 const int64_t patch_start_batch = batch_i * patches_per_batch;
6870 const int64_t patch_end_batch = std::min(a: patch_start_batch + patches_per_batch, b: patch_total);
6871 const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6872
6873 const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6874 const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6875 const int64_t patch_end = std::min(a: patch_start + patch_per_thread, b: patch_end_batch);
6876
6877 for (int64_t p = patch_start; p < patch_end; ++p) {
6878 const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6879 const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6880 const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6881 const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6882 const int64_t dst_y = p_in_depth / dst_w;
6883 const int64_t dst_x = p_in_depth % dst_w;
6884
6885 char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6886
6887 for (int64_t ic = 0; ic < c; ++ic) {
6888 for (int64_t kz = 0; kz < knl_d; ++kz) {
6889 for (int64_t ky = 0; ky < knl_h; ++ky) {
6890 for (int64_t kx = 0; kx < knl_w; ++kx) {
6891 const int64_t sz = dst_z * s2 + kz * d2 - p2;
6892 const int64_t sy = dst_y * s1 + ky * d1 - p1;
6893 const int64_t sx = dst_x * s0 + kx * d0 - p0;
6894
6895 int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6896
6897 float src_val;
6898 if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6899 src_val = 0.0f;
6900 } else {
6901 const int64_t cn_idx = batch_idx * c + ic;
6902 const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6903 src_val = *src_ptr;
6904 }
6905
6906 char * element_ptr = dst_row + dst_idx * traits->type_size;
6907 if (kernel_type == GGML_TYPE_F32) {
6908 *(float *)element_ptr = src_val;
6909 } else if (kernel_type == GGML_TYPE_F16) {
6910 *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6911 }
6912 }
6913 }
6914 }
6915 }
6916 }
6917
6918 ggml_barrier(tp: params->threadpool);
6919
6920 float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6921 ggml_call_mul_mat(type: kernel_type, params, m: patch_n_in_batch, n: oc, k: knl_n_total, a: tmp, b: knl_data, c: gemm_output);
6922
6923 ggml_barrier(tp: params->threadpool);
6924
6925 const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6926 const int64_t permute_start = params->ith * permute_per_thread;
6927 const int64_t permute_end = std::min(a: permute_start + permute_per_thread, b: patch_n_in_batch);
6928
6929 for (int64_t i = permute_start; i < permute_end; ++i) {
6930 const int64_t p = patch_start_batch + i;
6931 const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6932 const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6933 const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6934 const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6935 const int64_t dst_y = p_in_depth / dst_w;
6936 const int64_t dst_x = p_in_depth % dst_w;
6937
6938 for (int64_t ioc = 0; ioc < oc; ++ioc) {
6939 const float value = gemm_output[i * oc + ioc];
6940 const int64_t ocn_idx = batch_idx * oc + ioc;
6941 float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6942 *dst_ptr = value;
6943 }
6944 }
6945 }
6946}
6947
6948void ggml_compute_forward_conv_3d(
6949 const ggml_compute_params * params,
6950 ggml_tensor * dst) {
6951 const ggml_tensor * src0 = dst->src[0];
6952 const ggml_tensor * src1 = dst->src[1];
6953 ggml_compute_forward_conv_3d_impl(params, kernel: src0, src: src1, dst, kernel_type: src0->type);
6954}
6955
6956// ggml_compute_forward_conv_transpose_2d
6957
6958void ggml_compute_forward_conv_transpose_2d(
6959 const ggml_compute_params * params,
6960 ggml_tensor * dst) {
6961
6962 const ggml_tensor * src0 = dst->src[0];
6963 const ggml_tensor * src1 = dst->src[1];
6964
6965 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6966 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6967 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6968
6969 GGML_TENSOR_BINARY_OP_LOCALS
6970
6971 const int ith = params->ith;
6972 const int nth = params->nth;
6973
6974 const int nk = ne00*ne01*ne02*ne03;
6975
6976 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6977 GGML_ASSERT(nb10 == sizeof(float));
6978
6979 if (ith == 0) {
6980 memset(s: params->wdata, c: 0, n: params->wsize);
6981
6982 // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
6983 {
6984 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6985
6986 for (int64_t i03 = 0; i03 < ne03; i03++) {
6987 for (int64_t i02 = 0; i02 < ne02; i02++) {
6988 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
6989 ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
6990 for (int64_t i01 = 0; i01 < ne01; i01++) {
6991 for (int64_t i00 = 0; i00 < ne00; i00++) {
6992 dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
6993 }
6994 }
6995 }
6996 }
6997 }
6998
6999 // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
7000 {
7001 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
7002 for (int i12 = 0; i12 < ne12; i12++) {
7003 for (int i11 = 0; i11 < ne11; i11++) {
7004 const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
7005 ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
7006 for (int i10 = 0; i10 < ne10; i10++) {
7007 dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
7008 }
7009 }
7010 }
7011 }
7012
7013 memset(s: dst->data, c: 0, n: ggml_nbytes(tensor: dst));
7014 }
7015 ggml_barrier(tp: params->threadpool);
7016
7017 const int32_t stride = ggml_get_op_params_i32(tensor: dst, i: 0);
7018
7019 // total patches in dst
7020 const int np = ne2;
7021
7022 // patches per thread
7023 const int dp = (np + nth - 1)/nth;
7024
7025 // patch range for this thread
7026 const int ip0 = dp*ith;
7027 const int ip1 = MIN(ip0 + dp, np);
7028
7029 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
7030 ggml_fp16_t * const wdata_src = wdata + nk;
7031
7032 for (int i2 = ip0; i2 < ip1; i2++) { // Cout
7033 float * dst_data = (float *)((char *) dst->data + i2*nb2);
7034 ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
7035 for (int i11 = 0; i11 < ne11; i11++) {
7036 for (int i10 = 0; i10 < ne10; i10++) {
7037 const int i1n = i11*ne10*ne12 + i10*ne12;
7038 for (int i01 = 0; i01 < ne01; i01++) {
7039 for (int i00 = 0; i00 < ne00; i00++) {
7040 float v = 0;
7041 ggml_vec_dot_f16(n: ne03, s: &v, bs: 0,
7042 x: wdata_src + i1n, bx: 0,
7043 y: wdata_kernel + i01*ne00*ne03 + i00*ne03, by: 0, nrc: 1);
7044 dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
7045 }
7046 }
7047 }
7048 }
7049 }
7050}
7051
7052// ggml_compute_forward_conv_2d_dw
7053
7054struct ggml_conv_2d_dw_params {
7055 int64_t channels;
7056 int64_t batch;
7057 int64_t src_w;
7058 int64_t src_h;
7059 int64_t dst_w;
7060 int64_t dst_h;
7061 int64_t knl_w;
7062 int64_t knl_h;
7063 int stride_x;
7064 int stride_y;
7065 int pad_x;
7066 int pad_y;
7067 int dilation_x;
7068 int dilation_y;
7069};
7070
7071static void ggml_compute_forward_conv_2d_dw_cwhn(
7072 const ggml_compute_params * params,
7073 const ggml_tensor * src,
7074 const ggml_tensor * kernel,
7075 ggml_tensor * dst,
7076 const ggml_conv_2d_dw_params & p) {
7077
7078 const int64_t c = p.channels;
7079 const float * knl_data = (const float *)kernel->data;
7080
7081 const int64_t rows_total = p.dst_h * p.batch;
7082 const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
7083 const int64_t row_start = params->ith * rows_per_thread;
7084 const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7085
7086#ifdef GGML_SIMD
7087 #if defined(__ARM_FEATURE_SVE)
7088 const int64_t pkg_size = svcntw();
7089 #else
7090 const int64_t pkg_size = GGML_F32_EPR;
7091 #endif
7092 const int64_t pkg_count = c / pkg_size;
7093 const int64_t c_pkg_end = pkg_count * pkg_size;
7094#else
7095 const int64_t c_pkg_end = 0;
7096#endif
7097
7098 for (int64_t row = row_start; row < row_end; ++row) {
7099 const int64_t dst_y = row % p.dst_h;
7100 const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
7101 for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7102 float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
7103 const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
7104 const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
7105
7106#ifdef GGML_SIMD
7107 // Vectorized loop
7108 for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
7109 GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
7110 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7111 const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7112 if (src_y < 0 || src_y >= p.src_h) {
7113 continue;
7114 }
7115 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7116 const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7117 if (src_x < 0 || src_x >= p.src_w) {
7118 continue;
7119 }
7120 GGML_F32_VEC k = GGML_F32_VEC_LOAD(p: knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
7121 GGML_F32_VEC s = GGML_F32_VEC_LOAD(p: src_data + (src_y * p.src_w + src_x) * c + c_i);
7122 sum = GGML_F32_VEC_FMA(sum, k, s);
7123 }
7124 }
7125 GGML_F32_VEC_STORE(p: dst_data + c_i, a: sum);
7126 }
7127#endif
7128 // Scalar loop
7129 for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
7130 float sum = 0.0f;
7131 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7132 const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7133 if (src_y < 0 || src_y >= p.src_h) {
7134 continue;
7135 }
7136 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7137 const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7138 if (src_x < 0 || src_x >= p.src_w) {
7139 continue;
7140 }
7141 sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
7142 * src_data[(src_y * p.src_w + src_x) * c + c_i];
7143 }
7144 }
7145 dst_data[c_i] = sum;
7146 }
7147 }
7148 }
7149}
7150
7151static void ggml_compute_forward_conv_2d_dw_whcn(
7152 const ggml_compute_params * params,
7153 const ggml_tensor * src,
7154 const ggml_tensor * kernel,
7155 ggml_tensor * dst,
7156 const ggml_conv_2d_dw_params & p) {
7157
7158 const int64_t n = p.channels * p.batch;
7159 const int64_t per_thread = (n + params->nth - 1) / params->nth;
7160 const int64_t start = params->ith * per_thread;
7161 const int64_t end = MIN(start + per_thread, n);
7162
7163 for (int64_t i = start; i < end; ++i) {
7164 const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
7165 const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
7166 float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
7167
7168 for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
7169 for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7170
7171 float sum = 0.0f;
7172 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7173 const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
7174 if (src_y < 0 || src_y >= p.src_h) {
7175 continue;
7176 }
7177 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7178 const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
7179 if (src_x < 0 || src_x >= p.src_w) {
7180 continue;
7181 }
7182 sum += knl_data[knl_y * p.knl_w + knl_x]
7183 * src_data[src_y * p.src_w + src_x];
7184 }
7185 }
7186 dst_data[dst_y * p.dst_w + dst_x] = sum;
7187 }
7188 }
7189 }
7190}
7191
7192void ggml_compute_forward_conv_2d_dw(
7193 const ggml_compute_params * params,
7194 ggml_tensor * dst) {
7195
7196 const ggml_tensor * kernel = dst->src[0];
7197 const ggml_tensor * src = dst->src[1];
7198 ggml_conv_2d_dw_params p;
7199 p.channels = src->ne[2];
7200 p.batch = src->ne[3];
7201 p.src_w = src->ne[0];
7202 p.src_h = src->ne[1];
7203 p.dst_w = dst->ne[0];
7204 p.dst_h = dst->ne[1];
7205 p.knl_w = kernel->ne[0];
7206 p.knl_h = kernel->ne[1];
7207 p.stride_x = dst->op_params[0];
7208 p.stride_y = dst->op_params[1];
7209 p.pad_x = dst->op_params[2];
7210 p.pad_y = dst->op_params[3];
7211 p.dilation_x = dst->op_params[4];
7212 p.dilation_y = dst->op_params[5];
7213
7214 GGML_ASSERT(kernel->ne[3] == p.channels);
7215 GGML_ASSERT(dst->ne[3] == p.batch);
7216
7217 if (ggml_is_contiguous(tensor: src)) {
7218 ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
7219 } else if (ggml_is_contiguous_channels(tensor: src)) {
7220 // kernel should also have channels most contiguous in memory
7221 GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
7222 ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
7223 } else {
7224 GGML_ABORT("non-contiguous memory layout not supported");
7225 }
7226}
7227
7228// ggml_compute_forward_pool_1d_sk_p0
7229
7230static void ggml_compute_forward_pool_1d_sk_p0(
7231 const ggml_compute_params * params,
7232 const ggml_op_pool op,
7233 const int k,
7234 ggml_tensor * dst) {
7235
7236 const ggml_tensor * src = dst->src[0];
7237
7238 assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7239
7240 if (params->ith != 0) {
7241 return;
7242 }
7243
7244 const char * cdata = (const char *)src->data;
7245 const char * const data_end = cdata + ggml_nbytes(tensor: src);
7246 float * drow = (float *)dst->data;
7247
7248 const int64_t rs = dst->ne[0];
7249
7250 while (cdata < data_end) {
7251 const void * srow = (const void *)cdata;
7252 int j = 0;
7253 for (int64_t i = 0; i < rs; ++i) {
7254 switch (op) {
7255 case GGML_OP_POOL_AVG: drow[i] = 0; break;
7256 case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
7257 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7258 }
7259 for (int ki = 0; ki < k; ++ki) {
7260 const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7261 switch (op) {
7262 case GGML_OP_POOL_AVG: drow[i] += srow_j; break;
7263 case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break;
7264 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7265 }
7266 ++j;
7267 }
7268 switch (op) {
7269 case GGML_OP_POOL_AVG: drow[i] /= k; break;
7270 case GGML_OP_POOL_MAX: break;
7271 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7272 }
7273 }
7274
7275 cdata += src->nb[1];
7276 drow += rs;
7277 }
7278}
7279
7280// ggml_compute_forward_pool_1d
7281
7282void ggml_compute_forward_pool_1d(
7283 const ggml_compute_params * params,
7284 ggml_tensor * dst) {
7285
7286 const int32_t * opts = (const int32_t *)dst->op_params;
7287 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7288 const int k0 = opts[1];
7289 const int s0 = opts[2];
7290 const int p0 = opts[3];
7291 GGML_ASSERT(p0 == 0); // padding not supported
7292 GGML_ASSERT(k0 == s0); // only s = k supported
7293
7294 ggml_compute_forward_pool_1d_sk_p0(params, op, k: k0, dst);
7295}
7296
7297// ggml_compute_forward_pool_2d
7298
7299void ggml_compute_forward_pool_2d(
7300 const ggml_compute_params * params,
7301 ggml_tensor * dst) {
7302
7303 const ggml_tensor * src = dst->src[0];
7304
7305 assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7306
7307 if (params->ith != 0) {
7308 return;
7309 }
7310
7311 const int32_t * opts = (const int32_t *)dst->op_params;
7312 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7313 const int k0 = opts[1];
7314 const int k1 = opts[2];
7315 const int s0 = opts[3];
7316 const int s1 = opts[4];
7317 const int p0 = opts[5];
7318 const int p1 = opts[6];
7319 const char * cdata = (const char*)src->data;
7320 const char * const data_end = cdata + ggml_nbytes(tensor: src);
7321
7322 const int64_t px = dst->ne[0];
7323 const int64_t py = dst->ne[1];
7324 const int64_t pa = px * py;
7325
7326 float * dplane = (float *)dst->data;
7327
7328 const int ka = k0 * k1;
7329 const int offset0 = -p0;
7330 const int offset1 = -p1;
7331
7332 while (cdata < data_end) {
7333 for (int oy = 0; oy < py; ++oy) {
7334 float * const drow = dplane + oy * px;
7335 for (int ox = 0; ox < px; ++ox) {
7336 float * const out = drow + ox;
7337 switch (op) {
7338 case GGML_OP_POOL_AVG: *out = 0; break;
7339 case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
7340 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7341 }
7342
7343 const int ix = offset0 + ox * s0;
7344 const int iy = offset1 + oy * s1;
7345
7346 for (int ky = 0; ky < k1; ++ky) {
7347 if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
7348 const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7349 for (int kx = 0; kx < k0; ++kx) {
7350 int j = ix + kx;
7351 if (j < 0 || j >= src->ne[0]) continue;
7352 const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7353 switch (op) {
7354 case GGML_OP_POOL_AVG: *out += srow_j; break;
7355 case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break;
7356 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7357 }
7358 }
7359 }
7360 switch (op) {
7361 case GGML_OP_POOL_AVG: *out /= ka; break;
7362 case GGML_OP_POOL_MAX: break;
7363 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7364 }
7365 }
7366 }
7367
7368 cdata += src->nb[2];
7369 dplane += pa;
7370 }
7371}
7372
7373// ggml_compute_forward_pool_2d_back
7374
7375void ggml_compute_forward_pool_2d_back(
7376 const ggml_compute_params * params,
7377 ggml_tensor * dst) {
7378
7379 const ggml_tensor * src = dst->src[0];
7380 const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
7381
7382 assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
7383
7384 if (params->ith != 0) {
7385 return;
7386 }
7387
7388 const int32_t * opts = (const int32_t *)dst->op_params;
7389 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7390 const int k0 = opts[1];
7391 const int k1 = opts[2];
7392 const int s0 = opts[3];
7393 const int s1 = opts[4];
7394 const int p0 = opts[5];
7395 const int p1 = opts[6];
7396
7397 char * cdata = (char *) dst->data;
7398 const char * cdataf = (const char *) dstf->data;
7399 const char * const data_end = cdata + ggml_nbytes(tensor: dst);
7400
7401 GGML_ASSERT(params->ith == 0);
7402 memset(s: cdata, c: 0, n: ggml_nbytes(tensor: dst));
7403
7404 const int64_t px = src->ne[0];
7405 const int64_t py = src->ne[1];
7406 const int64_t pa = px * py;
7407
7408 const float * splane = (const float *) src->data;
7409
7410 const int ka = k0 * k1;
7411 const int offset0 = -p0;
7412 const int offset1 = -p1;
7413
7414 while (cdata < data_end) {
7415 for (int oy = 0; oy < py; ++oy) {
7416 const float * const srow = splane + oy * px;
7417 for (int ox = 0; ox < px; ++ox) {
7418 const float grad0 = srow[ox];
7419
7420 const int ix = offset0 + ox * s0;
7421 const int iy = offset1 + oy * s1;
7422
7423 if (op == GGML_OP_POOL_MAX) {
7424 float maxval = -FLT_MAX;
7425 int kxmax = -1;
7426 int kymax = -1;
7427
7428 for (int ky = 0; ky < k1; ++ky) {
7429 if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7430 continue;
7431 }
7432 const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
7433 for (int kx = 0; kx < k0; ++kx) {
7434 int j = ix + kx;
7435 if (j < 0 || j >= dst->ne[0]) {
7436 continue;
7437 }
7438
7439 const float val = dst->type == GGML_TYPE_F32 ?
7440 ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7441 if (val <= maxval) {
7442 continue;
7443 }
7444
7445 maxval = val;
7446 kxmax = kx;
7447 kymax = ky;
7448 }
7449 }
7450
7451 if (kxmax == -1 || kymax == -1) {
7452 continue;
7453 }
7454
7455 void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
7456 const int j = ix + kxmax;
7457 if (dst->type == GGML_TYPE_F32) {
7458 ((float *) drow)[j] += grad0;
7459 } else {
7460 ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7461 }
7462 } else if (op == GGML_OP_POOL_AVG) {
7463 const float grad = grad0 / ka;
7464
7465 for (int ky = 0; ky < k1; ++ky) {
7466 if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7467 continue;
7468 }
7469 void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
7470 for (int kx = 0; kx < k0; ++kx) {
7471 int j = ix + kx;
7472 if (j < 0 || j >= dst->ne[0]) {
7473 continue;
7474 }
7475
7476 if (dst->type == GGML_TYPE_F32) {
7477 ((float *) drow)[j] += grad;
7478 } else {
7479 ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
7480 }
7481 }
7482 }
7483 } else {
7484 GGML_ASSERT(false);
7485 }
7486 }
7487 }
7488
7489 cdata += dst->nb[2];
7490 cdataf += dst->nb[2];
7491 splane += pa;
7492 }
7493}
7494
7495// ggml_compute_forward_upscale
7496
7497static void ggml_compute_forward_upscale_f32(
7498 const ggml_compute_params * params,
7499 ggml_tensor * dst) {
7500
7501 const ggml_tensor * src0 = dst->src[0];
7502
7503 GGML_ASSERT(src0->type == GGML_TYPE_F32);
7504
7505 const int ith = params->ith;
7506 const int nth = params->nth;
7507
7508 GGML_TENSOR_UNARY_OP_LOCALS
7509
7510 float sf0 = (float)ne0/src0->ne[0];
7511 float sf1 = (float)ne1/src0->ne[1];
7512 float sf2 = (float)ne2/src0->ne[2];
7513 float sf3 = (float)ne3/src0->ne[3];
7514 float pixel_offset = 0.5f;
7515
7516 const int32_t mode_flags = ggml_get_op_params_i32(tensor: dst, i: 0);
7517 const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7518
7519 if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7520 pixel_offset = 0.0f;
7521 sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7522 sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7523 }
7524
7525 if (mode == GGML_SCALE_MODE_NEAREST) {
7526 for (int64_t i3 = 0; i3 < ne3; i3++) {
7527 const int64_t i03 = i3 / sf3;
7528 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7529 const int64_t i02 = i2 / sf2;
7530 for (int64_t i1 = 0; i1 < ne1; i1++) {
7531 const int64_t i01 = i1 / sf1;
7532 for (int64_t i0 = 0; i0 < ne0; i0++) {
7533 const int64_t i00 = i0 / sf0;
7534
7535 const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7536 float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7537
7538 *y = *x;
7539 }
7540 }
7541 }
7542 }
7543 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7544 for (int64_t i3 = 0; i3 < ne3; i3++) {
7545 const int64_t i03 = i3 / sf3;
7546 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7547 const int64_t i02 = i2 / sf2;
7548 for (int64_t i1 = 0; i1 < ne1; i1++) {
7549 const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7550 int64_t y0 = (int64_t)floorf(x: y);
7551 int64_t y1 = y0 + 1;
7552
7553 y0 = std::max(a: int64_t(0), b: std::min(a: y0, b: ne01 - 1));
7554 y1 = std::max(a: int64_t(0), b: std::min(a: y1, b: ne01 - 1));
7555
7556 float dy = y - (float)y0;
7557 dy = std::max(a: 0.0f, b: std::min(a: dy, b: 1.0f));
7558
7559 for (int64_t i0 = 0; i0 < ne0; i0++) {
7560 const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7561 int64_t x0 = (int64_t)floorf(x: x);
7562 int64_t x1 = x0 + 1;
7563
7564 x0 = std::max(a: int64_t(0), b: std::min(a: x0, b: ne00 - 1));
7565 x1 = std::max(a: int64_t(0), b: std::min(a: x1, b: ne00 - 1));
7566
7567 float dx = x - (float)x0;
7568 dx = std::max(a: 0.0f, b: std::min(a: dx, b: 1.0f));
7569
7570 // fetch the four surrounding pixel values and interpolate
7571 const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7572 const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7573 const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7574 const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7575
7576 const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7577
7578 float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7579 *y_dst = val;
7580 }
7581 }
7582 }
7583 }
7584 } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7585 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7586 const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7587 auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7588 auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7589 auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7590 const float w0 = weight2(x + 1);
7591 const float w1 = weight1(x + 0);
7592 const float w2 = weight1(1 - x);
7593 const float w3 = weight2(2 - x);
7594 return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7595 };
7596
7597 for (int64_t i3 = 0; i3 < ne3; i3++) {
7598 const int64_t i03 = i3 / sf3;
7599 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7600 const int64_t i02 = i2 / sf2;
7601 for (int64_t i1 = 0; i1 < ne1; i1++) {
7602 const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7603 const int64_t y0 = (int64_t)floorf(x: y);
7604 const float dy = y - (float)y0;
7605
7606 for (int64_t i0 = 0; i0 < ne0; i0++) {
7607 const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7608 const int64_t x0 = (int64_t)floorf(x: x);
7609 const float dx = x - (float)x0;
7610
7611 auto p = [=](int64_t x_off, int64_t y_off) -> float {
7612 int64_t i00 = std::max(a: int64_t(0), b: std::min(a: x0 + x_off, b: ne00 - 1));
7613 int64_t i01 = std::max(a: int64_t(0), b: std::min(a: y0 + y_off, b: ne01 - 1));
7614 return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7615 };
7616
7617 const float val = bicubic(
7618 bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7619 bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7620 bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7621 bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7622
7623 float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7624 *y_dst = val;
7625 }
7626 }
7627 }
7628 }
7629 } else {
7630 GGML_ABORT("unsupported upscale mode");
7631 }
7632}
7633
7634void ggml_compute_forward_upscale(
7635 const ggml_compute_params * params,
7636 ggml_tensor * dst) {
7637
7638 const ggml_tensor * src0 = dst->src[0];
7639
7640 switch (src0->type) {
7641 case GGML_TYPE_F32:
7642 {
7643 ggml_compute_forward_upscale_f32(params, dst);
7644 } break;
7645 default:
7646 {
7647 GGML_ABORT("fatal error");
7648 }
7649 }
7650}
7651
7652
7653// ggml_compute_forward_pad
7654
7655static void ggml_compute_forward_pad_f32(
7656 const ggml_compute_params * params,
7657 ggml_tensor * dst) {
7658
7659 const ggml_tensor * src0 = dst->src[0];
7660
7661 GGML_ASSERT(src0->nb[0] == sizeof(float));
7662 GGML_ASSERT( dst->nb[0] == sizeof(float));
7663
7664 const int ith = params->ith;
7665 const int nth = params->nth;
7666
7667 GGML_TENSOR_UNARY_OP_LOCALS
7668
7669 float * dst_ptr = (float *) dst->data;
7670 const int32_t lp0 = ggml_get_op_params_i32(tensor: dst, i: 0);
7671 const int32_t rp0 = ggml_get_op_params_i32(tensor: dst, i: 1);
7672 const int32_t lp1 = ggml_get_op_params_i32(tensor: dst, i: 2);
7673 const int32_t rp1 = ggml_get_op_params_i32(tensor: dst, i: 3);
7674 const int32_t lp2 = ggml_get_op_params_i32(tensor: dst, i: 4);
7675 const int32_t rp2 = ggml_get_op_params_i32(tensor: dst, i: 5);
7676 const int32_t lp3 = ggml_get_op_params_i32(tensor: dst, i: 6);
7677 const int32_t rp3 = ggml_get_op_params_i32(tensor: dst, i: 7);
7678
7679
7680 // TODO: optimize
7681
7682 for (int64_t i2 = 0; i2 < ne2; ++i2) {
7683 for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7684 for (int64_t i0 = 0; i0 < ne0; ++i0) {
7685 for (int64_t i3 = 0; i3 < ne3; ++i3) {
7686 const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7687 if ((i0 >= lp0 && i0 < ne0 - rp0) \
7688 && (i1 >= lp1 && i1 < ne1 - rp1) \
7689 && (i2 >= lp2 && i2 < ne2 - rp2) \
7690 && (i3 >= lp3 && i3 < ne3 - rp3)) {
7691 const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7692 const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7693 dst_ptr[dst_idx] = *src_ptr;
7694 } else {
7695 dst_ptr[dst_idx] = 0;
7696 }
7697 }
7698 }
7699 }
7700 }
7701}
7702
7703void ggml_compute_forward_pad(
7704 const ggml_compute_params * params,
7705 ggml_tensor * dst) {
7706
7707 const ggml_tensor * src0 = dst->src[0];
7708
7709 switch (src0->type) {
7710 case GGML_TYPE_F32:
7711 {
7712 ggml_compute_forward_pad_f32(params, dst);
7713 } break;
7714 default:
7715 {
7716 GGML_ABORT("fatal error");
7717 }
7718 }
7719}
7720
7721// ggml_compute_forward_pad_reflect_1d
7722
7723void ggml_compute_forward_pad_reflect_1d(
7724 const ggml_compute_params * params,
7725 ggml_tensor * dst) {
7726
7727 const ggml_tensor * src0 = dst->src[0];
7728
7729 GGML_ASSERT(src0->type == GGML_TYPE_F32);
7730 GGML_ASSERT( dst->type == GGML_TYPE_F32);
7731
7732 const int ith = params->ith;
7733 const int nth = params->nth;
7734
7735 const int32_t * opts = (const int32_t *) dst->op_params;
7736 const int p0 = opts[0];
7737 const int p1 = opts[1];
7738
7739 GGML_TENSOR_UNARY_OP_LOCALS
7740
7741 for (int64_t i3 = 0; i3 < ne3; i3++) {
7742 for (int64_t i2 = 0; i2 < ne2; i2++) {
7743 for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7744 float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
7745 float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
7746
7747 ggml_vec_cpy_f32(n: ne00, y: left, x: (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
7748
7749 for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
7750 for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
7751 }
7752 }
7753 }
7754}
7755
7756// ggml_compute_forward_roll
7757
7758static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
7759 if (i < 0) {
7760 return i + ne;
7761 } else if (i >= ne) {
7762 return i - ne;
7763 }
7764 return i;
7765}
7766
7767static void ggml_compute_forward_roll_f32(
7768 const ggml_compute_params * params,
7769 ggml_tensor * dst) {
7770
7771 const ggml_tensor * src0 = dst->src[0];
7772 const float * src_data = (const float *) src0->data;
7773 float * dst_data = (float *) dst->data;
7774
7775 GGML_TENSOR_UNARY_OP_LOCALS
7776
7777 const int s0 = ggml_get_op_params_i32(tensor: dst, i: 0);
7778 const int s1 = ggml_get_op_params_i32(tensor: dst, i: 1);
7779 const int s2 = ggml_get_op_params_i32(tensor: dst, i: 2);
7780 const int s3 = ggml_get_op_params_i32(tensor: dst, i: 3);
7781
7782 const int64_t total = ne1 * ne2 * ne3;
7783 const int64_t per_thread = (total + params->nth) / params->nth;
7784 const int64_t start = params->ith * per_thread;
7785 const int64_t end = std::min(a: start + per_thread, b: total);
7786
7787 for (int64_t i = start; i < end; ++i) {
7788 const int64_t i1 = i % ne1;
7789 const int64_t i2 = (i / ne1) % ne2;
7790 const int64_t i3 = i / (ne2 * ne1);
7791 float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
7792
7793 const int64_t i01 = ggml_wrap_index(i: i1 - s1, ne: ne01);
7794 const int64_t i02 = ggml_wrap_index(i: i2 - s2, ne: ne02);
7795 const int64_t i03 = ggml_wrap_index(i: i3 - s3, ne: ne03);
7796 const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
7797
7798 const int64_t s = ggml_wrap_index(i: -s0, ne: ne00);
7799 const int64_t n = ne00 - s;
7800 ggml_vec_cpy_f32(n, y: dst_row, x: src_row + s);
7801 ggml_vec_cpy_f32(n: s, y: dst_row + n, x: src_row);
7802 }
7803}
7804
7805void ggml_compute_forward_roll(
7806 const ggml_compute_params * params,
7807 ggml_tensor * dst) {
7808
7809 const ggml_tensor * src0 = dst->src[0];
7810
7811 switch (src0->type) {
7812 case GGML_TYPE_F32:
7813 {
7814 ggml_compute_forward_roll_f32(params, dst);
7815 } break;
7816 default:
7817 {
7818 GGML_ABORT("fatal error");
7819 }
7820 }
7821}
7822
7823// ggml_compute_forward_arange
7824
7825static void ggml_compute_forward_arange_f32(
7826 const ggml_compute_params * params,
7827 ggml_tensor * dst) {
7828
7829 GGML_ASSERT(dst->nb[0] == sizeof(float));
7830
7831 const int ith = params->ith;
7832 const int nth = params->nth;
7833
7834 const float start = ggml_get_op_params_f32(tensor: dst, i: 0);
7835 const float stop = ggml_get_op_params_f32(tensor: dst, i: 1);
7836 const float step = ggml_get_op_params_f32(tensor: dst, i: 2);
7837
7838 const int64_t steps = (int64_t) ceilf(x: (stop - start) / step);
7839
7840 GGML_ASSERT(ggml_nelements(dst) == steps);
7841
7842 for (int64_t i = ith; i < steps; i+= nth) {
7843 float value = start + step * i;
7844 ((float *)dst->data)[i] = value;
7845 }
7846}
7847
7848void ggml_compute_forward_arange(
7849 const ggml_compute_params * params,
7850 ggml_tensor * dst) {
7851 switch (dst->type) {
7852 case GGML_TYPE_F32:
7853 {
7854 ggml_compute_forward_arange_f32(params, dst);
7855 } break;
7856 default:
7857 {
7858 GGML_ABORT("fatal error");
7859 }
7860 }
7861}
7862
7863static void ggml_compute_forward_timestep_embedding_f32(
7864 const ggml_compute_params * params,
7865 ggml_tensor * dst) {
7866
7867 const ggml_tensor * src0 = dst->src[0];
7868
7869 GGML_ASSERT(src0->nb[0] == sizeof(float));
7870
7871 const int ith = params->ith;
7872 const int nth = params->nth;
7873
7874 GGML_TENSOR_UNARY_OP_LOCALS
7875
7876 const int dim = ggml_get_op_params_i32(tensor: dst, i: 0);
7877 const int max_period = ggml_get_op_params_i32(tensor: dst, i: 1);
7878
7879 int half = dim / 2;
7880
7881 for (int64_t i = 0; i < ne00; i++) {
7882 float * embed_data = (float *)((char *) dst->data + i*nb1);
7883 for (int64_t j = ith; j < half; j += nth) {
7884 float timestep = ((float *)src0->data)[i];
7885 float freq = (float)expf(x: -logf(x: max_period) * j / half);
7886 float arg = timestep * freq;
7887 embed_data[j] = cosf(x: arg);
7888 embed_data[j + half] = sinf(x: arg);
7889 }
7890 if (dim % 2 != 0 && ith == 0) {
7891 embed_data[2 * half] = 0.f;
7892 }
7893 }
7894}
7895
7896void ggml_compute_forward_timestep_embedding(
7897 const ggml_compute_params * params,
7898 ggml_tensor * dst) {
7899
7900 const ggml_tensor * src0 = dst->src[0];
7901
7902 switch (src0->type) {
7903 case GGML_TYPE_F32:
7904 {
7905 ggml_compute_forward_timestep_embedding_f32(params, dst);
7906 } break;
7907 default:
7908 {
7909 GGML_ABORT("fatal error");
7910 }
7911 }
7912}
7913
7914// ggml_compute_forward_argsort
7915
7916static void ggml_compute_forward_argsort_f32(
7917 const ggml_compute_params * params,
7918 ggml_tensor * dst) {
7919
7920 const ggml_tensor * src0 = dst->src[0];
7921
7922 GGML_TENSOR_UNARY_OP_LOCALS
7923
7924 GGML_ASSERT(nb0 == sizeof(float));
7925
7926 const int ith = params->ith;
7927 const int nth = params->nth;
7928
7929 const int64_t nr = ggml_nrows(tensor: src0);
7930
7931 ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(tensor: dst, i: 0);
7932
7933 for (int64_t i = ith; i < nr; i += nth) {
7934 int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7935 const float * src_data = (float *)((char *) src0->data + i*nb01);
7936
7937 for (int64_t j = 0; j < ne0; j++) {
7938 dst_data[j] = j;
7939 }
7940
7941 // C doesn't have a functional sort, so we do a bubble sort instead
7942 for (int64_t j = 0; j < ne0; j++) {
7943 for (int64_t k = j + 1; k < ne0; k++) {
7944 if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7945 (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7946 int32_t tmp = dst_data[j];
7947 dst_data[j] = dst_data[k];
7948 dst_data[k] = tmp;
7949 }
7950 }
7951 }
7952 }
7953}
7954
7955void ggml_compute_forward_argsort(
7956 const ggml_compute_params * params,
7957 ggml_tensor * dst) {
7958
7959 const ggml_tensor * src0 = dst->src[0];
7960
7961 switch (src0->type) {
7962 case GGML_TYPE_F32:
7963 {
7964 ggml_compute_forward_argsort_f32(params, dst);
7965 } break;
7966 default:
7967 {
7968 GGML_ABORT("fatal error");
7969 }
7970 }
7971}
7972
7973// ggml_compute_forward_flash_attn_ext
7974
7975static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
7976 const ggml_compute_params * params,
7977 ggml_tensor * dst,
7978 int ir0, int ir1) {
7979 const ggml_tensor * q = dst->src[0];
7980 const ggml_tensor * k = dst->src[1];
7981 const ggml_tensor * v = dst->src[2];
7982 const ggml_tensor * mask = dst->src[3];
7983 const ggml_tensor * sinks = dst->src[4];
7984
7985 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7986 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
7987 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
7988 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
7989 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
7990 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
7991 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
7992 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
7993
7994 const int64_t DK = nek0;
7995 const int64_t DV = nev0;
7996 const int64_t N = neq1;
7997
7998 GGML_ASSERT(ne0 == DV);
7999 GGML_ASSERT(ne2 == N);
8000
8001 // input tensor rows must be contiguous
8002 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8003 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8004 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8005
8006 GGML_ASSERT(neq0 == DK);
8007 GGML_ASSERT(nek0 == DK);
8008 GGML_ASSERT(nev0 == DV);
8009
8010 GGML_ASSERT(neq1 == N);
8011
8012 // dst cannot be transposed or permuted
8013 GGML_ASSERT(nb0 == sizeof(float));
8014 GGML_ASSERT(nb0 <= nb1);
8015 GGML_ASSERT(nb1 <= nb2);
8016 GGML_ASSERT(nb2 <= nb3);
8017
8018 // broadcast factors
8019 const int64_t rk2 = neq2/nek2;
8020 const int64_t rk3 = neq3/nek3;
8021
8022 const int64_t rv2 = neq2/nev2;
8023 const int64_t rv3 = neq3/nev3;
8024
8025 // parallelize by q rows using ggml_vec_dot_f32
8026
8027 float scale = 1.0f;
8028 float max_bias = 0.0f;
8029 float logit_softcap = 0.0f;
8030
8031 memcpy(dest: &scale, src: (float *) dst->op_params + 0, n: sizeof(float));
8032 memcpy(dest: &max_bias, src: (float *) dst->op_params + 1, n: sizeof(float));
8033 memcpy(dest: &logit_softcap, src: (float *) dst->op_params + 2, n: sizeof(float));
8034
8035 if (logit_softcap != 0) {
8036 scale /= logit_softcap;
8037 }
8038
8039 const uint32_t n_head = neq2;
8040 const uint32_t n_head_log2 = 1u << (uint32_t) floor(x: log2(x: n_head));
8041
8042 const float m0 = powf(x: 2.0f, y: -(max_bias ) / n_head_log2);
8043 const float m1 = powf(x: 2.0f, y: -(max_bias / 2.0f) / n_head_log2);
8044
8045 ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(type: k->type)->vec_dot_type;
8046 ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(type: k_vec_dot_type)->from_float;
8047 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(type: k->type)->vec_dot;
8048 ggml_to_float_t const v_to_float = ggml_get_type_traits(type: v->type)->to_float;
8049
8050 GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8051 GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8052
8053 int ith = params->ith;
8054
8055 // loop over n_batch and n_head
8056 for (int ir = ir0; ir < ir1; ++ir) {
8057 // q indices
8058 const int iq3 = ir/(neq2*neq1);
8059 const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8060 const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8061
8062 const uint32_t h = iq2; // head index
8063 const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(x: m0, y: h + 1) : powf(x: m1, y: 2*(h - n_head_log2) + 1) : 1.0f;
8064
8065 float S = 0.0f; // sum
8066 float M = -INFINITY; // maximum KQ value
8067
8068 float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
8069 float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
8070 ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
8071 ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
8072
8073 if (v->type == GGML_TYPE_F16) {
8074 memset(s: VKQ16, c: 0, n: DV*sizeof(ggml_fp16_t));
8075 } else {
8076 memset(s: VKQ32, c: 0, n: DV*sizeof(float));
8077 }
8078
8079 const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
8080
8081 // k indices
8082 const int ik3 = iq3 / rk3;
8083 const int ik2 = iq2 / rk2;
8084
8085 // v indices
8086 const int iv3 = iq3 / rv3;
8087 const int iv2 = iq2 / rv2;
8088
8089 const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
8090 q_to_vec_dot(pq, Q_q, DK);
8091
8092 // online softmax / attention
8093 // loop over n_kv and n_head_kv
8094 // ref: https://arxiv.org/pdf/2112.05682.pdf
8095 for (int64_t ic = 0; ic < nek1; ++ic) {
8096 const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8097 if (mv == -INFINITY) {
8098 continue;
8099 }
8100
8101 float s; // KQ value
8102
8103 const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8104 kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
8105
8106 s = s*scale; // scale KQ value
8107
8108 if (logit_softcap != 0.0f) {
8109 s = logit_softcap*tanhf(x: s);
8110 }
8111
8112 s += mv; // apply mask
8113
8114 const float Mold = M;
8115
8116 float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
8117 float vs = 1.0f; // post-softmax KQ value, expf(s - M)
8118
8119 const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
8120
8121 if (v->type == GGML_TYPE_F16) {
8122 if (s > M) {
8123 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8124 M = s;
8125 ms = expf(x: Mold - M);
8126
8127 // V = V*expf(Mold - M)
8128 ggml_vec_scale_f16(n: DV, y: VKQ16, v: ms);
8129 } else {
8130 // no new maximum, ms == 1.0f, vs != 1.0f
8131 vs = expf(x: s - M);
8132 }
8133
8134 // V += v*expf(s - M)
8135 ggml_vec_mad_f16(n: DV, y: VKQ16, x: (const ggml_fp16_t *) v_data, v: vs);
8136 } else {
8137 if (s > M) {
8138 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8139 M = s;
8140 ms = expf(x: Mold - M);
8141
8142 // V = V*expf(Mold - M)
8143 ggml_vec_scale_f32(n: DV, y: VKQ32, v: ms);
8144 } else {
8145 // no new maximum, ms == 1.0f, vs != 1.0f
8146 vs = expf(x: s - M);
8147 }
8148
8149 // V += v*expf(s - M)
8150 if (v_to_float) {
8151 v_to_float(v_data, V32, DV);
8152 ggml_vec_mad_f32(n: DV, y: VKQ32, x: V32, v: vs);
8153 } else {
8154 // V is F32
8155 ggml_vec_mad_f32(n: DV, y: VKQ32, x: (const float *) v_data, v: vs);
8156 }
8157 }
8158
8159 S = S*ms + vs; // scale and increment sum with partial sum
8160 }
8161
8162 if (v->type == GGML_TYPE_F16) {
8163 for (int64_t d = 0; d < DV; ++d) {
8164 VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
8165 }
8166 }
8167
8168 // sinks
8169 if (sinks) {
8170 const float s = ((float *)((char *) sinks->data))[h];
8171
8172 float ms = 1.0f;
8173 float vs = 1.0f;
8174
8175 if (s > M) {
8176 ms = expf(x: M - s);
8177 ggml_vec_scale_f32(n: DV, y: VKQ32, v: ms);
8178 } else {
8179 vs = expf(x: s - M);
8180 }
8181
8182 S = S*ms + vs;
8183 }
8184
8185 // V /= S
8186 const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8187 ggml_vec_scale_f32(n: DV, y: VKQ32, v: S_inv);
8188
8189 // dst indices
8190 const int i1 = iq1;
8191 const int i2 = iq2;
8192 const int i3 = iq3;
8193
8194 // original
8195 //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
8196
8197 // permute(0, 2, 1, 3)
8198 memcpy(dest: (char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, src: VKQ32, n: nb1);
8199 }
8200}
8201
8202static void ggml_compute_forward_flash_attn_ext_f16(
8203 const ggml_compute_params * params,
8204 ggml_tensor * dst) {
8205
8206 const ggml_tensor * q = dst->src[0];
8207 const ggml_tensor * k = dst->src[1];
8208 const ggml_tensor * v = dst->src[2];
8209
8210 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8211 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8212 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8213 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8214 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8215 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8216 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8217 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8218
8219 const int64_t DK = nek0;
8220 const int64_t DV = nev0;
8221 const int64_t N = neq1;
8222
8223 GGML_ASSERT(ne0 == DV);
8224 GGML_ASSERT(ne2 == N);
8225
8226 // input tensor rows must be contiguous
8227 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8228 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8229 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8230
8231 GGML_ASSERT(neq0 == DK);
8232 GGML_ASSERT(nek0 == DK);
8233 GGML_ASSERT(nev0 == DV);
8234
8235 GGML_ASSERT(neq1 == N);
8236
8237 // dst cannot be transposed or permuted
8238 GGML_ASSERT(nb0 == sizeof(float));
8239 GGML_ASSERT(nb0 <= nb1);
8240 GGML_ASSERT(nb1 <= nb2);
8241 GGML_ASSERT(nb2 <= nb3);
8242
8243 // parallelize by q rows using ggml_vec_dot_f32
8244
8245 // total rows in q
8246 const int64_t nr = neq1*neq2*neq3;
8247
8248 // rows per thread
8249 const int ith = params->ith;
8250 const int nth = params->nth;
8251
8252 // disable for NUMA
8253 const bool disable_chunking = ggml_is_numa();
8254
8255 // 4x chunks per thread
8256 int nth_scaled = nth * 4;
8257 int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8258 int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8259
8260 if (nth == 1 || nchunk < nth || disable_chunking) {
8261 nchunk = nth;
8262 }
8263
8264 if (ith == 0) {
8265 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
8266 ggml_threadpool_chunk_set(tp: params->threadpool, value: nth);
8267 }
8268
8269 ggml_barrier(tp: params->threadpool);
8270
8271 // The number of elements in each chunk
8272 const int64_t dr = (nr + nchunk - 1) / nchunk;
8273
8274 // The first chunk comes from our thread_id, the rest will get auto-assigned.
8275 int current_chunk = ith;
8276
8277 while (current_chunk < nchunk) {
8278 const int64_t ir0 = dr * current_chunk;
8279 const int64_t ir1 = MIN(ir0 + dr, nr);
8280
8281 ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
8282
8283 current_chunk = ggml_threadpool_chunk_add(tp: params->threadpool, value: 1);
8284 }
8285}
8286
8287void ggml_compute_forward_flash_attn_ext(
8288 const ggml_compute_params * params,
8289 ggml_tensor * dst) {
8290 switch (dst->op_params[3]) {
8291 case GGML_PREC_DEFAULT:
8292 case GGML_PREC_F32:
8293 {
8294 // uses F32 accumulators
8295 ggml_compute_forward_flash_attn_ext_f16(params, dst);
8296 } break;
8297 default:
8298 {
8299 GGML_ABORT("fatal error");
8300 }
8301 }
8302}
8303
8304// ggml_compute_forward_flash_attn_back
8305
8306static void ggml_compute_forward_flash_attn_back_f32(
8307 const ggml_compute_params * params,
8308 const bool masked,
8309 ggml_tensor * dst) {
8310
8311 const ggml_tensor * q = dst->src[0];
8312 const ggml_tensor * k = dst->src[1];
8313 const ggml_tensor * v = dst->src[2];
8314 const ggml_tensor * d = dst->src[3];
8315
8316 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8317 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8318 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8319 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8320 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8321 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8322 GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
8323 GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
8324 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8325 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8326
8327 const int ith = params->ith;
8328 const int nth = params->nth;
8329
8330 const int64_t D = neq0;
8331 const int64_t N = neq1;
8332 const int64_t P = nek1 - N;
8333 const int64_t M = P + N;
8334
8335 const int Mup = ggml_up(n: M, GGML_SOFT_MAX_UNROLL);
8336 const int mxDM = MAX(D, Mup);
8337
8338 // GGML_ASSERT(ne0 == D);
8339 // GGML_ASSERT(ne1 == N);
8340 GGML_ASSERT(P >= 0);
8341
8342 GGML_ASSERT(nbq0 == sizeof(float));
8343 GGML_ASSERT(nbk0 == sizeof(float));
8344 GGML_ASSERT(nbv0 == sizeof(float));
8345
8346 GGML_ASSERT(neq0 == D);
8347 GGML_ASSERT(nek0 == D);
8348 GGML_ASSERT(nev1 == D);
8349 GGML_ASSERT(ned0 == D);
8350
8351 GGML_ASSERT(neq1 == N);
8352 GGML_ASSERT(nek1 == N + P);
8353 GGML_ASSERT(nev1 == D);
8354 GGML_ASSERT(ned1 == N);
8355
8356 // dst cannot be transposed or permuted
8357 GGML_ASSERT(nb0 == sizeof(float));
8358 GGML_ASSERT(nb0 <= nb1);
8359 GGML_ASSERT(nb1 <= nb2);
8360 GGML_ASSERT(nb2 <= nb3);
8361
8362 if (ith == 0) {
8363 memset(s: dst->data, c: 0, n: nb0*ne0*ne1*ne2*ne3);
8364 }
8365 ggml_barrier(tp: params->threadpool);
8366
8367 const int64_t elem_q = ggml_nelements(tensor: q);
8368 const int64_t elem_k = ggml_nelements(tensor: k);
8369
8370 ggml_type result_type = dst->type;
8371 GGML_ASSERT(ggml_blck_size(result_type) == 1);
8372 const size_t tsize = ggml_type_size(type: result_type);
8373
8374 const size_t offs_q = 0;
8375 const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
8376 const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
8377
8378 void * grad_q = (char *) dst->data;
8379 void * grad_k = (char *) dst->data + offs_k;
8380 void * grad_v = (char *) dst->data + offs_v;
8381
8382 const size_t nbgq1 = nb0*neq0;
8383 const size_t nbgq2 = nb0*neq0*neq1;
8384 const size_t nbgq3 = nb0*neq0*neq1*neq2;
8385
8386 const size_t nbgk1 = nb0*nek0;
8387 const size_t nbgk2 = nb0*nek0*nek1;
8388 const size_t nbgk3 = nb0*nek0*nek1*neq2;
8389
8390 const size_t nbgv1 = nb0*nev0;
8391 const size_t nbgv2 = nb0*nev0*nev1;
8392 const size_t nbgv3 = nb0*nev0*nev1*neq2;
8393
8394 // parallelize by k rows using ggml_vec_dot_f32
8395
8396 // total rows in k
8397 const int nr = nek2*nek3;
8398
8399 // rows per thread
8400 const int dr = (nr + nth - 1)/nth;
8401
8402 // row range for this thread
8403 const int ir0 = dr*ith;
8404 const int ir1 = MIN(ir0 + dr, nr);
8405
8406 const float scale = 1.0f/sqrtf(x: D);
8407
8408 //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
8409
8410 // how often k2 (and v2) is repeated in q2
8411 int nrep = neq2/nek2;
8412
8413 for (int ir = ir0; ir < ir1; ++ir) {
8414 // q indices
8415 const int ik3 = ir/(nek2);
8416 const int ik2 = ir - ik3*nek2;
8417
8418 const int iq3 = ik3;
8419 const int id3 = ik3;
8420 const int iv3 = ik3;
8421 const int iv2 = ik2;
8422
8423 for (int irep = 0; irep < nrep; ++irep) {
8424 const int iq2 = ik2 + irep*nek2;
8425 const int id2 = iq2;
8426
8427 // (ik2 + irep*nek2) % nek2 == ik2
8428 for (int iq1 = 0; iq1 < neq1; ++iq1) {
8429 const int id1 = iq1;
8430
8431 // not sure about CACHE_LINE_SIZE_F32..
8432 // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
8433 float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
8434 float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
8435
8436 for (int i = M; i < Mup; ++i) {
8437 S[i] = -INFINITY;
8438 }
8439
8440 const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
8441 for (int64_t ic = 0; ic < masked_begin; ++ic) {
8442 // k indices
8443 const int ik1 = ic;
8444
8445 // S indices
8446 const int i1 = ik1;
8447
8448 ggml_vec_dot_f32(n: neq0,
8449 s: S + i1, bs: 0,
8450 x: (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), bx: 0,
8451 y: (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), by: 0, nrc: 1);
8452 }
8453
8454 // scale
8455 ggml_vec_scale_f32(n: masked_begin, y: S, v: scale);
8456
8457 for (int64_t i = masked_begin; i < M; i++) {
8458 S[i] = -INFINITY;
8459 }
8460
8461 // softmax
8462 // exclude known -INF S[..] values from max and loop
8463 // dont forget to set their SM values to zero
8464 {
8465 float max = -INFINITY;
8466 ggml_vec_max_f32(n: masked_begin, s: &max, x: S);
8467
8468 ggml_float sum = 0.0;
8469 {
8470#ifdef GGML_SOFT_MAX_ACCELERATE
8471 max = -max;
8472 vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
8473 vvexpf(SM, SM, &Mup);
8474 ggml_vec_sum_f32(Mup, &sum, SM);
8475#else
8476 sum = ggml_vec_soft_max_f32(n: Mup, y: SM, x: S, max);
8477#endif
8478 }
8479
8480 assert(sum > 0.0);
8481
8482 sum = 1.0/sum;
8483 ggml_vec_scale_f32(n: masked_begin, y: SM, v: sum);
8484
8485 }
8486
8487 // step-by-step explanation
8488 {
8489 // forward-process shape grads from backward process
8490 // parallel_for ik2,ik3:
8491 // for irep:
8492 // iq2 = ik2 + irep*nek2
8493 // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
8494 // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
8495 // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
8496 // for iq1:
8497 // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
8498 // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
8499 // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
8500 // S0 = -Inf [D,1,1,1]
8501 // ~S1[i] = dot(kcur[:D,i], qcur)
8502 // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
8503 // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
8504 // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
8505 // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
8506 // ~S5[i] = dot(vcur[:,i], S4)
8507 // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
8508 // ~dst[i,iq1,iq2,iq3] = S5[i] ^
8509 // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
8510 // dst backward-/ grad[dst] = d
8511 //
8512 // output gradients with their dependencies:
8513 //
8514 // grad[kcur] = grad[S1].T @ qcur
8515 // grad[S1] = diag_mask_zero(grad[S3], P) * scale
8516 // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
8517 // grad[S4] = grad[S5] @ vcur
8518 // grad[S4] = d[:D,id1,id2,id3] @ vcur
8519 // grad[qcur] = grad[S1] @ kcur
8520 // grad[vcur] = grad[S5].T @ S4
8521 // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8522 //
8523 // in post-order:
8524 //
8525 // S1 = qcur @ kcur.T
8526 // S2 = S1 * scale
8527 // S3 = diag_mask_inf(S2, P)
8528 // S4 = softmax(S3)
8529 // grad[S4] = d[:D,id1,id2,id3] @ vcur
8530 // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
8531 // grad[S1] = diag_mask_zero(grad[S3], P) * scale
8532 // grad[qcur] = grad[S1] @ kcur
8533 // grad[kcur] = grad[S1].T @ qcur
8534 // grad[vcur] = d[:D,id1,id2,id3].T @ S4
8535 //
8536 // using less variables (SM=S4):
8537 //
8538 // S = diag_mask_inf(qcur @ kcur.T * scale, P)
8539 // SM = softmax(S)
8540 // S = d[:D,iq1,iq2,iq3] @ vcur
8541 // dot_SM_gradSM = dot(SM, S)
8542 // S = SM * (S - dot(SM, S))
8543 // S = diag_mask_zero(S, P) * scale
8544 //
8545 // grad[q][:D,iq1,iq2,iq3] += S @ kcur
8546 // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
8547 // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
8548 }
8549
8550 // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8551 // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
8552 // for ic:
8553 // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
8554 // exclude known future zero S[..] values from operation
8555 ggml_vec_set_f32(n: masked_begin, x: S, v: 0);
8556 for (int64_t ic = 0; ic < D; ++ic) {
8557 ggml_vec_mad_f32(n: masked_begin,
8558 y: S,
8559 x: (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
8560 v: *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
8561 }
8562
8563 // S = SM * (S - dot(SM, S))
8564 float dot_SM_gradSM = 0;
8565 ggml_vec_dot_f32 (n: masked_begin, s: &dot_SM_gradSM, bs: 0, x: SM, bx: 0, y: S, by: 0, nrc: 1);
8566 ggml_vec_acc1_f32(n: M, y: S, v: -dot_SM_gradSM);
8567 ggml_vec_mul_f32 (n: masked_begin, z: S, x: S, y: SM);
8568
8569 // S = diag_mask_zero(S, P) * scale
8570 // already done by above ggml_vec_set_f32
8571
8572 // exclude known zero S[..] values from operation
8573 ggml_vec_scale_f32(n: masked_begin, y: S, v: scale);
8574
8575 // S shape [M,1]
8576 // SM shape [M,1]
8577 // kcur shape [D,M]
8578 // qcur shape [D,1]
8579 // vcur shape [M,D]
8580
8581 // grad[q][:D,iq1,iq2,iq3] += S @ kcur
8582 // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
8583 // for ic:
8584 // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
8585 // exclude known zero S[..] values from loop
8586 for (int64_t ic = 0; ic < masked_begin; ++ic) {
8587 ggml_vec_mad_f32(n: D,
8588 y: (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
8589 x: (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
8590 v: S[ic]);
8591 }
8592
8593 // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
8594 // for ic:
8595 // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
8596 // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
8597 // exclude known zero S[..] values from loop
8598 for (int64_t ic = 0; ic < masked_begin; ++ic) {
8599 ggml_vec_mad_f32(n: D,
8600 y: (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
8601 x: (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
8602 v: S[ic]);
8603 }
8604
8605 // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
8606 // for ic:
8607 // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
8608 // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
8609 // exclude known zero SM[..] values from mad
8610 for (int64_t ic = 0; ic < D; ++ic) {
8611 ggml_vec_mad_f32(n: masked_begin,
8612 y: (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
8613 x: SM,
8614 v: *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
8615 }
8616 }
8617 }
8618 }
8619}
8620
8621void ggml_compute_forward_flash_attn_back(
8622 const ggml_compute_params * params,
8623 const bool masked,
8624 ggml_tensor * dst) {
8625
8626 const ggml_tensor * q = dst->src[0];
8627
8628 switch (q->type) {
8629 case GGML_TYPE_F32:
8630 {
8631 ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
8632 } break;
8633 default:
8634 {
8635 GGML_ABORT("fatal error");
8636 }
8637 }
8638}
8639
8640// ggml_compute_forward_ssm_conv
8641
8642static void ggml_compute_forward_ssm_conv_f32(
8643 const ggml_compute_params * params,
8644 ggml_tensor * dst) {
8645 const ggml_tensor * src0 = dst->src[0]; // conv_x
8646 const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
8647
8648 const int ith = params->ith;
8649 const int nth = params->nth;
8650
8651 const int nc = src1->ne[0]; // d_conv
8652 const int ncs = src0->ne[0]; // d_conv - 1 + n_t
8653 const int nr = src0->ne[1]; // d_inner
8654 const int n_t = dst->ne[1]; // tokens per sequence
8655 const int n_s = dst->ne[2]; // number of sequences in the batch
8656
8657 GGML_ASSERT( dst->ne[0] == nr);
8658 GGML_ASSERT(src0->nb[0] == sizeof(float));
8659 GGML_ASSERT(src1->nb[0] == sizeof(float));
8660 GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
8661
8662 // rows per thread
8663 const int dr = (nr + nth - 1)/nth;
8664
8665 // row range for this thread
8666 const int ir0 = dr*ith;
8667 const int ir1 = MIN(ir0 + dr, nr);
8668 const int ir = ir1 - ir0;
8669
8670 for (int i3 = 0; i3 < n_s; ++i3) {
8671 for (int i2 = 0; i2 < n_t; ++i2) {
8672 // {d_conv - 1 + n_t, d_inner, n_seqs}
8673 // sliding window
8674 const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
8675 const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
8676 float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
8677
8678 // TODO: transpose the output for smaller strides for big batches?
8679 // d_inner
8680 for (int i1 = 0; i1 < ir; ++i1) {
8681 // rowwise dot product
8682 // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
8683 float sumf = 0.0f;
8684
8685 // d_conv
8686 for (int i0 = 0; i0 < nc; ++i0) {
8687 sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
8688 }
8689 x[i1] = sumf;
8690 }
8691 }
8692 }
8693}
8694
8695void ggml_compute_forward_ssm_conv(
8696 const ggml_compute_params * params,
8697 ggml_tensor * dst) {
8698 switch (dst->src[0]->type) {
8699 case GGML_TYPE_F32:
8700 {
8701 ggml_compute_forward_ssm_conv_f32(params, dst);
8702 } break;
8703 default:
8704 {
8705 GGML_ABORT("fatal error");
8706 }
8707 }
8708}
8709
8710// ggml_compute_forward_ssm_scan
8711
8712static void ggml_compute_forward_ssm_scan_f32(
8713 const ggml_compute_params * params,
8714 ggml_tensor * dst) {
8715 const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
8716 const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
8717 const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
8718 const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
8719 const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
8720 const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
8721 const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
8722
8723 const int ith = params->ith;
8724 const int nth = params->nth;
8725
8726 const int64_t nc = src0->ne[0]; // d_state
8727 const int64_t nr = src0->ne[1]; // dim
8728 const int64_t nh = src1->ne[1]; // n_head
8729 const int64_t ng = src4->ne[1];
8730 const int64_t nt = src1->ne[2]; // number of tokens per sequence
8731 const int64_t ns = src1->ne[3]; // number of sequences in the batch
8732
8733 // can't use ggml_nbytes because src1 is not necessarily contiguous
8734 const int64_t s_off = ggml_nelements(tensor: src1) * ggml_element_size(tensor: src1);
8735
8736 GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
8737 GGML_ASSERT(src0->nb[0] == sizeof(float));
8738 GGML_ASSERT(src1->nb[0] == sizeof(float));
8739 GGML_ASSERT(src2->nb[0] == sizeof(float));
8740 GGML_ASSERT(src3->nb[0] == sizeof(float));
8741 GGML_ASSERT(src4->nb[0] == sizeof(float));
8742 GGML_ASSERT(src5->nb[0] == sizeof(float));
8743 GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
8744 GGML_ASSERT(nh % ng == 0);
8745
8746 // heads per thread
8747 const int dh = (nh + nth - 1)/nth;
8748
8749 // head range for this thread
8750 const int ih0 = dh*ith;
8751 const int ih1 = MIN(ih0 + dh, nh);
8752
8753 const int32_t * ids = (const int32_t *) src6->data;
8754
8755 for (int i3 = 0; i3 < ns; ++i3) {
8756 const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
8757 float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
8758
8759 for (int i2 = 0; i2 < nt; ++i2) {
8760 const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
8761 const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
8762 const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
8763 const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
8764 const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
8765 float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
8766
8767 if (src3->ne[0] == 1) {
8768 // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
8769
8770 // n_head
8771 for (int h = ih0; h < ih1; ++h) {
8772 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8773 const float dt_soft_plus = ggml_softplus(input: dt[h]);
8774 const float dA = expf(x: dt_soft_plus * A[h]);
8775 const int g = h / (nh / ng); // repeat_interleave
8776
8777 // dim
8778 for (int i1 = 0; i1 < nr; ++i1) {
8779 const int ii = i1 + h*nr;
8780 const float x_dt = x[ii] * dt_soft_plus;
8781 float sumf = 0.0f;
8782#if defined(GGML_SIMD)
8783 #if defined(__ARM_FEATURE_SVE)
8784 const int ggml_f32_epr = svcntw();
8785 const int ggml_f32_step = 1 * ggml_f32_epr;
8786
8787 const int np = (nc & ~(ggml_f32_step - 1));
8788
8789 GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8790
8791 GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8792 GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8793
8794 for (int i = 0; i < np; i += ggml_f32_step) {
8795 // TODO: maybe unroll more?
8796 for (int j = 0; j < 1; j++) {
8797 GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8798 GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8799 GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8800
8801 t0 = GGML_F32_VEC_MUL(t0, adA);
8802 t1 = GGML_F32_VEC_MUL(t1, axdt);
8803
8804 t0 = GGML_F32_VEC_ADD(t0, t1);
8805
8806 sum = GGML_F32_VEC_FMA(sum, t0, t2);
8807
8808 GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8809 }
8810 }
8811
8812 sumf = GGML_F32xt_REDUCE_ONE(sum);
8813 #elif defined(__riscv_v_intrinsic)
8814 // todo: RVV implementation
8815 const int np = 0;
8816 #else
8817 const int np = (nc & ~(GGML_F32_STEP - 1));
8818
8819 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8820
8821 GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8822 GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8823
8824 GGML_F32_VEC ax[GGML_F32_ARR];
8825 GGML_F32_VEC ay[GGML_F32_ARR];
8826 GGML_F32_VEC az[GGML_F32_ARR];
8827
8828 for (int i = 0; i < np; i += GGML_F32_STEP) {
8829 for (int j = 0; j < GGML_F32_ARR; j++) {
8830 ax[j] = GGML_F32_VEC_LOAD(p: s0 + i + j*GGML_F32_EPR + ii*nc);
8831 ay[j] = GGML_F32_VEC_LOAD(p: B + i + j*GGML_F32_EPR + g*nc);
8832 az[j] = GGML_F32_VEC_LOAD(p: C + i + j*GGML_F32_EPR + g*nc);
8833
8834 ax[j] = GGML_F32_VEC_MUL(a: ax[j], b: adA);
8835 ay[j] = GGML_F32_VEC_MUL(a: ay[j], b: axdt);
8836
8837 ax[j] = GGML_F32_VEC_ADD(a: ax[j], b: ay[j]);
8838
8839 sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
8840
8841 GGML_F32_VEC_STORE(p: s + i + j*GGML_F32_EPR + ii*nc, a: ax[j]);
8842 }
8843 }
8844
8845 // reduce sum0..sum3 to sum0
8846 GGML_F32_VEC_REDUCE(sumf, sum);
8847 #endif
8848#else
8849 const int np = 0;
8850#endif
8851 // d_state
8852 for (int i0 = np; i0 < nc; ++i0) {
8853 const int i = i0 + ii*nc;
8854 const int ig = i0 + g*nc;
8855 // state = prev_state * dA + dB * x
8856 const float state = (s0[i] * dA) + (B[ig] * x_dt);
8857 // y = rowwise_dotprod(state, C)
8858 sumf += state * C[ig];
8859 s[i] = state;
8860 }
8861 y[ii] = sumf;
8862 }
8863 }
8864 } else {
8865 // Mamba-1 has an element-wise decay factor for the states
8866
8867 // n_head
8868 for (int h = ih0; h < ih1; ++h) {
8869 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8870 const float dt_soft_plus = ggml_softplus(input: dt[h]);
8871 const int g = h / (nh / ng); // repeat_interleave
8872
8873 // dim
8874 for (int i1 = 0; i1 < nr; ++i1) {
8875 const int ii = i1 + h*nr;
8876 const float x_dt = x[ii] * dt_soft_plus;
8877#if defined(__ARM_FEATURE_SVE)
8878 svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8879 svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8880 svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8881
8882 // d_state
8883 // TODO: what happens when (d_state % svcntw()) != 0?
8884 for (int64_t k = 0; k < nc; k += svcntw()) {
8885 svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8886 svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8887 svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8888 svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8889
8890 svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8891 t1 = exp_ps_sve(svptrue_b32(), t1);
8892 svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8893
8894 vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8895 r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8896
8897 GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8898 }
8899 y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8900#else
8901 float sumf = 0.0f;
8902 // NOTE: can't really use GGML_SIMD here because d_state is usually 16
8903 // and also because expf is used within the loop.
8904 // d_state
8905 for (int i0 = 0; i0 < nc; ++i0) {
8906 const int i = i0 + ii*nc;
8907 const int ig = i0 + g*nc;
8908 // state = prev_state * dA + dB * x
8909 const float state = (s0[i] * expf(x: dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
8910 // y = rowwise_dotprod(state, C)
8911 sumf += state * C[ig];
8912 s[i] = state;
8913 }
8914 y[ii] = sumf;
8915#endif
8916 }
8917 }
8918 }
8919 // use the output as the source when it's not the first token-wise iteration
8920 s0 = s;
8921 }
8922 }
8923}
8924
8925void ggml_compute_forward_ssm_scan(
8926 const ggml_compute_params * params,
8927 ggml_tensor * dst) {
8928 switch (dst->src[0]->type) {
8929 case GGML_TYPE_F32:
8930 {
8931 ggml_compute_forward_ssm_scan_f32(params, dst);
8932 } break;
8933 default:
8934 {
8935 GGML_ABORT("fatal error");
8936 }
8937 }
8938}
8939
8940// ggml_compute_forward_win_part
8941
8942static void ggml_compute_forward_win_part_f32(
8943 const ggml_compute_params * params,
8944 ggml_tensor * dst) {
8945 GGML_UNUSED(params);
8946
8947 const ggml_tensor * src0 = dst->src[0];
8948
8949 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
8950 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8951
8952 const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
8953 const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
8954 const int32_t w = ((const int32_t *)(dst->op_params))[2];
8955
8956 assert(ne00 == ne0);
8957 assert(ne3 == nep0*nep1);
8958
8959 // TODO: optimize / multi-thread
8960 for (int py = 0; py < nep1; ++py) {
8961 for (int px = 0; px < nep0; ++px) {
8962 const int64_t i3 = py*nep0 + px;
8963 for (int64_t i2 = 0; i2 < ne2; ++i2) {
8964 for (int64_t i1 = 0; i1 < ne1; ++i1) {
8965 for (int64_t i0 = 0; i0 < ne0; ++i0) {
8966 const int64_t i02 = py*w + i2;
8967 const int64_t i01 = px*w + i1;
8968 const int64_t i00 = i0;
8969
8970 const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
8971 const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
8972
8973 if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
8974 ((float *) dst->data)[i] = 0.0f;
8975 } else {
8976 ((float *) dst->data)[i] = ((float *) src0->data)[j];
8977 }
8978 }
8979 }
8980 }
8981 }
8982 }
8983}
8984
8985void ggml_compute_forward_win_part(
8986 const ggml_compute_params * params,
8987 ggml_tensor * dst) {
8988
8989 const ggml_tensor * src0 = dst->src[0];
8990
8991 switch (src0->type) {
8992 case GGML_TYPE_F32:
8993 {
8994 ggml_compute_forward_win_part_f32(params, dst);
8995 } break;
8996 default:
8997 {
8998 GGML_ABORT("fatal error");
8999 }
9000 }
9001}
9002
9003// ggml_compute_forward_win_unpart
9004
9005static void ggml_compute_forward_win_unpart_f32(
9006 const ggml_compute_params * params,
9007 ggml_tensor * dst) {
9008 GGML_UNUSED(params);
9009
9010 const ggml_tensor * src0 = dst->src[0];
9011
9012 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9013 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
9014
9015 const int32_t w = ((const int32_t *)(dst->op_params))[0];
9016
9017 // padding
9018 const int px = (w - ne1%w)%w;
9019 //const int py = (w - ne2%w)%w;
9020
9021 const int npx = (px + ne1)/w;
9022 //const int npy = (py + ne2)/w;
9023
9024 assert(ne0 == ne00);
9025
9026 // TODO: optimize / multi-thread
9027 for (int64_t i2 = 0; i2 < ne2; ++i2) {
9028 for (int64_t i1 = 0; i1 < ne1; ++i1) {
9029 for (int64_t i0 = 0; i0 < ne0; ++i0) {
9030 const int ip2 = i2/w;
9031 const int ip1 = i1/w;
9032
9033 const int64_t i02 = i2%w;
9034 const int64_t i01 = i1%w;
9035 const int64_t i00 = i0;
9036
9037 const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9038 const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
9039
9040 ((float *) dst->data)[j] = ((float *) src0->data)[i];
9041 }
9042 }
9043 }
9044}
9045
9046void ggml_compute_forward_win_unpart(
9047 const ggml_compute_params * params,
9048 ggml_tensor * dst) {
9049
9050 const ggml_tensor * src0 = dst->src[0];
9051
9052 switch (src0->type) {
9053 case GGML_TYPE_F32:
9054 {
9055 ggml_compute_forward_win_unpart_f32(params, dst);
9056 } break;
9057 default:
9058 {
9059 GGML_ABORT("fatal error");
9060 }
9061 }
9062}
9063
9064//gmml_compute_forward_unary
9065
9066void ggml_compute_forward_unary(
9067 const ggml_compute_params * params,
9068 ggml_tensor * dst) {
9069
9070 const ggml_unary_op op = ggml_get_unary_op(tensor: dst);
9071
9072 switch (op) {
9073 case GGML_UNARY_OP_ABS:
9074 {
9075 ggml_compute_forward_abs(params, dst);
9076 } break;
9077 case GGML_UNARY_OP_SGN:
9078 {
9079 ggml_compute_forward_sgn(params, dst);
9080 } break;
9081 case GGML_UNARY_OP_NEG:
9082 {
9083 ggml_compute_forward_neg(params, dst);
9084 } break;
9085 case GGML_UNARY_OP_STEP:
9086 {
9087 ggml_compute_forward_step(params, dst);
9088 } break;
9089 case GGML_UNARY_OP_TANH:
9090 {
9091 ggml_compute_forward_tanh(params, dst);
9092 } break;
9093 case GGML_UNARY_OP_ELU:
9094 {
9095 ggml_compute_forward_elu(params, dst);
9096 } break;
9097 case GGML_UNARY_OP_RELU:
9098 {
9099 ggml_compute_forward_relu(params, dst);
9100 } break;
9101 case GGML_UNARY_OP_SIGMOID:
9102 {
9103 ggml_compute_forward_sigmoid(params, dst);
9104 } break;
9105 case GGML_UNARY_OP_GELU:
9106 {
9107 ggml_compute_forward_gelu(params, dst);
9108 } break;
9109 case GGML_UNARY_OP_GELU_ERF:
9110 {
9111 ggml_compute_forward_gelu_erf(params, dst);
9112 } break;
9113 case GGML_UNARY_OP_GELU_QUICK:
9114 {
9115 ggml_compute_forward_gelu_quick(params, dst);
9116 } break;
9117 case GGML_UNARY_OP_SILU:
9118 {
9119 ggml_compute_forward_silu(params, dst);
9120 } break;
9121 case GGML_UNARY_OP_HARDSWISH:
9122 {
9123 ggml_compute_forward_hardswish(params, dst);
9124 } break;
9125 case GGML_UNARY_OP_HARDSIGMOID:
9126 {
9127 ggml_compute_forward_hardsigmoid(params, dst);
9128 } break;
9129 case GGML_UNARY_OP_EXP:
9130 {
9131 ggml_compute_forward_exp(params, dst);
9132 } break;
9133 case GGML_UNARY_OP_FLOOR:
9134 {
9135 ggml_compute_forward_floor(params, dst);
9136 } break;
9137 case GGML_UNARY_OP_CEIL:
9138 {
9139 ggml_compute_forward_ceil(params, dst);
9140 } break;
9141 case GGML_UNARY_OP_ROUND:
9142 {
9143 ggml_compute_forward_round(params, dst);
9144 } break;
9145 case GGML_UNARY_OP_TRUNC:
9146 {
9147 ggml_compute_forward_trunc(params, dst);
9148 } break;
9149 case GGML_UNARY_OP_XIELU:
9150 {
9151 ggml_compute_forward_xielu(params, dst);
9152 } break;
9153 default:
9154 {
9155 GGML_ABORT("fatal error");
9156 }
9157 }
9158}
9159
9160//ggml_compute_forward_glu
9161
9162void ggml_compute_forward_glu(
9163 const ggml_compute_params * params,
9164 ggml_tensor * dst) {
9165
9166 const ggml_glu_op op = ggml_get_glu_op(tensor: dst);
9167
9168 switch (op) {
9169 case GGML_GLU_OP_REGLU:
9170 {
9171 ggml_compute_forward_reglu(params, dst);
9172 } break;
9173 case GGML_GLU_OP_GEGLU:
9174 {
9175 ggml_compute_forward_geglu(params, dst);
9176 } break;
9177 case GGML_GLU_OP_SWIGLU:
9178 {
9179 ggml_compute_forward_swiglu(params, dst);
9180 } break;
9181 case GGML_GLU_OP_SWIGLU_OAI:
9182 {
9183 ggml_compute_forward_swiglu_oai(params, dst);
9184 } break;
9185 case GGML_GLU_OP_GEGLU_ERF:
9186 {
9187 ggml_compute_forward_geglu_erf(params, dst);
9188 } break;
9189 case GGML_GLU_OP_GEGLU_QUICK:
9190 {
9191 ggml_compute_forward_geglu_quick(params, dst);
9192 } break;
9193 default:
9194 {
9195 GGML_ABORT("fatal error");
9196 }
9197 }
9198}
9199
9200// ggml_compute_forward_get_rel_pos
9201
9202static void ggml_compute_forward_get_rel_pos_f16(
9203 const ggml_compute_params * params,
9204 ggml_tensor * dst) {
9205 GGML_UNUSED(params);
9206
9207 const ggml_tensor * src0 = dst->src[0];
9208
9209 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9210
9211 GGML_TENSOR_UNARY_OP_LOCALS
9212
9213 const int64_t w = ne1;
9214
9215 ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9216 ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
9217
9218 for (int64_t i2 = 0; i2 < ne2; ++i2) {
9219 for (int64_t i1 = 0; i1 < ne1; ++i1) {
9220 const int64_t pos = (w - i1 - 1) + i2;
9221 for (int64_t i0 = 0; i0 < ne0; ++i0) {
9222 dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9223 }
9224 }
9225 }
9226}
9227
9228void ggml_compute_forward_get_rel_pos(
9229 const ggml_compute_params * params,
9230 ggml_tensor * dst) {
9231
9232 const ggml_tensor * src0 = dst->src[0];
9233
9234 switch (src0->type) {
9235 case GGML_TYPE_F16:
9236 case GGML_TYPE_BF16:
9237 {
9238 ggml_compute_forward_get_rel_pos_f16(params, dst);
9239 } break;
9240 default:
9241 {
9242 GGML_ABORT("fatal error");
9243 }
9244 }
9245}
9246
9247// ggml_compute_forward_add_rel_pos
9248
9249static void ggml_compute_forward_add_rel_pos_f32(
9250 const ggml_compute_params * params,
9251 ggml_tensor * dst) {
9252
9253 const ggml_tensor * src0 = dst->src[0];
9254 const ggml_tensor * src1 = dst->src[1];
9255 const ggml_tensor * src2 = dst->src[2];
9256
9257 const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
9258 if (!inplace) {
9259 if (params->ith == 0) {
9260 memcpy(dest: (char *) dst->data, src: (char *) src0->data, n: ggml_nbytes(tensor: dst));
9261 }
9262 ggml_barrier(tp: params->threadpool);
9263 }
9264 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
9265
9266 float * src1_data = (float *) src1->data;
9267 float * src2_data = (float *) src2->data;
9268 float * dst_data = (float *) dst->data;
9269
9270 const int64_t ne10 = src1->ne[0];
9271 const int64_t ne11 = src1->ne[1];
9272 const int64_t ne12 = src1->ne[2];
9273 const int64_t ne13 = src1->ne[3];
9274
9275 const int ith = params->ith;
9276 const int nth = params->nth;
9277
9278 // total patches in dst
9279 const int np = ne13;
9280
9281 // patches per thread
9282 const int dp = (np + nth - 1)/nth;
9283
9284 // patch range for this thread
9285 const int ip0 = dp*ith;
9286 const int ip1 = MIN(ip0 + dp, np);
9287
9288 for (int64_t i13 = ip0; i13 < ip1; ++i13) {
9289 for (int64_t i12 = 0; i12 < ne12; ++i12) {
9290 for (int64_t i11 = 0; i11 < ne11; ++i11) {
9291 const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
9292 for (int64_t i10 = 0; i10 < ne10; ++i10) {
9293 const int64_t jp0 = jp1 + i10;
9294 const float src1_e = src1_data[jp0];
9295 const float src2_e = src2_data[jp0];
9296
9297 const int64_t jdh = jp0 * ne10;
9298 const int64_t jdw = jdh - (ne10 - 1) * i10;
9299
9300 for (int64_t j = 0; j < ne10; ++j) {
9301 dst_data[jdh + j ] += src2_e;
9302 dst_data[jdw + j*ne10] += src1_e;
9303 }
9304 }
9305 }
9306 }
9307 }
9308}
9309
9310void ggml_compute_forward_add_rel_pos(
9311 const ggml_compute_params * params,
9312 ggml_tensor * dst) {
9313
9314 const ggml_tensor * src0 = dst->src[0];
9315
9316 switch (src0->type) {
9317 case GGML_TYPE_F32:
9318 {
9319 ggml_compute_forward_add_rel_pos_f32(params, dst);
9320 } break;
9321 default:
9322 {
9323 GGML_ABORT("fatal error");
9324 }
9325 }
9326}
9327
9328// ggml_compute_forward_rwkv_wkv6
9329
9330static void ggml_compute_forward_rwkv_wkv6_f32(
9331 const ggml_compute_params * params,
9332 ggml_tensor * dst) {
9333 const int64_t T = dst->src[1]->ne[2];
9334 const int64_t C = dst->ne[0];
9335 const int64_t HEADS = dst->src[1]->ne[1];
9336 const int64_t n_seqs = dst->src[5]->ne[1];
9337 const int64_t head_size = C / HEADS;
9338
9339 float * dst_data = (float *) dst->data;
9340 float * state = ((float *) dst->data) + C * T;
9341
9342 const int ith = params->ith;
9343 const int nth = params->nth;
9344
9345 if (ith >= HEADS) {
9346 return;
9347 }
9348
9349 const int h_start = (HEADS * ith) / nth;
9350 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9351 (HEADS * (ith + 1)) / nth : HEADS;
9352
9353 float * k = (float *) dst->src[0]->data;
9354 float * v = (float *) dst->src[1]->data;
9355 float * r = (float *) dst->src[2]->data;
9356 float * time_faaaa = (float *) dst->src[3]->data;
9357 float * time_decay = (float *) dst->src[4]->data;
9358
9359 size_t t_stride = HEADS * head_size; // Same to C
9360
9361 size_t h_stride = C / HEADS;
9362 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9363 size_t h_stride_2d = head_size * head_size;
9364
9365 if (ith == 0) {
9366 memset(s: dst_data, c: 0, n: T * C * sizeof(float));
9367 }
9368 ggml_barrier(tp: params->threadpool);
9369
9370
9371 #if defined(__AVX__) && !defined(__AVX512F__)
9372 #define GGML_F32X GGML_F32x8
9373 #define GGML_F32X_SET1 GGML_F32x8_SET1
9374 #define GGML_F32X_LOAD GGML_F32x8_LOAD
9375 #define GGML_F32X_STORE GGML_F32x8_STORE
9376 #define GGML_F32X_MUL GGML_F32x8_MUL
9377 #define GGML_F32X_FMA GGML_F32x8_FMA
9378 #define WKV_VECTOR_SIZE 8
9379 #elif defined(__AVX512F__)
9380 #define GGML_F32X GGML_F32x16
9381 #define GGML_F32X_SET1 GGML_F32x16_SET1
9382 #define GGML_F32X_LOAD GGML_F32x16_LOAD
9383 #define GGML_F32X_STORE GGML_F32x16_STORE
9384 #define GGML_F32X_MUL GGML_F32x16_MUL
9385 #define GGML_F32X_FMA GGML_F32x16_FMA
9386 #define WKV_VECTOR_SIZE 16
9387 #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9388 #define GGML_F32X GGML_F32xt
9389 #define GGML_F32X_SET1 GGML_F32xt_SET1
9390 #define GGML_F32X_LOAD GGML_F32xt_LOAD
9391 #define GGML_F32X_STORE GGML_F32xt_STORE
9392 #define GGML_F32X_MUL GGML_F32xt_MUL
9393 #define GGML_F32X_FMA GGML_F32xt_FMA
9394 #define WKV_VECTOR_SIZE 8
9395 #elif defined(__ARM_NEON) && defined(__aarch64__)
9396 #define GGML_F32X GGML_F32x4
9397 #define GGML_F32X_SET1 GGML_F32x4_SET1
9398 #define GGML_F32X_LOAD GGML_F32x4_LOAD
9399 #define GGML_F32X_STORE GGML_F32x4_STORE
9400 #define GGML_F32X_MUL GGML_F32x4_MUL
9401 #define GGML_F32X_FMA GGML_F32x4_FMA
9402 #define WKV_VECTOR_SIZE 4
9403 #endif
9404
9405 #ifdef WKV_VECTOR_SIZE
9406 int wkv_vector_size;
9407 #if defined(__ARM_FEATURE_SVE)
9408 wkv_vector_size = svcntw();
9409 #else
9410 wkv_vector_size = WKV_VECTOR_SIZE;
9411 #endif
9412 const int64_t vec_count = head_size / wkv_vector_size;
9413
9414 for (int64_t t = 0; t < T; t++) {
9415 size_t t_offset = t * t_stride;
9416 size_t state_offset = head_size * C * (t / (T / n_seqs));
9417 float * state_cur = state + state_offset;
9418 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9419
9420 for (int64_t h = h_start; h < h_end; h++) {
9421 size_t h_offset = h * h_stride;
9422 size_t t_h_offset = t_offset + h_offset;
9423 size_t h_2d_offset = h * h_stride_2d;
9424
9425 for (int64_t i = 0; i < head_size; i++) {
9426 size_t t_h_i_offset = t_h_offset + i;
9427 size_t h_i_offset = h_offset + i;
9428 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9429
9430 float k_val = k[t_h_i_offset];
9431 float r_val = r[t_h_i_offset];
9432 float time_faaaa_val = time_faaaa[h_i_offset];
9433 float time_decay_val = time_decay[t_h_i_offset];
9434
9435 // Broadcast scalar values to vectors
9436 GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9437 GGML_F32X r_vec = GGML_F32X_SET1(r_val);
9438 GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
9439 GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
9440
9441 for (int64_t j = 0; j < vec_count; j++) {
9442 size_t base_j = j * wkv_vector_size;
9443 size_t t_h_j_offset = t_h_offset + base_j;
9444 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9445
9446 // Load x elements at once
9447 GGML_F32X v_vec = GGML_F32X_LOAD(p: &v[t_h_j_offset]);
9448 GGML_F32X prev_state_vec = GGML_F32X_LOAD(p: &state_prev[h_2d_i_j_offset]);
9449 GGML_F32X dst_vec = GGML_F32X_LOAD(p: &dst_data[t_h_j_offset]);
9450
9451 // Compute kv = v * k
9452 GGML_F32X kv_vec = GGML_F32X_MUL(a: v_vec, b: k_vec);
9453
9454 // Compute temp = kv * time_faaaa + prev_state
9455 GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
9456
9457 // Update dst: dst += temp * r
9458 dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
9459 GGML_F32X_STORE(p: &dst_data[t_h_j_offset], a: dst_vec);
9460
9461 // Update state: state = prev_state * time_decay + kv
9462 GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
9463 GGML_F32X_STORE(p: &state_cur[h_2d_i_j_offset], a: new_state_vec);
9464 }
9465
9466 // Handle remaining elements, this will not be used.
9467 for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
9468 size_t t_h_j_offset = t_h_offset + j;
9469 size_t h_2d_i_j_offset = h_2d_i_offset + j;
9470 float v_val = v[t_h_j_offset];
9471 float kv_val = v_val * k_val;
9472 float prev_state_val = state_prev[h_2d_i_j_offset];
9473 float temp_val = kv_val * time_faaaa_val + prev_state_val;
9474 dst_data[t_h_j_offset] += temp_val * r_val;
9475 state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9476 }
9477 }
9478 }
9479 }
9480
9481 #else
9482 // basically fused operations:
9483 // dst = r @ (time_faaaa * (k @ v) + state),
9484 // state = time_decay * state + (k @ v),
9485 // recursive through each token
9486 for (int64_t t = 0; t < T; t++) {
9487 size_t t_offset = t * t_stride;
9488 size_t state_offset = head_size * C * (t / (T / n_seqs));
9489 float * state_cur = state + state_offset;
9490 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9491
9492 for (int64_t h = h_start; h < h_end; h++) {
9493 size_t h_offset = h * h_stride;
9494 size_t t_h_offset = t_offset + h_offset;
9495 size_t h_2d_offset = h * h_stride_2d;
9496
9497 for (int64_t i = 0; i < head_size; i++) {
9498 size_t t_h_i_offset = t_h_offset + i;
9499 size_t h_i_offset = h_offset + i;
9500 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9501
9502 float k_val = k[t_h_i_offset];
9503 float r_val = r[t_h_i_offset];
9504 float time_faaaa_val = time_faaaa[h_i_offset];
9505 // RWKV v6: different time_decay for each token.
9506 float time_decay_val = time_decay[t_h_i_offset];
9507
9508 for (int64_t j = 0; j < head_size; j++) {
9509 size_t t_h_j_offset = t_h_offset + j;
9510 size_t h_2d_i_j_offset = h_2d_i_offset + j;
9511
9512 float v_val = v[t_h_j_offset];
9513 float kv_val = v_val * k_val;
9514 float prev_state_val = state_prev[h_2d_i_j_offset];
9515 float temp_val = kv_val * time_faaaa_val + prev_state_val;
9516 dst_data[t_h_j_offset] += temp_val * r_val;
9517 state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
9518 }
9519 }
9520 }
9521 }
9522 #endif
9523}
9524
9525
9526void ggml_compute_forward_rwkv_wkv6(
9527 const ggml_compute_params * params,
9528 ggml_tensor * dst) {
9529
9530 const ggml_tensor * src0 = dst->src[0];
9531
9532 switch (src0->type) {
9533 case GGML_TYPE_F32:
9534 {
9535 ggml_compute_forward_rwkv_wkv6_f32(params, dst);
9536 } break;
9537 default:
9538 {
9539 GGML_ABORT("fatal error");
9540 }
9541 }
9542}
9543
9544// ggml_compute_forward_gla
9545
9546static void ggml_compute_forward_gla_f32(
9547 const ggml_compute_params * params,
9548 ggml_tensor * dst) {
9549 const int64_t T = dst->src[1]->ne[2];
9550 const int64_t C = dst->ne[0];
9551 const int64_t HEADS = dst->src[1]->ne[1];
9552 const int64_t n_seqs = dst->src[4]->ne[1];
9553 const int64_t head_size = C / HEADS;
9554 const float scale = ggml_get_op_params_f32(tensor: dst, i: 0);
9555
9556 float * dst_data = (float *) dst->data;
9557 float * state = ((float *) dst->data) + C * T;
9558
9559 const int ith = params->ith;
9560 const int nth = params->nth;
9561
9562 if (ith >= HEADS) {
9563 return;
9564 }
9565
9566 const int h_start = (HEADS * ith) / nth;
9567 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9568 (HEADS * (ith + 1)) / nth : HEADS;
9569
9570 float * k = (float *) dst->src[0]->data;
9571 float * v = (float *) dst->src[1]->data;
9572 float * q = (float *) dst->src[2]->data;
9573 float * g = (float *) dst->src[3]->data;
9574
9575 size_t t_stride = HEADS * head_size; // Same to C
9576
9577 size_t h_stride = C / HEADS;
9578 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9579 size_t h_stride_2d = head_size * head_size;
9580
9581 if (ith == 0) {
9582 memset(s: dst_data, c: 0, n: T * C * sizeof(float));
9583 }
9584 ggml_barrier(tp: params->threadpool);
9585
9586
9587 #if defined(__AVX__) && !defined(__AVX512F__)
9588 #define GGML_F32X GGML_F32x8
9589 #define GGML_F32X_SET1 GGML_F32x8_SET1
9590 #define GGML_F32X_LOAD GGML_F32x8_LOAD
9591 #define GGML_F32X_STORE GGML_F32x8_STORE
9592 #define GGML_F32X_MUL GGML_F32x8_MUL
9593 #define GGML_F32X_FMA GGML_F32x8_FMA
9594 #define GLA_VECTOR_SIZE 8
9595 #elif defined(__AVX512F__)
9596 #define GGML_F32X GGML_F32x16
9597 #define GGML_F32X_SET1 GGML_F32x16_SET1
9598 #define GGML_F32X_LOAD GGML_F32x16_LOAD
9599 #define GGML_F32X_STORE GGML_F32x16_STORE
9600 #define GGML_F32X_MUL GGML_F32x16_MUL
9601 #define GGML_F32X_FMA GGML_F32x16_FMA
9602 #define GLA_VECTOR_SIZE 16
9603 #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9604 #define GGML_F32X GGML_F32xt
9605 #define GGML_F32X_SET1 GGML_F32xt_SET1
9606 #define GGML_F32X_LOAD GGML_F32xt_LOAD
9607 #define GGML_F32X_STORE GGML_F32xt_STORE
9608 #define GGML_F32X_MUL GGML_F32xt_MUL
9609 #define GGML_F32X_FMA GGML_F32xt_FMA
9610 #define GLA_VECTOR_SIZE 8
9611 #elif defined(__ARM_NEON) && defined(__aarch64__)
9612 #define GGML_F32X GGML_F32x4
9613 #define GGML_F32X_SET1 GGML_F32x4_SET1
9614 #define GGML_F32X_LOAD GGML_F32x4_LOAD
9615 #define GGML_F32X_STORE GGML_F32x4_STORE
9616 #define GGML_F32X_MUL GGML_F32x4_MUL
9617 #define GGML_F32X_FMA GGML_F32x4_FMA
9618 #define GLA_VECTOR_SIZE 4
9619 #endif
9620
9621 #ifdef GLA_VECTOR_SIZE
9622 int gla_vector_size;
9623 #if defined(__ARM_FEATURE_SVE)
9624 gla_vector_size = svcntw();
9625 #else
9626 gla_vector_size = GLA_VECTOR_SIZE;
9627 #endif
9628 const int64_t vec_count = head_size / gla_vector_size;
9629
9630 for (int64_t t = 0; t < T; t++) {
9631 size_t t_offset = t * t_stride;
9632 size_t state_offset = head_size * C * (t / (T / n_seqs));
9633 float * state_cur = state + state_offset;
9634 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9635
9636 for (int64_t h = h_start; h < h_end; h++) {
9637 size_t h_offset = h * h_stride;
9638 size_t t_h_offset = t_offset + h_offset;
9639 size_t h_2d_offset = h * h_stride_2d;
9640
9641 for (int64_t i = 0; i < head_size; i++) {
9642 size_t t_h_i_offset = t_h_offset + i;
9643 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9644
9645 float k_val = k[t_h_i_offset];
9646 float q_val = q[t_h_i_offset] * scale;
9647 float g_val = g[t_h_i_offset];
9648
9649 // Broadcast scalar values to vectors
9650 GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9651 GGML_F32X q_vec = GGML_F32X_SET1(q_val);
9652 GGML_F32X g_vec = GGML_F32X_SET1(g_val);
9653
9654 for (int64_t j = 0; j < vec_count; j++) {
9655 size_t base_j = j * gla_vector_size;
9656 size_t t_h_j_offset = t_h_offset + base_j;
9657 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9658
9659 // Load x elements at once
9660 GGML_F32X v_vec = GGML_F32X_LOAD(p: &v[t_h_j_offset]);
9661 GGML_F32X prev_state_vec = GGML_F32X_LOAD(p: &state_prev[h_2d_i_j_offset]);
9662 GGML_F32X dst_vec = GGML_F32X_LOAD(p: &dst_data[t_h_j_offset]);
9663
9664 // Compute kv = v * k
9665 GGML_F32X kv_vec = GGML_F32X_MUL(a: v_vec, b: k_vec);
9666
9667 // Compute temp = prev_state * g + kv
9668 GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
9669
9670 // Update dst: dst += temp * q
9671 dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
9672 GGML_F32X_STORE(p: &dst_data[t_h_j_offset], a: dst_vec);
9673
9674 // Update state
9675 GGML_F32X_STORE(p: &state_cur[h_2d_i_j_offset], a: temp_vec);
9676 }
9677
9678 // Handle remaining elements, this will not be used.
9679 for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
9680 size_t t_h_j_offset = t_h_offset + j;
9681 size_t h_2d_i_j_offset = h_2d_i_offset + j;
9682 float v_val = v[t_h_j_offset];
9683 float kv_val = v_val * k_val;
9684 float prev_state_val = state_prev[h_2d_i_j_offset];
9685 float temp_val = kv_val + prev_state_val * g_val;
9686 dst_data[t_h_j_offset] += temp_val * q_val;
9687 state_cur[h_2d_i_j_offset] = temp_val;
9688 }
9689 }
9690 }
9691 }
9692
9693 #else
9694 for (int64_t t = 0; t < T; t++) {
9695 size_t t_offset = t * t_stride;
9696 size_t state_offset = head_size * C * (t / (T / n_seqs));
9697 float * state_cur = state + state_offset;
9698 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
9699
9700 for (int64_t h = h_start; h < h_end; h++) {
9701 size_t h_offset = h * h_stride;
9702 size_t t_h_offset = t_offset + h_offset;
9703 size_t h_2d_offset = h * h_stride_2d;
9704
9705 for (int64_t i = 0; i < head_size; i++) {
9706 size_t t_h_i_offset = t_h_offset + i;
9707 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9708
9709 float k_val = k[t_h_i_offset];
9710 float q_val = q[t_h_i_offset] * scale;
9711 float g_val = g[t_h_i_offset];
9712
9713 for (int64_t j = 0; j < head_size; j++) {
9714 size_t t_h_j_offset = t_h_offset + j;
9715 size_t h_2d_i_j_offset = h_2d_i_offset + j;
9716
9717 float v_val = v[t_h_j_offset];
9718 float kv_val = v_val * k_val;
9719 float prev_state_val = state_prev[h_2d_i_j_offset];
9720 float temp_val = prev_state_val * g_val + kv_val;
9721 dst_data[t_h_j_offset] += temp_val * q_val;
9722 state_cur[h_2d_i_j_offset] = temp_val;
9723 }
9724 }
9725 }
9726 }
9727 #endif
9728}
9729
9730
9731void ggml_compute_forward_gla(
9732 const ggml_compute_params * params,
9733 ggml_tensor * dst) {
9734
9735 const ggml_tensor * src0 = dst->src[0];
9736
9737 switch (src0->type) {
9738 case GGML_TYPE_F32:
9739 {
9740 ggml_compute_forward_gla_f32(params, dst);
9741 } break;
9742 default:
9743 {
9744 GGML_ABORT("fatal error");
9745 }
9746 }
9747}
9748
9749// ggml_compute_forward_rwkv_wkv7
9750
9751static void ggml_compute_forward_rwkv_wkv7_f32(
9752 const ggml_compute_params * params,
9753 ggml_tensor * dst) {
9754 const int64_t T = dst->src[1]->ne[2];
9755 const int64_t C = dst->ne[0];
9756 const int64_t HEADS = dst->src[1]->ne[1];
9757 const int64_t n_seqs = dst->src[6]->ne[1];
9758 const int64_t head_size = C / HEADS;
9759
9760 float * dst_data = (float *) dst->data;
9761 float * state = ((float *) dst->data) + C * T;
9762
9763 const int ith = params->ith;
9764 const int nth = params->nth;
9765
9766 if (ith >= HEADS) {
9767 return;
9768 }
9769
9770 const int h_start = (HEADS * ith) / nth;
9771 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9772 (HEADS * (ith + 1)) / nth : HEADS;
9773
9774 float * r = (float *) dst->src[0]->data;
9775 float * w = (float *) dst->src[1]->data;
9776 float * k = (float *) dst->src[2]->data;
9777 float * v = (float *) dst->src[3]->data;
9778 float * a = (float *) dst->src[4]->data;
9779 float * b = (float *) dst->src[5]->data;
9780
9781 int64_t t_stride = HEADS * head_size; // Same to C
9782
9783 int64_t h_stride = C / HEADS;
9784 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9785 int64_t h_stride_2d = head_size * head_size;
9786
9787 #if defined(GGML_SIMD)
9788 #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9789 // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9790 for (int64_t t = 0; t < T; t++) {
9791 int64_t t_offset = t * t_stride;
9792 int64_t state_offset = head_size * C * (t / (T / n_seqs));
9793 float * state_cur = state + state_offset;
9794 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9795
9796 for (int64_t h = h_start; h < h_end; h++) {
9797 int64_t h_offset = h * h_stride;
9798 int64_t t_h_offset = t_offset + h_offset;
9799 int64_t h_2d_offset = h * h_stride_2d;
9800
9801 for (int64_t i = 0; i < head_size; i++) {
9802 int64_t t_h_i_offset = t_h_offset + i;
9803 int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9804
9805 float v_val = v[t_h_i_offset];
9806
9807 float sa = 0, result = 0;
9808 for (int64_t j = 0; j < head_size; j++) {
9809 sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9810 }
9811
9812 for (int64_t j = 0; j < head_size; j++) {
9813 int64_t t_h_j_offset = t_h_offset + j;
9814 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9815
9816 float r_val = r[t_h_j_offset];
9817 float w_val = w[t_h_j_offset];
9818 float k_val = k[t_h_j_offset];
9819 float b_val = b[t_h_j_offset];
9820 float kv_val = v_val * k_val;
9821 float prev_state_val = state_prev[h_2d_i_j_offset];
9822 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9823 result += state_cur[h_2d_i_j_offset] * r_val;
9824 }
9825 dst_data[t_h_i_offset] = result;
9826 }
9827 }
9828 }
9829 #else
9830 for (int64_t t = 0; t < T; t++) {
9831 int64_t t_offset = t * t_stride;
9832 int64_t state_offset = head_size * C * (t / (T / n_seqs));
9833 float * state_cur = state + state_offset;
9834 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9835
9836 for (int64_t h = h_start; h < h_end; h++) {
9837 int64_t h_offset = h * h_stride;
9838 int64_t t_h_offset = t_offset + h_offset;
9839 int64_t h_2d_offset = h * h_stride_2d;
9840
9841 for (int64_t ii = 0; ii < head_size; ii++) {
9842 int64_t t_h_i_offset = t_h_offset + ii;
9843 int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9844
9845 GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9846
9847 float sa = 0;
9848 {
9849 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9850 GGML_F32_VEC ax[GGML_F32_ARR];
9851 GGML_F32_VEC ay[GGML_F32_ARR];
9852 for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9853 for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9854 ax[kk] = GGML_F32_VEC_LOAD(p: &a[t_h_offset + j + kk * GGML_F32_EPR]);
9855 ay[kk] = GGML_F32_VEC_LOAD(p: &state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9856 sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
9857 }
9858 }
9859 GGML_F32_VEC_REDUCE(sa, sum);
9860 }
9861
9862 GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
9863
9864 int64_t j = 0;
9865 GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9866 for (; j < head_size; j += GGML_F32_STEP) {
9867 for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9868 int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9869 int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9870
9871 GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(p: &r[t_h_j_offset]);
9872 GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(p: &w[t_h_j_offset]);
9873 GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(p: &k[t_h_j_offset]);
9874 GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(p: &b[t_h_j_offset]);
9875
9876 k_vec = GGML_F32_VEC_MUL(a: v_vec, b: k_vec);
9877
9878 GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(p: &state_prev[h_2d_i_j_offset]);
9879 // kv + s * decay + sa * b
9880 state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
9881 state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
9882 GGML_F32_VEC_STORE(p: &state_cur[h_2d_i_j_offset], a: state_vec);
9883
9884 result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
9885 }
9886 }
9887 GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
9888
9889 // There shouldn't be left-overs though.
9890 for (; j < head_size; j++) {
9891 int64_t t_h_j_offset = t_h_offset + j;
9892 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9893
9894 float r_val = r[t_h_j_offset];
9895 float w_val = w[t_h_j_offset];
9896 float k_val = k[t_h_j_offset];
9897 float b_val = b[t_h_j_offset];
9898 float kv_val = v[t_h_i_offset] * k_val;
9899
9900 float prev_state_val = state_prev[h_2d_i_j_offset];
9901 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9902 dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9903 }
9904 }
9905 }
9906 }
9907 #endif
9908 #else
9909 for (int64_t t = 0; t < T; t++) {
9910 int64_t t_offset = t * t_stride;
9911 int64_t state_offset = head_size * C * (t / (T / n_seqs));
9912 float * state_cur = state + state_offset;
9913 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9914
9915 for (int64_t h = h_start; h < h_end; h++) {
9916 int64_t h_offset = h * h_stride;
9917 int64_t t_h_offset = t_offset + h_offset;
9918 int64_t h_2d_offset = h * h_stride_2d;
9919
9920 for (int64_t i = 0; i < head_size; i++) {
9921 int64_t t_h_i_offset = t_h_offset + i;
9922 int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9923
9924 float v_val = v[t_h_i_offset];
9925
9926 float sa = 0, result = 0;
9927 for (int64_t j = 0; j < head_size; j++) {
9928 sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9929 }
9930
9931 for (int64_t j = 0; j < head_size; j++) {
9932 int64_t t_h_j_offset = t_h_offset + j;
9933 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9934
9935 float r_val = r[t_h_j_offset];
9936 float w_val = w[t_h_j_offset];
9937 float k_val = k[t_h_j_offset];
9938 float b_val = b[t_h_j_offset];
9939 float kv_val = v_val * k_val;
9940 float prev_state_val = state_prev[h_2d_i_j_offset];
9941 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9942 result += state_cur[h_2d_i_j_offset] * r_val;
9943 }
9944 dst_data[t_h_i_offset] = result;
9945 }
9946 }
9947 }
9948 #endif
9949}
9950
9951
9952void ggml_compute_forward_rwkv_wkv7(
9953 const ggml_compute_params * params,
9954 ggml_tensor * dst) {
9955
9956 const ggml_tensor * src0 = dst->src[0];
9957
9958 switch (src0->type) {
9959 case GGML_TYPE_F32:
9960 {
9961 ggml_compute_forward_rwkv_wkv7_f32(params, dst);
9962 } break;
9963 default:
9964 {
9965 GGML_ABORT("fatal error");
9966 }
9967 }
9968}
9969
9970// ggml_compute_forward_map_custom1
9971
9972void ggml_compute_forward_map_custom1(
9973 const ggml_compute_params * params,
9974 ggml_tensor * dst) {
9975
9976 const ggml_tensor * a = dst->src[0];
9977
9978 struct ggml_map_custom1_op_params p;
9979 memcpy(dest: &p, src: dst->op_params, n: sizeof(p));
9980
9981 p.fun(dst, a, params->ith, params->nth, p.userdata);
9982}
9983
9984// ggml_compute_forward_map_custom2
9985
9986void ggml_compute_forward_map_custom2(
9987 const ggml_compute_params * params,
9988 ggml_tensor * dst) {
9989
9990 const ggml_tensor * a = dst->src[0];
9991 const ggml_tensor * b = dst->src[1];
9992
9993 struct ggml_map_custom2_op_params p;
9994 memcpy(dest: &p, src: dst->op_params, n: sizeof(p));
9995
9996 p.fun(dst, a, b, params->ith, params->nth, p.userdata);
9997}
9998
9999// ggml_compute_forward_map_custom3
10000
10001void ggml_compute_forward_map_custom3(
10002 const ggml_compute_params * params,
10003 ggml_tensor * dst) {
10004
10005 const ggml_tensor * a = dst->src[0];
10006 const ggml_tensor * b = dst->src[1];
10007 const ggml_tensor * c = dst->src[2];
10008
10009 struct ggml_map_custom3_op_params p;
10010 memcpy(dest: &p, src: dst->op_params, n: sizeof(p));
10011
10012 p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
10013}
10014
10015// ggml_compute_forward_custom
10016
10017void ggml_compute_forward_custom(
10018 const struct ggml_compute_params * params,
10019 struct ggml_tensor * dst) {
10020
10021 struct ggml_custom_op_params p;
10022 memcpy(dest: &p, src: dst->op_params, n: sizeof(p));
10023
10024 p.fun(dst, params->ith, params->nth, p.userdata);
10025}
10026
10027// ggml_compute_forward_cross_entropy_loss
10028
10029static void ggml_compute_forward_cross_entropy_loss_f32(
10030 const ggml_compute_params * params,
10031 ggml_tensor * dst) {
10032
10033 const ggml_tensor * src0 = dst->src[0];
10034 const ggml_tensor * src1 = dst->src[1];
10035
10036 GGML_ASSERT(src0->type == GGML_TYPE_F32);
10037 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10038 GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
10039 GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
10040 GGML_ASSERT(ggml_are_same_shape(src0, src1));
10041 GGML_ASSERT(ggml_is_scalar(dst));
10042 GGML_ASSERT(dst->type == GGML_TYPE_F32);
10043
10044 // TODO: handle transposed/permuted matrices
10045 const int64_t nc = src0->ne[0];
10046 const int64_t nr = ggml_nrows(tensor: src0);
10047
10048 const int ith = params->ith;
10049 const int nth = params->nth;
10050
10051 float * sums = (float *) params->wdata;
10052 float * st = ((float *) params->wdata) + nth + ith*nc;
10053 float sum_thread = 0.0f;
10054
10055 GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
10056
10057 // rows per thread
10058 const int64_t dr = (nr + nth - 1)/nth;
10059
10060 // row range for this thread
10061 const int64_t ir0 = dr*ith;
10062 const int64_t ir1 = MIN(ir0 + dr, nr);
10063
10064 for (int64_t i1 = ir0; i1 < ir1; ++i1) {
10065 const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
10066 const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
10067
10068#ifndef NDEBUG
10069 for (int64_t i = 0; i < nc; ++i) {
10070 //printf("p[%d] = %f\n", i, p[i]);
10071 assert(!isnan(s0[i]));
10072 assert(!isnan(s1[i]));
10073 }
10074#endif
10075
10076 float max = -INFINITY;
10077 ggml_vec_max_f32(n: nc, s: &max, x: s0);
10078 const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(n: nc, y: st, x: s0, max);
10079 assert(sum_softmax >= 0.0);
10080
10081 ggml_vec_add1_f32(n: nc, z: st, x: st, v: -sum_softmax);
10082 ggml_vec_mul_f32(n: nc, z: st, x: st, y: s1);
10083
10084 float sum_st = 0.0f;
10085 ggml_vec_sum_f32(n: nc, s: &sum_st, x: st);
10086 sum_thread += sum_st;
10087
10088#ifndef NDEBUG
10089 for (int64_t i = 0; i < nc; ++i) {
10090 assert(!isnan(st[i]));
10091 assert(!isinf(st[i]));
10092 }
10093#endif
10094 }
10095 sums[ith] = sum_thread;
10096 ggml_barrier(tp: params->threadpool);
10097
10098 if (ith == 0) {
10099 float * dp = (float *) dst->data;
10100 ggml_vec_sum_f32(n: nth, s: dp, x: sums);
10101 dp[0] *= -1.0f / (float) nr;
10102 }
10103}
10104
10105void ggml_compute_forward_cross_entropy_loss(
10106 const ggml_compute_params * params,
10107 ggml_tensor * dst) {
10108
10109 const ggml_tensor * src0 = dst->src[0];
10110
10111 switch (src0->type) {
10112 case GGML_TYPE_F32:
10113 {
10114 ggml_compute_forward_cross_entropy_loss_f32(params, dst);
10115 } break;
10116 default:
10117 {
10118 GGML_ABORT("fatal error");
10119 }
10120 }
10121}
10122
10123// ggml_compute_forward_cross_entropy_loss_back
10124
10125static void ggml_compute_forward_cross_entropy_loss_back_f32(
10126 const ggml_compute_params * params,
10127 ggml_tensor * dst) {
10128
10129 const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
10130 const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
10131 const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
10132
10133 GGML_ASSERT(ggml_is_contiguous(dst));
10134 GGML_ASSERT(ggml_is_contiguous(src0f));
10135 GGML_ASSERT(ggml_is_contiguous(src1f));
10136 GGML_ASSERT(ggml_is_contiguous(grad));
10137 GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
10138
10139 const int64_t ith = params->ith;
10140 const int64_t nth = params->nth;
10141
10142 // TODO: handle transposed/permuted matrices
10143 const int64_t nc = src0f->ne[0];
10144 const int64_t nr = ggml_nrows(tensor: src0f);
10145
10146 // rows per thread
10147 const int64_t dr = (nr + nth - 1)/nth;
10148
10149 // row range for this thread
10150 const int64_t ir0 = dr*ith;
10151 const int64_t ir1 = MIN(ir0 + dr, nr);
10152
10153 const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
10154
10155 for (int64_t i1 = ir0; i1 < ir1; i1++) {
10156 float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
10157 const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
10158 const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
10159
10160#ifndef NDEBUG
10161 for (int64_t i = 0; i < nc; ++i) {
10162 //printf("p[%d] = %f\n", i, p[i]);
10163 assert(!isnan(s0[i]));
10164 assert(!isnan(s1[i]));
10165 }
10166#endif
10167
10168 // soft_max
10169 float max = -INFINITY;
10170 ggml_vec_max_f32(n: nc, s: &max, x: s0);
10171 const ggml_float sum = ggml_vec_soft_max_f32(n: nc, y: ds0, x: s0, max);
10172 assert(sum > 0.0);
10173 ggml_vec_scale_f32(n: nc, y: ds0, v: 1.0/sum);
10174
10175 // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
10176 ggml_vec_sub_f32(n: nc, z: ds0, x: ds0, y: s1);
10177 ggml_vec_scale_f32(n: nc, y: ds0, v: d_by_nr);
10178
10179#ifndef NDEBUG
10180 for (int64_t i = 0; i < nc; ++i) {
10181 assert(!isnan(ds0[i]));
10182 assert(!isinf(ds0[i]));
10183 }
10184#endif
10185 }
10186}
10187
10188void ggml_compute_forward_cross_entropy_loss_back(
10189 const ggml_compute_params * params,
10190 ggml_tensor * dst) {
10191
10192 const ggml_tensor * src0 = dst->src[0];
10193
10194 switch (src0->type) {
10195 case GGML_TYPE_F32:
10196 {
10197 ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
10198 } break;
10199 default:
10200 {
10201 GGML_ABORT("fatal error");
10202 }
10203 }
10204}
10205
10206static void ggml_compute_forward_opt_step_adamw_f32(
10207 const ggml_compute_params * params,
10208 ggml_tensor * dst) {
10209
10210 const ggml_tensor * src0 = dst->src[0];
10211 const ggml_tensor * src0_grad = dst->src[1];
10212 const ggml_tensor * src0_grad_m = dst->src[2];
10213 const ggml_tensor * src0_grad_v = dst->src[3];
10214 const ggml_tensor * adamw_params = dst->src[4];
10215
10216 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10217 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
10218 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
10219 GGML_ASSERT(ggml_nelements(adamw_params) == 7);
10220
10221 const int ith = params->ith;
10222 const int nth = params->nth;
10223
10224 const int nr = ggml_nrows(tensor: src0);
10225
10226 GGML_TENSOR_UNARY_OP_LOCALS
10227 GGML_ASSERT(nb00 == sizeof(float));
10228
10229 // rows per thread
10230 const int dr = (nr + nth - 1)/nth;
10231
10232 // row range for this thread
10233 const int ir0 = dr*ith;
10234 const int ir1 = MIN(ir0 + dr, nr);
10235
10236 const float * adamw_params_ptr = ggml_get_data_f32(tensor: adamw_params);
10237
10238 const float alpha = adamw_params_ptr[0];
10239 const float beta1 = adamw_params_ptr[1];
10240 const float beta2 = adamw_params_ptr[2];
10241 const float eps = adamw_params_ptr[3];
10242 const float wd = adamw_params_ptr[4];
10243 const float beta1h = adamw_params_ptr[5];
10244 const float beta2h = adamw_params_ptr[6];
10245 const float keep = 1.f - alpha * wd;
10246 for (int ir = ir0; ir < ir1; ++ir) {
10247 const int64_t i03 = ir/(ne02*ne01);
10248 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10249 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10250
10251 const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
10252
10253 float * w = (float *) ((char *) src0->data + offset); // weight
10254 const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10255 float * m = (float *) ((char *) src0_grad_m->data + offset);
10256 float * v = (float *) ((char *) src0_grad_v->data + offset);
10257
10258 for (int i00 = 0; i00 < ne00; ++i00) {
10259 m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
10260 v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
10261
10262 const float mh = m[i00]*beta1h;
10263 const float vh = sqrtf(x: v[i00]*beta2h) + eps;
10264
10265 // The weight decay is applied independently of the Adam momenta m and v.
10266 // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10267 // See: https://arxiv.org/pdf/1711.05101v3.pdf
10268 w[i00] = w[i00] * keep - alpha * mh / vh;
10269 }
10270 }
10271}
10272
10273void ggml_compute_forward_opt_step_adamw(
10274 const ggml_compute_params * params,
10275 ggml_tensor * dst) {
10276
10277 const ggml_tensor * src0 = dst->src[0];
10278
10279 switch (src0->type) {
10280 case GGML_TYPE_F32:
10281 {
10282 ggml_compute_forward_opt_step_adamw_f32(params, dst);
10283 } break;
10284 default:
10285 {
10286 GGML_ABORT("fatal error");
10287 }
10288 }
10289}
10290
10291static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10292 const ggml_tensor * src0 = dst->src[0];
10293 const ggml_tensor * src0_grad = dst->src[1];
10294 const ggml_tensor * sgd_params = dst->src[2];
10295
10296 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10297 GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10298
10299 const int ith = params->ith;
10300 const int nth = params->nth;
10301
10302 const int nr = ggml_nrows(tensor: src0);
10303
10304 GGML_TENSOR_UNARY_OP_LOCALS
10305 GGML_ASSERT(nb00 == sizeof(float));
10306
10307 // rows per thread
10308 const int dr = (nr + nth - 1) / nth;
10309
10310 // row range for this thread
10311 const int ir0 = dr * ith;
10312 const int ir1 = MIN(ir0 + dr, nr);
10313
10314 // using adamw param subset we care about - alpha, wd - could have a separate struct
10315 const float * sgd_params_ptr = ggml_get_data_f32(tensor: sgd_params);
10316 const float alpha = sgd_params_ptr[0];
10317 const float keep = 1.f - alpha * sgd_params_ptr[1];
10318
10319 for (int ir = ir0; ir < ir1; ++ir) {
10320 const int64_t i03 = ir / (ne02 * ne01);
10321 const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10322 const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10323
10324 const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10325
10326 float * w = (float *) ((char *) src0->data + offset); // weight
10327 const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10328
10329 for (int i00 = 0; i00 < ne00; ++i00) {
10330 w[i00] = w[i00] * keep - alpha * g[i00];
10331 }
10332 }
10333}
10334
10335void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10336 const ggml_tensor * src0 = dst->src[0];
10337
10338 switch (src0->type) {
10339 case GGML_TYPE_F32:
10340 {
10341 ggml_compute_forward_opt_step_sgd_f32(params, dst);
10342 }
10343 break;
10344 default:
10345 {
10346 GGML_ABORT("fatal error - sgd is F32 only");
10347 }
10348 }
10349}
10350