1#include "common.cuh"
2#include "mmid.cuh"
3
4// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
5struct mm_ids_helper_store {
6 uint32_t data;
7
8 __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
9 data = (it & 0x003FFFFF) | (iex_used << 22);
10 }
11
12 __device__ uint32_t it() const {
13 return data & 0x003FFFFF;
14 }
15
16 __device__ uint32_t iex_used() const {
17 return data >> 22;
18 }
19};
20static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
21
22// Helper function for mul_mat_id, converts ids to a more convenient format.
23// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
24// ids_dst describes the same mapping but for the dst tensor.
25// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
26template <int n_expert_used_template>
27__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
28static __global__ void mm_ids_helper(
29 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
30 const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
31 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
32 const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
33 const int expert = blockIdx.x;
34
35 extern __shared__ char data_mm_ids_helper[];
36 mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
37
38 int nex_prev = 0; // Number of columns for experts with a lower index.
39 int it_compact = 0; // Running index for the compact slice of this expert.
40
41 if constexpr (n_expert_used_template == 0) {
42 // Generic implementation:
43 for (int it = 0; it < n_tokens; ++it) {
44 int iex_used = -1; // The index at which the expert is used, if any.
45 for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
46 const int expert_used = ids[it*si1 + iex];
47 nex_prev += expert_used < expert;
48 if (expert_used == expert) {
49 iex_used = iex;
50 }
51 }
52
53 if (iex_used != -1) {
54 store[it_compact] = mm_ids_helper_store(it, iex_used);
55 }
56
57 if (warp_reduce_any<warp_size>(x: iex_used != -1)) {
58 it_compact++;
59 }
60 }
61 } else {
62 // Implementation optimized for specific numbers of experts used:
63 static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
64 const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
65 for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
66 const int it = it0 + threadIdx.x / neu_padded;
67
68 const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
69 const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
70 ids[it*si1 + iex] : INT_MAX;
71 const int iex_used = expert_used == expert ? iex : -1;
72 nex_prev += expert_used < expert;
73
74 // Whether the threads at this token position have used the expert:
75 const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
76
77 // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
78 int it_compact_add_lower = 0;
79#pragma unroll
80 for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
81 const int tmp = __shfl_up_sync(mask: 0xFFFFFFFF, val: it_compact_add_self, offset: offset, width: warp_size);
82 if (threadIdx.x >= static_cast<unsigned int>(offset)) {
83 it_compact_add_lower += tmp;
84 }
85 }
86
87 if (iex_used != -1) {
88 store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
89 }
90
91 // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
92 it_compact += __shfl_sync(mask: 0xFFFFFFFF, val: it_compact_add_lower + it_compact_add_self, offset: warp_size - 1, width: warp_size);
93 }
94 }
95 nex_prev = warp_reduce_sum<warp_size>(x: nex_prev);
96
97 for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
98 const mm_ids_helper_store store_it = store[itc];
99 const int it = store_it.it();
100 const int iex_used = store_it.iex_used();
101 ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
102 ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
103 }
104
105 if (threadIdx.x != 0) {
106 return;
107 }
108
109 expert_bounds[expert] = nex_prev;
110
111 if (expert < static_cast<int>(gridDim.x) - 1) {
112 return;
113 }
114
115 expert_bounds[gridDim.x] = nex_prev + it_compact;
116}
117
118template <int n_expert_used_template>
119static void launch_mm_ids_helper(
120 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
121 const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
122 GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
123 GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
124
125 const int id = ggml_cuda_get_device();
126 const int warp_size = ggml_cuda_info().devices[id].warp_size;
127 const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
128 CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
129
130 const dim3 num_blocks(n_experts, 1, 1);
131 const dim3 block_size(warp_size, 1, 1);
132 const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
133 GGML_ASSERT(nbytes_shared <= smpbo);
134 mm_ids_helper<n_expert_used_template><<<gridDim: num_blocks, blockDim: block_size, sharedMem: nbytes_shared, stream>>>
135 (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
136}
137
138void ggml_cuda_launch_mm_ids_helper(
139 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
140 const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
141 switch (n_expert_used) {
142 case 2:
143 launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
144 break;
145 case 4:
146 launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
147 break;
148 case 6:
149 launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
150 break;
151 case 8:
152 launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
153 break;
154 case 16:
155 launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
156 break;
157 case 32:
158 launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
159 break;
160 default:
161 launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used_var: n_expert_used, nchannels_y, si1, sis1, stream);
162 break;
163 }
164}
165