| 1 | #include "unary.cuh" |
| 2 | #include "convert.cuh" |
| 3 | |
| 4 | static __device__ __forceinline__ float op_abs(float x) { |
| 5 | return fabsf(a: x); |
| 6 | } |
| 7 | |
| 8 | static __device__ __forceinline__ float op_sgn(float x) { |
| 9 | return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); |
| 10 | } |
| 11 | |
| 12 | static __device__ __forceinline__ float op_neg(float x) { |
| 13 | return -x; |
| 14 | } |
| 15 | |
| 16 | static __device__ __forceinline__ float op_step(float x) { |
| 17 | return x > 0.0f; |
| 18 | } |
| 19 | |
| 20 | static __device__ __forceinline__ float op_gelu(float x) { |
| 21 | return ggml_cuda_op_gelu_single(x); |
| 22 | } |
| 23 | |
| 24 | static __device__ __forceinline__ float op_gelu_erf(float x) { |
| 25 | const float SQRT_2_INV = 0.70710678118654752440084436210484f; |
| 26 | |
| 27 | return 0.5f*x*(1.0f + erff(a: x*SQRT_2_INV)); |
| 28 | } |
| 29 | |
| 30 | static __device__ __forceinline__ float op_gelu_quick(float x) { |
| 31 | const float GELU_QUICK_COEF = -1.702f; |
| 32 | |
| 33 | return x * (1.0f / (1.0f + expf(a: GELU_QUICK_COEF * x))); |
| 34 | } |
| 35 | |
| 36 | static __device__ __forceinline__ float op_silu(float x) { |
| 37 | return ggml_cuda_op_silu_single(x); |
| 38 | } |
| 39 | |
| 40 | static __device__ __forceinline__ float op_tanh(float x) { |
| 41 | return tanhf(a: x); |
| 42 | } |
| 43 | |
| 44 | static __device__ __forceinline__ float op_relu(float x) { |
| 45 | return fmaxf(a: x, b: 0); |
| 46 | } |
| 47 | |
| 48 | static __device__ __forceinline__ float op_sigmoid(float x) { |
| 49 | return 1.0f / (1.0f + expf(a: -x)); |
| 50 | } |
| 51 | |
| 52 | static __device__ __forceinline__ float op_hardsigmoid(float x) { |
| 53 | return fminf(a: 1.0f, b: fmaxf(a: 0.0f, b: (x + 3.0f) / 6.0f)); |
| 54 | } |
| 55 | |
| 56 | static __device__ __forceinline__ float op_hardswish(float x) { |
| 57 | return x * fminf(a: 1.0f, b: fmaxf(a: 0.0f, b: (x + 3.0f) / 6.0f)); |
| 58 | } |
| 59 | |
| 60 | static __device__ __forceinline__ float op_exp(float x) { |
| 61 | return expf(a: x); |
| 62 | } |
| 63 | |
| 64 | static __device__ __forceinline__ float op_sqr(float x) { |
| 65 | return x * x; |
| 66 | } |
| 67 | |
| 68 | static __device__ __forceinline__ float op_sqrt(float x) { |
| 69 | return sqrtf(a: x); |
| 70 | } |
| 71 | |
| 72 | static __device__ __forceinline__ float op_sin(float x) { |
| 73 | return sinf(a: x); |
| 74 | } |
| 75 | |
| 76 | static __device__ __forceinline__ float op_cos(float x) { |
| 77 | return cosf(a: x); |
| 78 | } |
| 79 | |
| 80 | static __device__ __forceinline__ float op_log(float x) { |
| 81 | return logf(a: x); |
| 82 | } |
| 83 | |
| 84 | static __device__ __forceinline__ float op_elu(float x) { |
| 85 | return (x > 0.f) ? x : expm1f(a: x); |
| 86 | } |
| 87 | |
| 88 | static __device__ __forceinline__ float op_floor(float x) { |
| 89 | return floorf(f: x); |
| 90 | } |
| 91 | |
| 92 | static __device__ __forceinline__ float op_ceil(float x) { |
| 93 | return ceilf(a: x); |
| 94 | } |
| 95 | |
| 96 | static __device__ __forceinline__ float op_round(float x) { |
| 97 | return round(a: x); |
| 98 | } |
| 99 | |
| 100 | static __device__ __forceinline__ float op_trunc(float x) { |
| 101 | return trunc(a: x); |
| 102 | } |
| 103 | |
| 104 | template <float (*op)(float), typename T> |
| 105 | static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { |
| 106 | const int i = blockDim.x*blockIdx.x + threadIdx.x; |
| 107 | |
| 108 | if (i >= k) { |
| 109 | return; |
| 110 | } |
| 111 | |
| 112 | dst[i] = (T)op((float)x[i]); |
| 113 | } |
| 114 | |
| 115 | template <float (*op)(float), typename T> |
| 116 | static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { |
| 117 | const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; |
| 118 | unary_op_kernel<op><<<gridDim: num_blocks, CUDA_NEG_BLOCK_SIZE, sharedMem: 0, stream>>>(x, dst, k); |
| 119 | } |
| 120 | |
| 121 | template <float (*op)(float)> |
| 122 | void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 123 | const ggml_tensor * src0 = dst->src[0]; |
| 124 | const void * src0_d = src0->data; |
| 125 | void * dst_d = dst->data; |
| 126 | cudaStream_t stream = ctx.stream(); |
| 127 | |
| 128 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 129 | |
| 130 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 131 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 132 | GGML_ASSERT(src0->type == dst->type); |
| 133 | |
| 134 | if (src0->type == GGML_TYPE_F16) { |
| 135 | unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); |
| 136 | } else { |
| 137 | unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); |
| 138 | } |
| 139 | } |
| 140 | |
| 141 | void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 142 | ggml_cuda_op_unary<op_abs>(ctx, dst); |
| 143 | } |
| 144 | |
| 145 | void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 146 | ggml_cuda_op_unary<op_sgn>(ctx, dst); |
| 147 | } |
| 148 | |
| 149 | void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 150 | ggml_cuda_op_unary<op_neg>(ctx, dst); |
| 151 | } |
| 152 | |
| 153 | void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 154 | ggml_cuda_op_unary<op_step>(ctx, dst); |
| 155 | } |
| 156 | |
| 157 | void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 158 | ggml_cuda_op_unary<op_gelu>(ctx, dst); |
| 159 | } |
| 160 | |
| 161 | void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 162 | ggml_cuda_op_unary<op_gelu_erf>(ctx, dst); |
| 163 | } |
| 164 | |
| 165 | void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 166 | ggml_cuda_op_unary<op_gelu_quick>(ctx, dst); |
| 167 | } |
| 168 | |
| 169 | void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 170 | ggml_cuda_op_unary<op_silu>(ctx, dst); |
| 171 | } |
| 172 | |
| 173 | void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 174 | ggml_cuda_op_unary<op_tanh>(ctx, dst); |
| 175 | } |
| 176 | |
| 177 | void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 178 | ggml_cuda_op_unary<op_relu>(ctx, dst); |
| 179 | } |
| 180 | |
| 181 | void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 182 | ggml_cuda_op_unary<op_sigmoid>(ctx, dst); |
| 183 | } |
| 184 | |
| 185 | void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 186 | ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst); |
| 187 | } |
| 188 | |
| 189 | void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 190 | ggml_cuda_op_unary<op_hardswish>(ctx, dst); |
| 191 | } |
| 192 | |
| 193 | void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 194 | ggml_cuda_op_unary<op_exp>(ctx, dst); |
| 195 | } |
| 196 | |
| 197 | void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 198 | ggml_cuda_op_unary<op_sqr>(ctx, dst); |
| 199 | } |
| 200 | |
| 201 | void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 202 | ggml_cuda_op_unary<op_sqrt>(ctx, dst); |
| 203 | } |
| 204 | |
| 205 | void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 206 | ggml_cuda_op_unary<op_sin>(ctx, dst); |
| 207 | } |
| 208 | |
| 209 | void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 210 | ggml_cuda_op_unary<op_cos>(ctx, dst); |
| 211 | } |
| 212 | |
| 213 | void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 214 | ggml_cuda_op_unary<op_log>(ctx, dst); |
| 215 | } |
| 216 | |
| 217 | void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 218 | ggml_cuda_op_unary<op_elu>(ctx, dst); |
| 219 | } |
| 220 | |
| 221 | void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 222 | ggml_cuda_op_unary<op_floor>(ctx, dst); |
| 223 | } |
| 224 | |
| 225 | void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 226 | ggml_cuda_op_unary<op_ceil>(ctx, dst); |
| 227 | } |
| 228 | |
| 229 | void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 230 | ggml_cuda_op_unary<op_round>(ctx, dst); |
| 231 | } |
| 232 | |
| 233 | void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 234 | ggml_cuda_op_unary<op_trunc>(ctx, dst); |
| 235 | } |
| 236 | /* gated ops */ |
| 237 | |
| 238 | template <float (*op)(float), typename T> |
| 239 | static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { |
| 240 | const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; |
| 241 | |
| 242 | if (i >= k) { |
| 243 | return; |
| 244 | } |
| 245 | |
| 246 | // perform base op and multiply with gate (either offset in same tensor or a separate one) |
| 247 | const int64_t j0 = (i / n) * o0 + (i % n); |
| 248 | const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); |
| 249 | |
| 250 | dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); |
| 251 | } |
| 252 | |
| 253 | template <float (*op)(float), typename T> |
| 254 | static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { |
| 255 | const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; |
| 256 | unary_gated_op_kernel<op><<<gridDim: num_blocks, CUDA_GLU_BLOCK_SIZE, sharedMem: 0, stream>>>(x, g, dst, k, n, o0, o1); |
| 257 | } |
| 258 | |
| 259 | template <float (*op)(float)> |
| 260 | void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 261 | const ggml_tensor * src0 = dst->src[0]; |
| 262 | const ggml_tensor * src1 = dst->src[1]; |
| 263 | void * src0_d = src0->data; |
| 264 | void * src1_d = src1 ? src1->data : src0->data; |
| 265 | const int64_t src0_o = src0->nb[1]; |
| 266 | const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; |
| 267 | void * dst_d = dst->data; |
| 268 | const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; |
| 269 | cudaStream_t stream = ctx.stream(); |
| 270 | |
| 271 | GGML_ASSERT(ggml_is_contiguous_1(src0)); |
| 272 | GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); |
| 273 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 274 | |
| 275 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 276 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 277 | GGML_ASSERT(src0->type == dst->type); |
| 278 | GGML_ASSERT(dst->ne[0] == nc); |
| 279 | GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); |
| 280 | |
| 281 | if (src1) { |
| 282 | GGML_ASSERT(ggml_is_contiguous_1(src1)); |
| 283 | GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); |
| 284 | GGML_ASSERT(src1->ne[0] == nc); |
| 285 | GGML_ASSERT(src0->type == src1->type); |
| 286 | } |
| 287 | |
| 288 | const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 289 | |
| 290 | if (src0->type == GGML_TYPE_F16) { |
| 291 | half * src0_p = (half *) src0_d; |
| 292 | half * src1_p = (half *) src1_d; |
| 293 | |
| 294 | if (!src1) { |
| 295 | src0_p += swapped ? nc : 0; |
| 296 | src1_p += swapped ? 0 : nc; |
| 297 | } |
| 298 | |
| 299 | unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); |
| 300 | } else { |
| 301 | float * src0_p = (float *) src0_d; |
| 302 | float * src1_p = (float *) src1_d; |
| 303 | |
| 304 | if (!src1) { |
| 305 | src0_p += swapped ? nc : 0; |
| 306 | src1_p += swapped ? 0 : nc; |
| 307 | } |
| 308 | |
| 309 | unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); |
| 310 | } |
| 311 | } |
| 312 | |
| 313 | void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 314 | ggml_cuda_op_unary_gated<op_relu>(ctx, dst); |
| 315 | } |
| 316 | |
| 317 | void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 318 | ggml_cuda_op_unary_gated<op_gelu>(ctx, dst); |
| 319 | } |
| 320 | |
| 321 | void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 322 | ggml_cuda_op_unary_gated<op_silu>(ctx, dst); |
| 323 | } |
| 324 | |
| 325 | void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 326 | ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst); |
| 327 | } |
| 328 | |
| 329 | void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 330 | ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst); |
| 331 | } |
| 332 | |
| 333 | // swiglu_oai |
| 334 | |
| 335 | template <typename T> |
| 336 | static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { |
| 337 | const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; |
| 338 | |
| 339 | if (i >= k) { |
| 340 | return; |
| 341 | } |
| 342 | |
| 343 | // perform base op and multiply with gate (either offset in same tensor or a separate one) |
| 344 | const int64_t j0 = (i / n) * o0 + (i % n); |
| 345 | const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); |
| 346 | |
| 347 | float xi = x[j0]; |
| 348 | float gi = g[j1]; |
| 349 | |
| 350 | dst[i] = ggml_cuda_op_swiglu_oai_single(x: xi, g: gi, alpha, limit); |
| 351 | } |
| 352 | |
| 353 | template <typename T> |
| 354 | static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { |
| 355 | const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; |
| 356 | swiglu_oai_kernel<<<gridDim: num_blocks, CUDA_GLU_BLOCK_SIZE, sharedMem: 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit); |
| 357 | } |
| 358 | |
| 359 | void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 360 | const ggml_tensor * src0 = dst->src[0]; |
| 361 | const ggml_tensor * src1 = dst->src[1]; |
| 362 | void * src0_d = src0->data; |
| 363 | void * src1_d = src1 ? src1->data : src0->data; |
| 364 | const int64_t src0_o = src0->nb[1]; |
| 365 | const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; |
| 366 | void * dst_d = dst->data; |
| 367 | const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; |
| 368 | cudaStream_t stream = ctx.stream(); |
| 369 | |
| 370 | GGML_ASSERT(ggml_is_contiguous_1(src0)); |
| 371 | GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); |
| 372 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 373 | |
| 374 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
| 375 | GGML_ASSERT( dst->type == GGML_TYPE_F32); |
| 376 | GGML_ASSERT(src0->type == dst->type); |
| 377 | GGML_ASSERT(dst->ne[0] == nc); |
| 378 | GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); |
| 379 | |
| 380 | if (src1) { |
| 381 | GGML_ASSERT(ggml_is_contiguous_1(src1)); |
| 382 | GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); |
| 383 | GGML_ASSERT(src1->ne[0] == nc); |
| 384 | GGML_ASSERT(src0->type == src1->type); |
| 385 | } |
| 386 | |
| 387 | //const int32_t swapped = ((const int32_t *) dst->op_params)[1]; |
| 388 | const int32_t swapped = ggml_get_op_params_i32(dst, 1); |
| 389 | const float alpha = ggml_get_op_params_f32(dst, 2); |
| 390 | const float limit = ggml_get_op_params_f32(dst, 3); |
| 391 | |
| 392 | float * src0_p = (float *) src0_d; |
| 393 | float * src1_p = (float *) src1_d; |
| 394 | |
| 395 | if (!src1) { |
| 396 | src0_p += swapped ? nc : 0; |
| 397 | src1_p += swapped ? 0 : nc; |
| 398 | } |
| 399 | |
| 400 | swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); |
| 401 | } |
| 402 | |
| 403 | /* CUDA kernel + launcher for xIELU */ |
| 404 | |
| 405 | template <typename T> |
| 406 | static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { |
| 407 | const int i = blockDim.x*blockIdx.x + threadIdx.x; |
| 408 | |
| 409 | if (i >= k) { |
| 410 | return; |
| 411 | } |
| 412 | |
| 413 | const float xi = ggml_cuda_cast<float>(x[i]); |
| 414 | |
| 415 | const float gate_pos = (xi > 0.0f); |
| 416 | const float y_pos = alpha_p * xi * xi + beta * xi; |
| 417 | const float min_v_eps = fminf(a: xi, b: eps); |
| 418 | const float y_neg = (expm1f(a: min_v_eps) - xi) * alpha_n + beta * xi; |
| 419 | const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; |
| 420 | |
| 421 | dst[i] = ggml_cuda_cast<T>(out); |
| 422 | } |
| 423 | |
| 424 | template <typename T> |
| 425 | static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { |
| 426 | const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; |
| 427 | xielu_kernel<<<gridDim: num_blocks, CUDA_XIELU_BLOCK_SIZE, sharedMem: 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps); |
| 428 | } |
| 429 | |
| 430 | void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 431 | const ggml_tensor * src0 = dst->src[0]; |
| 432 | const void * src0_d = src0->data; |
| 433 | void * dst_d = dst->data; |
| 434 | cudaStream_t stream = ctx.stream(); |
| 435 | |
| 436 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 437 | |
| 438 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 439 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 440 | GGML_ASSERT(src0->type == dst->type); |
| 441 | |
| 442 | const float alpha_n = ggml_get_op_params_f32(dst, 1); |
| 443 | const float alpha_p = ggml_get_op_params_f32(dst, 2); |
| 444 | const float beta = ggml_get_op_params_f32(dst, 3); |
| 445 | const float eps = ggml_get_op_params_f32(dst, 4); |
| 446 | |
| 447 | if (src0->type == GGML_TYPE_F16) { |
| 448 | xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); |
| 449 | } else { |
| 450 | xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); |
| 451 | } |
| 452 | } |
| 453 | |
| 454 | |
| 455 | |
| 456 | /* silu_back */ |
| 457 | |
| 458 | static __device__ __forceinline__ float op_silu_back(float grad, float x) { |
| 459 | const float s = 1.0f / (1.0f + expf(a: -x)); |
| 460 | return grad * s * (1.0f + x * (1.0f - s)); |
| 461 | } |
| 462 | |
| 463 | template <class T> |
| 464 | static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { |
| 465 | const int i = blockDim.x*blockIdx.x + threadIdx.x; |
| 466 | |
| 467 | if (i >= k) { |
| 468 | return; |
| 469 | } |
| 470 | |
| 471 | dst[i] = (T)op_silu_back(grad: (float)grad[i], x: (float)xf[i]); |
| 472 | } |
| 473 | |
| 474 | template <class T> |
| 475 | static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { |
| 476 | const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; |
| 477 | silu_back_kernel<<<gridDim: num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, sharedMem: 0, stream>>>(grad, x, dst, k); |
| 478 | } |
| 479 | |
| 480 | void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 481 | const ggml_tensor * src0 = dst->src[0]; // input from forward pass |
| 482 | const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output |
| 483 | |
| 484 | const float * src0_d = (const float *) src0->data; |
| 485 | const float * src1_d = (const float *) src1->data; |
| 486 | float * dst_d = (float *) dst->data; |
| 487 | |
| 488 | cudaStream_t stream = ctx.stream(); |
| 489 | |
| 490 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 491 | |
| 492 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 493 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 494 | GGML_ASSERT(src0->type == dst->type); |
| 495 | |
| 496 | if (src0->type == GGML_TYPE_F16) { |
| 497 | silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream); |
| 498 | } else { |
| 499 | silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream); |
| 500 | } |
| 501 | } |
| 502 | |
| 503 | /* leaky relu */ |
| 504 | |
| 505 | static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { |
| 506 | return fmaxf(a: x, b: 0) + fminf(a: x, b: 0.0f) * negative_slope; |
| 507 | } |
| 508 | |
| 509 | template <class T> |
| 510 | static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { |
| 511 | const int i = blockDim.x*blockIdx.x + threadIdx.x; |
| 512 | |
| 513 | if (i >= k) { |
| 514 | return; |
| 515 | } |
| 516 | |
| 517 | dst[i] = (T)op_leaky_relu(x: (float)x[i], negative_slope); |
| 518 | } |
| 519 | |
| 520 | template <class T> |
| 521 | static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { |
| 522 | const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; |
| 523 | leaky_relu_kernel<<<gridDim: num_blocks, CUDA_RELU_BLOCK_SIZE, sharedMem: 0, stream>>>(x, dst, k, negative_slope); |
| 524 | } |
| 525 | |
| 526 | void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 527 | const ggml_tensor * src0 = dst->src[0]; |
| 528 | const void * src0_d = src0->data; |
| 529 | void * dst_d = dst->data; |
| 530 | cudaStream_t stream = ctx.stream(); |
| 531 | |
| 532 | GGML_ASSERT(ggml_is_contiguous(src0)); |
| 533 | |
| 534 | GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); |
| 535 | GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); |
| 536 | GGML_ASSERT(src0->type == dst->type); |
| 537 | |
| 538 | float negative_slope; |
| 539 | memcpy(&negative_slope, dst->op_params, sizeof(float)); |
| 540 | |
| 541 | if (src0->type == GGML_TYPE_F16) { |
| 542 | leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream); |
| 543 | } else { |
| 544 | leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); |
| 545 | } |
| 546 | } |
| 547 | |