1#include "norm.cuh"
2#include <cstdint>
3
4template <int block_size>
5static __global__ void norm_f32(
6 const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7 const int64_t stride_sample, const float eps) {
8 const int nrows = gridDim.x;
9 const int nchannels = gridDim.y;
10
11 const int row = blockIdx.x;
12 const int channel = blockIdx.y;
13 const int sample = blockIdx.z;
14 const int tid = threadIdx.x;
15
16 x += sample*stride_sample + channel*stride_channel + row*stride_row;
17 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
18
19 float2 mean_var = make_float2(x: 0.0f, y: 0.0f);
20
21 for (int col = tid; col < ncols; col += block_size) {
22 const float xi = x[col];
23 mean_var.x += xi;
24 mean_var.y += xi * xi;
25 }
26
27 // sum up partial sums
28 mean_var = warp_reduce_sum(a: mean_var);
29 if constexpr (block_size > WARP_SIZE) {
30 static_assert(block_size == 1024, "unexpected block_size");
31 __shared__ float2 s_sum[32];
32 const int warp_id = threadIdx.x / WARP_SIZE;
33 const int lane_id = threadIdx.x % WARP_SIZE;
34 if (lane_id == 0) {
35 s_sum[warp_id] = mean_var;
36 }
37 __syncthreads();
38 mean_var = s_sum[lane_id];
39 mean_var = warp_reduce_sum(a: mean_var);
40 }
41
42 const float mean = mean_var.x / ncols;
43 const float var = mean_var.y / ncols - mean * mean;
44 const float inv_std = rsqrtf(a: var + eps);
45
46 for (int col = tid; col < ncols; col += block_size) {
47 dst[col] = (x[col] - mean) * inv_std;
48 }
49}
50
51template <int block_size>
52static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
53 // blockIdx.x: num_groups idx
54 // threadIdx.x: block_size idx
55 const int start = blockIdx.x*group_size + threadIdx.x;
56 const int end = min(a: blockIdx.x*group_size + group_size, b: ne_elements);
57
58 float tmp = 0.0f; // partial sum for thread in warp
59
60 for (int j = start; j < end; j += block_size) {
61 tmp += x[j];
62 }
63
64 tmp = warp_reduce_sum(x: tmp);
65 if constexpr (block_size > WARP_SIZE) {
66 static_assert(block_size == 1024, "unexpected block_size");
67 __shared__ float s_sum[32];
68 const int warp_id = threadIdx.x / WARP_SIZE;
69 const int lane_id = threadIdx.x % WARP_SIZE;
70 if (lane_id == 0) {
71 s_sum[warp_id] = tmp;
72 }
73 __syncthreads();
74 tmp = s_sum[lane_id];
75 tmp = warp_reduce_sum(x: tmp);
76 }
77
78 const float mean = tmp / group_size;
79 tmp = 0.0f;
80
81 for (int j = start; j < end; j += block_size) {
82 const float xi = x[j] - mean;
83 dst[j] = xi;
84 tmp += xi * xi;
85 }
86
87 tmp = warp_reduce_sum(x: tmp);
88 if (block_size > WARP_SIZE) {
89 __shared__ float s_sum[32];
90 const int warp_id = threadIdx.x / WARP_SIZE;
91 const int lane_id = threadIdx.x % WARP_SIZE;
92 if (lane_id == 0) {
93 s_sum[warp_id] = tmp;
94 }
95 __syncthreads();
96 tmp = s_sum[lane_id];
97 tmp = warp_reduce_sum(x: tmp);
98 }
99
100 const float variance = tmp / group_size;
101 const float scale = rsqrtf(a: variance + eps);
102 for (int j = start; j < end; j += block_size) {
103 dst[j] *= scale;
104 }
105}
106
107template <int block_size, bool do_multiply = false, bool do_add = false>
108static __global__ void rms_norm_f32(const float * x,
109 float * dst,
110 const int ncols,
111 const int64_t stride_row,
112 const int64_t stride_channel,
113 const int64_t stride_sample,
114 const float eps,
115 const float * mul = nullptr,
116 const int64_t mul_stride_row = 0,
117 const int64_t mul_stride_channel = 0,
118 const int64_t mul_stride_sample = 0,
119 const uint3 mul_ncols_packed = make_uint3(x: 0, y: 0, z: 0),
120 const uint3 mul_nrows_packed = make_uint3(x: 0, y: 0, z: 0),
121 const uint3 mul_nchannels_packed = make_uint3(x: 0, y: 0, z: 0),
122 const uint3 mul_nsamples_packed = make_uint3(x: 0, y: 0, z: 0),
123 const float * add = nullptr,
124 const int64_t add_stride_row = 0,
125 const int64_t add_stride_channel = 0,
126 const int64_t add_stride_sample = 0,
127 const uint3 add_ncols_packed = make_uint3(x: 0, y: 0, z: 0),
128 const uint3 add_nrows_packed = make_uint3(x: 0, y: 0, z: 0),
129 const uint3 add_nchannels_packed = make_uint3(x: 0, y: 0, z: 0),
130 const uint3 add_nsamples_packed = make_uint3(x: 0, y: 0, z: 0)) {
131 const int nrows = gridDim.x;
132 const int nchannels = gridDim.y;
133
134 const int row = blockIdx.x;
135 const int channel = blockIdx.y;
136 const int sample = blockIdx.z;
137 const int tid = threadIdx.x;
138
139 static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
140
141 x += sample*stride_sample + channel*stride_channel + row*stride_row;
142 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
143
144 if constexpr (do_multiply) {
145 const uint32_t mul_row = fastmodulo(n: row, fastdiv_values: mul_nrows_packed);
146 const uint32_t mul_channel = fastmodulo(n: channel, fastdiv_values: mul_nchannels_packed);
147 const uint32_t mul_sample = fastmodulo(n: sample, fastdiv_values: mul_nsamples_packed);
148 mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
149 }
150
151 if constexpr (do_add) {
152 const int add_row = fastmodulo(n: row, fastdiv_values: add_nrows_packed);
153 const int add_channel = fastmodulo(n: channel, fastdiv_values: add_nchannels_packed);
154 const int add_sample = fastmodulo(n: sample, fastdiv_values: add_nsamples_packed);
155 add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
156 }
157
158 float tmp = 0.0f; // partial sum for thread in warp
159
160 for (int col = tid; col < ncols; col += block_size) {
161 const float xi = x[col];
162 tmp += xi * xi;
163 }
164
165 // sum up partial sums
166 tmp = warp_reduce_sum(x: tmp);
167 if constexpr (block_size > WARP_SIZE) {
168 static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
169 __shared__ float s_sum[32];
170 const int warp_id = tid / WARP_SIZE;
171 const int lane_id = tid % WARP_SIZE;
172 if (lane_id == 0) {
173 s_sum[warp_id] = tmp;
174 }
175 __syncthreads();
176 tmp = 0.0f;
177 if (lane_id < (block_size / WARP_SIZE)) {
178 tmp = s_sum[lane_id];
179 }
180 tmp = warp_reduce_sum(x: tmp);
181 }
182
183 const float mean = tmp / ncols;
184 const float scale = rsqrtf(a: mean + eps);
185
186 for (int col = tid; col < ncols; col += block_size) {
187 if constexpr (do_multiply && do_add) {
188 const int mul_col = fastmodulo(n: col, fastdiv_values: mul_ncols_packed);
189 const int add_col = fastmodulo(n: col, fastdiv_values: add_ncols_packed);
190 dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
191 } else if constexpr (do_multiply) {
192 const int mul_col = fastmodulo(n: col, fastdiv_values: mul_ncols_packed);
193 dst[col] = scale * x[col] * mul[mul_col];
194 } else {
195 dst[col] = scale * x[col];
196 }
197 }
198}
199
200template <int block_size>
201static __global__ void rms_norm_back_f32(
202 const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
203 const int row = blockIdx.x*blockDim.y + threadIdx.y;
204 const int tid = threadIdx.x;
205
206 grad += int64_t(row)*ncols;
207 xf += int64_t(row)*ncols;
208 dst += int64_t(row)*ncols;
209
210 float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
211 float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
212
213 for (int col = tid; col < ncols; col += block_size) {
214 const float xfi = xf[col];
215 sum_xx += xfi * xfi;
216 sum_xg += xfi * grad[col];
217 }
218
219 // sum up partial sums
220 sum_xx = warp_reduce_sum(x: sum_xx);
221 sum_xg = warp_reduce_sum(x: sum_xg);
222 if constexpr (block_size > WARP_SIZE) {
223 static_assert(block_size == 1024, "unexpected block_size");
224 __shared__ float s_sum_xx[32];
225 __shared__ float s_sum_xg[32];
226 const int warp_id = threadIdx.x / WARP_SIZE;
227 const int lane_id = threadIdx.x % WARP_SIZE;
228 if (lane_id == 0) {
229 s_sum_xx[warp_id] = sum_xx;
230 s_sum_xg[warp_id] = sum_xg;
231 }
232 __syncthreads();
233
234 sum_xx = s_sum_xx[lane_id];
235 sum_xx = warp_reduce_sum(x: sum_xx);
236
237 sum_xg = s_sum_xg[lane_id];
238 sum_xg = warp_reduce_sum(x: sum_xg);
239 }
240
241 const float mean_eps = sum_xx / ncols + eps;
242 const float sum_eps = sum_xx + ncols*eps;
243
244 const float scale_grad = rsqrtf(a: mean_eps);
245 const float scale_x = -scale_grad * sum_xg/sum_eps;
246
247 for (int col = tid; col < ncols; col += block_size) {
248 dst[col] = scale_grad*grad[col] + scale_x*xf[col];
249 }
250}
251
252// template <int block_size>
253// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
254// const int row = blockIdx.x*blockDim.y + threadIdx.y;
255// const int tid = threadIdx.x;
256
257// float tmp = 0.0f; // partial sum for thread in warp
258
259// for (int col = tid; col < ncols; col += block_size) {
260// const float xi = x[row*ncols + col];
261// tmp += xi * xi;
262// }
263
264// // sum up partial sums
265// tmp = warp_reduce_sum(tmp);
266// if (block_size > WARP_SIZE) {
267// __shared__ float s_sum[32];
268// int warp_id = threadIdx.x / WARP_SIZE;
269// int lane_id = threadIdx.x % WARP_SIZE;
270// if (lane_id == 0) {
271// s_sum[warp_id] = tmp;
272// }
273// __syncthreads();
274// tmp = s_sum[lane_id];
275// tmp = warp_reduce_sum(tmp);
276// }
277
278// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
279// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
280
281// for (int col = tid; col < ncols; col += block_size) {
282// dst[row*ncols + col] = scale * x[row*ncols + col];
283// }
284// }
285
286template <int block_size>
287static __global__ void l2_norm_f32(
288 const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
289 const int64_t stride_sample, const float eps) {
290 const int nrows = gridDim.x;
291 const int nchannels = gridDim.y;
292
293 const int row = blockIdx.x;
294 const int channel = blockIdx.y;
295 const int sample = blockIdx.z;
296 const int tid = threadIdx.x;
297
298 x += sample*stride_sample + channel*stride_channel + row*stride_row;
299 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
300
301 float tmp = 0.0f; // partial sum for thread in warp
302
303 for (int col = tid; col < ncols; col += block_size) {
304 const float xi = x[col];
305 tmp += xi * xi;
306 }
307
308 // sum up partial sums
309 tmp = warp_reduce_sum(x: tmp);
310 if constexpr (block_size > WARP_SIZE) {
311 static_assert(block_size == 1024, "unexpected block_size");
312 __shared__ float s_sum[32];
313 const int warp_id = threadIdx.x / WARP_SIZE;
314 const int lane_id = threadIdx.x % WARP_SIZE;
315 if (lane_id == 0) {
316 s_sum[warp_id] = tmp;
317 }
318 __syncthreads();
319 tmp = s_sum[lane_id];
320 tmp = warp_reduce_sum(x: tmp);
321 }
322
323 // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
324 const float scale = rsqrtf(a: fmaxf(a: tmp, b: eps * eps));
325
326 for (int col = tid; col < ncols; col += block_size) {
327 dst[col] = scale * x[col];
328 }
329}
330
331static void norm_f32_cuda(
332 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
333 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
334 const dim3 blocks_num(nrows, nchannels, nsamples);
335 if (ncols < 1024) {
336 const dim3 block_dims(WARP_SIZE, 1, 1);
337 norm_f32<WARP_SIZE><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
338 } else {
339 const dim3 block_dims(1024, 1, 1);
340 norm_f32<1024><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
341 }
342}
343
344static void group_norm_f32_cuda(
345 const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
346 if (group_size < 1024) {
347 const dim3 block_dims(WARP_SIZE, 1, 1);
348 group_norm_f32<WARP_SIZE><<<gridDim: num_groups, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, group_size, ne_elements, eps);
349 } else {
350 const dim3 block_dims(1024, 1, 1);
351 group_norm_f32<1024><<<gridDim: num_groups, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, group_size, ne_elements, eps);
352 }
353}
354
355static void rms_norm_f32_cuda(
356 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
357 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
358 const dim3 blocks_num(nrows, nchannels, nsamples);
359 if (ncols < 1024) {
360 const dim3 block_dims(256, 1, 1);
361 rms_norm_f32<256, false><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
362 } else {
363 const dim3 block_dims(1024, 1, 1);
364 rms_norm_f32<1024, false><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
365 }
366}
367
368static void rms_norm_mul_f32_cuda(const float * x,
369 const float * mul,
370 const float * add,
371 float * dst,
372 const int ncols,
373 const int nrows,
374 const int nchannels,
375 const int nsamples,
376 const int64_t stride_row,
377 const int64_t stride_channel,
378 const int64_t stride_sample,
379 const int64_t mul_stride_row,
380 const int64_t mul_stride_channel,
381 const int64_t mul_stride_sample,
382 const uint32_t mul_ncols,
383 const uint32_t mul_nrows,
384 const uint32_t mul_nchannels,
385 const uint32_t mul_nsamples,
386 const int64_t add_stride_row,
387 const int64_t add_stride_channel,
388 const int64_t add_stride_sample,
389 const uint32_t add_ncols,
390 const uint32_t add_nrows,
391 const uint32_t add_nchannels,
392 const uint32_t add_nsamples,
393 const float eps,
394 cudaStream_t stream) {
395 const dim3 blocks_num(nrows, nchannels, nsamples);
396 if (mul == nullptr) {
397 rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
398 return;
399 }
400 if (add == nullptr) {
401 const uint3 mul_ncols_packed = init_fastdiv_values(d_64: mul_ncols);
402 const uint3 mul_nrows_packed = init_fastdiv_values(d_64: mul_nrows);
403 const uint3 mul_nchannels_packed = init_fastdiv_values(d_64: mul_nchannels);
404 const uint3 mul_nsamples_packed = init_fastdiv_values(d_64: mul_nsamples);
405 if (ncols < 1024) {
406 const dim3 block_dims(256, 1, 1);
407 rms_norm_f32<256, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(
408 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
409 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
410 } else {
411 const dim3 block_dims(1024, 1, 1);
412 rms_norm_f32<1024, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(
413 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
414 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
415 }
416 } else {
417 const uint3 mul_ncols_packed = init_fastdiv_values(d_64: mul_ncols);
418 const uint3 mul_nrows_packed = init_fastdiv_values(d_64: mul_nrows);
419 const uint3 mul_nchannels_packed = init_fastdiv_values(d_64: mul_nchannels);
420 const uint3 mul_nsamples_packed = init_fastdiv_values(d_64: mul_nsamples);
421
422 const uint3 add_ncols_packed = init_fastdiv_values(d_64: add_ncols);
423 const uint3 add_nrows_packed = init_fastdiv_values(d_64: add_nrows);
424 const uint3 add_nchannels_packed = init_fastdiv_values(d_64: add_nchannels);
425 const uint3 add_nsamples_packed = init_fastdiv_values(d_64: add_nsamples);
426 if (ncols < 1024) {
427 const dim3 block_dims(256, 1, 1);
428 rms_norm_f32<256, true, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(
429 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
430 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
431 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
432 add_nchannels_packed, add_nsamples_packed);
433 } else {
434 const dim3 block_dims(1024, 1, 1);
435 rms_norm_f32<1024, true, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(
436 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
437 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
438 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
439 add_nchannels_packed, add_nsamples_packed);
440 }
441 }
442}
443
444static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
445 if (ncols < 1024) {
446 const dim3 block_dims(WARP_SIZE, 1, 1);
447 rms_norm_back_f32<WARP_SIZE><<<gridDim: nrows, blockDim: block_dims, sharedMem: 0, stream>>>(grad, xf, dst, ncols, eps);
448 } else {
449 const dim3 block_dims(1024, 1, 1);
450 rms_norm_back_f32<1024><<<gridDim: nrows, blockDim: block_dims, sharedMem: 0, stream>>>(grad, xf, dst, ncols, eps);
451 }
452}
453
454static void l2_norm_f32_cuda(
455 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
456 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
457 const dim3 blocks_num(nrows, nchannels, nsamples);
458 if (ncols < 1024) {
459 const dim3 block_dims(WARP_SIZE, 1, 1);
460 l2_norm_f32<WARP_SIZE><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
461 } else {
462 const dim3 block_dims(1024, 1, 1);
463 l2_norm_f32<1024><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
464 }
465}
466
467void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
468 const ggml_tensor * src0 = dst->src[0];
469 const float * src0_d = (const float *) src0->data;
470 float * dst_d = (float *) dst->data;
471 cudaStream_t stream = ctx.stream();
472
473 GGML_ASSERT(src0->type == GGML_TYPE_F32);
474 GGML_ASSERT( dst->type == GGML_TYPE_F32);
475
476 GGML_TENSOR_UNARY_OP_LOCALS;
477
478 float eps;
479 memcpy(&eps, dst->op_params, sizeof(float));
480 GGML_ASSERT(eps >= 0.0f);
481
482 const size_t ts0 = ggml_type_size(src0->type);
483 GGML_ASSERT(nb00 == ts0);
484 const int64_t s01 = nb01 / ts0;
485 const int64_t s02 = nb02 / ts0;
486 const int64_t s03 = nb03 / ts0;
487
488 norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
489}
490
491void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
492 const ggml_tensor * src0 = dst->src[0];
493 const float * src0_d = (const float *)src0->data;
494 float * dst_d = (float *)dst->data;
495 cudaStream_t stream = ctx.stream();
496
497 GGML_ASSERT(src0->type == GGML_TYPE_F32);
498 GGML_ASSERT( dst->type == GGML_TYPE_F32);
499
500 int num_groups = dst->op_params[0];
501
502 float eps;
503 memcpy(&eps, dst->op_params + 1, sizeof(float));
504 GGML_ASSERT(eps >= 0.0f);
505
506 int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
507 group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
508}
509
510void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
511 const ggml_tensor * src0 = dst->src[0];
512 const float * src0_d = (const float *) src0->data;
513 float * dst_d = (float *) dst->data;
514 cudaStream_t stream = ctx.stream();
515
516 GGML_ASSERT(src0->type == GGML_TYPE_F32);
517 GGML_ASSERT( dst->type == GGML_TYPE_F32);
518
519 GGML_TENSOR_UNARY_OP_LOCALS;
520
521 float eps;
522 memcpy(&eps, dst->op_params, sizeof(float));
523 GGML_ASSERT(eps >= 0.0f);
524
525 const size_t ts0 = ggml_type_size(src0->type);
526 GGML_ASSERT(nb00 == ts0);
527 const int64_t s01 = nb01 / ts0;
528 const int64_t s02 = nb02 / ts0;
529 const int64_t s03 = nb03 / ts0;
530
531 rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
532}
533
534void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
535 const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
536 float eps = 0.0f;
537
538 memcpy(&eps, dst->op_params, sizeof(float));
539
540 const float * src0_d = (const float *) rms_norm_src->data;
541 const float * mul_d = nullptr;
542 const ggml_tensor * mul_src = nullptr;
543
544 if (mul_tensor->src[0] == dst) {
545 mul_d = (float *) mul_tensor->src[1]->data;
546 mul_src = mul_tensor->src[1];
547 } else if(mul_tensor->src[1] == dst) {
548 mul_d = (float *) mul_tensor->src[0]->data;
549 mul_src = mul_tensor->src[0];
550 } else {
551 GGML_ASSERT(false);
552 }
553
554 float * dst_d = (float *) mul_tensor->data;
555 cudaStream_t stream = ctx.stream();
556
557 GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
558 GGML_ASSERT(dst->type == GGML_TYPE_F32);
559 GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
560 GGML_ASSERT(eps >= 0.0f);
561
562 const int64_t ne00 = rms_norm_src->ne[0];
563 const int64_t ne01 = rms_norm_src->ne[1];
564 const int64_t ne02 = rms_norm_src->ne[2];
565 const int64_t ne03 = rms_norm_src->ne[3];
566
567 const size_t ts0 = ggml_type_size(rms_norm_src->type);
568 GGML_ASSERT(rms_norm_src->nb[0] == ts0);
569 const int64_t s01 = rms_norm_src->nb[1] / ts0;
570 const int64_t s02 = rms_norm_src->nb[2] / ts0;
571 const int64_t s03 = rms_norm_src->nb[3] / ts0;
572
573 const size_t ts_mul = ggml_type_size(mul_src->type);
574 GGML_ASSERT(mul_src->nb[0] == ts_mul);
575 const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
576 const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
577 const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
578
579 const int mul_ncols = mul_src->ne[0];
580 const int mul_nrows = mul_src->ne[1];
581 const int mul_nchannels = mul_src->ne[2];
582 const int mul_nsamples = mul_src->ne[3];
583
584 rms_norm_mul_f32_cuda(x: src0_d, mul: mul_d, add: nullptr, dst: dst_d,
585 ncols: ne00, nrows: ne01, nchannels: ne02, nsamples: ne03,
586 /*s00*/ stride_row: s01, stride_channel: s02, stride_sample: s03,
587 /*mul_s00*/ mul_stride_row: mul_s01, mul_stride_channel: mul_s02, mul_stride_sample: mul_s03,
588 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
589 /*add_s00*/ add_stride_row: 0, add_stride_channel: 0, add_stride_sample: 0,
590 add_ncols: 0, add_nrows: 0, add_nchannels: 0, add_nsamples: 0,
591 eps, stream);
592}
593
594void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
595 ggml_tensor * dst,
596 ggml_tensor * mul_tensor,
597 ggml_tensor * add_tensor) {
598 const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
599 float eps = 0.0f;
600
601 memcpy(&eps, dst->op_params, sizeof(float));
602
603 const float * src0_d = (const float *) rms_norm_src->data;
604 const float * mul_d = nullptr;
605 const ggml_tensor * mul_src = nullptr;
606
607 if (mul_tensor->src[0] == dst) {
608 mul_d = (float *) mul_tensor->src[1]->data;
609 mul_src = mul_tensor->src[1];
610 } else if (mul_tensor->src[1] == dst) {
611 mul_d = (float *) mul_tensor->src[0]->data;
612 mul_src = mul_tensor->src[0];
613 } else {
614 GGML_ASSERT(false);
615 }
616
617 const float * add_d = nullptr;
618 const ggml_tensor * add_src = nullptr;
619
620 if (add_tensor->src[0] == mul_tensor) {
621 add_d = (float *) add_tensor->src[1]->data;
622 add_src = add_tensor->src[1];
623 } else if (add_tensor->src[1] == mul_tensor) {
624 add_d = (float *) add_tensor->src[0]->data;
625 add_src = add_tensor->src[0];
626 } else {
627 GGML_ASSERT(false);
628 }
629
630 float * dst_d = (float *) add_tensor->data;
631 cudaStream_t stream = ctx.stream();
632
633 GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
634 GGML_ASSERT(dst->type == GGML_TYPE_F32);
635 GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
636 GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
637 GGML_ASSERT(eps >= 0.0f);
638
639 const int64_t ne00 = rms_norm_src->ne[0];
640 const int64_t ne01 = rms_norm_src->ne[1];
641 const int64_t ne02 = rms_norm_src->ne[2];
642 const int64_t ne03 = rms_norm_src->ne[3];
643
644 const size_t ts0 = ggml_type_size(rms_norm_src->type);
645 GGML_ASSERT(rms_norm_src->nb[0] == ts0);
646 const int64_t s01 = rms_norm_src->nb[1] / ts0;
647 const int64_t s02 = rms_norm_src->nb[2] / ts0;
648 const int64_t s03 = rms_norm_src->nb[3] / ts0;
649
650 const size_t ts_mul = ggml_type_size(mul_src->type);
651 GGML_ASSERT(mul_src->nb[0] == ts_mul);
652 const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
653 const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
654 const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
655
656 const int mul_ncols = mul_src->ne[0];
657 const int mul_nrows = mul_src->ne[1];
658 const int mul_nchannels = mul_src->ne[2];
659 const int mul_nsamples = mul_src->ne[3];
660
661 const size_t ts_add = ggml_type_size(add_src->type);
662 GGML_ASSERT(add_src->nb[0] == ts_add);
663 const int64_t add_s01 = add_src->nb[1] / ts_add;
664 const int64_t add_s02 = add_src->nb[2] / ts_add;
665 const int64_t add_s03 = add_src->nb[3] / ts_add;
666
667 const int add_ncols = add_src->ne[0];
668 const int add_nrows = add_src->ne[1];
669 const int add_nchannels = add_src->ne[2];
670 const int add_nsamples = add_src->ne[3];
671
672 rms_norm_mul_f32_cuda(x: src0_d, mul: mul_d,add: add_d,dst: dst_d,
673 ncols: ne00,nrows: ne01, nchannels: ne02, nsamples: ne03,
674 /*s00*/ stride_row: s01, stride_channel: s02, stride_sample: s03,
675 /*mul_s00*/ mul_stride_row: mul_s01, mul_stride_channel: mul_s02, mul_stride_sample: mul_s03,
676 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
677 /*add_s00*/ add_stride_row: add_s01, add_stride_channel: add_s02, add_stride_sample: add_s03,
678 add_ncols, add_nrows, add_nchannels, add_nsamples,
679 eps, stream);
680}
681
682void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
683 const ggml_tensor * grad = dst->src[0]; // gradients
684 const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
685
686 const float * grad_d = (const float *) grad->data;
687 const float * src0f_d = (const float *) src0f->data;
688 float * dst_d = (float *) dst->data;
689
690 cudaStream_t stream = ctx.stream();
691
692 GGML_ASSERT(ggml_is_contiguous(grad));
693
694 GGML_ASSERT( grad->type == GGML_TYPE_F32);
695 GGML_ASSERT(src0f->type == GGML_TYPE_F32);
696 GGML_ASSERT( dst->type == GGML_TYPE_F32);
697
698 const int64_t ne00 = src0f->ne[0];
699 const int64_t nrows = ggml_nrows(src0f);
700
701 float eps;
702 memcpy(&eps, dst->op_params, sizeof(float));
703 GGML_ASSERT(eps >= 0.0f);
704
705 rms_norm_back_f32_cuda(grad: grad_d, xf: src0f_d, dst: dst_d, ncols: ne00, nrows, eps, stream);
706}
707
708void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
709 const ggml_tensor * src0 = dst->src[0];
710 const float * src0_d = (const float *) src0->data;
711 float * dst_d = (float *) dst->data;
712 cudaStream_t stream = ctx.stream();
713
714 GGML_ASSERT(src0->type == GGML_TYPE_F32);
715 GGML_ASSERT( dst->type == GGML_TYPE_F32);
716
717 GGML_TENSOR_UNARY_OP_LOCALS;
718
719 float eps;
720 memcpy(&eps, dst->op_params, sizeof(float));
721 GGML_ASSERT(eps >= 0.0f);
722
723 const size_t ts0 = ggml_type_size(src0->type);
724 GGML_ASSERT(nb00 == ts0);
725 const int64_t s01 = nb01 / ts0;
726 const int64_t s02 = nb02 / ts0;
727 const int64_t s03 = nb03 / ts0;
728
729 l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
730}
731