1#include "mmvq.cuh"
2#include "quantize.cuh"
3#include "unary.cuh"
4#include "vecdotq.cuh"
5
6#include <cstdint>
7
8typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
9
10static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
11 switch (type) {
12 case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
13 case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
14 case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
15 case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
16 case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
17 case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
18 case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
19 case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
20 case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
21 case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
22 case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
23 case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
24 case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
25 case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
26 case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
27 case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
28 case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
29 case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
30 case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
31 case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
32 default: return nullptr;
33 }
34}
35
36static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
37 switch (type) {
38 case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
39 case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
40 case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
41 case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
42 case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
43 case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
44 case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
45 case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
46 case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
47 case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
48 case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
49 case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
50 case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
51 case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
52 case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
53 case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
54 case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
55 case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
56 default: return 1;
57 }
58}
59
60enum mmvq_parameter_table_id {
61 MMVQ_PARAMETERS_GENERIC = 0,
62 MMVQ_PARAMETERS_GCN,
63 MMVQ_PARAMETERS_RDNA2
64};
65
66static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
67#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
68 return MMVQ_PARAMETERS_RDNA2;
69#elif defined(GCN) || defined(CDNA)
70 return MMVQ_PARAMETERS_GCN;
71#else
72 return MMVQ_PARAMETERS_GENERIC;
73#endif
74}
75
76static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
77 if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
78 return MMVQ_PARAMETERS_RDNA2;
79 }
80 if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
81 return MMVQ_PARAMETERS_GCN;
82 }
83 return MMVQ_PARAMETERS_GENERIC;
84}
85
86static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
87 if (table_id == MMVQ_PARAMETERS_GENERIC) {
88 switch (ncols_dst) {
89 case 1:
90 case 2:
91 case 3:
92 case 4:
93 return 4;
94 case 5:
95 case 6:
96 case 7:
97 case 8:
98 return 2;
99 default:
100 return 1;
101 }
102 } else if (table_id == MMVQ_PARAMETERS_GCN) {
103 switch (ncols_dst) {
104 case 1:
105 case 2:
106 case 3:
107 case 4:
108 return 2;
109 case 5:
110 case 6:
111 case 7:
112 case 8:
113 default:
114 return 1;
115 }
116 }
117 return 1;
118}
119
120static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
121 if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
122 switch (ncols_dst) {
123 case 1:
124 return 1;
125 case 2:
126 case 3:
127 case 4:
128 case 5:
129 case 6:
130 case 7:
131 case 8:
132 return 2;
133 default:
134 return 1;
135 }
136 }
137 return 1;
138}
139
140// tell the compiler to use as many registers as it wants, see nwarps definition below
141template <ggml_type type, int ncols_dst, bool has_fusion>
142__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
143static __global__ void mul_mat_vec_q(
144 const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
145 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
146 const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
147 const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
148 const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
149
150 constexpr int qk = ggml_cuda_type_traits<type>::qk;
151 constexpr int qi = ggml_cuda_type_traits<type>::qi;
152 constexpr int vdr = get_vdr_mmvq(type);
153 constexpr mmvq_parameter_table_id table_id = get_device_table_id();
154 constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
155 constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
156 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
157
158 constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
159
160 const int tid = warp_size*threadIdx.y + threadIdx.x;
161 const int row0 = rows_per_cuda_block*blockIdx.x;
162 const int blocks_per_row_x = ncols_x / qk;
163 constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
164
165 // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
166 const uint32_t channel_dst = blockIdx.y;
167 const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(n: channel_dst, fastdiv_values: channel_ratio);
168 const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(n: channel_dst, fastdiv_values: nchannels_y) : channel_dst;
169 const uint32_t sample_dst = blockIdx.z;
170 const uint32_t sample_x = fastdiv(n: sample_dst, fastdiv_values: sample_ratio);
171 const uint32_t sample_y = sample_dst;
172
173 bool use_gate = false;
174 bool use_bias = false;
175 bool use_gate_bias = false;
176 const void * vgate = nullptr;
177 const float * x_bias = nullptr;
178 const float * gate_bias = nullptr;
179 ggml_glu_op active_glu;
180
181 if constexpr (has_fusion) {
182 use_gate = fusion.gate != nullptr;
183 use_bias = fusion.x_bias != nullptr;
184 use_gate_bias = fusion.gate_bias != nullptr && use_gate;
185 vgate = fusion.gate;
186 x_bias = (const float *) fusion.x_bias;
187 gate_bias = (const float *) fusion.gate_bias;
188 active_glu = fusion.glu_op;
189 }
190
191 const uint32_t channel_bias = ids ? channel_x : channel_dst;
192
193 float x_biases[ncols_dst] = { 0.0f };
194 float gate_biases[ncols_dst] = { 0.0f };
195 if constexpr (has_fusion) {
196 if (use_bias) {
197 x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
198 // 1. Hide latency by prefetching bias and gate here
199 // 2. load only on threads that won't die after partial sum calculation
200 if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
201 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
202#pragma unroll
203 for (int j = 0; j < ncols_dst; ++j) {
204 x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
205 }
206 }
207 }
208 if (use_gate_bias) {
209 gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
210 if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
211 (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
212#pragma unroll
213 for (int j = 0; j < ncols_dst; ++j) {
214 gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
215 }
216 }
217 }
218 }
219
220 // partial sum for each thread
221 float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
222 float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
223
224 const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
225 const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
226
227 for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
228 const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
229
230 // x block quant index when casting the quants to int
231 const int kqs = vdr * (tid % (qi/vdr));
232
233#pragma unroll
234 for (int j = 0; j < ncols_dst; ++j) {
235#pragma unroll
236 for (int i = 0; i < rows_per_cuda_block; ++i) {
237 tmp[j][i] += vec_dot_q_cuda(
238 vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
239 if constexpr (has_fusion) {
240 if (use_gate) {
241 tmp_gate[j][i] += vec_dot_q_cuda(
242 vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
243 }
244 }
245 }
246 }
247 }
248
249 __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
250 __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
251 if constexpr (!has_fusion) {
252 (void) tmp_shared_gate;
253 } else if (!use_gate) {
254 (void) tmp_shared_gate;
255 }
256
257 if (threadIdx.y > 0) {
258#pragma unroll
259 for (int j = 0; j < ncols_dst; ++j) {
260#pragma unroll
261 for (int i = 0; i < rows_per_cuda_block; ++i) {
262 tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
263 if constexpr (has_fusion) {
264 if (use_gate) {
265 tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
266 }
267 }
268 }
269 }
270 }
271 __syncthreads();
272 if (threadIdx.y > 0) {
273 return;
274 }
275
276 dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
277
278 // sum up partial sums and write back result
279#pragma unroll
280 for (int j = 0; j < ncols_dst; ++j) {
281#pragma unroll
282 for (int i = 0; i < rows_per_cuda_block; ++i) {
283#pragma unroll
284 for (int l = 0; l < nwarps-1; ++l) {
285 tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
286 if constexpr (has_fusion) {
287 if (use_gate) {
288 tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
289 }
290 }
291 }
292 tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
293 if constexpr (has_fusion) {
294 if (use_gate) {
295 tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
296 }
297 }
298 }
299
300 if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
301 float result = tmp[j][threadIdx.x];
302 if constexpr (has_fusion) {
303 if (use_bias) {
304 result += x_biases[j];
305 }
306 if (use_gate) {
307 float gate_value = tmp_gate[j][threadIdx.x];
308 if (use_gate_bias) {
309 gate_value += gate_biases[j];
310 }
311 switch (active_glu) {
312 case GGML_GLU_OP_SWIGLU:
313 result *= ggml_cuda_op_silu_single(x: gate_value);
314 break;
315 case GGML_GLU_OP_GEGLU:
316 result *= ggml_cuda_op_gelu_single(x: gate_value);
317 break;
318 case GGML_GLU_OP_SWIGLU_OAI: {
319 result = ggml_cuda_op_swiglu_oai_single(x: gate_value, g: result);
320 break;
321 }
322 default:
323 result = result * gate_value;
324 break;
325 }
326 }
327 }
328 dst[j*stride_col_dst + threadIdx.x] = result;
329 }
330 }
331
332 if constexpr (!has_fusion) {
333 GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
334 }
335}
336
337static std::pair<dim3, dim3> calc_launch_params(
338 const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
339 const int warp_size, const mmvq_parameter_table_id table_id) {
340 const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
341 const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
342 const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
343 return {block_nums, block_dims};
344}
345
346template<ggml_type type, int c_ncols_dst>
347static void mul_mat_vec_q_switch_fusion(
348 const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
349 const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
350 const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
351 const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
352 const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
353 const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
354
355 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
356 if constexpr (c_ncols_dst == 1) {
357 if (has_fusion) {
358 mul_mat_vec_q<type, c_ncols_dst, true><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
359 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
360 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
361 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
362 return;
363 }
364 }
365
366 GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
367
368 mul_mat_vec_q<type, c_ncols_dst, false><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared, stream>>>
369 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
370 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
371 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
372}
373
374template <ggml_type type>
375static void mul_mat_vec_q_switch_ncols_dst(
376 const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
377 const int ncols_x, const int nrows_x, const int ncols_dst,
378 const int stride_row_x, const int stride_col_y, const int stride_col_dst,
379 const int nchannels_x, const int nchannels_y, const int nchannels_dst,
380 const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
381 const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
382 cudaStream_t stream) {
383
384 GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
385 GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
386
387 const uint3 nchannels_y_fd = ids ? init_fastdiv_values(d_64: nchannels_y) : make_uint3(x: 0, y: 0, z: 0);
388 const uint3 channel_ratio_fd = ids ? make_uint3(x: 0, y: 0, z: 0) : init_fastdiv_values(d_64: nchannels_dst / nchannels_x);
389 const uint3 sample_ratio_fd = init_fastdiv_values(d_64: nsamples_dst / nsamples_x);
390
391 const int device = ggml_cuda_get_device();
392 const int warp_size = ggml_cuda_info().devices[device].warp_size;
393 const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
394
395 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
396
397 GGML_ASSERT(!ids || ncols_dst == 1);
398 switch (ncols_dst) {
399 case 1: {
400 constexpr int c_ncols_dst = 1;
401 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
402 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
403 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
404 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
405 dims.first, dims.second, 0, stream);
406 } break;
407 case 2: {
408 constexpr int c_ncols_dst = 2;
409 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
410 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
411 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
412 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
413 dims.first, dims.second, 0, stream);
414 } break;
415 case 3: {
416 constexpr int c_ncols_dst = 3;
417 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
418 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
419 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
420 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
421 dims.first, dims.second, 0, stream);
422 } break;
423 case 4: {
424 constexpr int c_ncols_dst = 4;
425 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
426 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
427 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
428 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
429 dims.first, dims.second, 0, stream);
430 } break;
431 case 5: {
432 constexpr int c_ncols_dst = 5;
433 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
434 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
435 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
436 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
437 dims.first, dims.second, 0, stream);
438 } break;
439 case 6: {
440 constexpr int c_ncols_dst = 6;
441 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
442 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
443 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
444 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
445 dims.first, dims.second, 0, stream);
446 } break;
447 case 7: {
448 constexpr int c_ncols_dst = 7;
449 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
450 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
451 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
452 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
453 dims.first, dims.second, 0, stream);
454 } break;
455 case 8: {
456 constexpr int c_ncols_dst = 8;
457 std::pair<dim3, dim3> dims = calc_launch_params(ncols_dst: c_ncols_dst, nrows_x, nchannels_y: nchannels_dst, nsamples_y: nsamples_dst, warp_size, table_id);
458 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
459 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
460 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
461 dims.first, dims.second, 0, stream);
462 } break;
463 default:
464 GGML_ABORT("fatal error");
465 break;
466 }
467
468 GGML_UNUSED(has_fusion);
469}
470static void mul_mat_vec_q_switch_type(
471 const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
472 const int ncols_x, const int nrows_x, const int ncols_dst,
473 const int stride_row_x, const int stride_col_y, const int stride_col_dst,
474 const int nchannels_x, const int nchannels_y, const int nchannels_dst,
475 const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
476 const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
477 cudaStream_t stream) {
478 switch (type_x) {
479 case GGML_TYPE_Q4_0:
480 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
481 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
482 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
483 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
484 break;
485 case GGML_TYPE_Q4_1:
486 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
487 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
488 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
489 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
490 break;
491 case GGML_TYPE_Q5_0:
492 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
493 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
494 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
495 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
496 break;
497 case GGML_TYPE_Q5_1:
498 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
499 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
500 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
501 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
502 break;
503 case GGML_TYPE_Q8_0:
504 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
505 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
506 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
507 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
508 break;
509 case GGML_TYPE_MXFP4:
510 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
511 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
512 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
513 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
514 break;
515 case GGML_TYPE_Q2_K:
516 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
517 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
518 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
519 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
520 break;
521 case GGML_TYPE_Q3_K:
522 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
523 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
524 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
525 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
526 break;
527 case GGML_TYPE_Q4_K:
528 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
529 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
530 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
531 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
532 break;
533 case GGML_TYPE_Q5_K:
534 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
535 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
536 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
537 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
538 break;
539 case GGML_TYPE_Q6_K:
540 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
541 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
542 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
543 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
544 break;
545 case GGML_TYPE_IQ2_XXS:
546 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
547 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
548 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
549 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
550 break;
551 case GGML_TYPE_IQ2_XS:
552 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
553 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
554 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
555 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
556 break;
557 case GGML_TYPE_IQ2_S:
558 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
559 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
560 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
561 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
562 break;
563 case GGML_TYPE_IQ3_XXS:
564 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
565 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
566 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
567 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
568 break;
569 case GGML_TYPE_IQ1_S:
570 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
571 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
572 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
573 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
574 break;
575 case GGML_TYPE_IQ1_M:
576 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
577 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
578 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
579 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
580 break;
581 case GGML_TYPE_IQ4_NL:
582 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
583 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
584 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
585 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
586 break;
587 case GGML_TYPE_IQ4_XS:
588 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
589 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
590 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
591 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
592 break;
593 case GGML_TYPE_IQ3_S:
594 mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
595 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
596 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
597 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
598 break;
599 default:
600 GGML_ABORT("fatal error");
601 break;
602 }
603}
604
605void ggml_cuda_mul_mat_vec_q(
606 ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
607 const ggml_cuda_mm_fusion_args_host * fusion) {
608 GGML_ASSERT( src1->type == GGML_TYPE_F32);
609 GGML_ASSERT( dst->type == GGML_TYPE_F32);
610 GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
611
612 GGML_TENSOR_BINARY_OP_LOCALS;
613
614 cudaStream_t stream = ctx.stream();
615
616 const size_t ts_src0 = ggml_type_size(src0->type);
617 const size_t ts_src1 = ggml_type_size(src1->type);
618 const size_t ts_dst = ggml_type_size(dst->type);
619
620 GGML_ASSERT( nb00 == ts_src0);
621 GGML_ASSERT( nb10 == ts_src1);
622 GGML_ASSERT( nb0 == ts_dst);
623 GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
624
625 GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
626
627 const float * src1_d = (const float *) src1->data;
628 const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
629 float * dst_d = (float *) dst->data;
630
631 ggml_cuda_mm_fusion_args_device fusion_local{};
632
633 if (fusion) {
634 GGML_ASSERT( !ids || dst->ne[2] == 1);
635 GGML_ASSERT( ids || dst->ne[1] == 1);
636
637 if (fusion->x_bias) {
638 GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
639 GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
640 GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
641 fusion_local.x_bias = fusion->x_bias->data;
642 }
643 if (fusion->gate) {
644 GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
645 fusion_local.gate = fusion->gate->data;
646 }
647 if (fusion->gate_bias) {
648 GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
649 GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
650 GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
651 fusion_local.gate_bias = fusion->gate_bias->data;
652 }
653 fusion_local.glu_op = fusion->glu_op;
654 }
655
656 // If src0 is a temporary compute buffer, clear any potential padding.
657 if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
658 const size_t size_data = ggml_nbytes(src0);
659 const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
660 if (size_alloc > size_data) {
661 GGML_ASSERT(ggml_is_contiguously_allocated(src0));
662 GGML_ASSERT(!src0->view_src);
663 CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
664 }
665 }
666
667 const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
668 ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
669 {
670 const int64_t s11 = src1->nb[1] / ts_src1;
671 const int64_t s12 = src1->nb[2] / ts_src1;
672 const int64_t s13 = src1->nb[3] / ts_src1;
673 quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
674 }
675
676 const int64_t s01 = src0->nb[1] / ts_src0;
677 const int64_t s11 = ne10_padded / QK8_1;
678 const int64_t s1 = dst->nb[1] / ts_dst;
679 const int64_t s02 = src0->nb[2] / ts_src0;
680 const int64_t s2 = dst->nb[2] / ts_dst;
681 const int64_t s03 = src0->nb[3] / ts_src0;
682 const int64_t s3 = dst->nb[3] / ts_dst;
683
684 const int64_t s12 = ne11*s11;
685 const int64_t s13 = ne12*s12;
686
687 // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
688 const int64_t ncols_dst = ids ? ne2 : ne1;
689 const int64_t nchannels_y = ids ? ne11 : ne12;
690 const int64_t nchannels_dst = ids ? ne1 : ne2;
691 const int64_t stride_col_dst = ids ? s2 : s1;
692 const int64_t stride_col_y = ids ? s12 : s11;
693 const int64_t stride_channel_dst = ids ? s1 : s2;
694 const int64_t stride_channel_y = ids ? s11 : s12;
695
696 mul_mat_vec_q_switch_type(
697 src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
698 ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
699 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
700 ne03, ne3, s03, s13, s3, stream);
701}
702
703void ggml_cuda_op_mul_mat_vec_q(
704 ggml_backend_cuda_context & ctx,
705 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
706 const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
707 const int64_t src1_padded_row_size, cudaStream_t stream) {
708
709 const int64_t ne00 = src0->ne[0];
710 const int64_t row_diff = row_high - row_low;
711
712 const int64_t ne10 = src1->ne[0];
713 GGML_ASSERT(ne10 % QK8_1 == 0);
714
715 const int64_t ne0 = dst->ne[0];
716
717 int id = ggml_cuda_get_device();
718
719 // the main device has a larger memory buffer to hold the results from all GPUs
720 // nrows_dst == nrows of the matrix that the kernel writes into
721 const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
722
723 const int stride_row_x = ne00 / ggml_blck_size(src0->type);
724 const int stride_col_y = src1_padded_row_size / QK8_1;
725
726 ggml_cuda_mm_fusion_args_device fusion_local{};
727 mul_mat_vec_q_switch_type(
728 src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
729 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
730
731 GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
732}
733