1#pragma once
2
3#include "mma.cuh"
4#include "common.cuh"
5
6using namespace ggml_cuda_mma;
7
8#define MMF_ROWS_PER_BLOCK 32
9
10struct mmf_ids_data {
11 const int32_t * ids_src_compact = nullptr;
12 const int32_t * ids_dst_compact = nullptr;
13 const int32_t * expert_bounds_dev = nullptr;
14 int n_experts = 0;
15 int sis1 = 0;
16};
17
18void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
19
20bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
21
22template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
23__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
24static __global__ void mul_mat_f(
25 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
26 const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
27 const int stride_col_id, const int stride_row_id,
28 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
29 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
30#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
31 constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
32 constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
33
34 if (!I_16_supported && !I_32_supported) {
35 NO_DEVICE_CODE;
36 return;
37 }
38
39 constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
40
41 typedef tile<I_preferred, 8, T> tile_A;
42 typedef tile<8, 8, T> tile_B;
43 typedef tile<I_preferred, 8, float> tile_C;
44
45 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
46 constexpr int tile_k_padded = warp_size + 4;
47 constexpr int ntA = rows_per_block / tile_A::I;
48 constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
49
50 const int row0 = blockIdx.x * rows_per_block;
51
52 int expert_idx = 0;
53 int col_base = 0;
54
55 const int channel_dst = has_ids ? 0 : blockIdx.y;
56
57 if constexpr (has_ids) {
58 // experts + tiles of ncols_dst are packed in the y dimension
59 int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
60 const int nchannels_x = gridDim.y / col_tiles;
61 const int tile_idx = blockIdx.y / nchannels_x;
62 expert_idx = blockIdx.y - tile_idx * nchannels_x;
63 col_base = tile_idx * cols_per_block;
64 }
65
66 const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
67 const int channel_y = channel_dst;
68 const int sample_dst = blockIdx.z;
69 const int sample_x = sample_dst / sample_ratio;
70 const int sample_y = sample_dst;
71
72 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
73 y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
74 dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
75
76 if constexpr (has_ids) {
77 constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
78 const int64_t col_offset = col_base;
79 y += col_offset * stride_col_y * y_stride_scale;
80 dst += col_offset * stride_col_dst;
81 ids += col_offset * stride_row_id;
82 }
83
84 const float2 * y2 = (const float2 *) y;
85
86 extern __shared__ char data_mmv[];
87
88 char * shmem_base = data_mmv;
89 int * slot_map = (int *) shmem_base;
90 char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
91
92 tile_C C[ntA][ntB];
93
94 T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
95
96 if constexpr (has_ids) {
97 int found = 0;
98
99 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
100 const int j = j0 + threadIdx.y;
101
102 if (threadIdx.x == 0) {
103 slot_map[j] = -1;
104 }
105
106 if (col_base + j >= ncols_dst_total) {
107 continue;
108 }
109
110 const int32_t * __restrict__ id_row = ids + j*stride_row_id;
111
112 for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
113 int match = id_row[k*stride_col_id] == expert_idx;
114
115 if (match) {
116 slot_map[j] = k;
117 found = 1;
118 break;
119 }
120 }
121 }
122
123 if (!__syncthreads_or(a: found)) {
124 return;
125 }
126 }
127
128
129 for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
130 tile_A A[ntA][warp_size / tile_A::J];
131#pragma unroll
132 for (int itA = 0; itA < ntA; ++itA) {
133#pragma unroll
134 for (int i = 0; i < tile_A::I; ++i) {
135 tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
136 }
137#pragma unroll
138 for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
139 load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
140 }
141 }
142
143#pragma unroll
144 for (int itB = 0; itB < ntB; ++itB) {
145 if constexpr (std::is_same_v<T, float>) {
146#pragma unroll
147 for (int j0 = 0; j0 < tile_B::I; ++j0) {
148 const int j = j0 + itB*tile_B::I;
149
150 if constexpr (!has_ids) {
151 tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
152 } else {
153 const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
154 tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
155 }
156 }
157 } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
158#pragma unroll
159 for (int j0 = 0; j0 < tile_B::I; ++j0) {
160 const int j = j0 + itB*tile_B::I;
161
162 if constexpr (!has_ids) {
163 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(x: 0.0f, y: 0.0f);
164 tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
165 } else {
166 const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
167 float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(x: 0.0f, y: 0.0f);
168 tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
169 }
170 }
171 } else {
172 static_assert(std::is_same_v<T, void>, "unsupported type");
173 }
174#pragma unroll
175 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
176 tile_B B;
177 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
178#pragma unroll
179 for (int itA = 0; itA < ntA; ++itA) {
180 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
181 }
182 }
183 }
184 }
185
186 float * buf_iw = (float *) compute_base;
187 constexpr int kiw = nwarps*rows_per_block + 4;
188
189 if (nwarps > 1) {
190 __syncthreads();
191 }
192#pragma unroll
193 for (int itB = 0; itB < ntB; ++itB) {
194#pragma unroll
195 for (int itA = 0; itA < ntA; ++itA) {
196#pragma unroll
197 for (int l = 0; l < tile_C::ne; ++l) {
198 const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
199 const int j = itB*tile_C::J + tile_C::get_j(l);
200 buf_iw[j*kiw + i] = C[itA][itB].x[l];
201 }
202 }
203 }
204
205 if (nwarps > 1) {
206 __syncthreads();
207 }
208
209#pragma unroll
210 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
211 const int j = j0 + threadIdx.y;
212
213 if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
214 return;
215 }
216
217 float sum = 0.0f;
218 static_assert(rows_per_block == warp_size, "need loop/check");
219#pragma unroll
220 for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
221 const int i = i0 + threadIdx.x;
222
223 sum += buf_iw[j*kiw + i];
224 }
225
226 if constexpr (!has_ids) {
227 dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
228 } else {
229 const int slot = (j < cols_per_block) ? slot_map[j] : -1;
230 if (slot >= 0 && (col_base + j) < ncols_dst_total) {
231 dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
232 }
233 }
234 }
235#else
236 GGML_UNUSED_VARS(x, y, ids, dst,
237 ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
238 stride_col_id, stride_row_id,
239 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
240 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
241 NO_DEVICE_CODE;
242#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
243}
244
245//This kernel is for larger batch sizes of mul_mat_id
246template <typename T, int rows_per_block, int cols_per_block, int nwarps>
247__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
248static __global__ void mul_mat_f_ids(
249 const T * __restrict__ x, const float * __restrict__ y,
250 const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
251 const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
252 const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
253 const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
254 const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
255 const uint3 sis1_fd, const uint3 nch_fd) {
256#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
257 constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
258 constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
259
260 if (!I_16_supported && !I_32_supported) {
261 NO_DEVICE_CODE;
262 return;
263 }
264
265 constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
266
267 typedef tile<I_preferred, 8, T> tile_A;
268 typedef tile<8, 8, T> tile_B;
269 typedef tile<I_preferred, 8, float> tile_C;
270
271 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
272 constexpr int tile_k_padded = warp_size + 4;
273 constexpr int ntA = rows_per_block / tile_A::I;
274 constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
275
276 const int row0 = blockIdx.x * rows_per_block;
277
278 const int expert_idx = blockIdx.y;
279 const int expert_start = expert_bounds[expert_idx];
280 const int expert_end = expert_bounds[expert_idx + 1];
281 const int ncols_expert = expert_end - expert_start;
282
283 const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
284 const int tile_idx = blockIdx.z;
285 if (tile_idx >= tiles_for_expert) {
286 return;
287 }
288
289 const int col_base = tile_idx * cols_per_block;
290
291 GGML_UNUSED(channel_ratio);
292
293 const int channel_x = expert_idx;
294 const int sample_dst = 0;
295 const int sample_x = sample_dst / sample_ratio;
296 const int sample_y = sample_dst;
297
298 x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
299 y += int64_t(sample_y) *stride_sample_y;
300 dst += int64_t(sample_dst)*stride_sample_dst;
301
302 const int32_t * ids_src_expert = ids_src_compact + expert_start;
303 const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
304
305 extern __shared__ char data_mmv[];
306 char * compute_base = data_mmv;
307
308 //const float2 * y2 = (const float2 *) y;
309
310 tile_C C[ntA][ntB];
311
312 T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
313
314 for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
315 tile_A A[ntA][warp_size / tile_A::J];
316#pragma unroll
317 for (int itA = 0; itA < ntA; ++itA) {
318#pragma unroll
319 for (int i = 0; i < tile_A::I; ++i) {
320 tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
321 }
322#pragma unroll
323 for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
324 load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
325 }
326 }
327
328 if constexpr (std::is_same_v<T, float>) {
329 float vals_buf[2][tile_B::I];
330 auto gather_tile = [&](int tile_idx_local, float *vals) {
331#pragma unroll
332 for (int j0 = 0; j0 < tile_B::I; ++j0) {
333 const int j = j0 + tile_idx_local*tile_B::I;
334 const int global_j = col_base + j;
335 float val = 0.0f;
336 if (j < cols_per_block && global_j < ncols_expert) {
337 const int src_entry = ids_src_expert[global_j];
338 const uint2 qrm = fast_div_modulo(n: (uint32_t) src_entry, fastdiv_values: sis1_fd);
339 const int token = (int) qrm.x;
340 const int channel = (int) qrm.y;
341 if (token < ncols_dst_total) {
342 val = y[channel*stride_channel_y + token*stride_col_y + col];
343 }
344 }
345 vals[j0] = val;
346 }
347 };
348
349 gather_tile(0, vals_buf[0]);
350
351 int curr_buf = 0;
352 int next_buf = 1;
353#pragma unroll
354 for (int itB = 0; itB < ntB; ++itB) {
355#pragma unroll
356 for (int j0 = 0; j0 < tile_B::I; ++j0) {
357 tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
358 }
359
360 if (itB + 1 < ntB) {
361 gather_tile(itB + 1, vals_buf[next_buf]);
362 }
363
364#pragma unroll
365 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
366 tile_B B;
367 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
368#pragma unroll
369 for (int itA = 0; itA < ntA; ++itA) {
370 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
371 }
372 }
373
374 if (itB + 1 < ntB) {
375 curr_buf ^= 1;
376 next_buf ^= 1;
377 }
378 }
379 } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
380 float2 vals_buf[2][tile_B::I];
381 auto gather_tile = [&](int tile_idx_local, float2 *vals) {
382#pragma unroll
383 for (int j0 = 0; j0 < tile_B::I; ++j0) {
384 const int j = j0 + tile_idx_local*tile_B::I;
385 const int global_j = col_base + j;
386 float2 tmp = make_float2(x: 0.0f, y: 0.0f);
387 if (j < cols_per_block && global_j < ncols_expert) {
388 const int src_entry = ids_src_expert[global_j];
389 const uint2 qrm = fast_div_modulo(n: (uint32_t) src_entry, fastdiv_values: sis1_fd);
390 const int token = (int) qrm.x;
391 const int channel = (int) qrm.y;
392 if (token < ncols_dst_total) {
393 tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
394 }
395 }
396 vals[j0] = tmp;
397 }
398 };
399
400 if (ntB > 0) {
401 gather_tile(0, vals_buf[0]);
402 }
403
404 int curr_buf = 0;
405 int next_buf = 1;
406#pragma unroll
407 for (int itB = 0; itB < ntB; ++itB) {
408#pragma unroll
409 for (int j0 = 0; j0 < tile_B::I; ++j0) {
410 const float2 tmp = vals_buf[curr_buf][j0];
411 tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
412 }
413
414 if (itB + 1 < ntB) {
415 gather_tile(itB + 1, vals_buf[next_buf]);
416 }
417
418#pragma unroll
419 for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
420 tile_B B;
421 load_ldmatrix(B, tile_xy + k0, tile_k_padded);
422#pragma unroll
423 for (int itA = 0; itA < ntA; ++itA) {
424 mma(C[itA][itB], A[itA][k0/tile_B::J], B);
425 }
426 }
427
428 if (itB + 1 < ntB) {
429 curr_buf ^= 1;
430 next_buf ^= 1;
431 }
432 }
433 } else {
434 static_assert(std::is_same_v<T, void>, "unsupported type");
435 }
436 }
437
438 float * buf_iw = (float *) compute_base;
439 constexpr int kiw = nwarps*rows_per_block + 4;
440
441 if (nwarps > 1) {
442 __syncthreads();
443 }
444#pragma unroll
445 for (int itB = 0; itB < ntB; ++itB) {
446#pragma unroll
447 for (int itA = 0; itA < ntA; ++itA) {
448#pragma unroll
449 for (int l = 0; l < tile_C::ne; ++l) {
450 const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
451 const int j = itB*tile_C::J + tile_C::get_j(l);
452 buf_iw[j*kiw + i] = C[itA][itB].x[l];
453 }
454 }
455 }
456
457 if (nwarps > 1) {
458 __syncthreads();
459 }
460
461#pragma unroll
462 for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
463 const int j = j0 + threadIdx.y;
464
465 if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
466 return;
467 }
468
469 float sum = 0.0f;
470 static_assert(rows_per_block == warp_size, "need loop/check");
471#pragma unroll
472 for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
473 const int i = i0 + threadIdx.x;
474
475 sum += buf_iw[j*kiw + i];
476 }
477
478 const int global_j = col_base + j;
479 if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
480 const int dst_entry = ids_dst_expert[global_j];
481 const uint2 qrm = fast_div_modulo(n: (uint32_t) dst_entry, fastdiv_values: nch_fd);
482 const int token = (int) qrm.x;
483 if (token < ncols_dst_total) {
484 const int slot = (int) qrm.y;
485 dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
486 }
487 }
488 }
489#else
490 GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
491 ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
492 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
493 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
494 NO_DEVICE_CODE;
495#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
496}
497
498template<typename T, int cols_per_block, int nwarps>
499static inline void mul_mat_f_switch_ids(
500 const T * x, const float * y, const int32_t * ids, float * dst,
501 const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
502 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
503 const int64_t stride_col_id, const int64_t stride_row_id,
504 const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
505 const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
506 const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
507 const mmf_ids_data * ids_data) {
508 const bool has_ids_data = ids_data && ids_data->ids_src_compact;
509
510 // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
511 // we prefer the normal mul_mat_f path with has_ids=true.
512 if (has_ids_data && ncols_dst > 16) {
513 const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
514 if (max_tiles == 0) {
515 return;
516 }
517 dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
518
519 const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values(d_64: (uint32_t) ids_data->sis1) : make_uint3(x: 0, y: 0, z: 1);
520 const uint3 nch_fd = init_fastdiv_values(d_64: (uint32_t) nchannels_dst);
521
522 mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<gridDim: block_nums_ids, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>>
523 (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
524 ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
525 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
526 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
527 sis1_fd, nch_fd);
528 } else if (ids) {
529 const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
530 dim3 block_nums_ids = block_nums;
531 block_nums_ids.y *= col_tiles;
532
533 mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<gridDim: block_nums_ids, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>>
534 (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
535 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
536 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
537 } else {
538 mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>>
539 (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
540 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
541 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
542 }
543}
544
545template <typename T, int cols_per_block>
546void mul_mat_f_cuda(
547 const T * x, const float * y, const int32_t * ids, float * dst,
548 const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
549 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
550 const int64_t stride_col_id, const int64_t stride_row_id,
551 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
552 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
553 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
554 cudaStream_t stream, const mmf_ids_data * ids_data) {
555 typedef tile<16, 8, T> tile_A_16;
556 typedef tile<32, 8, T> tile_A_32;
557 typedef tile< 8, 8, T> tile_B;
558
559 GGML_ASSERT(ncols_x % 2 == 0);
560 GGML_ASSERT(stride_row % 2 == 0);
561 GGML_ASSERT(stride_col_y % 2 == 0);
562 GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
563 GGML_ASSERT( nsamples_dst % nsamples_x == 0);
564 const int64_t channel_ratio = nchannels_dst / nchannels_x;
565 const int64_t sample_ratio = nsamples_dst / nsamples_x;
566
567 const int device = ggml_cuda_get_device();
568 const int cc = ggml_cuda_info().devices[device].cc;
569 const int warp_size = ggml_cuda_info().devices[device].warp_size;
570
571 int64_t nwarps_best = 1;
572 int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
573 int64_t max_block_size = 256;
574 for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
575 const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
576 if (niter < niter_best) {
577 niter_best = niter;
578 nwarps_best = nwarps;
579 }
580 }
581
582 constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
583 const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
584 const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
585 const int nbytes_shared = std::max(a: nbytes_shared_iter, b: nbytes_shared_combine);
586 const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
587 const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
588 const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
589
590 const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
591 const dim3 block_dims(warp_size, nwarps_best, 1);
592
593 switch (nwarps_best) {
594 case 1: {
595 mul_mat_f_switch_ids<T, cols_per_block, 1>(
596 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
597 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
598 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
599 ids_data);
600 } break;
601 case 2: {
602 mul_mat_f_switch_ids<T, cols_per_block, 2>(
603 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
604 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
605 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
606 ids_data);
607 } break;
608 case 3: {
609 mul_mat_f_switch_ids<T, cols_per_block, 3>(
610 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
611 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
612 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
613 ids_data);
614 } break;
615 case 4: {
616 mul_mat_f_switch_ids<T, cols_per_block, 4>(
617 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
618 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
619 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
620 ids_data);
621 } break;
622 case 5: {
623 mul_mat_f_switch_ids<T, cols_per_block, 5>(
624 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
625 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
626 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
627 ids_data);
628 } break;
629 case 6: {
630 mul_mat_f_switch_ids<T, cols_per_block, 6>(
631 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
632 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
633 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
634 ids_data);
635 } break;
636 case 7: {
637 mul_mat_f_switch_ids<T, cols_per_block, 7>(
638 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
639 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
640 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
641 ids_data);
642 } break;
643 case 8: {
644 mul_mat_f_switch_ids<T, cols_per_block, 8>(
645 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
646 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
647 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
648 ids_data);
649 } break;
650 default: {
651 GGML_ABORT("fatal error");
652 } break;
653 }
654
655 GGML_UNUSED_VARS(nchannels_y);
656}
657
658template <typename T>
659static void mul_mat_f_switch_cols_per_block(
660 const T * x, const float * y, const int32_t * ids, float * dst,
661 const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
662 const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
663 const int64_t stride_col_id, const int stride_row_id,
664 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
665 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
666 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
667 cudaStream_t stream, const mmf_ids_data * ids_data) {
668
669 const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
670
671 GGML_ASSERT(ids || ncols_dst <= 16);
672
673 switch (ncols_case) {
674 case 1: {
675 mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
676 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
677 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
678 } break;
679 case 2: {
680 mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
681 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
682 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
683 } break;
684 case 3: {
685 mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
686 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
687 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
688 } break;
689 case 4: {
690 mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
691 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
692 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
693 } break;
694 case 5: {
695 mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
696 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
697 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
698 } break;
699 case 6: {
700 mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
701 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
702 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
703 } break;
704 case 7: {
705 mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
706 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
707 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
708 } break;
709 case 8: {
710 mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
711 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
712 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
713 } break;
714 case 9: {
715 mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
716 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
717 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
718 } break;
719 case 10: {
720 mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
721 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
722 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
723 } break;
724 case 11: {
725 mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
726 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
727 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
728 } break;
729 case 12: {
730 mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
731 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
732 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
733 } break;
734 case 13: {
735 mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
736 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
737 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
738 } break;
739 case 14: {
740 mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
741 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
742 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
743 } break;
744 case 15: {
745 mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
746 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
747 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
748 } break;
749 case 16: {
750 mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
751 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
752 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
753 } break;
754 default: {
755 GGML_ABORT("fatal error");
756 } break;
757 }
758}
759
760#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
761 template void mul_mat_f_cuda<T, ncols_dst>( \
762 const T * x, const float * y, const int32_t * ids, float * dst, \
763 const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
764 const int64_t stride_col_id, const int64_t stride_row_id, \
765 const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
766 const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
767 const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
768 cudaStream_t stream, const mmf_ids_data * ids_data);
769
770#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
771#define DECL_MMF_CASE_EXTERN(ncols_dst) \
772 extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
773 extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
774 extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
775
776#define DECL_MMF_CASE(ncols_dst) \
777 DECL_MMF_CASE_HELPER(float, ncols_dst) \
778 DECL_MMF_CASE_HELPER(half2, ncols_dst) \
779 DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
780
781DECL_MMF_CASE_EXTERN(1);
782DECL_MMF_CASE_EXTERN(2);
783DECL_MMF_CASE_EXTERN(3);
784DECL_MMF_CASE_EXTERN(4);
785DECL_MMF_CASE_EXTERN(5);
786DECL_MMF_CASE_EXTERN(6);
787DECL_MMF_CASE_EXTERN(7);
788DECL_MMF_CASE_EXTERN(8);
789DECL_MMF_CASE_EXTERN(9);
790DECL_MMF_CASE_EXTERN(10);
791DECL_MMF_CASE_EXTERN(11);
792DECL_MMF_CASE_EXTERN(12);
793DECL_MMF_CASE_EXTERN(13);
794DECL_MMF_CASE_EXTERN(14);
795DECL_MMF_CASE_EXTERN(15);
796DECL_MMF_CASE_EXTERN(16);
797#else
798#define DECL_MMF_CASE(ncols_dst)
799#endif
800