1#include "vec.h"
2
3#include <cassert>
4
5// precomputed gelu table for f16 (128 KB)
6ggml_fp16_t ggml_table_gelu_f16[1 << 16];
7
8// precomputed quick gelu table for f16 (128 KB)
9ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
10
11void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) {
12 assert(nrc == 1);
13 GGML_UNUSED(nrc);
14 GGML_UNUSED(bx);
15 GGML_UNUSED(by);
16 GGML_UNUSED(bs);
17
18#if defined(GGML_SIMD)
19 float sumf = 0.0f;
20
21 #if defined(__ARM_FEATURE_SVE)
22 const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
23 const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
24 const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
25
26 const int np = (n & ~(ggml_f32_step - 1));
27 svfloat32_t sum1 = svdup_n_f32(0.0f);
28 svfloat32_t sum2 = svdup_n_f32(0.0f);
29 svfloat32_t sum3 = svdup_n_f32(0.0f);
30 svfloat32_t sum4 = svdup_n_f32(0.0f);
31 svfloat32_t sum5 = svdup_n_f32(0.0f);
32 svfloat32_t sum6 = svdup_n_f32(0.0f);
33 svfloat32_t sum7 = svdup_n_f32(0.0f);
34 svfloat32_t sum8 = svdup_n_f32(0.0f);
35 svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
36 svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
37 for (int i = 0; i < np; i += ggml_f32_step) {
38 ax1 = GGML_F32_VEC_LOAD(x + i);
39 ay1 = GGML_F32_VEC_LOAD(y + i);
40 sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
41
42 ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43 ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44 sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
45
46 ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47 ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48 sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
49
50 ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51 ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52 sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
53
54 ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55 ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56 sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
57
58 ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59 ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60 sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
61
62 ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63 ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64 sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
65
66 ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67 ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68 sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
69 }
70 // leftovers
71 // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
72 const int np2 = (n & ~(ggml_f32_epr - 1));
73 for (int i = np; i < np2; i += ggml_f32_epr) {
74 ax1 = GGML_F32_VEC_LOAD(x + i);
75 ay1 = GGML_F32_VEC_LOAD(y + i);
76 sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
77 }
78 // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79 if (np2 < n) {
80 svbool_t pg = svwhilelt_b32(np2, n);
81 ax1 = svld1_f32(pg, x + np2);
82 ay1 = svld1_f32(pg, y + np2);
83 sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
84 }
85 // reduce sum1,sum2 to sum1
86 GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
87 #elif defined(__riscv_v_intrinsic)
88 int vl = __riscv_vsetvlmax_e32m8();
89 vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
90 vfloat32m8_t vsum;
91 vfloat32m8_t ax;
92 vfloat32m8_t ay;
93 vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
94 for (int i = 0; i < n; i += vl) {
95 vl = __riscv_vsetvl_e32m8(n - i);
96 ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
97 ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
98 vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
99 }
100 vl = __riscv_vsetvlmax_e32m8();
101 vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
102 sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
103 #else
104 const int np = (n & ~(GGML_F32_STEP - 1));
105
106 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
107
108 GGML_F32_VEC ax[GGML_F32_ARR];
109 GGML_F32_VEC ay[GGML_F32_ARR];
110
111 for (int i = 0; i < np; i += GGML_F32_STEP) {
112 for (int j = 0; j < GGML_F32_ARR; j++) {
113 ax[j] = GGML_F32_VEC_LOAD(p: x + i + j*GGML_F32_EPR);
114 ay[j] = GGML_F32_VEC_LOAD(p: y + i + j*GGML_F32_EPR);
115
116 sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
117 }
118 }
119
120 // reduce sum0..sum3 to sum0
121 GGML_F32_VEC_REDUCE(sumf, sum);
122
123 // leftovers
124 for (int i = np; i < n; ++i) {
125 sumf += x[i]*y[i];
126 }
127 #endif
128#else
129 // scalar
130 ggml_float sumf = 0.0;
131 for (int i = 0; i < n; ++i) {
132 sumf += (ggml_float)(x[i]*y[i]);
133 }
134#endif
135
136 *s = sumf;
137}
138
139void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) {
140 assert(nrc == 1);
141 GGML_UNUSED(nrc);
142 GGML_UNUSED(bx);
143 GGML_UNUSED(by);
144 GGML_UNUSED(bs);
145 int i = 0;
146 ggml_float sumf = 0;
147
148#if defined(__AVX512BF16__)
149 __m512 c1 = _mm512_setzero_ps();
150 __m512 c2 = _mm512_setzero_ps();
151 for (; i + 64 <= n; i += 64) {
152 c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
153 m512bh(_mm512_loadu_si512((y + i))));
154 c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
155 m512bh(_mm512_loadu_si512((y + i + 32))));
156 }
157 sumf += (ggml_float)_mm512_reduce_add_ps(c1);
158 sumf += (ggml_float)_mm512_reduce_add_ps(c2);
159
160#elif defined(__AVX512F__)
161#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
162 __m512 c1 = _mm512_setzero_ps();
163 __m512 c2 = _mm512_setzero_ps();
164 for (; i + 32 <= n; i += 32) {
165 c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
166 c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
167 }
168 sumf += (ggml_float)_mm512_reduce_add_ps(c1);
169 sumf += (ggml_float)_mm512_reduce_add_ps(c2);
170
171#undef LOAD
172#elif defined(__AVX2__) || defined(__AVX__)
173#if defined(__AVX2__)
174#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
175#else
176#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
177#endif
178 __m256 c1 = _mm256_setzero_ps();
179 __m256 c2 = _mm256_setzero_ps();
180 __m256 c3 = _mm256_setzero_ps();
181 __m256 c4 = _mm256_setzero_ps();
182 for (; i + 32 <= n; i += 32) {
183 c1 = _mm256_add_ps(a: _mm256_mul_ps(LOAD(x + i), LOAD(y + i)), b: c1);
184 c2 = _mm256_add_ps(a: _mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), b: c2);
185 c3 = _mm256_add_ps(a: _mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), b: c3);
186 c4 = _mm256_add_ps(a: _mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), b: c4);
187 }
188 __m128 g;
189 c1 = _mm256_add_ps(a: _mm256_add_ps(a: c1, b: c3),
190 b: _mm256_add_ps(a: c2, b: c4));
191 g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
192 b: _mm256_castps256_ps128(a: c1));
193 g = _mm_add_ps(a: g, b: _mm_movehl_ps(a: g, b: g));
194 g = _mm_add_ss(a: g, b: _mm_movehdup_ps(a: g));
195 sumf += (ggml_float)_mm_cvtss_f32(a: g);
196
197#undef LOAD
198#endif
199
200 for (; i < n; ++i) {
201 sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
202 GGML_BF16_TO_FP32(y[i]));
203 }
204 *s = sumf;
205}
206
207void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) {
208 assert(nrc == 1);
209 GGML_UNUSED(nrc);
210 GGML_UNUSED(bx);
211 GGML_UNUSED(by);
212 GGML_UNUSED(bs);
213
214 ggml_float sumf = 0.0;
215
216
217#if defined(GGML_SIMD)
218 #if defined(__ARM_FEATURE_SVE)
219 const int sve_register_length = svcntb() * 8; //get vector length
220 const int ggml_f16_epr = sve_register_length / 16; // running when 16
221 const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
222
223 const int np= (n & ~(ggml_f16_step - 1));
224 svfloat16_t sum1 = svdup_n_f16(0.0f);
225 svfloat16_t sum2 = svdup_n_f16(0.0f);
226 svfloat16_t sum3 = svdup_n_f16(0.0f);
227 svfloat16_t sum4 = svdup_n_f16(0.0f);
228
229 svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
230 svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
231 for (int i = 0; i < np; i += ggml_f16_step) {
232 ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
233 ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
234 sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
235
236 ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
237 ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
238 sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
239
240 ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
241 ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
242 sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
243
244 ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
245 ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
246 sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
247
248 ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
249 ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
250 sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
251
252 ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
253 ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
254 sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
255
256 ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
257 ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
258 sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
259
260 ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
261 ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
262 sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
263 }
264
265 const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
266 for (int k = np; k < np2; k += ggml_f16_epr) {
267 svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
268 svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
269 sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
270 }
271
272 if (np2 < n) {
273 svbool_t pg = svwhilelt_b16(np2, n);
274 svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
275 svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
276
277 sum1 = svmad_f16_x(pg, hx, hy, sum1);
278 }
279 GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
280 #elif defined(__riscv_v_intrinsic)
281 #if defined(__riscv_zvfh)
282 int vl = __riscv_vsetvlmax_e32m2();
283 vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
284 vfloat32m2_t vsum;
285 vfloat16m1_t ax;
286 vfloat16m1_t ay;
287 vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
288 for (int i = 0; i < n; i += vl) {
289 vl = __riscv_vsetvl_e16m1(n - i);
290 ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
291 ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
292 vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
293 }
294 vl = __riscv_vsetvlmax_e32m1();
295 vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
296 vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
297 sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
298 #else
299 for (int i = 0; i < n; ++i) {
300 sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
301 }
302 #endif // __riscv_zvfh
303 #else
304 const int np = (n & ~(GGML_F16_STEP - 1));
305
306 GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
307
308 GGML_F16_VEC ax[GGML_F16_ARR];
309 GGML_F16_VEC ay[GGML_F16_ARR];
310
311 for (int i = 0; i < np; i += GGML_F16_STEP) {
312 for (int j = 0; j < GGML_F16_ARR; j++) {
313 ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
314 ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
315
316 sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
317 }
318 }
319
320 // reduce sum0..sum3 to sum0
321 GGML_F16_VEC_REDUCE(sumf, sum);
322
323 // leftovers
324 for (int i = np; i < n; ++i) {
325 sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
326 }
327 // if you hit this, you are likely running outside the FP range
328 assert(!isnan(sumf) && !isinf(sumf));
329 #endif
330#else
331 for (int i = 0; i < n; ++i) {
332 sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
333 }
334#endif // GGML_SIMD
335
336 *s = sumf;
337}
338
339void ggml_vec_silu_f32(const int n, float * y, const float * x) {
340 int i = 0;
341#if defined(__AVX512F__) && defined(__AVX512DQ__)
342 for (; i + 15 < n; i += 16) {
343 _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
344 }
345#elif defined(__AVX2__) && defined(__FMA__)
346 for (; i + 7 < n; i += 8) {
347 _mm256_storeu_ps(p: y + i, a: ggml_v_silu(x: _mm256_loadu_ps(p: x + i)));
348 }
349#elif defined(__SSE2__)
350 for (; i + 3 < n; i += 4) {
351 _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
352 }
353#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
354 const int vlen = svcntw();
355 for (; i < n; i += vlen) {
356 const svbool_t pg = svwhilelt_b32_s32(i, n);
357 svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
358 }
359#elif defined(__ARM_NEON) && defined(__aarch64__)
360 for (; i + 3 < n; i += 4) {
361 vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
362 }
363#endif
364 for (; i < n; ++i) {
365 y[i] = ggml_silu_f32(x: x[i]);
366 }
367}
368
369void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
370 int i = 0;
371#if defined(__AVX512F__) && defined(__AVX512DQ__)
372 for (; i + 15 < n; i += 16) {
373 _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
374 }
375#elif defined(__AVX2__) && defined(__FMA__)
376 for (; i + 7 < n; i += 8) {
377 _mm256_storeu_ps(p: y + i, a: _mm256_mul_ps(a: ggml_v_silu(x: _mm256_loadu_ps(p: x + i)), b: _mm256_loadu_ps(p: g + i)));
378 }
379#elif defined(__SSE2__)
380 for (; i + 3 < n; i += 4) {
381 _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
382 }
383#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
384 const int vlen = svcntw();
385 for (; i < n; i += vlen) {
386 const svbool_t pg = svwhilelt_b32_s32(i, n);
387 svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
388 }
389#elif defined(__ARM_NEON) && defined(__aarch64__)
390 for (; i + 3 < n; i += 4) {
391 vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
392 }
393#elif defined(__riscv_v_intrinsic)
394 for (int vl; i < n; i += vl) {
395 vl = __riscv_vsetvl_e32m2(n - i);
396 vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
397 vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
398 vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
399 __riscv_vse32_v_f32m2(&y[i], vy, vl);
400 }
401#endif
402 for (; i < n; ++i) {
403 y[i] = ggml_silu_f32(x: x[i]) * g[i];
404 }
405}
406
407ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
408 int i = 0;
409 ggml_float sum = 0;
410// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
411// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
412#if defined(__AVX512F__) && defined(__AVX512DQ__)
413 for (; i + 15 < n; i += 16) {
414 __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
415 _mm512_set1_ps(mean));
416 _mm512_storeu_ps(y + i, val);
417 sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
418 }
419#elif defined(__AVX2__) && defined(__FMA__)
420 for (; i + 7 < n; i += 8) {
421 __m256 val = _mm256_sub_ps(a: _mm256_loadu_ps(p: x + i),
422 b: _mm256_set1_ps(w: mean));
423 _mm256_storeu_ps(p: y + i, a: val);
424 val = _mm256_mul_ps(a: val,b: val);
425 __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
426 b: _mm256_castps256_ps128(a: val));
427 val2 = _mm_add_ps(a: val2, b: _mm_movehl_ps(a: val2, b: val2));
428 val2 = _mm_add_ss(a: val2, b: _mm_movehdup_ps(a: val2));
429 sum += (ggml_float)_mm_cvtss_f32(a: val2);
430 }
431#elif defined(__SSE2__)
432 for (; i + 3 < n; i += 4) {
433 __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
434 _mm_set1_ps(mean));
435 _mm_storeu_ps(y + i, val);
436 val = _mm_mul_ps(val, val);
437#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
438 val = _mm_add_ps(val, _mm_movehl_ps(val, val));
439 val = _mm_add_ss(val, _mm_movehdup_ps(val));
440#else
441 __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
442 val = _mm_add_ps(val, tmp);
443 tmp = _mm_movehl_ps(tmp, val);
444 val = _mm_add_ss(val, tmp);
445#endif // __AVX__ || __AVX2__ || __AVX512F__
446 sum += (ggml_float)_mm_cvtss_f32(val);
447 }
448#elif defined(__ARM_NEON) && defined(__aarch64__)
449 for (; i + 3 < n; i += 4) {
450 float32x4_t val = vsubq_f32(vld1q_f32(x + i),
451 vdupq_n_f32(mean));
452 vst1q_f32(y + i, val);
453 val = vmulq_f32(val, val);
454 sum += (ggml_float)vaddvq_f32(val);
455 }
456#elif defined(__VXE__) || defined(__VXE2__)
457 for (; i + 3 < n; i += 4) {
458 float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
459 vec_xst(val, 0, y + i);
460 val = vec_mul(val, val);
461 sum += (ggml_float)vec_hsum_f32x4(val);
462 }
463#endif
464 for (; i < n; ++i) {
465 float val = x[i] - mean;
466 y[i] = val;
467 val *= val;
468 sum += (ggml_float)val;
469 }
470 return sum/n;
471}
472
473ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
474 int i = 0;
475 ggml_float sum = 0;
476#if defined(__AVX512F__) && defined(__AVX512DQ__)
477 for (; i + 15 < n; i += 16) {
478 __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
479 _mm512_set1_ps(max)));
480 _mm512_storeu_ps(y + i, val);
481 sum += (ggml_float)_mm512_reduce_add_ps(val);
482 }
483#elif defined(__AVX2__) && defined(__FMA__)
484 for (; i + 7 < n; i += 8) {
485 __m256 val = ggml_v_expf(x: _mm256_sub_ps(a: _mm256_loadu_ps(p: x + i),
486 b: _mm256_set1_ps(w: max)));
487 _mm256_storeu_ps(p: y + i, a: val);
488 __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
489 b: _mm256_castps256_ps128(a: val));
490 val2 = _mm_add_ps(a: val2, b: _mm_movehl_ps(a: val2, b: val2));
491 val2 = _mm_add_ss(a: val2, b: _mm_movehdup_ps(a: val2));
492 sum += (ggml_float)_mm_cvtss_f32(a: val2);
493 }
494#elif defined(__SSE2__)
495 for (; i + 3 < n; i += 4) {
496 __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
497 _mm_set1_ps(max)));
498 _mm_storeu_ps(y + i, val);
499#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
500 val = _mm_add_ps(val, _mm_movehl_ps(val, val));
501 val = _mm_add_ss(val, _mm_movehdup_ps(val));
502#else
503 __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
504 val = _mm_add_ps(val, tmp);
505 tmp = _mm_movehl_ps(tmp, val);
506 val = _mm_add_ss(val, tmp);
507#endif
508 sum += (ggml_float)_mm_cvtss_f32(val);
509 }
510#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
511 const int vlen = svcntw();
512 for (; i < n; i += vlen) {
513 const svbool_t pg = svwhilelt_b32_s32(i, n);
514 svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
515 svdup_n_f32_x(pg, max)));
516 svst1_f32(pg, y + i, val);
517 sum += (ggml_float)svaddv_f32(pg, val);
518 }
519#elif defined(__ARM_NEON) && defined(__aarch64__)
520 for (; i + 3 < n; i += 4) {
521 float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
522 vdupq_n_f32(max)));
523 vst1q_f32(y + i, val);
524 sum += (ggml_float)vaddvq_f32(val);
525 }
526#elif defined(__riscv_v_intrinsic)
527 vfloat64m1_t vsum = __riscv_vfmv_v_f_f64m1(0, 1);
528 for (int avl; i < n; i += avl) {
529 avl = __riscv_vsetvl_e32m2(n - i);
530 vfloat32m2_t val = ggml_v_expf_m2(__riscv_vfsub_vf_f32m2(__riscv_vle32_v_f32m2(&x[i], avl), max, avl), avl);
531 __riscv_vse32_v_f32m2(&y[i], val, avl);
532 vsum = __riscv_vfwredusum_vs_f32m2_f64m1(val, vsum, avl);
533 }
534 return (ggml_float)__riscv_vfmv_f_s_f64m1_f64(vsum);
535#endif
536 for (; i < n; ++i) {
537 float val = expf(x: x[i] - max);
538 sum += (ggml_float)val;
539 y[i] = val;
540 }
541 return sum;
542}
543
544ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
545 // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
546
547 int i = 0;
548 ggml_float sum = 0;
549 for (; i < n; ++i) {
550 float val = x[i] - max;
551 y[i] = val;
552 sum += (ggml_float)expf(x: val);
553 }
554 return sum = (ggml_float)logf(x: sum);
555}
556