1#define GGML_COMMON_IMPL_CPP
2#define GGML_COMMON_DECL_CPP
3#include "ggml-common.h"
4#include "ggml-backend-impl.h"
5
6#include "ggml-impl.h"
7#include "ggml-cpu.h"
8#include "ggml-cpu-impl.h"
9#include "simd-mappings.h"
10#include "traits.h"
11
12#include "arch-fallback.h"
13
14#include <cmath>
15#include <cstring>
16#include <cassert>
17#include <cstdio> // for GGML_ASSERT
18
19#include "repack.h"
20
21#if defined(__GNUC__)
22#pragma GCC diagnostic ignored "-Woverlength-strings"
23#endif
24
25#define UNUSED GGML_UNUSED
26
27static inline int nearest_int(float fval) {
28 assert(fabsf(fval) <= 4194303.f);
29 float val = fval + 12582912.f;
30 int i; memcpy(dest: &i, src: &val, n: sizeof(int));
31 return (i & 0x007fffff) - 0x00400000;
32}
33
34// Functions to create the interleaved data layout formats
35
36// interleave 4 block_q4_0s in blocks of blck_size_interleave
37// returns an interleaved block_q4_0x4
38// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
39// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
40//
41// - in : an array of block_q4_0 pointers
42// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
43// blck_size_interleave bytes
44// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
45// from bias offset form to pure sign form (this saves subtract
46// operations durin unpacking)
47//
48
49extern "C" {
50
51void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
52 assert(QK8_0 == 32);
53 assert(k % QK8_0 == 0);
54 const int nb = k / QK8_0;
55
56 block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
57
58 // scalar
59 const int blck_size_interleave = 4;
60 float srcv[4][QK8_0];
61 float id[4];
62
63 for (int i = 0; i < nb; i++) {
64 for (int row_iter = 0; row_iter < 4; row_iter++) {
65 float amax = 0.0f; // absolute max
66
67 for (int j = 0; j < QK8_0; j++) {
68 srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
69 amax = MAX(amax, fabsf(srcv[row_iter][j]));
70 }
71
72 const float d = amax / ((1 << 7) - 1);
73 id[row_iter] = d ? 1.0f / d : 0.0f;
74
75 y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
76 }
77
78 for (int j = 0; j < QK8_0 * 4; j++) {
79 int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
80 int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
81 src_offset += (j % blck_size_interleave);
82
83 float x0 = srcv[src_id][src_offset] * id[src_id];
84 y[i].qs[j] = roundf(x: x0);
85 }
86 }
87}
88
89void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
90 assert(QK8_0 == 32);
91 assert(k % QK8_0 == 0);
92 const int nb = k / QK8_0;
93
94 block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
95
96 // scalar
97 const int blck_size_interleave = 8;
98 float srcv[4][QK8_0];
99 float id[4];
100
101 for (int i = 0; i < nb; i++) {
102 for (int row_iter = 0; row_iter < 4; row_iter++) {
103 float amax = 0.0f; // absolute max
104
105 for (int j = 0; j < QK8_0; j++) {
106 srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
107 amax = MAX(amax, fabsf(srcv[row_iter][j]));
108 }
109
110 const float d = amax / ((1 << 7) - 1);
111 id[row_iter] = d ? 1.0f / d : 0.0f;
112
113 y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
114 }
115
116 for (int j = 0; j < QK8_0 * 4; j++) {
117 int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
118 int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
119 src_offset += (j % blck_size_interleave);
120
121 float x0 = srcv[src_id][src_offset] * id[src_id];
122 y[i].qs[j] = roundf(x: x0);
123 }
124 }
125}
126
127void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
128 assert(QK_K == 256);
129 assert(k % QK_K == 0);
130 const int nb = k / QK_K;
131
132 block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
133
134 // scalar
135 const int blck_size_interleave = 8;
136 float srcv[4][QK_K];
137 float iscale[4];
138
139 for (int i = 0; i < nb; i++) {
140 for (int row_iter = 0; row_iter < 4; row_iter++) {
141 float amax = 0.0f; // absolute max
142 float max = 0;
143
144 for (int j = 0; j < QK_K; j++) {
145 srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
146 // Update the maximum value of the corresponding super block
147 if(amax < fabsf(x: srcv[row_iter][j])) {
148 amax = fabsf(x: srcv[row_iter][j]);
149 max = srcv[row_iter][j];
150 }
151 }
152
153 iscale[row_iter] = amax ? -127.f/max : 0;
154
155 y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
156 }
157
158 for (int j = 0; j < QK_K / 4; j++) {
159 y[i].bsums[j] = 0;
160 }
161
162 // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
163 // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
164 // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
165 for (int j = 0; j < QK_K * 4; j++) {
166 int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
167 int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
168 src_offset += (j % blck_size_interleave);
169 int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
170
171 float x0 = srcv[src_id][src_offset] * iscale[src_id];
172 y[i].qs[j] = nearest_int(fval: x0);
173 y[i].bsums[index] += y[i].qs[j];
174 }
175 }
176}
177
178} // extern "C"
179
180template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
181void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
182
183template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
184 assert(nrow == 4);
185 UNUSED(nrow);
186 ggml_quantize_mat_q8_0_4x4(x, vy, k: n_per_row);
187}
188
189template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
190 assert(nrow == 4);
191 UNUSED(nrow);
192 ggml_quantize_mat_q8_0_4x8(x, vy, k: n_per_row);
193}
194
195template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
196 assert(nrow == 4);
197 UNUSED(nrow);
198 ggml_quantize_mat_q8_K_4x8(x, vy, k: n_per_row);
199}
200
201extern "C" {
202
203void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
204 const int qk = QK8_0;
205 const int nb = n / qk;
206 const int ncols_interleaved = 4;
207 const int blocklen = 4;
208
209 assert(nr == 1);
210 assert(n % qk == 0);
211 assert(nc % ncols_interleaved == 0);
212
213 UNUSED(s);
214 UNUSED(bs);
215 UNUSED(vx);
216 UNUSED(vy);
217 UNUSED(nr);
218 UNUSED(nc);
219 UNUSED(nb);
220 UNUSED(ncols_interleaved);
221 UNUSED(blocklen);
222
223 float sumf[4];
224 int sumi;
225
226 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
227 for (int x = 0; x < nc / ncols_interleaved; x++) {
228 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
229
230 for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
231 for (int l = 0; l < nb; l++) {
232 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
233 for (int j = 0; j < ncols_interleaved; j++) {
234 sumi = 0;
235 for (int i = 0; i < blocklen; ++i) {
236 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
237 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
238 sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
239 }
240 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
241 }
242 }
243 }
244 for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
245 }
246}
247
248void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
249 const int qk = QK8_0;
250 const int nb = n / qk;
251 const int ncols_interleaved = 4;
252 const int blocklen = 8;
253
254 assert (n % qk == 0);
255 assert (nc % ncols_interleaved == 0);
256
257 UNUSED(s);
258 UNUSED(bs);
259 UNUSED(vx);
260 UNUSED(vy);
261 UNUSED(nr);
262 UNUSED(nc);
263 UNUSED(nb);
264 UNUSED(ncols_interleaved);
265 UNUSED(blocklen);
266
267 float sumf[4];
268 int sumi;
269
270 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
271 for (int x = 0; x < nc / ncols_interleaved; x++) {
272 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
273
274 for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
275 for (int l = 0; l < nb; l++) {
276 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
277 for (int j = 0; j < ncols_interleaved; j++) {
278 sumi = 0;
279 for (int i = 0; i < blocklen; ++i) {
280 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
281 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
282 sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
283 }
284 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
285 }
286 }
287 }
288 for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
289 }
290}
291
292void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
293 const int qk = QK8_0;
294 const int nb = n / qk;
295 const int ncols_interleaved = 8;
296 const int blocklen = 8;
297
298 assert (n % qk == 0);
299 assert (nc % ncols_interleaved == 0);
300
301 UNUSED(s);
302 UNUSED(bs);
303 UNUSED(vx);
304 UNUSED(vy);
305 UNUSED(nr);
306 UNUSED(nc);
307 UNUSED(nb);
308 UNUSED(ncols_interleaved);
309 UNUSED(blocklen);
310
311 float sumf[8];
312 int sumi;
313
314 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
315 for (int x = 0; x < nc / ncols_interleaved; x++) {
316 const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
317
318 for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
319 for (int l = 0; l < nb; l++) {
320 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
321 for (int j = 0; j < ncols_interleaved; j++) {
322 sumi = 0;
323 for (int i = 0; i < blocklen; ++i) {
324 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
325 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
326 sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
327 }
328 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
329 }
330 }
331 }
332 for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
333 }
334}
335
336void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
337 const int qk = QK_K;
338 const int nb = n / qk;
339 const int ncols_interleaved = 8;
340 const int blocklen = 8;
341 static const uint32_t kmask1 = 0x3f3f3f3f;
342 static const uint32_t kmask2 = 0x0f0f0f0f;
343 static const uint32_t kmask3 = 0x03030303;
344
345 assert (n % qk == 0);
346 assert (nc % ncols_interleaved == 0);
347
348 UNUSED(s);
349 UNUSED(bs);
350 UNUSED(vx);
351 UNUSED(vy);
352 UNUSED(nr);
353 UNUSED(nc);
354 UNUSED(nb);
355 UNUSED(ncols_interleaved);
356 UNUSED(blocklen);
357
358 float sumf[8];
359 float sum_minf[8];
360 uint32_t utmp[32];
361 int sumi1;
362 int sumi2;
363 int sumi;
364
365 const block_q8_K * a_ptr = (const block_q8_K *) vy;
366 for (int x = 0; x < nc / ncols_interleaved; x++) {
367 const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
368
369 for (int j = 0; j < ncols_interleaved; j++) {
370 sumf[j] = 0.0;
371 sum_minf[j] = 0.0;
372 }
373 for (int l = 0; l < nb; l++) {
374 for (int sb = 0; sb < 8; sb++) {
375 memcpy(dest: utmp + sb * 4, src: b_ptr[l].scales + sb * 12, n: 12);
376 utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
377 const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
378 utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
379 utmp[sb * 4 + 2] = uaux_0;
380 utmp[sb * 4 + 0] &= kmask1;
381 }
382 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
383 uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
384 uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
385 for (int j = 0; j < ncols_interleaved; j++) {
386 sumi1 = 0;
387 sumi2 = 0;
388 sumi = 0;
389 for (int i = 0; i < blocklen; ++i) {
390 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
391 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
392 sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
393 sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
394 sumi1 = sumi1 * scales_0[j];
395 sumi2 = sumi2 * scales_1[j];
396 sumi += sumi1 + sumi2;
397 }
398 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
399 }
400 }
401 for (int sb = 0; sb < 8; sb++) {
402 uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
403 for (int j = 0; j < ncols_interleaved; j++) {
404 sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
405 }
406 }
407 }
408 for (int j = 0; j < ncols_interleaved; j++) {
409 s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
410 }
411 }
412}
413
414void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
415 const int qk = QK_K;
416 const int nb = n / qk;
417 const int ncols_interleaved = 8;
418 const int blocklen = 8;
419
420 assert (n % qk == 0);
421 assert (nc % ncols_interleaved == 0);
422
423 UNUSED(s);
424 UNUSED(bs);
425 UNUSED(vx);
426 UNUSED(vy);
427 UNUSED(nr);
428 UNUSED(nc);
429 UNUSED(nb);
430 UNUSED(ncols_interleaved);
431 UNUSED(blocklen);
432
433 float sumf[8];
434 float sum_minf[8];
435 int sumi1,sumi2,sumi3,sumi4;
436 int sumi;
437
438 const block_q8_K * a_ptr = (const block_q8_K *)vy;
439 for(int x = 0; x < nc / ncols_interleaved; x++) {
440 const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
441 for (int j = 0; j < ncols_interleaved; j++) {
442 sumf[j] = 0.0;
443 sum_minf[j] = 0.0;
444 }
445 for (int l = 0; l < nb; l++) {
446 for (int k = 0; k < (qk / (4 * blocklen)); k++) {
447 const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
448 const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
449 const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
450 const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
451 for (int j = 0; j < ncols_interleaved; j++) {
452 sumi1 = 0;
453 sumi2 = 0;
454 sumi3 = 0;
455 sumi4 = 0;
456 sumi = 0;
457 int offset = ((k / 2) % 2) + j * 2;
458 for (int i = 0; i < blocklen; ++i){
459 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
460 const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
461 const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
462 const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
463 sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
464 sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
465 sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
466 sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
467
468 sumi1 = sumi1 * (scales_0[offset] & 0xF);
469 sumi2 = sumi2 * (scales_1[offset] & 0xF);
470 sumi3 = sumi3 * (scales_2[offset] & 0xF);
471 sumi4 = sumi4 * (scales_3[offset] & 0xF);
472 sumi += sumi1 + sumi2 + sumi3 + sumi4;
473 }
474 sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
475 }
476 }
477 for(int sb = 0; sb < 8; sb++) {
478 const uint8_t *mins = b_ptr[l].scales + sb * 16;
479 for(int j = 0; j < ncols_interleaved; j++){
480 sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
481 }
482 }
483 }
484 for (int j = 0; j < ncols_interleaved; j++) {
485 s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
486 }
487 }
488}
489
490void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
491 const int qk = QK8_0;
492 const int nb = n / qk;
493 const int ncols_interleaved = 4;
494 const int blocklen = 4;
495
496 assert(nr == 1);
497 assert(n % qk == 0);
498 assert(nc % ncols_interleaved == 0);
499
500 UNUSED(bs);
501 UNUSED(nr);
502
503 float sumf[4];
504 int sumi;
505
506 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
507 for (int x = 0; x < nc / ncols_interleaved; x++) {
508 const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
509
510 for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
511 for (int l = 0; l < nb; l++) {
512 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
513 for (int j = 0; j < ncols_interleaved; j++) {
514 sumi = 0;
515 for (int i = 0; i < blocklen; ++i) {
516 const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
517 const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
518 sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
519 }
520 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
521 }
522 }
523 }
524 for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
525 }
526}
527
528void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
529 const int qk = QK8_0;
530 const int nb = n / qk;
531 const int ncols_interleaved = 8;
532 const int blocklen = 8;
533
534 assert(nr == 1);
535 assert(n % qk == 0);
536 assert(nc % ncols_interleaved == 0);
537
538 UNUSED(bs);
539 UNUSED(nr);
540
541 float sumf[8];
542 int sumi;
543
544 const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
545 for (int x = 0; x < nc / ncols_interleaved; x++) {
546 const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
547
548 for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
549 for (int l = 0; l < nb; l++) {
550 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
551 for (int j = 0; j < ncols_interleaved; j++) {
552 sumi = 0;
553 for (int i = 0; i < blocklen; ++i) {
554 const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
555 const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
556 sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
557 }
558 sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
559 }
560 }
561 }
562 for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
563 }
564}
565
566void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
567 const int qk = QK8_0;
568 const int nb = n / qk;
569 const int ncols_interleaved = 4;
570 const int blocklen = 4;
571
572 assert (n % qk == 0);
573 assert (nr % 4 == 0);
574 assert (nc % ncols_interleaved == 0);
575
576 UNUSED(s);
577 UNUSED(bs);
578 UNUSED(vx);
579 UNUSED(vy);
580 UNUSED(nr);
581 UNUSED(nc);
582 UNUSED(nb);
583 UNUSED(ncols_interleaved);
584 UNUSED(blocklen);
585
586 {
587 float sumf[4][4];
588 int sumi;
589
590 for (int y = 0; y < nr / 4; y++) {
591 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
592 for (int x = 0; x < nc / ncols_interleaved; x++) {
593 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
594 for (int m = 0; m < 4; m++) {
595 for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
596 }
597 for (int l = 0; l < nb; l++) {
598 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
599 for (int m = 0; m < 4; m++) {
600 for (int j = 0; j < ncols_interleaved; j++) {
601 sumi = 0;
602 for (int i = 0; i < blocklen; ++i) {
603 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
604 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
605 sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
606 (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
607 }
608 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
609 }
610 }
611 }
612 }
613 for (int m = 0; m < 4; m++) {
614 for (int j = 0; j < ncols_interleaved; j++)
615 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
616 }
617 }
618 }
619 }
620}
621
622void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
623 const int qk = QK8_0;
624 const int nb = n / qk;
625 const int ncols_interleaved = 4;
626 const int blocklen = 8;
627
628 assert (n % qk == 0);
629 assert (nr % 4 == 0);
630 assert (nc % ncols_interleaved == 0);
631
632 UNUSED(s);
633 UNUSED(bs);
634 UNUSED(vx);
635 UNUSED(vy);
636 UNUSED(nr);
637 UNUSED(nc);
638 UNUSED(nb);
639 UNUSED(ncols_interleaved);
640 UNUSED(blocklen);
641
642 float sumf[4][4];
643 int sumi;
644
645 for (int y = 0; y < nr / 4; y++) {
646 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
647 for (int x = 0; x < nc / ncols_interleaved; x++) {
648 const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
649 for (int m = 0; m < 4; m++) {
650 for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
651 }
652 for (int l = 0; l < nb; l++) {
653 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
654 for (int m = 0; m < 4; m++) {
655 for (int j = 0; j < ncols_interleaved; j++) {
656 sumi = 0;
657 for (int i = 0; i < blocklen; ++i) {
658 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
659 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
660 sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
661 (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
662 }
663 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
664 }
665 }
666 }
667 }
668 for (int m = 0; m < 4; m++) {
669 for (int j = 0; j < ncols_interleaved; j++)
670 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
671 }
672 }
673 }
674}
675
676void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
677 const int qk = QK8_0;
678 const int nb = n / qk;
679 const int ncols_interleaved = 8;
680 const int blocklen = 8;
681
682 assert (n % qk == 0);
683 assert (nr % 4 == 0);
684 assert (nc % ncols_interleaved == 0);
685
686 UNUSED(s);
687 UNUSED(bs);
688 UNUSED(vx);
689 UNUSED(vy);
690 UNUSED(nr);
691 UNUSED(nc);
692 UNUSED(nb);
693 UNUSED(ncols_interleaved);
694 UNUSED(blocklen);
695
696 float sumf[4][8];
697 int sumi;
698
699 for (int y = 0; y < nr / 4; y++) {
700 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
701 for (int x = 0; x < nc / ncols_interleaved; x++) {
702 const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
703 for (int m = 0; m < 4; m++) {
704 for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
705 }
706 for (int l = 0; l < nb; l++) {
707 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
708 for (int m = 0; m < 4; m++) {
709 for (int j = 0; j < ncols_interleaved; j++) {
710 sumi = 0;
711 for (int i = 0; i < blocklen; ++i) {
712 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
713 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
714 sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
715 (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
716 }
717 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
718 }
719 }
720 }
721 }
722 for (int m = 0; m < 4; m++) {
723 for (int j = 0; j < ncols_interleaved; j++)
724 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
725 }
726 }
727 }
728}
729
730void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
731 const int qk = QK_K;
732 const int nb = n / qk;
733 const int ncols_interleaved = 8;
734 const int blocklen = 8;
735 static const uint32_t kmask1 = 0x3f3f3f3f;
736 static const uint32_t kmask2 = 0x0f0f0f0f;
737 static const uint32_t kmask3 = 0x03030303;
738
739 assert (n % qk == 0);
740 assert (nr % 4 == 0);
741 assert (nc % ncols_interleaved == 0);
742
743 UNUSED(s);
744 UNUSED(bs);
745 UNUSED(vx);
746 UNUSED(vy);
747 UNUSED(nr);
748 UNUSED(nc);
749 UNUSED(nb);
750 UNUSED(ncols_interleaved);
751 UNUSED(blocklen);
752
753 float sumf[4][8];
754 float sum_minf[4][8];
755 uint32_t utmp[32];
756 int sumi1;
757 int sumi2;
758 int sumi;
759
760 for (int y = 0; y < nr / 4; y++) {
761 const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
762 for (int x = 0; x < nc / ncols_interleaved; x++) {
763 const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
764 for (int m = 0; m < 4; m++) {
765 for (int j = 0; j < ncols_interleaved; j++) {
766 sumf[m][j] = 0.0;
767 sum_minf[m][j] = 0.0;
768 }
769 }
770 for (int l = 0; l < nb; l++) {
771 for (int sb = 0; sb < 8; sb++) {
772 memcpy(dest: utmp + sb * 4, src: b_ptr[l].scales + sb * 12, n: 12);
773 utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
774 const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
775 utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
776 utmp[sb * 4 + 2] = uaux_0;
777 utmp[sb * 4 + 0] &= kmask1;
778 }
779 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
780 uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
781 uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
782 for (int m = 0; m < 4; m++) {
783 for (int j = 0; j < ncols_interleaved; j++) {
784 sumi1 = 0;
785 sumi2 = 0;
786 sumi = 0;
787 for (int i = 0; i < blocklen; ++i) {
788 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
789 const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
790 sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
791 sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
792 sumi1 = sumi1 * scales_0[j];
793 sumi2 = sumi2 * scales_1[j];
794 sumi += sumi1 + sumi2;
795 }
796 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
797 }
798 }
799 }
800 for (int sb = 0; sb < 8; sb++) {
801 uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
802 for(int m = 0; m < 4; m++) {
803 const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
804 for(int j = 0; j < ncols_interleaved; j++) {
805 sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
806 }
807 }
808 }
809 }
810 for (int m = 0; m < 4; m++) {
811 for (int j = 0; j < ncols_interleaved; j++) {
812 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
813 }
814 }
815 }
816 }
817}
818
819void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
820 const int qk = QK_K;
821 const int nb = n / qk;
822 const int ncols_interleaved = 8;
823 const int blocklen = 8;
824
825 assert (n % qk == 0);
826 assert (nr % 4 == 0);
827 assert (nc % ncols_interleaved == 0);
828
829 UNUSED(s);
830 UNUSED(bs);
831 UNUSED(vx);
832 UNUSED(vy);
833 UNUSED(nr);
834 UNUSED(nc);
835 UNUSED(nb);
836 UNUSED(ncols_interleaved);
837 UNUSED(blocklen);
838
839 float sumf[4][8];
840 float sum_minf[4][8];
841 int sumi1, sumi2, sumi3, sumi4;
842 int sumi;
843
844 for (int y = 0; y < nr / 4; y++) {
845 const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
846 for (int x = 0; x < nc / ncols_interleaved; x++) {
847 const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
848 for (int m = 0; m < 4; m++) {
849 for (int j = 0; j < ncols_interleaved; j++) {
850 sumf[m][j] = 0.0;
851 sum_minf[m][j] = 0.0;
852 }
853 }
854 for (int l = 0; l < nb; l++) {
855 for (int k = 0; k < (qk / (4 * blocklen)); k++) {
856
857 const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
858 const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
859 const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
860 const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
861 for (int m = 0; m < 4; m++) {
862 for (int j = 0; j < ncols_interleaved; j++) {
863 sumi1 = 0;
864 sumi2 = 0;
865 sumi3 = 0;
866 sumi4 = 0;
867 sumi = 0;
868 int offset = ((k / 2) % 2) + j * 2;
869 for (int i = 0; i < blocklen; ++i){
870 const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
871 const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
872 const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
873 const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
874 sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
875 sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
876 sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
877 sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
878 sumi1 = sumi1 * (scales_0[offset] & 0xF);
879 sumi2 = sumi2 * (scales_1[offset] & 0xF);
880 sumi3 = sumi3 * (scales_2[offset] & 0xF);
881 sumi4 = sumi4 * (scales_3[offset] & 0xF);
882 sumi += sumi1 + sumi2 + sumi3 + sumi4;
883 }
884 sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
885 }
886 }
887 }
888 for(int sb = 0; sb < 8; sb++) {
889 const uint8_t *mins = b_ptr[l].scales + sb * 16;
890 for(int m = 0; m < 4; m++) {
891 const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
892 for(int j = 0; j < ncols_interleaved; j++) {
893 int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
894 sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
895 }
896 }
897 }
898 }
899
900 for (int m = 0; m < 4; m++) {
901 for (int j = 0; j < ncols_interleaved; j++) {
902 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
903 }
904 }
905 }
906 }
907}
908
909
910void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
911 const int qk = QK8_0;
912 const int nb = n / qk;
913 const int ncols_interleaved = 4;
914 const int blocklen = 4;
915
916 assert (n % qk == 0);
917 assert (nr % 4 == 0);
918 assert (nc % ncols_interleaved == 0);
919
920 UNUSED(s);
921 UNUSED(bs);
922 UNUSED(vx);
923 UNUSED(vy);
924 UNUSED(nr);
925 UNUSED(nc);
926 UNUSED(nb);
927 UNUSED(ncols_interleaved);
928 UNUSED(blocklen);
929
930 {
931 float sumf[4][4];
932 int sumi;
933
934 for (int y = 0; y < nr / 4; y++) {
935 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
936 for (int x = 0; x < nc / ncols_interleaved; x++) {
937 const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
938 for (int m = 0; m < 4; m++) {
939 for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
940 }
941 for (int l = 0; l < nb; l++) {
942 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
943 for (int m = 0; m < 4; m++) {
944 for (int j = 0; j < ncols_interleaved; j++) {
945 sumi = 0;
946 for (int i = 0; i < blocklen; ++i) {
947 const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
948 const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
949 sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
950 (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
951 }
952 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
953 }
954 }
955 }
956 }
957 for (int m = 0; m < 4; m++) {
958 for (int j = 0; j < ncols_interleaved; j++)
959 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
960 }
961 }
962 }
963 }
964}
965
966void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
967 const int qk = QK8_0;
968 const int nb = n / qk;
969 const int ncols_interleaved = 8;
970 const int blocklen = 8;
971
972 assert(n % qk == 0);
973 assert(nr % 4 == 0);
974 assert(nc % ncols_interleaved == 0);
975
976 float sumf[4][8];
977 int sumi;
978
979 for (int y = 0; y < nr / 4; y++) {
980 const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
981 for (int x = 0; x < nc / ncols_interleaved; x++) {
982 const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
983 for (int m = 0; m < 4; m++) {
984 for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
985 }
986 for (int l = 0; l < nb; l++) {
987 for (int k = 0; k < (qk / (2 * blocklen)); k++) {
988 for (int m = 0; m < 4; m++) {
989 for (int j = 0; j < ncols_interleaved; j++) {
990 sumi = 0;
991 for (int i = 0; i < blocklen; ++i) {
992 const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
993 const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
994 sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
995 (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
996 }
997 sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
998 }
999 }
1000 }
1001 }
1002 for (int m = 0; m < 4; m++) {
1003 for (int j = 0; j < ncols_interleaved; j++)
1004 s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1005 }
1006 }
1007 }
1008}
1009
1010} // extern "C"
1011
1012static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
1013 block_q4_0x4 out;
1014
1015 for (int i = 0; i < 4; i++) {
1016 out.d[i] = in[i].d;
1017 }
1018
1019 const int end = QK4_0 * 2 / blck_size_interleave;
1020
1021 if (blck_size_interleave == 8) {
1022 const uint64_t xor_mask = 0x8888888888888888ULL;
1023 for (int i = 0; i < end; ++i) {
1024 int src_id = i % 4;
1025 int src_offset = (i / 4) * blck_size_interleave;
1026 int dst_offset = i * blck_size_interleave;
1027
1028 uint64_t elems;
1029 // Using memcpy to avoid unaligned memory accesses
1030 memcpy(dest: &elems, src: &in[src_id].qs[src_offset], n: sizeof(uint64_t));
1031 elems ^= xor_mask;
1032 memcpy(dest: &out.qs[dst_offset], src: &elems, n: sizeof(uint64_t));
1033 }
1034 } else if (blck_size_interleave == 4) {
1035 const uint32_t xor_mask = 0x88888888;
1036 for (int i = 0; i < end; ++i) {
1037 int src_id = i % 4;
1038 int src_offset = (i / 4) * blck_size_interleave;
1039 int dst_offset = i * blck_size_interleave;
1040
1041 uint32_t elems;
1042 memcpy(dest: &elems, src: &in[src_id].qs[src_offset], n: sizeof(uint32_t));
1043 elems ^= xor_mask;
1044 memcpy(dest: &out.qs[dst_offset], src: &elems, n: sizeof(uint32_t));
1045 }
1046 } else {
1047 GGML_ASSERT(false);
1048 }
1049
1050 return out;
1051}
1052
1053// interleave 8 block_q4_0s in blocks of blck_size_interleave
1054// returns an interleaved block_q4_0x8
1055// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
1056// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
1057static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
1058 block_q4_0x8 out;
1059
1060 for (int i = 0; i < 8; i++) {
1061 out.d[i] = in[i].d;
1062 }
1063
1064 const int end = QK4_0 * 4 / blck_size_interleave;
1065 const uint64_t xor_mask = 0x8888888888888888ULL;
1066
1067 for (int i = 0; i < end; ++i) {
1068 int src_id = i % 8;
1069 int src_offset = (i / 8) * blck_size_interleave;
1070 int dst_offset = i * blck_size_interleave;
1071
1072 uint64_t elems;
1073 memcpy(dest: &elems, src: &in[src_id].qs[src_offset], n: sizeof(uint64_t));
1074 elems ^= xor_mask;
1075 memcpy(dest: &out.qs[dst_offset], src: &elems, n: sizeof(uint64_t));
1076 }
1077
1078 return out;
1079}
1080
1081static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
1082 block_q4_Kx8 out;
1083 //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
1084 for (int i = 0; i < 8; i++) {
1085 out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
1086 }
1087
1088 for (int i = 0; i < 8; i++) {
1089 out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
1090 }
1091
1092 const int end = QK_K * 4 / blck_size_interleave;
1093
1094 // Interleave Q4_K quants by taking 8 bytes at a time
1095 for (int i = 0; i < end; ++i) {
1096 int src_id = i % 8;
1097 int src_offset = (i / 8) * blck_size_interleave;
1098 int dst_offset = i * blck_size_interleave;
1099
1100 uint64_t elems;
1101 memcpy(dest: &elems, src: &in[src_id].qs[src_offset], n: sizeof(uint64_t));
1102 memcpy(dest: &out.qs[dst_offset], src: &elems, n: sizeof(uint64_t));
1103 }
1104
1105 // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
1106 // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
1107 // The output Q4_Kx8 structure has 96 bytes
1108 // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
1109 // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
1110 uint8_t s[8], m[8];
1111
1112 for (int i = 0; i < 4; i++) {
1113 for (int j = 0; j < 8; j++) {
1114 s[j] = in[j].scales[i] & 63;
1115 m[j] = in[j].scales[i + 4] & 63;
1116 }
1117
1118 out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
1119 out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
1120 out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
1121 out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
1122 out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
1123 out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
1124 out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
1125 out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
1126 out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
1127 out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
1128 out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
1129 out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
1130
1131 }
1132
1133 for (int i = 0; i < 4; i++) {
1134 for (int j = 0; j < 8; j++) {
1135 s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
1136 m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
1137 }
1138
1139 out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
1140 out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
1141 out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
1142 out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
1143 out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
1144 out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
1145 out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
1146 out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
1147 out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
1148 out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
1149 out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
1150 out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
1151
1152 }
1153
1154 return out;
1155}
1156
1157static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
1158 block_q2_Kx8 out;
1159
1160 // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
1161 for (int i = 0; i < 8; i++) {
1162 out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
1163 }
1164
1165 for (int i = 0; i < 8; i++) {
1166 out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
1167 }
1168
1169 const int end = QK_K * 2 / blck_size_interleave;
1170
1171 // Interleave Q2_K quants by taking 8 bytes at a time
1172 for (int i = 0; i < end; ++i) {
1173 int src_id = i % 8;
1174 int src_offset = (i / 8) * blck_size_interleave;
1175 int dst_offset = i * blck_size_interleave;
1176
1177 uint64_t elems;
1178 memcpy(dest: &elems, src: &in[src_id].qs[src_offset], n: sizeof(uint64_t));
1179 memcpy(dest: &out.qs[dst_offset], src: &elems, n: sizeof(uint64_t));
1180 }
1181
1182 // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
1183 // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
1184 // The output Q2_Kx8 structure has 128 bytes for storing scales and mins
1185 // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
1186 // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
1187
1188 for(int i = 0; i < 128; i++){
1189
1190 // Index for selecting which q2k super block
1191 int src1 = (i % 16) / 2;
1192 // Index for selecting scale
1193 int src2 = ((i / 16) * 2) + (i % 2);
1194
1195 out.scales[i] = in[src1].scales[src2];
1196 }
1197 return out;
1198
1199}
1200
1201static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1202 GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
1203 GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
1204 constexpr int nrows_interleaved = 4;
1205
1206 block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
1207 const block_q4_0 * src = (const block_q4_0 *)data;
1208 block_q4_0 dst_tmp[4];
1209 int nrow = ggml_nrows(tensor: t);
1210 int nblocks = t->ne[0] / QK4_0;
1211
1212 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
1213
1214 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1215 return -1;
1216 }
1217
1218 for (int b = 0; b < nrow; b += nrows_interleaved) {
1219 for (int64_t x = 0; x < nblocks; x++) {
1220 for (int i = 0; i < nrows_interleaved; i++) {
1221 dst_tmp[i] = src[x + i * nblocks];
1222 }
1223 *dst++ = make_block_q4_0x4(in: dst_tmp, blck_size_interleave: interleave_block);
1224 }
1225 src += nrows_interleaved * nblocks;
1226 }
1227 return 0;
1228
1229 GGML_UNUSED(data_size);
1230}
1231static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1232 GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
1233 GGML_ASSERT(interleave_block == 8);
1234 constexpr int nrows_interleaved = 8;
1235
1236 block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
1237 const block_q4_K * src = (const block_q4_K*) data;
1238 block_q4_K dst_tmp[8];
1239 int nrow = ggml_nrows(tensor: t);
1240 int nblocks = t->ne[0] / QK_K;
1241
1242 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
1243
1244 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1245 return -1;
1246 }
1247
1248 for (int b = 0; b < nrow; b += nrows_interleaved) {
1249 for (int64_t x = 0; x < nblocks; x++) {
1250 for (int i = 0; i < nrows_interleaved; i++ ) {
1251 dst_tmp[i] = src[x + i * nblocks];
1252 }
1253 *dst++ = make_block_q4_Kx8(in: dst_tmp, blck_size_interleave: interleave_block);
1254 }
1255 src += nrows_interleaved * nblocks;
1256 }
1257 return 0;
1258
1259 GGML_UNUSED(data_size);
1260}
1261
1262static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1263 GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
1264 GGML_ASSERT(interleave_block == 8);
1265 constexpr int nrows_interleaved = 8;
1266
1267 block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
1268 const block_q2_K * src = (const block_q2_K*) data;
1269 block_q2_K dst_tmp[8];
1270 int nrow = ggml_nrows(tensor: t);
1271 int nblocks = t->ne[0] / QK_K;
1272
1273 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
1274
1275 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1276 return -1;
1277 }
1278
1279 for (int b = 0; b < nrow; b += nrows_interleaved) {
1280 for (int64_t x = 0; x < nblocks; x++) {
1281 for (int i = 0; i < nrows_interleaved; i++ ) {
1282 dst_tmp[i] = src[x + i * nblocks];
1283 }
1284 *dst++ = make_block_q2_Kx8(in: dst_tmp, blck_size_interleave: interleave_block);
1285 }
1286 src += nrows_interleaved * nblocks;
1287 }
1288 return 0;
1289
1290 GGML_UNUSED(data_size);
1291}
1292
1293static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1294 GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
1295 GGML_ASSERT(interleave_block == 8);
1296 constexpr int nrows_interleaved = 8;
1297
1298 block_q4_0x8 * dst = (block_q4_0x8*)t->data;
1299 const block_q4_0 * src = (const block_q4_0*) data;
1300 block_q4_0 dst_tmp[8];
1301 int nrow = ggml_nrows(tensor: t);
1302 int nblocks = t->ne[0] / QK4_0;
1303
1304 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
1305
1306 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1307 return -1;
1308 }
1309
1310 for (int b = 0; b < nrow; b += nrows_interleaved) {
1311 for (int64_t x = 0; x < nblocks; x++) {
1312 for (int i = 0; i < nrows_interleaved; i++ ) {
1313 dst_tmp[i] = src[x + i * nblocks];
1314 }
1315 *dst++ = make_block_q4_0x8(in: dst_tmp, blck_size_interleave: interleave_block);
1316 }
1317 src += nrows_interleaved * nblocks;
1318 }
1319 return 0;
1320
1321 GGML_UNUSED(data_size);
1322}
1323
1324static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
1325 block_iq4_nlx4 out;
1326
1327 for (int i = 0; i < 4; i++) {
1328 out.d[i] = in[i].d;
1329 }
1330
1331 const int end = QK4_NL * 2 / blck_size_interleave;
1332
1333 // TODO: this branch seems wrong
1334 //if (blck_size_interleave == 8) {
1335 // for (int i = 0; i < end; ++i) {
1336 // int src_id = i % 4;
1337 // int src_offset = (i / 4) * blck_size_interleave;
1338 // int dst_offset = i * blck_size_interleave;
1339
1340 // // Using memcpy to avoid unaligned memory accesses
1341 // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
1342 // }
1343 //} else
1344 if (blck_size_interleave == 4) {
1345 for (int i = 0; i < end; ++i) {
1346 int src_id = i % 4;
1347 int src_offset = (i / 4) * blck_size_interleave;
1348 int dst_offset = i * blck_size_interleave;
1349
1350 memcpy(dest: &out.qs[dst_offset], src: &in[src_id].qs[src_offset], n: sizeof(uint32_t));
1351 }
1352 } else {
1353 GGML_ASSERT(false);
1354 }
1355
1356 return out;
1357}
1358
1359static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1360 GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
1361 GGML_ASSERT(interleave_block == 4);
1362
1363 const block_iq4_nl * src = (const block_iq4_nl *)data;
1364 block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data;
1365
1366 block_iq4_nl dst_tmp[4];
1367
1368 int nrow = ggml_nrows(tensor: t);
1369 int nrows_interleaved = 4;
1370 int nblocks = t->ne[0] / QK4_NL;
1371
1372 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
1373
1374 if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
1375 return -1;
1376 }
1377
1378 for (int b = 0; b < nrow; b += nrows_interleaved) {
1379 for (int64_t x = 0; x < nblocks; x++) {
1380 for (int i = 0; i < nrows_interleaved; i++) {
1381 dst_tmp[i] = src[x + i * nblocks];
1382 }
1383 *dst++ = make_block_iq4_nlx4(in: dst_tmp, blck_size_interleave: interleave_block);
1384 }
1385 src += nrows_interleaved * nblocks;
1386 }
1387 return 0;
1388
1389 GGML_UNUSED(data_size);
1390}
1391
1392static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
1393 block_iq4_nlx8 out;
1394
1395 for (int i = 0; i < 8; i++) {
1396 out.d[i] = in[i].d;
1397 }
1398
1399 const int end = QK4_NL * 4 / blck_size_interleave;
1400
1401 if (blck_size_interleave == 8) {
1402 for (int i = 0; i < end; ++i) {
1403 int src_id = i % 8;
1404 int src_offset = (i / 8) * blck_size_interleave;
1405 int dst_offset = i * blck_size_interleave;
1406
1407 memcpy(dest: &out.qs[dst_offset], src: &in[src_id].qs[src_offset], n: sizeof(uint64_t));
1408 }
1409 } else {
1410 GGML_ASSERT(false);
1411 }
1412
1413 return out;
1414}
1415
1416static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
1417 GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
1418 GGML_ASSERT(interleave_block == 8);
1419
1420 const block_iq4_nl * src = (const block_iq4_nl *)data;
1421 block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data;
1422
1423 block_iq4_nl dst_tmp[8];
1424
1425 int nrow = ggml_nrows(tensor: t);
1426 int nrows_interleaved = 8;
1427 int nblocks = t->ne[0] / QK4_NL;
1428
1429 GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
1430
1431 if (t->ne[1] % nrows_interleaved != 0) {
1432 return -1;
1433 }
1434
1435 for (int b = 0; b < nrow; b += nrows_interleaved) {
1436 for (int64_t x = 0; x < nblocks; x++) {
1437 for (int i = 0; i < nrows_interleaved; i++) {
1438 dst_tmp[i] = src[x + i * nblocks];
1439 }
1440 *dst++ = make_block_iq4_nlx8(in: dst_tmp, blck_size_interleave: interleave_block);
1441 }
1442 src += nrows_interleaved * nblocks;
1443 }
1444 return 0;
1445
1446 GGML_UNUSED(data_size);
1447}
1448
1449namespace ggml::cpu::repack {
1450// repack
1451template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
1452int repack(struct ggml_tensor *, const void *, size_t);
1453
1454// TODO: generalise.
1455template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1456 return repack_q4_0_to_q4_0_4_bl(t, interleave_block: 4, data, data_size);
1457}
1458
1459template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1460 return repack_q4_0_to_q4_0_4_bl(t, interleave_block: 8, data, data_size);
1461}
1462
1463template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1464 return repack_q4_0_to_q4_0_8_bl(t, interleave_block: 8, data, data_size);
1465}
1466
1467template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1468 return repack_q4_K_to_q4_K_8_bl(t, interleave_block: 8, data, data_size);
1469}
1470
1471template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1472 return repack_q2_K_to_q2_K_8_bl(t, interleave_block: 8, data, data_size);
1473}
1474
1475template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1476 return repack_iq4_nl_to_iq4_nl_4_bl(t, interleave_block: 4, data, data_size);
1477}
1478
1479// TODO: needs to be revisited
1480//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
1481// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
1482//}
1483
1484template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
1485 return repack_iq4_nl_to_iq4_nl_8_bl(t, interleave_block: 8, data, data_size);
1486}
1487
1488// gemv
1489template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1490void gemv(int, float *, size_t, const void *, const void *, int, int);
1491
1492template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1493 ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1494}
1495
1496template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1497 ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1498}
1499
1500template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1501 ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1502}
1503
1504template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1505 ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1506}
1507
1508template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1509 ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1510}
1511
1512template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1513 ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1514}
1515
1516template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1517 ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1518}
1519
1520// gemm
1521template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
1522void gemm(int, float *, size_t, const void *, const void *, int, int);
1523
1524template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1525 ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1526}
1527
1528template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1529 ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
1530}
1531
1532template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1533 ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1534}
1535
1536template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1537 ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1538}
1539
1540template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1541 ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
1542}
1543
1544template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1545 ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
1546}
1547
1548template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
1549 ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
1550}
1551
1552class tensor_traits_base : public ggml::cpu::tensor_traits {
1553 public:
1554 virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
1555};
1556
1557template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
1558
1559 bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
1560 // not realy a GGML_TYPE_Q8_0 but same size.
1561 switch (op->op) {
1562 case GGML_OP_MUL_MAT:
1563 {
1564 size = ggml_row_size(type: PARAM_TYPE, ne: ggml_nelements(tensor: op->src[1]));
1565 return true;
1566 }
1567 case GGML_OP_MUL_MAT_ID:
1568 {
1569 size = ggml_row_size(type: PARAM_TYPE, ne: ggml_nelements(tensor: op->src[1]));
1570 size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
1571
1572 const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
1573 const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
1574
1575 const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
1576
1577 size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
1578
1579 return true;
1580 }
1581 default:
1582 // GGML_ABORT("fatal error");
1583 break;
1584 }
1585 return false;
1586 }
1587
1588 bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
1589 switch (op->op) {
1590 case GGML_OP_MUL_MAT:
1591 forward_mul_mat(params, op);
1592 return true;
1593 case GGML_OP_MUL_MAT_ID:
1594 forward_mul_mat_id(params, op);
1595 return true;
1596 default:
1597 // GGML_ABORT("fatal error");
1598 break;
1599 }
1600 return false;
1601 }
1602
1603 void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
1604 const ggml_tensor * src0 = op->src[0];
1605 const ggml_tensor * src1 = op->src[1];
1606 ggml_tensor * dst = op;
1607
1608 GGML_TENSOR_BINARY_OP_LOCALS
1609
1610 const void * src1_wdata = params->wdata;
1611 const size_t src1_col_stride = ggml_row_size(type: PARAM_TYPE, ne: ne10);
1612
1613 // If there are more than three rows in src1, use gemm; otherwise, use gemv.
1614 if (ne11 > 3) {
1615 gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1616 (float *) ((char *) dst->data) + src0_start, ne01,
1617 (const char *) src0->data + src0_start * nb01,
1618 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
1619 }
1620 for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
1621 gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1622 (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
1623 (const char *) src0->data + src0_start * nb01,
1624 (const char *) src1_wdata + (src1_col_stride * iter), 1,
1625 src0_end - src0_start);
1626 }
1627 }
1628
1629 void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
1630 const ggml_tensor * src0 = op->src[0];
1631 const ggml_tensor * src1 = op->src[1];
1632 ggml_tensor * dst = op;
1633
1634 GGML_TENSOR_BINARY_OP_LOCALS
1635
1636 const int ith = params->ith;
1637 const int nth = params->nth;
1638
1639 GGML_ASSERT(ne0 == ne01);
1640 GGML_ASSERT(ne1 == ne11);
1641 GGML_ASSERT(ne2 == ne12);
1642 GGML_ASSERT(ne3 == ne13);
1643
1644 // dst cannot be transposed or permuted
1645 GGML_ASSERT(nb0 == sizeof(float));
1646 GGML_ASSERT(nb0 <= nb1);
1647 GGML_ASSERT(nb1 <= nb2);
1648 GGML_ASSERT(nb2 <= nb3);
1649
1650 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1651
1652 GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
1653 // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
1654
1655 char * wdata = static_cast<char *>(params->wdata);
1656 const size_t nbw1 = ggml_row_size(type: PARAM_TYPE, ne: ne10);
1657
1658 assert(params->wsize >= nbw1 * ne11);
1659
1660 const ggml_from_float_t from_float = ggml_get_type_traits_cpu(type: PARAM_TYPE)->from_float;
1661
1662 int64_t i11_processed = 0;
1663 for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
1664 ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
1665 }
1666
1667 i11_processed = ne11 - ne11 % 4;
1668 for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
1669 from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
1670 }
1671
1672 // disable for NUMA
1673 const bool disable_chunking = ggml_is_numa();
1674
1675 // 4x chunks per thread
1676 int64_t nr = ggml_nrows(tensor: op->src[0]);
1677 int nth_scaled = nth * 4;
1678 int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1679 int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
1680
1681 // Ensure minimum chunk size to avoid alignment issues with high thread counts
1682 // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
1683 const int64_t min_chunk_size = NB_COLS;
1684 if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
1685 nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1686 }
1687
1688 if (nth == 1 || nchunk < nth || disable_chunking) {
1689 nchunk = nth;
1690 }
1691
1692 // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
1693 // This prevents creating too many tiny chunks that could overlap after alignment
1694 const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
1695 if (nchunk > max_nchunk) {
1696 nchunk = max_nchunk;
1697 }
1698
1699 if (ith == 0) {
1700 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1701 ggml_threadpool_chunk_set(tp: params->threadpool, value: nth);
1702 }
1703
1704 ggml_barrier(tp: params->threadpool);
1705
1706 // The first chunk comes from our thread_id, the rest will get auto-assigned.
1707 int current_chunk = ith;
1708
1709 while (current_chunk < nchunk) {
1710 int64_t src0_start = (current_chunk * ne01) / nchunk;
1711 int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
1712
1713 // Align boundaries to NB_COLS - round up to ensure all data is included
1714 // The chunk size limiting above ensures chunks are large enough to prevent overlaps
1715 src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
1716 src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
1717 if (src0_end > ne01) {
1718 src0_end = ne01;
1719 }
1720
1721 if (src0_start >= src0_end) {
1722 break;
1723 }
1724
1725 forward_mul_mat_one_chunk(params, op: dst, src0_start, src0_end);
1726
1727 current_chunk = ggml_threadpool_chunk_add(tp: params->threadpool, value: 1);
1728 }
1729 }
1730
1731 void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
1732 const ggml_tensor * src0 = op->src[0];
1733 const ggml_tensor * src1 = op->src[1];
1734 const ggml_tensor * ids = op->src[2];
1735 ggml_tensor * dst = op;
1736
1737 GGML_TENSOR_BINARY_OP_LOCALS
1738
1739 const int ith = params->ith;
1740 const int nth = params->nth;
1741
1742 const ggml_from_float_t from_float = ggml_get_type_traits_cpu(type: PARAM_TYPE)->from_float;
1743
1744 // we don't support permuted src0 or src1
1745 GGML_ASSERT(nb00 == ggml_type_size(src0->type));
1746 GGML_ASSERT(nb10 == ggml_type_size(src1->type));
1747
1748 // dst cannot be transposed or permuted
1749 GGML_ASSERT(nb0 == sizeof(float));
1750 GGML_ASSERT(nb0 <= nb1);
1751 GGML_ASSERT(nb1 <= nb2);
1752 GGML_ASSERT(nb2 <= nb3);
1753
1754 GGML_ASSERT(ne03 == 1);
1755 GGML_ASSERT(ne13 == 1);
1756 GGML_ASSERT(ne3 == 1);
1757
1758 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1759
1760 // row groups
1761 const int n_ids = ids->ne[0]; // n_expert_used
1762 const int n_as = ne02; // n_expert
1763
1764 const size_t nbw1 = ggml_row_size(type: PARAM_TYPE, ne: ne10);
1765 const size_t nbw2 = nbw1*ne11;
1766 const size_t nbw3 = nbw2*ne12;
1767
1768 struct mmid_row_mapping {
1769 int32_t i1;
1770 int32_t i2;
1771 };
1772
1773 GGML_ASSERT(params->wsize >=
1774 (GGML_PAD(nbw3, sizeof(int64_t)) +
1775 n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
1776 );
1777
1778 auto * wdata = (char *)params->wdata;
1779 auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
1780
1781 // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
1782 auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
1783 struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
1784
1785 // src1: float32 => param type
1786 for (int64_t i12 = 0; i12 < ne12; ++i12) {
1787 for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
1788 from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
1789 (void *) (wdata + i12 * nbw2 + i11 * nbw1),
1790 ne10);
1791 }
1792 }
1793
1794#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
1795
1796 if (ith == 0) {
1797 // initialize matrix_row_counts
1798 memset(s: matrix_row_counts, c: 0, n: n_as * sizeof(int64_t));
1799
1800 // group rows by src0 matrix
1801 for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
1802 for (int32_t id = 0; id < n_ids; ++id) {
1803 const int32_t i02 =
1804 *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
1805
1806 GGML_ASSERT(i02 >= 0 && i02 < n_as);
1807
1808 MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
1809 matrix_row_counts[i02] += 1;
1810 }
1811 }
1812 }
1813
1814 ggml_barrier(tp: params->threadpool);
1815
1816 // compute each matrix multiplication in sequence
1817 for (int cur_a = 0; cur_a < n_as; ++cur_a) {
1818 const int64_t cne1 = matrix_row_counts[cur_a];
1819
1820 if (cne1 == 0) {
1821 continue;
1822 }
1823
1824 const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
1825
1826 //const int64_t nr0 = ne01; // src0 rows
1827 const int64_t nr1 = cne1; // src1 rows
1828
1829 int64_t src0_cur_start = (ith * ne01) / nth;
1830 int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
1831
1832 // Align boundaries to NB_COLS - round up to ensure all data is included
1833 src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
1834 src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
1835 if (src0_cur_end > ne01) {
1836 src0_cur_end = ne01;
1837 }
1838
1839 if (src0_cur_start >= src0_cur_end) {
1840 return;
1841 }
1842
1843 for (int ir1 = 0; ir1 < nr1; ir1++) {
1844 struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
1845
1846 const int id = row_mapping.i1; // selected expert index
1847
1848 const int64_t i11 = id % ne11;
1849 const int64_t i12 = row_mapping.i2; // row index in src1
1850
1851 const int64_t i1 = id; // selected expert index
1852 const int64_t i2 = i12; // row
1853
1854 const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
1855
1856 gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
1857 (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
1858 src0_cur + src0_cur_start * nb01,
1859 src1_col, 1, src0_cur_end - src0_cur_start);
1860 }
1861 }
1862#undef MMID_MATRIX_ROW
1863 }
1864
1865 int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
1866 GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
1867 (int) NB_COLS, (int) INTER_SIZE);
1868 return ggml::cpu::repack::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
1869 }
1870};
1871
1872} // namespace ggml::cpu::repack
1873
1874static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
1875
1876 // instance for Q4
1877 static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
1878 static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
1879 static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
1880 static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
1881
1882 // instance for Q2
1883 static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
1884
1885 // instance for IQ4
1886 static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
1887 static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
1888
1889 if (cur->type == GGML_TYPE_Q4_0) {
1890 if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
1891 if (cur->ne[1] % 8 == 0) {
1892 return &q4_0_8x8_q8_0;
1893 }
1894 }
1895 if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
1896 if (cur->ne[1] % 4 == 0) {
1897 return &q4_0_4x8_q8_0;
1898 }
1899 }
1900 if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1901 if (cur->ne[1] % 4 == 0) {
1902 return &q4_0_4x4_q8_0;
1903 }
1904 }
1905 } else if (cur->type == GGML_TYPE_Q4_K) {
1906 if (ggml_cpu_has_avx2()) {
1907 if (cur->ne[1] % 8 == 0) {
1908 return &q4_K_8x8_q8_K;
1909 }
1910 }
1911 } else if (cur->type == GGML_TYPE_Q2_K) {
1912 if (ggml_cpu_has_avx512()) {
1913 if (cur->ne[1] % 8 == 0) {
1914 return &q2_K_8x8_q8_K;
1915 }
1916 }
1917 } else if (cur->type == GGML_TYPE_IQ4_NL) {
1918 if (ggml_cpu_has_avx2()) {
1919 if (cur->ne[1] % 8 == 0) {
1920 return &iq4_nl_8x8_q8_0;
1921 }
1922 }
1923 if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
1924 if (cur->ne[1] % 4 == 0) {
1925 return &iq4_nl_4x4_q8_0;
1926 }
1927 }
1928 }
1929
1930 return nullptr;
1931}
1932
1933static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
1934 tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type(cur: tensor));
1935
1936 GGML_UNUSED(buffer);
1937 return GGML_STATUS_SUCCESS;
1938}
1939
1940static void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
1941 const void * data, size_t offset, size_t size) {
1942 GGML_ASSERT(offset == 0);
1943 GGML_ASSERT(size == ggml_nbytes(tensor));
1944
1945 auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra;
1946 auto OK = tensor_traits->repack(t: tensor, data, data_size: size);
1947
1948 GGML_ASSERT(OK == 0);
1949 GGML_UNUSED(buffer);
1950}
1951
1952static const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
1953 return "CPU_REPACK";
1954
1955 GGML_UNUSED(buft);
1956}
1957
1958static ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1959 ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft: ggml_backend_cpu_buffer_type(), size);
1960
1961 if (buffer == nullptr) {
1962 return nullptr;
1963 }
1964
1965 buffer->buft = buft;
1966 buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor;
1967 buffer->iface.set_tensor = ggml_backend_cpu_repack_buffer_set_tensor;
1968 buffer->iface.get_tensor = nullptr;
1969 buffer->iface.cpy_tensor = nullptr;
1970 return buffer;
1971}
1972
1973static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1974 return TENSOR_ALIGNMENT;
1975
1976 GGML_UNUSED(buft);
1977}
1978
1979namespace ggml::cpu::repack {
1980class extra_buffer_type : ggml::cpu::extra_buffer_type {
1981 bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1982 if ( op->op == GGML_OP_MUL_MAT &&
1983 op->src[0]->buffer &&
1984 (ggml_n_dims(tensor: op->src[0]) == 2) &&
1985 op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() &&
1986 ggml_repack_get_optimal_repack_type(cur: op->src[0])
1987 ) {
1988 if (op->src[1]->buffer && !ggml_backend_buft_is_host(buft: op->src[1]->buffer->buft)) {
1989 return false;
1990 }
1991 if (op->src[1]->type == GGML_TYPE_F32) {
1992 return true;
1993 }
1994 //if (op->src[1]->type == GGML_TYPE_Q8_0) {
1995 // return true;
1996 //}
1997 // may be possible if Q8_0 packed...
1998 } else if (op->op == GGML_OP_MUL_MAT_ID
1999 && op->src[0]->buffer
2000 && (ggml_n_dims(tensor: op->src[0]) == 3)
2001 && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
2002 && ggml_repack_get_optimal_repack_type(cur: op->src[0])
2003 ) {
2004 if (op->src[1]->buffer && !ggml_backend_buft_is_host(buft: op->src[1]->buffer->buft)) {
2005 return false;
2006 }
2007 if (op->src[1]->type == GGML_TYPE_F32) {
2008 return true;
2009 }
2010 //if (op->src[1]->type == GGML_TYPE_Q8_0) {
2011 // return true;
2012 //}
2013 }
2014 return false;
2015 }
2016
2017 ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
2018 if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
2019 if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
2020 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
2021 }
2022 }
2023 return nullptr;
2024 }
2025};
2026} // namespace ggml::cpu::repack
2027
2028ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
2029 static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = {
2030 /* .iface = */ {
2031 /* .get_name = */ ggml_backend_cpu_repack_buffer_type_get_name,
2032 /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
2033 /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
2034 /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
2035 /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
2036 /* .is_host = */ nullptr,
2037 },
2038 /* .device = */ ggml_backend_reg_dev_get(reg: ggml_backend_cpu_reg(), index: 0),
2039 /* .context = */ new ggml::cpu::repack::extra_buffer_type(),
2040 };
2041
2042 return &ggml_backend_cpu_buffer_type_repack;
2043}
2044