1// Copyright 2024 Mozilla Foundation
2//
3// Permission is hereby granted, free of charge, to any person obtaining
4// a copy of this software and associated documentation files (the
5// "Software"), to deal in the Software without restriction, including
6// without limitation the rights to use, copy, modify, merge, publish,
7// distribute, sublicense, and/or sell copies of the Software, and to
8// permit persons to whom the Software is furnished to do so, subject to
9// the following conditions:
10//
11// The above copyright notice and this permission notice shall be
12// included in all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21// SOFTWARE.
22
23//
24// _ _ ___ _ _ ___
25// | |_(_)_ _ _ _| _ ) | /_\ / __|
26// | _| | ' \ || | _ \ |__ / _ \\__ \.
27// \__|_|_||_\_, |___/____/_/ \_\___/
28// |__/
29//
30// BASIC LINEAR ALGEBRA SUBPROGRAMS
31//
32//
33// This file implements multithreaded CPU matrix multiplication for the
34// common contiguous use case C = Aᵀ * B. These kernels are designed to
35// have excellent performance[1] for matrices that fit in the CPU cache
36// without imposing any overhead such as cache filling or malloc calls.
37//
38// This implementation does not guarantee any upper bound with rounding
39// errors, which grow along with k. Our goal's to maximally exploit the
40// hardware for performance, and then use whatever resources remain for
41// improving numerical accuracy.
42//
43// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
46#if defined(__GNUC__)
47#pragma GCC diagnostic ignored "-Wpedantic"
48#pragma GCC diagnostic ignored "-Wignored-attributes"
49#endif
50
51#include "sgemm.h"
52#include "ggml-impl.h"
53#include "ggml-cpu-impl.h"
54#include "ggml-quants.h"
55#include "simd-mappings.h"
56
57#include <array>
58#include <type_traits>
59
60#ifdef _MSC_VER
61#define NOINLINE __declspec(noinline)
62#else
63#define NOINLINE __attribute__((__noinline__))
64#endif
65
66#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
67#define VECTOR_REGISTERS 32
68#else
69#define VECTOR_REGISTERS 16
70#endif
71
72#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
73
74namespace {
75
76inline float unhalf(ggml_fp16_t d) {
77 return GGML_CPU_FP16_TO_FP32(d);
78}
79
80////////////////////////////////////////////////////////////////////////////////////////////////////
81// VECTORIZED ARITHMETIC OPERATIONS
82
83#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
84inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(a: x, b: y); }
85inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(a: x, b: y); }
86inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(a: x, b: y); }
87#endif // __SSE__
88
89#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
90inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(a: x, b: y); }
91inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(a: x, b: y); }
92inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(a: x, b: y); }
93#endif // __AVX__
94
95#if defined(__AVX512F__)
96inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
97inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
98inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
99#endif // __AVX512F__
100
101#if defined(__ARM_NEON)
102inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
103inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
104inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
105#endif // __ARM_NEON
106
107#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
108inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
109inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
110inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
111#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
112
113#if defined(__VXE__) || defined(__VXE2__)
114inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
115inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
116inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
117#endif
118
119#if defined(__MMA__)
120typedef vector unsigned char vec_t;
121typedef __vector_quad acc_t;
122#endif
123////////////////////////////////////////////////////////////////////////////////////////////////////
124// VECTORIZED FUSED MULTIPLY ADD
125
126/**
127 * Computes a * b + c.
128 */
129template <typename T, typename U>
130inline U madd(T a, T b, U c) {
131 return add(mul(a, b), c);
132}
133
134#if defined(__FMA__)
135#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
136template <>
137inline __m256 madd(__m256 a, __m256 b, __m256 c) {
138 return _mm256_fmadd_ps(A: a, B: b, C: c);
139}
140#endif
141#if defined(__AVX512F__)
142template <>
143inline __m512 madd(__m512 a, __m512 b, __m512 c) {
144 return _mm512_fmadd_ps(a, b, c);
145}
146#endif
147#if defined(__AVX512BF16__)
148template <>
149inline __m512 madd(__m512bh a, __m512bh b, __m512 c) {
150 return _mm512_dpbf16_ps(c, a, b);
151}
152template <>
153inline __m256 madd(__m256bh a, __m256bh b, __m256 c) {
154 return _mm256_dpbf16_ps(c, a, b);
155}
156#endif
157#endif
158
159#if defined(__ARM_FEATURE_FMA)
160template <>
161inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
162 return vfmaq_f32(c, b, a);
163}
164#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
165template <>
166inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
167 return vfmaq_f16(c, b, a);
168}
169#endif
170#endif
171
172#if defined(__VXE__) || defined(__VXE2__)
173template <>
174inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
175 return vec_madd(a, b, c);
176}
177#endif
178
179////////////////////////////////////////////////////////////////////////////////////////////////////
180// VECTORIZED HORIZONTAL SUM
181
182#if defined(__ARM_NEON)
183inline float hsum(float32x4_t x) {
184 return vaddvq_f32(x);
185}
186#endif // __ARM_NEON
187
188#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
189inline float hsum(float16x8_t x) {
190 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
191 vcvt_f32_f16(vget_high_f16(x))));
192}
193#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
194
195#if defined(__VXE__) || defined(__VXE2__)
196inline float hsum(float32x4_t x) {
197 float32x4_t tmp = x + vec_reve(x);
198 return tmp[0] + tmp[1];
199}
200#endif
201
202#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
203inline float hsum(__m128 x) {
204#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
205 x = _mm_add_ps(a: x, b: _mm_movehl_ps(a: x, b: x));
206 x = _mm_add_ss(a: x, b: _mm_movehdup_ps(a: x));
207#else
208 __m128 t;
209 t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
210 x = _mm_add_ps(x, t);
211 t = _mm_movehl_ps(t, x);
212 x = _mm_add_ss(x, t);
213#endif
214 return _mm_cvtss_f32(a: x);
215}
216#endif
217
218#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
219inline float hsum(__m256 x) {
220 return hsum(x: _mm_add_ps(_mm256_extractf128_ps(x, 1),
221 b: _mm256_castps256_ps128(a: x)));
222}
223#endif // __AVX__
224
225#if defined(__AVX512F__)
226inline float hsum(__m512 x) {
227 return _mm512_reduce_add_ps(x);
228}
229#endif // __AVX512F__
230
231////////////////////////////////////////////////////////////////////////////////////////////////////
232// VECTORIZED MEMORY LOADING
233
234template <typename T, typename U> T load(const U *);
235
236#if defined(__ARM_NEON)
237template <> inline float32x4_t load(const float *p) {
238 return vld1q_f32(p);
239}
240#if !defined(_MSC_VER)
241// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
242template <> inline float16x8_t load(const ggml_fp16_t *p) {
243 return vld1q_f16((const float16_t *)p);
244}
245template <> inline float32x4_t load(const ggml_fp16_t *p) {
246 return vcvt_f32_f16(vld1_f16((const float16_t *)p));
247}
248#endif // _MSC_VER
249#endif // __ARM_NEON
250
251#if defined(__VXE__) || defined(__VXE2__)
252template <> inline float32x4_t load(const ggml_fp16_t * p) {
253 float tmp[4];
254
255 for (int i = 0; i < 4; i++) {
256 tmp[i] = GGML_CPU_FP16_TO_FP32(p[i]);
257 }
258
259 return vec_xl(0, (const float *)(tmp));
260}
261template <> inline float32x4_t load(const float * p) {
262 return vec_xl(0, p);
263}
264#endif
265
266#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
267template <> inline __m128 load(const float *p) {
268 return _mm_loadu_ps(p: p);
269}
270#endif // __SSE__
271
272#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
273template <> inline __m256 load(const float *p) {
274 return _mm256_loadu_ps(p: p);
275}
276#endif // __AVX__
277
278#if defined(__AVX2__) || defined(__AVX512F__)
279template <> inline __m256 load(const ggml_bf16_t *p) {
280 return _mm256_castsi256_ps(
281 a: _mm256_slli_epi32(a: _mm256_cvtepu16_epi32(V: _mm_loadu_si128(p: (const __m128i *)p)), count: 16));
282}
283#endif // __AVX2__
284
285#if defined(__F16C__)
286template <> inline __m256 load(const ggml_fp16_t *p) {
287 return _mm256_cvtph_ps(a: _mm_loadu_si128(p: (const __m128i *)p));
288}
289#endif // __F16C__
290
291#if defined(__AVX512F__)
292template <> inline __m512 load(const float *p) {
293 return _mm512_loadu_ps(p);
294}
295template <> inline __m512 load(const ggml_fp16_t *p) {
296 return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
297}
298template <> inline __m512 load(const ggml_bf16_t *p) {
299 return _mm512_castsi512_ps(
300 _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16));
301}
302#endif // __AVX512F__
303
304#if defined(__AVX512BF16__)
305template <> inline __m512bh load(const ggml_bf16_t *p) {
306 return (__m512bh)_mm512_loadu_ps((const float *)p);
307}
308template <> inline __m256bh load(const ggml_bf16_t *p) {
309 return (__m256bh)_mm256_loadu_ps((const float *)p);
310}
311template <> inline __m512bh load(const float *p) {
312 return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));
313}
314template <> inline __m256bh load(const float *p) {
315 return _mm512_cvtneps_pbh(_mm512_loadu_ps(p));
316}
317#endif
318
319////////////////////////////////////////////////////////////////////////////////////////////////////
320// FLOATING POINT MATRIX MULTIPLICATION
321
322template <int M>
323static inline int64_t BLOCK_SIZE(size_t m) {
324 const int64_t NB_BLOC_M = (m + M - 1) / M;
325 return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
326}
327
328static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
329 return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
330}
331
332template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
333class tinyBLAS {
334 public:
335 tinyBLAS(const ggml_compute_params * params, int64_t k,
336 const TA *A, int64_t lda,
337 const TB *B, int64_t ldb,
338 TC *C, int64_t ldc)
339 : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
340 }
341
342 bool matmul(int64_t m, int64_t n) {
343 if (k % KN != 0)
344 return false;
345 // compute RM for only need tile with size RM&RM-1
346#if VECTOR_REGISTERS == 32
347 if (m % 16 == 0 && (m/16 >= params->nth)) {
348 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
349 mnpack<4, 6, 4>(m, n, SIZE_N, 12);
350 return true;
351 }
352 if (m % 8 == 0 ) {
353 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
354 mnpack<4, 6, 2>(m, n, SIZE_N, 12);
355 return true;
356 }
357 if (m % 4 == 0) {
358 const int64_t SIZE_N = BLOCK_SIZE<6>(n);
359 mnpack<4, 6, 1>(m, n, SIZE_N, 12);
360 return true;
361 }
362#else // VECTOR_REGISTERS == 16
363 if (m % 16 == 0 && (m/16 >= params->nth)) {
364 const int64_t SIZE_N = BLOCK_SIZE<3>(m: n);
365 mnpack<4, 3, 4>(m, n, SIZE_N, 24);
366 return true;
367 }
368 if (m % 8 == 0 ) {
369 const int64_t SIZE_N = BLOCK_SIZE<3>(m: n);
370 mnpack<4, 3, 2>(m, n, SIZE_N, 24);
371 return true;
372 }
373 if (m % 4 == 0) {
374 const int64_t SIZE_N = BLOCK_SIZE<3>(m: n);
375 mnpack<4, 3, 1>(m, n, SIZE_N, 24);
376 return true;
377 }
378#endif
379 return false;
380 }
381
382 private:
383 template <int RM, int RN, int BM>
384 inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
385 if (SIZE_N == RN) {
386 return gemm<RM, RN, BM>(m, n, BN);
387 }
388 if constexpr (RN > 1) {
389 return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
390 } else {
391 GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
392 GGML_ASSERT(false); // we have miss something.
393 }
394 }
395
396 template <int RM, int RN>
397 inline void gemm_bloc(int64_t ii, int64_t jj) {
398 D Cv[RN][RM] = {};
399 for (int64_t l = 0; l < k; l += KN) {
400 // help compiler for op order.
401 if constexpr (RM <= RN) {
402 V Av[RM];
403 for (int64_t i = 0; i < RM; ++i) {
404 Av[i] = load<V>(A + lda * (ii + i) + l);
405 }
406 for (int64_t j = 0; j < RN; ++j) {
407 V Bv = load<V>(B + ldb * (jj + j) + l);
408 for (int64_t i = 0; i < RM; ++i) {
409 Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
410 }
411 }
412 } else {
413 V Bv[RN];
414 for (int64_t j = 0; j < RN; ++j) {
415 Bv[j] = load<V>(B + ldb * (jj + j) + l);
416 }
417 for (int64_t i = 0; i < RM; ++i) {
418 V Av = load<V>(A + lda * (ii + i) + l);
419 for (int64_t j = 0; j < RN; ++j) {
420 Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
421 }
422 }
423 }
424 }
425 for (int64_t j = 0; j < RN; ++j)
426 for (int64_t i = 0; i < RM; ++i)
427 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
428 }
429
430 template <int RM, int RN, int BM>
431 NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
432 GGML_ASSERT(m % (RM * BM) == 0);
433 const int64_t ytiles = m / (RM * BM);
434 const int64_t xtiles = (n + RN -1) / RN;
435 const int64_t jj_RN = (xtiles - (xtiles * RN - n));
436
437 // "round" bloc_size to "nearest" BN
438 const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
439 const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
440 const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
441 const int64_t nb_job = ytiles * NB_BN;
442
443 if (params->ith == 0) {
444 GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
445 // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
446 ggml_threadpool_chunk_set(tp: params->threadpool, value: params->nth);
447 }
448
449 ggml_barrier(tp: params->threadpool);
450
451 int64_t job = params->ith;
452 while (job < nb_job) {
453 const int64_t ii = (job % ytiles) * RM * BM;
454 const int64_t jb = job / ytiles;
455 const int64_t jr0 = BLOC_POS(ib: jb , ibN: jj_BN, bloc_size: SIZE_BN);
456 const int64_t jrN = BLOC_POS(ib: jb+1, ibN: jj_BN, bloc_size: SIZE_BN);
457
458 const int64_t jj0 = BLOC_POS(ib: jr0, ibN: jj_RN, bloc_size: RN);
459 const int64_t jj2 = BLOC_POS(ib: jrN, ibN: jj_RN, bloc_size: RN);
460 const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
461
462 for (int64_t bi = 0; bi < BM * RM; bi += RM) {
463 int64_t jj = jj0;
464 for (; jj < jj1; jj += RN) {
465 gemm_bloc<RM, RN>(ii + bi, jj);
466 }
467 if constexpr (RN > 1) {
468 for (; jj < jj2; jj += RN - 1) {
469 gemm_bloc<RM, RN-1>(ii + bi, jj);
470 }
471 }
472 GGML_ASSERT(jj == jj2);
473 }
474
475 job = ggml_threadpool_chunk_add(tp: params->threadpool, value: 1);
476 }
477
478 ggml_barrier(tp: params->threadpool);
479 return;
480 }
481
482 const ggml_compute_params * params;
483 const TA *const A;
484 const TB *const B;
485 TC *const C;
486 const int64_t k;
487 const int64_t lda;
488 const int64_t ldb;
489 const int64_t ldc;
490};
491
492//////////////////////////////////////////////////////////////////////////////////////////
493// QUANT ZERO MATRIX MULTIPLICATION
494
495#if defined(__ARM_FEATURE_DOTPROD)
496template <typename TA>
497class tinyBLAS_Q0_ARM {
498 public:
499 tinyBLAS_Q0_ARM(int64_t k,
500 const TA *A, int64_t lda,
501 const block_q8_0 *B, int64_t ldb,
502 float *C, int64_t ldc,
503 int ith, int nth)
504 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
505 }
506
507 void matmul(int64_t m, int64_t n) {
508 mnpack(0, m, 0, n);
509 }
510
511 private:
512 NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
513 int64_t mc, nc, mp, np;
514 switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
515 case 0x33:
516 mc = 3;
517 nc = 3;
518 gemm<3, 3>(m0, m, n0, n);
519 break;
520 case 0x32:
521 mc = 3;
522 nc = 2;
523 gemm<3, 2>(m0, m, n0, n);
524 break;
525 case 0x23:
526 mc = 2;
527 nc = 3;
528 gemm<2, 3>(m0, m, n0, n);
529 break;
530 case 0x22:
531 mc = 2;
532 nc = 2;
533 gemm<2, 2>(m0, m, n0, n);
534 break;
535 case 0x31:
536 mc = 3;
537 nc = 1;
538 gemm<3, 1>(m0, m, n0, n);
539 break;
540 case 0x13:
541 mc = 1;
542 nc = 3;
543 gemm<1, 3>(m0, m, n0, n);
544 break;
545 case 0x21:
546 mc = 2;
547 nc = 1;
548 gemm<2, 1>(m0, m, n0, n);
549 break;
550 case 0x12:
551 mc = 1;
552 nc = 2;
553 gemm<1, 2>(m0, m, n0, n);
554 break;
555 case 0x11:
556 mc = 1;
557 nc = 1;
558 gemm<1, 1>(m0, m, n0, n);
559 break;
560 default:
561 return;
562 }
563 mp = m0 + (m - m0) / mc * mc;
564 np = n0 + (n - n0) / nc * nc;
565 mnpack(mp, m, n0, np);
566 mnpack(m0, m, np, n);
567 }
568
569 template <int RM, int RN>
570 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
571 int64_t ytiles = (m - m0) / RM;
572 int64_t xtiles = (n - n0) / RN;
573 int64_t tiles = xtiles * ytiles;
574 int64_t duty = (tiles + nth - 1) / nth;
575 int64_t start = duty * ith;
576 int64_t end = start + duty;
577 if (end > tiles)
578 end = tiles;
579 for (int64_t job = start; job < end; ++job) {
580 int64_t ii = m0 + job / xtiles * RM;
581 int64_t jj = n0 + job % xtiles * RN;
582 float32x4_t Cv[RN][RM] = {};
583 for (int64_t l = 0; l < k; ++l)
584 for (int64_t j = 0; j < RN; ++j)
585 for (int64_t i = 0; i < RM; ++i)
586 Cv[j][i] = vmlaq_n_f32(Cv[j][i],
587 vcvtq_f32_s32(vdotq_s32(
588 vdotq_s32(vdupq_n_s32(0),
589 load_lo(A + lda * (ii + i) + l),
590 load_lo(B + ldb * (jj + j) + l)),
591 load_hi(A + lda * (ii + i) + l),
592 load_hi(B + ldb * (jj + j) + l))),
593 unhalf(A[lda * (ii + i) + l].d) *
594 unhalf(B[ldb * (jj + j) + l].d));
595 for (int64_t j = 0; j < RN; ++j)
596 for (int64_t i = 0; i < RM; ++i)
597 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
598 }
599 }
600
601 inline int8x16_t load_lo(const block_q8_0 *b) {
602 return vld1q_s8(b->qs);
603 }
604
605 inline int8x16_t load_hi(const block_q8_0 *b) {
606 return vld1q_s8(b->qs + 16);
607 }
608
609 inline int8x16_t load_lo(const block_q4_0 *b) {
610 return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
611 vdupq_n_u8(0x0f))),
612 vdupq_n_s8(0x8));
613 }
614
615 inline int8x16_t load_hi(const block_q4_0 *b) {
616 return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
617 vdupq_n_s8(0x8));
618 }
619
620 const TA *const A;
621 const block_q8_0 *const B;
622 float *const C;
623 const int64_t k;
624 const int64_t lda;
625 const int64_t ldb;
626 const int64_t ldc;
627 const int ith;
628 const int nth;
629};
630#endif // __ARM_FEATURE_DOTPROD
631
632#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
633template <typename TA, typename TB, typename TC>
634class tinyBLAS_Q0_AVX {
635 public:
636 tinyBLAS_Q0_AVX(int64_t k,
637 const TA *A, int64_t lda,
638 const TB *B, int64_t ldb,
639 TC *C, int64_t ldc,
640 int ith, int nth)
641 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
642 const int8_t kvalues_iq4nl[16] = {
643 -127, -104, -83, -65,
644 -49, -35, -22, -10,
645 1, 13, 25, 38,
646 53, 69, 89, 113
647 };
648
649 iq4nlt = _mm_loadu_si128(p: (const __m128i *)kvalues_iq4nl);
650 }
651
652 void matmul(int64_t m, int64_t n) {
653 mnpack(m0: 0, m, n0: 0, n);
654 }
655
656 private:
657 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
658 int64_t mc, nc, mp, np;
659 switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
660#if VECTOR_REGISTERS == 32
661 case 0x44:
662 mc = 4;
663 nc = 4;
664#if defined(__AVX2__) && defined(__F16C__)
665 gemm4xN<4>(m0, m, n0, n);
666#else
667 gemm<4, 4>(m0, m, n0, n);
668#endif
669 break;
670 case 0x43:
671 mc = 4;
672 nc = 3;
673#if defined(__AVX2__) && defined(__F16C__)
674 gemm4xN<3>(m0, m, n0, n);
675#else
676 gemm<4, 3>(m0, m, n0, n);
677#endif
678 break;
679 case 0x34:
680 mc = 3;
681 nc = 4;
682#if defined(__AVX2__) && defined(__F16C__)
683 gemmMx4<3>(m0, m, n0, n);
684#else
685 gemm<3, 4>(m0, m, n0, n);
686#endif
687 break;
688 case 0x33:
689 mc = 3;
690 nc = 3;
691 gemm<3, 3>(m0, m, n0, n);
692 break;
693 case 0x42:
694 mc = 4;
695 nc = 2;
696#if defined(__AVX2__) && defined(__F16C__)
697 gemm4xN<2>(m0, m, n0, n);
698#else
699 gemm<4, 2>(m0, m, n0, n);
700#endif
701 break;
702 case 0x24:
703 mc = 2;
704 nc = 4;
705#if defined(__AVX2__) && defined(__F16C__)
706 gemmMx4<2>(m0, m, n0, n);
707#else
708 gemm<2, 4>(m0, m, n0, n);
709#endif
710 break;
711#else
712 case 0x44:
713 case 0x43:
714 case 0x42:
715 mc = 4;
716 nc = 2;
717#if defined(__AVX2__) && defined(__F16C__)
718 gemm4xN<2>(m0, m, n0, n);
719#else
720 gemm<4, 2>(m0, m, n0, n);
721#endif
722 break;
723 case 0x34:
724 case 0x24:
725 mc = 2;
726 nc = 4;
727#if defined(__AVX2__) && defined(__F16C__)
728 gemmMx4<2>(m0, m, n0, n);
729#else
730 gemm<2, 4>(m0, m, n0, n);
731#endif
732 break;
733 case 0x33:
734#endif
735 case 0x32:
736 mc = 3;
737 nc = 2;
738 gemm<3, 2>(m0, m, n0, n);
739 break;
740 case 0x23:
741 mc = 2;
742 nc = 3;
743 gemm<2, 3>(m0, m, n0, n);
744 break;
745 case 0x41:
746 mc = 4;
747 nc = 1;
748#if defined(__AVX2__) && defined(__F16C__)
749 gemm4xN<1>(m0, m, n0, n);
750#else
751 gemm<4, 1>(m0, m, n0, n);
752#endif
753 break;
754 case 0x22:
755 mc = 2;
756 nc = 2;
757 gemm<2, 2>(m0, m, n0, n);
758 break;
759 case 0x14:
760 mc = 1;
761 nc = 4;
762#if defined(__AVX2__) && defined(__F16C__)
763 gemmMx4<1>(m0, m, n0, n);
764#else
765 gemm<1, 4>(m0, m, n0, n);
766#endif
767 break;
768 case 0x31:
769 mc = 3;
770 nc = 1;
771 gemm<3, 1>(m0, m, n0, n);
772 break;
773 case 0x13:
774 mc = 1;
775 nc = 3;
776 gemm<1, 3>(m0, m, n0, n);
777 break;
778 case 0x21:
779 mc = 2;
780 nc = 1;
781 gemm<2, 1>(m0, m, n0, n);
782 break;
783 case 0x12:
784 mc = 1;
785 nc = 2;
786 gemm<1, 2>(m0, m, n0, n);
787 break;
788 case 0x11:
789 mc = 1;
790 nc = 1;
791 gemm<1, 1>(m0, m, n0, n);
792 break;
793 default:
794 return;
795 }
796 mp = m0 + (m - m0) / mc * mc;
797 np = n0 + (n - n0) / nc * nc;
798 mnpack(m0: mp, m, n0, n: np);
799 mnpack(m0, m, n0: np, n);
800 }
801
802#if defined(__AVX2__) && defined(__F16C__)
803// Templated functions for gemm of dimensions 4xN
804 template <int RN>
805 NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
806 int64_t ytiles = (m - m0) / 4;
807 int64_t xtiles = (n - n0) / RN;
808 int64_t tiles = xtiles * ytiles;
809 int64_t duty = (tiles + nth - 1) / nth;
810 int64_t start = duty * ith;
811 int64_t end = start + duty;
812 if (end > tiles)
813 end = tiles;
814 for (int64_t job = start; job < end; ++job) {
815 int64_t ii = m0 + job / xtiles * 4;
816 int64_t jj = n0 + job % xtiles * RN;
817 __m256 Cv[RN][4] = {};
818 for (int64_t l = 0; l < k; ++l) {
819 uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
820 // Convert delta values for four blocks to float values
821 __m128 da = _mm_cvtph_ps(a: _mm_set_epi64x(q1: 0, q0: a_delta));
822 __m256i avec0 = load(A + lda * (ii + 0) + l);
823 __m256i avec1 = load(A + lda * (ii + 1) + l);
824 __m256i avec2 = load(A + lda * (ii + 2) + l);
825 __m256i avec3 = load(A + lda * (ii + 3) + l);
826 for (int64_t j = 0; j < RN; ++j) {
827 __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
828 // Computation of product of delta values for four blocks and replicate it across 256 bit lane
829 __m256 dvec = _mm256_castps128_ps256(a: _mm_mul_ps(a: da, b: db));
830 dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
831 // Computation of dot product and multiplication with appropriate delta value products
832 Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
833 updot(u: _mm256_sign_epi8(a: avec0, b: avec0),
834 s: _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
835 Cv[j][0]);
836 Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
837 updot(u: _mm256_sign_epi8(a: avec1, b: avec1),
838 s: _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
839 Cv[j][1]);
840 Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
841 updot(u: _mm256_sign_epi8(a: avec2, b: avec2),
842 s: _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
843 Cv[j][2]);
844 Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
845 updot(u: _mm256_sign_epi8(a: avec3, b: avec3),
846 s: _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
847 Cv[j][3]);
848 }
849 }
850
851 for (int64_t j = 0; j < RN; ++j)
852 for (int64_t i = 0; i < 4; ++i)
853 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
854 }
855 }
856
857 // Templated functions for gemm of dimensions Mx4
858 template <int RM>
859 NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
860 int64_t ytiles = (m - m0) / RM;
861 int64_t xtiles = (n - n0) / 4;
862 int64_t tiles = xtiles * ytiles;
863 int64_t duty = (tiles + nth - 1) / nth;
864 int64_t start = duty * ith;
865 int64_t end = start + duty;
866 if (end > tiles)
867 end = tiles;
868 for (int64_t job = start; job < end; ++job) {
869 int64_t ii = m0 + job / xtiles * RM;
870 int64_t jj = n0 + job % xtiles * 4;
871 __m256 Cv[4][RM] = {};
872 for (int64_t l = 0; l < k; ++l) {
873 uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
874 // Convert delta values for four blocks to float values
875 __m128 db = _mm_cvtph_ps(a: _mm_set_epi64x(q1: 0, q0: b_delta));
876 __m256i bvec0 = load(B + ldb * (jj + 0) + l);
877 __m256i bvec1 = load(B + ldb * (jj + 1) + l);
878 __m256i bvec2 = load(B + ldb * (jj + 2) + l);
879 __m256i bvec3 = load(B + ldb * (jj + 3) + l);
880 for (int64_t i = 0; i < RM; ++i) {
881 __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
882 // Computation of product of delta values for four blocks and replicate it across 256 bit lane
883 __m256 dvec = _mm256_castps128_ps256(a: _mm_mul_ps(a: da, b: db));
884 dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
885 // Computation of dot product and multiplication with appropriate delta value products
886 Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
887 updot(u: _mm256_sign_epi8(load(A + lda * (ii + i) + l),
888 load(A + lda * (ii + i) + l)),
889 s: _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
890 Cv[0][i]);
891 Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
892 updot(u: _mm256_sign_epi8(load(A + lda * (ii + i) + l),
893 load(A + lda * (ii + i) + l)),
894 s: _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
895 Cv[1][i]);
896 Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
897 updot(u: _mm256_sign_epi8(load(A + lda * (ii + i) + l),
898 load(A + lda * (ii + i) + l)),
899 s: _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
900 Cv[2][i]);
901 Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
902 updot(u: _mm256_sign_epi8(load(A + lda * (ii + i) + l),
903 load(A + lda * (ii + i) + l)),
904 s: _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
905 Cv[3][i]);
906 }
907 }
908 for (int64_t j = 0; j < 4; ++j)
909 for (int64_t i = 0; i < RM; ++i)
910 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
911 }
912 }
913#endif
914
915 template <int RM, int RN>
916 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
917 int64_t ytiles = (m - m0) / RM;
918 int64_t xtiles = (n - n0) / RN;
919 int64_t tiles = xtiles * ytiles;
920 int64_t duty = (tiles + nth - 1) / nth;
921 int64_t start = duty * ith;
922 int64_t end = start + duty;
923 if (end > tiles)
924 end = tiles;
925 for (int64_t job = start; job < end; ++job) {
926 int64_t ii = m0 + job / xtiles * RM;
927 int64_t jj = n0 + job % xtiles * RN;
928 __m256 Cv[RN][RM] = {};
929 for (int64_t l = 0; l < k; ++l)
930 for (int64_t j = 0; j < RN; ++j)
931 for (int64_t i = 0; i < RM; ++i) {
932#if defined(__AVX2__)
933 __m256 udTmp = updot(u: _mm256_sign_epi8(load(A + lda * (ii + i) + l),
934 load(A + lda * (ii + i) + l)),
935 s: _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
936 load(A + lda * (ii + i) + l)));
937#else
938 __m128i ali0 = load0(A + lda * (ii + i) + l);
939 __m128i ali1 = load1(A + lda * (ii + i) + l);
940 __m128i blj0 = load0(B + ldb * (jj + j) + l);
941 __m128i blj1 = load1(B + ldb * (jj + j) + l);
942
943 __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
944 __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
945 __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
946 __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
947
948 // updot
949 const __m128i oneFill = _mm_set1_epi16(1);
950 __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
951 __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
952 __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
953#endif
954 Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
955 unhalf(B[ldb * (jj + j) + l].d)),
956 udTmp,
957 Cv[j][i]);
958 }
959 for (int64_t j = 0; j < RN; ++j)
960 for (int64_t i = 0; i < RM; ++i)
961 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
962 }
963 }
964
965 inline __m256i load(const block_q8_0 *b) {
966 return _mm256_loadu_si256(p: (const __m256i *)b->qs);
967 }
968
969 inline __m128i load0(const block_q8_0 *b) {
970 return _mm_loadu_si128(p: (const __m128i *)b->qs);
971 }
972
973 inline __m128i load1(const block_q8_0 *b) {
974 return _mm_loadu_si128(p: ((const __m128i *)b->qs) + 1);
975 }
976
977 inline __m256i load(const block_q4_0 *b) {
978 return _mm256_sub_epi8(a: denibble(p: b->qs), b: _mm256_set1_epi8(b: 8));
979 }
980
981 inline __m128i load0(const block_q4_0 *b) {
982 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
983 return _mm_sub_epi8(a: _mm_and_si128(a: _mm_set1_epi8(b: 15), b: x), b: _mm_set1_epi8(b: 8));
984 }
985
986 inline __m128i load1(const block_q4_0 *b) {
987 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
988 return _mm_sub_epi8(a: _mm_and_si128(a: _mm_set1_epi8(b: 15), b: _mm_srli_epi16(a: x, count: 4)), b: _mm_set1_epi8(b: 8));
989 }
990
991 inline __m256i load(const block_q5_0 *b) {
992 return _mm256_or_si256(a: denibble(p: b->qs), b: bittobyte(p: b->qh));
993 }
994
995 inline __m128i load0(const block_q5_0* b) {
996 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
997 uint32_t x32;
998 memcpy(dest: &x32, src: b->qh, n: sizeof(uint32_t));
999 __m128i qxl = _mm_and_si128(a: _mm_set1_epi8(b: 15), b: x);
1000 __m128i bytesl = _mm_cmpeq_epi8(a: _mm_set1_epi64x(q: -1),
1001 b: _mm_or_si128(a: _mm_set1_epi64x(q: 0x7fbfdfeff7fbfdfe),
1002 b: _mm_shuffle_epi8(a: _mm_set1_epi32(i: x32),
1003 b: _mm_set_epi64x(q1: 0x0101010101010101, q0: 0x0000000000000000))));
1004 bytesl = _mm_andnot_si128(a: bytesl, b: _mm_set1_epi8(b: (char)0xF0));
1005 return _mm_or_si128(a: qxl, b: bytesl);
1006 }
1007
1008 inline __m128i load1(const block_q5_0* b) {
1009 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
1010 uint32_t x32;
1011 memcpy(dest: &x32, src: b->qh, n: sizeof(uint32_t));
1012 __m128i qxh = _mm_and_si128(a: _mm_set1_epi8(b: 15), b: _mm_srli_epi16(a: x, count: 4));
1013 __m128i bytesh = _mm_cmpeq_epi8(a: _mm_set1_epi64x(q: -1),
1014 b: _mm_or_si128(a: _mm_set1_epi64x(q: 0x7fbfdfeff7fbfdfe),
1015 b: _mm_shuffle_epi8(a: _mm_set1_epi32(i: x32),
1016 b: _mm_set_epi64x(q1: 0x0303030303030303, q0: 0x0202020202020202))));
1017 bytesh = _mm_andnot_si128(a: bytesh, b: _mm_set1_epi8(b: (char)0xF0));
1018 return _mm_or_si128(a: qxh, b: bytesh);
1019 }
1020
1021 inline __m256i load(const block_iq4_nl *b) {
1022 return MM256_SET_M128I(load1(b), load0(b));
1023 }
1024
1025 inline __m128i load0(const block_iq4_nl *b) {
1026 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
1027 return _mm_shuffle_epi8(a: iq4nlt, b: _mm_and_si128(a: _mm_set1_epi8(b: 15), b: x));
1028 }
1029
1030 inline __m128i load1(const block_iq4_nl *b) {
1031 const __m128i x = _mm_loadu_si128(p: (const __m128i *)(b->qs));
1032 return _mm_shuffle_epi8(a: iq4nlt, b: _mm_and_si128(a: _mm_set1_epi8(b: 15), b: _mm_srli_epi16(a: x, count: 4)));
1033 }
1034
1035 inline __m256 updot(__m256i u, __m256i s) {
1036 __m256i res;
1037#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
1038 res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
1039#elif defined(__AVXVNNI__)
1040 res = _mm256_dpbusd_avx_epi32(S: _mm256_setzero_si256(), A: u, B: s);
1041#else
1042 res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
1043#endif
1044 return _mm256_cvtepi32_ps(a: res);
1045 }
1046
1047 static inline __m256i denibble(const uint8_t *p) {
1048 __m128i x = _mm_loadu_si128(p: (const __m128i *)p);
1049 return _mm256_and_si256(a: _mm256_set1_epi8(b: 15),
1050 _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1051 _mm_srli_epi16(x, 4), 1));
1052 }
1053
1054 static inline __m256i bittobyte(const uint8_t *p) {
1055 uint32_t x32;
1056 memcpy(dest: &x32, src: p, n: sizeof(uint32_t));
1057 __m256i bytes = _mm256_cmpeq_epi8(a: _mm256_set1_epi64x(q: -1),
1058 b: _mm256_or_si256(a: _mm256_set1_epi64x(q: 0x7fbfdfeff7fbfdfe),
1059 b: _mm256_shuffle_epi8(a: _mm256_set1_epi32(i: x32),
1060 b: _mm256_set_epi64x(a: 0x0303030303030303, b: 0x0202020202020202,
1061 c: 0x0101010101010101, d: 0x0000000000000000))));
1062 return _mm256_andnot_si256(a: bytes, b: _mm256_set1_epi8(b: (char)0xF0));
1063 }
1064
1065 const TA *const A;
1066 const TB *const B;
1067 TC *const C;
1068 const int64_t k;
1069 const int64_t lda;
1070 const int64_t ldb;
1071 const int64_t ldc;
1072 const int ith;
1073 const int nth;
1074 __m128i iq4nlt;
1075};
1076#endif // __AVX__
1077
1078//PPC Implementation
1079#if defined(__MMA__)
1080
1081#define SAVE_ACC(ACC, ii, jj) \
1082 __builtin_mma_disassemble_acc(vec_C, ACC); \
1083 for (int I = 0; I < 4; I++) { \
1084 for (int J = 0; J < 4; J++) { \
1085 *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1086 } \
1087 } \
1088
1089template <typename TA, typename TB, typename TC>
1090class tinyBLAS_BF16_PPC {
1091 public:
1092 tinyBLAS_BF16_PPC(int64_t k,
1093 const TA *A, int64_t lda,
1094 const TB *B, int64_t ldb,
1095 TC *C, int64_t ldc,
1096 int ith, int nth)
1097 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1098 }
1099
1100 void matmul(int64_t m, int64_t n) {
1101 mnpack(0, m, 0, n);
1102 }
1103
1104 private:
1105 void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
1106 vec_t t[8], s[8];
1107 vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
1108 vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
1109 vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1110 vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1111
1112 if (numVec == 2) {
1113 t[0] = vec_perm(c[0], c[1], swiz1);
1114 t[1] = vec_perm(c[2], c[3], swiz1);
1115 s[0] = vec_perm(t[0], t[1], swiz3);
1116 s[1] = vec_perm(t[0], t[1], swiz4);
1117 vec_xst(s[0], 0, (vec_t*)vecOffset);
1118 vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
1119 } else if (numVec == 4) {
1120 t[0] = vec_perm(c[0], c[1], swiz1);
1121 t[1] = vec_perm(c[0], c[1], swiz2);
1122 t[2] = vec_perm(c[2], c[3], swiz1);
1123 t[3] = vec_perm(c[2], c[3], swiz2);
1124 s[0] = vec_perm(t[0], t[2], swiz3);
1125 s[1] = vec_perm(t[0], t[2], swiz4);
1126 s[2] = vec_perm(t[1], t[3], swiz3);
1127 s[3] = vec_perm(t[1], t[3], swiz4);
1128 for (int i = 0; i < 4; ++i)
1129 vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1130 } else if (numVec == 8) {
1131 for (int i = 0; i < 4; i += 2) {
1132 t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1133 t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1134 }
1135 for (int i = 4; i < 8; i += 2) {
1136 t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
1137 t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
1138 }
1139 s[0] = vec_perm(t[0], t[2], swiz3);
1140 s[1] = vec_perm(t[0], t[2], swiz4);
1141 s[2] = vec_perm(t[1], t[3], swiz3);
1142 s[3] = vec_perm(t[1], t[3], swiz4);
1143 s[4] = vec_perm(t[4], t[6], swiz3);
1144 s[5] = vec_perm(t[4], t[6], swiz4);
1145 s[6] = vec_perm(t[5], t[7], swiz3);
1146 s[7] = vec_perm(t[5], t[7], swiz4);
1147 for (int i = 0; i < 8; ++i)
1148 vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
1149 }
1150 }
1151
1152 void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
1153 int64_t i, j;
1154 TA *aoffset = NULL;
1155 unsigned char *vecOffset = NULL;
1156 TA * aoffsets[8];
1157 vector unsigned char c_arr[8];
1158 aoffset = const_cast<TA*>(a);
1159 vecOffset = vec;
1160 j = (rows >> 3);
1161 if (j > 0) {
1162 do {
1163 if (cols == 4) {
1164 aoffsets[0] = aoffset;
1165 for (int it = 1; it < 4; ++it)
1166 aoffsets[it] = aoffsets[it-1] + lda;
1167 aoffset += 4 * lda;
1168 for (int i = 0; i < 4; ++i)
1169 c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
1170 vector_permute_store(c_arr, 4, vecOffset);
1171 for (int i = 0; i<4; i++)
1172 aoffsets[i] = aoffsets[i]+lda;
1173 vecOffset +=64;
1174 }
1175 i = (cols >> 3);
1176 if (i > 0) {
1177 aoffsets[0] = aoffset;
1178 for (int it = 1; it < 8; ++it) {
1179 aoffsets[it] = aoffsets[it-1] + lda;
1180 }
1181 aoffset += 8 * lda;
1182 do {
1183 for (int it = 0; it < 8; ++it)
1184 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1185 vector_permute_store(c_arr, 8, vecOffset);
1186 for (int it = 0; it < 8; ++it)
1187 aoffsets[it] = aoffsets[it] + 8*lda;
1188 vecOffset += 128;
1189 i--;
1190 } while(i > 0);
1191 }
1192 j--;
1193 } while(j > 0);
1194 }
1195 if (rows & 4) {
1196 aoffsets[0] = aoffset;
1197 for (int it = 1; it < 4; ++it)
1198 aoffsets[it] = aoffsets[it-1] + lda;
1199 aoffset += 4 * lda;
1200 if (cols == 4) {
1201 for (int it = 0; it < 4; ++it)
1202 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1203 vector_permute_store(c_arr, 2, vecOffset);
1204 for (int it = 0; it< 4; it++)
1205 aoffsets[it] = aoffsets[it] + lda;
1206 vecOffset += 32;
1207 }
1208 i = (cols >> 3);
1209 if (i > 0) {
1210 do {
1211 for (int it = 0; it < 4; ++it)
1212 c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
1213 vector_permute_store(c_arr, 4, vecOffset);
1214 for (int it = 0; it< 4; it++)
1215 aoffsets[it] = aoffsets[it] + 8*lda;
1216 vecOffset += 64;
1217 i--;
1218 } while(i > 0);
1219 }
1220 }
1221 if (rows & 3) {
1222 aoffsets[0] = aoffset;
1223 for (int it = 1; it < 4; ++it)
1224 aoffsets[it] = aoffsets[it-1] + lda;
1225 if (cols == 4) {
1226 switch(rows) {
1227 case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1228 case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1229 case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1230 break;
1231 }
1232 vector_permute_store(c_arr, 2, vecOffset);
1233 for (int it = 0; it< 4; it++)
1234 aoffsets[it] = aoffsets[it] + lda;
1235 vecOffset += 32;
1236 }
1237 i = (cols >> 3);
1238 if (i > 0) {
1239 do {
1240 switch(rows) {
1241 case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
1242 case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
1243 case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
1244 break;
1245 }
1246 vector_permute_store(c_arr, 4, vecOffset);
1247 for (int it = 0; it <4; it++)
1248 aoffsets[it] = aoffsets[it] + 8* lda;
1249 vecOffset += 64;
1250 i--;
1251 } while(i > 0);
1252 }
1253 }
1254 }
1255
1256 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1257 int64_t mc, nc, mp, np;
1258 int m_rem = MIN(m - m0, 8);
1259 int n_rem = MIN(n - n0, 8);
1260
1261 if (m_rem >= 8 && n_rem >= 8) {
1262 mc = 8;
1263 nc = 8;
1264 gemm<8,8>(m0, m, n0, n);
1265 } else if (m_rem >= 4 && n_rem >= 8) {
1266 mc = 4;
1267 nc = 8;
1268 gemm<4,8>(m0, m, n0, n);
1269 } else if (m_rem >=8 && n_rem >=4){
1270 mc = 8;
1271 nc = 4;
1272 gemm<8,4>(m0, m, n0, n);
1273 } else if ((m_rem < 4) && (n_rem >= 8)) {
1274 nc = 8;
1275 switch(m_rem) {
1276 case 1:
1277 mc = 1;
1278 gemm_Mx8<1>(m0, m, n0, n);
1279 break;
1280 case 2:
1281 mc = 2;
1282 gemm_Mx8<2>(m0, m, n0, n);
1283 break;
1284 case 3:
1285 mc = 3;
1286 gemm_Mx8<3>(m0, m, n0, n);
1287 break;
1288 default:
1289 return;
1290 }
1291 } else if (m_rem >= 4 && n_rem >= 4) {
1292 mc = 4;
1293 nc = 4;
1294 gemm_small<4, 4>(m0, m, n0, n);
1295 } else if ((m_rem > 4) && (n_rem < 4)) {
1296 mc = 4;
1297 switch(n_rem) {
1298 case 1:
1299 nc = 1;
1300 gemm_small<4, 1>(m0, m, n0, n);
1301 break;
1302 case 2:
1303 nc = 2;
1304 gemm_small<4, 2>(m0, m, n0, n);
1305 break;
1306 case 3:
1307 nc = 3;
1308 gemm_small<4, 3>(m0, m, n0, n);
1309 break;
1310
1311 default:
1312 return;
1313 }
1314 } else {
1315 switch((m_rem << 4) | n_rem) {
1316 case 0x43:
1317 mc = 4;
1318 nc = 3;
1319 gemm_small<4, 3>(m0, m, n0, n);
1320 break;
1321 case 0x42:
1322 mc = 4;
1323 nc = 2;
1324 gemm_small<4, 2>(m0, m, n0, n);
1325 break;
1326 case 0x41:
1327 mc = 4;
1328 nc = 1;
1329 gemm_small<4, 1>(m0, m, n0, n);
1330 break;
1331 case 0x34:
1332 mc = 3;
1333 nc = 4;
1334 gemm_small<3, 4>(m0, m, n0, n);
1335 break;
1336 case 0x33:
1337 mc = 3;
1338 nc = 3;
1339 gemm_small<3, 3>(m0, m, n0, n);
1340 break;
1341 case 0x32:
1342 mc = 3;
1343 nc = 2;
1344 gemm_small<3, 2>(m0, m, n0, n);
1345 break;
1346 case 0x31:
1347 mc = 3;
1348 nc = 1;
1349 gemm_small<3, 1>(m0, m, n0, n);
1350 break;
1351 case 0x24:
1352 mc = 2;
1353 nc = 4;
1354 gemm_small<2,4>(m0, m, n0, n);
1355 break;
1356 case 0x23:
1357 mc = 2;
1358 nc = 3;
1359 gemm_small<2, 3>(m0, m, n0, n);
1360 break;
1361 case 0x22:
1362 mc = 2;
1363 nc = 2;
1364 gemm_small<2, 2>(m0, m, n0, n);
1365 break;
1366 case 0x21:
1367 mc = 2;
1368 nc = 1;
1369 gemm_small<2, 1>(m0, m, n0, n);
1370 break;
1371 case 0x14:
1372 mc = 1;
1373 nc = 4;
1374 gemm_small<1, 4>(m0, m, n0, n);
1375 break;
1376 case 0x13:
1377 mc = 1;
1378 nc = 3;
1379 gemm_small<1, 3>(m0, m, n0, n);
1380 break;
1381 case 0x12:
1382 mc = 1;
1383 nc = 2;
1384 gemm_small<1, 2>(m0, m, n0, n);
1385 break;
1386 case 0x11:
1387 mc = 1;
1388 nc = 1;
1389 gemm_small<1, 1>(m0, m, n0, n);
1390 break;
1391 default:
1392 return;
1393 }
1394 }
1395 mp = m0 + (m - m0) / mc * mc;
1396 np = n0 + (n - n0) / nc * nc;
1397 mnpack(mp, m, n0, np);
1398 mnpack(m0, m, np, n);
1399 }
1400
1401 void KERNEL_4x8(int64_t ii, int64_t jj) {
1402 vec_t vec_A[4], vec_B[8] , vec_C[4];
1403 acc_t acc_0, acc_1;
1404 __builtin_mma_xxsetaccz(&acc_0);
1405 __builtin_mma_xxsetaccz(&acc_1);
1406 for (int l = 0; l < k; l+=8) {
1407 packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
1408 packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
1409 for (int x = 0; x < 4; x++) {
1410 __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1411 __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1412 }
1413 }
1414 SAVE_ACC(&acc_0, ii, jj);
1415 SAVE_ACC(&acc_1, ii, jj+4);
1416 }
1417
1418 void KERNEL_8x4(int64_t ii, int64_t jj) {
1419 vec_t vec_A[8], vec_B[4] , vec_C[4];
1420 acc_t acc_0, acc_1;
1421 __builtin_mma_xxsetaccz(&acc_0);
1422 __builtin_mma_xxsetaccz(&acc_1);
1423 for (int l = 0; l < k; l+=8) {
1424 packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
1425 packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
1426 for (int x = 0; x < 4; x++) {
1427 __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1428 __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
1429 }
1430 }
1431 SAVE_ACC(&acc_0, ii, jj);
1432 SAVE_ACC(&acc_1, ii+4, jj);
1433 }
1434
1435
1436 void KERNEL_8x8(int64_t ii, int64_t jj) {
1437 vec_t vec_A[8], vec_B[8], vec_C[4];
1438 acc_t acc_0, acc_1, acc_2, acc_3;
1439 __builtin_mma_xxsetaccz(&acc_0);
1440 __builtin_mma_xxsetaccz(&acc_1);
1441 __builtin_mma_xxsetaccz(&acc_2);
1442 __builtin_mma_xxsetaccz(&acc_3);
1443 for (int l = 0; l < k; l+=8) {
1444 packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
1445 packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
1446 for (int x = 0; x < 4; x++) {
1447 __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1448 __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
1449 __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
1450 __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
1451 }
1452 }
1453
1454 SAVE_ACC(&acc_0, ii, jj);
1455 SAVE_ACC(&acc_1, ii, jj+4);
1456 SAVE_ACC(&acc_2, ii+4, jj);
1457 SAVE_ACC(&acc_3, ii+4, jj+4);
1458 }
1459
1460 template<int RM, int RN>
1461 void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1462 int64_t ytiles = (m - m0) / RM;
1463 int64_t xtiles = (n - n0) / RN;
1464 int64_t tiles = xtiles * ytiles;
1465 int64_t duty = (tiles + nth - 1) / nth;
1466 int64_t start = duty * ith;
1467 int64_t end = start + duty;
1468 if (end > tiles)
1469 end = tiles;
1470 for (int64_t job = start; job < end; ++job) {
1471 int64_t ii = m0 + job / xtiles * RM;
1472 int64_t jj = n0 + job % xtiles * RN;
1473 vec_t vec_C[4];
1474 acc_t acc_0;
1475 __builtin_mma_xxsetaccz(&acc_0);
1476 vec_t vec_A[2], vec_B[2];
1477 for (int l=0; l<k; l+=4) {
1478 packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
1479 packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
1480 for (int x = 0; x<2; x++) {
1481 __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1482 }
1483 }
1484 __builtin_mma_disassemble_acc(vec_C, &acc_0);
1485 for (int I = 0; I < RM; I++) {
1486 for (int J = 0; J < RN; J++) {
1487 *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1488 }
1489 }
1490 }
1491 }
1492
1493 template<int RM>
1494 void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1495 int RN = 8;
1496 int64_t ytiles = (m - m0) / RM;
1497 int64_t xtiles = (n - n0) / RN;
1498 int64_t tiles = xtiles * ytiles;
1499 int64_t duty = (tiles + nth - 1) / nth;
1500 int64_t start = duty * ith;
1501 int64_t end = start + duty;
1502 if (end > tiles)
1503 end = tiles;
1504 for (int64_t job = start; job < end; ++job) {
1505 int64_t ii = m0 + job / xtiles * RM;
1506 int64_t jj = n0 + job % xtiles * RN;
1507 vec_t vec_C[4];
1508 acc_t acc_0, acc_1;
1509 __builtin_mma_xxsetaccz(&acc_0);
1510 __builtin_mma_xxsetaccz(&acc_1);
1511 vec_t vec_A[4], vec_B[8];
1512 for (int l=0; l<k; l+=8) {
1513 packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
1514 packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
1515 for (int x = 0; x<4; x++) {
1516 __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
1517 __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
1518 }
1519 }
1520 __builtin_mma_disassemble_acc(vec_C, &acc_0);
1521 for (int I = 0; I < RM; I++) {
1522 for (int J = 0; J < 4; J++) {
1523 *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1524 }
1525 }
1526 __builtin_mma_disassemble_acc(vec_C, &acc_1);
1527 for (int I = 0; I < RM; I++) {
1528 for (int J = 0; J < 4; J++) {
1529 *((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
1530 }
1531 }
1532 }
1533 }
1534
1535 template<int RM, int RN>
1536 inline void kernel(int64_t ii, int64_t jj) {
1537 if constexpr(RM == 4 && RN == 8) {
1538 KERNEL_4x8(ii,jj);
1539 } else if constexpr(RM == 8 && RN == 8) {
1540 KERNEL_8x8(ii,jj);
1541 } else if constexpr(RM == 8 && RN == 4) {
1542 KERNEL_8x4(ii,jj);
1543 } else {
1544 assert(false && "RN/RM values not supported");
1545 }
1546 }
1547
1548 template <int RM, int RN>
1549 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1550 int64_t ytiles = (m - m0) / RM;
1551 int64_t xtiles = (n - n0) / RN;
1552 int64_t tiles = xtiles * ytiles;
1553 int64_t duty = (tiles + nth - 1) / nth;
1554 int64_t start = duty * ith;
1555 int64_t end = start + duty;
1556 if (end > tiles)
1557 end = tiles;
1558 for (int64_t job = start; job < end; ++job) {
1559 int64_t ii = m0 + job / xtiles * RM;
1560 int64_t jj = n0 + job % xtiles * RN;
1561 kernel<RM, RN>(ii, jj);
1562 }
1563 }
1564
1565 const TA *const A;
1566 const TB *const B;
1567 TC *C;
1568 const int64_t k;
1569 const int64_t lda;
1570 const int64_t ldb;
1571 const int64_t ldc;
1572 const int ith;
1573 const int nth;
1574};
1575
1576template <typename TA>
1577class tinyBLAS_Q0_PPC {
1578 public:
1579 tinyBLAS_Q0_PPC(int64_t k,
1580 const TA *A, int64_t lda,
1581 const block_q8_0 *B, int64_t ldb,
1582 float *C, int64_t ldc,
1583 int ith, int nth)
1584 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585 }
1586
1587 void matmul(int64_t m, int64_t n) {
1588 mnpack(0, m, 0, n);
1589 }
1590
1591 private:
1592
1593 inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1594 for (int I = 0; I < RM; I++) {
1595 for (int J = 0; J < RN; J++) {
1596 *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
1597 }
1598 }
1599 }
1600
1601 template<int size>
1602 inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1603 vector signed int vec_C[4];
1604 vector float CA[4] = {0};
1605 vector float res[4] = {0};
1606 __builtin_mma_disassemble_acc(vec_C, ACC);
1607 for (int i = 0; i < 4; i++) {
1608 CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
1609 res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
1610 fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
1611 }
1612 }
1613 /* This function processes quantized data from block_q4_0 elements.
1614 * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
1615 * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
1616 * Also compute the rowsum which is required to compensate the above conversion. */
1617 inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
1618 const vector signed char lowMask = vec_splats((signed char)0xF);
1619 const vector unsigned char v4 = vec_splats((unsigned char)0x4);
1620 const vector signed char v8 = vec_splats((signed char)0x8);
1621 vector signed int vsum = {0};
1622 vector signed int vsum2 = {0};
1623 c[0] = vec_and(c[1], lowMask);
1624 c[1] = vec_sr(c[1], v4);
1625 c[0] = vec_sub(c[0], v8);
1626 c[1] = vec_sub(c[1], v8);
1627 vsum = vec_sum4s(c[0], vsum);
1628 vsum2 = vec_sum4s(c[1], vsum2);
1629 vsum = vec_add(vsum, vsum2);
1630 *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
1631 }
1632
1633 template <typename V1, typename V2>
1634 inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
1635 vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
1636 vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
1637 vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
1638 vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
1639 V2 t1, t2, t3, t4, t5, t6, t7, t8;
1640 vector unsigned char xor_vector;
1641 uint8_t flip_vec = 0x80;
1642 xor_vector = vec_splats(flip_vec);
1643 t1 = vec_perm(s1, s2, swiz1);
1644 t2 = vec_perm(s1, s2, swiz2);
1645 t3 = vec_perm(s3, s4, swiz1);
1646 t4 = vec_perm(s3, s4, swiz2);
1647 t5 = vec_perm(t1, t3, swiz3);
1648 t6 = vec_perm(t1, t3, swiz4);
1649 t7 = vec_perm(t2, t4, swiz3);
1650 t8 = vec_perm(t2, t4, swiz4);
1651 if (flip == true) {
1652 t5 = vec_xor(t5, xor_vector);
1653 t6 = vec_xor(t6, xor_vector);
1654 t7 = vec_xor(t7, xor_vector);
1655 t8 = vec_xor(t8, xor_vector);
1656 }
1657 vec_xst(t5, 0, vecOffset);
1658 vec_xst(t6, 0, vecOffset+16);
1659 vec_xst(t7, 0, vecOffset+32);
1660 vec_xst(t8, 0, vecOffset+48);
1661 }
1662
1663 template<int size>
1664 void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1665 int64_t i, j;
1666 TA *aoffset = NULL;
1667 int8_t *vecOffset = NULL;
1668 TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1669 TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1670 vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
1671 vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
1672 aoffset = const_cast<TA*>(a);
1673 vecOffset = vec;
1674 j = (rows >> 3);
1675 if (j > 0) {
1676 do {
1677 aoffset1 = aoffset;
1678 aoffset2 = aoffset1 + lda;
1679 aoffset3 = aoffset2 + lda;
1680 aoffset4 = aoffset3 + lda;
1681 aoffset5 = aoffset4 + lda;
1682 aoffset6 = aoffset5 + lda;
1683 aoffset7 = aoffset6 + lda;
1684 aoffset8 = aoffset7 + lda;
1685 aoffset += 8 * lda;
1686 i = (cols >> 2);
1687 if (i > 0) {
1688 do {
1689 c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690 c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691 c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692 c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693 c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694 c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695 c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696 c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697
1698 process_q4_elements(c1, &comparray[0]);
1699 process_q4_elements(c2, &comparray[1]);
1700 process_q4_elements(c3, &comparray[2]);
1701 process_q4_elements(c4, &comparray[3]);
1702 process_q4_elements(c5, &comparray[4]);
1703 process_q4_elements(c6, &comparray[5]);
1704 process_q4_elements(c7, &comparray[6]);
1705 process_q4_elements(c8, &comparray[7]);
1706 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1707 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1708 vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
1709 vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1710 aoffset1 += lda;
1711 aoffset2 += lda;
1712 aoffset3 += lda;
1713 aoffset4 += lda;
1714 aoffset5 += lda;
1715 aoffset6 += lda;
1716 aoffset7 += lda;
1717 aoffset8 += lda;
1718 vecOffset += 256;
1719 i--;
1720 } while (i > 0);
1721 }
1722 j--;
1723 } while (j > 0);
1724 }
1725
1726 if (rows & 4) {
1727 aoffset1 = aoffset;
1728 aoffset2 = aoffset1 + lda;
1729 aoffset3 = aoffset2 + lda;
1730 aoffset4 = aoffset3 + lda;
1731 aoffset += 4 * lda;
1732 i = (cols >> 2);
1733 if (i > 0) {
1734 do {
1735 c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1736 c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1737 c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1738 c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1739
1740 process_q4_elements(c1, &comparray[0]);
1741 process_q4_elements(c2, &comparray[1]);
1742 process_q4_elements(c3, &comparray[2]);
1743 process_q4_elements(c4, &comparray[3]);
1744 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1745 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1746 aoffset1 += lda;
1747 aoffset2 += lda;
1748 aoffset3 += lda;
1749 aoffset4 += lda;
1750 vecOffset += 128;
1751 i--;
1752 } while (i > 0);
1753 }
1754 }
1755
1756 if (rows & 3) {
1757 aoffset1 = aoffset;
1758 aoffset2 = aoffset1 + lda;
1759 aoffset3 = aoffset2 + lda;
1760 i = (cols >> 2);
1761 if (i > 0) {
1762 do {
1763 switch(rows) {
1764 case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1765 case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1766 case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1767 break;
1768 }
1769 process_q4_elements(c1, &comparray[0]);
1770 process_q4_elements(c2, &comparray[1]);
1771 process_q4_elements(c3, &comparray[2]);
1772 process_q4_elements(c4, &comparray[3]);
1773 vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
1774 vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
1775 aoffset1 += lda;
1776 aoffset2 += lda;
1777 aoffset3 += lda;
1778 vecOffset += 128;
1779 i--;
1780 } while(i > 0);
1781 }
1782 }
1783 }
1784 template<typename VA, typename VB>
1785 void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
1786 int64_t i, j;
1787 block_q8_0 *aoffset = NULL;
1788 VA *vecOffset = NULL;
1789 block_q8_0* aoffsets[8];
1790 __vector_pair arr[8];
1791 VB c[8][2] = {0};
1792 VB c1[8] = {0}; VB c2[8] = {0};
1793 aoffset = const_cast<block_q8_0*>(a);
1794 vecOffset = vec;
1795 j = (rows >> 3);
1796 if (j > 0) {
1797 do {
1798 aoffsets[0] = aoffset;
1799 for (int it = 1; it < 8; it++)
1800 aoffsets[it] = aoffsets[it-1] + lda;
1801 aoffset += 8 * lda;
1802
1803 i = (cols >> 3);
1804 if (i > 0) {
1805 do {
1806 for (int it = 0; it < 8; it++) {
1807 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1808 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1809 c1[it] = c[it][0];
1810 c2[it] = c[it][1];
1811 }
1812 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1813 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1814 vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
1815 vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816 for (int it = 0; it < 8; it++)
1817 aoffsets[it] += lda;
1818 vecOffset += 256;
1819 i--;
1820 } while(i > 0);
1821 }
1822 j--;
1823 } while(j > 0);
1824 }
1825
1826 if (rows & 4) {
1827 aoffsets[0] = aoffset;
1828 for (int it = 1; it < 4; it++ )
1829 aoffsets[it] = aoffsets[it-1] + lda;
1830 aoffset += 4 * lda;
1831 i = (cols >> 3);
1832 if (i > 0) {
1833 do {
1834 for (int it = 0; it < 4; it++) {
1835 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1836 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
1837 c1[it] = c[it][0];
1838 c2[it] = c[it][1];
1839 }
1840 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1841 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1842 for (int it = 0; it < 4; it++) {
1843 aoffsets[it] += lda;
1844 }
1845 vecOffset += 128;
1846 i--;
1847 } while(i > 0);
1848 }
1849 }
1850
1851 if (rows & 3) {
1852 aoffsets[0] = aoffset;
1853 for (int it = 1; it < 3; it++ )
1854 aoffsets[it] = aoffsets[it-1] + lda;
1855 i = (cols >> 3);
1856 if (i > 0) {
1857 do {
1858 switch(rows) {
1859 case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
1860 __builtin_vsx_disassemble_pair(c[2], &arr[2]);
1861 c1[2] = c[2][0]; c2[2] = c[2][1];
1862 case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
1863 __builtin_vsx_disassemble_pair(c[1], &arr[1]);
1864 c1[1] = c[1][0]; c2[1] = c[1][1];
1865 case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
1866 __builtin_vsx_disassemble_pair(c[0], &arr[0]);
1867 c1[0] = c[0][0]; c2[0] = c[0][1];
1868 break;
1869 }
1870 vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
1871 vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
1872 for (int it = 0; it < 3; it++)
1873 aoffsets[it] += lda;
1874 vecOffset += 128;
1875 i--;
1876 } while(i > 0);
1877 }
1878 }
1879 }
1880
1881 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1882 int m_rem = MIN(m - m0, 16);
1883 int n_rem = MIN(n - n0, 16);
1884
1885 int mc = 0, nc = 0;
1886
1887 if (m_rem >= 8 && n_rem >= 8) {
1888 mc = 8;
1889 nc = 8;
1890 gemm<8, 8>(m0, m, n0, n);
1891 } else if (m_rem >= 4 && n_rem >= 8) {
1892 mc = 4;
1893 nc = 8;
1894 gemm<4, 8>(m0, m, n0, n);
1895 } else if (m_rem >= 8 && n_rem >= 4) {
1896 mc = 8;
1897 nc = 4;
1898 gemm<8, 4>(m0, m, n0, n);
1899 } else if (m_rem >= 4 && n_rem >= 4) {
1900 mc = 4;
1901 nc = 4;
1902 gemm_small(m0, m, n0, n, mc, nc);
1903 } else {
1904 mc = (m_rem >= 4) ? 4 : m_rem;
1905 nc = (n_rem >= 4) ? 4 : n_rem;
1906 if (mc == 0 || nc == 0)
1907 return;
1908 gemm_small(m0, m, n0, n, mc, nc);
1909 }
1910
1911 int64_t mp = m0 + ((m - m0) / mc) * mc;
1912 int64_t np = n0 + ((n - n0) / nc) * nc;
1913 mnpack(mp, m, n0, np);
1914 mnpack(m0, m, np, n);
1915 }
1916
1917
1918 void KERNEL_4x8(int64_t ii, int64_t jj) {
1919 vec_t vec_A[8], vec_B[16] = {0};
1920 acc_t acc_0, acc_1;
1921 std::array<int, 4> comparray {};
1922 vector float fin_res[8] = {0};
1923 vector float vs[8] = {0};
1924 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1925 for (int l = 0; l < k; l++) {
1926 __builtin_mma_xxsetaccz(&acc_0);
1927 __builtin_mma_xxsetaccz(&acc_1);
1928 if (std::is_same_v<TA, block_q4_0>) {
1929 packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
1930 } else {
1931 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
1932 }
1933 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
1934 for(int x = 0; x < 8; x++) {
1935 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1936 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
1937 }
1938 for (int I = 0; I<4; I++) {
1939 for (int J = 0; J<4; J++) {
1940 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1941 *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
1942 }
1943 }
1944 if (!isAblock_q4) {
1945 auto aoffset = A+(ii*lda)+l;
1946 for (int i = 0; i < 4; i++) {
1947 comparray[i] = 0;
1948 int ca = 0;
1949 auto *at = aoffset->qs;
1950 for (int j = 0; j < 32; j++)
1951 ca += (int)*at++;
1952 comparray[i] = ca;
1953 aoffset += lda;
1954 }
1955 }
1956 compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
1957 compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
1958 }
1959 save_res(ii, jj, 0, fin_res);
1960 save_res(ii, jj+4, 4, fin_res);
1961 }
1962
1963 void KERNEL_8x4(int64_t ii, int64_t jj) {
1964 vec_t vec_A[16], vec_B[8] = {0};
1965 acc_t acc_0, acc_1;
1966 std::array<int, 8> comparray {};
1967 vector float fin_res[8] = {0};
1968 vector float vs[8] = {0};
1969 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
1970 for (int l = 0; l < k; l++) {
1971 __builtin_mma_xxsetaccz(&acc_0);
1972 __builtin_mma_xxsetaccz(&acc_1);
1973 if (std::is_same_v<TA, block_q4_0>) {
1974 packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
1975 } else {
1976 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
1977 }
1978 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
1979 for(int x = 0; x < 8; x++) {
1980 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
1981 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
1982 }
1983 for (int I = 0; I<8; I++) {
1984 for (int J = 0; J<4; J++) {
1985 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
1986 }
1987 }
1988 if (!isAblock_q4) {
1989 auto aoffset = A+(ii*lda)+l;
1990 for (int i = 0; i < 8; i++) {
1991 comparray[i] = 0;
1992 int ca = 0;
1993 auto *at = aoffset->qs;
1994 for (int j = 0; j < 32; j++)
1995 ca += (int)*at++;
1996 comparray[i] = ca;
1997 aoffset += lda;
1998 }
1999 }
2000 compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2001 compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2002 }
2003 save_res(ii, jj, 0, fin_res);
2004 save_res(ii+4, jj, 4, fin_res);
2005 }
2006
2007 void KERNEL_8x8(int64_t ii, int64_t jj) {
2008 vec_t vec_A[16], vec_B[16] = {0};
2009 acc_t acc_0, acc_1, acc_2, acc_3;
2010 std::array<int, 8> comparray {};
2011 vector float fin_res[16] = {0};
2012 vector float vs[16] = {0};
2013 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2014 for (int l = 0; l < k; l++) {
2015 __builtin_mma_xxsetaccz(&acc_0);
2016 __builtin_mma_xxsetaccz(&acc_1);
2017 __builtin_mma_xxsetaccz(&acc_2);
2018 __builtin_mma_xxsetaccz(&acc_3);
2019 if (std::is_same_v<TA, block_q4_0>) {
2020 packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
2021 } else {
2022 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
2023 }
2024 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
2025 for(int x = 0; x < 8; x++) {
2026 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2027 __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
2028 __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
2029 __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
2030 }
2031 for (int I = 0; I<8; I++) {
2032 for (int J = 0; J<4; J++) {
2033 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2034 *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
2035 }
2036 }
2037 if (!isAblock_q4) {
2038 auto aoffset = A+(ii*lda)+l;
2039 for (int i = 0; i < 8; i++) {
2040 comparray[i] = 0;
2041 int ca = 0;
2042 auto *at = aoffset->qs;
2043 for (int j = 0; j < 32; j++)
2044 ca += (int)*at++;
2045 comparray[i] = ca;
2046 aoffset += lda;
2047 }
2048 }
2049 compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
2050 compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
2051 compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
2052 compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
2053 }
2054 save_res(ii, jj, 0, fin_res);
2055 save_res(ii+4, jj, 4, fin_res);
2056 save_res(ii, jj+4, 8, fin_res);
2057 save_res(ii+4, jj+4, 12, fin_res);
2058 }
2059
2060 void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2061 int64_t ytiles = (m - m0) / RM;
2062 int64_t xtiles = (n - n0) / RN;
2063 int64_t tiles = xtiles * ytiles;
2064 int64_t duty = (tiles + nth - 1) / nth;
2065 int64_t start = duty * ith;
2066 int64_t end = start + duty;
2067 vec_t vec_A[8] = {0}, vec_B[8] = {0};
2068 vector signed int vec_C[4];
2069 acc_t acc_0;
2070 bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
2071
2072 if (end > tiles)
2073 end = tiles;
2074 for (int64_t job = start; job < end; ++job) {
2075 int64_t ii = m0 + job / xtiles * RM;
2076 int64_t jj = n0 + job % xtiles * RN;
2077 std::array<int, 4> comparray{};
2078 vector float res[4] = {0};
2079 vector float fin_res[4] = {0};
2080 vector float vs[4] = {0};
2081 vector float CA[4] = {0};
2082 __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
2083 __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
2084 for (int l = 0; l < k; l++) {
2085 __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2086 __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
2087 __builtin_mma_xxsetaccz(&acc_0);
2088 if (isAblock_q4) {
2089 packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
2090 } else {
2091 packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
2092 }
2093 packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
2094 for(int x = 0; x < 8; x+=4) {
2095 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
2096 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
2097 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
2098 __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
2099 }
2100 for (int I = 0; I<RM; I++) {
2101 for (int J = 0; J<RN; J++) {
2102 *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
2103 }
2104 }
2105 __builtin_mma_disassemble_acc(vec_C, &acc_0);
2106 if (!isAblock_q4) {
2107 auto aoffset = A+(ii*lda)+l;
2108 for (int i = 0; i < RM; i++) {
2109 comparray[i] = 0;
2110 int ca = 0;
2111 auto *at = aoffset->qs;
2112 for (int j = 0; j < 32; j++)
2113 ca += (int)*at++;
2114 comparray[i] = ca;
2115 aoffset += lda;
2116 }
2117 }
2118 for (int i = 0; i < RM; i++) {
2119 CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
2120 res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
2121 fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
2122 }
2123 }
2124 save_res(ii, jj, 0, fin_res, RM, RN);
2125 }
2126 }
2127
2128 template<int RM, int RN>
2129 inline void kernel(int64_t ii, int64_t jj) {
2130 if constexpr(RM == 4 && RN == 8) {
2131 KERNEL_4x8(ii,jj);
2132 } else if constexpr(RM == 8 && RN == 4) {
2133 KERNEL_8x4(ii,jj);
2134 } else if constexpr(RM == 8 && RN == 8) {
2135 KERNEL_8x8(ii,jj);
2136 } else {
2137 assert(false && "RN/RM values not supported");
2138 }
2139 }
2140
2141 template <int RM, int RN>
2142 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2143 int64_t ytiles = (m - m0) / RM;
2144 int64_t xtiles = (n - n0) / RN;
2145 int64_t tiles = xtiles * ytiles;
2146 int64_t duty = (tiles + nth - 1) / nth;
2147 int64_t start = duty * ith;
2148 int64_t end = start + duty;
2149 if (end > tiles)
2150 end = tiles;
2151 for (int64_t job = start; job < end; ++job) {
2152 int64_t ii = m0 + job / xtiles * RM;
2153 int64_t jj = n0 + job % xtiles * RN;
2154 kernel<RM, RN>(ii, jj);
2155 }
2156 }
2157
2158 const TA *const A;
2159 const block_q8_0 *const B;
2160 float *C;
2161 const int64_t k;
2162 const int64_t lda;
2163 const int64_t ldb;
2164 const int64_t ldc;
2165 const int ith;
2166 const int nth;
2167};
2168
2169class tinyBLAS_PPC {
2170 public:
2171 tinyBLAS_PPC(int64_t k,
2172 const float * A, int64_t lda,
2173 const float * B, int64_t ldb,
2174 float * C, int64_t ldc,
2175 int ith, int nth)
2176 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
2177 }
2178
2179 void matmul(int64_t m, int64_t n) {
2180 int64_t mc = 256; int64_t nc = 256; int64_t kc = 256;
2181 if (m % mc == 0 && n % nc == 0 && k % kc == 0) {
2182 matmul_tiled(m, n, mc, nc, kc);
2183 } else {
2184 mnpack(0, m, 0, n);
2185 }
2186 }
2187
2188 private:
2189
2190 inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2191 vec_t vec_C[4];
2192 __builtin_mma_disassemble_acc(vec_C, ACC);
2193 for (int I = 0; I < 4; I++) {
2194 for (int J = 0; J < 4; J++) {
2195 *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2196 }
2197 }
2198 }
2199
2200 inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
2201 vec_t vec_C[4];
2202 __builtin_mma_disassemble_acc(vec_C, ACC);
2203 for (int I = 0; I < 4; I++) {
2204 for (int J = 0; J < 4; J++) {
2205 float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
2206 *c_ptr += *((float *)&vec_C[I]+J);
2207 }
2208 }
2209 }
2210
2211 inline void vector_permute_store_4(vector float * src, float * vecOffset) {
2212 vector float t1, t2, t3, t4, t5, t6, t7, t8;
2213 t1 = vec_mergeh(src[0], src[1]);
2214 t2 = vec_mergeh(src[2], src[3]);
2215 t3 = vec_mergel(src[0], src[1]);
2216 t4 = vec_mergel(src[2], src[3]);
2217
2218 t5 = vec_xxpermdi(t1, t2, 0);
2219 t6 = vec_xxpermdi(t1, t2, 3);
2220 t7 = vec_xxpermdi(t3, t4, 0);
2221 t8 = vec_xxpermdi(t3, t4, 3);
2222
2223 vec_xst(t5, 0, vecOffset);
2224 vec_xst(t6, 0, vecOffset + 4);
2225 vec_xst(t7, 0, vecOffset + 8);
2226 vec_xst(t8, 0, vecOffset + 12);
2227 }
2228
2229 inline void vector_permute_store_8(vector float * src, float * vecOffset) {
2230 vector float t1, t2, t3, t4, t5, t6, t7, t8;
2231 t1 = vec_mergeh(src[0], src[1]);
2232 t2 = vec_mergeh(src[2], src[3]);
2233 t3 = vec_mergeh(src[4], src[5]);
2234 t4 = vec_mergeh(src[6], src[7]);
2235
2236 t5 = vec_xxpermdi(t1, t2, 0);
2237 t6 = vec_xxpermdi(t3, t4, 0);
2238 t7 = vec_xxpermdi(t1, t2, 3);
2239 t8 = vec_xxpermdi(t3, t4, 3);
2240
2241 vec_xst(t5, 0, vecOffset);
2242 vec_xst(t6, 0, vecOffset + 4);
2243 vec_xst(t7, 0, vecOffset + 8);
2244 vec_xst(t8, 0, vecOffset + 12);
2245
2246 t1 = vec_mergel(src[0], src[1]);
2247 t2 = vec_mergel(src[2], src[3]);
2248 t3 = vec_mergel(src[4], src[5]);
2249 t4 = vec_mergel(src[6], src[7]);
2250
2251 t5 = vec_xxpermdi(t1, t2, 0);
2252 t6 = vec_xxpermdi(t3, t4, 0);
2253 t7 = vec_xxpermdi(t1, t2, 3);
2254 t8 = vec_xxpermdi(t3, t4, 3);
2255
2256 vec_xst(t5, 0, vecOffset + 16);
2257 vec_xst(t6, 0, vecOffset + 20);
2258 vec_xst(t7, 0, vecOffset + 24);
2259 vec_xst(t8, 0, vecOffset + 28);
2260 }
2261
2262 void packTranspose(const float * a, int64_t lda, int rows, int cols, float * vec) {
2263 int64_t i, j;
2264 float * aoffsets[8];
2265 float * aoffset = NULL, * boffset = NULL;
2266 __vector_pair arr[8];
2267 vector float c[8][2] = {0};
2268 vector float c1[8] = {0};
2269 vector float c2[8] = {0};
2270 aoffset = const_cast<float *>(a);
2271 boffset = vec;
2272 j = (rows >> 3);
2273 if (j > 0) {
2274 do {
2275 aoffsets[0] = aoffset;
2276 for (int it = 1; it < 8; it++)
2277 aoffsets[it] = aoffsets[it-1] + lda;
2278 aoffset += 8 * lda;
2279 i = (cols >> 3);
2280 if (i > 0) {
2281 do {
2282 for (int it = 0; it < 8; it++) {
2283 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2284 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2285 c1[it] = c[it][0];
2286 c2[it] = c[it][1];
2287 }
2288
2289 vector_permute_store_8(c1, boffset);
2290 vector_permute_store_8(c2, boffset + 32);
2291 boffset += 64;
2292 i--;
2293 if (i > 0) {
2294 for (int it = 0; it < 8; it++) {
2295 aoffsets[it] = aoffsets[it] + 8;
2296 }
2297 }
2298 } while(i > 0);
2299 }
2300 if (cols & 4) {
2301 for (int it = 0; it < 8 ; it++)
2302 c1[it] = vec_xl(0, aoffsets[it]);
2303 vector_permute_store_8(c1, boffset);
2304 }
2305 j--;
2306 } while(j > 0);
2307 }
2308
2309 if (rows & 4) {
2310 aoffsets[0] = aoffset;
2311 for (int it = 1; it < 4; it++)
2312 aoffsets[it] = aoffsets[it-1] + lda;
2313 aoffset += 4 * lda;
2314 i = (cols >> 3);
2315 if (i > 0) {
2316 do {
2317 for (int it = 0; it < 4; it++) {
2318 arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
2319 __builtin_vsx_disassemble_pair(c[it], &arr[it]);
2320 c1[it] = c[it][0];
2321 c2[it] = c[it][1];
2322 }
2323 vector_permute_store_4(c1, boffset);
2324 vector_permute_store_4(c2, boffset + 16);
2325 for (int it = 0; it < 4; it++)
2326 aoffsets[it] += 8 * lda;
2327 boffset += 32;
2328 i--;
2329 } while(i > 0);
2330 }
2331
2332 if (cols & 4) {
2333 for (int it = 0; it < 4; it++)
2334 c1[it] = vec_xl(0, aoffsets[it]);
2335 vector_permute_store_4(c1, boffset);
2336 }
2337 }
2338 if (rows & 3) {
2339 aoffsets[0] = aoffset;
2340 for (int it = 1; it < 3; it++)
2341 aoffsets[it] = aoffsets[it-1] + lda;
2342 if (cols & 4) {
2343 for (int it = 0; it < 3; it++)
2344 c1[it] = vec_xl(0, aoffsets[it]);
2345 vector_permute_store_4(c1, boffset);
2346 }
2347 }
2348 }
2349
2350 void KERNEL_4x4(int64_t ii, int64_t jj) {
2351 vec_t vec_A[4], vec_B[4], vec_C[4];
2352 acc_t acc_0;
2353 __builtin_mma_xxsetaccz(&acc_0);
2354 for (int l = 0; l < k; l += 4) {
2355 packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2356 packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2357 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2358 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2359 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2360 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2361 }
2362 save_acc(&acc_0, ii, jj);
2363 }
2364
2365 void KERNEL_4x8(int64_t ii, int64_t jj) {
2366 vec_t vec_A[4], vec_B[8], vec_C[4];
2367 acc_t acc_0, acc_1;
2368 __builtin_mma_xxsetaccz(&acc_0);
2369 __builtin_mma_xxsetaccz(&acc_1);
2370 for (int64_t l = 0; l < k; l += 4) {
2371 packTranspose(A + (ii * lda) + l, lda, 4, 4, (float *)vec_A);
2372 packTranspose(B + (jj * ldb) + l, ldb, 8, 4, (float *)vec_B);
2373 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
2374 __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
2375 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
2376 __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
2377 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
2378 __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
2379 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
2380 __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
2381 }
2382 save_acc(&acc_0, ii, jj);
2383 save_acc(&acc_1, ii, jj + 4);
2384 }
2385
2386 void KERNEL_8x4(int64_t ii, int64_t jj) {
2387 vec_t vec_A[8], vec_B[4], vec_C[4];
2388 acc_t acc_0, acc_1;
2389 __builtin_mma_xxsetaccz(&acc_0);
2390 __builtin_mma_xxsetaccz(&acc_1);
2391 for (int64_t l = 0; l < k; l += 4) {
2392 packTranspose(A + (ii * lda) + l, lda, 8, 4, (float *)vec_A);
2393 packTranspose(B + (jj * ldb) + l, ldb, 4, 4, (float *)vec_B);
2394 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
2395 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
2396 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
2397 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
2398 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
2399 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
2400 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
2401 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
2402 }
2403 save_acc(&acc_0, ii, jj);
2404 save_acc(&acc_1, ii + 4, jj);
2405 }
2406
2407 void KERNEL_8x8(int64_t ii, int64_t jj) {
2408 vec_t vec_A[16], vec_B[16], vec_C[4];
2409 acc_t acc_0, acc_1, acc_2, acc_3;
2410 __builtin_mma_xxsetaccz(&acc_0);
2411 __builtin_mma_xxsetaccz(&acc_1);
2412 __builtin_mma_xxsetaccz(&acc_2);
2413 __builtin_mma_xxsetaccz(&acc_3);
2414 for (int l = 0; l < k; l+=8) {
2415 packTranspose(A + (ii * lda) + l, lda, 8, 8, (float *)vec_A);
2416 packTranspose(B + (jj * ldb) + l, ldb, 8, 8, (float *)vec_B);
2417 for(int x = 0; x < 16; x+=2) {
2418 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
2419 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x + 1]);
2420 __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x + 1], vec_B[x]);
2421 __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x + 1], vec_B[x + 1]);
2422 }
2423 }
2424 save_acc(&acc_0, ii, jj);
2425 save_acc(&acc_1, ii, jj + 4);
2426 save_acc(&acc_2, ii + 4, jj);
2427 save_acc(&acc_3, ii + 4, jj + 4);
2428 }
2429
2430 inline void MMA_16x8(vec_t * vec_A0, vec_t * vec_A1, vec_t * vec_B, acc_t * acc) {
2431 for (int x = 0; x < 16; x += 2) {
2432 __builtin_mma_xvf32gerpp(&acc[0], vec_A0[x + 0], vec_B[x]);
2433 __builtin_mma_xvf32gerpp(&acc[1], vec_A0[x + 0], vec_B[x + 1]);
2434 __builtin_mma_xvf32gerpp(&acc[2], vec_A0[x + 1], vec_B[x]);
2435 __builtin_mma_xvf32gerpp(&acc[3], vec_A0[x + 1], vec_B[x + 1]);
2436 __builtin_mma_xvf32gerpp(&acc[4], vec_A1[x + 0], vec_B[x]);
2437 __builtin_mma_xvf32gerpp(&acc[5], vec_A1[x + 0], vec_B[x + 1]);
2438 __builtin_mma_xvf32gerpp(&acc[6], vec_A1[x + 1], vec_B[x]);
2439 __builtin_mma_xvf32gerpp(&acc[7], vec_A1[x + 1], vec_B[x + 1]);
2440 }
2441 }
2442
2443 void KERNEL(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t * vec_A, vec_t * vec_B, int64_t kk) {
2444 for (int64_t i = 0; i < mc; i += 16) {
2445 int A_base_addr = (mc / 8) * (i / 8) * 16;
2446 for (int64_t j = 0; j < nc; j += 8) {
2447 int B_base_addr = (nc / 8) * (j / 8) * 16;
2448 acc_t acc[8];
2449 vec_t A0_block[16]; vec_t A1_block[16];
2450 for (int x = 0; x < 8; x++)
2451 __builtin_mma_xxsetaccz(&acc[x]);
2452 for (int64_t l = 0; l < kc; l += 8) {
2453 int A0_block_idx = A_base_addr + (l / 8) * 16;
2454 int A1_block_idx = A0_block_idx + (mc / 8) * 16;
2455 int B_block_idx = B_base_addr + (l / 8) * 16;
2456 vec_t* A0_block = &vec_A[A0_block_idx];
2457 vec_t* A1_block = &vec_A[A1_block_idx];
2458 vec_t* B_block = &vec_B[B_block_idx];
2459 MMA_16x8(A0_block, A1_block, B_block, acc);
2460 }
2461 if (kk == 0) {
2462 save_acc(&acc[0], ii + i, jj + j);
2463 save_acc(&acc[1], ii + i, jj + j + 4);
2464 save_acc(&acc[2], ii + i + 4, jj + j);
2465 save_acc(&acc[3], ii + i + 4, jj + j + 4);
2466 save_acc(&acc[4], ii + i + 8, jj + j);
2467 save_acc(&acc[5], ii + i + 8, jj + j + 4);
2468 save_acc(&acc[6], ii + i + 12, jj + j);
2469 save_acc(&acc[7], ii + i + 12, jj + j + 4);
2470 } else {
2471 add_save_acc(&acc[0], ii + i, jj + j);
2472 add_save_acc(&acc[1], ii + i, jj + j + 4);
2473 add_save_acc(&acc[2], ii + i + 4, jj + j);
2474 add_save_acc(&acc[3], ii + i + 4, jj + j + 4);
2475 add_save_acc(&acc[4], ii + i + 8, jj + j);
2476 add_save_acc(&acc[5], ii + i + 8, jj + j + 4);
2477 add_save_acc(&acc[6], ii + i + 12, jj + j);
2478 add_save_acc(&acc[7], ii + i + 12, jj + j + 4);
2479 }
2480 }
2481 }
2482 }
2483
2484 void matmul_tiled(int64_t m , int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2485 int64_t ytiles = m / mc;
2486 int64_t xtiles = n / nc;
2487 int64_t tiles = xtiles * ytiles;
2488 int64_t duty = (tiles + nth - 1) / nth;
2489 int64_t start = duty * ith;
2490 int64_t end = start + duty;
2491 if (end > tiles) {
2492 end = tiles;
2493 }
2494 for (int64_t job = start; job < end; ++job) {
2495 int64_t ii = (job / xtiles) * mc;
2496 int64_t jj = (job % xtiles) * nc;
2497 for (int64_t kk = 0; kk < k; kk += kc) {
2498 vec_t A_pack[kc * mc / 4];
2499 vec_t B_pack[kc * nc / 4];
2500 packTranspose(A + (ii * lda) + kk, lda, kc, mc, (float *)A_pack);
2501 packTranspose(B + (jj * ldb) + kk, ldb, kc, nc, (float *)B_pack);
2502 KERNEL(ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2503 }
2504 }
2505 }
2506
2507 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2508 int m_rem = MIN(m - m0, 8);
2509 int n_rem = MIN(n - n0, 8);
2510 int mc = 0, nc = 0;
2511 if (m_rem >= 8 && n_rem >= 8) {
2512 mc = 8;
2513 nc = 8;
2514 gemm<8, 8>(m0, m, n0, n);
2515 } else if (m_rem >= 4 && n_rem >= 8) {
2516 mc = 4;
2517 nc = 8;
2518 gemm<4, 8>(m0, m, n0, n);
2519 } else if (m_rem >= 8 && n_rem >= 4) {
2520 mc = 8;
2521 nc = 4;
2522 gemm<8, 4>(m0, m, n0, n);
2523 } else if (m_rem >= 4 && n_rem >= 4) {
2524 mc = 4;
2525 nc = 4;
2526 gemm<4, 4>(m0, m, n0, n);
2527 } else {
2528 mc = (m_rem >= 4) ? 4 : m_rem;
2529 nc = (n_rem >= 4) ? 4 : n_rem;
2530 if (mc == 0 || nc == 0)
2531 return;
2532 gemm_small(m0, m, n0, n, mc, nc);
2533 }
2534 int64_t mp = m0 + ((m - m0) / mc) * mc;
2535 int64_t np = n0 + ((n - n0) / nc) * nc;
2536 mnpack(mp, m, n0, np);
2537 mnpack(m0, m, np, n);
2538 }
2539
2540 void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
2541 int64_t ytiles = (m - m0) / RM;
2542 int64_t xtiles = (n - n0) / RN;
2543 int64_t tiles = xtiles * ytiles;
2544 int64_t duty = (tiles + nth - 1) / nth;
2545 int64_t start = duty * ith;
2546 int64_t end = start + duty;
2547 if (end > tiles)
2548 end = tiles;
2549 for (int64_t job = start; job < end; ++job) {
2550 int64_t ii = m0 + job / xtiles * RM;
2551 int64_t jj = n0 + job % xtiles * RN;
2552 vec_t vec_C[4];
2553 acc_t acc_0;
2554 __builtin_mma_xxsetaccz(&acc_0);
2555 vec_t vec_A[4] = {0}, vec_B[4] = {0};
2556 for (int l = 0; l < k; l += 4) {
2557 /* 'GEMV Forwarding' concept is used in first two conditional loops.
2558 * when one of the matrix has a single row/column, the elements are
2559 * broadcasted, instead of using packing routine to prepack the
2560 * matrix elements.
2561 */
2562 if (RM == 1) {
2563 float * a = const_cast<float *>(A + (ii) * lda + l);
2564 packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2565 vec_A[0] = (vec_t)vec_xl(0,a);
2566 vec_A[1] = (vec_t)vec_splats(*((float *)&vec_A+1));
2567 vec_A[2] = (vec_t)vec_splats(*((float *)&vec_A+2));
2568 vec_A[3] = (vec_t)vec_splats(*((float *)&vec_A+3));
2569 } else if (RN == 1) {
2570 packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2571 float * b = const_cast<float *>(B + (jj) * ldb + l);
2572 vec_B[0] = (vec_t)vec_xl(0,b);
2573 vec_B[1] = (vec_t)vec_splats(*((float *)&vec_B+1));
2574 vec_B[2] = (vec_t)vec_splats(*((float *)&vec_B+2));
2575 vec_B[3] = (vec_t)vec_splats(*((float *)&vec_B+3));
2576 } else {
2577 packTranspose(A + (ii * lda) + l, lda, RM, 4, (float *)vec_A);
2578 packTranspose(B + (jj * ldb) + l, ldb, RN, 4, (float *)vec_B);
2579 }
2580 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
2581 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
2582 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
2583 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
2584 }
2585 __builtin_mma_disassemble_acc(vec_C, &acc_0);
2586 for (int I = 0; I < RM; I++) {
2587 for (int J = 0; J < RN; J++) {
2588 *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2589 }
2590 }
2591 }
2592 }
2593
2594 template<int RM, int RN>
2595 inline void kernel(int64_t ii, int64_t jj) {
2596 if constexpr(RM == 4 && RN == 4) {
2597 KERNEL_4x4(ii, jj);
2598 } else if constexpr(RM == 4 && RN == 8) {
2599 KERNEL_4x8(ii, jj);
2600 } else if constexpr(RM == 8 && RN == 4) {
2601 KERNEL_8x4(ii, jj);
2602 } else if constexpr(RM == 8 && RN == 8) {
2603 KERNEL_8x8(ii, jj);
2604 } else {
2605 static_assert(false, "RN/RM values not supported");
2606 }
2607 }
2608
2609 template <int RM, int RN>
2610 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
2611 int64_t ytiles = (m - m0) / RM;
2612 int64_t xtiles = (n - n0) / RN;
2613 int64_t tiles = xtiles * ytiles;
2614 int64_t duty = (tiles + nth - 1) / nth;
2615 int64_t start = duty * ith;
2616 int64_t end = start + duty;
2617 if (end > tiles)
2618 end = tiles;
2619 for (int64_t job = start; job < end; ++job) {
2620 int64_t ii = m0 + job / xtiles * RM;
2621 int64_t jj = n0 + job % xtiles * RN;
2622 kernel<RM, RN>(ii, jj);
2623 }
2624 }
2625
2626 const float * const A;
2627 const float * const B;
2628 float * C;
2629 const int64_t k;
2630 const int64_t lda;
2631 const int64_t ldb;
2632 const int64_t ldc;
2633 const int ith;
2634 const int nth;
2635};
2636#endif
2637} // namespace
2638
2639/**
2640 * Performs optimized matrix multiplication on CPU.
2641 *
2642 * This subroutine may compute C = Aᵀ * B with column major ordering.
2643 * Despite its name, this isn't a generalized implementation. Work is
2644 * only performed when a handwritten kernel is written and available.
2645 * Otherwise the caller should fall back to a general matmul routine.
2646 *
2647 * For example, for single-threaded single-precision GEMM you can say
2648 *
2649 * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
2650 * 0, 1,
2651 * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
2652 *
2653 * @param m is rows in `A` and `C`
2654 * @param n is cols in `B` and `C`
2655 * @param k is cols in `A` and rows in `B`
2656 * @param A is first input matrix (always transposed)
2657 * @param lda is row stride of `A`
2658 * @param B is second input matrix (never transposed)
2659 * @param ldb is row stride of `B`
2660 * @param C is input/output array of output matrices
2661 * @param ldc is row stride of `C`
2662 * @param ith is thread id (must be less than `nth`)
2663 * @param nth is number of threads (must be greater than zero)
2664 * @param Atype is GGML data type of `A`
2665 * @param Btype is GGML data type of `B`
2666 * @param Ctype is GGML data type of `C`
2667 * @return true if this function was able to service the matmul request
2668 */
2669bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
2670 const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
2671 int64_t ldc, int Atype, int Btype, int Ctype) {
2672
2673 assert(m >= 0);
2674 assert(n >= 0);
2675 assert(k >= 0);
2676 assert(lda >= k);
2677 assert(ldb >= k);
2678 assert(ldc >= m);
2679 assert(params->nth > 0);
2680 assert(params->ith < params->nth);
2681
2682 // only enable sgemm for prompt processing
2683#if !defined(__MMA__)
2684 if (n < 2)
2685 return false;
2686#endif
2687
2688 if (Ctype != GGML_TYPE_F32)
2689 return false;
2690
2691 switch (Atype) {
2692
2693 case GGML_TYPE_F32: {
2694 if (Btype != GGML_TYPE_F32)
2695 return false;
2696#if defined(__AVX512F__)
2697 tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
2698 k, (const float *)A, lda,
2699 (const float *)B, ldb,
2700 (float *)C, ldc};
2701 return tb.matmul(m, n);
2702#elif defined(__AVX__) || defined(__AVX2__)
2703 tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
2704 k, (const float *)A, lda,
2705 (const float *)B, ldb,
2706 (float *)C, ldc};
2707 return tb.matmul(m, n);
2708#elif defined(__ARM_NEON)
2709 if (n < 4)
2710 return false;
2711 tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
2712 k, (const float *)A, lda,
2713 (const float *)B, ldb,
2714 (float *)C, ldc};
2715 return tb.matmul(m, n);
2716#elif defined(__VXE__) || defined(__VXE2__)
2717 if (n < 4)
2718 return false;
2719 tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
2720 k, (const float *)A, lda,
2721 (const float *)B, ldb,
2722 (float *)C, ldc};
2723 return tb.matmul(m, n);
2724#elif defined(__MMA__)
2725 if (k % 8)
2726 return false;
2727 tinyBLAS_PPC tb{
2728 k, (const float *)A, lda,
2729 (const float *)B, ldb,
2730 (float *)C, ldc,
2731 params->ith, params->nth};
2732 tb.matmul(m, n);
2733 return true;
2734#else
2735 return false;
2736#endif
2737 }
2738
2739 case GGML_TYPE_BF16: {
2740#if defined(__AVX512BF16__)
2741 if (Btype == GGML_TYPE_BF16) {
2742 tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2743 (const ggml_bf16_t *)A, lda,
2744 (const ggml_bf16_t *)B, ldb,
2745 (float *)C, ldc};
2746 return tb.matmul(m, n);
2747 }
2748#elif defined(__AVX512F__)
2749 if (Btype == GGML_TYPE_BF16) {
2750 tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2751 (const ggml_bf16_t *)A, lda,
2752 (const ggml_bf16_t *)B, ldb,
2753 (float *)C, ldc};
2754 return tb.matmul(m, n);
2755 }
2756#elif defined(__AVX2__)
2757 if (Btype == GGML_TYPE_BF16) {
2758 tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
2759 (const ggml_bf16_t *)A, lda,
2760 (const ggml_bf16_t *)B, ldb,
2761 (float *)C, ldc};
2762 return tb.matmul(m, n);
2763 }
2764#elif defined(__MMA__)
2765 if ((k % 8))
2766 return false;
2767 if(Btype == GGML_TYPE_BF16) {
2768 tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
2769 (const ggml_bf16_t *)A, lda,
2770 (const ggml_bf16_t *)B, ldb,
2771 (float *)C, ldc,
2772 params->ith, params->nth};
2773 tb.matmul(m, n);
2774 return true;
2775 }
2776#endif
2777 return false;
2778 }
2779
2780 case GGML_TYPE_F16: {
2781#if defined(__AVX512F__)
2782 if (Btype == GGML_TYPE_F16) {
2783 tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
2784 (const ggml_fp16_t *)A, lda,
2785 (const ggml_fp16_t *)B, ldb,
2786 (float *)C, ldc};
2787 return tb.matmul(m, n);
2788 }
2789#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
2790 if (Btype == GGML_TYPE_F16) {
2791 tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
2792 (const ggml_fp16_t *)A, lda,
2793 (const ggml_fp16_t *)B, ldb,
2794 (float *)C, ldc};
2795 return tb.matmul(m, n);
2796 }
2797#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
2798 if (n < 8)
2799 return false;
2800 if (Btype == GGML_TYPE_F16) {
2801 tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
2802 k, (const ggml_fp16_t *)A, lda,
2803 (const ggml_fp16_t *)B, ldb,
2804 (float *)C, ldc};
2805 return tb.matmul(m, n);
2806 }
2807#elif defined(__ARM_NEON) && !defined(_MSC_VER)
2808 if (Btype == GGML_TYPE_F32) {
2809 tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
2810 k, (const ggml_fp16_t *)A, lda,
2811 (const float *)B, ldb,
2812 (float *)C, ldc};
2813 return tb.matmul(m, n);
2814 }
2815#elif defined(__VXE__) || defined(__VXE2__)
2816 if (n < 4)
2817 return false;
2818 if (Btype == GGML_TYPE_F16) {
2819 tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
2820 k, (const ggml_fp16_t *)A, lda,
2821 (const ggml_fp16_t *)B, ldb,
2822 (float *)C, ldc};
2823 return tb.matmul(m, n);
2824 }
2825#endif
2826 return false;
2827 }
2828
2829 case GGML_TYPE_Q8_0: {
2830 if (Btype != GGML_TYPE_Q8_0)
2831 return false;
2832#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2833 tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
2834 k, (const block_q8_0 *)A, lda,
2835 (const block_q8_0 *)B, ldb,
2836 (float *)C, ldc,
2837 params->ith, params->nth};
2838 tb.matmul(m, n);
2839 return true;
2840#elif defined(__ARM_FEATURE_DOTPROD)
2841 tinyBLAS_Q0_ARM<block_q8_0> tb{
2842 k, (const block_q8_0 *)A, lda,
2843 (const block_q8_0 *)B, ldb,
2844 (float *)C, ldc,
2845 params->ith, params->nth};
2846 tb.matmul(m, n);
2847 return true;
2848#elif defined(__MMA__)
2849 //TO-DO: Remove this condition once gemv forwarding is enabled.
2850 if (n < 8 && n != 4)
2851 return false;
2852 if (m < 8 && m != 4)
2853 return false;
2854 tinyBLAS_Q0_PPC<block_q8_0> tb{
2855 k, (const block_q8_0 *)A, lda,
2856 (const block_q8_0 *)B, ldb,
2857 (float *)C, ldc,
2858 params->ith, params->nth};
2859 tb.matmul(m, n);
2860 return true;
2861#else
2862 return false;
2863#endif
2864 }
2865
2866 case GGML_TYPE_Q4_0: {
2867 if (Btype != GGML_TYPE_Q8_0)
2868 return false;
2869#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2870 tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
2871 k, (const block_q4_0 *)A, lda,
2872 (const block_q8_0 *)B, ldb,
2873 (float *)C, ldc,
2874 params->ith, params->nth};
2875 tb.matmul(m, n);
2876 return true;
2877#elif defined(__ARM_FEATURE_DOTPROD)
2878 tinyBLAS_Q0_ARM<block_q4_0> tb{
2879 k, (const block_q4_0 *)A, lda,
2880 (const block_q8_0 *)B, ldb,
2881 (float *)C, ldc,
2882 params->ith, params->nth};
2883 tb.matmul(m, n);
2884 return true;
2885#elif defined(__MMA__)
2886 //TO-DO: Remove this condition once gemv forwarding is enabled.
2887 if (n < 8 && n != 4)
2888 return false;
2889 if (m < 8 && m != 4)
2890 return false;
2891 tinyBLAS_Q0_PPC<block_q4_0> tb{
2892 k, (const block_q4_0 *)A, lda,
2893 (const block_q8_0 *)B, ldb,
2894 (float *)C, ldc,
2895 params->ith, params->nth};
2896 tb.matmul(m, n);
2897 return true;
2898#else
2899 return false;
2900#endif
2901 }
2902
2903 case GGML_TYPE_Q5_0: {
2904 if (Btype != GGML_TYPE_Q8_0)
2905 return false;
2906#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2907 tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
2908 k, (const block_q5_0 *)A, lda,
2909 (const block_q8_0 *)B, ldb,
2910 (float *)C, ldc,
2911 params->ith, params->nth};
2912 tb.matmul(m, n);
2913 return true;
2914#else
2915 return false;
2916#endif
2917 }
2918
2919 case GGML_TYPE_IQ4_NL: {
2920 if (Btype != GGML_TYPE_Q8_0)
2921 return false;
2922#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
2923 tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
2924 k, (const block_iq4_nl *)A, lda,
2925 (const block_q8_0 *)B, ldb,
2926 (float *)C, ldc,
2927 params->ith, params->nth};
2928 tb.matmul(m, n);
2929 return true;
2930#else
2931 return false;
2932#endif
2933 }
2934
2935 default:
2936 return false;
2937 }
2938
2939 (void)params;
2940 (void)m;
2941 (void)n;
2942 (void)k;
2943 (void)A;
2944 (void)lda;
2945 (void)B;
2946 (void)ldb;
2947 (void)C;
2948 (void)ldc;
2949 (void)Atype;
2950 (void)Btype;
2951 (void)Ctype;
2952}
2953