| 1 | #include "norm.cuh" |
| 2 | #include <cstdint> |
| 3 | |
| 4 | template <int block_size> |
| 5 | static __global__ void norm_f32( |
| 6 | const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, |
| 7 | const int64_t stride_sample, const float eps) { |
| 8 | const int nrows = gridDim.x; |
| 9 | const int nchannels = gridDim.y; |
| 10 | |
| 11 | const int row = blockIdx.x; |
| 12 | const int channel = blockIdx.y; |
| 13 | const int sample = blockIdx.z; |
| 14 | const int tid = threadIdx.x; |
| 15 | |
| 16 | x += sample*stride_sample + channel*stride_channel + row*stride_row; |
| 17 | dst += ((sample*nchannels + channel)*nrows + row)*ncols; |
| 18 | |
| 19 | float2 mean_var = make_float2(x: 0.0f, y: 0.0f); |
| 20 | |
| 21 | for (int col = tid; col < ncols; col += block_size) { |
| 22 | const float xi = x[col]; |
| 23 | mean_var.x += xi; |
| 24 | mean_var.y += xi * xi; |
| 25 | } |
| 26 | |
| 27 | // sum up partial sums |
| 28 | mean_var = warp_reduce_sum(a: mean_var); |
| 29 | if constexpr (block_size > WARP_SIZE) { |
| 30 | static_assert(block_size == 1024, "unexpected block_size" ); |
| 31 | __shared__ float2 s_sum[32]; |
| 32 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 33 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 34 | if (lane_id == 0) { |
| 35 | s_sum[warp_id] = mean_var; |
| 36 | } |
| 37 | __syncthreads(); |
| 38 | mean_var = s_sum[lane_id]; |
| 39 | mean_var = warp_reduce_sum(a: mean_var); |
| 40 | } |
| 41 | |
| 42 | const float mean = mean_var.x / ncols; |
| 43 | const float var = mean_var.y / ncols - mean * mean; |
| 44 | const float inv_std = rsqrtf(a: var + eps); |
| 45 | |
| 46 | for (int col = tid; col < ncols; col += block_size) { |
| 47 | dst[col] = (x[col] - mean) * inv_std; |
| 48 | } |
| 49 | } |
| 50 | |
| 51 | template <int block_size> |
| 52 | static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { |
| 53 | // blockIdx.x: num_groups idx |
| 54 | // threadIdx.x: block_size idx |
| 55 | const int start = blockIdx.x*group_size + threadIdx.x; |
| 56 | const int end = min(a: blockIdx.x*group_size + group_size, b: ne_elements); |
| 57 | |
| 58 | float tmp = 0.0f; // partial sum for thread in warp |
| 59 | |
| 60 | for (int j = start; j < end; j += block_size) { |
| 61 | tmp += x[j]; |
| 62 | } |
| 63 | |
| 64 | tmp = warp_reduce_sum(x: tmp); |
| 65 | if constexpr (block_size > WARP_SIZE) { |
| 66 | static_assert(block_size == 1024, "unexpected block_size" ); |
| 67 | __shared__ float s_sum[32]; |
| 68 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 69 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 70 | if (lane_id == 0) { |
| 71 | s_sum[warp_id] = tmp; |
| 72 | } |
| 73 | __syncthreads(); |
| 74 | tmp = s_sum[lane_id]; |
| 75 | tmp = warp_reduce_sum(x: tmp); |
| 76 | } |
| 77 | |
| 78 | const float mean = tmp / group_size; |
| 79 | tmp = 0.0f; |
| 80 | |
| 81 | for (int j = start; j < end; j += block_size) { |
| 82 | const float xi = x[j] - mean; |
| 83 | dst[j] = xi; |
| 84 | tmp += xi * xi; |
| 85 | } |
| 86 | |
| 87 | tmp = warp_reduce_sum(x: tmp); |
| 88 | if (block_size > WARP_SIZE) { |
| 89 | __shared__ float s_sum[32]; |
| 90 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 91 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 92 | if (lane_id == 0) { |
| 93 | s_sum[warp_id] = tmp; |
| 94 | } |
| 95 | __syncthreads(); |
| 96 | tmp = s_sum[lane_id]; |
| 97 | tmp = warp_reduce_sum(x: tmp); |
| 98 | } |
| 99 | |
| 100 | const float variance = tmp / group_size; |
| 101 | const float scale = rsqrtf(a: variance + eps); |
| 102 | for (int j = start; j < end; j += block_size) { |
| 103 | dst[j] *= scale; |
| 104 | } |
| 105 | } |
| 106 | |
| 107 | template <int block_size, bool do_multiply = false, bool do_add = false> |
| 108 | static __global__ void rms_norm_f32(const float * x, |
| 109 | float * dst, |
| 110 | const int ncols, |
| 111 | const int64_t stride_row, |
| 112 | const int64_t stride_channel, |
| 113 | const int64_t stride_sample, |
| 114 | const float eps, |
| 115 | const float * mul = nullptr, |
| 116 | const int64_t mul_stride_row = 0, |
| 117 | const int64_t mul_stride_channel = 0, |
| 118 | const int64_t mul_stride_sample = 0, |
| 119 | const uint3 mul_ncols_packed = make_uint3(x: 0, y: 0, z: 0), |
| 120 | const uint3 mul_nrows_packed = make_uint3(x: 0, y: 0, z: 0), |
| 121 | const uint3 mul_nchannels_packed = make_uint3(x: 0, y: 0, z: 0), |
| 122 | const uint3 mul_nsamples_packed = make_uint3(x: 0, y: 0, z: 0), |
| 123 | const float * add = nullptr, |
| 124 | const int64_t add_stride_row = 0, |
| 125 | const int64_t add_stride_channel = 0, |
| 126 | const int64_t add_stride_sample = 0, |
| 127 | const uint3 add_ncols_packed = make_uint3(x: 0, y: 0, z: 0), |
| 128 | const uint3 add_nrows_packed = make_uint3(x: 0, y: 0, z: 0), |
| 129 | const uint3 add_nchannels_packed = make_uint3(x: 0, y: 0, z: 0), |
| 130 | const uint3 add_nsamples_packed = make_uint3(x: 0, y: 0, z: 0)) { |
| 131 | const int nrows = gridDim.x; |
| 132 | const int nchannels = gridDim.y; |
| 133 | |
| 134 | const int row = blockIdx.x; |
| 135 | const int channel = blockIdx.y; |
| 136 | const int sample = blockIdx.z; |
| 137 | const int tid = threadIdx.x; |
| 138 | |
| 139 | static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying" ); |
| 140 | |
| 141 | x += sample*stride_sample + channel*stride_channel + row*stride_row; |
| 142 | dst += ((sample*nchannels + channel)*nrows + row)*ncols; |
| 143 | |
| 144 | if constexpr (do_multiply) { |
| 145 | const uint32_t mul_row = fastmodulo(n: row, fastdiv_values: mul_nrows_packed); |
| 146 | const uint32_t mul_channel = fastmodulo(n: channel, fastdiv_values: mul_nchannels_packed); |
| 147 | const uint32_t mul_sample = fastmodulo(n: sample, fastdiv_values: mul_nsamples_packed); |
| 148 | mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; |
| 149 | } |
| 150 | |
| 151 | if constexpr (do_add) { |
| 152 | const int add_row = fastmodulo(n: row, fastdiv_values: add_nrows_packed); |
| 153 | const int add_channel = fastmodulo(n: channel, fastdiv_values: add_nchannels_packed); |
| 154 | const int add_sample = fastmodulo(n: sample, fastdiv_values: add_nsamples_packed); |
| 155 | add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; |
| 156 | } |
| 157 | |
| 158 | float tmp = 0.0f; // partial sum for thread in warp |
| 159 | |
| 160 | for (int col = tid; col < ncols; col += block_size) { |
| 161 | const float xi = x[col]; |
| 162 | tmp += xi * xi; |
| 163 | } |
| 164 | |
| 165 | // sum up partial sums |
| 166 | tmp = warp_reduce_sum(x: tmp); |
| 167 | if constexpr (block_size > WARP_SIZE) { |
| 168 | static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size" ); |
| 169 | __shared__ float s_sum[32]; |
| 170 | const int warp_id = tid / WARP_SIZE; |
| 171 | const int lane_id = tid % WARP_SIZE; |
| 172 | if (lane_id == 0) { |
| 173 | s_sum[warp_id] = tmp; |
| 174 | } |
| 175 | __syncthreads(); |
| 176 | tmp = 0.0f; |
| 177 | if (lane_id < (block_size / WARP_SIZE)) { |
| 178 | tmp = s_sum[lane_id]; |
| 179 | } |
| 180 | tmp = warp_reduce_sum(x: tmp); |
| 181 | } |
| 182 | |
| 183 | const float mean = tmp / ncols; |
| 184 | const float scale = rsqrtf(a: mean + eps); |
| 185 | |
| 186 | for (int col = tid; col < ncols; col += block_size) { |
| 187 | if constexpr (do_multiply && do_add) { |
| 188 | const int mul_col = fastmodulo(n: col, fastdiv_values: mul_ncols_packed); |
| 189 | const int add_col = fastmodulo(n: col, fastdiv_values: add_ncols_packed); |
| 190 | dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; |
| 191 | } else if constexpr (do_multiply) { |
| 192 | const int mul_col = fastmodulo(n: col, fastdiv_values: mul_ncols_packed); |
| 193 | dst[col] = scale * x[col] * mul[mul_col]; |
| 194 | } else { |
| 195 | dst[col] = scale * x[col]; |
| 196 | } |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | template <int block_size> |
| 201 | static __global__ void rms_norm_back_f32( |
| 202 | const float * grad, const float * xf, float * dst, const int ncols, const float eps) { |
| 203 | const int row = blockIdx.x*blockDim.y + threadIdx.y; |
| 204 | const int tid = threadIdx.x; |
| 205 | |
| 206 | grad += int64_t(row)*ncols; |
| 207 | xf += int64_t(row)*ncols; |
| 208 | dst += int64_t(row)*ncols; |
| 209 | |
| 210 | float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass |
| 211 | float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs |
| 212 | |
| 213 | for (int col = tid; col < ncols; col += block_size) { |
| 214 | const float xfi = xf[col]; |
| 215 | sum_xx += xfi * xfi; |
| 216 | sum_xg += xfi * grad[col]; |
| 217 | } |
| 218 | |
| 219 | // sum up partial sums |
| 220 | sum_xx = warp_reduce_sum(x: sum_xx); |
| 221 | sum_xg = warp_reduce_sum(x: sum_xg); |
| 222 | if constexpr (block_size > WARP_SIZE) { |
| 223 | static_assert(block_size == 1024, "unexpected block_size" ); |
| 224 | __shared__ float s_sum_xx[32]; |
| 225 | __shared__ float s_sum_xg[32]; |
| 226 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 227 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 228 | if (lane_id == 0) { |
| 229 | s_sum_xx[warp_id] = sum_xx; |
| 230 | s_sum_xg[warp_id] = sum_xg; |
| 231 | } |
| 232 | __syncthreads(); |
| 233 | |
| 234 | sum_xx = s_sum_xx[lane_id]; |
| 235 | sum_xx = warp_reduce_sum(x: sum_xx); |
| 236 | |
| 237 | sum_xg = s_sum_xg[lane_id]; |
| 238 | sum_xg = warp_reduce_sum(x: sum_xg); |
| 239 | } |
| 240 | |
| 241 | const float mean_eps = sum_xx / ncols + eps; |
| 242 | const float sum_eps = sum_xx + ncols*eps; |
| 243 | |
| 244 | const float scale_grad = rsqrtf(a: mean_eps); |
| 245 | const float scale_x = -scale_grad * sum_xg/sum_eps; |
| 246 | |
| 247 | for (int col = tid; col < ncols; col += block_size) { |
| 248 | dst[col] = scale_grad*grad[col] + scale_x*xf[col]; |
| 249 | } |
| 250 | } |
| 251 | |
| 252 | // template <int block_size> |
| 253 | // static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) { |
| 254 | // const int row = blockIdx.x*blockDim.y + threadIdx.y; |
| 255 | // const int tid = threadIdx.x; |
| 256 | |
| 257 | // float tmp = 0.0f; // partial sum for thread in warp |
| 258 | |
| 259 | // for (int col = tid; col < ncols; col += block_size) { |
| 260 | // const float xi = x[row*ncols + col]; |
| 261 | // tmp += xi * xi; |
| 262 | // } |
| 263 | |
| 264 | // // sum up partial sums |
| 265 | // tmp = warp_reduce_sum(tmp); |
| 266 | // if (block_size > WARP_SIZE) { |
| 267 | // __shared__ float s_sum[32]; |
| 268 | // int warp_id = threadIdx.x / WARP_SIZE; |
| 269 | // int lane_id = threadIdx.x % WARP_SIZE; |
| 270 | // if (lane_id == 0) { |
| 271 | // s_sum[warp_id] = tmp; |
| 272 | // } |
| 273 | // __syncthreads(); |
| 274 | // tmp = s_sum[lane_id]; |
| 275 | // tmp = warp_reduce_sum(tmp); |
| 276 | // } |
| 277 | |
| 278 | // // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html |
| 279 | // const float scale = rsqrtf(fmaxf(tmp, eps * eps)); |
| 280 | |
| 281 | // for (int col = tid; col < ncols; col += block_size) { |
| 282 | // dst[row*ncols + col] = scale * x[row*ncols + col]; |
| 283 | // } |
| 284 | // } |
| 285 | |
| 286 | template <int block_size> |
| 287 | static __global__ void l2_norm_f32( |
| 288 | const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, |
| 289 | const int64_t stride_sample, const float eps) { |
| 290 | const int nrows = gridDim.x; |
| 291 | const int nchannels = gridDim.y; |
| 292 | |
| 293 | const int row = blockIdx.x; |
| 294 | const int channel = blockIdx.y; |
| 295 | const int sample = blockIdx.z; |
| 296 | const int tid = threadIdx.x; |
| 297 | |
| 298 | x += sample*stride_sample + channel*stride_channel + row*stride_row; |
| 299 | dst += ((sample*nchannels + channel)*nrows + row)*ncols; |
| 300 | |
| 301 | float tmp = 0.0f; // partial sum for thread in warp |
| 302 | |
| 303 | for (int col = tid; col < ncols; col += block_size) { |
| 304 | const float xi = x[col]; |
| 305 | tmp += xi * xi; |
| 306 | } |
| 307 | |
| 308 | // sum up partial sums |
| 309 | tmp = warp_reduce_sum(x: tmp); |
| 310 | if constexpr (block_size > WARP_SIZE) { |
| 311 | static_assert(block_size == 1024, "unexpected block_size" ); |
| 312 | __shared__ float s_sum[32]; |
| 313 | const int warp_id = threadIdx.x / WARP_SIZE; |
| 314 | const int lane_id = threadIdx.x % WARP_SIZE; |
| 315 | if (lane_id == 0) { |
| 316 | s_sum[warp_id] = tmp; |
| 317 | } |
| 318 | __syncthreads(); |
| 319 | tmp = s_sum[lane_id]; |
| 320 | tmp = warp_reduce_sum(x: tmp); |
| 321 | } |
| 322 | |
| 323 | // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html |
| 324 | const float scale = rsqrtf(a: fmaxf(a: tmp, b: eps * eps)); |
| 325 | |
| 326 | for (int col = tid; col < ncols; col += block_size) { |
| 327 | dst[col] = scale * x[col]; |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | static void norm_f32_cuda( |
| 332 | const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, |
| 333 | const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { |
| 334 | const dim3 blocks_num(nrows, nchannels, nsamples); |
| 335 | if (ncols < 1024) { |
| 336 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 337 | norm_f32<WARP_SIZE><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 338 | } else { |
| 339 | const dim3 block_dims(1024, 1, 1); |
| 340 | norm_f32<1024><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 341 | } |
| 342 | } |
| 343 | |
| 344 | static void group_norm_f32_cuda( |
| 345 | const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { |
| 346 | if (group_size < 1024) { |
| 347 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 348 | group_norm_f32<WARP_SIZE><<<gridDim: num_groups, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, group_size, ne_elements, eps); |
| 349 | } else { |
| 350 | const dim3 block_dims(1024, 1, 1); |
| 351 | group_norm_f32<1024><<<gridDim: num_groups, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, group_size, ne_elements, eps); |
| 352 | } |
| 353 | } |
| 354 | |
| 355 | static void rms_norm_f32_cuda( |
| 356 | const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, |
| 357 | const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { |
| 358 | const dim3 blocks_num(nrows, nchannels, nsamples); |
| 359 | if (ncols < 1024) { |
| 360 | const dim3 block_dims(256, 1, 1); |
| 361 | rms_norm_f32<256, false><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 362 | } else { |
| 363 | const dim3 block_dims(1024, 1, 1); |
| 364 | rms_norm_f32<1024, false><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 365 | } |
| 366 | } |
| 367 | |
| 368 | static void rms_norm_mul_f32_cuda(const float * x, |
| 369 | const float * mul, |
| 370 | const float * add, |
| 371 | float * dst, |
| 372 | const int ncols, |
| 373 | const int nrows, |
| 374 | const int nchannels, |
| 375 | const int nsamples, |
| 376 | const int64_t stride_row, |
| 377 | const int64_t stride_channel, |
| 378 | const int64_t stride_sample, |
| 379 | const int64_t mul_stride_row, |
| 380 | const int64_t mul_stride_channel, |
| 381 | const int64_t mul_stride_sample, |
| 382 | const uint32_t mul_ncols, |
| 383 | const uint32_t mul_nrows, |
| 384 | const uint32_t mul_nchannels, |
| 385 | const uint32_t mul_nsamples, |
| 386 | const int64_t add_stride_row, |
| 387 | const int64_t add_stride_channel, |
| 388 | const int64_t add_stride_sample, |
| 389 | const uint32_t add_ncols, |
| 390 | const uint32_t add_nrows, |
| 391 | const uint32_t add_nchannels, |
| 392 | const uint32_t add_nsamples, |
| 393 | const float eps, |
| 394 | cudaStream_t stream) { |
| 395 | const dim3 blocks_num(nrows, nchannels, nsamples); |
| 396 | if (mul == nullptr) { |
| 397 | rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); |
| 398 | return; |
| 399 | } |
| 400 | if (add == nullptr) { |
| 401 | const uint3 mul_ncols_packed = init_fastdiv_values(d_64: mul_ncols); |
| 402 | const uint3 mul_nrows_packed = init_fastdiv_values(d_64: mul_nrows); |
| 403 | const uint3 mul_nchannels_packed = init_fastdiv_values(d_64: mul_nchannels); |
| 404 | const uint3 mul_nsamples_packed = init_fastdiv_values(d_64: mul_nsamples); |
| 405 | if (ncols < 1024) { |
| 406 | const dim3 block_dims(256, 1, 1); |
| 407 | rms_norm_f32<256, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 408 | x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, |
| 409 | mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); |
| 410 | } else { |
| 411 | const dim3 block_dims(1024, 1, 1); |
| 412 | rms_norm_f32<1024, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 413 | x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, |
| 414 | mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); |
| 415 | } |
| 416 | } else { |
| 417 | const uint3 mul_ncols_packed = init_fastdiv_values(d_64: mul_ncols); |
| 418 | const uint3 mul_nrows_packed = init_fastdiv_values(d_64: mul_nrows); |
| 419 | const uint3 mul_nchannels_packed = init_fastdiv_values(d_64: mul_nchannels); |
| 420 | const uint3 mul_nsamples_packed = init_fastdiv_values(d_64: mul_nsamples); |
| 421 | |
| 422 | const uint3 add_ncols_packed = init_fastdiv_values(d_64: add_ncols); |
| 423 | const uint3 add_nrows_packed = init_fastdiv_values(d_64: add_nrows); |
| 424 | const uint3 add_nchannels_packed = init_fastdiv_values(d_64: add_nchannels); |
| 425 | const uint3 add_nsamples_packed = init_fastdiv_values(d_64: add_nsamples); |
| 426 | if (ncols < 1024) { |
| 427 | const dim3 block_dims(256, 1, 1); |
| 428 | rms_norm_f32<256, true, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 429 | x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, |
| 430 | mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, |
| 431 | add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, |
| 432 | add_nchannels_packed, add_nsamples_packed); |
| 433 | } else { |
| 434 | const dim3 block_dims(1024, 1, 1); |
| 435 | rms_norm_f32<1024, true, true><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 436 | x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, |
| 437 | mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, |
| 438 | add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, |
| 439 | add_nchannels_packed, add_nsamples_packed); |
| 440 | } |
| 441 | } |
| 442 | } |
| 443 | |
| 444 | static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { |
| 445 | if (ncols < 1024) { |
| 446 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 447 | rms_norm_back_f32<WARP_SIZE><<<gridDim: nrows, blockDim: block_dims, sharedMem: 0, stream>>>(grad, xf, dst, ncols, eps); |
| 448 | } else { |
| 449 | const dim3 block_dims(1024, 1, 1); |
| 450 | rms_norm_back_f32<1024><<<gridDim: nrows, blockDim: block_dims, sharedMem: 0, stream>>>(grad, xf, dst, ncols, eps); |
| 451 | } |
| 452 | } |
| 453 | |
| 454 | static void l2_norm_f32_cuda( |
| 455 | const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, |
| 456 | const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { |
| 457 | const dim3 blocks_num(nrows, nchannels, nsamples); |
| 458 | if (ncols < 1024) { |
| 459 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 460 | l2_norm_f32<WARP_SIZE><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 461 | } else { |
| 462 | const dim3 block_dims(1024, 1, 1); |
| 463 | l2_norm_f32<1024><<<gridDim: blocks_num, blockDim: block_dims, sharedMem: 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); |
| 464 | } |
| 465 | } |
| 466 | |
| 467 | void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 468 | const ggml_tensor * src0 = dst->src[0]; |
| 469 | const float * src0_d = (const float *) src0->data; |
| 470 | float * dst_d = (float *) dst->data; |
| 471 | cudaStream_t stream = ctx.stream(); |
| 472 | |
| 473 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 474 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 475 | |
| 476 | GGML_TENSOR_UNARY_OP_LOCALS; |
| 477 | |
| 478 | float eps; |
| 479 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 480 | GGML_ASSERT(eps >= 0.0f); |
| 481 | |
| 482 | const size_t ts0 = ggml_type_size(src0->type); |
| 483 | GGML_ASSERT(nb00 == ts0); |
| 484 | const int64_t s01 = nb01 / ts0; |
| 485 | const int64_t s02 = nb02 / ts0; |
| 486 | const int64_t s03 = nb03 / ts0; |
| 487 | |
| 488 | norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); |
| 489 | } |
| 490 | |
| 491 | void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 492 | const ggml_tensor * src0 = dst->src[0]; |
| 493 | const float * src0_d = (const float *)src0->data; |
| 494 | float * dst_d = (float *)dst->data; |
| 495 | cudaStream_t stream = ctx.stream(); |
| 496 | |
| 497 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 498 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 499 | |
| 500 | int num_groups = dst->op_params[0]; |
| 501 | |
| 502 | float eps; |
| 503 | memcpy(&eps, dst->op_params + 1, sizeof(float)); |
| 504 | GGML_ASSERT(eps >= 0.0f); |
| 505 | |
| 506 | int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); |
| 507 | group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); |
| 508 | } |
| 509 | |
| 510 | void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 511 | const ggml_tensor * src0 = dst->src[0]; |
| 512 | const float * src0_d = (const float *) src0->data; |
| 513 | float * dst_d = (float *) dst->data; |
| 514 | cudaStream_t stream = ctx.stream(); |
| 515 | |
| 516 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 517 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 518 | |
| 519 | GGML_TENSOR_UNARY_OP_LOCALS; |
| 520 | |
| 521 | float eps; |
| 522 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 523 | GGML_ASSERT(eps >= 0.0f); |
| 524 | |
| 525 | const size_t ts0 = ggml_type_size(src0->type); |
| 526 | GGML_ASSERT(nb00 == ts0); |
| 527 | const int64_t s01 = nb01 / ts0; |
| 528 | const int64_t s02 = nb02 / ts0; |
| 529 | const int64_t s03 = nb03 / ts0; |
| 530 | |
| 531 | rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); |
| 532 | } |
| 533 | |
| 534 | void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) { |
| 535 | const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; |
| 536 | float eps = 0.0f; |
| 537 | |
| 538 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 539 | |
| 540 | const float * src0_d = (const float *) rms_norm_src->data; |
| 541 | const float * mul_d = nullptr; |
| 542 | const ggml_tensor * mul_src = nullptr; |
| 543 | |
| 544 | if (mul_tensor->src[0] == dst) { |
| 545 | mul_d = (float *) mul_tensor->src[1]->data; |
| 546 | mul_src = mul_tensor->src[1]; |
| 547 | } else if(mul_tensor->src[1] == dst) { |
| 548 | mul_d = (float *) mul_tensor->src[0]->data; |
| 549 | mul_src = mul_tensor->src[0]; |
| 550 | } else { |
| 551 | GGML_ASSERT(false); |
| 552 | } |
| 553 | |
| 554 | float * dst_d = (float *) mul_tensor->data; |
| 555 | cudaStream_t stream = ctx.stream(); |
| 556 | |
| 557 | GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); |
| 558 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 559 | GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); |
| 560 | GGML_ASSERT(eps >= 0.0f); |
| 561 | |
| 562 | const int64_t ne00 = rms_norm_src->ne[0]; |
| 563 | const int64_t ne01 = rms_norm_src->ne[1]; |
| 564 | const int64_t ne02 = rms_norm_src->ne[2]; |
| 565 | const int64_t ne03 = rms_norm_src->ne[3]; |
| 566 | |
| 567 | const size_t ts0 = ggml_type_size(rms_norm_src->type); |
| 568 | GGML_ASSERT(rms_norm_src->nb[0] == ts0); |
| 569 | const int64_t s01 = rms_norm_src->nb[1] / ts0; |
| 570 | const int64_t s02 = rms_norm_src->nb[2] / ts0; |
| 571 | const int64_t s03 = rms_norm_src->nb[3] / ts0; |
| 572 | |
| 573 | const size_t ts_mul = ggml_type_size(mul_src->type); |
| 574 | GGML_ASSERT(mul_src->nb[0] == ts_mul); |
| 575 | const int64_t mul_s01 = mul_src->nb[1] / ts_mul; |
| 576 | const int64_t mul_s02 = mul_src->nb[2] / ts_mul; |
| 577 | const int64_t mul_s03 = mul_src->nb[3] / ts_mul; |
| 578 | |
| 579 | const int mul_ncols = mul_src->ne[0]; |
| 580 | const int mul_nrows = mul_src->ne[1]; |
| 581 | const int mul_nchannels = mul_src->ne[2]; |
| 582 | const int mul_nsamples = mul_src->ne[3]; |
| 583 | |
| 584 | rms_norm_mul_f32_cuda(x: src0_d, mul: mul_d, add: nullptr, dst: dst_d, |
| 585 | ncols: ne00, nrows: ne01, nchannels: ne02, nsamples: ne03, |
| 586 | /*s00*/ stride_row: s01, stride_channel: s02, stride_sample: s03, |
| 587 | /*mul_s00*/ mul_stride_row: mul_s01, mul_stride_channel: mul_s02, mul_stride_sample: mul_s03, |
| 588 | mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, |
| 589 | /*add_s00*/ add_stride_row: 0, add_stride_channel: 0, add_stride_sample: 0, |
| 590 | add_ncols: 0, add_nrows: 0, add_nchannels: 0, add_nsamples: 0, |
| 591 | eps, stream); |
| 592 | } |
| 593 | |
| 594 | void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, |
| 595 | ggml_tensor * dst, |
| 596 | ggml_tensor * mul_tensor, |
| 597 | ggml_tensor * add_tensor) { |
| 598 | const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; |
| 599 | float eps = 0.0f; |
| 600 | |
| 601 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 602 | |
| 603 | const float * src0_d = (const float *) rms_norm_src->data; |
| 604 | const float * mul_d = nullptr; |
| 605 | const ggml_tensor * mul_src = nullptr; |
| 606 | |
| 607 | if (mul_tensor->src[0] == dst) { |
| 608 | mul_d = (float *) mul_tensor->src[1]->data; |
| 609 | mul_src = mul_tensor->src[1]; |
| 610 | } else if (mul_tensor->src[1] == dst) { |
| 611 | mul_d = (float *) mul_tensor->src[0]->data; |
| 612 | mul_src = mul_tensor->src[0]; |
| 613 | } else { |
| 614 | GGML_ASSERT(false); |
| 615 | } |
| 616 | |
| 617 | const float * add_d = nullptr; |
| 618 | const ggml_tensor * add_src = nullptr; |
| 619 | |
| 620 | if (add_tensor->src[0] == mul_tensor) { |
| 621 | add_d = (float *) add_tensor->src[1]->data; |
| 622 | add_src = add_tensor->src[1]; |
| 623 | } else if (add_tensor->src[1] == mul_tensor) { |
| 624 | add_d = (float *) add_tensor->src[0]->data; |
| 625 | add_src = add_tensor->src[0]; |
| 626 | } else { |
| 627 | GGML_ASSERT(false); |
| 628 | } |
| 629 | |
| 630 | float * dst_d = (float *) add_tensor->data; |
| 631 | cudaStream_t stream = ctx.stream(); |
| 632 | |
| 633 | GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); |
| 634 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
| 635 | GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); |
| 636 | GGML_ASSERT(add_tensor->type == GGML_TYPE_F32); |
| 637 | GGML_ASSERT(eps >= 0.0f); |
| 638 | |
| 639 | const int64_t ne00 = rms_norm_src->ne[0]; |
| 640 | const int64_t ne01 = rms_norm_src->ne[1]; |
| 641 | const int64_t ne02 = rms_norm_src->ne[2]; |
| 642 | const int64_t ne03 = rms_norm_src->ne[3]; |
| 643 | |
| 644 | const size_t ts0 = ggml_type_size(rms_norm_src->type); |
| 645 | GGML_ASSERT(rms_norm_src->nb[0] == ts0); |
| 646 | const int64_t s01 = rms_norm_src->nb[1] / ts0; |
| 647 | const int64_t s02 = rms_norm_src->nb[2] / ts0; |
| 648 | const int64_t s03 = rms_norm_src->nb[3] / ts0; |
| 649 | |
| 650 | const size_t ts_mul = ggml_type_size(mul_src->type); |
| 651 | GGML_ASSERT(mul_src->nb[0] == ts_mul); |
| 652 | const int64_t mul_s01 = mul_src->nb[1] / ts_mul; |
| 653 | const int64_t mul_s02 = mul_src->nb[2] / ts_mul; |
| 654 | const int64_t mul_s03 = mul_src->nb[3] / ts_mul; |
| 655 | |
| 656 | const int mul_ncols = mul_src->ne[0]; |
| 657 | const int mul_nrows = mul_src->ne[1]; |
| 658 | const int mul_nchannels = mul_src->ne[2]; |
| 659 | const int mul_nsamples = mul_src->ne[3]; |
| 660 | |
| 661 | const size_t ts_add = ggml_type_size(add_src->type); |
| 662 | GGML_ASSERT(add_src->nb[0] == ts_add); |
| 663 | const int64_t add_s01 = add_src->nb[1] / ts_add; |
| 664 | const int64_t add_s02 = add_src->nb[2] / ts_add; |
| 665 | const int64_t add_s03 = add_src->nb[3] / ts_add; |
| 666 | |
| 667 | const int add_ncols = add_src->ne[0]; |
| 668 | const int add_nrows = add_src->ne[1]; |
| 669 | const int add_nchannels = add_src->ne[2]; |
| 670 | const int add_nsamples = add_src->ne[3]; |
| 671 | |
| 672 | rms_norm_mul_f32_cuda(x: src0_d, mul: mul_d,add: add_d,dst: dst_d, |
| 673 | ncols: ne00,nrows: ne01, nchannels: ne02, nsamples: ne03, |
| 674 | /*s00*/ stride_row: s01, stride_channel: s02, stride_sample: s03, |
| 675 | /*mul_s00*/ mul_stride_row: mul_s01, mul_stride_channel: mul_s02, mul_stride_sample: mul_s03, |
| 676 | mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, |
| 677 | /*add_s00*/ add_stride_row: add_s01, add_stride_channel: add_s02, add_stride_sample: add_s03, |
| 678 | add_ncols, add_nrows, add_nchannels, add_nsamples, |
| 679 | eps, stream); |
| 680 | } |
| 681 | |
| 682 | void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 683 | const ggml_tensor * grad = dst->src[0]; // gradients |
| 684 | const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass |
| 685 | |
| 686 | const float * grad_d = (const float *) grad->data; |
| 687 | const float * src0f_d = (const float *) src0f->data; |
| 688 | float * dst_d = (float *) dst->data; |
| 689 | |
| 690 | cudaStream_t stream = ctx.stream(); |
| 691 | |
| 692 | GGML_ASSERT(ggml_is_contiguous(grad)); |
| 693 | |
| 694 | GGML_ASSERT( grad->type == GGML_TYPE_F32); |
| 695 | GGML_ASSERT(src0f->type == GGML_TYPE_F32); |
| 696 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 697 | |
| 698 | const int64_t ne00 = src0f->ne[0]; |
| 699 | const int64_t nrows = ggml_nrows(src0f); |
| 700 | |
| 701 | float eps; |
| 702 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 703 | GGML_ASSERT(eps >= 0.0f); |
| 704 | |
| 705 | rms_norm_back_f32_cuda(grad: grad_d, xf: src0f_d, dst: dst_d, ncols: ne00, nrows, eps, stream); |
| 706 | } |
| 707 | |
| 708 | void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 709 | const ggml_tensor * src0 = dst->src[0]; |
| 710 | const float * src0_d = (const float *) src0->data; |
| 711 | float * dst_d = (float *) dst->data; |
| 712 | cudaStream_t stream = ctx.stream(); |
| 713 | |
| 714 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 715 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 716 | |
| 717 | GGML_TENSOR_UNARY_OP_LOCALS; |
| 718 | |
| 719 | float eps; |
| 720 | memcpy(&eps, dst->op_params, sizeof(float)); |
| 721 | GGML_ASSERT(eps >= 0.0f); |
| 722 | |
| 723 | const size_t ts0 = ggml_type_size(src0->type); |
| 724 | GGML_ASSERT(nb00 == ts0); |
| 725 | const int64_t s01 = nb01 / ts0; |
| 726 | const int64_t s02 = nb02 / ts0; |
| 727 | const int64_t s03 = nb03 / ts0; |
| 728 | |
| 729 | l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); |
| 730 | } |
| 731 | |