1#pragma once
2
3#include "common.cuh"
4#include "vecdotq.cuh"
5#include "mma.cuh"
6
7#include <climits>
8#include <cstdint>
9
10using namespace ggml_cuda_mma;
11
12#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
13#define MMQ_ITER_K 256
14#define MMQ_NWARPS 8
15
16typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
17typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
18typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
19 float * __restrict__ dst, const int stride, const int i_max, const int j_max);
20
21enum mmq_q8_1_ds_layout {
22 MMQ_Q8_1_DS_LAYOUT_D4,
23 MMQ_Q8_1_DS_LAYOUT_DS4,
24 MMQ_Q8_1_DS_LAYOUT_D2S6,
25};
26
27struct block_q8_1_mmq {
28 // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
29 // The y float data is first grouped as blocks of 128 values.
30 // These blocks are then treated as individual data values and transposed.
31 //
32 // To avoid shared memory bank conflicts each block is padded with 16 bytes.
33 // This padding is also used to store block scales/partial sums.
34 // The scales multiplied with the quantized data are equal to the unquantized values.
35 // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
36 // and are only needed for performance reasons.
37 //
38 // The exact data stored depends on the x data type.
39 union {
40 float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
41 half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
42 half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
43 // stored as d0,d1,s1,s2,s3,s4,s5
44 };
45 int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
46};
47static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
48static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
49
50static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
51 switch (type_x) {
52 case GGML_TYPE_Q4_0:
53 case GGML_TYPE_Q4_1:
54 return MMQ_Q8_1_DS_LAYOUT_DS4;
55 case GGML_TYPE_Q5_0:
56 return MMQ_Q8_1_DS_LAYOUT_D4;
57 case GGML_TYPE_Q5_1:
58 return MMQ_Q8_1_DS_LAYOUT_DS4;
59 case GGML_TYPE_Q8_0:
60 return MMQ_Q8_1_DS_LAYOUT_D4;
61 case GGML_TYPE_MXFP4:
62 return MMQ_Q8_1_DS_LAYOUT_D4;
63 case GGML_TYPE_Q2_K:
64 return MMQ_Q8_1_DS_LAYOUT_D2S6;
65 case GGML_TYPE_Q3_K:
66 return MMQ_Q8_1_DS_LAYOUT_D4;
67 case GGML_TYPE_Q4_K:
68 case GGML_TYPE_Q5_K:
69 return MMQ_Q8_1_DS_LAYOUT_DS4;
70 case GGML_TYPE_Q6_K:
71 case GGML_TYPE_IQ2_XXS:
72 case GGML_TYPE_IQ2_XS:
73 case GGML_TYPE_IQ2_S:
74 case GGML_TYPE_IQ3_XXS:
75 case GGML_TYPE_IQ3_S:
76 return MMQ_Q8_1_DS_LAYOUT_D4;
77 case GGML_TYPE_IQ1_S:
78 return MMQ_Q8_1_DS_LAYOUT_DS4;
79 case GGML_TYPE_IQ4_XS:
80 case GGML_TYPE_IQ4_NL:
81 return MMQ_Q8_1_DS_LAYOUT_D4;
82 default:
83 GGML_ABORT("fatal error");
84 break;
85 }
86}
87
88struct tile_x_sizes {
89 int qs;
90 int dm;
91 int sc;
92};
93
94static int get_mmq_x_max_host(const int cc) {
95 return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
96 GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(arch: cc) >= GGML_CUDA_CC_VOLTA ?
97#ifdef GGML_CUDA_FORCE_MMQ
98 128 : 64;
99#else
100 MMQ_DP4A_MAX_BATCH_SIZE : 64;
101#endif // GGML_CUDA_FORCE_MMQ
102}
103
104static constexpr __device__ int get_mmq_x_max_device() {
105#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
106 return 128;
107#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
108
109#if defined(GGML_USE_HIP)
110 return 64;
111#else // defined(GGML_USE_HIP)
112
113#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
114#ifdef GGML_CUDA_FORCE_MMQ
115 return 128;
116#else // GGML_CUDA_FORCE_MMQ
117 return MMQ_DP4A_MAX_BATCH_SIZE;
118#endif // GGML_CUDA_FORCE_MMQ
119#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
120 return 64;
121#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
122
123#endif // defined(GGML_USE_HIP)
124#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
125}
126
127static int get_mmq_y_host(const int cc) {
128 return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
129 ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(arch: cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
130}
131
132static constexpr __device__ int get_mmq_y_device() {
133#if defined(GGML_USE_HIP)
134#if defined(RDNA1)
135 return 64;
136#else
137 return 128;
138#endif // defined RDNA1
139#else
140#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
141 return 128;
142#else
143 return 64;
144#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
145#endif // defined(GGML_USE_HIP)
146}
147
148// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
149// The K dimension of the tiles has either,
150// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
151// 32 bit elements for the quantized data (does not include scales).
152// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
153// The final tile size in K direction is padded to avoid shared memory bank conflicts,
154// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
155#define MMQ_TILE_NE_K 32
156
157#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
158#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
159#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
160#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
161#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
162#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
163#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
164#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
165#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
166#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
167
168static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
169 switch (type) {
170 case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
171 case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
172 case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
173 case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
174 case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
175 case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
176 case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
177 case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
178 case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
179 case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
180 case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
181 case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
182 case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
183 case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
184 case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
185 case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
186 case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
187 case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
188 case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
189 default: return tile_x_sizes{.qs: 0, .dm: 0, .sc: 0};
190 }
191}
192
193#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
194#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
195#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
196#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
197#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
198
199static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
200static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
201static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
202static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
203static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
204
205static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
206 switch (type) {
207 case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
208 case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
209 case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
210 case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
211 case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
212 case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
213 case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
214 case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
215 case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
216 case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
217 case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
218 case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
219 case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
220 case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
221 case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
222 case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
223 case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
224 case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
225 case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
226 default: return 0;
227 }
228}
229
230// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
231#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
232
233static int mmq_get_granularity_host(const int mmq_x, const int cc) {
234 if (amd_mfma_available(cc)) {
235 return mmq_x >= 128 ? 32 : 16;
236 } else if (turing_mma_available(cc) && mmq_x >= 48) {
237 return 16;
238 } else {
239 return 8;
240 }
241}
242
243#if defined(AMD_MFMA_AVAILABLE)
244static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
245 return mmq_x >= 128 ? 32 : 16;
246}
247#elif defined(TURING_MMA_AVAILABLE)
248static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
249 return mmq_x >= 48 ? 16 : 8;
250}
251#else
252static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
253 return 8;
254}
255#endif // AMD_MFMA_AVAILABLE
256
257#if defined(GGML_USE_HIP)
258static int mmq_get_nwarps_host(const int cc, const int warp_size) {
259 return amd_mfma_available(cc) ? 8 : 256/warp_size;
260}
261#else
262static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
263 return 256/warp_size;
264}
265#endif // (GGML_USE_HIP)
266
267static constexpr __device__ int mmq_get_nwarps_device() {
268#if defined(AMD_MFMA_AVAILABLE)
269 return 8;
270#else
271 return 256/ggml_cuda_get_physical_warp_size();
272#endif // AMD_MFMA_AVAILABLE
273}
274
275// ------------------------------------------------------------
276
277template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
278 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
279 constexpr int nwarps = mmq_get_nwarps_device();
280 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
282#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
283 int * x_qs = (int *) x_tile;
284 float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
285#else
286 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
287 int * x_qs = (int *) x_tile;
288 float * x_df = (float *) (x_qs + txs.qs);
289#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
290
291 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
292 constexpr int nrows = warp_size / threads_per_row;
293 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
294 const int kbx = txi / QI4_0;
295 const int kqsx = txi % QI4_0;
296
297#pragma unroll
298 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
299 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
300
301 if (need_check) {
302 i = min(a: i, b: i_max);
303 }
304
305 const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
306 const int qs0 = get_int_b2(bxi->qs, kqsx);
307
308#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
309 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
310 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
311#else
312 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
313#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
314 }
315
316 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
317 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
318 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
319
320#pragma unroll
321 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
322 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
323
324 if (need_check) {
325 i = min(a: i, b: i_max);
326 }
327
328 const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
329
330#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
331 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
332#else
333 x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
334#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
335 }
336}
337
338template <int mmq_x, int mmq_y>
339static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
340 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
341 constexpr int nwarps = mmq_get_nwarps_device();
342 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
343
344 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
345 const int * x_qs = (const int *) x;
346 const float * x_df = (const float *) x_qs + txs.qs;
347 const int * y_qs = (const int *) y + 4;
348 const half2 * y_ds = (const half2 *) y;
349
350// #pragma unroll
351 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
352 const int k0 = k00 + k01;
353
354#pragma unroll
355 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
356 const int j = j0 + threadIdx.y;
357
358#pragma unroll
359 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
360 const int i = i0 + threadIdx.x;
361
362 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
363
364 int u[2*VDR_Q4_0_Q8_1_MMQ];
365
366#pragma unroll
367 for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
368 u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
369 u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
370 }
371
372 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
373 (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
374 x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
375 }
376 }
377 }
378}
379
380template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
381 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
382 constexpr int nwarps = mmq_get_nwarps_device();
383 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
384
385#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
386 int * x_qs = (int *) x_tile;
387 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
388#else
389 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
390 int * x_qs = (int *) x_tile;
391 half2 * x_dm = (half2 *) (x_qs + txs.qs);
392#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
393
394 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
395 constexpr int nrows = warp_size / threads_per_row;
396 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
397 const int kbx = txi / QI4_1;
398 const int kqsx = txi % QI4_1;
399
400#pragma unroll
401 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
402 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
403
404 if (need_check) {
405 i = min(a: i, b: i_max);
406 }
407
408 const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
409 const int qs0 = get_int_b4(bxi->qs, kqsx);
410
411#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
412 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
413 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
414#else
415 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
416#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
417 }
418
419 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
420 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
421 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
422
423#pragma unroll
424 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
425 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
426
427 if (need_check) {
428 i = min(a: i, b: i_max);
429 }
430
431 const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
432
433#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
434 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
435#else
436 x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
437#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
438 }
439}
440
441template <int mmq_x, int mmq_y>
442static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
443 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
444 constexpr int nwarps = mmq_get_nwarps_device();
445 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
446
447 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
448 const int * x_qs = (const int *) x;
449 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
450 const int * y_qs = (const int *) y + 4;
451 const half2 * y_ds = (const half2 *) y;
452
453// #pragma unroll
454 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
455 const int k0 = k00 + k01;
456
457#pragma unroll
458 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
459 const int j = j0 + threadIdx.y;
460
461#pragma unroll
462 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
463 const int i = i0 + threadIdx.x;
464
465 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
466
467 int u[2*VDR_Q4_1_Q8_1_MMQ];
468
469#pragma unroll
470 for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
471 u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
472 u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
473 }
474
475 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
476 (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
477 x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
478 }
479 }
480 }
481}
482
483template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
484 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
485 constexpr int nwarps = mmq_get_nwarps_device();
486 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
487
488#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
489 int * x_qs = (int *) x_tile;
490 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
491#else
492 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
493 int * x_qs = (int *) x_tile;
494 float * x_df = (float *) (x_qs + txs.qs);
495#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
496
497 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
498 constexpr int nrows = warp_size / threads_per_row;
499 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
500 const int kbx = txi / QI5_0;
501 const int kqsx = txi % QI5_0;
502
503#pragma unroll
504 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
505 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
506
507 if (need_check) {
508 i = min(a: i, b: i_max);
509 }
510
511 const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
512
513 const int ql = get_int_b2(bxi->qs, kqsx);
514 const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
515
516 int qs0 = (ql >> 0) & 0x0F0F0F0F;
517 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
518 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
519 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
520 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
521 qs0 = __vsubss4(a: qs0, b: 0x10101010); // subtract 16
522
523 int qs1 = (ql >> 4) & 0x0F0F0F0F;
524 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
525 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
526 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
527 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
528 qs1 = __vsubss4(a: qs1, b: 0x10101010); // subtract 16
529
530#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
531 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
532 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
533#else
534 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
535 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
536#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
537 }
538
539 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
540 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
541 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
542
543#pragma unroll
544 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
545 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
546
547 if (need_check) {
548 i = min(a: i, b: i_max);
549 }
550
551 const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
552
553#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
554 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
555#else
556 x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
557#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
558 }
559}
560
561template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
562 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
563 constexpr int nwarps = mmq_get_nwarps_device();
564 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
565
566#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
567 int * x_qs = (int *) x_tile;
568 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
569#else
570 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
571 int * x_qs = (int *) x_tile;
572 half2 * x_dm = (half2 *) (x_qs + txs.qs);
573#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
574
575 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
576 constexpr int nrows = warp_size / threads_per_row;
577 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
578 const int kbx = txi / QI5_1;
579 const int kqsx = txi % QI5_1;
580
581#pragma unroll
582 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
583 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
584
585 if (need_check) {
586 i = min(a: i, b: i_max);
587 }
588
589 const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
590
591 const int ql = get_int_b4(bxi->qs, kqsx);
592 const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
593
594 int qs0 = (ql >> 0) & 0x0F0F0F0F;
595 qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
596 qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
597 qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
598 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
599
600 int qs1 = (ql >> 4) & 0x0F0F0F0F;
601 qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
602 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
603 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
604 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
605
606#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
607 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
608 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
609#else
610 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
611 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
612#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
613 }
614
615 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
616 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
617 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
618
619#pragma unroll
620 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
621 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
622
623 if (need_check) {
624 i = min(a: i, b: i_max);
625 }
626
627 const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
628
629#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
630 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
631#else
632 x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
633#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
634 }
635}
636
637template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
638 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
639 constexpr int nwarps = mmq_get_nwarps_device();
640 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
641
642#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
643 int * x_qs = (int *) x_tile;
644 float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
645#else
646 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
647 int * x_qs = (int *) x_tile;
648 float * x_df = (float *) (x_qs + txs.qs);
649#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
650
651 // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
652 constexpr int threads_per_row = 32;
653 constexpr int nrows = warp_size / threads_per_row;
654 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
655 const int kbx = txi / QI8_0;
656 const int kqsx = txi % QI8_0;
657
658#pragma unroll
659 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
660 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
661
662 if (need_check) {
663 i = min(a: i, b: i_max);
664 }
665
666 const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
667
668#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
669 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
670 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
671#else
672 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
673 x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
674#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
675 }
676
677 constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
678 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
679 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
680
681#pragma unroll
682 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
683 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
684
685 if (need_check) {
686 i = min(a: i, b: i_max);
687 }
688
689 const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
690
691#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
692 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
693#else
694 x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
695#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
696 }
697}
698
699template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
700 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
701 constexpr int nwarps = mmq_get_nwarps_device();
702 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
703
704#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
705 int * x_qs = (int *) x_tile;
706 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
707#else
708 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
709 int * x_qs = (int *) x_tile;
710 float * x_df = (float *) (x_qs + txs.qs);
711#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
712
713 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
714 constexpr int nrows = warp_size / threads_per_row;
715 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
716 const int kbx = txi / QI_MXFP4;
717 const int kqsx = txi % QI_MXFP4;
718
719#pragma unroll
720 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
721 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
722
723 if (need_check) {
724 i = min(a: i, b: i_max);
725 }
726
727 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
728
729 const int aux_q4 = get_int_b1(bxi->qs, kqsx);
730 const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
731 const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
732
733#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
734 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
735 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
736#else
737 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
738 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
739#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
740 }
741
742 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
743 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
744 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
745
746#pragma unroll
747 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
748 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
749
750 if (need_check) {
751 i = min(a: i, b: i_max);
752 }
753
754 const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
755
756#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
757 x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
758#else
759 x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
760#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
761 }
762}
763
764template <int mmq_x, int mmq_y>
765static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
766 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
767 constexpr int nwarps = mmq_get_nwarps_device();
768 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
769
770 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
771 const int * x_qs = (const int *) x;
772 const float * x_df = (const float *) x_qs + txs.qs;
773 const int * y_qs = (const int *) y + 4;
774 const float * y_df = (const float *) y;
775
776// #pragma unroll
777 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
778 const int k0 = k00 + k01;
779
780#pragma unroll
781 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
782 const int j = j0 + threadIdx.y;
783
784#pragma unroll
785 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
786 const int i = i0 + threadIdx.x;
787
788 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
789 (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
790 x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
791 }
792 }
793 }
794}
795
796template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
797static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
798 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
799#if defined(AMD_MFMA_AVAILABLE)
800 typedef tile<16, 8, int> tile_A;
801 typedef tile<16, 8, int> tile_B;
802 typedef tile<16, 16, int> tile_C;
803
804 constexpr int granularity = mmq_get_granularity_device(mmq_x);
805 constexpr int rows_per_warp = granularity;
806 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
807
808 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
809
810 const int * x_qs = (const int *) x;
811 const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
812 const int * y_qs = (const int *) y + 4;
813 const float * y_df = (const float *) y;
814 const half2 * y_ds = (const half2 *) y;
815
816 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
817
818 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
819 const int k0 = k00 + k01;
820
821 tile_A A[ntx];
822#pragma unroll
823 for (int n = 0; n < ntx; ++n) {
824 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
825 }
826
827#pragma unroll
828 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
829 tile_B B;
830 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
831
832 float dB;
833 const int j = j0 + tile_C::get_j(0);
834 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
835 dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
836 } else {
837 dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
838 }
839
840#pragma unroll
841 for (int n = 0; n < ntx; ++n) {
842 tile_C C;
843 mma(C, A[n], B);
844
845#pragma unroll
846 for (int l = 0; l < tile_C::ne; ++l) {
847 const int i = i0 + n*tile_A::I + tile_C::get_i(l);
848 const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
849 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
850 }
851 }
852 }
853 }
854#else
855 typedef tile<16, 8, int> tile_A;
856 typedef tile< 8, 8, int> tile_B;
857 typedef tile<16, 8, int> tile_C;
858
859 constexpr int granularity = mmq_get_granularity_device(mmq_x);
860 constexpr int rows_per_warp = 2 * granularity;
861 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
862
863 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
864
865 const int * x_qs = (const int *) x;
866 const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
867 const int * y_qs = (const int *) y + 4;
868 const float * y_df = (const float *) y;
869 const half2 * y_ds = (const half2 *) y;
870
871 tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
872 float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
873
874 const int i0 = (threadIdx.y/ntx)*rows_per_warp;
875
876#pragma unroll
877 for (int n = 0; n < ntx; ++n) {
878#pragma unroll
879 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
880 const int k0 = k00 + k01;
881
882 load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
883 }
884
885#pragma unroll
886 for (int l = 0; l < tile_C::ne/2; ++l) {
887 const int i = i0 + n*tile_A::I + tile_C::get_i(l: 2*l);
888
889#pragma unroll
890 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
891 const int k0 = k00 + k01;
892
893 dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
894 }
895 }
896 }
897
898#pragma unroll
899 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
900#pragma unroll
901 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
902 tile_B B;
903 float dB[tile_C::ne/2];
904
905 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
906
907#pragma unroll
908 for (int l = 0; l < tile_C::ne/2; ++l) {
909 const int j = j0 + tile_C::get_j(l);
910
911 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
912 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
913 } else {
914 dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
915 }
916 }
917
918#pragma unroll
919 for (int n = 0; n < ntx; ++n) {
920 tile_C C;
921 mma(C, A[n][k01/QI8_0], B);
922
923#pragma unroll
924 for (int l = 0; l < tile_C::ne; ++l) {
925 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
926 }
927 }
928 }
929 }
930#endif // defined(AMD_MFMA_AVAILABLE)
931}
932
933template <int mmq_x, int mmq_y>
934static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
935 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
936 constexpr int nwarps = mmq_get_nwarps_device();
937 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
938
939 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
940 const int * x_qs = (const int *) x;
941 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
942 const int * y_qs = (const int *) y + 4;
943 const half2 * y_ds = (const half2 *) y;
944
945// #pragma unroll
946 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
947 const int k0 = k00 + k01;
948
949#pragma unroll
950 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
951 const int j = j0 + threadIdx.y;
952
953#pragma unroll
954 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
955 const int i = i0 + threadIdx.x;
956
957 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
958 (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
959 x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
960 }
961 }
962 }
963}
964
965template <int mmq_x, int mmq_y>
966static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
967 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
968#if defined(AMD_MFMA_AVAILABLE)
969 typedef tile<16, 8, int> tile_A;
970 typedef tile<16, 8, int> tile_B;
971 typedef tile<16, 16, int> tile_C;
972
973 constexpr int granularity = mmq_get_granularity_device(mmq_x);
974 constexpr int rows_per_warp = granularity;
975 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
976
977 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
978
979 const int * x_qs = (const int *) x;
980 const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
981 const int * y_qs = (const int *) y + 4;
982 const half2 * y_dm = (const half2 *) y;
983
984 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
985
986 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
987 const int k0 = k00 + k01;
988
989 tile_A A[ntx];
990#pragma unroll
991 for (int n = 0; n < ntx; ++n) {
992 load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
993 }
994
995#pragma unroll
996 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
997 tile_B B;
998 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
999
1000 const int j = j0 + tile_C::get_j(0);
1001 const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1002
1003#pragma unroll
1004 for (int n = 0; n < ntx; ++n) {
1005 tile_C C;
1006 mma(C, A[n], B);
1007
1008#pragma unroll
1009 for (int l = 0; l < tile_C::ne; ++l) {
1010 const int i = i0 + n*tile_A::I + tile_C::get_i(l);
1011 float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1012 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
1013 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
1014 }
1015 }
1016 }
1017 }
1018#else
1019 typedef tile<16, 8, int> tile_A;
1020 typedef tile< 8, 8, int> tile_B;
1021 typedef tile<16, 8, int> tile_C;
1022
1023 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1024 constexpr int rows_per_warp = 2 * granularity;
1025 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1026
1027 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1028
1029 const int * x_qs = (const int *) x;
1030 const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
1031 const int * y_qs = (const int *) y + 4;
1032 const half2 * y_dm = (const half2 *) y;
1033
1034 tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
1035 float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
1036
1037 const int i0 = (threadIdx.y/ntx)*rows_per_warp;
1038
1039#pragma unroll
1040 for (int n = 0; n < ntx; ++n) {
1041#pragma unroll
1042 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1043 const int k0 = k00 + k01;
1044
1045 load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
1046 }
1047
1048#pragma unroll
1049 for (int l = 0; l < tile_C::ne/2; ++l) {
1050 const int i = i0 + n*tile_A::I + tile_C::get_i(l: 2*l);
1051
1052#pragma unroll
1053 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1054 const int k0 = k00 + k01;
1055
1056 dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
1057 }
1058 }
1059 }
1060
1061#pragma unroll
1062 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1063#pragma unroll
1064 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1065 tile_B B;
1066 float2 dsB[tile_C::ne/2];
1067
1068 load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
1069
1070#pragma unroll
1071 for (int l = 0; l < tile_C::ne/2; ++l) {
1072 const int j = j0 + tile_C::get_j(l);
1073
1074 dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
1075 }
1076
1077#pragma unroll
1078 for (int n = 0; n < ntx; ++n) {
1079 tile_C C;
1080 mma(C, A[n][k01/QI8_1], B);
1081
1082#pragma unroll
1083 for (int l = 0; l < tile_C::ne; ++l) {
1084 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
1085 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
1086 }
1087 }
1088 }
1089 }
1090#endif // defined(AMD_MFMA_AVAILABLE)
1091}
1092
1093// Used for Q3_K, IQ2_S, and IQ2_XS
1094template <int mmq_x, int mmq_y>
1095static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
1096 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1097 constexpr int nwarps = mmq_get_nwarps_device();
1098 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1099
1100 constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
1101 const int * x_qs = (const int *) x;
1102 const float * x_df = (const float *) x_qs + txs.qs;
1103 const int * y_qs = (const int *) y + 4;
1104 const float * y_df = (const float *) y;
1105
1106// #pragma unroll
1107 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
1108 const int k0 = k00 + k01;
1109
1110#pragma unroll
1111 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1112 const int j = j0 + threadIdx.y;
1113
1114#pragma unroll
1115 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1116 const int i = i0 + threadIdx.x;
1117
1118 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
1119 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
1120 &y_qs[j*MMQ_TILE_Y_K + k01],
1121 &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
1122 y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1123 }
1124 }
1125 }
1126}
1127
1128// Used for Q3_K, IQ2_S, and IQ2_XS:
1129template <int mmq_x, int mmq_y>
1130static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
1131 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1132#if defined(AMD_MFMA_AVAILABLE)
1133 typedef tile<16, 8, int> tile_A;
1134 typedef tile<16, 8, int> tile_B;
1135 typedef tile<16, 16, int> tile_C;
1136 typedef tile<64, 2, int> tile_load;
1137
1138 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1139 constexpr int rows_per_warp = granularity;
1140 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1141
1142 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1143
1144 const int * x_qs = (const int *) x;
1145 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1146 const int * y_qs = (const int *) y + 4;
1147 const float * y_df = (const float *) y;
1148
1149 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1150
1151 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1152 const int k0 = k00 + k01;
1153
1154 tile_A A[ntx];
1155#pragma unroll
1156 for (int n = 0; n < ntx; ++n) {
1157 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1158 }
1159
1160#pragma unroll
1161 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1162 tile_B B[1];
1163 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1164
1165 const int j = j0 + tile_C::get_j(0);
1166 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
1167
1168#pragma unroll
1169 for (int n = 0; n < ntx; ++n) {
1170 tile_C C;
1171 mma(C, A[n], B[0]);
1172
1173#pragma unroll
1174 for (int l = 0; l < tile_C::ne; ++l) {
1175 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1176 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
1177 }
1178 }
1179 }
1180 }
1181#elif defined(TURING_MMA_AVAILABLE)
1182
1183 typedef tile<16, 4, int> tile_A;
1184 typedef tile<16, 8, int> tile_A_8;
1185 typedef tile< 8, 4, int> tile_B;
1186 typedef tile<16, 8, int> tile_C;
1187
1188 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1189 constexpr int rows_per_warp = 2 * granularity;
1190 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1191
1192 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1193
1194 const int * x_qs = (const int *) x;
1195 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
1196 const int * y_qs = (const int *) y + 4;
1197 const float * y_df = (const float *) y;
1198
1199 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1200
1201 tile_A A[ntx][8];
1202 float dA[ntx][tile_C::ne/2][8];
1203
1204#pragma unroll
1205 for (int n = 0; n < ntx; ++n) {
1206#pragma unroll
1207 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
1208 const int k0 = k00 + k01;
1209
1210 load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
1211 }
1212
1213#pragma unroll
1214 for (int l = 0; l < tile_C::ne/2; ++l) {
1215 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1216
1217#pragma unroll
1218 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1219 const int k0 = k00 + k01;
1220
1221 dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
1222 }
1223 }
1224 }
1225
1226#pragma unroll
1227 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1228#pragma unroll
1229 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1230 tile_B B[2];
1231 float dB[tile_C::ne/2];
1232
1233 // Here load_generic is faster than load_ldmatrix.
1234 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1235 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1236
1237#pragma unroll
1238 for (int l = 0; l < tile_C::ne/2; ++l) {
1239 const int j = j0 + tile_C::get_j(l);
1240
1241 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
1242 }
1243
1244#pragma unroll
1245 for (int n = 0; n < ntx; ++n) {
1246 tile_C C[2];
1247 mma(C[0], A[n][k01/4 + 0], B[0]);
1248 mma(C[1], A[n][k01/4 + 1], B[1]);
1249
1250#pragma unroll
1251 for (int l = 0; l < tile_C::ne; ++l) {
1252 sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
1253 }
1254 }
1255 }
1256 }
1257#else
1258 GGML_UNUSED_VARS(x, y, sum, k00);
1259 NO_DEVICE_CODE;
1260#endif // AMD_MFMA_AVAILABLE
1261}
1262
1263template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
1264 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1265 constexpr int nwarps = mmq_get_nwarps_device();
1266
1267#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1268 int * x_qs = (int *) x_tile;
1269 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1270#else
1271 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1272 int * x_qs = (int *) x_tile;
1273 half2 * x_dm = (half2 *) (x_qs + txs.qs);
1274#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1275
1276 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
1277 constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
1278 const int kqsx = threadIdx.x % threads_per_row;
1279
1280#pragma unroll
1281 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1282 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1283
1284 if (need_check) {
1285 i = min(a: i, b: i_max);
1286 }
1287
1288 const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
1289
1290 const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1291
1292#pragma unroll
1293 for (int l = 0; l < QR2_K; ++l) {
1294 const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1295
1296 const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
1297
1298#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1299 x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
1300#else
1301 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1302#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1303 }
1304
1305 const int sc_m = bxi->scales[kqsx];
1306#ifdef FAST_FP16_AVAILABLE
1307 const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
1308#else
1309 const float2 bxi_dmf = __half22float2(bxi->dm);
1310 const half2 x_dm_ik = make_half2(x: bxi_dmf.x*(sc_m & 0x0F), y: bxi_dmf.y*(sc_m >> 4));
1311#endif // FAST_FP16_AVAILABLE
1312
1313#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1314 x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
1315#else
1316 x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
1317#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1318 }
1319}
1320
1321template <int mmq_x, int mmq_y>
1322static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
1323 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1324 constexpr int nwarps = mmq_get_nwarps_device();
1325 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1326
1327 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
1328 const int * x_qs = (const int *) x;
1329 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1330 const int * y_qs = (const int *) y + 4;
1331 const half2 * y_ds = (const half2 *) y;
1332
1333 float2 y_df[mmq_x/nwarps];
1334#pragma unroll
1335 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1336 const int j = j0 + threadIdx.y;
1337
1338 y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1339 }
1340
1341#pragma unroll
1342 for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1343 const int k0 = k00 + k01;
1344
1345#pragma unroll
1346 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1347 const int j = j0 + threadIdx.y;
1348
1349#pragma unroll
1350 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1351 const int i = i0 + threadIdx.x;
1352
1353 constexpr int ns = 2;
1354 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1355 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1356 &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1357 &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1358 }
1359 }
1360 }
1361
1362 // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
1363 // As a workaround 2 separate loops are used instead.
1364#pragma unroll
1365 for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
1366 const int k0 = k00 + k01;
1367
1368#pragma unroll
1369 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1370 const int j = j0 + threadIdx.y;
1371
1372#pragma unroll
1373 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1374 const int i = i0 + threadIdx.x;
1375
1376 constexpr int ns = 1;
1377 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
1378 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
1379 &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
1380 &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1381 }
1382 }
1383 }
1384}
1385
1386template <int mmq_x, int mmq_y>
1387static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
1388 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1389#if defined(AMD_MFMA_AVAILABLE)
1390 typedef tile<16, 8, int> tile_A;
1391 typedef tile<16, 8, int> tile_B;
1392 typedef tile<16, 16, int> tile_C;
1393 typedef tile<64, 2, int> tile_load;
1394
1395 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1396 constexpr int rows_per_warp = granularity;
1397 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1398
1399 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1400
1401 const int * x_qs = (const int *) x;
1402 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1403 const int * y_qs = (const int *) y + 4;
1404 const half2 * y_ds = (const half2 *) y;
1405
1406 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
1407
1408 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
1409 const int k0 = k00 + k01;
1410
1411 tile_A A[ntx];
1412#pragma unroll
1413 for (int n = 0; n < ntx; ++n) {
1414 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1415 }
1416
1417#pragma unroll
1418 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1419 tile_B B[1];
1420 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
1421
1422 const int j = j0 + tile_C::get_j(0);
1423 const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
1424 const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
1425 : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
1426 : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
1427
1428 tile_C Cm;
1429 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1430 tile_A A1;
1431 A1.x[0] = 0x01010101;
1432 A1.x[1] = 0x01010101;
1433 mma(Cm, A1, B[0]);
1434 }
1435
1436#pragma unroll
1437 for (int n = 0; n < ntx; ++n) {
1438 tile_C Cd;
1439 mma(Cd, A[n], B[0]);
1440
1441#pragma unroll
1442 for (int l = 0; l < tile_C::ne; ++l) {
1443 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
1444 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
1445 float tmp = Cd.x[l]*dm.x;
1446 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1447 tmp -= Cm.x[l]*dm.y;
1448 }
1449 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
1450 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
1451 }
1452 }
1453 }
1454 }
1455#elif defined(TURING_MMA_AVAILABLE)
1456
1457 typedef tile<16, 4, int> tile_A;
1458 typedef tile<16, 8, int> tile_A_8;
1459 typedef tile< 8, 4, int> tile_B;
1460 typedef tile<16, 8, int> tile_C;
1461
1462 constexpr int granularity = mmq_get_granularity_device(mmq_x);
1463 constexpr int rows_per_warp = 2 * granularity;
1464 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
1465
1466 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
1467
1468 const int * x_qs = (const int *) x;
1469 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
1470 const int * y_qs = (const int *) y + 4;
1471 const half2 * y_ds = (const half2 *) y;
1472
1473 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
1474
1475 tile_A A[ntx][8];
1476 float dA[ntx][tile_C::ne/2][8];
1477 float mA[ntx][tile_C::ne/2][8];
1478
1479#pragma unroll
1480 for (int n = 0; n < ntx; ++n) {
1481#pragma unroll
1482 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1483 const int k0 = k00 + k01;
1484
1485 load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
1486 }
1487 }
1488
1489#pragma unroll
1490 for (int n = 0; n < ntx; ++n) {
1491#pragma unroll
1492 for (int l = 0; l < tile_C::ne/2; ++l) {
1493 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
1494
1495#pragma unroll
1496 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
1497 const int k0 = k00 + k01;
1498
1499 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
1500
1501 dA[n][l][k01/(QI8_1/2)] = dm.x;
1502 mA[n][l][k01/(QI8_1/2)] = dm.y;
1503 }
1504 }
1505 }
1506
1507#pragma unroll
1508 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
1509 float2 dB[tile_C::ne/2];
1510
1511#pragma unroll
1512 for (int l = 0; l < tile_C::ne/2; ++l) {
1513 const int j = j0 + tile_C::get_j(l);
1514
1515 dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
1516 }
1517
1518#pragma unroll
1519 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
1520 tile_B B[2];
1521
1522 // Here load_generic is faster than load_ldmatrix.
1523 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
1524 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
1525
1526 tile_C Cm[2];
1527 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1528 tile_A A1;
1529 A1.x[0] = 0x01010101;
1530 A1.x[1] = 0x01010101;
1531 mma(Cm[0], A1, B[0]);
1532 mma(Cm[1], A1, B[1]);
1533 }
1534
1535#pragma unroll
1536 for (int n = 0; n < ntx; ++n) {
1537 tile_C Cd[2];
1538
1539 mma(Cd[0], A[n][k01/4 + 0], B[0]);
1540 mma(Cd[1], A[n][k01/4 + 1], B[1]);
1541
1542#pragma unroll
1543 for (int l = 0; l < tile_C::ne; ++l) {
1544 float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
1545 if (k01 >= MMQ_TILE_NE_K * 3/4) {
1546 tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
1547 }
1548 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
1549 }
1550 }
1551 }
1552
1553#pragma unroll
1554 for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
1555 float2 sB[tile_C::ne/2];
1556
1557#pragma unroll
1558 for (int l = 0; l < tile_C::ne/2; ++l) {
1559 const int j = j0 + tile_C::get_j(l);
1560
1561 sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
1562 }
1563
1564#pragma unroll
1565 for (int n = 0; n < ntx; ++n) {
1566#pragma unroll
1567 for (int l = 0; l < tile_C::ne; ++l) {
1568 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
1569 sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
1570 }
1571 }
1572 }
1573 }
1574#else
1575 GGML_UNUSED_VARS(x, y, sum, k00);
1576 NO_DEVICE_CODE;
1577#endif // AMD_MFMA_AVAILABLE
1578}
1579
1580template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
1581 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1582 constexpr int nwarps = mmq_get_nwarps_device();
1583 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1584
1585#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1586 int * x_qs = (int *) x_tile;
1587 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
1588#else
1589 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1590 int * x_qs = (int *) x_tile;
1591 float * x_df = (float *) (x_qs + txs.qs);
1592 int * x_sc = (int *) (x_df + txs.dm);
1593#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1594
1595 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
1596 constexpr int nrows = warp_size / threads_per_row;
1597 const int kqsx = threadIdx.x % threads_per_row;
1598
1599#pragma unroll
1600 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1601 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
1602
1603 if (need_check) {
1604 i = min(a: i, b: i_max);
1605 }
1606
1607 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1608
1609 const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
1610 const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
1611
1612#pragma unroll
1613 for (int l = 0; l < QR3_K; ++l) {
1614 const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
1615
1616 const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
1617 const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
1618
1619 const int x_qs_k = __vsubss4(a: x_ql_k | x_qh_k, b: 0x04040404);
1620
1621#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1622 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
1623#else
1624 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
1625#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1626 }
1627 }
1628
1629 constexpr int rows_per_warp = warp_size / 4;
1630#pragma unroll
1631 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1632 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
1633
1634 if (need_check) {
1635 i = min(a: i, b: i_max);
1636 }
1637
1638 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1639
1640 const int ksc = threadIdx.x % 4;
1641
1642 const int ksc_low = ksc % (QI3_K/8);
1643 const int shift_low = 4 * (ksc / (QI3_K/8));
1644 const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
1645
1646 const int ksc_high = QI3_K/8;
1647 const int shift_high = 2 * ksc;
1648 const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
1649
1650 const int sc = __vsubss4(a: sc_low | sc_high, b: 0x20202020);
1651
1652#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1653 const int8_t * sc8 = (const int8_t *) &sc;
1654 const float d = bxi->d;
1655
1656#pragma unroll
1657 for (int l = 0; l < int(sizeof(int)); ++l) {
1658 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
1659 }
1660#else
1661 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
1662#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1663 }
1664
1665#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1666#pragma unroll
1667 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1668 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1669
1670 if (need_check) {
1671 i = min(a: i, b: i_max);
1672 }
1673
1674 const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
1675
1676 x_df[i] = bxi->d;
1677 }
1678#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1679}
1680
1681template <int mmq_x, int mmq_y>
1682static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
1683 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1684 constexpr int nwarps = mmq_get_nwarps_device();
1685 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1686
1687 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
1688 const int * x_qs = (const int *) x;
1689 const float * x_df = (const float *) x_qs + txs.qs;
1690 const int * x_sc = (const int *) x_df + txs.dm;
1691 const int * y_qs = (const int *) y + 4;
1692 const float * y_df = (const float *) y;
1693
1694// #pragma unroll
1695 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
1696 const int k0 = k00 + k01;
1697
1698#pragma unroll
1699 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1700 const int j = j0 + threadIdx.y;
1701
1702#pragma unroll
1703 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1704 const int i = i0 + threadIdx.x;
1705
1706 const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
1707
1708 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
1709 &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
1710 x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
1711 }
1712 }
1713 }
1714}
1715
1716static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
1717 // scale arrangement after the following two lines:
1718 // - ksc == 0: sc0, sc1, sc2, sc3
1719 // - ksc == 1: sc4, sc5, sc6, sc7
1720 // - ksc == 2: m0, m1, m2, m3
1721 // - ksc == 3: m4, m5, m6, m7
1722 return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
1723 ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
1724}
1725
1726template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
1727 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1728 constexpr int nwarps = mmq_get_nwarps_device();
1729 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1730
1731#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1732 int * x_qs = (int *) x_tile;
1733 half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
1734#else
1735 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1736 int * x_qs = (int *) x_tile;
1737 half2 * x_dm = (half2 *) (x_qs + txs.qs);
1738 int * x_sc = (int *) (x_dm + txs.dm);
1739#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1740
1741 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
1742 constexpr int nrows = warp_size / threads_per_row;
1743 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1744
1745#pragma unroll
1746 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1747 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1748
1749 if (need_check) {
1750 i = min(a: i, b: i_max);
1751 }
1752
1753 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1754 const int qs0 = get_int_b4(bxi->qs, txi);
1755
1756#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1757 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
1758 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
1759#else
1760 x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
1761#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1762 }
1763
1764#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1765 constexpr int rows_per_warp = warp_size / 2;
1766#pragma unroll
1767 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1768#if defined(AMD_MFMA_AVAILABLE)
1769 // Need if on AMD instead of % because warp_size == 64
1770 // This causes double work and throughput loss (MI300X)
1771 // H100 loses about 100 t/s with 'if' condition over '%'
1772 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1773 if (i < mmq_y) {
1774#else
1775 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1776 {
1777#endif // defined(AMD_MFMA_AVAILABLE)
1778 if (need_check) {
1779 i = min(i, i_max);
1780 }
1781
1782 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1783
1784 const int * scales = (const int *) bxi->scales;
1785 const int ksc = threadIdx.x % 2;
1786
1787 const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1788 const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1789
1790 const uint8_t * sc8 = (const uint8_t *) &sc32;
1791 const uint8_t * m8 = (const uint8_t *) &m32;
1792
1793 const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1794
1795 #pragma unroll
1796 for (int l = 0; l < sizeof(int); ++l) {
1797 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1798 }
1799 }
1800 }
1801#else
1802#pragma unroll
1803 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1804 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1805
1806 if (need_check) {
1807 i = min(a: i, b: i_max);
1808 }
1809
1810 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
1811
1812 x_dm[i] = bxi->dm;
1813 }
1814 constexpr int rows_per_warp = warp_size / 4;
1815#pragma unroll
1816 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1817 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1818
1819 if (need_check) {
1820 i = min(a: i, b: i_max);
1821 }
1822
1823 const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
1824
1825 const int * scales = (const int *) bxi->scales;
1826
1827 const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1828 const int scales8 = unpack_scales_q45_K(scales, ksc);
1829
1830 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1831 }
1832#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1833}
1834
1835template <int mmq_x, int mmq_y>
1836static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
1837 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1838 constexpr int nwarps = mmq_get_nwarps_device();
1839 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1840
1841 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
1842 const int * x_qs = (const int *) x;
1843 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
1844 const int * x_sc = (const int *) x_dm + txs.dm;
1845 const int * y_qs = (const int *) y + 4;
1846 const half2 * y_ds = (const half2 *) y;
1847
1848// #pragma unroll
1849 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
1850 const int k0 = k00 + k01;
1851
1852#pragma unroll
1853 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
1854 const int j = j0 + threadIdx.y;
1855
1856#pragma unroll
1857 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
1858 const int i = i0 + threadIdx.x;
1859
1860 const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
1861
1862 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
1863 &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
1864 x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
1865 }
1866 }
1867 }
1868}
1869
1870template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
1871 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
1872 constexpr int nwarps = mmq_get_nwarps_device();
1873 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1874
1875#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1876 int * x_qs = (int *) x_tile;
1877 half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
1878#else
1879 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1880 int * x_qs = (int *) x_tile;
1881 half2 * x_dm = (half2 *) (x_qs + txs.qs);
1882 int * x_sc = (int *) (x_dm + txs.dm);
1883#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1884
1885 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
1886 constexpr int nrows = warp_size / threads_per_row;
1887 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
1888
1889#pragma unroll
1890 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
1891 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
1892
1893 if (need_check) {
1894 i = min(a: i, b: i_max);
1895 }
1896
1897 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1898 const int ky = QR5_K*txi;
1899
1900 const int ql = get_int_b4(bxi->qs, txi);
1901 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
1902 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
1903
1904 const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
1905 const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
1906 const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
1907
1908 const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
1909 const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
1910
1911#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1912 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
1913 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
1914#else
1915 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
1916 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
1917#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1918 }
1919
1920#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1921 constexpr int rows_per_warp = warp_size / 2;
1922#pragma unroll
1923 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1924#if defined(AMD_MFMA_AVAILABLE)
1925 // Need if on AMD instead of % because warp_size == 64
1926 // This causes double work and throughput loss (MI300X)
1927 // H100 loses about 100 t/s with 'if' condition over '%'
1928 int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
1929 if (i < mmq_y) {
1930#else
1931 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
1932 {
1933#endif // defined(AMD_MFMA_AVAILABLE)
1934 if (need_check) {
1935 i = min(i, i_max);
1936 }
1937
1938 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1939
1940 const int * scales = (const int *) bxi->scales;
1941 const int ksc = threadIdx.x % 2;
1942
1943 const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
1944 const int m32 = unpack_scales_q45_K(scales, ksc + 2);
1945
1946 const uint8_t * sc8 = (const uint8_t *) &sc32;
1947 const uint8_t * m8 = (const uint8_t *) &m32;
1948
1949 const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
1950
1951#pragma unroll
1952 for (int l = 0; l < int(sizeof(int)); ++l) {
1953 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
1954 }
1955 }
1956 }
1957#else
1958#pragma unroll
1959 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
1960 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
1961
1962 if (need_check) {
1963 i = min(a: i, b: i_max);
1964 }
1965
1966 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1967
1968 x_dm[i] = bxi->dm;
1969 }
1970
1971 constexpr int rows_per_warp = warp_size / 4;
1972#pragma unroll
1973 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
1974 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
1975
1976 if (need_check) {
1977 i = min(a: i, b: i_max);
1978 }
1979
1980 const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
1981
1982 const int * scales = (const int *) bxi->scales;
1983
1984 const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
1985 const int scales8 = unpack_scales_q45_K(scales, ksc);
1986
1987 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
1988 }
1989#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1990}
1991
1992template <int mmq_x, int mmq_y>
1993static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
1994 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1995 constexpr int nwarps = mmq_get_nwarps_device();
1996 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1997
1998 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
1999 const int * x_qs = (const int *) x;
2000 const half2 * x_dm = (const half2 *) x_qs + txs.qs;
2001 const int * x_sc = (const int *) x_dm + txs.dm;
2002 const int * y_qs = (const int *) y + 4;
2003 const half2 * y_ds = (const half2 *) y;
2004
2005// #pragma unroll
2006 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
2007 const int k0 = k00 + k01;
2008
2009#pragma unroll
2010 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2011 const int j = j0 + threadIdx.y;
2012
2013#pragma unroll
2014 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2015 const int i = i0 + threadIdx.x;
2016
2017 const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
2018
2019 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
2020 &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
2021 x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
2022 }
2023 }
2024 }
2025}
2026
2027template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
2028 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2029 constexpr int nwarps = mmq_get_nwarps_device();
2030 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2031
2032#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2033 int * x_qs = (int *) x_tile;
2034 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2035 int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
2036#else
2037 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2038 int * x_qs = (int *) x_tile;
2039 float * x_df = (float *) (x_qs + txs.qs);
2040 int * x_sc = (int *) (x_df + txs.dm);
2041#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2042
2043 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
2044 constexpr int nrows = warp_size / threads_per_row;
2045 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2046
2047#pragma unroll
2048 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2049 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2050
2051 if (need_check) {
2052 i = min(a: i, b: i_max);
2053 }
2054
2055 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2056
2057 const int ql = get_int_b2(bxi->ql, txi);
2058 const int ql0 = (ql >> 0) & 0x0F0F0F0F;
2059 const int ql1 = (ql >> 4) & 0x0F0F0F0F;
2060
2061 const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
2062 const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
2063 const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
2064
2065 const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
2066 const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
2067
2068#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2069 x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
2070 x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
2071#else
2072 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(a: ql0 | qh0, b: 0x20202020);
2073 x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(a: ql1 | qh1, b: 0x20202020);
2074#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2075 }
2076
2077#pragma unroll
2078 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
2079 int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
2080
2081 if (need_check) {
2082 i = min(a: i, b: i_max);
2083 }
2084
2085 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
2086
2087#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2088 x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
2089#else
2090 x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
2091#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2092 }
2093
2094 constexpr int rows_per_warp = warp_size / 4;
2095#pragma unroll
2096 for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
2097 int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
2098
2099 if (need_check) {
2100 i = min(a: i, b: i_max);
2101 }
2102
2103 const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
2104
2105#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2106 x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
2107#else
2108 x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
2109#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2110 }
2111}
2112
2113template <int mmq_x, int mmq_y>
2114static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
2115 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2116 constexpr int nwarps = mmq_get_nwarps_device();
2117 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2118
2119 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
2120 const int * x_qs = (const int *) x;
2121 const float * x_df = (const float *) x_qs + txs.qs;
2122 const int * x_sc = (const int *) x_df + txs.dm;
2123 const int * y_qs = (const int *) y + 4;
2124 const float * y_df = (const float *) y;
2125
2126// #pragma unroll
2127 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
2128 const int k0 = k00 + k01;
2129
2130#pragma unroll
2131 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2132 const int j = j0 + threadIdx.y;
2133
2134#pragma unroll
2135 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2136 const int i = i0 + threadIdx.x;
2137
2138 const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
2139
2140 sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
2141 &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
2142 x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
2143 }
2144 }
2145 }
2146}
2147
2148template <int mmq_x, int mmq_y>
2149static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
2150 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
2151#if defined(AMD_MFMA_AVAILABLE)
2152 typedef tile<16, 8, int> tile_A;
2153 typedef tile<16, 8, int> tile_B;
2154 typedef tile<16, 16, int> tile_C;
2155 typedef tile<64, 2, int> tile_load;
2156
2157 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2158 constexpr int rows_per_warp = granularity;
2159 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2160
2161 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2162
2163 const int * x_qs = (const int *) x;
2164 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2165 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2166 const int * y_qs = (const int *) y + 4;
2167 const float * y_df = (const float *) y;
2168
2169 const int i0 = (threadIdx.y / ntx) * rows_per_warp;
2170
2171 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2172 const int k0 = k00 + k01;
2173
2174 tile_A A[ntx];
2175#pragma unroll
2176 for (int n = 0; n < ntx; ++n) {
2177 load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2178 }
2179
2180#pragma unroll
2181 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2182 tile_B B[1];
2183 load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2184
2185 const int j = j0 + tile_C::get_j(0);
2186 const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
2187
2188#pragma unroll
2189 for (int n = 0; n < ntx; ++n) {
2190 tile_C C;
2191 mma(C, A[n], B[0]);
2192
2193#pragma unroll
2194 for (int l = 0; l < tile_C::ne; ++l) {
2195 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2196 const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2197 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2198 }
2199 }
2200 }
2201 }
2202#elif defined(TURING_MMA_AVAILABLE)
2203
2204 typedef tile<16, 4, int> tile_A;
2205 typedef tile< 8, 4, int> tile_B;
2206 typedef tile<16, 8, int> tile_C;
2207
2208 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2209 constexpr int rows_per_warp = 2 * granularity;
2210 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2211
2212 y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
2213
2214 const int * x_qs = (const int *) x;
2215 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2216 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2217 const int * y_qs = (const int *) y + 4;
2218 const float * y_df = (const float *) y;
2219
2220 const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
2221
2222 tile_A A[ntx][8];
2223 int scA[ntx][tile_C::ne/2][8];
2224 float dA[ntx][tile_C::ne/2];
2225
2226#pragma unroll
2227 for (int n = 0; n < ntx; ++n) {
2228#pragma unroll
2229 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2230 const int k0 = k00 + k01;
2231
2232 load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
2233 load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
2234 }
2235
2236#pragma unroll
2237 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
2238 const int k0 = k00 + k01;
2239
2240#pragma unroll
2241 for (int l = 0; l < tile_C::ne/2; ++l) {
2242 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2243
2244 const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
2245 const int8_t * sc = (const int8_t *) &sc_packed;
2246
2247#pragma unroll
2248 for (int ksc = 0; ksc < sizeof(int); ++ksc) {
2249 scA[n][l][k01/4 + ksc] = sc[ksc];
2250 }
2251 }
2252 }
2253
2254#pragma unroll
2255 for (int l = 0; l < tile_C::ne/2; ++l) {
2256 const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2257
2258 dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2259 }
2260 }
2261
2262#pragma unroll
2263 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2264 float tmp[ntx][tile_C::ne] = {{0.0f}};
2265
2266#pragma unroll
2267 for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
2268 tile_B B[2];
2269 float dB[tile_C::ne/2];
2270
2271 // Here load_generic is faster than load_ldmatrix.
2272 load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
2273 load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2274
2275#pragma unroll
2276 for (int l = 0; l < tile_C::ne/2; ++l) {
2277 const int j = j0 + tile_C::get_j(l);
2278
2279 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2280 }
2281
2282#pragma unroll
2283 for (int n = 0; n < ntx; ++n) {
2284 tile_C C[2];
2285 mma(C[0], A[n][k01/4 + 0], B[0]);
2286 mma(C[1], A[n][k01/4 + 1], B[1]);
2287
2288#pragma unroll
2289 for (int l = 0; l < tile_C::ne; ++l) {
2290 tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
2291 }
2292 }
2293 }
2294
2295#pragma unroll
2296 for (int n = 0; n < ntx; ++n) {
2297#pragma unroll
2298 for (int l = 0; l < tile_C::ne; ++l) {
2299 sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
2300 }
2301 }
2302 }
2303#else
2304 GGML_UNUSED_VARS(x, y, sum, k00);
2305 NO_DEVICE_CODE;
2306#endif // AMD_MFMA_AVAILABLE
2307}
2308
2309template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
2310 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2311 constexpr int nwarps = mmq_get_nwarps_device();
2312 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2313
2314#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2315 int * x_qs = (int *) x_tile;
2316 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2317#else
2318 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
2319 int * x_qs = (int *) x_tile;
2320 float * x_df = (float *) (x_qs + txs.qs);
2321#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2322
2323 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
2324 constexpr int nrows = warp_size / threads_per_row;
2325 const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2326 const int kbx = txi / QI4_NL;
2327 const int kqsx = txi % QI4_NL;
2328
2329#pragma unroll
2330 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2331 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2332
2333 if (need_check) {
2334 i = min(a: i, b: i_max);
2335 }
2336
2337 const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
2338
2339 const int aux_q4 = get_int_b2(bxi->qs, kqsx);
2340 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2341 const int k0 = kbx * (2 * QI4_NL) + kqsx;
2342
2343#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2344 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2345 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
2346#else
2347 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2348 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
2349#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2350 }
2351
2352 constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
2353 constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
2354 const int kbxd = threadIdx.x % blocks_per_tile_x_row;
2355
2356#pragma unroll
2357 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2358 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
2359
2360 if (need_check) {
2361 i = min(a: i, b: i_max);
2362 }
2363
2364 const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
2365
2366#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2367 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
2368#else
2369 x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
2370#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2371 }
2372}
2373
2374template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
2375 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2376 constexpr int nwarps = mmq_get_nwarps_device();
2377 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2378
2379#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2380 int * x_qs = (int *) x_tile;
2381 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2382#else
2383 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
2384 int * x_qs = (int *) x_tile;
2385 float * x_df = (float *) (x_qs + txs.qs);
2386#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2387
2388 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
2389 constexpr int nrows = warp_size / threads_per_row;
2390 const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
2391
2392#pragma unroll
2393 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2394 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2395
2396 if (need_check) {
2397 i = min(a: i, b: i_max);
2398 }
2399
2400 const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
2401
2402 const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
2403 const uint8_t * aux8 = (const uint8_t *) &q2;
2404 const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
2405
2406#pragma unroll
2407 for (int l = 0; l < QR2_XXS; ++l) {
2408 const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
2409 const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
2410
2411 const int signs0 = __vcmpne4(a: ((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), b: 0x00000000);
2412 const int grid0 = __vsub4(a: grid_pos[0] ^ signs0, b: signs0);
2413
2414 const int signs1 = __vcmpne4(a: ((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), b: 0x00000000);
2415 const int grid1 = __vsub4(a: grid_pos[1] ^ signs1, b: signs1);
2416
2417#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2418 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
2419 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
2420#else
2421 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
2422 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
2423#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2424 }
2425
2426 const int ls = aux32 >> 28;
2427 const float d = bxi->d;
2428#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2429 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
2430#else
2431 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
2432#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2433 }
2434}
2435
2436template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
2437 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2438 constexpr int nwarps = mmq_get_nwarps_device();
2439 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2440
2441#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2442 int * x_qs = (int *) x_tile;
2443 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2444#else
2445 constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
2446 int * x_qs = (int *) x_tile;
2447 float * x_df = (float *) (x_qs + txs.qs);
2448#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2449
2450 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
2451 constexpr int nrows = warp_size / threads_per_row;
2452 const int kqsx = threadIdx.x % threads_per_row;
2453
2454#pragma unroll
2455 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2456 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2457
2458 if (need_check) {
2459 i = min(a: i, b: i_max);
2460 }
2461
2462 const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
2463
2464 const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2465 const uint16_t * q2 = (const uint16_t *) &q2_packed;
2466
2467 #pragma unroll
2468 for (int l = 0; l < QR2_XS; ++l) {
2469 const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
2470 const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
2471
2472 const int grid_l = __vsub4(a: grid_pos[0] ^ signs[0], b: signs[0]);
2473 const int grid_h = __vsub4(a: grid_pos[1] ^ signs[1], b: signs[1]);
2474
2475#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2476 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2477 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2478#else
2479 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2480 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2481#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2482 }
2483
2484 const int ls = bxi->scales[kqsx];
2485 const float d = bxi->d;
2486#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2487 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2488 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2489#else
2490 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2491 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2492#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2493 }
2494}
2495
2496template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
2497 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2498 constexpr int nwarps = mmq_get_nwarps_device();
2499 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2500
2501#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2502 int * x_qs = (int *) x_tile;
2503 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2504#else
2505 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
2506 int * x_qs = (int *) x_tile;
2507 float * x_df = (float *) (x_qs + txs.qs);
2508#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2509
2510 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
2511 constexpr int nrows = warp_size / threads_per_row;
2512 const int kqsx = threadIdx.x % threads_per_row;
2513
2514#pragma unroll
2515 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2516 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2517
2518 if (need_check) {
2519 i = min(a: i, b: i_max);
2520 }
2521
2522 const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
2523
2524 const int qs_packed = get_int_b2(bxi->qs, kqsx);
2525 const uint8_t * qs = (const uint8_t *) &qs_packed;
2526
2527 const int qh = bxi->qh[kqsx];
2528
2529 const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
2530 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2531
2532#pragma unroll
2533 for (int l = 0; l < QR2_S; ++l) {
2534 const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
2535
2536 const int signs0 = __vcmpne4(a: ((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), b: 0x00000000);
2537 const int signs1 = __vcmpne4(a: ((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), b: 0x00000000);
2538
2539 const int grid_l = __vsub4(a: grid_pos[0] ^ signs0, b: signs0);
2540 const int grid_h = __vsub4(a: grid_pos[1] ^ signs1, b: signs1);
2541
2542#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2543 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
2544 x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
2545#else
2546 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2547 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2548#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2549 }
2550
2551 const int ls = bxi->scales[kqsx];
2552 const float d = bxi->d;
2553#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2554 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2555 x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2556#else
2557 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
2558 x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
2559#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2560 }
2561}
2562
2563template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
2564 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2565 constexpr int nwarps = mmq_get_nwarps_device();
2566 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2567
2568#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2569 int * x_qs = (int *) x_tile;
2570 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2571#else
2572 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
2573 int * x_qs = (int *) x_tile;
2574 float * x_df = (float *) (x_qs + txs.qs);
2575#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2576
2577 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
2578 constexpr int nrows = warp_size / threads_per_row;
2579 const int kqsx = threadIdx.x % threads_per_row;
2580
2581#pragma unroll
2582 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2583 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2584
2585 if (need_check) {
2586 i = min(a: i, b: i_max);
2587 }
2588
2589 const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
2590
2591 const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2592 const uint8_t * q3 = (const uint8_t *) &q3_packed;
2593 const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
2594
2595#pragma unroll
2596 for (int l = 0; l < QR3_XXS; ++l) {
2597 const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
2598
2599 const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
2600
2601 const int grid_l = __vsub4(a: grid_pos.x ^ signs[0], b: signs[0]);
2602 const int grid_h = __vsub4(a: grid_pos.y ^ signs[1], b: signs[1]);
2603
2604#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2605 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
2606 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
2607#else
2608 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
2609 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
2610#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2611 }
2612
2613 const int ls = aux32 >> 28;
2614 const float d = bxi->d;
2615#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2616 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
2617#else
2618 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
2619#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2620 }
2621}
2622
2623template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
2624 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2625 constexpr int nwarps = mmq_get_nwarps_device();
2626 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2627
2628#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2629 int * x_qs = (int *) x_tile;
2630 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2631#else
2632 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2633 int * x_qs = (int *) x_tile;
2634 float * x_df = (float *) (x_qs + txs.qs);
2635#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2636
2637 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
2638 constexpr int nrows = warp_size / threads_per_row;
2639 const int kqsx = threadIdx.x % threads_per_row;
2640
2641#pragma unroll
2642 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2643 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2644
2645 if (need_check) {
2646 i = min(a: i, b: i_max);
2647 }
2648
2649 const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
2650
2651 const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
2652 const uint8_t * qs = (const uint8_t *) &qs_packed;
2653
2654 const int qh = bxi->qh[kqsx];
2655
2656 const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
2657 const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
2658
2659#pragma unroll
2660 for (int l = 0; l < QR3_S; ++l) {
2661 const int2 grid_pos = make_int2(
2662 iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
2663 iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
2664
2665 const int signs0 = __vcmpne4(a: ((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), b: 0x00000000);
2666 const int signs1 = __vcmpne4(a: ((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), b: 0x00000000);
2667
2668 const int grid_l = __vsub4(a: grid_pos.x ^ signs0, b: signs0);
2669 const int grid_h = __vsub4(a: grid_pos.y ^ signs1, b: signs1);
2670
2671#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2672 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
2673 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
2674#else
2675 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
2676 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
2677#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2678 }
2679
2680 const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
2681 const float d = bxi->d;
2682#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2683 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
2684#else
2685 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
2686#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2687 }
2688}
2689
2690template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
2691 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2692 constexpr int nwarps = mmq_get_nwarps_device();
2693 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2694
2695#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2696 int * x_qs = (int *) x_tile;
2697 half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
2698#else
2699 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
2700 int * x_qs = (int *) x_tile;
2701 half2 * x_ds = (half2 *) (x_qs + txs.qs);
2702#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2703
2704 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
2705 constexpr int nrows = warp_size / threads_per_row;
2706 const int kqsx = threadIdx.x % threads_per_row;
2707
2708#pragma unroll
2709 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
2710 int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
2711
2712 if (need_check) {
2713 i = min(a: i, b: i_max);
2714 }
2715
2716 const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
2717
2718 const int qs_packed = get_int_b2(bxi->qs, kqsx);
2719 const uint8_t * qs = (const uint8_t *) &qs_packed;
2720
2721 const int qh = bxi->qh[kqsx];
2722
2723 #pragma unroll
2724 for (int l = 0; l < QR1_S/2; ++l) {
2725 const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
2726
2727 const int grid0 = (grid >> 0) & 0x0F0F0F0F;
2728 const int grid1 = (grid >> 4) & 0x0F0F0F0F;
2729
2730#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2731 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
2732 x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
2733#else
2734 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
2735 x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
2736#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2737 }
2738
2739 const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
2740 const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
2741
2742#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2743 x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
2744#else
2745 x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(x: d1q, y: d1q*delta);
2746#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2747 }
2748}
2749
2750template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
2751 const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
2752 constexpr int nwarps = mmq_get_nwarps_device();
2753 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2754
2755#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2756 int * x_qs = (int *) x_tile;
2757 float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
2758#else
2759 constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2760 int * x_qs = (int *) x_tile;
2761 float * x_df = (float *) (x_qs + txs.qs);
2762#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2763
2764 constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
2765 constexpr int nrows = warp_size / threads_per_row;
2766 const int kqsx = threadIdx.x % threads_per_row;
2767
2768#pragma unroll
2769 for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
2770 int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
2771
2772 if (need_check) {
2773 i = min(a: i, b: i_max);
2774 }
2775
2776 const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2777
2778 const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2779 const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
2780 const int k0 = 8 * (kqsx / 4) + kqsx % 4;
2781
2782#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2783 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2784 x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2785#else
2786 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
2787 x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
2788#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2789 }
2790
2791 constexpr int rows_per_warp = warp_size / 8;
2792#pragma unroll
2793 for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
2794 int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
2795
2796 if (need_check) {
2797 i = min(a: i, b: i_max);
2798 }
2799
2800 const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
2801
2802 const float d = __half2float(bxi->d);
2803
2804 const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
2805 | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
2806
2807#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2808 x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
2809#else
2810 x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
2811#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2812 }
2813}
2814
2815template<int mmq_x, int mmq_y, bool need_check>
2816static __device__ __forceinline__ void mmq_write_back_dp4a(
2817 const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
2818 const int stride, const int i_max, const int j_max) {
2819 constexpr int nwarps = mmq_get_nwarps_device();
2820 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
2821
2822#pragma unroll
2823 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
2824 const int j = j0 + threadIdx.y;
2825
2826 if (j > j_max) {
2827 return;
2828 }
2829
2830#pragma unroll
2831 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
2832 const int i = i0 + threadIdx.x;
2833
2834 if (need_check && i > i_max) {
2835 continue;
2836 }
2837
2838 dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
2839 }
2840 }
2841}
2842
2843template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
2844static __device__ __forceinline__ void mmq_write_back_mma(
2845 const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
2846 const int stride, const int i_max, const int j_max) {
2847
2848 constexpr int granularity = mmq_get_granularity_device(mmq_x);
2849 constexpr int nwarps = mmq_get_nwarps_device();
2850
2851#if defined(AMD_MFMA_AVAILABLE)
2852 constexpr int tileC_IJ = mmq_get_granularity_device(0);
2853 typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
2854 constexpr int rows_per_warp = granularity;
2855#else
2856 typedef tile<16, 8, int> tile_C;
2857 constexpr int rows_per_warp = 2 * granularity;
2858#endif // defined(AMD_MFMA_AVAILABLE)
2859 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2860
2861 const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
2862#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
2863 static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
2864#else
2865 GGML_UNUSED(nwarps);
2866#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
2867
2868#pragma unroll
2869 for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
2870#pragma unroll
2871 for (int n = 0; n < ntx; ++n) {
2872#pragma unroll
2873 for (int l = 0; l < tile_C::ne; ++l) {
2874 const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
2875
2876 if (j > j_max) {
2877 continue;
2878 }
2879
2880 const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2881
2882 if (need_check && i > i_max) {
2883 continue;
2884 }
2885
2886 dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
2887 }
2888 }
2889 }
2890}
2891
2892// -------------------------------------------------------------------------------------------------------------------------------------
2893
2894template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
2895struct mmq_type_traits;
2896
2897template <int mmq_x, int mmq_y, bool need_check>
2898struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
2899 static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
2900 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
2901 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
2902 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
2903};
2904
2905template <int mmq_x, int mmq_y, bool need_check>
2906struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
2907 static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
2908 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
2909 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2910 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
2911};
2912
2913template <int mmq_x, int mmq_y, bool need_check>
2914struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
2915 static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
2916 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
2917 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2918 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2919};
2920
2921template <int mmq_x, int mmq_y, bool need_check>
2922struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
2923 static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
2924 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
2925 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2926 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
2927};
2928
2929template <int mmq_x, int mmq_y, bool need_check>
2930struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
2931 static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
2932 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
2933 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2934 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2935};
2936
2937template <int mmq_x, int mmq_y, bool need_check>
2938struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
2939 static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
2940 static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
2941 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2942 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2943};
2944
2945template <int mmq_x, int mmq_y, bool need_check>
2946struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
2947 static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
2948 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
2949 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
2950 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
2951};
2952
2953template <int mmq_x, int mmq_y, bool need_check>
2954struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
2955 static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
2956 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
2957 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2958 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
2959};
2960
2961template <int mmq_x, int mmq_y, bool need_check>
2962struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
2963 static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
2964 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
2965 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2966 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
2967};
2968
2969template <int mmq_x, int mmq_y, bool need_check>
2970struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
2971 static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
2972 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
2973 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
2974 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
2975};
2976
2977template <int mmq_x, int mmq_y, bool need_check>
2978struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
2979 static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
2980 static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
2981 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
2982 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
2983};
2984
2985template <int mmq_x, int mmq_y, bool need_check>
2986struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
2987 static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
2988 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
2989 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
2990 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
2991};
2992
2993template <int mmq_x, int mmq_y, bool need_check>
2994struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
2995 static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
2996 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
2997 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
2998 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
2999};
3000
3001template <int mmq_x, int mmq_y, bool need_check>
3002struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
3003 static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
3004 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
3005 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
3006 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
3007};
3008
3009template <int mmq_x, int mmq_y, bool need_check>
3010struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
3011 static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
3012 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
3013 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3014 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3015};
3016
3017template <int mmq_x, int mmq_y, bool need_check>
3018struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
3019 static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
3020 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
3021 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3022 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3023};
3024
3025template <int mmq_x, int mmq_y, bool need_check>
3026struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
3027 static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
3028 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
3029 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
3030 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
3031};
3032
3033template <int mmq_x, int mmq_y, bool need_check>
3034struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
3035 static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
3036 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
3037 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3038 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3039};
3040
3041template <int mmq_x, int mmq_y, bool need_check>
3042struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
3043 static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
3044 static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
3045 static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
3046 static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
3047};
3048
3049template <ggml_type type, int mmq_x, bool need_check, bool fixup>
3050static __device__ __forceinline__ void mul_mat_q_process_tile(
3051 const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
3052 const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3053 const int stride_row_x, const int ncols_y, const int stride_col_dst,
3054 const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
3055
3056 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3057 constexpr int nwarps = mmq_get_nwarps_device();
3058 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3059 constexpr int mmq_y = get_mmq_y_device();
3060 constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
3061
3062 extern __shared__ int data_mul_mat_q[];
3063 int * tile_y = data_mul_mat_q + mmq_x;
3064 int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
3065
3066#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3067 constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
3068 constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
3069#else
3070 constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
3071 constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
3072#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
3073
3074 constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3075
3076 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3077
3078 for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
3079 load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
3080
3081 {
3082 const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
3083#pragma unroll
3084 for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3085 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3086
3087 tile_y[l] = by0[l];
3088 }
3089 }
3090
3091 __syncthreads();
3092
3093 vec_dot(tile_x, tile_y, sum, 0);
3094
3095 __syncthreads();
3096
3097 {
3098 const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
3099#pragma unroll
3100 for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
3101 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
3102
3103 tile_y[l] = by0[l];
3104 }
3105 }
3106
3107 __syncthreads();
3108
3109 vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
3110
3111 __syncthreads();
3112 }
3113
3114 if (fixup) {
3115 write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
3116 } else {
3117 write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
3118 }
3119}
3120
3121
3122// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
3123
3124template <ggml_type type, int mmq_x, bool need_check>
3125#if defined(GGML_USE_HIP)
3126#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3127 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3128#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3129#else
3130#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3131 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
3132#else
3133 __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
3134#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3135#endif // defined(GGML_USE_HIP)
3136static __global__ void mul_mat_q(
3137 const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
3138 const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
3139 const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
3140 const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
3141 const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
3142 const int ncols_max) {
3143
3144 // Skip unused template specializations for faster compilation:
3145 if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
3146 NO_DEVICE_CODE;
3147 return;
3148 }
3149
3150 constexpr int nwarps = mmq_get_nwarps_device();
3151 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3152
3153 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3154 constexpr int mmq_y = get_mmq_y_device();
3155
3156 const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
3157 const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
3158
3159 // Initialize the ids for writing back data with just the index.
3160 // For regular matrix multiplications this is never changed.
3161 // For MoE the correct indices are loaded from ids_dst.
3162 extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
3163#pragma unroll
3164 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3165 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3166
3167 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3168 break;
3169 }
3170
3171 ids_dst_shared[j] = j;
3172 }
3173 __syncthreads();
3174
3175 // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
3176#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3177 {
3178 const int wt = blockIdx.z / nchannels_y;
3179 const int zt = blockIdx.z - wt*nchannels_y;
3180 const int jt = blockIdx.y;
3181 const int it = blockIdx.x;
3182
3183 // Defaults for regular matrix multiplication:
3184 int col_low = 0;
3185 int col_high = ncols_dst;
3186 int col_diff = ncols_dst;
3187 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3188 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3189
3190 if (ids_dst) {
3191 col_low = expert_bounds[zt + 0];
3192 col_high = expert_bounds[zt + 1];
3193 col_diff = col_high - col_low;
3194
3195 offset_y = 0;
3196 offset_dst = 0;
3197
3198 if (jt*mmq_x >= col_diff) {
3199 return;
3200 }
3201
3202 // __syncthreads(); // There is no previous tile that could cause a race condition.
3203#pragma unroll
3204 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3205 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3206
3207 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3208 break;
3209 }
3210
3211 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3212 }
3213 __syncthreads();
3214 }
3215
3216 offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3217 offset_dst += it*mmq_y;
3218
3219 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3220 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3221
3222 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3223
3224 constexpr bool fixup = false;
3225 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3226 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3227 tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
3228 return;
3229 }
3230#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3231
3232 const int64_t blocks_per_ne00 = ncols_x / qk;
3233 constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3234
3235 // kbc == k block continuous, current index in continuous ijk space.
3236 int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3237 int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3238
3239 kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3240 kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
3241
3242 // kb0 == k index when doing the matrix multiplication for an output tile.
3243 int kb0_start = kbc % blocks_per_ne00;
3244 int kb0_stop = min(a: blocks_per_ne00, b: kb0_start + kbc_stop - kbc);
3245 while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
3246 int tmp = kbc;
3247 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3248 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3249 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3250 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3251 const int zt = tmp / (ntx*blocks_per_ne00);
3252 tmp -= zt * (ntx*blocks_per_ne00);
3253 const int jt = tmp / blocks_per_ne00;
3254
3255 // Defaults for regular matrix multiplication:
3256 int col_low = 0;
3257 int col_high = ncols_dst;
3258 int col_diff = ncols_dst;
3259 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3260 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3261
3262 if (ids_dst) {
3263 col_low = expert_bounds[zt + 0];
3264 col_high = expert_bounds[zt + 1];
3265 col_diff = col_high - col_low;
3266
3267 offset_y = 0;
3268 offset_dst = 0;
3269
3270 if (jt*mmq_x >= col_diff) {
3271 kbc += blocks_per_ne00;
3272 kbc -= kbc % blocks_per_ne00;
3273
3274 kb0_start = 0;
3275 kb0_stop = min(a: blocks_per_ne00, b: kbc_stop - kbc);
3276
3277 continue;
3278 }
3279
3280 __syncthreads();
3281#pragma unroll
3282 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3283 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3284
3285 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3286 break;
3287 }
3288
3289 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3290 }
3291 __syncthreads();
3292 }
3293
3294 offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3295 offset_dst += it*mmq_y;
3296
3297 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3298 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3299
3300 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3301
3302 constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
3303 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3304 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3305 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3306
3307 kbc += blocks_per_ne00;
3308 kbc -= kbc % blocks_per_ne00;
3309
3310 kb0_start = 0;
3311 kb0_stop = min(a: blocks_per_ne00, b: kbc_stop - kbc);
3312 }
3313
3314 if (kbc >= kbc_stop) {
3315 return;
3316 }
3317
3318 int tmp = kbc;
3319 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3320 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3321 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3322 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3323 const int zt = tmp / (ntx*blocks_per_ne00);
3324 tmp -= zt * (ntx*blocks_per_ne00);
3325 const int jt = tmp / blocks_per_ne00;
3326
3327 // Defaults for regular matrix multiplication:
3328 int col_low = 0;
3329 int col_high = ncols_dst;
3330 int col_diff = ncols_dst;
3331 int offset_y = wt*stride_sample_y + zt*stride_channel_y;
3332 int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
3333
3334 if (ids_dst) {
3335 col_low = expert_bounds[zt + 0];
3336 col_high = expert_bounds[zt + 1];
3337 col_diff = col_high - col_low;
3338
3339 offset_y = 0;
3340 offset_dst = 0;
3341
3342 if (jt*mmq_x >= col_diff) {
3343 return;
3344 }
3345
3346 // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
3347 __syncthreads();
3348#pragma unroll
3349 for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
3350 const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
3351
3352 if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
3353 break;
3354 }
3355
3356 ids_dst_shared[j] = j;
3357 }
3358 __syncthreads();
3359 }
3360
3361 offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
3362 offset_dst += it*mmq_y;
3363
3364 const int tile_x_max_i = nrows_x - it*mmq_y - 1;
3365 const int tile_y_max_j = col_diff - jt*mmq_x - 1;
3366
3367 const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
3368
3369 constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
3370 mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
3371 (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
3372 tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
3373}
3374
3375
3376template <ggml_type type, int mmq_x, bool need_check>
3377static __global__ void mul_mat_q_stream_k_fixup(
3378 const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
3379 const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
3380 const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
3381 const int ncols_max) {
3382 constexpr int mmq_y = get_mmq_y_device();
3383 constexpr int qk = ggml_cuda_type_traits<type>::qk;
3384 constexpr int blocks_per_iter = MMQ_ITER_K / qk;
3385 const int64_t blocks_per_ne00 = ncols_x / qk;
3386
3387 constexpr int nwarps = mmq_get_nwarps_device();
3388 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3389
3390 float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
3391
3392 const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
3393 const int nty = (nrows_x + mmq_y - 1) / mmq_y;
3394
3395 const int bidx0 = blockIdx.x;
3396
3397 // kbc == k block continuous, current index in continuous ijk space.
3398 int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3399 int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3400
3401 kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
3402 kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
3403
3404 const bool did_not_have_any_data = kbc0 == kbc0_stop;
3405 const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
3406 const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
3407 if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
3408 return;
3409 }
3410
3411 bool any_fixup = false;
3412
3413 // Iterate over previous blocks and sum up partial sums written to fixup buffer.
3414 // All CUDA blocks that get here must have a previous block that needs a fixup.
3415 int64_t bidx = bidx0 - 1;
3416 int64_t kbc_stop = kbc0;
3417 while(true) {
3418 int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
3419 kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
3420
3421 if (kbc == kbc_stop) { // Did not have any data.
3422 bidx--;
3423 kbc_stop = kbc;
3424 continue;
3425 }
3426
3427 any_fixup = true;
3428
3429#pragma unroll
3430 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3431 const int j = j0 + threadIdx.y;
3432
3433#pragma unroll
3434 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3435 const int i = i0 + threadIdx.x;
3436
3437 sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
3438 }
3439 }
3440
3441 // If this block started in a previous tile we are done and don't need to combine additional partial results.
3442 if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
3443 break;
3444 }
3445 bidx--;
3446 kbc_stop = kbc;
3447 }
3448
3449 if (!any_fixup) {
3450 return;
3451 }
3452
3453 int tmp = kbc0;
3454 const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3455 tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
3456 const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
3457 tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
3458 const int zt = tmp / (ntx*blocks_per_ne00);
3459 tmp -= zt * (ntx*blocks_per_ne00);
3460 const int jt = tmp / blocks_per_ne00;
3461
3462 if (!ids_dst) {
3463 const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
3464 dst += offset_dst;
3465
3466 const int i_max = nrows_x - it*mmq_y - 1;
3467 const int j_max = ncols_dst - jt*mmq_x - 1;
3468
3469#pragma unroll
3470 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3471 const int j = j0 + threadIdx.y;
3472
3473 if (j > j_max) {
3474 return;
3475 }
3476
3477#pragma unroll
3478 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3479 const int i = i0 + threadIdx.x;
3480
3481 if (need_check && i > i_max) {
3482 continue;
3483 }
3484
3485 dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3486 }
3487 }
3488 return;
3489 }
3490
3491 __shared__ int ids_dst_shared[mmq_x];
3492 const int col_low = expert_bounds[zt + 0];
3493 const int col_high = expert_bounds[zt + 1];
3494 const int col_diff = col_high - col_low;
3495
3496 for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
3497 ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
3498 }
3499 __syncthreads();
3500
3501 const int offset_dst = it*mmq_y;
3502 dst += offset_dst;
3503
3504 const int i_max = nrows_x - it*mmq_y - 1;
3505 const int j_max = col_diff - jt*mmq_x - 1;
3506
3507#pragma unroll
3508 for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
3509 const int j = j0 + threadIdx.y;
3510
3511 if (j > j_max) {
3512 return;
3513 }
3514
3515#pragma unroll
3516 for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
3517 const int i = i0 + threadIdx.x;
3518
3519 if (need_check && i > i_max) {
3520 continue;
3521 }
3522
3523 dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
3524 }
3525 }
3526}
3527
3528struct mmq_args {
3529 const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
3530 int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
3531 int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
3532 int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
3533 bool use_stream_k; int64_t ncols_max;
3534};
3535
3536template<ggml_type type>
3537static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
3538 const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
3539 const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
3540 const size_t nbs_ids = mmq_x*sizeof(int);
3541 const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
3542 const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
3543 return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
3544}
3545
3546template <ggml_type type, int mmq_x>
3547static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3548 const int id = ggml_cuda_get_device();
3549 const int cc = ggml_cuda_info().devices[id].cc;
3550 const int nsm = ggml_cuda_info().devices[id].nsm;
3551 const int warp_size = ggml_cuda_info().devices[id].warp_size;
3552 const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3553 const int mmq_y = get_mmq_y_host(cc);
3554
3555 const dim3 block_dims(warp_size, nwarps, 1);
3556
3557 const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
3558
3559 CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
3560 CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
3561
3562 const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
3563 const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
3564 const int ntzw = args.nchannels_y * args.nsamples_y;
3565 const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
3566
3567 GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
3568 GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
3569 const int channel_ratio = args.nchannels_y / args.nchannels_x;
3570 const int sample_ratio = args.nsamples_y / args.nsamples_x;
3571
3572 if (!args.use_stream_k) {
3573 if (args.nrows_x % mmq_y == 0) {
3574 constexpr bool need_check = false;
3575 mul_mat_q<type, mmq_x, need_check><<<gridDim: block_nums_xy_tiling, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
3576 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3577 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3578 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3579 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3580 args.ncols_max);
3581 } else {
3582 constexpr bool need_check = true;
3583 mul_mat_q<type, mmq_x, need_check><<<gridDim: block_nums_xy_tiling, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
3584 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
3585 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3586 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3587 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3588 args.ncols_max);
3589 }
3590 return;
3591 }
3592
3593 const dim3 block_nums_stream_k(nsm, 1, 1);
3594 const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
3595
3596 ggml_cuda_pool & pool = ctx.pool(device: id);
3597 ggml_cuda_pool_alloc<float> tmp_fixup(pool);
3598 if (fixup_needed) {
3599 tmp_fixup.alloc(size: block_nums_stream_k.x * mmq_x*mmq_y);
3600 }
3601
3602 if (args.nrows_x % mmq_y == 0) {
3603 constexpr bool need_check = false;
3604 mul_mat_q<type, mmq_x, need_check><<<gridDim: block_nums_stream_k, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
3605 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3606 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3607 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3608 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3609 args.ncols_max);
3610
3611 if (!fixup_needed) {
3612 return;
3613 }
3614
3615 mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<gridDim: block_nums_stream_k, blockDim: block_dims, sharedMem: 0, stream>>>
3616 (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3617 args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3618 args.ncols_max);
3619 } else {
3620 constexpr bool need_check = true;
3621 mul_mat_q<type, mmq_x, need_check><<<gridDim: block_nums_stream_k, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
3622 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
3623 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
3624 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
3625 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
3626 args.ncols_max);
3627
3628 if (!fixup_needed) {
3629 return;
3630 }
3631
3632 mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<gridDim: block_nums_stream_k, blockDim: block_dims, sharedMem: 0, stream>>>
3633 (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
3634 args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
3635 args.ncols_max);
3636 }
3637}
3638
3639template <ggml_type type>
3640void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3641 const int id = ggml_cuda_get_device();
3642 const int cc = ggml_cuda_info().devices[id].cc;
3643 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3644 const int warp_size = ggml_cuda_info().devices[id].warp_size;
3645 const int nwarps = mmq_get_nwarps_host(cc, warp_size);
3646
3647 const int mmq_x_max = get_mmq_x_max_host(cc);
3648 const int mmq_y = get_mmq_y_host(cc);
3649
3650 int mmq_x_best = 0;
3651 int ntiles_x_best = INT_MAX;
3652
3653 for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
3654 const int granularity = mmq_get_granularity_host(mmq_x, cc);
3655
3656 if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
3657 continue;
3658 }
3659
3660 const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
3661
3662 if (ntiles_x < ntiles_x_best) {
3663 mmq_x_best = mmq_x;
3664 ntiles_x_best = ntiles_x;
3665 }
3666 }
3667
3668 switch (mmq_x_best) {
3669 case 8:
3670 launch_mul_mat_q<type, 8>(ctx, args, stream);
3671 break;
3672 case 16:
3673 launch_mul_mat_q<type, 16>(ctx, args, stream);
3674 break;
3675 case 24:
3676 launch_mul_mat_q<type, 24>(ctx, args, stream);
3677 break;
3678 case 32:
3679 launch_mul_mat_q<type, 32>(ctx, args, stream);
3680 break;
3681 case 40:
3682 launch_mul_mat_q<type, 40>(ctx, args, stream);
3683 break;
3684 case 48:
3685 launch_mul_mat_q<type, 48>(ctx, args, stream);
3686 break;
3687 case 56:
3688 launch_mul_mat_q<type, 56>(ctx, args, stream);
3689 break;
3690 case 64:
3691 launch_mul_mat_q<type, 64>(ctx, args, stream);
3692 break;
3693 case 72:
3694 launch_mul_mat_q<type, 72>(ctx, args, stream);
3695 break;
3696 case 80:
3697 launch_mul_mat_q<type, 80>(ctx, args, stream);
3698 break;
3699 case 88:
3700 launch_mul_mat_q<type, 88>(ctx, args, stream);
3701 break;
3702 case 96:
3703 launch_mul_mat_q<type, 96>(ctx, args, stream);
3704 break;
3705 case 104:
3706 launch_mul_mat_q<type, 104>(ctx, args, stream);
3707 break;
3708 case 112:
3709 launch_mul_mat_q<type, 112>(ctx, args, stream);
3710 break;
3711 case 120:
3712 launch_mul_mat_q<type, 120>(ctx, args, stream);
3713 break;
3714 case 128:
3715 launch_mul_mat_q<type, 128>(ctx, args, stream);
3716 break;
3717 default:
3718 fprintf(stderr, format: "mmq_x_best=%d\n", mmq_x_best);
3719 GGML_ABORT("fatal error");
3720 break;
3721 }
3722}
3723
3724#define DECL_MMQ_CASE(type) \
3725 template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
3726
3727extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
3728extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
3729extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
3730extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
3731extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
3732extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
3733extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
3734extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
3735extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
3736extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
3737extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
3738extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
3739extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
3740extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
3741extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
3742extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
3743extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
3744extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
3745extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
3746
3747// -------------------------------------------------------------------------------------------------------------------------
3748
3749void ggml_cuda_mul_mat_q(
3750 ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
3751
3752void ggml_cuda_op_mul_mat_q(
3753 ggml_backend_cuda_context & ctx,
3754 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
3755 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
3756 const int64_t src1_padded_row_size, cudaStream_t stream);
3757
3758bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
3759