1#include "ggml-cuda/common.cuh"
2#include "ggml.h"
3#include "topk-moe.cuh"
4
5#include <cmath>
6#include <initializer_list>
7
8// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
9template <int experts_per_thread, bool use_limit>
10__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
11 float max_val = -INFINITY;
12
13#pragma unroll
14 for (int i = 0; i < experts_per_thread; i++) {
15 const int idx = lane + i * WARP_SIZE;
16 const bool active = !use_limit || (idx < limit);
17 if (active) {
18 max_val = max(max_val, vals[i]);
19 }
20 }
21
22 max_val = warp_reduce_max(x: max_val);
23
24 float sum = 0.f;
25
26#pragma unroll
27 for (int i = 0; i < experts_per_thread; i++) {
28 const int idx = lane + i * WARP_SIZE;
29 const bool active = !use_limit || (idx < limit);
30 if (active) {
31 const float val = expf(vals[i] - max_val);
32 vals[i] = val;
33 sum += val;
34 } else {
35 vals[i] = 0.f;
36 }
37 }
38
39 sum = warp_reduce_sum(x: sum);
40
41 const float inv_sum = 1.0f / sum;
42
43#pragma unroll
44 for (int i = 0; i < experts_per_thread; i++) {
45 const int idx = lane + i * WARP_SIZE;
46 const bool active = !use_limit || (idx < limit);
47 if (active) {
48 vals[i] *= inv_sum;
49 }
50 }
51}
52
53/*
54 This kernel does the following:
55 1. optionally softmax over the logits per token [n_experts, n_tokens]
56 2. argmax reduce over the top-k (n_experts_used) logits
57 3. write weights + ids to global memory
58 4. optionally normalize the weights or apply softmax over the selected logits
59
60 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
61*/
62template <int n_experts, bool with_norm, bool delayed_softmax = false>
63__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
64 float * weights,
65 int32_t * ids,
66 const int n_rows,
67 const int n_expert_used,
68 const float clamp_val) {
69 const int row = blockIdx.x * blockDim.y + threadIdx.y;
70 if (row >= n_rows) {
71 return;
72 }
73
74 logits += n_experts * row;
75 weights += n_expert_used * row;
76 ids += n_experts * row;
77
78 constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
79
80 float wt[experts_per_thread];
81
82#pragma unroll
83 for (int i = 0; i < n_experts; i += WARP_SIZE) {
84 const int expert = i + threadIdx.x;
85 wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
86 }
87
88 if constexpr (!delayed_softmax) {
89 softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
90 }
91
92 //at this point, each thread holds either a portion of the softmax distribution
93 //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
94 //the expert weight as -inf to exclude from the next iteration
95
96 float wt_sum = 0.f;
97
98 float output_weights[experts_per_thread];
99
100#pragma unroll
101 for (int i = 0; i < experts_per_thread; i++) {
102 output_weights[i] = 0.f;
103 }
104
105 for (int k = 0; k < n_expert_used; k++) {
106 float max_val = wt[0];
107 int max_expert = threadIdx.x;
108
109#pragma unroll
110 for (int i = 1; i < experts_per_thread; i++) {
111 const int expert = threadIdx.x + i * WARP_SIZE;
112 if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
113 max_val = wt[i];
114 max_expert = expert;
115 }
116 }
117
118#pragma unroll
119 for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
120 const float val = __shfl_xor_sync(mask: 0xFFFFFFFF, val: max_val, offset: mask, WARP_SIZE);
121 const int expert = __shfl_xor_sync(mask: 0xFFFFFFFF, val: max_expert, offset: mask, WARP_SIZE);
122 if (val > max_val || (val == max_val && expert < max_expert)) {
123 max_val = val;
124 max_expert = expert;
125 }
126 }
127
128 if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
129 output_weights[k / WARP_SIZE] = max_val;
130 }
131
132 if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
133 wt[max_expert / WARP_SIZE] = -INFINITY;
134
135 ids[k] = max_expert;
136 if constexpr (with_norm) {
137 wt_sum += max_val;
138 }
139 }
140 }
141
142 if constexpr (with_norm) {
143 wt_sum = warp_reduce_sum(x: wt_sum);
144 wt_sum = max(a: wt_sum, b: clamp_val);
145 const float inv_sum = 1.0f / wt_sum;
146
147 for (int i = 0; i < experts_per_thread; i++) {
148 output_weights[i] *= inv_sum;
149 }
150 }
151
152 if constexpr (delayed_softmax) {
153 softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
154 }
155
156#pragma unroll
157 for (int i = 0; i < experts_per_thread; i++) {
158 const int idx = i * WARP_SIZE + threadIdx.x;
159 if (idx < n_expert_used) {
160 weights[idx] = output_weights[i];
161 }
162 }
163
164 if (!with_norm) {
165 GGML_UNUSED(clamp_val);
166 }
167}
168
169template <bool with_norm, bool delayed_softmax = false>
170static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
171 const float * logits,
172 float * weights,
173 int32_t * ids,
174 const int n_rows,
175 const int n_expert,
176 const int n_expert_used,
177 const float clamp_val) {
178 static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
179 const int rows_per_block = 4;
180 dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
181 dim3 block_dims(WARP_SIZE, rows_per_block, 1);
182 cudaStream_t stream = ctx.stream();
183
184 switch (n_expert) {
185 case 1:
186 topk_moe_cuda<1, with_norm, delayed_softmax>
187 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
188 break;
189 case 2:
190 topk_moe_cuda<2, with_norm, delayed_softmax>
191 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
192 break;
193 case 4:
194 topk_moe_cuda<4, with_norm, delayed_softmax>
195 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
196 break;
197 case 8:
198 topk_moe_cuda<8, with_norm, delayed_softmax>
199 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
200 break;
201 case 16:
202 topk_moe_cuda<16, with_norm, delayed_softmax>
203 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
204 break;
205 case 32:
206 topk_moe_cuda<32, with_norm, delayed_softmax>
207 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
208 break;
209 case 64:
210 topk_moe_cuda<64, with_norm, delayed_softmax>
211 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
212 break;
213 case 128:
214 topk_moe_cuda<128, with_norm, delayed_softmax>
215 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
216 break;
217 case 256:
218 topk_moe_cuda<256, with_norm, delayed_softmax>
219 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
220 break;
221 case 512:
222 topk_moe_cuda<512, with_norm, delayed_softmax>
223 <<<gridDim: grid_dims, blockDim: block_dims, sharedMem: 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
224 break;
225 default:
226 GGML_ASSERT(false && "fatal error");
227 break;
228 }
229}
230
231void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
232 const ggml_tensor * logits,
233 ggml_tensor * weights,
234 ggml_tensor * ids,
235 const bool with_norm,
236 const bool delayed_softmax,
237 ggml_tensor * clamp) {
238 GGML_ASSERT(logits->type == GGML_TYPE_F32);
239 GGML_ASSERT(weights->type == GGML_TYPE_F32);
240 GGML_ASSERT(ids->type == GGML_TYPE_I32);
241
242 const int n_experts = logits->ne[0];
243 const int n_rows = logits->ne[1];
244
245 const float * logits_d = (const float *) logits->data;
246 float * weights_d = (float *) weights->data;
247 int32_t * ids_d = (int32_t *) ids->data;
248
249 GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
250
251 const int n_expert_used = weights->ne[1];
252
253 float clamp_val = -INFINITY;
254 if (with_norm) {
255 if (clamp) {
256 clamp_val = ggml_get_op_params_f32(clamp, 0);
257 }
258 launch_topk_moe_cuda<true>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used, clamp_val);
259 } else {
260 GGML_ASSERT(clamp == nullptr);
261 if (delayed_softmax) {
262 launch_topk_moe_cuda<false, true>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used,
263 clamp_val);
264 } else {
265 launch_topk_moe_cuda<false, false>(ctx, logits: logits_d, weights: weights_d, ids: ids_d, n_rows, n_expert: n_experts, n_expert_used,
266 clamp_val);
267 }
268 }
269}
270
271bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
272 float scale = 1.0f;
273 float max_bias = 0.0f;
274
275 memcpy(dest: &scale, src: (const float *) softmax->op_params + 0, n: sizeof(float));
276 memcpy(dest: &max_bias, src: (const float *) softmax->op_params + 1, n: sizeof(float));
277
278 if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
279 return false;
280 }
281
282 if (scale != 1.0f || max_bias != 0.0f) {
283 return false;
284 }
285
286 // don't fuse when masks or sinks are present
287 if (softmax->src[1] || softmax->src[2]) {
288 return false;
289 }
290
291 const int n_expert = softmax->ne[0];
292 // n_expert must be a power of 2
293 if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
294 return false;
295 }
296
297 if (clamp) {
298 if (clamp->op != GGML_OP_CLAMP) {
299 return false;
300 }
301 float max_val = ggml_get_op_params_f32(clamp, 1);
302
303 if (max_val != INFINITY) {
304 return false;
305 }
306 }
307
308
309 return true;
310}
311
312std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
313 static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
314 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
315 GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
316 GGML_OP_RESHAPE };
317
318 static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
319 GGML_OP_VIEW, GGML_OP_GET_ROWS };
320
321 static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
322 GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
323 GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
324
325 GGML_ASSERT(!norm || !delayed_softmax);
326
327 if (delayed_softmax) {
328 return delayed_softmax_ops;
329 }
330
331 if (norm) {
332 return norm_ops;
333 }
334
335 return no_norm_ops;
336}
337