| 1 | #include "binbcast.cuh" |
| 2 | #include <cstdint> |
| 3 | #include <utility> |
| 4 | |
| 5 | static __device__ __forceinline__ float op_repeat(const float a, const float b) { |
| 6 | return b; |
| 7 | GGML_UNUSED(a); |
| 8 | } |
| 9 | |
| 10 | static __device__ __forceinline__ float op_add(const float a, const float b) { |
| 11 | return a + b; |
| 12 | } |
| 13 | |
| 14 | static __device__ __forceinline__ float op_sub(const float a, const float b) { |
| 15 | return a - b; |
| 16 | } |
| 17 | |
| 18 | static __device__ __forceinline__ float op_mul(const float a, const float b) { |
| 19 | return a * b; |
| 20 | } |
| 21 | |
| 22 | static __device__ __forceinline__ float op_div(const float a, const float b) { |
| 23 | return a / b; |
| 24 | } |
| 25 | |
| 26 | template <float (*bin_op)(const float, const float), |
| 27 | typename src0_t, |
| 28 | typename src1_t, |
| 29 | typename dst_t, |
| 30 | typename... src1_ptrs> |
| 31 | static __global__ void k_bin_bcast(const src0_t * src0, |
| 32 | const src1_t * src1, |
| 33 | dst_t * dst, |
| 34 | const int ne0, |
| 35 | const int ne1, |
| 36 | const int ne2, |
| 37 | const uint3 ne3, |
| 38 | const uint3 ne10, |
| 39 | const uint3 ne11, |
| 40 | const uint3 ne12, |
| 41 | const uint3 ne13, |
| 42 | /*int s0, */ const int s1, |
| 43 | const int s2, |
| 44 | const int s3, |
| 45 | /*int s00,*/ const int s01, |
| 46 | const int s02, |
| 47 | const int s03, |
| 48 | /*int s10,*/ const int s11, |
| 49 | const int s12, |
| 50 | const int s13, |
| 51 | src1_ptrs... src1s) { |
| 52 | const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; |
| 53 | const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); |
| 54 | const uint32_t i2 = fastdiv(n: (blockDim.z * blockIdx.z + threadIdx.z), fastdiv_values: ne3); |
| 55 | const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); |
| 56 | |
| 57 | if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { |
| 58 | return; |
| 59 | } |
| 60 | |
| 61 | const uint32_t i11 = fastmodulo(n: i1, fastdiv_values: ne11); |
| 62 | const uint32_t i12 = fastmodulo(n: i2, fastdiv_values: ne12); |
| 63 | const uint32_t i13 = fastmodulo(n: i3, fastdiv_values: ne13); |
| 64 | |
| 65 | const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; |
| 66 | const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; |
| 67 | const size_t i_dst = i3*s3 + i2*s2 + i1*s1; |
| 68 | |
| 69 | const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; |
| 70 | dst_t * dst_row = dst + i_dst; |
| 71 | |
| 72 | for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { |
| 73 | const uint32_t i10 = fastmodulo(n: i0, fastdiv_values: ne10); |
| 74 | |
| 75 | float result = src0_row ? (float) src0_row[i0] : 0.0f; |
| 76 | if constexpr (sizeof...(src1_ptrs) > 0) { |
| 77 | result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); |
| 78 | } else { |
| 79 | result = bin_op(result, (float)src1[i_src1 + i10]); |
| 80 | } |
| 81 | |
| 82 | dst_row[i0] = (dst_t) result; |
| 83 | } |
| 84 | } |
| 85 | |
| 86 | template <float (*bin_op)(const float, const float), |
| 87 | typename src0_t, |
| 88 | typename src1_t, |
| 89 | typename dst_t, |
| 90 | typename... src1_ptrs> |
| 91 | static __global__ void k_bin_bcast_unravel(const src0_t * src0, |
| 92 | const src1_t * src1, |
| 93 | dst_t * dst, |
| 94 | const uint3 ne0, |
| 95 | const uint3 ne1, |
| 96 | const uint3 ne2, |
| 97 | const uint32_t ne3, |
| 98 | const uint3 prod_012, |
| 99 | const uint3 prod_01, |
| 100 | const uint3 ne10, |
| 101 | const uint3 ne11, |
| 102 | const uint3 ne12, |
| 103 | const uint3 ne13, |
| 104 | /*int s0, */ const int s1, |
| 105 | const int s2, |
| 106 | const int s3, |
| 107 | /*int s00,*/ const int s01, |
| 108 | const int s02, |
| 109 | const int s03, |
| 110 | /*int s10,*/ const int s11, |
| 111 | const int s12, |
| 112 | const int s13, |
| 113 | src1_ptrs... src1s) { |
| 114 | const int i = blockDim.x*blockIdx.x + threadIdx.x; |
| 115 | |
| 116 | const uint32_t i3 = fastdiv(n: i, fastdiv_values: prod_012); |
| 117 | const uint32_t i2 = fastdiv(n: i - i3 * prod_012.z, fastdiv_values: prod_01); |
| 118 | const uint32_t i1 = fastdiv(n: i - i3 * prod_012.z - i2 * prod_01.z, fastdiv_values: ne0); |
| 119 | const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; |
| 120 | |
| 121 | if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { |
| 122 | return; |
| 123 | } |
| 124 | |
| 125 | const int i11 = fastmodulo(n: i1, fastdiv_values: ne11); |
| 126 | const int i12 = fastmodulo(n: i2, fastdiv_values: ne12); |
| 127 | const int i13 = fastmodulo(n: i3, fastdiv_values: ne13); |
| 128 | |
| 129 | const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; |
| 130 | const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; |
| 131 | const size_t i_dst = i3*s3 + i2*s2 + i1*s1; |
| 132 | |
| 133 | const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; |
| 134 | dst_t * dst_row = dst + i_dst; |
| 135 | |
| 136 | const int i10 = fastmodulo(n: i0, fastdiv_values: ne10); |
| 137 | |
| 138 | float result = src0_row ? (float) src0_row[i0] : 0.0f; |
| 139 | if constexpr (sizeof...(src1_ptrs) > 0) { |
| 140 | result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); |
| 141 | } else { |
| 142 | result = bin_op(result, (float)src1[i_src1 + i10]); |
| 143 | } |
| 144 | |
| 145 | dst_row[i0] = (dst_t) result; |
| 146 | } |
| 147 | |
| 148 | template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I> |
| 149 | static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, |
| 150 | const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, |
| 151 | cudaStream_t stream, std::index_sequence<I...>) { |
| 152 | GGML_TENSOR_BINARY_OP_LOCALS |
| 153 | |
| 154 | int nr0 = ne10 / ne0; |
| 155 | int nr1 = ne11 / ne1; |
| 156 | int nr2 = ne12 / ne2; |
| 157 | int nr3 = ne13 / ne3; |
| 158 | |
| 159 | int nr[4] = { nr0, nr1, nr2, nr3 }; |
| 160 | |
| 161 | int64_t cne[] = { ne0, ne1, ne2, ne3 }; |
| 162 | int64_t cne0[] = { ne00, ne01, ne02, ne03 }; |
| 163 | int64_t cne1[] = { ne10, ne11, ne12, ne13 }; |
| 164 | |
| 165 | size_t cnb[] = { nb0, nb1, nb2, nb3 }; |
| 166 | size_t cnb0[] = { nb00, nb01, nb02, nb03 }; |
| 167 | size_t cnb1[] = { nb10, nb11, nb12, nb13 }; |
| 168 | |
| 169 | auto collapse = [](int64_t cne[]) { |
| 170 | cne[0] *= cne[1]; |
| 171 | cne[1] = cne[2]; |
| 172 | cne[2] = cne[3]; |
| 173 | cne[3] = 1; |
| 174 | }; |
| 175 | |
| 176 | auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { |
| 177 | cnb[1] *= cne[1]; |
| 178 | cnb[2] *= cne[2]; |
| 179 | cnb[3] *= cne[3]; |
| 180 | }; |
| 181 | |
| 182 | if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { |
| 183 | for (int i = 0; i < 4; i++) { |
| 184 | if (nr[i] != 1) { |
| 185 | break; |
| 186 | } |
| 187 | if (i > 0) { |
| 188 | collapse_nb(cnb, cne); |
| 189 | collapse_nb(cnb0, cne0); |
| 190 | collapse_nb(cnb1, cne1); |
| 191 | collapse(cne); |
| 192 | collapse(cne0); |
| 193 | collapse(cne1); |
| 194 | } |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | { |
| 199 | int64_t ne0 = cne[0]; |
| 200 | int64_t ne1 = cne[1]; |
| 201 | int64_t ne2 = cne[2]; |
| 202 | int64_t ne3 = cne[3]; |
| 203 | |
| 204 | //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); |
| 205 | //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); |
| 206 | //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); |
| 207 | //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); |
| 208 | |
| 209 | size_t nb0 = cnb[0]; |
| 210 | size_t nb1 = cnb[1]; |
| 211 | size_t nb2 = cnb[2]; |
| 212 | size_t nb3 = cnb[3]; |
| 213 | |
| 214 | size_t nb00 = cnb0[0]; |
| 215 | size_t nb01 = cnb0[1]; |
| 216 | size_t nb02 = cnb0[2]; |
| 217 | size_t nb03 = cnb0[3]; |
| 218 | |
| 219 | size_t nb10 = cnb1[0]; |
| 220 | size_t nb11 = cnb1[1]; |
| 221 | size_t nb12 = cnb1[2]; |
| 222 | size_t nb13 = cnb1[3]; |
| 223 | |
| 224 | size_t s0 = nb0 / sizeof(dst_t); |
| 225 | size_t s1 = nb1 / sizeof(dst_t); |
| 226 | size_t s2 = nb2 / sizeof(dst_t); |
| 227 | size_t s3 = nb3 / sizeof(dst_t); |
| 228 | |
| 229 | size_t s10 = nb10 / sizeof(src1_t); |
| 230 | size_t s11 = nb11 / sizeof(src1_t); |
| 231 | size_t s12 = nb12 / sizeof(src1_t); |
| 232 | size_t s13 = nb13 / sizeof(src1_t); |
| 233 | |
| 234 | size_t s00 = nb00 / sizeof(src0_t); |
| 235 | size_t s01 = nb01 / sizeof(src0_t); |
| 236 | size_t s02 = nb02 / sizeof(src0_t); |
| 237 | size_t s03 = nb03 / sizeof(src0_t); |
| 238 | |
| 239 | GGML_ASSERT(nb0 % sizeof(dst_t) == 0); |
| 240 | GGML_ASSERT(nb1 % sizeof(dst_t) == 0); |
| 241 | GGML_ASSERT(nb2 % sizeof(dst_t) == 0); |
| 242 | GGML_ASSERT(nb3 % sizeof(dst_t) == 0); |
| 243 | |
| 244 | GGML_ASSERT(nb00 % sizeof(src0_t) == 0); |
| 245 | GGML_ASSERT(nb01 % sizeof(src0_t) == 0); |
| 246 | GGML_ASSERT(nb02 % sizeof(src0_t) == 0); |
| 247 | GGML_ASSERT(nb03 % sizeof(src0_t) == 0); |
| 248 | |
| 249 | GGML_ASSERT(nb10 % sizeof(src1_t) == 0); |
| 250 | GGML_ASSERT(nb11 % sizeof(src1_t) == 0); |
| 251 | GGML_ASSERT(nb12 % sizeof(src1_t) == 0); |
| 252 | GGML_ASSERT(nb13 % sizeof(src1_t) == 0); |
| 253 | |
| 254 | GGML_ASSERT(s0 == 1); |
| 255 | GGML_ASSERT(s00 == 1); |
| 256 | GGML_ASSERT(s10 == 1); |
| 257 | |
| 258 | const int block_size = 128; |
| 259 | |
| 260 | int64_t hne0 = std::max(a: ne0 / 2LL, b: 1LL); |
| 261 | |
| 262 | dim3 block_dims; |
| 263 | block_dims.x = std::min<unsigned int>(a: hne0, b: block_size); |
| 264 | block_dims.y = std::min<unsigned int>(a: ne1, b: block_size / block_dims.x); |
| 265 | block_dims.z = std::min(a: std::min<unsigned int>(a: ne2 * ne3, b: block_size / block_dims.x / block_dims.y), b: 64U); |
| 266 | |
| 267 | dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, |
| 268 | (ne2 * ne3 + block_dims.z - 1) / block_dims.z); |
| 269 | |
| 270 | const uint3 ne10 = init_fastdiv_values(d_64: (uint32_t) cne1[0]); |
| 271 | const uint3 ne11 = init_fastdiv_values(d_64: (uint32_t) cne1[1]); |
| 272 | const uint3 ne12 = init_fastdiv_values(d_64: (uint32_t) cne1[2]); |
| 273 | const uint3 ne13 = init_fastdiv_values(d_64: (uint32_t) cne1[3]); |
| 274 | |
| 275 | if (block_nums.z > 65535 || block_nums.y > 65535) { |
| 276 | int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; |
| 277 | const uint3 prod_012 = init_fastdiv_values(d_64: (uint32_t) (ne0 * ne1 * ne2)); |
| 278 | const uint3 prod_01 = init_fastdiv_values(d_64: (uint32_t) (ne0 * ne1)); |
| 279 | const uint3 ne0_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne0); |
| 280 | const uint3 ne1_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne1); |
| 281 | const uint3 ne2_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne2); |
| 282 | |
| 283 | if constexpr (sizeof...(I) > 0) { |
| 284 | k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_num, blockDim: block_size, sharedMem: 0, stream>>>( |
| 285 | src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, |
| 286 | ne12, ne13, |
| 287 | /* s0, */ s1, s2, s3, |
| 288 | /* s00,*/ s01, s02, s03, |
| 289 | /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); |
| 290 | } else { |
| 291 | k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> |
| 292 | <<<gridDim: block_num, blockDim: block_size, sharedMem: 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, |
| 293 | ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, |
| 294 | /* s0, */ s1, s2, s3, |
| 295 | /* s00,*/ s01, s02, s03, |
| 296 | /* s10,*/ s11, s12, s13); |
| 297 | } |
| 298 | } else { |
| 299 | const uint3 ne3_fastdiv = init_fastdiv_values(d_64: (uint32_t) ne3); |
| 300 | if constexpr (sizeof...(I) > 0) { |
| 301 | k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 302 | src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, |
| 303 | /* s0, */ s1, s2, s3, |
| 304 | /* s00,*/ s01, s02, s03, |
| 305 | /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); |
| 306 | } else { |
| 307 | k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>>( |
| 308 | src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, |
| 309 | /* s0, */ s1, s2, s3, |
| 310 | /* s00,*/ s01, s02, s03, |
| 311 | /* s10,*/ s11, s12, s13); |
| 312 | } |
| 313 | } |
| 314 | } |
| 315 | } |
| 316 | |
| 317 | template <typename T> |
| 318 | static __global__ void k_repeat_back( |
| 319 | const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 320 | const size_t s00, const size_t s01, const size_t s02, const size_t s03, |
| 321 | const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { |
| 322 | |
| 323 | const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; |
| 324 | const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; |
| 325 | const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; |
| 326 | const int64_t tid2 = tid23 % ne2; |
| 327 | const int64_t tid3 = tid23 / ne2; |
| 328 | |
| 329 | if (tid0 >= ne0) { |
| 330 | return; |
| 331 | } |
| 332 | |
| 333 | T sum = 0; |
| 334 | for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { |
| 335 | for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { |
| 336 | for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { |
| 337 | for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { |
| 338 | sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; |
| 339 | } |
| 340 | } |
| 341 | } |
| 342 | } |
| 343 | dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; |
| 344 | } |
| 345 | |
| 346 | template <float (*bin_op)(const float, const float), int n_fuse = 1> |
| 347 | struct bin_bcast_cuda { |
| 348 | template<typename src0_t, typename src1_t, typename dst_t> |
| 349 | void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, |
| 350 | const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, |
| 351 | cudaStream_t stream) { |
| 352 | launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>( |
| 353 | src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{}); |
| 354 | } |
| 355 | }; |
| 356 | |
| 357 | template <typename T> |
| 358 | static void repeat_back_cuda( |
| 359 | const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, |
| 360 | const size_t s00, const size_t s01, const size_t s02, const size_t s03, |
| 361 | const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { |
| 362 | |
| 363 | const dim3 block_dims(WARP_SIZE, 1, 1); |
| 364 | const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); |
| 365 | k_repeat_back<T><<<gridDim: block_nums, blockDim: block_dims, sharedMem: 0, stream>>> |
| 366 | (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); |
| 367 | } |
| 368 | |
| 369 | template<class op> |
| 370 | static void ggml_cuda_op_bin_bcast( |
| 371 | const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, |
| 372 | const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { |
| 373 | |
| 374 | GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); |
| 375 | |
| 376 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 377 | op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); |
| 378 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 379 | op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); |
| 380 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 381 | op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); |
| 382 | } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |
| 383 | op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); |
| 384 | } else { |
| 385 | fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n" , __func__, |
| 386 | ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); |
| 387 | GGML_ABORT("fatal error" ); |
| 388 | } |
| 389 | } |
| 390 | |
| 391 | void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 392 | ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); |
| 393 | } |
| 394 | |
| 395 | void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 396 | ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); |
| 397 | } |
| 398 | |
| 399 | void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 400 | ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); |
| 401 | } |
| 402 | |
| 403 | void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 404 | ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); |
| 405 | } |
| 406 | |
| 407 | void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 408 | ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); |
| 409 | } |
| 410 | |
| 411 | template <float (*op)(const float, const float), int n_fuse> |
| 412 | static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 413 | cudaStream_t stream = ctx.stream(); |
| 414 | |
| 415 | const ggml_tensor * src0 = dst->src[0]; |
| 416 | const ggml_tensor * src1 = dst->src[1]; |
| 417 | |
| 418 | if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |
| 419 | launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst, |
| 420 | (const float *) src0->data, (const float *) src1->data, (float *) dst->data, |
| 421 | stream, std::make_index_sequence<n_fuse>{}); |
| 422 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |
| 423 | launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst, |
| 424 | (const half *) src0->data, (const half *) src1->data, (half *) dst->data, |
| 425 | stream, std::make_index_sequence<n_fuse>{}); |
| 426 | } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { |
| 427 | launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst, |
| 428 | (const half *) src0->data, (const float *) src1->data, (half *) dst->data, |
| 429 | stream, std::make_index_sequence<n_fuse>{}); |
| 430 | } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |
| 431 | launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst, |
| 432 | (const half *) src0->data, (const float *) src1->data, (float *) dst->data, |
| 433 | stream, std::make_index_sequence<n_fuse>{}); |
| 434 | } else { |
| 435 | fprintf(stderr, |
| 436 | "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n" , |
| 437 | __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); |
| 438 | GGML_ABORT("fatal error" ); |
| 439 | } |
| 440 | } |
| 441 | |
| 442 | |
| 443 | void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { |
| 444 | GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); |
| 445 | |
| 446 | switch (n_fuse) { |
| 447 | case 2: |
| 448 | ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst); |
| 449 | break; |
| 450 | case 3: |
| 451 | ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst); |
| 452 | break; |
| 453 | case 4: |
| 454 | ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst); |
| 455 | break; |
| 456 | case 5: |
| 457 | ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst); |
| 458 | break; |
| 459 | case 6: |
| 460 | ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst); |
| 461 | break; |
| 462 | case 7: |
| 463 | ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst); |
| 464 | break; |
| 465 | case 8: |
| 466 | ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst); |
| 467 | break; |
| 468 | default: |
| 469 | GGML_ASSERT(false && "Unsupported n_fuse value" ); |
| 470 | } |
| 471 | } |
| 472 | |
| 473 | void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
| 474 | const ggml_tensor * src0 = dst->src[0]; |
| 475 | |
| 476 | GGML_ASSERT(src0->type == dst->type); |
| 477 | GGML_ASSERT(ggml_is_contiguous(dst)); |
| 478 | GGML_ASSERT(ggml_can_repeat(dst, src0)); |
| 479 | |
| 480 | cudaStream_t stream = ctx.stream(); |
| 481 | |
| 482 | GGML_TENSOR_UNARY_OP_LOCALS; |
| 483 | |
| 484 | GGML_ASSERT(ne2*ne3 <= (1 << 15)); |
| 485 | |
| 486 | const size_t ts = ggml_type_size(src0->type); |
| 487 | const size_t s00 = nb00 / ts; |
| 488 | const size_t s01 = nb01 / ts; |
| 489 | const size_t s02 = nb02 / ts; |
| 490 | const size_t s03 = nb03 / ts; |
| 491 | |
| 492 | switch (dst->type) { |
| 493 | case GGML_TYPE_F32: { |
| 494 | const float * src0_d = (const float *) src0->data; |
| 495 | float * dst_d = (float *) dst->data; |
| 496 | repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); |
| 497 | } break; |
| 498 | default: { |
| 499 | GGML_ASSERT(false); |
| 500 | } break; |
| 501 | } |
| 502 | } |
| 503 | |