| 1 | #pragma once |
| 2 | |
| 3 | #include "mma.cuh" |
| 4 | #include "common.cuh" |
| 5 | |
| 6 | using namespace ggml_cuda_mma; |
| 7 | |
| 8 | #define MMF_ROWS_PER_BLOCK 32 |
| 9 | |
| 10 | struct mmf_ids_data { |
| 11 | const int32_t * ids_src_compact = nullptr; |
| 12 | const int32_t * ids_dst_compact = nullptr; |
| 13 | const int32_t * expert_bounds_dev = nullptr; |
| 14 | int n_experts = 0; |
| 15 | int sis1 = 0; |
| 16 | }; |
| 17 | |
| 18 | void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); |
| 19 | |
| 20 | bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); |
| 21 | |
| 22 | template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids> |
| 23 | __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) |
| 24 | static __global__ void mul_mat_f( |
| 25 | const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, |
| 26 | const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, |
| 27 | const int stride_col_id, const int stride_row_id, |
| 28 | const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, |
| 29 | const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { |
| 30 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 31 | constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); |
| 32 | constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); |
| 33 | |
| 34 | if (!I_16_supported && !I_32_supported) { |
| 35 | NO_DEVICE_CODE; |
| 36 | return; |
| 37 | } |
| 38 | |
| 39 | constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster. |
| 40 | |
| 41 | typedef tile<I_preferred, 8, T> tile_A; |
| 42 | typedef tile<8, 8, T> tile_B; |
| 43 | typedef tile<I_preferred, 8, float> tile_C; |
| 44 | |
| 45 | constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |
| 46 | constexpr int tile_k_padded = warp_size + 4; |
| 47 | constexpr int ntA = rows_per_block / tile_A::I; |
| 48 | constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; |
| 49 | |
| 50 | const int row0 = blockIdx.x * rows_per_block; |
| 51 | |
| 52 | int expert_idx = 0; |
| 53 | int col_base = 0; |
| 54 | |
| 55 | const int channel_dst = has_ids ? 0 : blockIdx.y; |
| 56 | |
| 57 | if constexpr (has_ids) { |
| 58 | // experts + tiles of ncols_dst are packed in the y dimension |
| 59 | int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block; |
| 60 | const int nchannels_x = gridDim.y / col_tiles; |
| 61 | const int tile_idx = blockIdx.y / nchannels_x; |
| 62 | expert_idx = blockIdx.y - tile_idx * nchannels_x; |
| 63 | col_base = tile_idx * cols_per_block; |
| 64 | } |
| 65 | |
| 66 | const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); |
| 67 | const int channel_y = channel_dst; |
| 68 | const int sample_dst = blockIdx.z; |
| 69 | const int sample_x = sample_dst / sample_ratio; |
| 70 | const int sample_y = sample_dst; |
| 71 | |
| 72 | x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; |
| 73 | y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); |
| 74 | dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); |
| 75 | |
| 76 | if constexpr (has_ids) { |
| 77 | constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2; |
| 78 | const int64_t col_offset = col_base; |
| 79 | y += col_offset * stride_col_y * y_stride_scale; |
| 80 | dst += col_offset * stride_col_dst; |
| 81 | ids += col_offset * stride_row_id; |
| 82 | } |
| 83 | |
| 84 | const float2 * y2 = (const float2 *) y; |
| 85 | |
| 86 | extern __shared__ char data_mmv[]; |
| 87 | |
| 88 | char * shmem_base = data_mmv; |
| 89 | int * slot_map = (int *) shmem_base; |
| 90 | char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; |
| 91 | |
| 92 | tile_C C[ntA][ntB]; |
| 93 | |
| 94 | T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); |
| 95 | |
| 96 | if constexpr (has_ids) { |
| 97 | int found = 0; |
| 98 | |
| 99 | for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { |
| 100 | const int j = j0 + threadIdx.y; |
| 101 | |
| 102 | if (threadIdx.x == 0) { |
| 103 | slot_map[j] = -1; |
| 104 | } |
| 105 | |
| 106 | if (col_base + j >= ncols_dst_total) { |
| 107 | continue; |
| 108 | } |
| 109 | |
| 110 | const int32_t * __restrict__ id_row = ids + j*stride_row_id; |
| 111 | |
| 112 | for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { |
| 113 | int match = id_row[k*stride_col_id] == expert_idx; |
| 114 | |
| 115 | if (match) { |
| 116 | slot_map[j] = k; |
| 117 | found = 1; |
| 118 | break; |
| 119 | } |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | if (!__syncthreads_or(a: found)) { |
| 124 | return; |
| 125 | } |
| 126 | } |
| 127 | |
| 128 | |
| 129 | for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { |
| 130 | tile_A A[ntA][warp_size / tile_A::J]; |
| 131 | #pragma unroll |
| 132 | for (int itA = 0; itA < ntA; ++itA) { |
| 133 | #pragma unroll |
| 134 | for (int i = 0; i < tile_A::I; ++i) { |
| 135 | tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; |
| 136 | } |
| 137 | #pragma unroll |
| 138 | for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { |
| 139 | load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); |
| 140 | } |
| 141 | } |
| 142 | |
| 143 | #pragma unroll |
| 144 | for (int itB = 0; itB < ntB; ++itB) { |
| 145 | if constexpr (std::is_same_v<T, float>) { |
| 146 | #pragma unroll |
| 147 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 148 | const int j = j0 + itB*tile_B::I; |
| 149 | |
| 150 | if constexpr (!has_ids) { |
| 151 | tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; |
| 152 | } else { |
| 153 | const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; |
| 154 | tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; |
| 155 | } |
| 156 | } |
| 157 | } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { |
| 158 | #pragma unroll |
| 159 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 160 | const int j = j0 + itB*tile_B::I; |
| 161 | |
| 162 | if constexpr (!has_ids) { |
| 163 | const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(x: 0.0f, y: 0.0f); |
| 164 | tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; |
| 165 | } else { |
| 166 | const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; |
| 167 | float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(x: 0.0f, y: 0.0f); |
| 168 | tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; |
| 169 | } |
| 170 | } |
| 171 | } else { |
| 172 | static_assert(std::is_same_v<T, void>, "unsupported type" ); |
| 173 | } |
| 174 | #pragma unroll |
| 175 | for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { |
| 176 | tile_B B; |
| 177 | load_ldmatrix(B, tile_xy + k0, tile_k_padded); |
| 178 | #pragma unroll |
| 179 | for (int itA = 0; itA < ntA; ++itA) { |
| 180 | mma(C[itA][itB], A[itA][k0/tile_B::J], B); |
| 181 | } |
| 182 | } |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | float * buf_iw = (float *) compute_base; |
| 187 | constexpr int kiw = nwarps*rows_per_block + 4; |
| 188 | |
| 189 | if (nwarps > 1) { |
| 190 | __syncthreads(); |
| 191 | } |
| 192 | #pragma unroll |
| 193 | for (int itB = 0; itB < ntB; ++itB) { |
| 194 | #pragma unroll |
| 195 | for (int itA = 0; itA < ntA; ++itA) { |
| 196 | #pragma unroll |
| 197 | for (int l = 0; l < tile_C::ne; ++l) { |
| 198 | const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); |
| 199 | const int j = itB*tile_C::J + tile_C::get_j(l); |
| 200 | buf_iw[j*kiw + i] = C[itA][itB].x[l]; |
| 201 | } |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | if (nwarps > 1) { |
| 206 | __syncthreads(); |
| 207 | } |
| 208 | |
| 209 | #pragma unroll |
| 210 | for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { |
| 211 | const int j = j0 + threadIdx.y; |
| 212 | |
| 213 | if (j0 + nwarps > cols_per_block && j >= cols_per_block) { |
| 214 | return; |
| 215 | } |
| 216 | |
| 217 | float sum = 0.0f; |
| 218 | static_assert(rows_per_block == warp_size, "need loop/check" ); |
| 219 | #pragma unroll |
| 220 | for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { |
| 221 | const int i = i0 + threadIdx.x; |
| 222 | |
| 223 | sum += buf_iw[j*kiw + i]; |
| 224 | } |
| 225 | |
| 226 | if constexpr (!has_ids) { |
| 227 | dst[j*stride_col_dst + row0 + threadIdx.x] = sum; |
| 228 | } else { |
| 229 | const int slot = (j < cols_per_block) ? slot_map[j] : -1; |
| 230 | if (slot >= 0 && (col_base + j) < ncols_dst_total) { |
| 231 | dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; |
| 232 | } |
| 233 | } |
| 234 | } |
| 235 | #else |
| 236 | GGML_UNUSED_VARS(x, y, ids, dst, |
| 237 | ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 238 | stride_col_id, stride_row_id, |
| 239 | channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 240 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |
| 241 | NO_DEVICE_CODE; |
| 242 | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 243 | } |
| 244 | |
| 245 | //This kernel is for larger batch sizes of mul_mat_id |
| 246 | template <typename T, int rows_per_block, int cols_per_block, int nwarps> |
| 247 | __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) |
| 248 | static __global__ void mul_mat_f_ids( |
| 249 | const T * __restrict__ x, const float * __restrict__ y, |
| 250 | const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, |
| 251 | const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, |
| 252 | const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, |
| 253 | const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, |
| 254 | const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, |
| 255 | const uint3 sis1_fd, const uint3 nch_fd) { |
| 256 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 257 | constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported(); |
| 258 | constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported(); |
| 259 | |
| 260 | if (!I_16_supported && !I_32_supported) { |
| 261 | NO_DEVICE_CODE; |
| 262 | return; |
| 263 | } |
| 264 | |
| 265 | constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster. |
| 266 | |
| 267 | typedef tile<I_preferred, 8, T> tile_A; |
| 268 | typedef tile<8, 8, T> tile_B; |
| 269 | typedef tile<I_preferred, 8, float> tile_C; |
| 270 | |
| 271 | constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |
| 272 | constexpr int tile_k_padded = warp_size + 4; |
| 273 | constexpr int ntA = rows_per_block / tile_A::I; |
| 274 | constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; |
| 275 | |
| 276 | const int row0 = blockIdx.x * rows_per_block; |
| 277 | |
| 278 | const int expert_idx = blockIdx.y; |
| 279 | const int expert_start = expert_bounds[expert_idx]; |
| 280 | const int expert_end = expert_bounds[expert_idx + 1]; |
| 281 | const int ncols_expert = expert_end - expert_start; |
| 282 | |
| 283 | const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; |
| 284 | const int tile_idx = blockIdx.z; |
| 285 | if (tile_idx >= tiles_for_expert) { |
| 286 | return; |
| 287 | } |
| 288 | |
| 289 | const int col_base = tile_idx * cols_per_block; |
| 290 | |
| 291 | GGML_UNUSED(channel_ratio); |
| 292 | |
| 293 | const int channel_x = expert_idx; |
| 294 | const int sample_dst = 0; |
| 295 | const int sample_x = sample_dst / sample_ratio; |
| 296 | const int sample_y = sample_dst; |
| 297 | |
| 298 | x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; |
| 299 | y += int64_t(sample_y) *stride_sample_y; |
| 300 | dst += int64_t(sample_dst)*stride_sample_dst; |
| 301 | |
| 302 | const int32_t * ids_src_expert = ids_src_compact + expert_start; |
| 303 | const int32_t * ids_dst_expert = ids_dst_compact + expert_start; |
| 304 | |
| 305 | extern __shared__ char data_mmv[]; |
| 306 | char * compute_base = data_mmv; |
| 307 | |
| 308 | //const float2 * y2 = (const float2 *) y; |
| 309 | |
| 310 | tile_C C[ntA][ntB]; |
| 311 | |
| 312 | T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); |
| 313 | |
| 314 | for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { |
| 315 | tile_A A[ntA][warp_size / tile_A::J]; |
| 316 | #pragma unroll |
| 317 | for (int itA = 0; itA < ntA; ++itA) { |
| 318 | #pragma unroll |
| 319 | for (int i = 0; i < tile_A::I; ++i) { |
| 320 | tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; |
| 321 | } |
| 322 | #pragma unroll |
| 323 | for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { |
| 324 | load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); |
| 325 | } |
| 326 | } |
| 327 | |
| 328 | if constexpr (std::is_same_v<T, float>) { |
| 329 | float vals_buf[2][tile_B::I]; |
| 330 | auto gather_tile = [&](int tile_idx_local, float *vals) { |
| 331 | #pragma unroll |
| 332 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 333 | const int j = j0 + tile_idx_local*tile_B::I; |
| 334 | const int global_j = col_base + j; |
| 335 | float val = 0.0f; |
| 336 | if (j < cols_per_block && global_j < ncols_expert) { |
| 337 | const int src_entry = ids_src_expert[global_j]; |
| 338 | const uint2 qrm = fast_div_modulo(n: (uint32_t) src_entry, fastdiv_values: sis1_fd); |
| 339 | const int token = (int) qrm.x; |
| 340 | const int channel = (int) qrm.y; |
| 341 | if (token < ncols_dst_total) { |
| 342 | val = y[channel*stride_channel_y + token*stride_col_y + col]; |
| 343 | } |
| 344 | } |
| 345 | vals[j0] = val; |
| 346 | } |
| 347 | }; |
| 348 | |
| 349 | gather_tile(0, vals_buf[0]); |
| 350 | |
| 351 | int curr_buf = 0; |
| 352 | int next_buf = 1; |
| 353 | #pragma unroll |
| 354 | for (int itB = 0; itB < ntB; ++itB) { |
| 355 | #pragma unroll |
| 356 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 357 | tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; |
| 358 | } |
| 359 | |
| 360 | if (itB + 1 < ntB) { |
| 361 | gather_tile(itB + 1, vals_buf[next_buf]); |
| 362 | } |
| 363 | |
| 364 | #pragma unroll |
| 365 | for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { |
| 366 | tile_B B; |
| 367 | load_ldmatrix(B, tile_xy + k0, tile_k_padded); |
| 368 | #pragma unroll |
| 369 | for (int itA = 0; itA < ntA; ++itA) { |
| 370 | mma(C[itA][itB], A[itA][k0/tile_B::J], B); |
| 371 | } |
| 372 | } |
| 373 | |
| 374 | if (itB + 1 < ntB) { |
| 375 | curr_buf ^= 1; |
| 376 | next_buf ^= 1; |
| 377 | } |
| 378 | } |
| 379 | } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { |
| 380 | float2 vals_buf[2][tile_B::I]; |
| 381 | auto gather_tile = [&](int tile_idx_local, float2 *vals) { |
| 382 | #pragma unroll |
| 383 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 384 | const int j = j0 + tile_idx_local*tile_B::I; |
| 385 | const int global_j = col_base + j; |
| 386 | float2 tmp = make_float2(x: 0.0f, y: 0.0f); |
| 387 | if (j < cols_per_block && global_j < ncols_expert) { |
| 388 | const int src_entry = ids_src_expert[global_j]; |
| 389 | const uint2 qrm = fast_div_modulo(n: (uint32_t) src_entry, fastdiv_values: sis1_fd); |
| 390 | const int token = (int) qrm.x; |
| 391 | const int channel = (int) qrm.y; |
| 392 | if (token < ncols_dst_total) { |
| 393 | tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; |
| 394 | } |
| 395 | } |
| 396 | vals[j0] = tmp; |
| 397 | } |
| 398 | }; |
| 399 | |
| 400 | if (ntB > 0) { |
| 401 | gather_tile(0, vals_buf[0]); |
| 402 | } |
| 403 | |
| 404 | int curr_buf = 0; |
| 405 | int next_buf = 1; |
| 406 | #pragma unroll |
| 407 | for (int itB = 0; itB < ntB; ++itB) { |
| 408 | #pragma unroll |
| 409 | for (int j0 = 0; j0 < tile_B::I; ++j0) { |
| 410 | const float2 tmp = vals_buf[curr_buf][j0]; |
| 411 | tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; |
| 412 | } |
| 413 | |
| 414 | if (itB + 1 < ntB) { |
| 415 | gather_tile(itB + 1, vals_buf[next_buf]); |
| 416 | } |
| 417 | |
| 418 | #pragma unroll |
| 419 | for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { |
| 420 | tile_B B; |
| 421 | load_ldmatrix(B, tile_xy + k0, tile_k_padded); |
| 422 | #pragma unroll |
| 423 | for (int itA = 0; itA < ntA; ++itA) { |
| 424 | mma(C[itA][itB], A[itA][k0/tile_B::J], B); |
| 425 | } |
| 426 | } |
| 427 | |
| 428 | if (itB + 1 < ntB) { |
| 429 | curr_buf ^= 1; |
| 430 | next_buf ^= 1; |
| 431 | } |
| 432 | } |
| 433 | } else { |
| 434 | static_assert(std::is_same_v<T, void>, "unsupported type" ); |
| 435 | } |
| 436 | } |
| 437 | |
| 438 | float * buf_iw = (float *) compute_base; |
| 439 | constexpr int kiw = nwarps*rows_per_block + 4; |
| 440 | |
| 441 | if (nwarps > 1) { |
| 442 | __syncthreads(); |
| 443 | } |
| 444 | #pragma unroll |
| 445 | for (int itB = 0; itB < ntB; ++itB) { |
| 446 | #pragma unroll |
| 447 | for (int itA = 0; itA < ntA; ++itA) { |
| 448 | #pragma unroll |
| 449 | for (int l = 0; l < tile_C::ne; ++l) { |
| 450 | const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); |
| 451 | const int j = itB*tile_C::J + tile_C::get_j(l); |
| 452 | buf_iw[j*kiw + i] = C[itA][itB].x[l]; |
| 453 | } |
| 454 | } |
| 455 | } |
| 456 | |
| 457 | if (nwarps > 1) { |
| 458 | __syncthreads(); |
| 459 | } |
| 460 | |
| 461 | #pragma unroll |
| 462 | for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { |
| 463 | const int j = j0 + threadIdx.y; |
| 464 | |
| 465 | if (j0 + nwarps > cols_per_block && j >= cols_per_block) { |
| 466 | return; |
| 467 | } |
| 468 | |
| 469 | float sum = 0.0f; |
| 470 | static_assert(rows_per_block == warp_size, "need loop/check" ); |
| 471 | #pragma unroll |
| 472 | for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { |
| 473 | const int i = i0 + threadIdx.x; |
| 474 | |
| 475 | sum += buf_iw[j*kiw + i]; |
| 476 | } |
| 477 | |
| 478 | const int global_j = col_base + j; |
| 479 | if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { |
| 480 | const int dst_entry = ids_dst_expert[global_j]; |
| 481 | const uint2 qrm = fast_div_modulo(n: (uint32_t) dst_entry, fastdiv_values: nch_fd); |
| 482 | const int token = (int) qrm.x; |
| 483 | if (token < ncols_dst_total) { |
| 484 | const int slot = (int) qrm.y; |
| 485 | dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; |
| 486 | } |
| 487 | } |
| 488 | } |
| 489 | #else |
| 490 | GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, |
| 491 | ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 492 | channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 493 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); |
| 494 | NO_DEVICE_CODE; |
| 495 | #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 496 | } |
| 497 | |
| 498 | template<typename T, int cols_per_block, int nwarps> |
| 499 | static inline void mul_mat_f_switch_ids( |
| 500 | const T * x, const float * y, const int32_t * ids, float * dst, |
| 501 | const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, |
| 502 | const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, |
| 503 | const int64_t stride_col_id, const int64_t stride_row_id, |
| 504 | const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, |
| 505 | const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, |
| 506 | const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, |
| 507 | const mmf_ids_data * ids_data) { |
| 508 | const bool has_ids_data = ids_data && ids_data->ids_src_compact; |
| 509 | |
| 510 | // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) |
| 511 | // we prefer the normal mul_mat_f path with has_ids=true. |
| 512 | if (has_ids_data && ncols_dst > 16) { |
| 513 | const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); |
| 514 | if (max_tiles == 0) { |
| 515 | return; |
| 516 | } |
| 517 | dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); |
| 518 | |
| 519 | const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values(d_64: (uint32_t) ids_data->sis1) : make_uint3(x: 0, y: 0, z: 1); |
| 520 | const uint3 nch_fd = init_fastdiv_values(d_64: (uint32_t) nchannels_dst); |
| 521 | |
| 522 | mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<gridDim: block_nums_ids, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>> |
| 523 | (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, |
| 524 | ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 525 | channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 526 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, |
| 527 | sis1_fd, nch_fd); |
| 528 | } else if (ids) { |
| 529 | const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; |
| 530 | dim3 block_nums_ids = block_nums; |
| 531 | block_nums_ids.y *= col_tiles; |
| 532 | |
| 533 | mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<gridDim: block_nums_ids, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>> |
| 534 | (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 535 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 536 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |
| 537 | } else { |
| 538 | mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<gridDim: block_nums, blockDim: block_dims, sharedMem: nbytes_shared_total, stream>>> |
| 539 | (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 540 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 541 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |
| 542 | } |
| 543 | } |
| 544 | |
| 545 | template <typename T, int cols_per_block> |
| 546 | void mul_mat_f_cuda( |
| 547 | const T * x, const float * y, const int32_t * ids, float * dst, |
| 548 | const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, |
| 549 | const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, |
| 550 | const int64_t stride_col_id, const int64_t stride_row_id, |
| 551 | const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, |
| 552 | const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, |
| 553 | const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, |
| 554 | cudaStream_t stream, const mmf_ids_data * ids_data) { |
| 555 | typedef tile<16, 8, T> tile_A_16; |
| 556 | typedef tile<32, 8, T> tile_A_32; |
| 557 | typedef tile< 8, 8, T> tile_B; |
| 558 | |
| 559 | GGML_ASSERT(ncols_x % 2 == 0); |
| 560 | GGML_ASSERT(stride_row % 2 == 0); |
| 561 | GGML_ASSERT(stride_col_y % 2 == 0); |
| 562 | GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); |
| 563 | GGML_ASSERT( nsamples_dst % nsamples_x == 0); |
| 564 | const int64_t channel_ratio = nchannels_dst / nchannels_x; |
| 565 | const int64_t sample_ratio = nsamples_dst / nsamples_x; |
| 566 | |
| 567 | const int device = ggml_cuda_get_device(); |
| 568 | const int cc = ggml_cuda_info().devices[device].cc; |
| 569 | const int warp_size = ggml_cuda_info().devices[device].warp_size; |
| 570 | |
| 571 | int64_t nwarps_best = 1; |
| 572 | int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); |
| 573 | int64_t max_block_size = 256; |
| 574 | for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { |
| 575 | const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); |
| 576 | if (niter < niter_best) { |
| 577 | niter_best = niter; |
| 578 | nwarps_best = nwarps; |
| 579 | } |
| 580 | } |
| 581 | |
| 582 | constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; |
| 583 | const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; |
| 584 | const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; |
| 585 | const int nbytes_shared = std::max(a: nbytes_shared_iter, b: nbytes_shared_combine); |
| 586 | const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; |
| 587 | const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; |
| 588 | const int64_t grid_y = ids ? nchannels_x : nchannels_dst; |
| 589 | |
| 590 | const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); |
| 591 | const dim3 block_dims(warp_size, nwarps_best, 1); |
| 592 | |
| 593 | switch (nwarps_best) { |
| 594 | case 1: { |
| 595 | mul_mat_f_switch_ids<T, cols_per_block, 1>( |
| 596 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 597 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 598 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 599 | ids_data); |
| 600 | } break; |
| 601 | case 2: { |
| 602 | mul_mat_f_switch_ids<T, cols_per_block, 2>( |
| 603 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 604 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 605 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 606 | ids_data); |
| 607 | } break; |
| 608 | case 3: { |
| 609 | mul_mat_f_switch_ids<T, cols_per_block, 3>( |
| 610 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 611 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 612 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 613 | ids_data); |
| 614 | } break; |
| 615 | case 4: { |
| 616 | mul_mat_f_switch_ids<T, cols_per_block, 4>( |
| 617 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 618 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 619 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 620 | ids_data); |
| 621 | } break; |
| 622 | case 5: { |
| 623 | mul_mat_f_switch_ids<T, cols_per_block, 5>( |
| 624 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 625 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 626 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 627 | ids_data); |
| 628 | } break; |
| 629 | case 6: { |
| 630 | mul_mat_f_switch_ids<T, cols_per_block, 6>( |
| 631 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 632 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 633 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 634 | ids_data); |
| 635 | } break; |
| 636 | case 7: { |
| 637 | mul_mat_f_switch_ids<T, cols_per_block, 7>( |
| 638 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 639 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 640 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 641 | ids_data); |
| 642 | } break; |
| 643 | case 8: { |
| 644 | mul_mat_f_switch_ids<T, cols_per_block, 8>( |
| 645 | x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, |
| 646 | stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 647 | sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, |
| 648 | ids_data); |
| 649 | } break; |
| 650 | default: { |
| 651 | GGML_ABORT("fatal error" ); |
| 652 | } break; |
| 653 | } |
| 654 | |
| 655 | GGML_UNUSED_VARS(nchannels_y); |
| 656 | } |
| 657 | |
| 658 | template <typename T> |
| 659 | static void mul_mat_f_switch_cols_per_block( |
| 660 | const T * x, const float * y, const int32_t * ids, float * dst, |
| 661 | const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, |
| 662 | const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, |
| 663 | const int64_t stride_col_id, const int stride_row_id, |
| 664 | const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, |
| 665 | const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, |
| 666 | const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, |
| 667 | cudaStream_t stream, const mmf_ids_data * ids_data) { |
| 668 | |
| 669 | const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; |
| 670 | |
| 671 | GGML_ASSERT(ids || ncols_dst <= 16); |
| 672 | |
| 673 | switch (ncols_case) { |
| 674 | case 1: { |
| 675 | mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 676 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 677 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 678 | } break; |
| 679 | case 2: { |
| 680 | mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 681 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 682 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 683 | } break; |
| 684 | case 3: { |
| 685 | mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 686 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 687 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 688 | } break; |
| 689 | case 4: { |
| 690 | mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 691 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 692 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 693 | } break; |
| 694 | case 5: { |
| 695 | mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 696 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 697 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 698 | } break; |
| 699 | case 6: { |
| 700 | mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 701 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 702 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 703 | } break; |
| 704 | case 7: { |
| 705 | mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 706 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 707 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 708 | } break; |
| 709 | case 8: { |
| 710 | mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 711 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 712 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 713 | } break; |
| 714 | case 9: { |
| 715 | mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 716 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 717 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 718 | } break; |
| 719 | case 10: { |
| 720 | mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 721 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 722 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 723 | } break; |
| 724 | case 11: { |
| 725 | mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 726 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 727 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 728 | } break; |
| 729 | case 12: { |
| 730 | mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 731 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 732 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 733 | } break; |
| 734 | case 13: { |
| 735 | mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 736 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 737 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 738 | } break; |
| 739 | case 14: { |
| 740 | mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 741 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 742 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 743 | } break; |
| 744 | case 15: { |
| 745 | mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 746 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 747 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 748 | } break; |
| 749 | case 16: { |
| 750 | mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, |
| 751 | stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
| 752 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); |
| 753 | } break; |
| 754 | default: { |
| 755 | GGML_ABORT("fatal error" ); |
| 756 | } break; |
| 757 | } |
| 758 | } |
| 759 | |
| 760 | #define DECL_MMF_CASE_HELPER(T, ncols_dst) \ |
| 761 | template void mul_mat_f_cuda<T, ncols_dst>( \ |
| 762 | const T * x, const float * y, const int32_t * ids, float * dst, \ |
| 763 | const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ |
| 764 | const int64_t stride_col_id, const int64_t stride_row_id, \ |
| 765 | const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ |
| 766 | const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ |
| 767 | const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ |
| 768 | cudaStream_t stream, const mmf_ids_data * ids_data); |
| 769 | |
| 770 | #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) |
| 771 | #define DECL_MMF_CASE_EXTERN(ncols_dst) \ |
| 772 | extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ |
| 773 | extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ |
| 774 | extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) |
| 775 | |
| 776 | #define DECL_MMF_CASE(ncols_dst) \ |
| 777 | DECL_MMF_CASE_HELPER(float, ncols_dst) \ |
| 778 | DECL_MMF_CASE_HELPER(half2, ncols_dst) \ |
| 779 | DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) |
| 780 | |
| 781 | DECL_MMF_CASE_EXTERN(1); |
| 782 | DECL_MMF_CASE_EXTERN(2); |
| 783 | DECL_MMF_CASE_EXTERN(3); |
| 784 | DECL_MMF_CASE_EXTERN(4); |
| 785 | DECL_MMF_CASE_EXTERN(5); |
| 786 | DECL_MMF_CASE_EXTERN(6); |
| 787 | DECL_MMF_CASE_EXTERN(7); |
| 788 | DECL_MMF_CASE_EXTERN(8); |
| 789 | DECL_MMF_CASE_EXTERN(9); |
| 790 | DECL_MMF_CASE_EXTERN(10); |
| 791 | DECL_MMF_CASE_EXTERN(11); |
| 792 | DECL_MMF_CASE_EXTERN(12); |
| 793 | DECL_MMF_CASE_EXTERN(13); |
| 794 | DECL_MMF_CASE_EXTERN(14); |
| 795 | DECL_MMF_CASE_EXTERN(15); |
| 796 | DECL_MMF_CASE_EXTERN(16); |
| 797 | #else |
| 798 | #define DECL_MMF_CASE(ncols_dst) |
| 799 | #endif |
| 800 | |