1#define GGML_COMMON_IMPL_C
2#include "ggml-common.h"
3#include "ggml-quants.h"
4#include "ggml-impl.h"
5#include "ggml-cpu.h"
6#include "simd-mappings.h"
7
8#include "../../quants.h"
9#include "../../ggml-cpu-impl.h"
10
11#include <math.h>
12#include <string.h>
13#include <assert.h>
14#include <stdlib.h> // for qsort
15#include <stdio.h> // for GGML_ASSERT
16
17#define GROUP_MAX_EPS 1e-15f
18#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19#define GROUP_MAX_EPS_IQ2_S 1e-8f
20#define GROUP_MAX_EPS_IQ1_M 1e-7f
21#define GROUP_MAX_EPS_IQ1_S 1e-12f
22
23#define UNUSED GGML_UNUSED
24
25// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
26#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
27
28#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
29// multiply int8_t, add results pairwise twice
30static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
31 // Get absolute values of x vectors
32 const __m128i ax = _mm_sign_epi8(a: x, b: x);
33 // Sign the values of the y vectors
34 const __m128i sy = _mm_sign_epi8(a: y, b: x);
35 // Perform multiplication and create 16-bit values
36 const __m128i dot = _mm_maddubs_epi16(a: ax, b: sy);
37 const __m128i ones = _mm_set1_epi16(w: 1);
38 return _mm_madd_epi16(a: ones, b: dot);
39}
40
41#if __AVX__ || __AVX2__ || __AVX512F__
42// horizontally add 8 floats
43static inline float hsum_float_8(const __m256 x) {
44 __m128 res = _mm256_extractf128_ps(x, 1);
45 res = _mm_add_ps(a: res, b: _mm256_castps256_ps128(a: x));
46 res = _mm_add_ps(a: res, b: _mm_movehl_ps(a: res, b: res));
47 res = _mm_add_ss(a: res, b: _mm_movehdup_ps(a: res));
48 return _mm_cvtss_f32(a: res);
49}
50
51// horizontally add 8 int32_t
52static inline int hsum_i32_8(const __m256i a) {
53 const __m128i sum128 = _mm_add_epi32(a: _mm256_castsi256_si128(a: a), _mm256_extractf128_si256(a, 1));
54 const __m128i hi64 = _mm_unpackhi_epi64(a: sum128, b: sum128);
55 const __m128i sum64 = _mm_add_epi32(a: hi64, b: sum128);
56 const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
57 return _mm_cvtsi128_si32(a: _mm_add_epi32(a: sum64, b: hi32));
58}
59
60// horizontally add 4 int32_t
61static inline int hsum_i32_4(const __m128i a) {
62 const __m128i hi64 = _mm_unpackhi_epi64(a: a, b: a);
63 const __m128i sum64 = _mm_add_epi32(a: hi64, b: a);
64 const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
65 return _mm_cvtsi128_si32(a: _mm_add_epi32(a: sum64, b: hi32));
66}
67
68#if defined(__AVX2__) || defined(__AVX512F__)
69static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
70 const __m256i ax = _mm256_sign_epi8(a: x, b: x);
71 const __m256i sy = _mm256_sign_epi8(a: y, b: x);
72 return _mm256_maddubs_epi16(a: ax, b: sy);
73}
74
75// spread 32 bits to 32 bytes { 0x00, 0xFF }
76static inline __m256i bytes_from_bits_32(const uint8_t * x) {
77 uint32_t x32;
78 memcpy(dest: &x32, src: x, n: sizeof(uint32_t));
79 const __m256i shuf_mask = _mm256_set_epi64x(
80 a: 0x0303030303030303, b: 0x0202020202020202,
81 c: 0x0101010101010101, d: 0x0000000000000000);
82 __m256i bytes = _mm256_shuffle_epi8(a: _mm256_set1_epi32(i: x32), b: shuf_mask);
83 const __m256i bit_mask = _mm256_set1_epi64x(q: 0x7fbfdfeff7fbfdfe);
84 bytes = _mm256_or_si256(a: bytes, b: bit_mask);
85 return _mm256_cmpeq_epi8(a: bytes, b: _mm256_set1_epi64x(q: -1));
86}
87
88// Unpack 32 4-bit fields into 32 bytes
89// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
90static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
91{
92 const __m128i tmp = _mm_loadu_si128(p: (const __m128i *)rsi);
93 const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
94 const __m256i lowMask = _mm256_set1_epi8( b: 0xF );
95 return _mm256_and_si256(a: lowMask, b: bytes);
96}
97
98// add int16_t pairwise and return as float vector
99static inline __m256 sum_i16_pairs_float(const __m256i x) {
100 const __m256i ones = _mm256_set1_epi16(w: 1);
101 const __m256i summed_pairs = _mm256_madd_epi16(a: ones, b: x);
102 return _mm256_cvtepi32_ps(a: summed_pairs);
103}
104
105static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
106#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
107 const __m256i zero = _mm256_setzero_si256();
108 const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
109 return _mm256_cvtepi32_ps(summed_pairs);
110#elif defined(__AVXVNNI__)
111 const __m256i zero = _mm256_setzero_si256();
112 const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(S: zero, A: ax, B: sy);
113 return _mm256_cvtepi32_ps(a: summed_pairs);
114#else
115 // Perform multiplication and create 16-bit values
116 const __m256i dot = _mm256_maddubs_epi16(ax, sy);
117 return sum_i16_pairs_float(dot);
118#endif
119}
120
121// multiply int8_t, add results pairwise twice and return as float vector
122static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
123#if __AVXVNNIINT8__
124 const __m256i zero = _mm256_setzero_si256();
125 const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
126 return _mm256_cvtepi32_ps(summed_pairs);
127#else
128 // Get absolute values of x vectors
129 const __m256i ax = _mm256_sign_epi8(a: x, b: x);
130 // Sign the values of the y vectors
131 const __m256i sy = _mm256_sign_epi8(a: y, b: x);
132 return mul_sum_us8_pairs_float(ax, sy);
133#endif
134}
135
136static inline __m128i packNibbles( __m256i bytes )
137{
138 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
139#if __AVX512F__
140 const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
141 bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
142 return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
143#else
144 const __m256i lowByte = _mm256_set1_epi16( w: 0xFF );
145 __m256i high = _mm256_andnot_si256( a: lowByte, b: bytes );
146 __m256i low = _mm256_and_si256( a: lowByte, b: bytes );
147 high = _mm256_srli_epi16( a: high, count: 4 );
148 bytes = _mm256_or_si256( a: low, b: high );
149
150 // Compress uint16_t lanes into bytes
151 __m128i r0 = _mm256_castsi256_si128( a: bytes );
152 __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
153 return _mm_packus_epi16( a: r0, b: r1 );
154#endif
155}
156#elif defined(__AVX__)
157static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
158{
159 // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
160 const __m128i lowByte = _mm_set1_epi16( 0xFF );
161 __m128i high = _mm_andnot_si128( lowByte, bytes1 );
162 __m128i low = _mm_and_si128( lowByte, bytes1 );
163 high = _mm_srli_epi16( high, 4 );
164 bytes1 = _mm_or_si128( low, high );
165 high = _mm_andnot_si128( lowByte, bytes2 );
166 low = _mm_and_si128( lowByte, bytes2 );
167 high = _mm_srli_epi16( high, 4 );
168 bytes2 = _mm_or_si128( low, high );
169
170 return _mm_packus_epi16( bytes1, bytes2);
171}
172
173static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
174 const __m128i ax = _mm_sign_epi8(x, x);
175 const __m128i sy = _mm_sign_epi8(y, x);
176 return _mm_maddubs_epi16(ax, sy);
177}
178
179// spread 32 bits to 32 bytes { 0x00, 0xFF }
180static inline __m256i bytes_from_bits_32(const uint8_t * x) {
181 uint32_t x32;
182 memcpy(&x32, x, sizeof(uint32_t));
183 const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
184 const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
185 __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
186 __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
187 const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
188 bytesl = _mm_or_si128(bytesl, bit_mask);
189 bytesh = _mm_or_si128(bytesh, bit_mask);
190 bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
191 bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
192 return MM256_SET_M128I(bytesh, bytesl);
193}
194
195// Unpack 32 4-bit fields into 32 bytes
196// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
197static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
198{
199 // Load 16 bytes from memory
200 __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
201 __m128i tmph = _mm_srli_epi16(tmpl, 4);
202 const __m128i lowMask = _mm_set1_epi8(0xF);
203 tmpl = _mm_and_si128(lowMask, tmpl);
204 tmph = _mm_and_si128(lowMask, tmph);
205 return MM256_SET_M128I(tmph, tmpl);
206}
207
208// add int16_t pairwise and return as float vector
209static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
210 const __m128i ones = _mm_set1_epi16(1);
211 const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
212 const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
213 const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
214 return _mm256_cvtepi32_ps(summed_pairs);
215}
216
217static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
218 const __m128i axl = _mm256_castsi256_si128(ax);
219 const __m128i axh = _mm256_extractf128_si256(ax, 1);
220 const __m128i syl = _mm256_castsi256_si128(sy);
221 const __m128i syh = _mm256_extractf128_si256(sy, 1);
222 // Perform multiplication and create 16-bit values
223 const __m128i dotl = _mm_maddubs_epi16(axl, syl);
224 const __m128i doth = _mm_maddubs_epi16(axh, syh);
225 return sum_i16_pairs_float(doth, dotl);
226}
227
228// multiply int8_t, add results pairwise twice and return as float vector
229static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
230 const __m128i xl = _mm256_castsi256_si128(x);
231 const __m128i xh = _mm256_extractf128_si256(x, 1);
232 const __m128i yl = _mm256_castsi256_si128(y);
233 const __m128i yh = _mm256_extractf128_si256(y, 1);
234 // Get absolute values of x vectors
235 const __m128i axl = _mm_sign_epi8(xl, xl);
236 const __m128i axh = _mm_sign_epi8(xh, xh);
237 // Sign the values of the y vectors
238 const __m128i syl = _mm_sign_epi8(yl, xl);
239 const __m128i syh = _mm_sign_epi8(yh, xh);
240 // Perform multiplication and create 16-bit values
241 const __m128i dotl = _mm_maddubs_epi16(axl, syl);
242 const __m128i doth = _mm_maddubs_epi16(axh, syh);
243 return sum_i16_pairs_float(doth, dotl);
244}
245
246// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
247static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
248 const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
249 const __m128i mone = _mm_set1_epi16(1);
250
251 const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
252 const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
253 const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
254 const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
255 const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
256 const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
257 const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
258 const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
259 const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
260 const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
261 return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
262}
263
264// quad fp16 delta calculation
265static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
266 // GGML_CPU_FP16_TO_FP32 is faster than Intel F16C
267 return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
268 _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
269}
270
271static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
272 return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
273 _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
274}
275#endif
276#elif defined(__SSSE3__)
277// horizontally add 4x4 floats
278static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
279 __m128 res_0 =_mm_hadd_ps(a, b);
280 __m128 res_1 =_mm_hadd_ps(c, d);
281 __m128 res =_mm_hadd_ps(res_0, res_1);
282 res =_mm_hadd_ps(res, res);
283 res =_mm_hadd_ps(res, res);
284
285 return _mm_cvtss_f32(res);
286}
287#endif // __AVX__ || __AVX2__ || __AVX512F__
288#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
289
290void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
291 assert(QK8_0 == 32);
292 assert(k % QK8_0 == 0);
293 const int nb = k / QK8_0;
294
295 block_q8_0 * GGML_RESTRICT y = vy;
296
297#if defined(__AVX2__) || defined(__AVX__)
298 for (int i = 0; i < nb; i++) {
299 // Load elements into 4 AVX vectors
300 __m256 v0 = _mm256_loadu_ps( p: x );
301 __m256 v1 = _mm256_loadu_ps( p: x + 8 );
302 __m256 v2 = _mm256_loadu_ps( p: x + 16 );
303 __m256 v3 = _mm256_loadu_ps( p: x + 24 );
304 x += 32;
305
306 // Compute max(abs(e)) for the block
307 const __m256 signBit = _mm256_set1_ps( w: -0.0f );
308 __m256 maxAbs = _mm256_andnot_ps( a: signBit, b: v0 );
309 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v1 ) );
310 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v2 ) );
311 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v3 ) );
312
313 __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), b: _mm256_castps256_ps128( a: maxAbs ) );
314 max4 = _mm_max_ps( a: max4, b: _mm_movehl_ps( a: max4, b: max4 ) );
315 max4 = _mm_max_ss( a: max4, b: _mm_movehdup_ps( a: max4 ) );
316 const float maxScalar = _mm_cvtss_f32( a: max4 );
317
318 // Quantize these floats
319 const float d = maxScalar / 127.f;
320 y[i].d = GGML_CPU_FP32_TO_FP16(d);
321 const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
322 const __m256 mul = _mm256_set1_ps( w: id );
323
324 // Apply the multiplier
325 v0 = _mm256_mul_ps( a: v0, b: mul );
326 v1 = _mm256_mul_ps( a: v1, b: mul );
327 v2 = _mm256_mul_ps( a: v2, b: mul );
328 v3 = _mm256_mul_ps( a: v3, b: mul );
329
330 // Round to nearest integer
331 v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
332 v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
333 v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
334 v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
335
336 // Convert floats to integers
337 __m256i i0 = _mm256_cvtps_epi32( a: v0 );
338 __m256i i1 = _mm256_cvtps_epi32( a: v1 );
339 __m256i i2 = _mm256_cvtps_epi32( a: v2 );
340 __m256i i3 = _mm256_cvtps_epi32( a: v3 );
341
342#if defined(__AVX2__)
343 // Convert int32 to int16
344 i0 = _mm256_packs_epi32( a: i0, b: i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
345 i2 = _mm256_packs_epi32( a: i2, b: i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
346 // Convert int16 to int8
347 i0 = _mm256_packs_epi16( a: i0, b: i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
348
349 // We got our precious signed bytes, but the order is now wrong
350 // These AVX2 pack instructions process 16-byte pieces independently
351 // The following instruction is fixing the order
352 const __m256i perm = _mm256_setr_epi32( i0: 0, i1: 4, i2: 1, i3: 5, i4: 2, i5: 6, i6: 3, i7: 7 );
353 i0 = _mm256_permutevar8x32_epi32( a: i0, b: perm );
354
355 _mm256_storeu_si256(p: (__m256i *)y[i].qs, a: i0);
356#else
357 // Since we don't have in AVX some necessary functions,
358 // we split the registers in half and call AVX2 analogs from SSE
359 __m128i ni0 = _mm256_castsi256_si128( i0 );
360 __m128i ni1 = _mm256_extractf128_si256( i0, 1);
361 __m128i ni2 = _mm256_castsi256_si128( i1 );
362 __m128i ni3 = _mm256_extractf128_si256( i1, 1);
363 __m128i ni4 = _mm256_castsi256_si128( i2 );
364 __m128i ni5 = _mm256_extractf128_si256( i2, 1);
365 __m128i ni6 = _mm256_castsi256_si128( i3 );
366 __m128i ni7 = _mm256_extractf128_si256( i3, 1);
367
368 // Convert int32 to int16
369 ni0 = _mm_packs_epi32( ni0, ni1 );
370 ni2 = _mm_packs_epi32( ni2, ni3 );
371 ni4 = _mm_packs_epi32( ni4, ni5 );
372 ni6 = _mm_packs_epi32( ni6, ni7 );
373 // Convert int16 to int8
374 ni0 = _mm_packs_epi16( ni0, ni2 );
375 ni4 = _mm_packs_epi16( ni4, ni6 );
376
377 _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
378 _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
379#endif
380 }
381#else
382 GGML_UNUSED(nb);
383 // scalar
384 quantize_row_q8_0_ref(x, y, k);
385#endif
386}
387
388void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
389 assert(k % QK8_1 == 0);
390 const int nb = k / QK8_1;
391
392 block_q8_1 * GGML_RESTRICT y = vy;
393#if defined(__AVX2__) || defined(__AVX__)
394 for (int i = 0; i < nb; i++) {
395 // Load elements into 4 AVX vectors
396 __m256 v0 = _mm256_loadu_ps( p: x );
397 __m256 v1 = _mm256_loadu_ps( p: x + 8 );
398 __m256 v2 = _mm256_loadu_ps( p: x + 16 );
399 __m256 v3 = _mm256_loadu_ps( p: x + 24 );
400 x += 32;
401
402 // Compute max(abs(e)) for the block
403 const __m256 signBit = _mm256_set1_ps( w: -0.0f );
404 __m256 maxAbs = _mm256_andnot_ps( a: signBit, b: v0 );
405 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v1 ) );
406 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v2 ) );
407 maxAbs = _mm256_max_ps( a: maxAbs, b: _mm256_andnot_ps( a: signBit, b: v3 ) );
408
409 __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), b: _mm256_castps256_ps128( a: maxAbs ) );
410 max4 = _mm_max_ps( a: max4, b: _mm_movehl_ps( a: max4, b: max4 ) );
411 max4 = _mm_max_ss( a: max4, b: _mm_movehdup_ps( a: max4 ) );
412 const float max_scalar = _mm_cvtss_f32( a: max4 );
413
414 // Quantize these floats
415 const float d = max_scalar / 127.f;
416 y[i].d = GGML_CPU_FP32_TO_FP16(d);
417 const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
418 const __m256 mul = _mm256_set1_ps( w: id );
419
420 // Apply the multiplier
421 v0 = _mm256_mul_ps( a: v0, b: mul );
422 v1 = _mm256_mul_ps( a: v1, b: mul );
423 v2 = _mm256_mul_ps( a: v2, b: mul );
424 v3 = _mm256_mul_ps( a: v3, b: mul );
425
426 // Round to nearest integer
427 v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
428 v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
429 v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
430 v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
431
432 // Convert floats to integers
433 __m256i i0 = _mm256_cvtps_epi32( a: v0 );
434 __m256i i1 = _mm256_cvtps_epi32( a: v1 );
435 __m256i i2 = _mm256_cvtps_epi32( a: v2 );
436 __m256i i3 = _mm256_cvtps_epi32( a: v3 );
437
438#if defined(__AVX2__)
439 // Compute the sum of the quants and set y[i].s
440 y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
441
442 // Convert int32 to int16
443 i0 = _mm256_packs_epi32( a: i0, b: i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
444 i2 = _mm256_packs_epi32( a: i2, b: i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
445 // Convert int16 to int8
446 i0 = _mm256_packs_epi16( a: i0, b: i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
447
448 // We got our precious signed bytes, but the order is now wrong
449 // These AVX2 pack instructions process 16-byte pieces independently
450 // The following instruction is fixing the order
451 const __m256i perm = _mm256_setr_epi32( i0: 0, i1: 4, i2: 1, i3: 5, i4: 2, i5: 6, i6: 3, i7: 7 );
452 i0 = _mm256_permutevar8x32_epi32( a: i0, b: perm );
453
454 _mm256_storeu_si256(p: (__m256i *)y[i].qs, a: i0);
455#else
456 // Since we don't have in AVX some necessary functions,
457 // we split the registers in half and call AVX2 analogs from SSE
458 __m128i ni0 = _mm256_castsi256_si128( i0 );
459 __m128i ni1 = _mm256_extractf128_si256( i0, 1);
460 __m128i ni2 = _mm256_castsi256_si128( i1 );
461 __m128i ni3 = _mm256_extractf128_si256( i1, 1);
462 __m128i ni4 = _mm256_castsi256_si128( i2 );
463 __m128i ni5 = _mm256_extractf128_si256( i2, 1);
464 __m128i ni6 = _mm256_castsi256_si128( i3 );
465 __m128i ni7 = _mm256_extractf128_si256( i3, 1);
466
467 // Compute the sum of the quants and set y[i].s
468 const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
469 const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
470 y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
471
472 // Convert int32 to int16
473 ni0 = _mm_packs_epi32( ni0, ni1 );
474 ni2 = _mm_packs_epi32( ni2, ni3 );
475 ni4 = _mm_packs_epi32( ni4, ni5 );
476 ni6 = _mm_packs_epi32( ni6, ni7 );
477 // Convert int16 to int8
478 ni0 = _mm_packs_epi16( ni0, ni2 );
479 ni4 = _mm_packs_epi16( ni4, ni6 );
480
481 _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
482 _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
483#endif
484 }
485#else
486 GGML_UNUSED(nb);
487 // scalar
488 quantize_row_q8_1_ref(x, y, k);
489#endif
490}
491
492// placeholder implementation for Apple targets
493void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
494 quantize_row_q8_K_ref(x, y, k);
495}
496
497//===================================== Dot products =================================
498
499//
500// Helper functions
501//
502
503#if __AVX__ || __AVX2__ || __AVX512F__
504
505// shuffles to pick the required scales in dot products
506static inline __m256i get_scale_shuffle_q3k(int i) {
507 static const uint8_t k_shuffle[128] = {
508 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
509 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
510 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
511 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
512 };
513 return _mm256_loadu_si256(p: (const __m256i*)k_shuffle + i);
514}
515static inline __m256i get_scale_shuffle_k4(int i) {
516 static const uint8_t k_shuffle[256] = {
517 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
518 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
519 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
520 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
521 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
522 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
523 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
524 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
525 };
526 return _mm256_loadu_si256(p: (const __m256i*)k_shuffle + i);
527}
528static inline __m128i get_scale_shuffle(int i) {
529 static const uint8_t k_shuffle[128] = {
530 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
531 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
532 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
533 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
534 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
535 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
536 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
537 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
538 };
539 return _mm_loadu_si128(p: (const __m128i*)k_shuffle + i);
540}
541#endif
542
543void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
544 const int qk = QK8_0;
545 const int nb = n / qk;
546
547 assert(n % qk == 0);
548 assert(nrc == 1);
549 UNUSED(nrc);
550 UNUSED(bx);
551 UNUSED(by);
552 UNUSED(bs);
553
554 const block_q4_0 * GGML_RESTRICT x = vx;
555 const block_q8_0 * GGML_RESTRICT y = vy;
556
557 int ib = 0;
558 float sumf = 0;
559
560#if defined(__AVX2__)
561 // Initialize accumulator with zeros
562 __m256 acc = _mm256_setzero_ps();
563
564 // Main loop
565 for (; ib < nb; ++ib) {
566 /* Compute combined scale for the block */
567 const __m256 d = _mm256_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
568
569 __m256i qx = bytes_from_nibbles_32(rsi: x[ib].qs);
570
571 // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
572 const __m256i off = _mm256_set1_epi8( b: 8 );
573 qx = _mm256_sub_epi8( a: qx, b: off );
574
575 __m256i qy = _mm256_loadu_si256(p: (const __m256i *)y[ib].qs);
576
577 const __m256 q = mul_sum_i8_pairs_float(x: qx, y: qy);
578
579 /* Multiply q with scale and accumulate */
580 acc = _mm256_fmadd_ps( A: d, B: q, C: acc );
581 }
582
583 sumf = hsum_float_8(x: acc);
584#elif defined(__AVX__)
585 __m256 accum = _mm256_setzero_ps();
586 for (; ib + 1 < nb; ib += 2) {
587 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
588 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
589 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
590 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
591 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
592 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
593
594 const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
595 const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
596 const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
597 const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
598
599 const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
600 const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
601 const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
602 const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
603 const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
604 const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
605 const __m256 p = sum_i16_pairs_float(p_2, p_1);
606
607 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
608 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
609 }
610
611 sumf = hsum_float_8(accum);
612#elif defined(__SSSE3__)
613 // set constants
614 const __m128i lowMask = _mm_set1_epi8(0xF);
615 const __m128i off = _mm_set1_epi8(8);
616
617 // Initialize accumulator with zeros
618 __m128 acc_0 = _mm_setzero_ps();
619 __m128 acc_1 = _mm_setzero_ps();
620 __m128 acc_2 = _mm_setzero_ps();
621 __m128 acc_3 = _mm_setzero_ps();
622
623 for (; ib + 1 < nb; ib += 2) {
624 _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
625 _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
626
627 // Compute combined scale for the block 0 and 1
628 const __m128 d_0_1 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
629
630 const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
631
632 __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
633 __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
634 bx_0 = _mm_sub_epi8(bx_0, off);
635 const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
636
637 __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
638 __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
639 bx_1 = _mm_sub_epi8(bx_1, off);
640 const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
641
642 _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
643 _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
644
645 // Compute combined scale for the block 2 and 3
646 const __m128 d_2_3 = _mm_set1_ps( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
647
648 const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
649
650 __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
651 __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
652 bx_2 = _mm_sub_epi8(bx_2, off);
653 const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
654
655 __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
656 __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
657 bx_3 = _mm_sub_epi8(bx_3, off);
658 const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
659
660 // Convert int32_t to float
661 __m128 p0 = _mm_cvtepi32_ps(i32_0);
662 __m128 p1 = _mm_cvtepi32_ps(i32_1);
663 __m128 p2 = _mm_cvtepi32_ps(i32_2);
664 __m128 p3 = _mm_cvtepi32_ps(i32_3);
665
666 // Apply the scale
667 __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
668 __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
669 __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
670 __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
671
672 // Acummulate
673 acc_0 = _mm_add_ps(p0_d, acc_0);
674 acc_1 = _mm_add_ps(p1_d, acc_1);
675 acc_2 = _mm_add_ps(p2_d, acc_2);
676 acc_3 = _mm_add_ps(p3_d, acc_3);
677 }
678
679 sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
680
681#endif
682 for (; ib < nb; ++ib) {
683 int sumi0 = 0;
684 int sumi1 = 0;
685
686 for (int j = 0; j < qk/2; ++j) {
687 const int v0 = (x[ib].qs[j] & 0x0F) - 8;
688 const int v1 = (x[ib].qs[j] >> 4) - 8;
689
690 sumi0 += (v0 * y[ib].qs[j]);
691 sumi1 += (v1 * y[ib].qs[j + qk/2]);
692 }
693
694 int sumi = sumi0 + sumi1;
695 sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
696 }
697
698 *s = sumf;
699}
700
701void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
702 const int qk = QK8_1;
703 const int nb = n / qk;
704
705 assert(n % qk == 0);
706 assert(nrc == 1);
707 UNUSED(nrc);
708 UNUSED(bx);
709 UNUSED(by);
710 UNUSED(bs);
711
712 const block_q4_1 * GGML_RESTRICT x = vx;
713 const block_q8_1 * GGML_RESTRICT y = vy;
714
715 int ib = 0;
716
717#if defined(__AVX2__) || defined(__AVX__)
718 // Initialize accumulator with zeros
719 __m256 acc = _mm256_setzero_ps();
720
721 float summs = 0;
722
723 // Main loop
724 for (; ib < nb; ++ib) {
725 const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
726 const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
727
728 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
729
730 const __m256 d0v = _mm256_set1_ps( w: d0 );
731 const __m256 d1v = _mm256_set1_ps( w: d1 );
732
733 // Compute combined scales
734 const __m256 d0d1 = _mm256_mul_ps( a: d0v, b: d1v );
735
736 // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
737 const __m256i qx = bytes_from_nibbles_32(rsi: x[ib].qs);
738 const __m256i qy = _mm256_loadu_si256( p: (const __m256i *)y[ib].qs );
739
740 const __m256 xy = mul_sum_us8_pairs_float(ax: qx, sy: qy);
741
742 // Accumulate d0*d1*x*y
743#if defined(__AVX2__)
744 acc = _mm256_fmadd_ps( A: d0d1, B: xy, C: acc );
745#else
746 acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
747#endif
748 }
749
750 *s = hsum_float_8(x: acc) + summs;
751#else
752 UNUSED(nb);
753 UNUSED(x);
754 UNUSED(y);
755 UNUSED(ib);
756 ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
757#endif
758}
759
760void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
761 assert(nrc == 1);
762 UNUSED(nrc);
763 UNUSED(bx);
764 UNUSED(by);
765 UNUSED(bs);
766 assert(n % QK_MXFP4 == 0);
767 static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
768
769 const block_mxfp4 * GGML_RESTRICT x = vx;
770 const block_q8_0 * GGML_RESTRICT y = vy;
771
772 const int nb = n / QK_MXFP4;
773
774 int ib = 0;
775 float sumf = 0;
776
777#if defined __AVX2__
778
779 const __m128i values128 = _mm_loadu_si128(p: (const __m128i*)kvalues_mxfp4);
780 const __m128i m4b = _mm_set1_epi8(b: 0x0f);
781 const __m256i mone = _mm256_set1_epi16(w: 1);
782
783 __m256 accum1 = _mm256_setzero_ps();
784 __m256 accum2 = _mm256_setzero_ps();
785 for (; ib + 1 < nb; ib += 2) {
786 const __m128i q4bits_1 = _mm_loadu_si128(p: (const __m128i*)x[ib + 0].qs);
787 const __m128i q4bits_2 = _mm_loadu_si128(p: (const __m128i*)x[ib + 1].qs);
788 const __m256i q8b_1 = _mm256_loadu_si256(p: (const __m256i *)y[ib + 0].qs);
789 const __m256i q8b_2 = _mm256_loadu_si256(p: (const __m256i *)y[ib + 1].qs);
790 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
791 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
792 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
793 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
794 const __m256i p16_1 = mul_add_epi8(x: q4b_1, y: q8b_1);
795 const __m256i p16_2 = mul_add_epi8(x: q4b_2, y: q8b_2);
796 const __m256i p_1 = _mm256_madd_epi16(a: p16_1, b: mone);
797 const __m256i p_2 = _mm256_madd_epi16(a: p16_2, b: mone);
798 accum1 = _mm256_fmadd_ps(A: _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
799 B: _mm256_cvtepi32_ps(a: p_1), C: accum1);
800 accum2 = _mm256_fmadd_ps(A: _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
801 B: _mm256_cvtepi32_ps(a: p_2), C: accum2);
802 }
803
804 sumf = hsum_float_8(x: _mm256_add_ps(a: accum1, b: accum2));
805
806#elif defined __AVX__
807 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
808 const __m128i m4b = _mm_set1_epi8(0x0f);
809
810 __m256 accum = _mm256_setzero_ps();
811 for (; ib + 1 < nb; ib += 2) {
812 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
813 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
814 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
815 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
816 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
817 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
818
819 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
820 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
821 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
822 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
823
824 const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
825 const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
826 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
827 }
828
829 sumf = hsum_float_8(accum);
830
831#endif
832 for (; ib < nb; ++ib) {
833 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
834 int sumi1 = 0;
835 int sumi2 = 0;
836 for (int j = 0; j < QK_MXFP4/2; ++j) {
837 sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
838 sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
839 }
840 sumf += d * (sumi1 + sumi2);
841 }
842 *s = sumf;
843}
844
845void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
846 const int qk = QK8_0;
847 const int nb = n / qk;
848
849 int ib = 0;
850
851 assert(n % qk == 0);
852 assert(qk == QK5_0);
853 assert(nrc == 1);
854 UNUSED(nrc);
855 UNUSED(bx);
856 UNUSED(by);
857 UNUSED(bs);
858
859 const block_q5_0 * GGML_RESTRICT x = vx;
860 const block_q8_0 * GGML_RESTRICT y = vy;
861
862#if defined(__AVX2__)
863 // Initialize accumulator with zeros
864 __m256 acc = _mm256_setzero_ps();
865
866 // Main loop
867 for (; ib < nb; ++ib) {
868 /* Compute combined scale for the block */
869 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
870
871 __m256i qx = bytes_from_nibbles_32(rsi: x[ib].qs);
872 __m256i bxhi = bytes_from_bits_32(x: x[ib].qh);
873 bxhi = _mm256_andnot_si256(a: bxhi, b: _mm256_set1_epi8(b: (char)0xF0));
874 qx = _mm256_or_si256(a: qx, b: bxhi);
875
876 __m256i qy = _mm256_loadu_si256(p: (const __m256i *)y[ib].qs);
877
878 const __m256 q = mul_sum_i8_pairs_float(x: qx, y: qy);
879
880 /* Multiply q with scale and accumulate */
881 acc = _mm256_fmadd_ps(A: d, B: q, C: acc);
882 }
883
884 *s = hsum_float_8(x: acc);
885#elif defined(__AVX__)
886 // Initialize accumulator with zeros
887 __m256 acc = _mm256_setzero_ps();
888 __m128i mask = _mm_set1_epi8((char)0xF0);
889
890 // Main loop
891 for (; ib < nb; ++ib) {
892 /* Compute combined scale for the block */
893 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
894
895 __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
896 const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
897 __m128i bxhil = _mm256_castsi256_si128(bxhi);
898 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
899 bxhil = _mm_andnot_si128(bxhil, mask);
900 bxhih = _mm_andnot_si128(bxhih, mask);
901 __m128i bxl = _mm256_castsi256_si128(bx_0);
902 __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
903 bxl = _mm_or_si128(bxl, bxhil);
904 bxh = _mm_or_si128(bxh, bxhih);
905 bx_0 = MM256_SET_M128I(bxh, bxl);
906
907 const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
908
909 const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
910
911 /* Multiply q with scale and accumulate */
912 acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
913 }
914
915 *s = hsum_float_8(acc);
916#else
917 UNUSED(nb);
918 UNUSED(ib);
919 UNUSED(x);
920 UNUSED(y);
921 ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
922#endif
923}
924
925void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
926 const int qk = QK8_1;
927 const int nb = n / qk;
928
929 int ib = 0;
930
931 assert(n % qk == 0);
932 assert(qk == QK5_1);
933 assert(nrc == 1);
934 UNUSED(nrc);
935 UNUSED(bx);
936 UNUSED(by);
937 UNUSED(bs);
938
939 const block_q5_1 * GGML_RESTRICT x = vx;
940 const block_q8_1 * GGML_RESTRICT y = vy;
941
942#if defined(__AVX2__)
943 // Initialize accumulator with zeros
944 __m256 acc = _mm256_setzero_ps();
945
946 float summs = 0.0f;
947
948 // Main loop
949 for (; ib < nb; ++ib) {
950 const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
951
952 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
953
954 __m256i qx = bytes_from_nibbles_32(rsi: x[ib].qs);
955 __m256i bxhi = bytes_from_bits_32(x: x[ib].qh);
956 bxhi = _mm256_and_si256(a: bxhi, b: _mm256_set1_epi8(b: 0x10));
957 qx = _mm256_or_si256(a: qx, b: bxhi);
958
959 const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
960 const __m256i qy = _mm256_loadu_si256(p: (const __m256i *)y[ib].qs);
961
962 const __m256 q = mul_sum_us8_pairs_float(ax: qx, sy: qy);
963
964 acc = _mm256_fmadd_ps(A: q, B: _mm256_mul_ps(a: dx, b: dy), C: acc);
965 }
966
967 *s = hsum_float_8(x: acc) + summs;
968#elif defined(__AVX__)
969 // Initialize accumulator with zeros
970 __m256 acc = _mm256_setzero_ps();
971 __m128i mask = _mm_set1_epi8(0x10);
972
973 float summs = 0.0f;
974
975 // Main loop
976 for (; ib < nb; ++ib) {
977 const __m256 dx = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d));
978
979 summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
980
981 __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
982 const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
983 __m128i bxhil = _mm256_castsi256_si128(bxhi);
984 __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
985 bxhil = _mm_and_si128(bxhil, mask);
986 bxhih = _mm_and_si128(bxhih, mask);
987 __m128i bxl = _mm256_castsi256_si128(bx_0);
988 __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
989 bxl = _mm_or_si128(bxl, bxhil);
990 bxh = _mm_or_si128(bxh, bxhih);
991 bx_0 = MM256_SET_M128I(bxh, bxl);
992
993 const __m256 dy = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib].d));
994 const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
995
996 const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
997
998 acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
999 }
1000
1001 *s = hsum_float_8(acc) + summs;
1002#else
1003 UNUSED(nb);
1004 UNUSED(ib);
1005 UNUSED(x);
1006 UNUSED(y);
1007 ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
1008#endif
1009}
1010
1011void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1012 const int qk = QK8_0;
1013 const int nb = n / qk;
1014
1015 assert(n % qk == 0);
1016 assert(nrc == 1);
1017 UNUSED(nrc);
1018 UNUSED(bx);
1019 UNUSED(by);
1020 UNUSED(bs);
1021
1022 const block_q8_0 * GGML_RESTRICT x = vx;
1023 const block_q8_0 * GGML_RESTRICT y = vy;
1024
1025 int ib = 0;
1026 float sumf = 0;
1027
1028#if defined(__AVX2__)
1029 // Initialize accumulator with zeros
1030 __m256 acc = _mm256_setzero_ps();
1031
1032 // Main loop
1033 for (; ib < nb; ++ib) {
1034 // Compute combined scale for the block
1035 const __m256 d = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
1036 __m256i qx = _mm256_loadu_si256(p: (const __m256i *)x[ib].qs);
1037 __m256i qy = _mm256_loadu_si256(p: (const __m256i *)y[ib].qs);
1038
1039 const __m256 q = mul_sum_i8_pairs_float(x: qx, y: qy);
1040
1041 // Multiply q with scale and accumulate
1042 acc = _mm256_fmadd_ps( A: d, B: q, C: acc );
1043 }
1044
1045 sumf = hsum_float_8(x: acc);
1046#elif defined(__AVX__)
1047 __m256 accum = _mm256_setzero_ps();
1048
1049 for (; ib + 1 < nb; ib += 2) {
1050 const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
1051 const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
1052 const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
1053 const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
1054 const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
1055 const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
1056 const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
1057 const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
1058
1059 const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
1060 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
1061 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
1062 }
1063
1064 sumf = hsum_float_8(accum);
1065#endif
1066 for (; ib < nb; ++ib) {
1067 int sumi = 0;
1068
1069 for (int j = 0; j < qk; j++) {
1070 sumi += x[ib].qs[j]*y[ib].qs[j];
1071 }
1072
1073 sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
1074 }
1075
1076 *s = sumf;
1077}
1078
1079void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1080 assert(nrc == 1);
1081 UNUSED(nrc);
1082 UNUSED(bx);
1083 UNUSED(by);
1084 UNUSED(bs);
1085
1086 const block_tq1_0 * GGML_RESTRICT x = vx;
1087 const block_q8_K * GGML_RESTRICT y = vy;
1088
1089 const int nb = n / QK_K;
1090
1091#if defined(__AVX2__)
1092 __m256 sumf = _mm256_setzero_ps();
1093
1094 for (int i = 0; i < nb; ++i) {
1095 // 16-bit sums
1096 __m256i sumi0 = _mm256_setzero_si256();
1097 __m256i sumi1 = _mm256_setzero_si256();
1098 __m256i sumi2 = _mm256_setzero_si256();
1099
1100 // first 32 bytes of 5 elements
1101 {
1102 __m256i qx0 = _mm256_loadu_si256(p: (const __m256i *) (x[i].qs));
1103 // 8-bit multiplies with shifts, masks and adds
1104 __m256i qx1 = _mm256_add_epi8(a: qx0, b: _mm256_add_epi8(a: qx0, b: qx0)); // 1 * 3
1105 __m256i qx2 = _mm256_add_epi8(a: _mm256_and_si256(a: _mm256_slli_epi16(a: qx0, count: 3), b: _mm256_set1_epi8(b: -8)), b: qx0); // 1 * 9
1106 __m256i qx3 = _mm256_add_epi8(a: _mm256_and_si256(a: _mm256_slli_epi16(a: qx1, count: 3), b: _mm256_set1_epi8(b: -8)), b: qx1); // 3 * 9
1107 __m256i qx4 = _mm256_add_epi8(a: _mm256_and_si256(a: _mm256_slli_epi16(a: qx2, count: 3), b: _mm256_set1_epi8(b: -8)), b: qx2); // 9 * 9
1108
1109 // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
1110
1111 // Cancel the +1 from avg so that it behaves like a halving add
1112 qx0 = _mm256_subs_epu8(a: qx0, b: _mm256_set1_epi8(b: 1));
1113 qx1 = _mm256_subs_epu8(a: qx1, b: _mm256_set1_epi8(b: 1));
1114 qx2 = _mm256_subs_epu8(a: qx2, b: _mm256_set1_epi8(b: 1));
1115 qx3 = _mm256_subs_epu8(a: qx3, b: _mm256_set1_epi8(b: 1));
1116 qx4 = _mm256_subs_epu8(a: qx4, b: _mm256_set1_epi8(b: 1));
1117 // Multiply by 3 and get the top 2 bits
1118 qx0 = _mm256_avg_epu8(a: qx0, b: _mm256_avg_epu8(a: qx0, b: _mm256_setzero_si256()));
1119 qx1 = _mm256_avg_epu8(a: qx1, b: _mm256_avg_epu8(a: qx1, b: _mm256_setzero_si256()));
1120 qx2 = _mm256_avg_epu8(a: qx2, b: _mm256_avg_epu8(a: qx2, b: _mm256_setzero_si256()));
1121 qx3 = _mm256_avg_epu8(a: qx3, b: _mm256_avg_epu8(a: qx3, b: _mm256_setzero_si256()));
1122 qx4 = _mm256_avg_epu8(a: qx4, b: _mm256_avg_epu8(a: qx4, b: _mm256_setzero_si256()));
1123 qx0 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx0, count: 6), b: _mm256_set1_epi8(b: 3));
1124 qx1 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx1, count: 6), b: _mm256_set1_epi8(b: 3));
1125 qx2 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx2, count: 6), b: _mm256_set1_epi8(b: 3));
1126 qx3 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx3, count: 6), b: _mm256_set1_epi8(b: 3));
1127 qx4 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx4, count: 6), b: _mm256_set1_epi8(b: 3));
1128
1129 const __m256i qy0 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 0));
1130 const __m256i qy1 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 32));
1131 const __m256i qy2 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 64));
1132 const __m256i qy3 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 96));
1133 const __m256i qy4 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 128));
1134
1135 qx0 = _mm256_maddubs_epi16(a: qx0, b: qy0);
1136 qx1 = _mm256_maddubs_epi16(a: qx1, b: qy1);
1137 qx2 = _mm256_maddubs_epi16(a: qx2, b: qy2);
1138 qx3 = _mm256_maddubs_epi16(a: qx3, b: qy3);
1139 qx4 = _mm256_maddubs_epi16(a: qx4, b: qy4);
1140
1141 sumi0 = _mm256_add_epi16(a: sumi0, b: _mm256_add_epi16(a: qx0, b: qx1));
1142 sumi1 = _mm256_add_epi16(a: sumi1, b: _mm256_add_epi16(a: qx2, b: qx3));
1143 sumi2 = _mm256_add_epi16(a: sumi2, b: qx4);
1144 }
1145
1146 // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
1147 {
1148 __m128i qx0 = _mm_loadu_si128(p: (const __m128i *) (x[i].qs + 32));
1149 uint32_t qh;
1150 memcpy(dest: &qh, src: x[i].qh, n: sizeof(qh)); // potentially unaligned
1151 __m256i qx5_l = _mm256_cvtepu8_epi16(V: _mm_set1_epi32(i: qh));
1152 __m128i qx1 = _mm_add_epi8(a: qx0, b: _mm_add_epi8(a: qx0, b: qx0)); // 1 * 3
1153 __m128i qx2 = _mm_add_epi8(a: _mm_and_si128(a: _mm_slli_epi16(a: qx0, count: 3), b: _mm_set1_epi8(b: -8)), b: qx0); // 1 * 9
1154 __m128i qx3 = _mm_add_epi8(a: _mm_and_si128(a: _mm_slli_epi16(a: qx1, count: 3), b: _mm_set1_epi8(b: -8)), b: qx1); // 3 * 9
1155 __m128i qx4 = _mm_add_epi8(a: _mm_and_si128(a: _mm_slli_epi16(a: qx2, count: 3), b: _mm_set1_epi8(b: -8)), b: qx2); // 9 * 9
1156 __m256i qx01 = MM256_SET_M128I(qx1, qx0);
1157 __m256i qx23 = MM256_SET_M128I(qx3, qx2);
1158
1159 // avx2 does not have 8-bit multiplies, so 16-bit it is.
1160 qx5_l = _mm256_mullo_epi16(a: qx5_l, b: _mm256_set_epi16(w15: 27, w14: 27, w13: 27, w12: 27, w11: 9, w10: 9, w09: 9, w08: 9, w07: 3, w06: 3, w05: 3, w04: 3, w03: 1, w02: 1, w01: 1, w00: 1));
1161 qx5_l = _mm256_and_si256(a: qx5_l, b: _mm256_set1_epi16(w: 0xFF));
1162 __m128i qx5 = _mm_packus_epi16(a: _mm256_castsi256_si128(a: qx5_l), _mm256_extracti128_si256(qx5_l, 1));
1163
1164 __m256i qx45 = MM256_SET_M128I(qx5, qx4);
1165
1166 // Cancel the +1 from avg so that it behaves like a halving add
1167 qx01 = _mm256_subs_epu8(a: qx01, b: _mm256_set1_epi8(b: 1));
1168 qx23 = _mm256_subs_epu8(a: qx23, b: _mm256_set1_epi8(b: 1));
1169 qx45 = _mm256_subs_epu8(a: qx45, b: _mm256_set1_epi8(b: 1));
1170 // Multiply by 3 and get the top 2 bits
1171 qx01 = _mm256_avg_epu8(a: qx01, b: _mm256_avg_epu8(a: qx01, b: _mm256_setzero_si256()));
1172 qx23 = _mm256_avg_epu8(a: qx23, b: _mm256_avg_epu8(a: qx23, b: _mm256_setzero_si256()));
1173 qx45 = _mm256_avg_epu8(a: qx45, b: _mm256_avg_epu8(a: qx45, b: _mm256_setzero_si256()));
1174 qx01 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx01, count: 6), b: _mm256_set1_epi8(b: 3));
1175 qx23 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx23, count: 6), b: _mm256_set1_epi8(b: 3));
1176 qx45 = _mm256_and_si256(a: _mm256_srli_epi16(a: qx45, count: 6), b: _mm256_set1_epi8(b: 3));
1177
1178 const __m256i qy01 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 160));
1179 const __m256i qy23 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 192));
1180 const __m256i qy45 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + 224));
1181
1182 qx01 = _mm256_maddubs_epi16(a: qx01, b: qy01);
1183 qx23 = _mm256_maddubs_epi16(a: qx23, b: qy23);
1184 qx45 = _mm256_maddubs_epi16(a: qx45, b: qy45);
1185
1186 sumi0 = _mm256_add_epi16(a: sumi0, b: qx01);
1187 sumi1 = _mm256_add_epi16(a: sumi1, b: qx23);
1188 sumi2 = _mm256_add_epi16(a: sumi2, b: qx45);
1189 }
1190
1191 const __m256i ysum = _mm256_loadu_si256(p: (const __m256i *) y[i].bsums);
1192 const __m256 d = _mm256_set1_ps(w: y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1193
1194 sumi0 = _mm256_sub_epi16(a: sumi0, b: ysum);
1195 sumi0 = _mm256_add_epi16(a: sumi0, b: _mm256_add_epi16(a: sumi1, b: sumi2));
1196 sumi0 = _mm256_madd_epi16(a: sumi0, b: _mm256_set1_epi16(w: 1));
1197
1198 sumf = _mm256_add_ps(a: _mm256_mul_ps(a: _mm256_cvtepi32_ps(a: sumi0), b: d), b: sumf);
1199 }
1200
1201 *s = hsum_float_8(x: sumf);
1202
1203#else
1204 UNUSED(x);
1205 UNUSED(y);
1206 UNUSED(nb);
1207 ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1208#endif
1209}
1210
1211void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1212 assert(nrc == 1);
1213 UNUSED(nrc);
1214 UNUSED(bx);
1215 UNUSED(by);
1216 UNUSED(bs);
1217
1218 const block_tq2_0 * GGML_RESTRICT x = vx;
1219 const block_q8_K * GGML_RESTRICT y = vy;
1220
1221 const int nb = n / QK_K;
1222
1223#if defined(__AVX2__)
1224 __m256 sumf = _mm256_setzero_ps();
1225
1226 for (int i = 0; i < nb; ++i) {
1227 // 16-bit sums, because 256*127 still fits
1228 __m256i sumi0 = _mm256_setzero_si256();
1229 __m256i sumi1 = _mm256_setzero_si256();
1230
1231 for (size_t j = 0; j < sizeof(x->qs); j += 32) {
1232 __m256i qx0 = _mm256_loadu_si256(p: (const __m256i *) (x[i].qs + j));
1233 __m256i qx1 = _mm256_srli_epi16(a: qx0, count: 2);
1234 __m256i qx2 = _mm256_srli_epi16(a: qx0, count: 4);
1235 __m256i qx3 = _mm256_srli_epi16(a: qx0, count: 6);
1236
1237 // 0, 1, 2 (should not be 3)
1238 qx0 = _mm256_and_si256(a: qx0, b: _mm256_set1_epi8(b: 3));
1239 qx1 = _mm256_and_si256(a: qx1, b: _mm256_set1_epi8(b: 3));
1240 qx2 = _mm256_and_si256(a: qx2, b: _mm256_set1_epi8(b: 3));
1241 qx3 = _mm256_and_si256(a: qx3, b: _mm256_set1_epi8(b: 3));
1242
1243 const __m256i qy0 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + j*4 + 0));
1244 const __m256i qy1 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + j*4 + 32));
1245 const __m256i qy2 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + j*4 + 64));
1246 const __m256i qy3 = _mm256_loadu_si256(p: (const __m256i *) (y[i].qs + j*4 + 96));
1247
1248 qx0 = _mm256_maddubs_epi16(a: qx0, b: qy0);
1249 qx1 = _mm256_maddubs_epi16(a: qx1, b: qy1);
1250 qx2 = _mm256_maddubs_epi16(a: qx2, b: qy2);
1251 qx3 = _mm256_maddubs_epi16(a: qx3, b: qy3);
1252
1253 sumi0 = _mm256_add_epi16(a: sumi0, b: _mm256_add_epi16(a: qx0, b: qx1));
1254 sumi1 = _mm256_add_epi16(a: sumi1, b: _mm256_add_epi16(a: qx2, b: qx3));
1255 }
1256
1257 const __m256i ysum = _mm256_loadu_si256(p: (const __m256i *) y[i].bsums);
1258 const __m256 d = _mm256_set1_ps(w: y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d));
1259
1260 sumi0 = _mm256_add_epi16(a: sumi0, b: sumi1);
1261 sumi0 = _mm256_sub_epi16(a: sumi0, b: ysum);
1262 sumi0 = _mm256_madd_epi16(a: sumi0, b: _mm256_set1_epi16(w: 1));
1263
1264 sumf = _mm256_add_ps(a: _mm256_mul_ps(a: _mm256_cvtepi32_ps(a: sumi0), b: d), b: sumf);
1265 }
1266
1267 *s = hsum_float_8(x: sumf);
1268
1269#else
1270 UNUSED(x);
1271 UNUSED(y);
1272 UNUSED(nb);
1273 ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1274#endif
1275}
1276
1277void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1278 assert(nrc == 1);
1279 UNUSED(nrc);
1280 UNUSED(bx);
1281 UNUSED(by);
1282 UNUSED(bs);
1283
1284 const block_q2_K * GGML_RESTRICT x = vx;
1285 const block_q8_K * GGML_RESTRICT y = vy;
1286
1287 const int nb = n / QK_K;
1288
1289#if defined __AVX2__
1290
1291 const __m256i m3 = _mm256_set1_epi8(b: 3);
1292 const __m128i m4 = _mm_set1_epi8(b: 0xF);
1293
1294 __m256 acc = _mm256_setzero_ps();
1295
1296 for (int i = 0; i < nb; ++i) {
1297
1298 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1299 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1300
1301 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1302 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1303
1304 const __m128i mins_and_scales = _mm_loadu_si128(p: (const __m128i*)x[i].scales);
1305 const __m128i scales8 = _mm_and_si128(a: mins_and_scales, b: m4);
1306 const __m128i mins8 = _mm_and_si128(a: _mm_srli_epi16(a: mins_and_scales, count: 4), b: m4);
1307 const __m256i mins = _mm256_cvtepi8_epi16(V: mins8);
1308 const __m256i prod = _mm256_madd_epi16(a: mins, b: _mm256_loadu_si256(p: (const __m256i*)y[i].bsums));
1309
1310 acc = _mm256_fmadd_ps(A: _mm256_broadcast_ss(a: &dmin), B: _mm256_cvtepi32_ps(a: prod), C: acc);
1311
1312 const __m256i all_scales = _mm256_cvtepi8_epi16(V: scales8);
1313 const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1314 const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1315 const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1316
1317 __m256i sumi = _mm256_setzero_si256();
1318
1319 for (int j = 0; j < QK_K/128; ++j) {
1320
1321 const __m256i q2bits = _mm256_loadu_si256(p: (const __m256i*)q2); q2 += 32;
1322
1323 const __m256i q8_0 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1324 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1325 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1326 const __m256i q8_3 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1327
1328 const __m256i q2_0 = _mm256_and_si256(a: q2bits, b: m3);
1329 const __m256i q2_1 = _mm256_and_si256(a: _mm256_srli_epi16(a: q2bits, count: 2), b: m3);
1330 const __m256i q2_2 = _mm256_and_si256(a: _mm256_srli_epi16(a: q2bits, count: 4), b: m3);
1331 const __m256i q2_3 = _mm256_and_si256(a: _mm256_srli_epi16(a: q2bits, count: 6), b: m3);
1332
1333 __m256i p0 = _mm256_maddubs_epi16(a: q2_0, b: q8_0);
1334 __m256i p1 = _mm256_maddubs_epi16(a: q2_1, b: q8_1);
1335 __m256i p2 = _mm256_maddubs_epi16(a: q2_2, b: q8_2);
1336 __m256i p3 = _mm256_maddubs_epi16(a: q2_3, b: q8_3);
1337
1338 p0 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: 0)), b: p0);
1339 p1 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: 1)), b: p1);
1340 p2 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: 2)), b: p2);
1341 p3 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: 3)), b: p3);
1342
1343 p0 = _mm256_add_epi32(a: p0, b: p1);
1344 p2 = _mm256_add_epi32(a: p2, b: p3);
1345
1346 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p0, b: p2));
1347 }
1348
1349 acc = _mm256_fmadd_ps(A: _mm256_broadcast_ss(a: &d), B: _mm256_cvtepi32_ps(a: sumi), C: acc);
1350
1351 }
1352
1353 *s = hsum_float_8(x: acc);
1354
1355#elif defined __AVX__
1356
1357 const __m128i m3 = _mm_set1_epi8(0x3);
1358 const __m128i m4 = _mm_set1_epi8(0xF);
1359 const __m128i m2 = _mm_set1_epi8(0x2);
1360
1361 __m256 acc = _mm256_setzero_ps();
1362
1363 for (int i = 0; i < nb; ++i) {
1364
1365 const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1366 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1367
1368 const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1369 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1370
1371 // load mins and scales from block_q2_K.scales[QK_K/16]
1372 const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
1373 const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
1374 const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
1375 const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
1376 const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
1377
1378 // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
1379 const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
1380 const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
1381
1382 // sumf += -dmin * summs in 32bits*8
1383 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
1384
1385 const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
1386 const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
1387 const __m128i scales[2] = { scales_0, scales_1 };
1388
1389 __m128i sumi_0 = _mm_setzero_si128();
1390 __m128i sumi_1 = _mm_setzero_si128();
1391
1392 for (int j = 0; j < QK_K/128; ++j) {
1393
1394 // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
1395 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1396 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1397 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1398 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1399 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1400 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1401 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1402 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1403
1404 // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
1405 __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1406 const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1407 const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1408 const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1409 const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1410 q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
1411 const __m128i q2_1 = _mm_and_si128(q2bits, m3);
1412 const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1413 const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1414 const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1415
1416 // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
1417 __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
1418 __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
1419 __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
1420 __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
1421 __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
1422 __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
1423 __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
1424 __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
1425
1426 // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
1427 __m128i shuffle = _mm_set1_epi16(0x0100);
1428 p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
1429 shuffle = _mm_add_epi16(shuffle, m2);
1430 p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
1431 shuffle = _mm_add_epi16(shuffle, m2);
1432 p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
1433 shuffle = _mm_add_epi16(shuffle, m2);
1434 p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
1435 shuffle = _mm_add_epi16(shuffle, m2);
1436 p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
1437 shuffle = _mm_add_epi16(shuffle, m2);
1438 p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
1439 shuffle = _mm_add_epi16(shuffle, m2);
1440 p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
1441 shuffle = _mm_add_epi16(shuffle, m2);
1442 p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
1443
1444 p0 = _mm_add_epi32(p0, p1);
1445 p2 = _mm_add_epi32(p2, p3);
1446 p4 = _mm_add_epi32(p4, p5);
1447 p6 = _mm_add_epi32(p6, p7);
1448
1449 // isum in 32bits*4*2
1450 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
1451 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
1452 }
1453
1454 // sumf += dall * isum - dmin * summs in 32bits
1455 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1456 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
1457 }
1458
1459 *s = hsum_float_8(acc);
1460
1461#else
1462 UNUSED(x);
1463 UNUSED(y);
1464 UNUSED(nb);
1465 ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1466#endif
1467}
1468
1469void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1470 assert(n % QK_K == 0);
1471 assert(nrc == 1);
1472 UNUSED(nrc);
1473 UNUSED(bx);
1474 UNUSED(by);
1475 UNUSED(bs);
1476
1477 const uint32_t kmask1 = 0x03030303;
1478 const uint32_t kmask2 = 0x0f0f0f0f;
1479
1480 const block_q3_K * GGML_RESTRICT x = vx;
1481 const block_q8_K * GGML_RESTRICT y = vy;
1482
1483 const int nb = n / QK_K;
1484
1485#if defined __AVX2__
1486
1487 const __m256i m3 = _mm256_set1_epi8(b: 3);
1488 const __m256i mone = _mm256_set1_epi8(b: 1);
1489 const __m128i m32 = _mm_set1_epi8(b: 32);
1490
1491 __m256 acc = _mm256_setzero_ps();
1492
1493 uint32_t aux[3];
1494
1495 for (int i = 0; i < nb; ++i) {
1496
1497 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1498
1499 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1500 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1501
1502 // Set up scales
1503 memcpy(dest: aux, src: x[i].scales, n: 12);
1504 __m128i scales128 = _mm_set_epi32(
1505 i3: ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1506 i2: ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1507 i1: (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1508 i0: (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1509 scales128 = _mm_sub_epi8(a: scales128, b: m32);
1510 const __m256i all_scales = _mm256_cvtepi8_epi16(V: scales128);
1511 const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
1512 const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
1513 const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
1514
1515 // high bit
1516 const __m256i hbits = _mm256_loadu_si256(p: (const __m256i*)x[i].hmask);
1517
1518 // integer accumulator
1519 __m256i sumi = _mm256_setzero_si256();
1520
1521 int bit = 0;
1522 int is = 0;
1523
1524 for (int j = 0; j < QK_K/128; ++j) {
1525 // load low 2 bits
1526 const __m256i q3bits = _mm256_loadu_si256(p: (const __m256i*)q3); q3 += 32;
1527
1528 // prepare low and high bits
1529 const __m256i q3l_0 = _mm256_and_si256(a: q3bits, b: m3);
1530 const __m256i q3h_0 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_andnot_si256(a: hbits, b: _mm256_slli_epi16(a: mone, count: bit)), count: bit), count: 2);
1531 ++bit;
1532
1533 const __m256i q3l_1 = _mm256_and_si256(a: _mm256_srli_epi16(a: q3bits, count: 2), b: m3);
1534 const __m256i q3h_1 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_andnot_si256(a: hbits, b: _mm256_slli_epi16(a: mone, count: bit)), count: bit), count: 2);
1535 ++bit;
1536
1537 const __m256i q3l_2 = _mm256_and_si256(a: _mm256_srli_epi16(a: q3bits, count: 4), b: m3);
1538 const __m256i q3h_2 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_andnot_si256(a: hbits, b: _mm256_slli_epi16(a: mone, count: bit)), count: bit), count: 2);
1539 ++bit;
1540
1541 const __m256i q3l_3 = _mm256_and_si256(a: _mm256_srli_epi16(a: q3bits, count: 6), b: m3);
1542 const __m256i q3h_3 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_andnot_si256(a: hbits, b: _mm256_slli_epi16(a: mone, count: bit)), count: bit), count: 2);
1543 ++bit;
1544
1545 // load Q8 quants
1546 const __m256i q8_0 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1547 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1548 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1549 const __m256i q8_3 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1550
1551 // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1552 // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1553 // and 2 if the high bit was set)
1554 __m256i q8s_0 = _mm256_maddubs_epi16(a: q3h_0, b: q8_0);
1555 __m256i q8s_1 = _mm256_maddubs_epi16(a: q3h_1, b: q8_1);
1556 __m256i q8s_2 = _mm256_maddubs_epi16(a: q3h_2, b: q8_2);
1557 __m256i q8s_3 = _mm256_maddubs_epi16(a: q3h_3, b: q8_3);
1558
1559 __m256i p16_0 = _mm256_maddubs_epi16(a: q3l_0, b: q8_0);
1560 __m256i p16_1 = _mm256_maddubs_epi16(a: q3l_1, b: q8_1);
1561 __m256i p16_2 = _mm256_maddubs_epi16(a: q3l_2, b: q8_2);
1562 __m256i p16_3 = _mm256_maddubs_epi16(a: q3l_3, b: q8_3);
1563
1564 p16_0 = _mm256_sub_epi16(a: p16_0, b: q8s_0);
1565 p16_1 = _mm256_sub_epi16(a: p16_1, b: q8s_1);
1566 p16_2 = _mm256_sub_epi16(a: p16_2, b: q8s_2);
1567 p16_3 = _mm256_sub_epi16(a: p16_3, b: q8s_3);
1568
1569 // multiply with scales
1570 p16_0 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: is + 0)), b: p16_0);
1571 p16_1 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: is + 1)), b: p16_1);
1572 p16_2 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: is + 2)), b: p16_2);
1573 p16_3 = _mm256_madd_epi16(a: _mm256_shuffle_epi8(a: scales[j], b: get_scale_shuffle_q3k(i: is + 3)), b: p16_3);
1574
1575 // accumulate
1576 p16_0 = _mm256_add_epi32(a: p16_0, b: p16_1);
1577 p16_2 = _mm256_add_epi32(a: p16_2, b: p16_3);
1578 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p16_0, b: p16_2));
1579
1580 }
1581
1582 // multiply with block scale and accumulate
1583 acc = _mm256_fmadd_ps(A: _mm256_broadcast_ss(a: &d), B: _mm256_cvtepi32_ps(a: sumi), C: acc);
1584
1585 }
1586
1587 *s = hsum_float_8(x: acc);
1588
1589#elif defined __AVX__
1590
1591 const __m128i m3 = _mm_set1_epi8(3);
1592 const __m128i mone = _mm_set1_epi8(1);
1593 const __m128i m32 = _mm_set1_epi8(32);
1594 const __m128i m2 = _mm_set1_epi8(2);
1595
1596 __m256 acc = _mm256_setzero_ps();
1597
1598 const uint32_t *aux;
1599
1600 for (int i = 0; i < nb; ++i) {
1601
1602 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1603
1604 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1605 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1606
1607 // Set up scales
1608 aux = (const uint32_t *)x[i].scales;
1609 __m128i scales128 = _mm_set_epi32(
1610 ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1611 ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1612 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1613 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1614 scales128 = _mm_sub_epi8(scales128, m32);
1615 const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
1616 const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
1617 const __m128i scales[2] = { scales_0, scales_1 };
1618
1619 // high bit *128*2 from block_q3_K.hmask[QK_K/8]
1620 const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
1621 const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
1622
1623 // integer accumulator
1624 __m128i sumi_0 = _mm_setzero_si128();
1625 __m128i sumi_1 = _mm_setzero_si128();
1626
1627 for (int j = 0; j < QK_K/128; ++j) {
1628 // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
1629 const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1630 const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
1631
1632 // prepare low and high bits
1633 const int bit = j << 2;
1634
1635 const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
1636 const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
1637 const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
1638 const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
1639
1640 const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
1641 const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
1642 const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1643 const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
1644
1645 const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
1646 const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
1647 const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1648 const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
1649
1650 const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
1651 const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
1652 const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1653 const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
1654
1655 // load Q8 quants from block_q8_K.qs[QK_K]
1656 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1657 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1658 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1659 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1660 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1661 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1662 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1663 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1664
1665 // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
1666 // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
1667 // and 2 if the high bit was set)
1668 __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
1669 __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
1670 __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
1671 __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
1672 __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
1673 __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
1674 __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
1675 __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
1676
1677 __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
1678 __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
1679 __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
1680 __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
1681 __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
1682 __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
1683 __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
1684 __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
1685
1686 p16_0 = _mm_sub_epi16(p16_0, q8s_0);
1687 p16_1 = _mm_sub_epi16(p16_1, q8s_1);
1688 p16_2 = _mm_sub_epi16(p16_2, q8s_2);
1689 p16_3 = _mm_sub_epi16(p16_3, q8s_3);
1690 p16_4 = _mm_sub_epi16(p16_4, q8s_4);
1691 p16_5 = _mm_sub_epi16(p16_5, q8s_5);
1692 p16_6 = _mm_sub_epi16(p16_6, q8s_6);
1693 p16_7 = _mm_sub_epi16(p16_7, q8s_7);
1694
1695 // multiply with scales
1696 __m128i shuffle = _mm_set1_epi16(0x0100);
1697 p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
1698 shuffle = _mm_add_epi16(shuffle, m2);
1699 p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
1700 shuffle = _mm_add_epi16(shuffle, m2);
1701 p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
1702 shuffle = _mm_add_epi16(shuffle, m2);
1703 p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
1704 shuffle = _mm_add_epi16(shuffle, m2);
1705 p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
1706 shuffle = _mm_add_epi16(shuffle, m2);
1707 p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
1708 shuffle = _mm_add_epi16(shuffle, m2);
1709 p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
1710 shuffle = _mm_add_epi16(shuffle, m2);
1711 p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
1712
1713 // accumulate
1714 p16_0 = _mm_add_epi32(p16_0, p16_1);
1715 p16_2 = _mm_add_epi32(p16_2, p16_3);
1716 p16_4 = _mm_add_epi32(p16_4, p16_5);
1717 p16_6 = _mm_add_epi32(p16_6, p16_7);
1718 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
1719 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
1720
1721 }
1722
1723 // multiply with block scale and accumulate
1724 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1725 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
1726
1727 }
1728
1729 *s = hsum_float_8(acc);
1730
1731#else
1732 UNUSED(kmask1);
1733 UNUSED(kmask2);
1734 UNUSED(x);
1735 UNUSED(y);
1736 UNUSED(nb);
1737 ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1738#endif
1739}
1740
1741void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1742 assert(n % QK_K == 0);
1743 assert(nrc == 1);
1744 UNUSED(nrc);
1745 UNUSED(bx);
1746 UNUSED(by);
1747 UNUSED(bs);
1748
1749 const block_q4_K * GGML_RESTRICT x = vx;
1750 const block_q8_K * GGML_RESTRICT y = vy;
1751
1752 const int nb = n / QK_K;
1753
1754 static const uint32_t kmask1 = 0x3f3f3f3f;
1755 static const uint32_t kmask2 = 0x0f0f0f0f;
1756 static const uint32_t kmask3 = 0x03030303;
1757
1758 uint32_t utmp[4];
1759
1760#if defined __AVX2__
1761
1762 const __m256i m4 = _mm256_set1_epi8(b: 0xF);
1763
1764 __m256 acc = _mm256_setzero_ps();
1765 __m128 acc_m = _mm_setzero_ps();
1766
1767 for (int i = 0; i < nb; ++i) {
1768
1769 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1770 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1771
1772 memcpy(dest: utmp, src: x[i].scales, n: 12);
1773 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1774 const uint32_t uaux = utmp[1] & kmask1;
1775 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1776 utmp[2] = uaux;
1777 utmp[0] &= kmask1;
1778
1779 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1780 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1781
1782 const __m256i mins_and_scales = _mm256_cvtepu8_epi16(V: _mm_set_epi32(i3: utmp[3], i2: utmp[2], i1: utmp[1], i0: utmp[0]));
1783
1784 const __m256i q8sums = _mm256_loadu_si256(p: (const __m256i*)y[i].bsums);
1785 const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1786 const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), b: q8s);
1787 acc_m = _mm_fmadd_ps(A: _mm_set1_ps(w: dmin), B: _mm_cvtepi32_ps(a: prod), C: acc_m);
1788
1789 const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1790 const __m256i scales = MM256_SET_M128I(sc128, sc128);
1791
1792 __m256i sumi = _mm256_setzero_si256();
1793
1794 for (int j = 0; j < QK_K/64; ++j) {
1795
1796 const __m256i scale_l = _mm256_shuffle_epi8(a: scales, b: get_scale_shuffle_k4(i: 2*j+0));
1797 const __m256i scale_h = _mm256_shuffle_epi8(a: scales, b: get_scale_shuffle_k4(i: 2*j+1));
1798
1799 const __m256i q4bits = _mm256_loadu_si256(p: (const __m256i*)q4); q4 += 32;
1800 const __m256i q4l = _mm256_and_si256(a: q4bits, b: m4);
1801 const __m256i q4h = _mm256_and_si256(a: _mm256_srli_epi16(a: q4bits, count: 4), b: m4);
1802
1803 const __m256i q8l = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1804 __m256i p16l = _mm256_maddubs_epi16(a: q4l, b: q8l);
1805 p16l = _mm256_madd_epi16(a: scale_l, b: p16l);
1806
1807 const __m256i q8h = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1808 __m256i p16h = _mm256_maddubs_epi16(a: q4h, b: q8h);
1809 p16h = _mm256_madd_epi16(a: scale_h, b: p16h);
1810 const __m256i sumj = _mm256_add_epi32(a: p16l, b: p16h);
1811
1812 sumi = _mm256_add_epi32(a: sumi, b: sumj);
1813 }
1814
1815 __m256 vd = _mm256_set1_ps(w: d);
1816 acc = _mm256_fmadd_ps(A: vd, B: _mm256_cvtepi32_ps(a: sumi), C: acc);
1817
1818 }
1819
1820 acc_m = _mm_add_ps(a: acc_m, b: _mm_movehl_ps(a: acc_m, b: acc_m));
1821 acc_m = _mm_add_ss(a: acc_m, b: _mm_movehdup_ps(a: acc_m));
1822
1823 *s = hsum_float_8(x: acc) + _mm_cvtss_f32(a: acc_m);
1824
1825#elif defined __AVX__
1826
1827 const __m128i m4 = _mm_set1_epi8(0xF);
1828 const __m128i m2 = _mm_set1_epi8(0x2);
1829
1830 __m256 acc = _mm256_setzero_ps();
1831 __m128 acc_m = _mm_setzero_ps();
1832
1833 for (int i = 0; i < nb; ++i) {
1834
1835 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1836 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1837
1838 const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1839 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1840
1841 memcpy(utmp, x[i].scales, 12);
1842 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1843 const uint32_t uaux = utmp[1] & kmask1;
1844 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1845 utmp[2] = uaux;
1846 utmp[0] &= kmask1;
1847
1848 const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
1849 const __m128i scales = _mm_cvtepu8_epi16(utmps);
1850 const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
1851
1852 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
1853 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
1854 const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
1855 const __m128i prod = _mm_madd_epi16(mins, q8s);
1856 acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
1857
1858 __m128i sumi_0 = _mm_setzero_si128();
1859 __m128i sumi_1 = _mm_setzero_si128();
1860
1861 __m128i shuffle = _mm_set1_epi16(0x0100);
1862 for (int j = 0; j < QK_K/64; ++j) {
1863
1864 const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
1865 shuffle = _mm_add_epi16(shuffle, m2);
1866 const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
1867 shuffle = _mm_add_epi16(shuffle, m2);
1868
1869 __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1870 const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
1871 const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1872 q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
1873 const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
1874 const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
1875
1876 const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1877 __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
1878 p16l = _mm_madd_epi16(scale_l, p16l);
1879 sumi_0 = _mm_add_epi32(sumi_0, p16l);
1880 const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1881 p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
1882 p16l = _mm_madd_epi16(scale_l, p16l);
1883 sumi_1 = _mm_add_epi32(sumi_1, p16l);
1884
1885 const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1886 __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
1887 p16h = _mm_madd_epi16(scale_h, p16h);
1888 sumi_0 = _mm_add_epi32(sumi_0, p16h);
1889 const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
1890 p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
1891 p16h = _mm_madd_epi16(scale_h, p16h);
1892 sumi_1 = _mm_add_epi32(sumi_1, p16h);
1893
1894 }
1895
1896 __m256 vd = _mm256_set1_ps(d);
1897 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
1898 acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
1899
1900 }
1901
1902 acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
1903 acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
1904
1905 *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
1906
1907#else
1908 UNUSED(x);
1909 UNUSED(y);
1910 UNUSED(nb);
1911 UNUSED(kmask1);
1912 UNUSED(kmask2);
1913 UNUSED(kmask3);
1914 UNUSED(utmp);
1915 ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1916#endif
1917}
1918
1919void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1920 assert(n % QK_K == 0);
1921 assert(nrc == 1);
1922 UNUSED(nrc);
1923 UNUSED(bx);
1924 UNUSED(by);
1925 UNUSED(bs);
1926
1927 const block_q5_K * GGML_RESTRICT x = vx;
1928 const block_q8_K * GGML_RESTRICT y = vy;
1929
1930 const int nb = n / QK_K;
1931
1932 static const uint32_t kmask1 = 0x3f3f3f3f;
1933 static const uint32_t kmask2 = 0x0f0f0f0f;
1934 static const uint32_t kmask3 = 0x03030303;
1935
1936 uint32_t utmp[4];
1937
1938#if defined __AVX2__
1939
1940 const __m256i m4 = _mm256_set1_epi8(b: 0xF);
1941 const __m128i mzero = _mm_setzero_si128();
1942 const __m256i mone = _mm256_set1_epi8(b: 1);
1943
1944 __m256 acc = _mm256_setzero_ps();
1945
1946 float summs = 0.f;
1947
1948 for (int i = 0; i < nb; ++i) {
1949 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1950 const int8_t * GGML_RESTRICT q8 = y[i].qs;
1951
1952 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1953 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1954
1955 memcpy(dest: utmp, src: x[i].scales, n: 12);
1956 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1957 const uint32_t uaux = utmp[1] & kmask1;
1958 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1959 utmp[2] = uaux;
1960 utmp[0] &= kmask1;
1961
1962 const __m256i mins_and_scales = _mm256_cvtepu8_epi16(V: _mm_set_epi32(i3: utmp[3], i2: utmp[2], i1: utmp[1], i0: utmp[0]));
1963
1964 const __m256i q8sums = _mm256_loadu_si256(p: (const __m256i*)y[i].bsums);
1965 const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
1966 const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), b: q8s);
1967 const __m128i hsum = _mm_hadd_epi32(a: _mm_hadd_epi32(a: prod, b: mzero), b: mzero);
1968 summs += dmin * _mm_extract_epi32(hsum, 0);
1969
1970 const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
1971 const __m256i scales = MM256_SET_M128I(sc128, sc128);
1972
1973 const __m256i hbits = _mm256_loadu_si256(p: (const __m256i*)x[i].qh);
1974 __m256i hmask = mone;
1975
1976 __m256i sumi = _mm256_setzero_si256();
1977
1978 int bit = 0;
1979
1980 for (int j = 0; j < QK_K/64; ++j) {
1981
1982 const __m256i scale_0 = _mm256_shuffle_epi8(a: scales, b: get_scale_shuffle_k4(i: 2*j+0));
1983 const __m256i scale_1 = _mm256_shuffle_epi8(a: scales, b: get_scale_shuffle_k4(i: 2*j+1));
1984
1985 const __m256i q5bits = _mm256_loadu_si256(p: (const __m256i*)q5); q5 += 32;
1986
1987 const __m256i q5l_0 = _mm256_and_si256(a: q5bits, b: m4);
1988 const __m256i q5h_0 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_and_si256(a: hbits, b: hmask), count: bit++), count: 4);
1989 const __m256i q5_0 = _mm256_add_epi8(a: q5l_0, b: q5h_0);
1990 hmask = _mm256_slli_epi16(a: hmask, count: 1);
1991
1992 const __m256i q5l_1 = _mm256_and_si256(a: _mm256_srli_epi16(a: q5bits, count: 4), b: m4);
1993 const __m256i q5h_1 = _mm256_slli_epi16(a: _mm256_srli_epi16(a: _mm256_and_si256(a: hbits, b: hmask), count: bit++), count: 4);
1994 const __m256i q5_1 = _mm256_add_epi8(a: q5l_1, b: q5h_1);
1995 hmask = _mm256_slli_epi16(a: hmask, count: 1);
1996
1997 const __m256i q8_0 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1998 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
1999
2000 __m256i p16_0 = _mm256_maddubs_epi16(a: q5_0, b: q8_0);
2001 __m256i p16_1 = _mm256_maddubs_epi16(a: q5_1, b: q8_1);
2002
2003 p16_0 = _mm256_madd_epi16(a: scale_0, b: p16_0);
2004 p16_1 = _mm256_madd_epi16(a: scale_1, b: p16_1);
2005
2006 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p16_0, b: p16_1));
2007
2008 }
2009
2010 __m256 vd = _mm256_set1_ps(w: d);
2011 acc = _mm256_fmadd_ps(A: vd, B: _mm256_cvtepi32_ps(a: sumi), C: acc);
2012
2013 }
2014
2015 *s = hsum_float_8(x: acc) + summs;
2016
2017#elif defined __AVX__
2018
2019 const __m128i m4 = _mm_set1_epi8(0xF);
2020 const __m128i mzero = _mm_setzero_si128();
2021 const __m128i mone = _mm_set1_epi8(1);
2022 const __m128i m2 = _mm_set1_epi8(2);
2023
2024 __m256 acc = _mm256_setzero_ps();
2025
2026 float summs = 0.f;
2027
2028 for (int i = 0; i < nb; ++i) {
2029
2030 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2031 const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
2032
2033 const uint8_t * GGML_RESTRICT q5 = x[i].qs;
2034 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2035
2036 memcpy(utmp, x[i].scales, 12);
2037 utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
2038 const uint32_t uaux = utmp[1] & kmask1;
2039 utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
2040 utmp[2] = uaux;
2041 utmp[0] &= kmask1;
2042
2043 const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
2044 const __m128i scales = _mm_cvtepu8_epi16(utmps);
2045 const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
2046
2047 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
2048 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
2049 const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
2050 const __m128i prod = _mm_madd_epi16(mins, q8s);
2051 const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
2052 summs += dmin * _mm_extract_epi32(hsum, 0);
2053
2054 const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
2055 const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
2056 __m128i hmask = mone;
2057
2058 __m128i sumi_0 = _mm_setzero_si128();
2059 __m128i sumi_1 = _mm_setzero_si128();
2060
2061 int bit = 0;
2062
2063 __m128i shuffle = _mm_set1_epi16(0x0100);
2064 for (int j = 0; j < QK_K/64; ++j) {
2065
2066 const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
2067 shuffle = _mm_add_epi16(shuffle, m2);
2068 const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
2069 shuffle = _mm_add_epi16(shuffle, m2);
2070
2071 const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2072 const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
2073
2074 __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
2075 __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
2076 __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2077 __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2078 __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2079 __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2080 hmask = _mm_slli_epi16(hmask, 1);
2081
2082 __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2083 __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2084 __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
2085 __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
2086 p16_0 = _mm_madd_epi16(scale_0, p16_0);
2087 p16_1 = _mm_madd_epi16(scale_0, p16_1);
2088
2089 q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
2090 q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
2091 q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
2092 q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
2093 q5_0 = _mm_add_epi8(q5l_0, q5h_0);
2094 q5_1 = _mm_add_epi8(q5l_1, q5h_1);
2095 hmask = _mm_slli_epi16(hmask, 1);
2096
2097 q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2098 q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2099 __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
2100 __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
2101 p16_2 = _mm_madd_epi16(scale_1, p16_2);
2102 p16_3 = _mm_madd_epi16(scale_1, p16_3);
2103
2104 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2105 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2106
2107 }
2108
2109 __m256 vd = _mm256_set1_ps(d);
2110 __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2111 acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
2112
2113 }
2114
2115 *s = hsum_float_8(acc) + summs;
2116
2117#else
2118 UNUSED(x);
2119 UNUSED(y);
2120 UNUSED(nb);
2121 UNUSED(kmask1);
2122 UNUSED(kmask2);
2123 UNUSED(kmask3);
2124 UNUSED(utmp);
2125 ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2126#endif
2127}
2128
2129void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2130 assert(n % QK_K == 0);
2131 assert(nrc == 1);
2132 UNUSED(nrc);
2133 UNUSED(bx);
2134 UNUSED(by);
2135 UNUSED(bs);
2136
2137 const block_q6_K * GGML_RESTRICT x = vx;
2138 const block_q8_K * GGML_RESTRICT y = vy;
2139
2140 const int nb = n / QK_K;
2141
2142#if defined __AVX2__
2143
2144 const __m256i m4 = _mm256_set1_epi8(b: 0xF);
2145 const __m256i m2 = _mm256_set1_epi8(b: 3);
2146 const __m256i m32s = _mm256_set1_epi8(b: 32);
2147
2148 __m256 acc = _mm256_setzero_ps();
2149
2150 for (int i = 0; i < nb; ++i) {
2151
2152 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2153
2154 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2155 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2156 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2157
2158 const __m128i scales = _mm_loadu_si128(p: (const __m128i*)x[i].scales);
2159
2160 __m256i sumi = _mm256_setzero_si256();
2161
2162 int is = 0;
2163
2164 for (int j = 0; j < QK_K/128; ++j) {
2165
2166 const __m128i scale_0 = _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: is + 0));
2167 const __m128i scale_1 = _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: is + 1));
2168 const __m128i scale_2 = _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: is + 2));
2169 const __m128i scale_3 = _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: is + 3));
2170 is += 4;
2171
2172 const __m256i q4bits1 = _mm256_loadu_si256(p: (const __m256i*)q4); q4 += 32;
2173 const __m256i q4bits2 = _mm256_loadu_si256(p: (const __m256i*)q4); q4 += 32;
2174 const __m256i q4bitsH = _mm256_loadu_si256(p: (const __m256i*)qh); qh += 32;
2175
2176 const __m256i q4h_0 = _mm256_slli_epi16(a: _mm256_and_si256(a: q4bitsH, b: m2), count: 4);
2177 const __m256i q4h_1 = _mm256_slli_epi16(a: _mm256_and_si256(a: _mm256_srli_epi16(a: q4bitsH, count: 2), b: m2), count: 4);
2178 const __m256i q4h_2 = _mm256_slli_epi16(a: _mm256_and_si256(a: _mm256_srli_epi16(a: q4bitsH, count: 4), b: m2), count: 4);
2179 const __m256i q4h_3 = _mm256_slli_epi16(a: _mm256_and_si256(a: _mm256_srli_epi16(a: q4bitsH, count: 6), b: m2), count: 4);
2180
2181 const __m256i q4_0 = _mm256_or_si256(a: _mm256_and_si256(a: q4bits1, b: m4), b: q4h_0);
2182 const __m256i q4_1 = _mm256_or_si256(a: _mm256_and_si256(a: q4bits2, b: m4), b: q4h_1);
2183 const __m256i q4_2 = _mm256_or_si256(a: _mm256_and_si256(a: _mm256_srli_epi16(a: q4bits1, count: 4), b: m4), b: q4h_2);
2184 const __m256i q4_3 = _mm256_or_si256(a: _mm256_and_si256(a: _mm256_srli_epi16(a: q4bits2, count: 4), b: m4), b: q4h_3);
2185
2186 const __m256i q8_0 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
2187 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
2188 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
2189 const __m256i q8_3 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
2190
2191 __m256i q8s_0 = _mm256_maddubs_epi16(a: m32s, b: q8_0);
2192 __m256i q8s_1 = _mm256_maddubs_epi16(a: m32s, b: q8_1);
2193 __m256i q8s_2 = _mm256_maddubs_epi16(a: m32s, b: q8_2);
2194 __m256i q8s_3 = _mm256_maddubs_epi16(a: m32s, b: q8_3);
2195
2196 __m256i p16_0 = _mm256_maddubs_epi16(a: q4_0, b: q8_0);
2197 __m256i p16_1 = _mm256_maddubs_epi16(a: q4_1, b: q8_1);
2198 __m256i p16_2 = _mm256_maddubs_epi16(a: q4_2, b: q8_2);
2199 __m256i p16_3 = _mm256_maddubs_epi16(a: q4_3, b: q8_3);
2200
2201 p16_0 = _mm256_sub_epi16(a: p16_0, b: q8s_0);
2202 p16_1 = _mm256_sub_epi16(a: p16_1, b: q8s_1);
2203 p16_2 = _mm256_sub_epi16(a: p16_2, b: q8s_2);
2204 p16_3 = _mm256_sub_epi16(a: p16_3, b: q8s_3);
2205
2206 p16_0 = _mm256_madd_epi16(a: _mm256_cvtepi8_epi16(V: scale_0), b: p16_0);
2207 p16_1 = _mm256_madd_epi16(a: _mm256_cvtepi8_epi16(V: scale_1), b: p16_1);
2208 p16_2 = _mm256_madd_epi16(a: _mm256_cvtepi8_epi16(V: scale_2), b: p16_2);
2209 p16_3 = _mm256_madd_epi16(a: _mm256_cvtepi8_epi16(V: scale_3), b: p16_3);
2210
2211 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p16_0, b: p16_1));
2212 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p16_2, b: p16_3));
2213
2214 }
2215
2216 acc = _mm256_fmadd_ps(A: _mm256_broadcast_ss(a: &d), B: _mm256_cvtepi32_ps(a: sumi), C: acc);
2217 }
2218
2219 *s = hsum_float_8(x: acc);
2220
2221#elif defined __AVX__
2222
2223 const __m128i m3 = _mm_set1_epi8(3);
2224 const __m128i m15 = _mm_set1_epi8(15);
2225
2226 __m256 acc = _mm256_setzero_ps();
2227
2228 for (int i = 0; i < nb; ++i) {
2229
2230 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2231
2232 const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2233 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2234 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2235
2236 // handle the q6_k -32 offset separately using bsums
2237 const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
2238 const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
2239 const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
2240 const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
2241 const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
2242 const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
2243 const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
2244
2245 __m128i sumi_0 = _mm_setzero_si128();
2246 __m128i sumi_1 = _mm_setzero_si128();
2247
2248 int is = 0;
2249
2250 for (int j = 0; j < QK_K/128; ++j) {
2251
2252 const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2253 const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
2254
2255 const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
2256 const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
2257 const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
2258 const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
2259 const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
2260 const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
2261 const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
2262 const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
2263
2264 const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2265 const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2266 const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2267 const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
2268
2269 const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
2270 const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
2271 const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
2272 const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
2273 const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
2274 const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
2275 const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
2276 const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
2277
2278 const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2279 const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2280 const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2281 const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2282 const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2283 const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2284 const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2285 const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
2286
2287 __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
2288 __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
2289 __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
2290 __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
2291 __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
2292 __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
2293 __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
2294 __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
2295
2296 const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
2297 const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
2298 const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
2299 const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
2300 is += 4;
2301
2302 p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
2303 p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
2304 p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
2305 p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
2306 p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
2307 p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
2308 p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
2309 p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
2310
2311 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
2312 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
2313 sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
2314 sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
2315
2316 }
2317
2318 sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
2319 sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
2320 const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
2321 acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
2322 }
2323
2324 *s = hsum_float_8(acc);
2325
2326#else
2327 UNUSED(x);
2328 UNUSED(y);
2329 UNUSED(nb);
2330 ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2331#endif
2332}
2333
2334#if defined (__AVX__) || defined (__AVX2__)
2335static const int8_t keven_signs_q2xs[1024] = {
2336 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
2337 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
2338 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
2339 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
2340 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
2341 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
2342 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
2343 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
2344 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
2345 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
2346 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
2347 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
2348 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
2349 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
2350 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
2351 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
2352 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
2353 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
2354 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
2355 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
2356 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
2357 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
2358 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
2359 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
2360 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
2361 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
2362 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
2363 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
2364 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
2365 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
2366 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
2367 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
2368};
2369#endif
2370
2371void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2372 assert(n % QK_K == 0);
2373 assert(nrc == 1);
2374 UNUSED(nrc);
2375 UNUSED(bx);
2376 UNUSED(by);
2377 UNUSED(bs);
2378
2379 const block_iq2_xxs * GGML_RESTRICT x = vx;
2380 const block_q8_K * GGML_RESTRICT y = vy;
2381
2382 const int nb = n / QK_K;
2383
2384#if defined(__AVX2__)
2385
2386 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2387
2388 uint32_t aux32[4];
2389 const uint8_t * aux8 = (const uint8_t *)aux32;
2390
2391 __m256 accumf = _mm256_setzero_ps();
2392 for (int i = 0; i < nb; ++i) {
2393 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2394 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2395 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2396 __m256i sumi1 = _mm256_setzero_si256();
2397 __m256i sumi2 = _mm256_setzero_si256();
2398 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2399 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2400 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2401 memcpy(dest: aux32, src: q2, n: 4*sizeof(uint32_t)); q2 += 8;
2402 const __m256i q2_1 = _mm256_set_epi64x(a: iq2xxs_grid[aux8[ 3]], b: iq2xxs_grid[aux8[ 2]], c: iq2xxs_grid[aux8[1]], d: iq2xxs_grid[aux8[0]]);
2403 const __m256i q2_2 = _mm256_set_epi64x(a: iq2xxs_grid[aux8[11]], b: iq2xxs_grid[aux8[10]], c: iq2xxs_grid[aux8[9]], d: iq2xxs_grid[aux8[8]]);
2404 const __m256i s2_1 = _mm256_set_epi64x(a: signs64[(aux32[1] >> 21) & 127], b: signs64[(aux32[1] >> 14) & 127],
2405 c: signs64[(aux32[1] >> 7) & 127], d: signs64[(aux32[1] >> 0) & 127]);
2406 const __m256i s2_2 = _mm256_set_epi64x(a: signs64[(aux32[3] >> 21) & 127], b: signs64[(aux32[3] >> 14) & 127],
2407 c: signs64[(aux32[3] >> 7) & 127], d: signs64[(aux32[3] >> 0) & 127]);
2408 const __m256i q8s_1 = _mm256_sign_epi8(a: q8_1, b: s2_1);
2409 const __m256i q8s_2 = _mm256_sign_epi8(a: q8_2, b: s2_2);
2410 const __m256i dot1 = _mm256_maddubs_epi16(a: q2_1, b: q8s_1);
2411 const __m256i dot2 = _mm256_maddubs_epi16(a: q2_2, b: q8s_2);
2412 const uint16_t ls1 = aux32[1] >> 28;
2413 const uint16_t ls2 = aux32[3] >> 28;
2414 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: _mm256_set1_epi16(w: 2*ls1+1));
2415 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: _mm256_set1_epi16(w: 2*ls2+1));
2416 sumi1 = _mm256_add_epi32(a: sumi1, b: p1);
2417 sumi2 = _mm256_add_epi32(a: sumi2, b: p2);
2418 }
2419
2420 accumf = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accumf);
2421
2422 }
2423
2424 *s = 0.125f * hsum_float_8(x: accumf);
2425
2426#elif defined(__AVX__)
2427 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2428
2429 uint32_t aux32[4];
2430 const uint8_t * aux8 = (const uint8_t *)aux32;
2431
2432 __m256 accumf = _mm256_setzero_ps();
2433 for (int i = 0; i < nb; ++i) {
2434 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2435 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2436 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2437 __m128i sumi1_0 = _mm_setzero_si128();
2438 __m128i sumi1_1 = _mm_setzero_si128();
2439 __m128i sumi2_0 = _mm_setzero_si128();
2440 __m128i sumi2_1 = _mm_setzero_si128();
2441 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2442 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2443 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2444 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2445 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2446 memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
2447 const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
2448 const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
2449 const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
2450 const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
2451 const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
2452 const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
2453 const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
2454 const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
2455 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
2456 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
2457 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
2458 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
2459 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2460 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2461 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2462 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2463 const uint16_t ls1 = aux32[1] >> 28;
2464 const uint16_t ls2 = aux32[3] >> 28;
2465 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
2466 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
2467 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
2468 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
2469 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2470 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2471 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2472 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2473 }
2474
2475 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2476
2477 }
2478
2479 *s = 0.125f * hsum_float_8(accumf);
2480
2481#else
2482 UNUSED(x);
2483 UNUSED(y);
2484 UNUSED(nb);
2485 ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2486#endif
2487}
2488
2489void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2490 assert(n % QK_K == 0);
2491 assert(nrc == 1);
2492 UNUSED(nrc);
2493 UNUSED(bx);
2494 UNUSED(by);
2495 UNUSED(bs);
2496
2497 const block_iq2_xs * GGML_RESTRICT x = vx;
2498 const block_q8_K * GGML_RESTRICT y = vy;
2499
2500 const int nb = n / QK_K;
2501
2502#if defined(__AVX2__)
2503
2504 const __m256i mone = _mm256_set1_epi8(b: 1);
2505 static const char block_sign_shuffle_mask_1[32] = {
2506 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2507 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2508 };
2509 static const char block_sign_shuffle_mask_2[32] = {
2510 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2511 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2512 };
2513 static const uint8_t bit_selector_mask_bytes[32] = {
2514 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2515 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2516 };
2517
2518 const __m256i bit_selector_mask = _mm256_loadu_si256(p: (const __m256i*)bit_selector_mask_bytes);
2519 const __m256i block_sign_shuffle_1 = _mm256_loadu_si256(p: (const __m256i*)block_sign_shuffle_mask_1);
2520 const __m256i block_sign_shuffle_2 = _mm256_loadu_si256(p: (const __m256i*)block_sign_shuffle_mask_2);
2521
2522 static const uint8_t k_bit_helper[32] = {
2523 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2524 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2525 };
2526 const __m256i bit_helper = _mm256_loadu_si256(p: (const __m256i*)k_bit_helper);
2527 const __m256i m511 = _mm256_set1_epi16(w: 511);
2528 const __m128i m4 = _mm_set1_epi8(b: 0xf);
2529 const __m128i m1 = _mm_set1_epi8(b: 1);
2530
2531 uint64_t aux64;
2532
2533 // somewhat hacky, but gives a significant boost in performance
2534 __m256i aux_gindex;
2535 const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2536
2537 __m256 accumf = _mm256_setzero_ps();
2538 for (int i = 0; i < nb; ++i) {
2539 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2540 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2541 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2542
2543 memcpy(dest: &aux64, src: x[i].scales, n: 8);
2544 __m128i stmp = _mm_set1_epi64x(q: aux64);
2545 stmp = _mm_unpacklo_epi8(a: _mm_and_si128(a: stmp, b: m4), b: _mm_and_si128(a: _mm_srli_epi16(a: stmp, count: 4), b: m4));
2546 const __m128i scales = _mm_add_epi8(a: _mm_slli_epi16(a: stmp, count: 1), b: m1);
2547
2548 __m256i sumi1 = _mm256_setzero_si256();
2549 __m256i sumi2 = _mm256_setzero_si256();
2550 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2551
2552 const __m256i q2_data = _mm256_loadu_si256(p: (const __m256i*)q2); q2 += 16;
2553 aux_gindex = _mm256_and_si256(a: q2_data, b: m511);
2554
2555 const __m256i partial_sign_bits = _mm256_srli_epi16(a: q2_data, count: 9);
2556 const __m256i partial_sign_bits_upper = _mm256_srli_epi16(a: q2_data, count: 13);
2557 const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(a: partial_sign_bits, b: partial_sign_bits_upper);
2558
2559 const __m256i odd_bits = _mm256_shuffle_epi8(a: bit_helper, b: partial_sign_bits_for_counting);
2560 const __m256i full_sign_bits = _mm256_or_si256(a: partial_sign_bits, b: odd_bits);
2561
2562 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2563 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2564 const __m256i q8_3 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2565 const __m256i q8_4 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2566
2567 const __m256i q2_1 = _mm256_set_epi64x(a: iq2xs_grid[gindex[ 3]], b: iq2xs_grid[gindex[ 2]],
2568 c: iq2xs_grid[gindex[ 1]], d: iq2xs_grid[gindex[ 0]]);
2569 const __m256i q2_2 = _mm256_set_epi64x(a: iq2xs_grid[gindex[ 7]], b: iq2xs_grid[gindex[ 6]],
2570 c: iq2xs_grid[gindex[ 5]], d: iq2xs_grid[gindex[ 4]]);
2571 const __m256i q2_3 = _mm256_set_epi64x(a: iq2xs_grid[gindex[11]], b: iq2xs_grid[gindex[10]],
2572 c: iq2xs_grid[gindex[ 9]], d: iq2xs_grid[gindex[ 8]]);
2573 const __m256i q2_4 = _mm256_set_epi64x(a: iq2xs_grid[gindex[15]], b: iq2xs_grid[gindex[14]],
2574 c: iq2xs_grid[gindex[13]], d: iq2xs_grid[gindex[12]]);
2575
2576 const __m128i full_signs_l = _mm256_castsi256_si128(a: full_sign_bits);
2577 const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
2578 const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
2579 const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
2580
2581 __m256i signs;
2582 signs = _mm256_shuffle_epi8(a: full_signs_1, b: block_sign_shuffle_1);
2583 signs = _mm256_cmpeq_epi8(a: _mm256_and_si256(a: signs, b: bit_selector_mask), b: bit_selector_mask);
2584 const __m256i q8s_1 = _mm256_sign_epi8(a: q8_1, b: _mm256_or_si256(a: signs, b: mone));
2585
2586 signs = _mm256_shuffle_epi8(a: full_signs_1, b: block_sign_shuffle_2);
2587 signs = _mm256_cmpeq_epi8(a: _mm256_and_si256(a: signs, b: bit_selector_mask), b: bit_selector_mask);
2588 const __m256i q8s_2 = _mm256_sign_epi8(a: q8_2, b: _mm256_or_si256(a: signs, b: mone));
2589
2590 signs = _mm256_shuffle_epi8(a: full_signs_2, b: block_sign_shuffle_1);
2591 signs = _mm256_cmpeq_epi8(a: _mm256_and_si256(a: signs, b: bit_selector_mask), b: bit_selector_mask);
2592 const __m256i q8s_3 = _mm256_sign_epi8(a: q8_3, b: _mm256_or_si256(a: signs, b: mone));
2593
2594 signs = _mm256_shuffle_epi8(a: full_signs_2, b: block_sign_shuffle_2);
2595 signs = _mm256_cmpeq_epi8(a: _mm256_and_si256(a: signs, b: bit_selector_mask), b: bit_selector_mask);
2596 const __m256i q8s_4 = _mm256_sign_epi8(a: q8_4, b: _mm256_or_si256(a: signs, b: mone));
2597
2598 const __m256i dot1 = _mm256_maddubs_epi16(a: q2_1, b: q8s_1);
2599 const __m256i dot2 = _mm256_maddubs_epi16(a: q2_2, b: q8s_2);
2600 const __m256i dot3 = _mm256_maddubs_epi16(a: q2_3, b: q8s_3);
2601 const __m256i dot4 = _mm256_maddubs_epi16(a: q2_4, b: q8s_4);
2602
2603 const __m256i sc1 = _mm256_cvtepi8_epi16(V: _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: ib32+0)));
2604 const __m256i sc2 = _mm256_cvtepi8_epi16(V: _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: ib32+1)));
2605 const __m256i sc3 = _mm256_cvtepi8_epi16(V: _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: ib32+2)));
2606 const __m256i sc4 = _mm256_cvtepi8_epi16(V: _mm_shuffle_epi8(a: scales, b: get_scale_shuffle(i: ib32+3)));
2607
2608 sumi1 = _mm256_add_epi32(a: sumi1, b: _mm256_madd_epi16(a: dot1, b: sc1));
2609 sumi2 = _mm256_add_epi32(a: sumi2, b: _mm256_madd_epi16(a: dot2, b: sc2));
2610 sumi1 = _mm256_add_epi32(a: sumi1, b: _mm256_madd_epi16(a: dot3, b: sc3));
2611 sumi2 = _mm256_add_epi32(a: sumi2, b: _mm256_madd_epi16(a: dot4, b: sc4));
2612 }
2613
2614 accumf = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accumf);
2615
2616 }
2617
2618 *s = 0.125f * hsum_float_8(x: accumf);
2619
2620#elif defined(__AVX__)
2621 const __m128i mone = _mm_set1_epi8(1);
2622 static const char block_sign_shuffle_mask_1[32] = {
2623 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
2624 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
2625 };
2626 static const char block_sign_shuffle_mask_2[32] = {
2627 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
2628 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
2629 };
2630 static const uint8_t bit_selector_mask_bytes[32] = {
2631 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2632 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2633 };
2634
2635 const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
2636 const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
2637 const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
2638 const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
2639 const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
2640 const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
2641
2642 static const uint8_t k_bit_helper[32] = {
2643 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2644 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
2645 };
2646 const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
2647 const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
2648 const __m128i m511 = _mm_set1_epi16(511);
2649 const __m128i m4 = _mm_set1_epi8(0xf);
2650 const __m128i m1 = _mm_set1_epi8(1);
2651
2652 uint64_t aux64;
2653
2654 // somewhat hacky, but gives a significant boost in performance
2655 __m256i aux_gindex;
2656 const uint16_t * gindex = (const uint16_t *)&aux_gindex;
2657
2658 __m256 accumf = _mm256_setzero_ps();
2659 for (int i = 0; i < nb; ++i) {
2660 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2661 const uint16_t * GGML_RESTRICT q2 = x[i].qs;
2662 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2663
2664 memcpy(&aux64, x[i].scales, 8);
2665 __m128i stmp = _mm_set1_epi64x(aux64);
2666 stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
2667 const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
2668
2669 __m128i sumi1_0 = _mm_setzero_si128();
2670 __m128i sumi1_1 = _mm_setzero_si128();
2671 __m128i sumi2_0 = _mm_setzero_si128();
2672 __m128i sumi2_1 = _mm_setzero_si128();
2673 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
2674
2675 const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
2676 const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
2677 aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
2678
2679 const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
2680 const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
2681 const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
2682 const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
2683 const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
2684 const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
2685
2686 const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
2687 const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
2688 const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
2689 const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
2690
2691 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2692 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2693 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2694 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2695 const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2696 const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2697 const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2698 const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2699
2700 const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
2701 const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
2702 const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
2703 const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
2704 const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
2705 const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
2706 const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
2707 const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
2708
2709 // AVX2 full_signs_1 is full_sign_bits_0 here
2710 // AVX2 full_signs_2 is full_sign_bits_1 here
2711 __m128i signs_0, signs_1;
2712 signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
2713 signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
2714 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2715 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2716 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
2717 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
2718
2719 signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
2720 signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
2721 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2722 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2723 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
2724 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
2725
2726 signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
2727 signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
2728 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2729 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2730 const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
2731 const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
2732
2733 signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
2734 signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
2735 signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
2736 signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
2737 const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
2738 const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
2739
2740 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2741 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2742 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2743 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2744 const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
2745 const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
2746 const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
2747 const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
2748
2749 __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
2750 const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
2751 const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2752 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
2753 const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
2754 const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2755 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
2756 const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
2757 const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2758 sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
2759 const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
2760 const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
2761
2762 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
2763 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
2764 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
2765 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
2766 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
2767 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
2768 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
2769 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
2770 }
2771
2772 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2773
2774 }
2775
2776 *s = 0.125f * hsum_float_8(accumf);
2777
2778#else
2779 UNUSED(x);
2780 UNUSED(y);
2781 UNUSED(nb);
2782 ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2783#endif
2784}
2785
2786void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2787 assert(n % QK_K == 0);
2788 assert(nrc == 1);
2789 UNUSED(nrc);
2790 UNUSED(bx);
2791 UNUSED(by);
2792 UNUSED(bs);
2793
2794 const block_iq2_s * GGML_RESTRICT x = vx;
2795 const block_q8_K * GGML_RESTRICT y = vy;
2796
2797 const int nb = n / QK_K;
2798
2799#if defined(__AVX2__)
2800
2801 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2802 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2803 };
2804
2805 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2806 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2807 };
2808
2809 const __m128i m4 = _mm_set1_epi8(b: 0xf);
2810 const __m128i m1 = _mm_set1_epi8(b: 1);
2811
2812 const __m256i mask1 = _mm256_loadu_si256(p: (const __m256i*)k_mask1);
2813 const __m256i mask2 = _mm256_loadu_si256(p: (const __m256i*)k_mask2);
2814
2815 uint64_t aux64;
2816
2817 __m256 accumf = _mm256_setzero_ps();
2818 for (int i = 0; i < nb; ++i) {
2819 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2820 const uint8_t * GGML_RESTRICT qs = x[i].qs;
2821 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2822 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2823 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2824
2825 memcpy(dest: &aux64, src: x[i].scales, n: 8);
2826 const __m128i scales8 = _mm_add_epi8(a: _mm_slli_epi16(a: _mm_and_si128(a: _mm_set_epi64x(q1: aux64 >> 4, q0: aux64), b: m4), count: 1), b: m1);
2827 const __m256i scales16 = _mm256_cvtepi8_epi16(V: scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
2828
2829 __m256i sumi1 = _mm256_setzero_si256();
2830 __m256i sumi2 = _mm256_setzero_si256();
2831 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2832 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2833 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
2834 const __m256i q2_1 = _mm256_set_epi64x(a: iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2835 b: iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
2836 c: iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2837 d: iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2838 const __m256i q2_2 = _mm256_set_epi64x(a: iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2839 b: iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
2840 c: iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2841 d: iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2842 qs += 8;
2843
2844 __m256i aux256 = _mm256_set1_epi32(i: signs[0] | ((uint32_t) signs[1] << 16));
2845 aux256 = _mm256_and_si256(a: _mm256_shuffle_epi8(a: aux256,b: mask1), b: mask2);
2846 const __m256i s2_1 = _mm256_cmpeq_epi8(a: aux256, b: mask2);
2847 const __m256i q8s_1 = _mm256_sub_epi8(a: _mm256_xor_si256(a: s2_1, b: q8_1), b: s2_1);
2848
2849 aux256 = _mm256_set1_epi32(i: signs[2] | ((uint32_t) signs[3] << 16));
2850 aux256 = _mm256_and_si256(a: _mm256_shuffle_epi8(a: aux256,b: mask1), b: mask2);
2851 const __m256i s2_2 = _mm256_cmpeq_epi8(a: aux256, b: mask2);
2852 const __m256i q8s_2 = _mm256_sub_epi8(a: _mm256_xor_si256(a: s2_2, b: q8_2), b: s2_2);
2853
2854 signs += 4;
2855
2856 const __m256i dot1 = _mm256_maddubs_epi16(a: q2_1, b: q8s_1); // blocks 2*ib32+0, 2*ib32+1
2857 const __m256i dot2 = _mm256_maddubs_epi16(a: q2_2, b: q8s_2); // blocks 2*ib32+2, 2*ib32+3
2858
2859 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: _mm256_shuffle_epi8(a: scales16, b: get_scale_shuffle_k4(i: ib32+0)));
2860 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: _mm256_shuffle_epi8(a: scales16, b: get_scale_shuffle_k4(i: ib32+1)));
2861 sumi1 = _mm256_add_epi32(a: sumi1, b: p1);
2862 sumi2 = _mm256_add_epi32(a: sumi2, b: p2);
2863 }
2864
2865 accumf = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accumf);
2866
2867 }
2868
2869 *s = 0.125f * hsum_float_8(x: accumf);
2870
2871#elif defined(__AVX__)
2872 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
2873 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
2874 };
2875
2876 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2877 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
2878 };
2879
2880 const __m128i m4 = _mm_set1_epi8(0xf);
2881 const __m128i m1 = _mm_set1_epi8(1);
2882
2883 const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
2884 const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
2885 const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
2886 const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
2887
2888 uint64_t aux64;
2889
2890 __m256 accumf = _mm256_setzero_ps();
2891 for (int i = 0; i < nb; ++i) {
2892 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2893 const uint8_t * GGML_RESTRICT qs = x[i].qs;
2894 const uint8_t * GGML_RESTRICT qh = x[i].qh;
2895 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
2896 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2897
2898 memcpy(&aux64, x[i].scales, 8);
2899 const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
2900 const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
2901 const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
2902
2903 __m128i sumi1_0 = _mm_setzero_si128();
2904 __m128i sumi1_1 = _mm_setzero_si128();
2905 __m128i sumi2_0 = _mm_setzero_si128();
2906 __m128i sumi2_1 = _mm_setzero_si128();
2907 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2908 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2909 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2910 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2911 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
2912 const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
2913 iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
2914 const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
2915 iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
2916 const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
2917 iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
2918 const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
2919 iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
2920 qs += 8;
2921
2922 __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
2923 __m128i aux128_1 = aux128_0;
2924 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2925 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2926 const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2927 const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2928 const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
2929 const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
2930
2931 aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
2932 aux128_1 = aux128_0;
2933 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
2934 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
2935 const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
2936 const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
2937 const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
2938 const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
2939
2940 signs += 4;
2941
2942 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
2943 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
2944 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
2945 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
2946
2947 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
2948 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
2949 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
2950 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
2951 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
2952 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
2953 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
2954 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
2955 }
2956
2957 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
2958
2959 }
2960
2961 *s = 0.125f * hsum_float_8(accumf);
2962
2963#else
2964 UNUSED(x);
2965 UNUSED(y);
2966 UNUSED(nb);
2967 ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2968#endif
2969}
2970
2971void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2972 assert(n % QK_K == 0);
2973 assert(nrc == 1);
2974 UNUSED(nrc);
2975 UNUSED(bx);
2976 UNUSED(by);
2977 UNUSED(bs);
2978
2979 const block_iq3_xxs * GGML_RESTRICT x = vx;
2980 const block_q8_K * GGML_RESTRICT y = vy;
2981
2982 const int nb = n / QK_K;
2983
2984#if defined(__AVX2__)
2985
2986 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
2987
2988 uint32_t aux32[2];
2989
2990 __m256 accumf = _mm256_setzero_ps();
2991 for (int i = 0; i < nb; ++i) {
2992 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2993 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
2994 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
2995 const int8_t * GGML_RESTRICT q8 = y[i].qs;
2996 __m256i sumi1 = _mm256_setzero_si256();
2997 __m256i sumi2 = _mm256_setzero_si256();
2998 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
2999 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3000 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3001 const __m256i q2_1 = _mm256_set_epi32(i0: iq3xxs_grid[q3[7]], i1: iq3xxs_grid[q3[6]], i2: iq3xxs_grid[q3[5]], i3: iq3xxs_grid[q3[4]],
3002 i4: iq3xxs_grid[q3[3]], i5: iq3xxs_grid[q3[2]], i6: iq3xxs_grid[q3[1]], i7: iq3xxs_grid[q3[0]]);
3003 q3 += 8;
3004 const __m256i q2_2 = _mm256_set_epi32(i0: iq3xxs_grid[q3[7]], i1: iq3xxs_grid[q3[6]], i2: iq3xxs_grid[q3[5]], i3: iq3xxs_grid[q3[4]],
3005 i4: iq3xxs_grid[q3[3]], i5: iq3xxs_grid[q3[2]], i6: iq3xxs_grid[q3[1]], i7: iq3xxs_grid[q3[0]]);
3006 q3 += 8;
3007 memcpy(dest: aux32, src: gas, n: 8); gas += 8;
3008 const __m256i s2_1 = _mm256_set_epi64x(a: signs64[(aux32[0] >> 21) & 127], b: signs64[(aux32[0] >> 14) & 127],
3009 c: signs64[(aux32[0] >> 7) & 127], d: signs64[(aux32[0] >> 0) & 127]);
3010 const __m256i s2_2 = _mm256_set_epi64x(a: signs64[(aux32[1] >> 21) & 127], b: signs64[(aux32[1] >> 14) & 127],
3011 c: signs64[(aux32[1] >> 7) & 127], d: signs64[(aux32[1] >> 0) & 127]);
3012 const __m256i q8s_1 = _mm256_sign_epi8(a: q8_1, b: s2_1);
3013 const __m256i q8s_2 = _mm256_sign_epi8(a: q8_2, b: s2_2);
3014 const __m256i dot1 = _mm256_maddubs_epi16(a: q2_1, b: q8s_1);
3015 const __m256i dot2 = _mm256_maddubs_epi16(a: q2_2, b: q8s_2);
3016 const uint16_t ls1 = aux32[0] >> 28;
3017 const uint16_t ls2 = aux32[1] >> 28;
3018 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: _mm256_set1_epi16(w: 2*ls1+1));
3019 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: _mm256_set1_epi16(w: 2*ls2+1));
3020 sumi1 = _mm256_add_epi32(a: sumi1, b: p1);
3021 sumi2 = _mm256_add_epi32(a: sumi2, b: p2);
3022 }
3023
3024 accumf = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accumf);
3025
3026 }
3027
3028 *s = 0.25f * hsum_float_8(x: accumf);
3029
3030#elif defined(__AVX__)
3031 const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
3032
3033 uint32_t aux32[2];
3034
3035 __m256 accumf = _mm256_setzero_ps();
3036 for (int i = 0; i < nb; ++i) {
3037 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3038 const uint8_t * GGML_RESTRICT q3 = x[i].qs;
3039 const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
3040 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3041 __m128i sumi1_0 = _mm_setzero_si128();
3042 __m128i sumi1_1 = _mm_setzero_si128();
3043 __m128i sumi2_0 = _mm_setzero_si128();
3044 __m128i sumi2_1 = _mm_setzero_si128();
3045 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3046 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3047 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3048 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3049 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3050 const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3051 const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3052 q3 += 8;
3053 const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
3054 const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
3055 q3 += 8;
3056 memcpy(aux32, gas, 8); gas += 8;
3057 const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
3058 const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
3059 const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
3060 const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
3061 const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
3062 const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
3063 const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
3064 const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
3065 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3066 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3067 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3068 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3069 const uint16_t ls1 = aux32[0] >> 28;
3070 const uint16_t ls2 = aux32[1] >> 28;
3071 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3072 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3073 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3074 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3075 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3076 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3077 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3078 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3079 }
3080
3081 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3082
3083 }
3084
3085 *s = 0.25f * hsum_float_8(accumf);
3086
3087#else
3088 UNUSED(x);
3089 UNUSED(y);
3090 UNUSED(nb);
3091 ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3092#endif
3093}
3094
3095void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3096 assert(n % QK_K == 0);
3097 assert(nrc == 1);
3098 UNUSED(nrc);
3099 UNUSED(bx);
3100 UNUSED(by);
3101 UNUSED(bs);
3102
3103 const block_iq3_s * GGML_RESTRICT x = vx;
3104 const block_q8_K * GGML_RESTRICT y = vy;
3105
3106 const int nb = n / QK_K;
3107
3108#if defined(__AVX2__)
3109
3110 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3111 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3112 };
3113
3114 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3115 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3116 };
3117
3118 const __m256i mask1 = _mm256_loadu_si256(p: (const __m256i*)k_mask1);
3119 const __m256i mask2 = _mm256_loadu_si256(p: (const __m256i*)k_mask2);
3120
3121 const __m256i idx_shift = _mm256_set_epi32(i0: 1, i1: 2, i2: 3, i3: 4, i4: 5, i5: 6, i6: 7, i7: 8);
3122 const __m256i idx_mask = _mm256_set1_epi32(i: 256);
3123
3124 typedef union {
3125 __m256i vec[2];
3126 uint32_t index[16];
3127 } index_t;
3128
3129 index_t idx;
3130
3131 __m256 accumf = _mm256_setzero_ps();
3132 for (int i = 0; i < nb; ++i) {
3133 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3134 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3135 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3136 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3137 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3138 __m256i sumi1 = _mm256_setzero_si256();
3139 __m256i sumi2 = _mm256_setzero_si256();
3140 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3141 const __m256i q8_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3142 const __m256i q8_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3143 const __m256i idx_l = _mm256_cvtepu8_epi16(V: _mm_loadu_si128(p: (const __m128i *)qs)); qs += 16;
3144 idx.vec[0] = _mm256_set1_epi32(i: qh[ib32+0]);
3145 idx.vec[1] = _mm256_set1_epi32(i: qh[ib32+1]);
3146 idx.vec[0] = _mm256_and_si256(a: _mm256_sllv_epi32(X: idx.vec[0], Y: idx_shift), b: idx_mask);
3147 idx.vec[1] = _mm256_and_si256(a: _mm256_sllv_epi32(X: idx.vec[1], Y: idx_shift), b: idx_mask);
3148 idx.vec[0] = _mm256_or_si256(a: idx.vec[0], b: _mm256_cvtepi16_epi32(V: _mm256_castsi256_si128(a: idx_l)));
3149 idx.vec[1] = _mm256_or_si256(a: idx.vec[1], b: _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
3150
3151 // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
3152 //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
3153 //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
3154 const __m256i q2_1 = _mm256_set_epi32(
3155 i0: iq3s_grid[idx.index[7]], i1: iq3s_grid[idx.index[6]], i2: iq3s_grid[idx.index[5]], i3: iq3s_grid[idx.index[4]],
3156 i4: iq3s_grid[idx.index[3]], i5: iq3s_grid[idx.index[2]], i6: iq3s_grid[idx.index[1]], i7: iq3s_grid[idx.index[0]]
3157 );
3158 const __m256i q2_2 = _mm256_set_epi32(
3159 i0: iq3s_grid[idx.index[15]], i1: iq3s_grid[idx.index[14]], i2: iq3s_grid[idx.index[13]], i3: iq3s_grid[idx.index[12]],
3160 i4: iq3s_grid[idx.index[11]], i5: iq3s_grid[idx.index[10]], i6: iq3s_grid[idx.index[ 9]], i7: iq3s_grid[idx.index[ 8]]
3161 );
3162
3163 __m256i aux256 = _mm256_set1_epi32(i: signs[0] | (signs[1] << 16));
3164 aux256 = _mm256_and_si256(a: _mm256_shuffle_epi8(a: aux256,b: mask1), b: mask2);
3165 const __m256i s2_1 = _mm256_cmpeq_epi8(a: aux256, b: mask2);
3166 const __m256i q8s_1 = _mm256_sub_epi8(a: _mm256_xor_si256(a: s2_1, b: q8_1), b: s2_1);
3167
3168 aux256 = _mm256_set1_epi32(i: signs[2] | (signs[3] << 16));
3169 aux256 = _mm256_and_si256(a: _mm256_shuffle_epi8(a: aux256,b: mask1), b: mask2);
3170 const __m256i s2_2 = _mm256_cmpeq_epi8(a: aux256, b: mask2);
3171 const __m256i q8s_2 = _mm256_sub_epi8(a: _mm256_xor_si256(a: s2_2, b: q8_2), b: s2_2);
3172
3173 signs += 4;
3174
3175 const __m256i dot1 = _mm256_maddubs_epi16(a: q2_1, b: q8s_1);
3176 const __m256i dot2 = _mm256_maddubs_epi16(a: q2_2, b: q8s_2);
3177 const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3178 const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3179 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: _mm256_set1_epi16(w: 2*ls1+1));
3180 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: _mm256_set1_epi16(w: 2*ls2+1));
3181 sumi1 = _mm256_add_epi32(a: sumi1, b: p1);
3182 sumi2 = _mm256_add_epi32(a: sumi2, b: p2);
3183 }
3184
3185 accumf = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accumf);
3186
3187 }
3188
3189 *s = hsum_float_8(x: accumf);
3190
3191#elif defined(__AVX__)
3192 static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
3193 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
3194 };
3195
3196 static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3197 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
3198 };
3199
3200 const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
3201 const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
3202 const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
3203 const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
3204
3205 const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
3206 const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
3207 const __m128i idx_mask = _mm_set1_epi32(256);
3208
3209 typedef union {
3210 __m128i vec[4];
3211 uint32_t index[16];
3212 } index_t;
3213
3214 index_t idx;
3215
3216 __m256 accumf = _mm256_setzero_ps();
3217 for (int i = 0; i < nb; ++i) {
3218 const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
3219 const uint8_t * GGML_RESTRICT qs = x[i].qs;
3220 const uint8_t * GGML_RESTRICT qh = x[i].qh;
3221 const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
3222 const int8_t * GGML_RESTRICT q8 = y[i].qs;
3223 __m128i sumi1_0 = _mm_setzero_si128();
3224 __m128i sumi1_1 = _mm_setzero_si128();
3225 __m128i sumi2_0 = _mm_setzero_si128();
3226 __m128i sumi2_1 = _mm_setzero_si128();
3227 for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
3228 const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3229 const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3230 const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3231 const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3232 const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
3233 const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
3234 const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
3235 idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
3236 idx.vec[1] = idx.vec[0];
3237 idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
3238 idx.vec[3] = idx.vec[2];
3239
3240 idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
3241 idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
3242 idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
3243 idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
3244
3245 idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
3246 idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
3247 idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
3248 idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
3249
3250 const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
3251 const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
3252 const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
3253 const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
3254
3255 __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
3256 __m128i aux128_1 = aux128_0;
3257 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3258 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3259 const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3260 const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3261 const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
3262 const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
3263
3264 aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
3265 aux128_1 = aux128_0;
3266 aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
3267 aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
3268 const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
3269 const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
3270 const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
3271 const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
3272
3273 signs += 4;
3274
3275 const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
3276 const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
3277 const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
3278 const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
3279 const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
3280 const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
3281 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
3282 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
3283 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
3284 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
3285 sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
3286 sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
3287 sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
3288 sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
3289 }
3290
3291 accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
3292
3293 }
3294
3295 *s = hsum_float_8(accumf);
3296
3297#else
3298 UNUSED(x);
3299 UNUSED(y);
3300 UNUSED(nb);
3301 ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3302#endif
3303}
3304
3305void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3306 assert(n % QK_K == 0);
3307 assert(nrc == 1);
3308 UNUSED(nrc);
3309 UNUSED(bx);
3310 UNUSED(by);
3311 UNUSED(bs);
3312
3313 const block_iq1_s * GGML_RESTRICT x = vx;
3314 const block_q8_K * GGML_RESTRICT y = vy;
3315
3316 const int nb = n / QK_K;
3317
3318#if defined __AVX2__
3319
3320 __m256 accum = _mm256_setzero_ps();
3321 float accum1 = 0;
3322 for (int i = 0; i < nb; ++i) {
3323
3324 const int8_t * q8 = y[i].qs;
3325 const uint8_t * qs = x[i].qs;
3326 const uint16_t * qh = x[i].qh;
3327
3328 __m256i sumi = _mm256_setzero_si256();
3329 int sumi1 = 0;
3330 for (int ib = 0; ib < QK_K/32; ib += 2) {
3331#ifdef __BMI2__
3332 const uint64_t packed_idx1 = _pdep_u64(X: *(const uint32_t *)qs, Y: 0x00ff00ff00ff00ffULL) | _pdep_u64(X: qh[ib], Y: 0x700070007000700ULL);
3333 const uint64_t packed_idx2 = _pdep_u64(X: *(const uint32_t *)(qs + 4), Y: 0x00ff00ff00ff00ffULL) | _pdep_u64(X: qh[ib + 1], Y: 0x700070007000700ULL);
3334 const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3335 const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3336 const __m256i q1b_1 = _mm256_set_epi64x(a: iq1s_grid[idx1[3]], b: iq1s_grid[idx1[2]], c: iq1s_grid[idx1[1]], d: iq1s_grid[idx1[0]]);
3337 const __m256i q1b_2 = _mm256_set_epi64x(a: iq1s_grid[idx2[3]], b: iq1s_grid[idx2[2]], c: iq1s_grid[idx2[1]], d: iq1s_grid[idx2[0]]);
3338#else
3339 const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
3340 iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3341 const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
3342 iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3343#endif
3344 qs += 8;
3345 const __m256i q8b_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
3346 const __m256i q8b_2 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
3347
3348 const __m256i dot1 = mul_add_epi8(x: q1b_1, y: q8b_1);
3349 const __m256i dot2 = mul_add_epi8(x: q1b_2, y: q8b_2);
3350 const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3351 const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3352 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: _mm256_set1_epi16(w: ls1));
3353 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: _mm256_set1_epi16(w: ls2));
3354
3355 sumi = _mm256_add_epi32(a: sumi, b: _mm256_add_epi32(a: p1, b: p2));
3356 sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3357 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3358 }
3359
3360 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3361 accum = _mm256_fmadd_ps(A: _mm256_set1_ps(w: d), B: _mm256_cvtepi32_ps(a: sumi), C: accum);
3362 accum1 += d * sumi1;
3363
3364 }
3365
3366 *s = hsum_float_8(x: accum) + IQ1S_DELTA * accum1;
3367
3368#elif defined __AVX__
3369 __m256 accum = _mm256_setzero_ps();
3370 float accum1 = 0;
3371 for (int i = 0; i < nb; ++i) {
3372
3373 const int8_t * q8 = y[i].qs;
3374 const uint8_t * qs = x[i].qs;
3375 const uint16_t * qh = x[i].qh;
3376
3377 __m128i sumi1_0 = _mm_setzero_si128();
3378 __m128i sumi1_1 = _mm_setzero_si128();
3379 int sumi1 = 0;
3380 for (int ib = 0; ib < QK_K/32; ib += 2) {
3381 const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
3382 const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
3383 const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
3384 const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
3385 qs += 8;
3386 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3387 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3388 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3389 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3390
3391 const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3392 const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3393 const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3394 const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3395 const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
3396 const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
3397 const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
3398 const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
3399 const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
3400 const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
3401
3402 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3403 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3404 sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
3405 + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
3406 }
3407
3408 const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
3409 accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
3410 accum1 += d * sumi1;
3411
3412 }
3413
3414 *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
3415
3416#else
3417 UNUSED(x);
3418 UNUSED(y);
3419 UNUSED(nb);
3420 ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3421#endif
3422}
3423
3424void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3425 assert(n % QK_K == 0);
3426 assert(nrc == 1);
3427 UNUSED(nrc);
3428 UNUSED(bx);
3429 UNUSED(by);
3430 UNUSED(bs);
3431
3432 const block_iq1_m * GGML_RESTRICT x = vx;
3433 const block_q8_K * GGML_RESTRICT y = vy;
3434
3435 const int nb = n / QK_K;
3436
3437 iq1m_scale_t scale;
3438
3439#if defined __AVX2__
3440
3441 const __m256i mask = _mm256_set1_epi16(w: 0x7);
3442 const __m256i mone = _mm256_set1_epi16(w: 1);
3443 const __m256i mone8 = _mm256_set1_epi8(b: 1);
3444 const __m256i mtwo8 = _mm256_set1_epi8(b: 2);
3445 // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half.
3446 const __m256i scales_shift = _mm256_set_epi64x(a: 9, b: 3, c: 6, d: 0);
3447
3448 __m256 accum1 = _mm256_setzero_ps();
3449 __m256 accum2 = _mm256_setzero_ps();
3450 for (int i = 0; i < nb; ++i) {
3451
3452 const int8_t * q8 = y[i].qs;
3453 const uint8_t * qs = x[i].qs;
3454 const uint8_t * qh = x[i].qh;
3455 const uint16_t * sc = (const uint16_t *)x[i].scales;
3456
3457 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3458 // Extract 3-bit scales (16 values)
3459 __m256i scales = _mm256_set1_epi64x(q: *(const uint64_t*)sc);
3460 scales = _mm256_srlv_epi64(X: scales, Y: scales_shift);
3461 scales = _mm256_add_epi16(a: _mm256_slli_epi16(a: _mm256_and_si256(a: scales, b: mask), count: 1), b: mone);
3462
3463 // Indices to repeat each scale 8 times.
3464 __m256i scales_idx1 = _mm256_set1_epi16(w: 0x0100);
3465 __m256i scales_idx2 = _mm256_add_epi8(a: scales_idx1, b: _mm256_set1_epi8(b: 8));
3466
3467 __m256i sumi1 = _mm256_setzero_si256();
3468 __m256i sumi2 = _mm256_setzero_si256();
3469 for (int ib = 0; ib < QK_K/32; ib += 2) {
3470#ifdef __BMI2__
3471 const uint64_t packed_idx1 = _pdep_u64(X: *(const uint32_t *)qs, Y: 0x00ff00ff00ff00ffULL)
3472 | _pdep_u64(X: *(const uint16_t*)(qh) & 0x7777, Y: 0xf000f000f000f00ULL);
3473 const uint64_t packed_idx2 = _pdep_u64(X: *(const uint32_t *)(qs + 4), Y: 0x00ff00ff00ff00ffULL)
3474 | _pdep_u64(X: *(const uint16_t*)(qh + 2) & 0x7777, Y: 0xf000f000f000f00ULL);
3475 const uint16_t *idx1 = (const uint16_t *)(&packed_idx1);
3476 const uint16_t *idx2 = (const uint16_t *)(&packed_idx2);
3477 const __m256i q1b_1 = _mm256_set_epi64x(a: iq1s_grid[idx1[3]], b: iq1s_grid[idx1[2]], c: iq1s_grid[idx1[1]], d: iq1s_grid[idx1[0]]);
3478 const __m256i q1b_2 = _mm256_set_epi64x(a: iq1s_grid[idx2[3]], b: iq1s_grid[idx2[2]], c: iq1s_grid[idx2[1]], d: iq1s_grid[idx2[0]]);
3479
3480 // Convert signs to bytes 0x81 (negative) or 0x01 (positive)
3481 const uint64_t delta_sign = _pdep_u64(X: *(const uint32_t*)(qh) & 0x88888888, Y: 0xf0f0f0f0f0f0f0f0ULL);
3482 const __m256i delta1 = _mm256_or_si256(a: mone8, b: _mm256_cvtepi8_epi64(V: _mm_set1_epi32(i: delta_sign)));
3483 const __m256i delta2 = _mm256_or_si256(a: mone8, b: _mm256_cvtepi8_epi64(V: _mm_set1_epi32(i: delta_sign >> 32)));
3484#else
3485 const __m256i q1b_1 = _mm256_set_epi64x(
3486 iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
3487 iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
3488 );
3489 const __m256i q1b_2 = _mm256_set_epi64x(
3490 iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
3491 iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
3492 );
3493
3494 const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3495 qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3496 qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3497 qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3498 const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3499 qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
3500 qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3501 qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3502#endif
3503 const __m256i q8b_1 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
3504 const __m256i q8b_2 = _mm256_loadu_si256(p: (const __m256i*)q8); q8 += 32;
3505
3506 const __m256i dot1 = mul_add_epi8(x: q1b_1, y: q8b_1);
3507 const __m256i dot2 = mul_add_epi8(x: q1b_2, y: q8b_2);
3508 const __m256i dot3 = _mm256_maddubs_epi16(a: mone8, b: _mm256_sign_epi8(a: q8b_1, b: delta1));
3509 const __m256i dot4 = _mm256_maddubs_epi16(a: mone8, b: _mm256_sign_epi8(a: q8b_2, b: delta2));
3510
3511 __m256i scale1 = _mm256_shuffle_epi8(a: scales, b: scales_idx1);
3512 __m256i scale2 = _mm256_shuffle_epi8(a: scales, b: scales_idx2);
3513
3514 scales_idx1 = _mm256_add_epi8(a: scales_idx1, b: mtwo8);
3515 scales_idx2 = _mm256_add_epi8(a: scales_idx2, b: mtwo8);
3516
3517 const __m256i p1 = _mm256_madd_epi16(a: dot1, b: scale1);
3518 const __m256i p2 = _mm256_madd_epi16(a: dot2, b: scale2);
3519 const __m256i p3 = _mm256_madd_epi16(a: dot3, b: scale1);
3520 const __m256i p4 = _mm256_madd_epi16(a: dot4, b: scale2);
3521
3522 sumi1 = _mm256_add_epi32(a: sumi1, b: _mm256_add_epi32(a: p1, b: p2));
3523 sumi2 = _mm256_add_epi32(a: sumi2, b: _mm256_add_epi32(a: p3, b: p4));
3524
3525 qs += 8; qh += 4;
3526 }
3527
3528 const __m256 d = _mm256_set1_ps(w: y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3529
3530 accum1 = _mm256_fmadd_ps(A: d, B: _mm256_cvtepi32_ps(a: sumi1), C: accum1);
3531 accum2 = _mm256_fmadd_ps(A: d, B: _mm256_cvtepi32_ps(a: sumi2), C: accum2);
3532 }
3533
3534 *s = hsum_float_8(x: accum1) + IQ1M_DELTA * hsum_float_8(x: accum2);
3535
3536#elif defined __AVX__
3537 const __m128i mask = _mm_set1_epi16(0x7);
3538 const __m128i mone = _mm_set1_epi16(1);
3539
3540 __m256 accum1 = _mm256_setzero_ps();
3541 __m256 accum2 = _mm256_setzero_ps();
3542 for (int i = 0; i < nb; ++i) {
3543
3544 const int8_t * q8 = y[i].qs;
3545 const uint8_t * qs = x[i].qs;
3546 const uint8_t * qh = x[i].qh;
3547 const uint16_t * sc = (const uint16_t *)x[i].scales;
3548
3549 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
3550
3551 __m128i sumi1_0 = _mm_setzero_si128();
3552 __m128i sumi1_1 = _mm_setzero_si128();
3553 __m128i sumi2_0 = _mm_setzero_si128();
3554 __m128i sumi2_1 = _mm_setzero_si128();
3555 for (int ib = 0; ib < QK_K/32; ib += 2) {
3556 const __m128i q1b_1_0 = _mm_set_epi64x(
3557 iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
3558 const __m128i q1b_1_1 = _mm_set_epi64x(
3559 iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
3560 const __m128i q1b_2_0 = _mm_set_epi64x(
3561 iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
3562 const __m128i q1b_2_1 = _mm_set_epi64x(
3563 iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
3564 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3565 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3566 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3567 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3568
3569 const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
3570 const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
3571 const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
3572 const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
3573
3574 const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3575 qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3576 const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3577 qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3578 const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3579 qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3580 const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
3581 qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
3582
3583 const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
3584 const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
3585 const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
3586 const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
3587
3588 __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
3589 __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
3590 __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
3591 __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
3592
3593 scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
3594 scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
3595 scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
3596 scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
3597 const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
3598 const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
3599 const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
3600 const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
3601 const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
3602 const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
3603 const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
3604 const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
3605
3606 sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
3607 sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
3608 sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
3609 sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
3610
3611 qs += 8; qh += 4;
3612 }
3613
3614 const __m256 d = _mm256_set1_ps(y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16));
3615
3616 accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
3617 accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
3618 }
3619
3620 *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
3621
3622#else
3623 UNUSED(x);
3624 UNUSED(y);
3625 UNUSED(nb);
3626 UNUSED(scale);
3627 ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3628#endif
3629}
3630
3631void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3632 assert(nrc == 1);
3633 UNUSED(nrc);
3634 UNUSED(bx);
3635 UNUSED(by);
3636 UNUSED(bs);
3637 assert(n % QK4_NL == 0);
3638 static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
3639
3640 const block_iq4_nl * GGML_RESTRICT x = vx;
3641 const block_q8_0 * GGML_RESTRICT y = vy;
3642
3643 const int nb = n / QK4_NL;
3644
3645 int ib = 0;
3646 float sumf = 0;
3647
3648#if defined __AVX2__
3649
3650 const __m128i values128 = _mm_loadu_si128(p: (const __m128i*)kvalues_iq4nl);
3651 const __m128i m4b = _mm_set1_epi8(b: 0x0f);
3652 const __m256i mone = _mm256_set1_epi16(w: 1);
3653
3654 __m256 accum1 = _mm256_setzero_ps();
3655 __m256 accum2 = _mm256_setzero_ps();
3656 for (; ib + 1 < nb; ib += 2) {
3657 const __m128i q4bits_1 = _mm_loadu_si128(p: (const __m128i*)x[ib + 0].qs);
3658 const __m128i q4bits_2 = _mm_loadu_si128(p: (const __m128i*)x[ib + 1].qs);
3659 const __m256i q8b_1 = _mm256_loadu_si256(p: (const __m256i *)y[ib + 0].qs);
3660 const __m256i q8b_2 = _mm256_loadu_si256(p: (const __m256i *)y[ib + 1].qs);
3661 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3662 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3663 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3664 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3665 const __m256i p16_1 = mul_add_epi8(x: q4b_1, y: q8b_1);
3666 const __m256i p16_2 = mul_add_epi8(x: q4b_2, y: q8b_2);
3667 const __m256i p_1 = _mm256_madd_epi16(a: p16_1, b: mone);
3668 const __m256i p_2 = _mm256_madd_epi16(a: p16_2, b: mone);
3669 accum1 = _mm256_fmadd_ps(A: _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
3670 B: _mm256_cvtepi32_ps(a: p_1), C: accum1);
3671 accum2 = _mm256_fmadd_ps(A: _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
3672 B: _mm256_cvtepi32_ps(a: p_2), C: accum2);
3673 }
3674
3675 sumf = hsum_float_8(x: _mm256_add_ps(a: accum1, b: accum2));
3676
3677#elif defined __AVX__
3678 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3679 const __m128i m4b = _mm_set1_epi8(0x0f);
3680
3681 __m256 accum = _mm256_setzero_ps();
3682 for (; ib + 1 < nb; ib += 2) {
3683 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
3684 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
3685 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
3686 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
3687 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
3688 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
3689
3690 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3691 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3692 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3693 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3694
3695 const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
3696 const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
3697 accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
3698 }
3699
3700 sumf = hsum_float_8(accum);
3701
3702#endif
3703 for (; ib < nb; ++ib) {
3704 const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
3705 int sumi1 = 0, sumi2 = 0;
3706 for (int j = 0; j < QK4_NL/2; ++j) {
3707 sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
3708 sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
3709 }
3710 sumf += d * (sumi1 + sumi2);
3711 }
3712 *s = sumf;
3713}
3714
3715void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3716 assert(nrc == 1);
3717 UNUSED(nrc);
3718 UNUSED(bx);
3719 UNUSED(by);
3720 UNUSED(bs);
3721 assert(n % QK_K == 0);
3722
3723 const block_iq4_xs * GGML_RESTRICT x = vx;
3724 const block_q8_K * GGML_RESTRICT y = vy;
3725
3726 const int nb = n / QK_K;
3727
3728#if defined __AVX2__
3729
3730 const __m128i values128 = _mm_loadu_si128(p: (const __m128i*)kvalues_iq4nl);
3731 const __m128i m4b = _mm_set1_epi8(b: 0x0f);
3732
3733 __m256 accum = _mm256_setzero_ps();
3734 for (int ibl = 0; ibl < nb; ++ibl) {
3735 const uint8_t * qs = x[ibl].qs;
3736 const int8_t * q8 = y[ibl].qs;
3737 uint16_t sh = x[ibl].scales_h;
3738 __m256i sumi1 = _mm256_setzero_si256();
3739 __m256i sumi2 = _mm256_setzero_si256();
3740 for (int ib = 0; ib < QK_K/32; ib += 2) {
3741 const __m128i q4bits_1 = _mm_loadu_si128(p: (const __m128i*)qs); qs += 16;
3742 const __m128i q4bits_2 = _mm_loadu_si128(p: (const __m128i*)qs); qs += 16;
3743 const __m256i q8b_1 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3744 const __m256i q8b_2 = _mm256_loadu_si256(p: (const __m256i *)q8); q8 += 32;
3745 const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
3746 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
3747 const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
3748 _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
3749 const __m256i p16_1 = mul_add_epi8(x: q4b_1, y: q8b_1);
3750 const __m256i p16_2 = mul_add_epi8(x: q4b_2, y: q8b_2);
3751 const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3752 const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
3753 sh >>= 4;
3754 const __m256i p_1 = _mm256_madd_epi16(a: p16_1, b: _mm256_set1_epi16(w: ls1));
3755 const __m256i p_2 = _mm256_madd_epi16(a: p16_2, b: _mm256_set1_epi16(w: ls2));
3756 sumi1 = _mm256_add_epi32(a: p_1, b: sumi1);
3757 sumi2 = _mm256_add_epi32(a: p_2, b: sumi2);
3758 }
3759 accum = _mm256_fmadd_ps(A: _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3760 B: _mm256_cvtepi32_ps(a: _mm256_add_epi32(a: sumi1, b: sumi2)), C: accum);
3761 }
3762
3763 *s = hsum_float_8(x: accum);
3764
3765#elif defined __AVX__
3766 const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
3767 const __m128i m4b = _mm_set1_epi8(0x0f);
3768
3769 __m256 accum = _mm256_setzero_ps();
3770 for (int ibl = 0; ibl < nb; ++ibl) {
3771 const uint8_t * qs = x[ibl].qs;
3772 const int8_t * q8 = y[ibl].qs;
3773 uint16_t sh = x[ibl].scales_h;
3774 __m128i sumi1_0 = _mm_setzero_si128();
3775 __m128i sumi1_1 = _mm_setzero_si128();
3776 __m128i sumi2_0 = _mm_setzero_si128();
3777 __m128i sumi2_1 = _mm_setzero_si128();
3778 for (int ib = 0; ib < QK_K/32; ib += 2) {
3779 const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3780 const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
3781 const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3782 const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3783 const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3784 const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
3785 const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
3786 const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
3787 const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
3788 const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
3789 const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
3790 const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
3791 const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
3792 const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
3793 const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
3794 const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
3795 sh >>= 4;
3796 const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
3797 const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
3798 const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
3799 const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
3800 sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
3801 sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
3802 sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
3803 sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
3804 }
3805 __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
3806 __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
3807 accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
3808 _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
3809 }
3810
3811 *s = hsum_float_8(accum);
3812
3813#else
3814 UNUSED(x);
3815 UNUSED(y);
3816 UNUSED(nb);
3817 ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
3818#endif
3819}
3820
3821