| 1 | #include "ggml-cuda/common.cuh" |
| 2 | #include "ggml.h" |
| 3 | #include "topk-moe.cuh" |
| 4 | |
| 5 | #include <cmath> |
| 6 | #include <initializer_list> |
| 7 | |
| 8 | // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. |
| 9 | template <int experts_per_thread, bool use_limit> |
| 10 | __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { |
| 11 | float max_val = -INFINITY; |
| 12 | |
| 13 | #pragma unroll |
| 14 | for (int i = 0; i < experts_per_thread; i++) { |
| 15 | const int idx = lane + i * WARP_SIZE; |
| 16 | const bool active = !use_limit || (idx < limit); |
| 17 | if (active) { |
| 18 | max_val = max(max_val, vals[i]); |
| 19 | } |
| 20 | } |
| 21 | |
| 22 | max_val = warp_reduce_max(x: max_val); |
| 23 | |
| 24 | float sum = 0.f; |
| 25 | |
| 26 | #pragma unroll |
| 27 | for (int i = 0; i < experts_per_thread; i++) { |
| 28 | const int idx = lane + i * WARP_SIZE; |
| 29 | const bool active = !use_limit || (idx < limit); |
| 30 | if (active) { |
| 31 | const float val = expf(vals[i] - max_val); |
| 32 | vals[i] = val; |
| 33 | sum += val; |
| 34 | } else { |
| 35 | vals[i] = 0.f; |
| 36 | } |
| 37 | } |
| 38 | |
| 39 | sum = warp_reduce_sum(x: sum); |
| 40 | |
| 41 | const float inv_sum = 1.0f / sum; |
| 42 | |
| 43 | #pragma unroll |
| 44 | for (int i = 0; i < experts_per_thread; i++) { |
| 45 | const int idx = lane + i * WARP_SIZE; |
| 46 | const bool active = !use_limit || (idx < limit); |
| 47 | if (active) { |
| 48 | vals[i] *= inv_sum; |
| 49 | } |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | /* |
| 54 | This kernel does the following: |
| 55 | 1. optionally softmax over the logits per token [n_experts, n_tokens] |
| 56 | 2. argmax reduce over the top-k (n_experts_used) logits |
| 57 | 3. write weights + ids to global memory |
| 58 | 4. optionally normalize the weights or apply softmax over the selected logits |
| 59 | |
| 60 | It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models |
| 61 | */ |
| 62 | template <int n_experts, bool with_norm, bool delayed_softmax = false> |
| 63 | __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, |
| 64 | float * weights, |
| 65 | int32_t * ids, |
| 66 | const int n_rows, |
| 67 | const int n_expert_used, |
| 68 | const float clamp_val) { |
| 69 | const int row = blockIdx.x * blockDim.y + threadIdx.y; |
| 70 | if (row >= n_rows) { |
| 71 | return; |
| 72 | } |
| 73 | |
| 74 | logits += n_experts * row; |
| 75 | weights += n_expert_used * row; |
| 76 | ids += n_experts * row; |
| 77 | |
| 78 | constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; |
| 79 | |
| 80 | float wt[experts_per_thread]; |
| 81 | |
| 82 | #pragma unroll |
| 83 | for (int i = 0; i < n_experts; i += WARP_SIZE) { |
| 84 | const int expert = i + threadIdx.x; |
| 85 | wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; |
| 86 | } |
| 87 | |
| 88 | if constexpr (!delayed_softmax) { |
| 89 | softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); |
| 90 | } |
| 91 | |
| 92 | //at this point, each thread holds either a portion of the softmax distribution |
| 93 | //or the raw logits. We do the argmax reduce over n_expert_used, each time marking |
| 94 | //the expert weight as -inf to exclude from the next iteration |
| 95 | |
| 96 | float wt_sum = 0.f; |
| 97 | |
| 98 | float output_weights[experts_per_thread]; |
| 99 | |
| 100 | #pragma unroll |
| 101 | for (int i = 0; i < experts_per_thread; i++) { |
| 102 | output_weights[i] = 0.f; |
| 103 | } |
| 104 | |
| 105 | for (int k = 0; k < n_expert_used; k++) { |
| 106 | float max_val = wt[0]; |
| 107 | int max_expert = threadIdx.x; |
| 108 | |
| 109 | #pragma unroll |
| 110 | for (int i = 1; i < experts_per_thread; i++) { |
| 111 | const int expert = threadIdx.x + i * WARP_SIZE; |
| 112 | if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { |
| 113 | max_val = wt[i]; |
| 114 | max_expert = expert; |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | #pragma unroll |
| 119 | for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { |
| 120 | const float val = __shfl_xor_sync(mask: 0xFFFFFFFF, val: max_val, offset: mask, WARP_SIZE); |
| 121 | const int expert = __shfl_xor_sync(mask: 0xFFFFFFFF, val: max_expert, offset: mask, WARP_SIZE); |
| 122 | if (val > max_val || (val == max_val && expert < max_expert)) { |
| 123 | max_val = val; |
| 124 | max_expert = expert; |
| 125 | } |
| 126 | } |
| 127 | |
| 128 | if ((k & (WARP_SIZE - 1)) == threadIdx.x) { |
| 129 | output_weights[k / WARP_SIZE] = max_val; |
| 130 | } |
| 131 | |
| 132 | if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { |
| 133 | wt[max_expert / WARP_SIZE] = -INFINITY; |
| 134 | |
| 135 | ids[k] = max_expert; |
| 136 | if constexpr (with_norm) { |
| 137 | wt_sum += max_val; |
| 138 | } |
| 139 | } |
| 140 | } |
| 141 | |
| 142 | if constexpr (with_norm) { |
| 143 | wt_sum = warp_reduce_sum(x: wt_sum); |
| 144 | wt_sum = max(a: wt_sum, b: clamp_val); |
| 145 | const float inv_sum = 1.0f / wt_sum; |
| 146 | |
| 147 | for (int i = 0; i < experts_per_thread; i++) { |
| 148 | output_weights[i] *= inv_sum; |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | if constexpr (delayed_softmax) { |
| 153 | softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x); |
| 154 | } |
| 155 | |
| 156 | #pragma unroll |
| 157 | for (int i = 0; i < experts_per_thread; i++) { |
| 158 | const int idx = i * WARP_SIZE + threadIdx.x; |
| 159 | if (idx < n_expert_used) { |
| 160 | weights[idx] = output_weights[i]; |
| 161 | } |
| 162 | } |
| 163 | |
| 164 | if (!with_norm) { |
| 165 | GGML_UNUSED(clamp_val); |
| 166 | } |
| 167 | } |
| 168 | |
| 169 | template <bool with_norm, bool delayed_softmax = false> |
| 170 | static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, |
| 171 | const float * logits, |
| 172 | float * weights, |
| 173 | int32_t * ids, |
| 174 | const int n_rows, |
| 175 | const int n_expert, |
| 176 | const int n_expert_used, |
| 177 | const float clamp_val) { |
| 178 | static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization" ); |
| 179 | const int rows_per_block = 4; |
| 180 | dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); |
| 181 | dim3 block_dims(WARP_SIZE, rows_per_block, 1); |
| 182 | cudaStream_t stream = ctx.stream(); |
| 183 | |
| 184 | switch (n_expert) { |
| 185 | case 1: |
| 186 | topk_moe_cuda<1, with_norm, delayed_softmax> |
| 187 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 188 | break; |
| 189 | case 2: |
| 190 | topk_moe_cuda<2, with_norm, delayed_softmax> |
| 191 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 192 | break; |
| 193 | case 4: |
| 194 | topk_moe_cuda<4, with_norm, delayed_softmax> |
| 195 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 196 | break; |
| 197 | case 8: |
| 198 | topk_moe_cuda<8, with_norm, delayed_softmax> |
| 199 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 200 | break; |
| 201 | case 16: |
| 202 | topk_moe_cuda<16, with_norm, delayed_softmax> |
| 203 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 204 | break; |
| 205 | case 32: |
| 206 | topk_moe_cuda<32, with_norm, delayed_softmax> |
| 207 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 208 | break; |
| 209 | case 64: |
| 210 | topk_moe_cuda<64, with_norm, delayed_softmax> |
| 211 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 212 | break; |
| 213 | case 128: |
| 214 | topk_moe_cuda<128, with_norm, delayed_softmax> |
| 215 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 216 | break; |
| 217 | case 256: |
| 218 | topk_moe_cuda<256, with_norm, delayed_softmax> |
| 219 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 220 | break; |
| 221 | case 512: |
| 222 | topk_moe_cuda<512, with_norm, delayed_softmax> |
| 223 | <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); |
| 224 | break; |
| 225 | default: |
| 226 | GGML_ASSERT(false && "fatal error" ); |
| 227 | break; |
| 228 | } |
| 229 | } |
| 230 | |
| 231 | void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, |
| 232 | const ggml_tensor * logits, |
| 233 | ggml_tensor * weights, |
| 234 | ggml_tensor * ids, |
| 235 | const bool with_norm, |
| 236 | const bool delayed_softmax, |
| 237 | ggml_tensor * clamp) { |
| 238 | GGML_ASSERT(logits->type == GGML_TYPE_F32); |
| 239 | GGML_ASSERT(weights->type == GGML_TYPE_F32); |
| 240 | GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 241 | |
| 242 | const int n_experts = logits->ne[0]; |
| 243 | const int n_rows = logits->ne[1]; |
| 244 | |
| 245 | const float * logits_d = (const float *) logits->data; |
| 246 | float * weights_d = (float *) weights->data; |
| 247 | int32_t * ids_d = (int32_t *) ids->data; |
| 248 | |
| 249 | GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); |
| 250 | |
| 251 | const int n_expert_used = weights->ne[1]; |
| 252 | |
| 253 | float clamp_val = -INFINITY; |
| 254 | if (with_norm) { |
| 255 | if (clamp) { |
| 256 | clamp_val = ggml_get_op_params_f32(clamp, 0); |
| 257 | } |
| 258 | launch_topk_moe_cuda<true>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used, clamp_val); |
| 259 | } else { |
| 260 | GGML_ASSERT(clamp == nullptr); |
| 261 | if (delayed_softmax) { |
| 262 | launch_topk_moe_cuda<false, true>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used, |
| 263 | clamp_val); |
| 264 | } else { |
| 265 | launch_topk_moe_cuda<false, false>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used, |
| 266 | clamp_val); |
| 267 | } |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { |
| 272 | float scale = 1.0f; |
| 273 | float max_bias = 0.0f; |
| 274 | |
| 275 | memcpy(dest: &scale, src: (const float *) softmax->op_params + 0, n: sizeof(float)); |
| 276 | memcpy(dest: &max_bias, src: (const float *) softmax->op_params + 1, n: sizeof(float)); |
| 277 | |
| 278 | if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { |
| 279 | return false; |
| 280 | } |
| 281 | |
| 282 | if (scale != 1.0f || max_bias != 0.0f) { |
| 283 | return false; |
| 284 | } |
| 285 | |
| 286 | // don't fuse when masks or sinks are present |
| 287 | if (softmax->src[1] || softmax->src[2]) { |
| 288 | return false; |
| 289 | } |
| 290 | |
| 291 | const int n_expert = softmax->ne[0]; |
| 292 | // n_expert must be a power of 2 |
| 293 | if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { |
| 294 | return false; |
| 295 | } |
| 296 | |
| 297 | if (clamp) { |
| 298 | if (clamp->op != GGML_OP_CLAMP) { |
| 299 | return false; |
| 300 | } |
| 301 | float max_val = ggml_get_op_params_f32(clamp, 1); |
| 302 | |
| 303 | if (max_val != INFINITY) { |
| 304 | return false; |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | |
| 309 | return true; |
| 310 | } |
| 311 | |
| 312 | std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { |
| 313 | static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 314 | GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, |
| 315 | GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, |
| 316 | GGML_OP_RESHAPE }; |
| 317 | |
| 318 | static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, |
| 319 | GGML_OP_VIEW, GGML_OP_GET_ROWS }; |
| 320 | |
| 321 | static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, |
| 322 | GGML_OP_GET_ROWS, GGML_OP_RESHAPE, |
| 323 | GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; |
| 324 | |
| 325 | GGML_ASSERT(!norm || !delayed_softmax); |
| 326 | |
| 327 | if (delayed_softmax) { |
| 328 | return delayed_softmax_ops; |
| 329 | } |
| 330 | |
| 331 | if (norm) { |
| 332 | return norm_ops; |
| 333 | } |
| 334 | |
| 335 | return no_norm_ops; |
| 336 | } |
| 337 | |