| 1 | |
| 2 | #if defined(__GNUC__) |
| 3 | #pragma GCC diagnostic ignored "-Wpedantic" |
| 4 | #pragma GCC diagnostic ignored "-Wunused-local-typedefs" |
| 5 | #endif |
| 6 | |
| 7 | #include "amx.h" |
| 8 | #include "mmq.h" |
| 9 | #include "ggml-impl.h" |
| 10 | #include "ggml-cpu-impl.h" |
| 11 | #include "simd-mappings.h" |
| 12 | #include "quants.h" |
| 13 | #include "ggml-quants.h" |
| 14 | #include <algorithm> |
| 15 | #include <type_traits> |
| 16 | |
| 17 | #if defined(__gnu_linux__) |
| 18 | #include <sys/syscall.h> |
| 19 | #include <unistd.h> |
| 20 | #endif |
| 21 | |
| 22 | #if (defined(_WIN32) || defined(_WIN64)) |
| 23 | #define RESTRICT __restrict |
| 24 | #else |
| 25 | #define RESTRICT __restrict__ |
| 26 | #endif |
| 27 | |
| 28 | #if (defined(_WIN32) || defined(_WIN64)) |
| 29 | #define ALWAYS_INLINE __forceinline |
| 30 | #elif __has_attribute(always_inline) || defined(__GNUC__) |
| 31 | #define ALWAYS_INLINE __attribute__((__always_inline__)) inline |
| 32 | #else |
| 33 | #define ALWAYS_INLINE inline |
| 34 | #endif |
| 35 | |
| 36 | #if defined(__AMX_INT8__) && defined(__AVX512VNNI__) |
| 37 | |
| 38 | namespace { |
| 39 | |
| 40 | // Forced unrolling |
| 41 | template <int n> |
| 42 | struct Unroll { |
| 43 | template <typename Func, typename... Args> |
| 44 | ALWAYS_INLINE void operator()(const Func& f, Args... args) const { |
| 45 | Unroll<n - 1>{}(f, args...); |
| 46 | f(std::integral_constant<int, n - 1>{}, args...); |
| 47 | } |
| 48 | }; |
| 49 | |
| 50 | template <> |
| 51 | struct Unroll<1> { |
| 52 | template <typename Func, typename... Args> |
| 53 | ALWAYS_INLINE void operator()(const Func& f, Args... args) const { |
| 54 | f(std::integral_constant<int, 0>{}, args...); |
| 55 | } |
| 56 | }; |
| 57 | |
| 58 | // type traits |
| 59 | template <typename T> struct PackedTypes {}; |
| 60 | template <> struct PackedTypes<block_q4_0> { using type = int8_t; }; |
| 61 | template <> struct PackedTypes<block_q4_1> { using type = uint8_t; }; |
| 62 | template <> struct PackedTypes<block_q8_0> { using type = int8_t; }; |
| 63 | template <typename T> using packed_B_type = typename PackedTypes<T>::type; |
| 64 | |
| 65 | template <typename T> |
| 66 | struct do_compensate : std::integral_constant<bool, |
| 67 | std::is_same<T, block_q8_0>::value> {}; |
| 68 | |
| 69 | template <typename T> |
| 70 | struct do_unpack : std::integral_constant<bool, |
| 71 | std::is_same<T, block_q4_0>::value || |
| 72 | std::is_same<T, block_q4_1>::value> {}; |
| 73 | |
| 74 | template <typename T> |
| 75 | struct is_type_qkk : std::integral_constant<bool, |
| 76 | std::is_same<T, block_q4_K>::value || |
| 77 | std::is_same<T, block_q5_K>::value || |
| 78 | std::is_same<T, block_q6_K>::value || |
| 79 | std::is_same<T, block_iq4_xs>::value> {}; |
| 80 | |
| 81 | #define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ |
| 82 | [&] { \ |
| 83 | switch (TYPE) { \ |
| 84 | case GGML_TYPE_F16: { \ |
| 85 | using type = ggml_fp16_t; \ |
| 86 | constexpr int blck_size = 16; \ |
| 87 | return __VA_ARGS__(); \ |
| 88 | } \ |
| 89 | case GGML_TYPE_BF16: { \ |
| 90 | using type = ggml_bf16_t; \ |
| 91 | constexpr int blck_size = 32; \ |
| 92 | return __VA_ARGS__(); \ |
| 93 | } \ |
| 94 | default: \ |
| 95 | fprintf(stderr, "Unsupported floating data type\n"); \ |
| 96 | } \ |
| 97 | }() |
| 98 | |
| 99 | #define GGML_DISPATCH_QTYPES(QT, ...) \ |
| 100 | [&] { \ |
| 101 | switch (QT) { \ |
| 102 | case GGML_TYPE_Q4_0: { \ |
| 103 | using type = block_q4_0; \ |
| 104 | using vec_dot_type = block_q8_0; \ |
| 105 | constexpr int blck_size = QK4_0; \ |
| 106 | return __VA_ARGS__(); \ |
| 107 | } \ |
| 108 | case GGML_TYPE_Q4_1: { \ |
| 109 | using type = block_q4_1; \ |
| 110 | using vec_dot_type = block_q8_1; \ |
| 111 | constexpr int blck_size = QK4_1; \ |
| 112 | return __VA_ARGS__(); \ |
| 113 | } \ |
| 114 | case GGML_TYPE_Q8_0: { \ |
| 115 | using type = block_q8_0; \ |
| 116 | using vec_dot_type = block_q8_0; \ |
| 117 | constexpr int blck_size = QK8_0; \ |
| 118 | return __VA_ARGS__(); \ |
| 119 | } \ |
| 120 | case GGML_TYPE_Q4_K: { \ |
| 121 | using type = block_q4_K; \ |
| 122 | using vec_dot_type = block_q8_K; \ |
| 123 | constexpr int blck_size = QK_K; \ |
| 124 | return __VA_ARGS__(); \ |
| 125 | } \ |
| 126 | case GGML_TYPE_Q5_K: { \ |
| 127 | using type = block_q5_K; \ |
| 128 | using vec_dot_type = block_q8_K; \ |
| 129 | constexpr int blck_size = QK_K; \ |
| 130 | return __VA_ARGS__(); \ |
| 131 | } \ |
| 132 | case GGML_TYPE_Q6_K: { \ |
| 133 | using type = block_q6_K; \ |
| 134 | using vec_dot_type = block_q8_K; \ |
| 135 | constexpr int blck_size = QK_K; \ |
| 136 | return __VA_ARGS__(); \ |
| 137 | } \ |
| 138 | case GGML_TYPE_IQ4_XS: { \ |
| 139 | using type = block_iq4_xs; \ |
| 140 | using vec_dot_type = block_q8_K; \ |
| 141 | constexpr int blck_size = QK_K; \ |
| 142 | return __VA_ARGS__(); \ |
| 143 | } \ |
| 144 | default: \ |
| 145 | fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ |
| 146 | } \ |
| 147 | }() |
| 148 | |
| 149 | #define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ |
| 150 | [&] { \ |
| 151 | if (BOOL_V) { \ |
| 152 | constexpr bool BOOL_NAME = true; \ |
| 153 | return __VA_ARGS__(); \ |
| 154 | } else { \ |
| 155 | constexpr bool BOOL_NAME = false; \ |
| 156 | return __VA_ARGS__(); \ |
| 157 | } \ |
| 158 | }() |
| 159 | |
| 160 | // define amx tile config data structure |
| 161 | struct tile_config_t{ |
| 162 | uint8_t palette_id = 0; |
| 163 | uint8_t start_row = 0; |
| 164 | uint8_t reserved_0[14] = {0}; |
| 165 | uint16_t colsb[16] = {0}; |
| 166 | uint8_t rows[16] = {0}; |
| 167 | }; |
| 168 | |
| 169 | // Notes: amx tile config |
| 170 | // |
| 171 | // Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, |
| 172 | // and accumulate the result to a 16 x 16 matrix C containing INT32 values, |
| 173 | // |
| 174 | // As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used |
| 175 | // instead of the normally used 16-16-64 config. |
| 176 | // |
| 177 | // Block A: {16, 32}, dtype = int8_t |
| 178 | // Block B: {16, 32}, dtype = uint8_t/int8_t |
| 179 | // Block C: {16, 16}, dtype = int32_t |
| 180 | // |
| 181 | // Block B needs to be prepacked to vnni format before feeding into TMUL: |
| 182 | // packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} |
| 183 | // |
| 184 | // Therefore, we get tileconfig: |
| 185 | // A B C |
| 186 | // rows 16 8 16 |
| 187 | // colsb 32 64 16 |
| 188 | // |
| 189 | // For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, |
| 190 | // C used TMM4-TMM7: |
| 191 | // B TMM0 B TMM1 |
| 192 | // A TMM2 C TMM4 C TMM6 |
| 193 | // A TMM3 C TMM5 C TMM7 |
| 194 | // |
| 195 | // Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A |
| 196 | // will be needed. |
| 197 | // |
| 198 | // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; |
| 199 | // and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. |
| 200 | // |
| 201 | // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ |
| 202 | // advanced-matrix-extensions-intrinsics-functions.html |
| 203 | // |
| 204 | |
| 205 | #define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb |
| 206 | void ggml_tile_config_init(void) { |
| 207 | static thread_local bool is_first_time = true; |
| 208 | |
| 209 | if (!is_first_time) { |
| 210 | return; |
| 211 | } |
| 212 | |
| 213 | static thread_local tile_config_t tc; |
| 214 | tile_config_t current_tc; |
| 215 | _tile_storeconfig(¤t_tc); |
| 216 | |
| 217 | // load only when config changes |
| 218 | if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && |
| 219 | memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { |
| 220 | tc.palette_id = 1; |
| 221 | tc.start_row = 0; |
| 222 | TC_CONFIG_TILE(TMM0, 8, 64); |
| 223 | TC_CONFIG_TILE(TMM1, 8, 64); |
| 224 | TC_CONFIG_TILE(TMM2, 16, 32); |
| 225 | TC_CONFIG_TILE(TMM3, 16, 32); |
| 226 | TC_CONFIG_TILE(TMM4, 16, 64); |
| 227 | TC_CONFIG_TILE(TMM5, 16, 64); |
| 228 | TC_CONFIG_TILE(TMM6, 16, 64); |
| 229 | TC_CONFIG_TILE(TMM7, 16, 64); |
| 230 | _tile_loadconfig(&tc); |
| 231 | } |
| 232 | |
| 233 | is_first_time = false; |
| 234 | } |
| 235 | |
| 236 | // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. |
| 237 | // See the notes `s8s8 igemm compensation in avx512-vnni` for detail. |
| 238 | template <typename TB> |
| 239 | int get_tile_size() { |
| 240 | int tile_size = TILE_N * sizeof(TB); |
| 241 | if (do_compensate<TB>::value) { |
| 242 | tile_size += TILE_N * sizeof(int32_t); |
| 243 | } |
| 244 | if (std::is_same<TB, block_q4_K>::value || |
| 245 | std::is_same<TB, block_q5_K>::value) { |
| 246 | tile_size += TILE_N * 4; |
| 247 | } |
| 248 | if (std::is_same<TB, block_iq4_xs>::value) { |
| 249 | tile_size += TILE_N * 2; |
| 250 | } |
| 251 | return tile_size; |
| 252 | } |
| 253 | |
| 254 | template <typename TB, int BLOCK_K> |
| 255 | int get_row_size(int K) { |
| 256 | int KB = K / BLOCK_K; |
| 257 | int row_size = KB * sizeof(TB); |
| 258 | if (do_compensate<TB>::value) { |
| 259 | row_size += KB * sizeof(int32_t); |
| 260 | } |
| 261 | if (std::is_same<TB, block_q4_K>::value || |
| 262 | std::is_same<TB, block_q5_K>::value) { |
| 263 | row_size += KB * 4; |
| 264 | } |
| 265 | if (std::is_same<TB, block_iq4_xs>::value) { |
| 266 | row_size += KB * 2; |
| 267 | } |
| 268 | return row_size; |
| 269 | } |
| 270 | |
| 271 | // vectorized dtype conversion |
| 272 | inline float FP16_TO_FP32(ggml_half val) { |
| 273 | __m256i v = _mm256_setr_epi16( |
| 274 | val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); |
| 275 | __m512 o = _mm512_cvtph_ps(v); |
| 276 | return _mm512_cvtss_f32(o); |
| 277 | } |
| 278 | |
| 279 | inline __m512 FP16_TO_FP32_VEC(ggml_half val) { |
| 280 | __m256i v = _mm256_set1_epi16(val); |
| 281 | return _mm512_cvtph_ps(v); |
| 282 | } |
| 283 | |
| 284 | // horizontal reduce |
| 285 | inline float _mm512_reduce_max_ps(const __m512 x) { |
| 286 | __m512 v = x; |
| 287 | __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); |
| 288 | v = _mm512_max_ps(v, v1); |
| 289 | v1 = _mm512_shuffle_f32x4(v, v, 0xB1); |
| 290 | v = _mm512_max_ps(v, v1); |
| 291 | v1 = _mm512_shuffle_ps(v, v, 0x4E); |
| 292 | v = _mm512_max_ps(v, v1); |
| 293 | v1 = _mm512_shuffle_ps(v, v, 0xB1); |
| 294 | v = _mm512_max_ps(v, v1); |
| 295 | return _mm512_cvtss_f32(v); |
| 296 | } |
| 297 | |
| 298 | // transpose utils |
| 299 | #define SHUFFLE_EPI32(a, b, mask) \ |
| 300 | _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) |
| 301 | inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { |
| 302 | // unpacking and 32-bit elements |
| 303 | v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); |
| 304 | v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); |
| 305 | v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); |
| 306 | v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); |
| 307 | v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); |
| 308 | v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); |
| 309 | v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); |
| 310 | v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); |
| 311 | |
| 312 | // shuffling the 32-bit elements |
| 313 | v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); |
| 314 | v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); |
| 315 | v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); |
| 316 | v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); |
| 317 | v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); |
| 318 | v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); |
| 319 | v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); |
| 320 | v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); |
| 321 | |
| 322 | // shuffling 128-bit elements |
| 323 | v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); |
| 324 | v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); |
| 325 | v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); |
| 326 | v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); |
| 327 | v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); |
| 328 | v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); |
| 329 | v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); |
| 330 | v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); |
| 331 | } |
| 332 | |
| 333 | inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { |
| 334 | |
| 335 | static const __m512i index1 = _mm512_set_epi32( |
| 336 | 0x0f, 0x0b, 0x07, 0x03, |
| 337 | 0x0e, 0x0a, 0x06, 0x02, |
| 338 | 0x0d, 0x09, 0x05, 0x01, |
| 339 | 0x0c, 0x08, 0x04, 0x00); |
| 340 | |
| 341 | d[0] = _mm512_permutexvar_epi32(index1, r[0]); |
| 342 | d[1] = _mm512_permutexvar_epi32(index1, r[1]); |
| 343 | d[2] = _mm512_permutexvar_epi32(index1, r[2]); |
| 344 | d[3] = _mm512_permutexvar_epi32(index1, r[3]); |
| 345 | |
| 346 | r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); |
| 347 | r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); |
| 348 | r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); |
| 349 | r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); |
| 350 | |
| 351 | d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); |
| 352 | d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); |
| 353 | d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); |
| 354 | d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); |
| 355 | } |
| 356 | |
| 357 | inline void transpose_16x16_32bit(__m512i * v) { |
| 358 | __m512i v1[16]; |
| 359 | v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); |
| 360 | v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); |
| 361 | v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); |
| 362 | v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); |
| 363 | v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); |
| 364 | v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); |
| 365 | v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); |
| 366 | v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); |
| 367 | v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); |
| 368 | v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); |
| 369 | v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); |
| 370 | v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); |
| 371 | v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); |
| 372 | v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); |
| 373 | v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); |
| 374 | v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); |
| 375 | |
| 376 | v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); |
| 377 | v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); |
| 378 | v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); |
| 379 | v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); |
| 380 | v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); |
| 381 | v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); |
| 382 | v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); |
| 383 | v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); |
| 384 | v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); |
| 385 | v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); |
| 386 | v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); |
| 387 | v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); |
| 388 | v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); |
| 389 | v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); |
| 390 | v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); |
| 391 | v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); |
| 392 | |
| 393 | v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); |
| 394 | v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); |
| 395 | v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); |
| 396 | v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); |
| 397 | v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); |
| 398 | v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); |
| 399 | v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); |
| 400 | v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); |
| 401 | v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); |
| 402 | v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); |
| 403 | v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); |
| 404 | v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); |
| 405 | v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); |
| 406 | v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); |
| 407 | v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); |
| 408 | v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); |
| 409 | |
| 410 | v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); |
| 411 | v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); |
| 412 | v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); |
| 413 | v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); |
| 414 | v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); |
| 415 | v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); |
| 416 | v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); |
| 417 | v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); |
| 418 | v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); |
| 419 | v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); |
| 420 | v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); |
| 421 | v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); |
| 422 | v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); |
| 423 | v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); |
| 424 | v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); |
| 425 | v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); |
| 426 | } |
| 427 | |
| 428 | void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { |
| 429 | assert(k % QK_K == 0); |
| 430 | const int KB = k / QK_K; |
| 431 | constexpr int kVecs = QK_K / 16; |
| 432 | |
| 433 | block_q8_K * y = reinterpret_cast<block_q8_K *>(vy); |
| 434 | |
| 435 | // hold 16 float vecs from x |
| 436 | __m512 v[kVecs]; |
| 437 | |
| 438 | // hold the quants vecs |
| 439 | __m512i vq[kVecs / 4]; |
| 440 | |
| 441 | // hold the packed quants vecs |
| 442 | __m512i vq_packed[kVecs / 4]; |
| 443 | |
| 444 | const __m512 signBit = _mm512_set1_ps(-0.f); |
| 445 | |
| 446 | for (int i = 0; i < KB; ++i) { |
| 447 | // Compute max(abs(e)) for the block |
| 448 | __m512 vamax = _mm512_set1_ps(0.f); |
| 449 | for (int j = 0; j < kVecs; ++j) { |
| 450 | v[j] = _mm512_loadu_ps(x); x += 16; |
| 451 | vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); |
| 452 | } |
| 453 | const float amax = _mm512_reduce_max_ps(vamax); |
| 454 | |
| 455 | // Quantize these floats |
| 456 | const float iscale = 127.f / amax; |
| 457 | y[i].d = GGML_CPU_FP32_TO_FP16(1 / iscale); |
| 458 | const float id = ( amax != 0.0f ) ? iscale : 0.f; |
| 459 | const __m512 vscale = _mm512_set1_ps(id); |
| 460 | |
| 461 | // Apply multiplier and round to nearest integer |
| 462 | for (int j = 0; j < kVecs; ++j) { |
| 463 | v[j] = _mm512_mul_ps(v[j], vscale); |
| 464 | v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
| 465 | } |
| 466 | |
| 467 | // Pack to epi8 vecs |
| 468 | for (int j = 0; j < kVecs / 4; ++j) { |
| 469 | __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); |
| 470 | __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); |
| 471 | __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); |
| 472 | __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); |
| 473 | |
| 474 | __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); |
| 475 | __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); |
| 476 | |
| 477 | vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); |
| 478 | _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); |
| 479 | } |
| 480 | |
| 481 | // Compute the bsums with vnni |
| 482 | transpose_16x4_32bit(vq, vq_packed); |
| 483 | |
| 484 | const __m512i one = _mm512_set1_epi8(1); |
| 485 | __m512i sum = _mm512_setzero_si512(); |
| 486 | for (int k = 0; k < 4; ++k) { |
| 487 | sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); |
| 488 | } |
| 489 | _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); |
| 490 | } |
| 491 | } |
| 492 | |
| 493 | // quantize A from float to `vec_dot_type` |
| 494 | template <typename T> |
| 495 | inline void from_float(const float * x, char * vy, int64_t k); |
| 496 | |
| 497 | template <> |
| 498 | inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) { |
| 499 | quantize_row_q8_0(x, (block_q8_0 *)vy, k); |
| 500 | } |
| 501 | |
| 502 | template <> |
| 503 | inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) { |
| 504 | quantize_row_q8_1(x, (block_q8_1 *)vy, k); |
| 505 | } |
| 506 | |
| 507 | template <> |
| 508 | inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) { |
| 509 | #if 1 |
| 510 | // TODO: this is reference impl! |
| 511 | quantize_row_q8_K_ref(x, (block_q8_K *)vy, k); |
| 512 | #else |
| 513 | quantize_row_q8_K_vnni(x, vy, k); |
| 514 | #endif |
| 515 | } |
| 516 | |
| 517 | // load A from memory to array when nrows can not fill in whole tile |
| 518 | void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { |
| 519 | assert(nr != TILE_M); |
| 520 | for (int m = 0; m < nr; ++m) { |
| 521 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); |
| 522 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); |
| 523 | } |
| 524 | } |
| 525 | |
| 526 | void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { |
| 527 | assert(nr != TILE_M); |
| 528 | for (int m = 0; m < nr; ++m) { |
| 529 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); |
| 530 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); |
| 531 | } |
| 532 | } |
| 533 | |
| 534 | template <typename TB> |
| 535 | void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { |
| 536 | assert(nr <= TILE_M); |
| 537 | for (int m = 0; m < nr; ++m) { |
| 538 | const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); |
| 539 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); |
| 540 | } |
| 541 | } |
| 542 | |
| 543 | template <> |
| 544 | void unpack_A<block_q6_K>(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { |
| 545 | assert(nr <= TILE_M); |
| 546 | // zero padding k from 16 to 32, so that we don't have to re-config amx |
| 547 | const __m128i zero = _mm_setzero_si128(); |
| 548 | for (int m = 0; m < nr; ++m) { |
| 549 | const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); |
| 550 | const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); |
| 551 | _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); |
| 552 | } |
| 553 | } |
| 554 | |
| 555 | #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) |
| 556 | inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { |
| 557 | const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); |
| 558 | const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); |
| 559 | const __m256i lowMask = _mm256_set1_epi8(0xF); |
| 560 | return _mm256_and_si256(lowMask, bytes); |
| 561 | } |
| 562 | |
| 563 | // used for block_q4_K |
| 564 | inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { |
| 565 | const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); |
| 566 | const __m256i lowMask = _mm256_set1_epi8(0xF); |
| 567 | const __m256i q4l = _mm256_and_si256(tmp, lowMask); |
| 568 | const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); |
| 569 | return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); |
| 570 | } |
| 571 | |
| 572 | // used for block_q5_K |
| 573 | inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { |
| 574 | const __m256i lowMask = _mm256_set1_epi8(0xF); |
| 575 | __m256i hmask = _mm256_set1_epi8(1); |
| 576 | hmask = _mm256_slli_epi16(hmask, k); |
| 577 | |
| 578 | const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); |
| 579 | const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); |
| 580 | |
| 581 | const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); |
| 582 | const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); |
| 583 | const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); |
| 584 | hmask = _mm256_slli_epi16(hmask, 1); |
| 585 | |
| 586 | const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); |
| 587 | const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); |
| 588 | const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); |
| 589 | |
| 590 | return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); |
| 591 | } |
| 592 | |
| 593 | // used for block_q6_K |
| 594 | inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { |
| 595 | const __m256i m4 = _mm256_set1_epi8(0xF); |
| 596 | const __m256i m2 = _mm256_set1_epi8(0x3); |
| 597 | |
| 598 | const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); |
| 599 | const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); |
| 600 | const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); |
| 601 | |
| 602 | const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); |
| 603 | const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); |
| 604 | const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); |
| 605 | const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); |
| 606 | |
| 607 | const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); |
| 608 | const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); |
| 609 | const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); |
| 610 | const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); |
| 611 | |
| 612 | r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); |
| 613 | r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); |
| 614 | } |
| 615 | |
| 616 | inline __m512i packNibbles(__m512i r0, __m512i r1) { |
| 617 | return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); |
| 618 | } |
| 619 | |
| 620 | template <typename TB> |
| 621 | inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { |
| 622 | int8_t tmp[8 * 64]; |
| 623 | __m256i v[8], v2[8]; |
| 624 | for (int n = 0; n < 8; ++n) { |
| 625 | v[n] = bytes_from_nibbles_32(B[n * KB].qs); |
| 626 | } |
| 627 | transpose_8x8_32bit(v, v2); |
| 628 | for (int n = 0; n < 8; ++n) { |
| 629 | _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); |
| 630 | } |
| 631 | for (int n = 0; n < 8; ++n) { |
| 632 | v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); |
| 633 | } |
| 634 | transpose_8x8_32bit(v, v2); |
| 635 | for (int n = 0; n < 8; ++n) { |
| 636 | _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); |
| 637 | } |
| 638 | |
| 639 | // pack again with 128 to fully utilize vector length |
| 640 | for (int n = 0; n < 8; n += 2) { |
| 641 | __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); |
| 642 | __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); |
| 643 | __m512i r1r0 = packNibbles(r0, r1); |
| 644 | _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); |
| 645 | } |
| 646 | } |
| 647 | |
| 648 | template <> |
| 649 | inline void pack_qs<block_q8_0>(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { |
| 650 | __m256i v[8], v2[8]; |
| 651 | for (int n = 0; n < 8; ++n) { |
| 652 | v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); |
| 653 | } |
| 654 | transpose_8x8_32bit(v, v2); |
| 655 | for (int n = 0; n < 8; ++n) { |
| 656 | _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); |
| 657 | } |
| 658 | for (int n = 0; n < 8; ++n) { |
| 659 | v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); |
| 660 | } |
| 661 | transpose_8x8_32bit(v, v2); |
| 662 | for (int n = 0; n < 8; ++n) { |
| 663 | _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); |
| 664 | } |
| 665 | } |
| 666 | |
| 667 | template <> |
| 668 | inline void pack_qs<block_q4_K>(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { |
| 669 | __m512i v[16]; |
| 670 | // QK_K 256 with 8 groups, handle 2 groups at a time |
| 671 | char * pb = (char *)packed_B; |
| 672 | for (int k = 0; k < QK_K / 64; ++k) { |
| 673 | // pack 2 groups { n, g, k} to {g, k/4, 4n} |
| 674 | // e.g. {16, 2, 32} to {2, 8, 64} |
| 675 | for (int n = 0; n < TILE_N; ++n) { |
| 676 | v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); |
| 677 | } |
| 678 | |
| 679 | transpose_16x16_32bit(v); |
| 680 | |
| 681 | // pack again with 128 to fully utilize vector length |
| 682 | for (int n = 0; n < TILE_N; n += 2) { |
| 683 | _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); |
| 684 | pb += 64; |
| 685 | } |
| 686 | } |
| 687 | } |
| 688 | |
| 689 | template <> |
| 690 | inline void pack_qs<block_q5_K>(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { |
| 691 | __m512i v[16]; |
| 692 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 693 | // QK_K 256 with 8 groups, handle 2 groups at a time |
| 694 | char * pb = (char *)packed_B; |
| 695 | char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; |
| 696 | for (int k = 0; k < QK_K / 64; ++k) { |
| 697 | // pack 2 groups { n, g, k} to {g, k/4, 4n} |
| 698 | // e.g. {16, 2, 32} to {2, 8, 64} |
| 699 | for (int n = 0; n < TILE_N; ++n) { |
| 700 | v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); |
| 701 | } |
| 702 | |
| 703 | transpose_16x16_32bit(v); |
| 704 | |
| 705 | // 1. pack lower 4bits with 2 groups |
| 706 | for (int n = 0; n < TILE_N; n += 2) { |
| 707 | // get lower 4 bits |
| 708 | const __m512i r0 = _mm512_and_si512(v[n], lowMask); |
| 709 | const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); |
| 710 | _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; |
| 711 | } |
| 712 | |
| 713 | // 2. pack higher 1bit with 2 groups |
| 714 | const __m512i hmask = _mm512_set1_epi8(0x10); |
| 715 | for (int g = 0; g < 2; ++g) { |
| 716 | __m512i hbits = _mm512_setzero_si512(); |
| 717 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); |
| 718 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); |
| 719 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); |
| 720 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); |
| 721 | hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); |
| 722 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); |
| 723 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); |
| 724 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); |
| 725 | _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; |
| 726 | } |
| 727 | } |
| 728 | } |
| 729 | |
| 730 | template <> |
| 731 | inline void pack_qs<block_q6_K>(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { |
| 732 | __m512i v[32]; |
| 733 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 734 | // QK_K 256 with 8 groups, handle 4 groups at a time |
| 735 | char * pb = (char *)packed_B; |
| 736 | char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; |
| 737 | for (int k = 0; k < QK_K / 128; ++k) { |
| 738 | for (int n = 0; n < TILE_N; ++n) { |
| 739 | bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); |
| 740 | } |
| 741 | |
| 742 | // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 |
| 743 | transpose_16x16_32bit(v); |
| 744 | transpose_16x16_32bit(v + 16); |
| 745 | |
| 746 | // 1. pack lower 4bits with 4 groups |
| 747 | for (int n = 0; n < 32; n += 2) { |
| 748 | const __m512i r0 = _mm512_and_si512(v[n], lowMask); |
| 749 | const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); |
| 750 | _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; |
| 751 | } |
| 752 | |
| 753 | // 2. pack higher 2bit with 4 groups |
| 754 | const __m512i hmask = _mm512_set1_epi8(0x30); |
| 755 | for (int g = 0; g < 8; ++g) { |
| 756 | __m512i hbits = _mm512_setzero_si512(); |
| 757 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); |
| 758 | hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); |
| 759 | hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); |
| 760 | hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); |
| 761 | _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; |
| 762 | } |
| 763 | } |
| 764 | } |
| 765 | |
| 766 | template <> |
| 767 | inline void pack_qs<block_iq4_xs>(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { |
| 768 | __m512i v[16]; |
| 769 | char * pb = (char *)packed_B; |
| 770 | for (int k = 0; k < QK_K / 64; ++k) { |
| 771 | for (int n = 0; n < TILE_N; ++n) { |
| 772 | __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); |
| 773 | __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); |
| 774 | v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); |
| 775 | } |
| 776 | |
| 777 | transpose_16x16_32bit(v); |
| 778 | |
| 779 | // pack again with 128 to fully utilize vector length |
| 780 | for (int n = 0; n < TILE_N; n += 2) { |
| 781 | _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); |
| 782 | pb += 64; |
| 783 | } |
| 784 | } |
| 785 | } |
| 786 | |
| 787 | // pack B to vnni formats in 4bits or 8 bits |
| 788 | void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { |
| 789 | pack_qs(packed_B, B, KB); |
| 790 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2); |
| 791 | for (int n = 0; n < TILE_N; ++n) { |
| 792 | d0[n] = B[n * KB].d; |
| 793 | } |
| 794 | } |
| 795 | |
| 796 | void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { |
| 797 | pack_qs(packed_B, B, KB); |
| 798 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K / 2); |
| 799 | ggml_half * m0 = d0 + TILE_N; |
| 800 | for (int n = 0; n < TILE_N; ++n) { |
| 801 | d0[n] = B[n * KB].d; |
| 802 | m0[n] = B[n * KB].m; |
| 803 | } |
| 804 | } |
| 805 | |
| 806 | inline void s8s8_compensation(void * RESTRICT packed_B) { |
| 807 | // packed_B layout: |
| 808 | // quants {TILE_N, TILEK} int8_t |
| 809 | // d0 {TILE_N} ggml_half |
| 810 | // comp {TILE_N} int32_t |
| 811 | const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); |
| 812 | __m512i vcomp = _mm512_setzero_si512(); |
| 813 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); |
| 814 | for (int k = 0; k < 8; ++k) { |
| 815 | __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); |
| 816 | vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); |
| 817 | } |
| 818 | _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); |
| 819 | } |
| 820 | |
| 821 | void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { |
| 822 | pack_qs(packed_B, B, KB); |
| 823 | ggml_half * d0 = reinterpret_cast<ggml_half *>((char *)packed_B + TILE_N * TILE_K); |
| 824 | for (int n = 0; n < TILE_N; ++n) { |
| 825 | d0[n] = B[n * KB].d; |
| 826 | } |
| 827 | s8s8_compensation(packed_B); |
| 828 | } |
| 829 | |
| 830 | // convert 8 * {min, scale} from int6 to int8 |
| 831 | inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { |
| 832 | const uint32_t kmask1 = 0x3f3f3f3f; |
| 833 | const uint32_t kmask2 = 0x0f0f0f0f; |
| 834 | const uint32_t kmask3 = 0x03030303; |
| 835 | |
| 836 | memcpy(utmp, scales, 12); |
| 837 | utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); |
| 838 | const uint32_t uaux = utmp[1] & kmask1; |
| 839 | utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); |
| 840 | utmp[2] = uaux; |
| 841 | utmp[0] &= kmask1; |
| 842 | } |
| 843 | |
| 844 | // packed_B layout: |
| 845 | // quants {8, TILE_N, 16} uint8 |
| 846 | // scales {8, TILE_N} uint8 |
| 847 | // mins {8, TILE_N} uint8 |
| 848 | // d {TILE_N} ggml_half |
| 849 | // dmin {TILE_N} ggml_half |
| 850 | void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { |
| 851 | pack_qs(packed_B, B, KB); |
| 852 | |
| 853 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N); |
| 854 | uint8_t * mins = scales + 8 * TILE_N; |
| 855 | ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N); |
| 856 | ggml_half * dmin = d + TILE_N; |
| 857 | |
| 858 | union { |
| 859 | uint32_t u32[4]; |
| 860 | uint8_t u8[16]; |
| 861 | } s; |
| 862 | |
| 863 | for (int n = 0; n < TILE_N; ++n) { |
| 864 | unpack_mins_and_scales(B[n * KB].scales, s.u32); |
| 865 | for (int k = 0; k < 8; ++k) { |
| 866 | scales[k * TILE_N + n] = s.u8[k]; |
| 867 | mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; |
| 868 | } |
| 869 | d[n] = B[n * KB].d; |
| 870 | dmin[n] = B[n * KB].dmin; |
| 871 | } |
| 872 | } |
| 873 | |
| 874 | // packed_B layout: |
| 875 | // quants {8, TILE_N, 16} uint8 |
| 876 | // qh {8, TILE_N, 4} uint8 |
| 877 | // scales {8, TILE_N} uint8 |
| 878 | // mins {8, TILE_N} uint8 |
| 879 | // d {TILE_N} ggml_half |
| 880 | // dmin {TILE_N} ggml_half |
| 881 | void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { |
| 882 | pack_qs(packed_B, B, KB); |
| 883 | |
| 884 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); |
| 885 | uint8_t * mins = scales + 8 * TILE_N; |
| 886 | ggml_half * d = reinterpret_cast<ggml_half *>(mins + 8 * TILE_N); |
| 887 | ggml_half * dmin = d + TILE_N; |
| 888 | |
| 889 | union { |
| 890 | uint32_t u32[4]; |
| 891 | uint8_t u8[16]; |
| 892 | } s; |
| 893 | |
| 894 | for (int n = 0; n < TILE_N; ++n) { |
| 895 | unpack_mins_and_scales(B[n * KB].scales, s.u32); |
| 896 | for (int k = 0; k < 8; ++k) { |
| 897 | scales[k * TILE_N + n] = s.u8[k]; |
| 898 | mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; |
| 899 | } |
| 900 | d[n] = B[n * KB].d; |
| 901 | dmin[n] = B[n * KB].dmin; |
| 902 | } |
| 903 | } |
| 904 | |
| 905 | // packed_B layout: |
| 906 | // quants {16, TILE_N, 8} uint8 |
| 907 | // qh {16, TILE_N, 4} uint8 |
| 908 | // scales {16, TILE_N} uint8 |
| 909 | // d {TILE_N} ggml_half |
| 910 | void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { |
| 911 | pack_qs(packed_B, B, KB); |
| 912 | |
| 913 | uint8_t * scales = reinterpret_cast<uint8_t *>((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); |
| 914 | ggml_half * d = reinterpret_cast<ggml_half *>(scales + 16 * TILE_N); |
| 915 | for (int n = 0; n < TILE_N; ++n) { |
| 916 | const int8_t * ps = B[n * KB].scales; |
| 917 | for (int k = 0; k < 16; ++k) { |
| 918 | scales[k * TILE_N + n] = ps[k]; |
| 919 | } |
| 920 | d[n] = B[n * KB].d; |
| 921 | } |
| 922 | } |
| 923 | |
| 924 | // packed_B layout: |
| 925 | // quants {8, TILE_N, 16} uint8 |
| 926 | // scales {8, TILE_N} int8 |
| 927 | // d {TILE_N} ggml_half |
| 928 | void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { |
| 929 | pack_qs(packed_B, B, KB); |
| 930 | |
| 931 | int8_t * scales = reinterpret_cast<int8_t *>((char *)packed_B + (QK_K / 2) * TILE_N); |
| 932 | ggml_half * d = reinterpret_cast<ggml_half *>(scales + 8 * TILE_N); |
| 933 | |
| 934 | // pack the scales |
| 935 | for (int n = 0; n < TILE_N; ++n) { |
| 936 | uint16_t sh = B[n * KB].scales_h; |
| 937 | for (int k = 0; k < 8; k += 2) { |
| 938 | const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; |
| 939 | const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; |
| 940 | scales[(k + 0) * TILE_N + n] = ls1; |
| 941 | scales[(k + 1) * TILE_N + n] = ls2; |
| 942 | sh >>= 4; |
| 943 | } |
| 944 | d[n] = B[n * KB].d; |
| 945 | } |
| 946 | } |
| 947 | |
| 948 | template<typename TB, typename packed_B_t = packed_B_type<TB>> |
| 949 | void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { |
| 950 | GGML_UNUSED(tile); |
| 951 | GGML_UNUSED(packed_B); |
| 952 | } |
| 953 | |
| 954 | template <> |
| 955 | void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { |
| 956 | const __m512i off = _mm512_set1_epi8(8); |
| 957 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 958 | for (int n = 0; n < 8; n += 2) { |
| 959 | __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); |
| 960 | const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); |
| 961 | const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); |
| 962 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); |
| 963 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); |
| 964 | } |
| 965 | } |
| 966 | |
| 967 | template <> |
| 968 | void unpack_B<block_q4_1>(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { |
| 969 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 970 | for (int n = 0; n < 8; n += 2) { |
| 971 | __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); |
| 972 | const __m512i r0 = _mm512_and_si512(bytes, lowMask); |
| 973 | const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 974 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); |
| 975 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); |
| 976 | } |
| 977 | } |
| 978 | |
| 979 | // packed_B_t for QKK is int8_t |
| 980 | template <typename TB> |
| 981 | void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { |
| 982 | const int packed_B_group_size = QK_K / 2 * TILE_N / 8; |
| 983 | const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; |
| 984 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 985 | for (int n = 0; n < 8; n += 2) { |
| 986 | __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); |
| 987 | const __m512i r0 = _mm512_and_si512(bytes, lowMask); |
| 988 | const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 989 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); |
| 990 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); |
| 991 | } |
| 992 | } |
| 993 | |
| 994 | template <> |
| 995 | void unpack_B<block_q5_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { |
| 996 | // lower 4bits, stride 256 bytes |
| 997 | const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; |
| 998 | const char * pb = (const char *)packed_B + k * packed_l4_group_size; |
| 999 | |
| 1000 | // higher 1bit, stride 64 bytes |
| 1001 | const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; |
| 1002 | const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; |
| 1003 | const __m512i hbits = _mm512_loadu_si512(ph); |
| 1004 | |
| 1005 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1006 | __m512i hmask0 = _mm512_set1_epi8(0x1); |
| 1007 | __m512i hmask1 = _mm512_set1_epi8(0x2); |
| 1008 | |
| 1009 | for (int n = 0; n < 8; n += 2) { |
| 1010 | __m512i bytes = _mm512_loadu_si512(pb + n * 32); |
| 1011 | __m512i r0 = _mm512_and_si512(bytes, lowMask); |
| 1012 | __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1013 | __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); |
| 1014 | __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); |
| 1015 | |
| 1016 | hmask0 = _mm512_slli_epi16(hmask0, 2); |
| 1017 | hmask1 = _mm512_slli_epi16(hmask1, 2); |
| 1018 | r0 = _mm512_add_epi8(r0, h0); |
| 1019 | r1 = _mm512_add_epi8(r1, h1); |
| 1020 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); |
| 1021 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); |
| 1022 | } |
| 1023 | } |
| 1024 | |
| 1025 | template <> |
| 1026 | void unpack_B<block_q6_K>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { |
| 1027 | // lower 4bits, stride 128 bytes |
| 1028 | const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; |
| 1029 | const char * pb = (const char *)packed_B + k * packed_l4_group_size; |
| 1030 | |
| 1031 | // higher 2bits, stride 64 bytes |
| 1032 | const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; |
| 1033 | const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; |
| 1034 | const __m512i hbits = _mm512_loadu_si512(ph); |
| 1035 | |
| 1036 | const __m512i off = _mm512_set1_epi8(32); |
| 1037 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1038 | __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 |
| 1039 | __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 |
| 1040 | |
| 1041 | // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` |
| 1042 | __m512i bytes = _mm512_loadu_si512(pb); |
| 1043 | __m512i r0 = _mm512_and_si512(bytes, lowMask); |
| 1044 | __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1045 | __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); |
| 1046 | __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); |
| 1047 | _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); |
| 1048 | _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); |
| 1049 | |
| 1050 | hmask0 = _mm512_slli_epi16(hmask0, 4); |
| 1051 | hmask1 = _mm512_slli_epi16(hmask1, 4); |
| 1052 | |
| 1053 | bytes = _mm512_loadu_si512(pb + 64); |
| 1054 | r0 = _mm512_and_si512(bytes, lowMask); |
| 1055 | r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1056 | h0 = _mm512_and_si512(hbits, hmask0); |
| 1057 | h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); |
| 1058 | _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); |
| 1059 | _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); |
| 1060 | } |
| 1061 | |
| 1062 | template <> |
| 1063 | void unpack_B<block_iq4_xs>(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { |
| 1064 | static const __m512i values128 = _mm512_set_epi8( |
| 1065 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1066 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1067 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1068 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 |
| 1069 | ); |
| 1070 | |
| 1071 | const int packed_B_group_size = QK_K / 2 * TILE_N / 8; |
| 1072 | const char * pb = (const char *)packed_B + k * packed_B_group_size; |
| 1073 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1074 | |
| 1075 | for (int n = 0; n < 8; n += 2) { |
| 1076 | __m512i bytes = _mm512_loadu_si512(pb + n * 32); |
| 1077 | const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); |
| 1078 | const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); |
| 1079 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); |
| 1080 | _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); |
| 1081 | } |
| 1082 | } |
| 1083 | |
| 1084 | template <typename TA, typename TB, bool is_acc> |
| 1085 | struct acc_C {}; |
| 1086 | |
| 1087 | template <bool is_acc> |
| 1088 | struct acc_C<block_q8_0, block_q4_0, is_acc> { |
| 1089 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { |
| 1090 | const int offset = TILE_N * TILE_K / 2; |
| 1091 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); |
| 1092 | |
| 1093 | for (int m = 0; m < nr; ++m) { |
| 1094 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); |
| 1095 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1096 | |
| 1097 | __m512 vsum; |
| 1098 | if (is_acc) { |
| 1099 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1100 | } else { |
| 1101 | vsum = _mm512_set1_ps(0.f); |
| 1102 | } |
| 1103 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); |
| 1104 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1105 | } |
| 1106 | } |
| 1107 | }; |
| 1108 | |
| 1109 | template <bool is_acc> |
| 1110 | struct acc_C<block_q8_1, block_q4_1, is_acc> { |
| 1111 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { |
| 1112 | const int offset = TILE_N * TILE_K / 2; |
| 1113 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); |
| 1114 | const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half)))); |
| 1115 | |
| 1116 | for (int m = 0; m < nr; ++m) { |
| 1117 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); |
| 1118 | const __m512 vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].s)); |
| 1119 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1120 | |
| 1121 | __m512 vsum; |
| 1122 | if (is_acc) { |
| 1123 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1124 | } else { |
| 1125 | vsum = _mm512_set1_ps(0.f); |
| 1126 | } |
| 1127 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); |
| 1128 | vsum = _mm512_fmadd_ps(vm0, vs1, vsum); |
| 1129 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1130 | } |
| 1131 | } |
| 1132 | }; |
| 1133 | |
| 1134 | template <bool is_acc> |
| 1135 | struct acc_C<block_q8_0, block_q8_0, is_acc> { |
| 1136 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { |
| 1137 | const int offset = TILE_N * TILE_K; |
| 1138 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); |
| 1139 | |
| 1140 | for (int m = 0; m < nr; ++m) { |
| 1141 | const __m512 vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[m * lda].d)); |
| 1142 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1143 | |
| 1144 | __m512 vsum; |
| 1145 | if (is_acc) { |
| 1146 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1147 | } else { |
| 1148 | vsum = _mm512_set1_ps(0.f); |
| 1149 | } |
| 1150 | vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); |
| 1151 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1152 | } |
| 1153 | } |
| 1154 | }; |
| 1155 | |
| 1156 | template <bool is_acc> |
| 1157 | struct acc_C<block_q8_K, block_q4_K, is_acc> { |
| 1158 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { |
| 1159 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N); |
| 1160 | const uint8_t * mins = scales + 8 * TILE_N; |
| 1161 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N); |
| 1162 | const ggml_half * dmin = d0 + TILE_N; |
| 1163 | |
| 1164 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); |
| 1165 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); |
| 1166 | |
| 1167 | for (int m = 0; m < nr; ++m) { |
| 1168 | const float d1 = A[m * lda].d; |
| 1169 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); |
| 1170 | const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); |
| 1171 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1172 | |
| 1173 | __m512 vsum; |
| 1174 | if (is_acc) { |
| 1175 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1176 | } else { |
| 1177 | vsum = _mm512_set1_ps(0.f); |
| 1178 | } |
| 1179 | |
| 1180 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); |
| 1181 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); |
| 1182 | |
| 1183 | __m512i acc_m = _mm512_setzero_si512(); |
| 1184 | for (int k = 0; k < 4; ++k) { |
| 1185 | __m512i vmask = _mm512_set1_epi32(k); |
| 1186 | __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); |
| 1187 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); |
| 1188 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); |
| 1189 | } |
| 1190 | |
| 1191 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); |
| 1192 | vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); |
| 1193 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1194 | } |
| 1195 | } |
| 1196 | }; |
| 1197 | |
| 1198 | template <bool is_acc> |
| 1199 | struct acc_C<block_q8_K, block_q5_K, is_acc> { |
| 1200 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { |
| 1201 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); |
| 1202 | const uint8_t * mins = scales + 8 * TILE_N; |
| 1203 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(mins + 8 * TILE_N); |
| 1204 | const ggml_half * dmin = d0 + TILE_N; |
| 1205 | |
| 1206 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); |
| 1207 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); |
| 1208 | |
| 1209 | for (int m = 0; m < nr; ++m) { |
| 1210 | const float d1 = A[m * lda].d; |
| 1211 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); |
| 1212 | const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); |
| 1213 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1214 | |
| 1215 | __m512 vsum; |
| 1216 | if (is_acc) { |
| 1217 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1218 | } else { |
| 1219 | vsum = _mm512_set1_ps(0.f); |
| 1220 | } |
| 1221 | |
| 1222 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); |
| 1223 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); |
| 1224 | |
| 1225 | __m512i acc_m = _mm512_setzero_si512(); |
| 1226 | for (int k = 0; k < 4; ++k) { |
| 1227 | __m512i vmask = _mm512_set1_epi32(k); |
| 1228 | __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); |
| 1229 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); |
| 1230 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); |
| 1231 | } |
| 1232 | |
| 1233 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); |
| 1234 | vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); |
| 1235 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1236 | } |
| 1237 | } |
| 1238 | }; |
| 1239 | |
| 1240 | template <bool is_acc> |
| 1241 | struct acc_C<block_q8_K, block_q6_K, is_acc> { |
| 1242 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { |
| 1243 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); |
| 1244 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 16 * TILE_N); |
| 1245 | |
| 1246 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); |
| 1247 | |
| 1248 | for (int m = 0; m < nr; ++m) { |
| 1249 | const float d1 = A[m * lda].d; |
| 1250 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); |
| 1251 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1252 | |
| 1253 | __m512 vsum; |
| 1254 | if (is_acc) { |
| 1255 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1256 | } else { |
| 1257 | vsum = _mm512_set1_ps(0.f); |
| 1258 | } |
| 1259 | |
| 1260 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); |
| 1261 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1262 | } |
| 1263 | } |
| 1264 | }; |
| 1265 | |
| 1266 | template <bool is_acc> |
| 1267 | struct acc_C<block_q8_K, block_iq4_xs, is_acc> { |
| 1268 | static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { |
| 1269 | const int8_t * scales = reinterpret_cast<const int8_t *>((const char *)packed_B + (QK_K / 2) * TILE_N); |
| 1270 | const ggml_half * d0 = reinterpret_cast<const ggml_half *>(scales + 8 * TILE_N); |
| 1271 | |
| 1272 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); |
| 1273 | |
| 1274 | for (int m = 0; m < nr; ++m) { |
| 1275 | const float d1 = A[m * lda].d; |
| 1276 | const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); |
| 1277 | const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); |
| 1278 | |
| 1279 | __m512 vsum; |
| 1280 | if (is_acc) { |
| 1281 | vsum = _mm512_loadu_ps(C + m * ldc); |
| 1282 | } else { |
| 1283 | vsum = _mm512_set1_ps(0.f); |
| 1284 | } |
| 1285 | |
| 1286 | vsum = _mm512_fmadd_ps(vtile, vd, vsum); |
| 1287 | _mm512_storeu_ps(C + m * ldc, vsum); |
| 1288 | } |
| 1289 | } |
| 1290 | }; |
| 1291 | |
| 1292 | template <typename TB> constexpr int get_quants_size(); |
| 1293 | template <> constexpr int get_quants_size<block_q4_K>() { return (QK_K / 2) * TILE_N; } |
| 1294 | template <> constexpr int get_quants_size<block_q5_K>() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } |
| 1295 | template <> constexpr int get_quants_size<block_q6_K>() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } |
| 1296 | template <> constexpr int get_quants_size<block_iq4_xs>() { return (QK_K / 2) * TILE_N; } |
| 1297 | |
| 1298 | // used for QKK format |
| 1299 | template <typename TB, bool is_acc, |
| 1300 | typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0> |
| 1301 | inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { |
| 1302 | const uint8_t * scales = reinterpret_cast<const uint8_t *>((const char *)packed_B + get_quants_size<TB>()); |
| 1303 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); |
| 1304 | |
| 1305 | for (int m = 0; m < nr; ++m) { |
| 1306 | __m512i vsumi; |
| 1307 | if (is_acc) { |
| 1308 | vsumi = _mm512_loadu_si512(sumi + m * TILE_N); |
| 1309 | } else { |
| 1310 | vsumi = _mm512_setzero_si512(); |
| 1311 | } |
| 1312 | __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); |
| 1313 | vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); |
| 1314 | _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); |
| 1315 | } |
| 1316 | } |
| 1317 | |
| 1318 | template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1319 | struct tinygemm_kernel_avx { |
| 1320 | static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { |
| 1321 | GGML_UNUSED(K); |
| 1322 | GGML_UNUSED(A); |
| 1323 | GGML_UNUSED(B); |
| 1324 | GGML_UNUSED(C); |
| 1325 | GGML_UNUSED(ldc); |
| 1326 | } |
| 1327 | }; |
| 1328 | |
| 1329 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1330 | struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1331 | static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { |
| 1332 | constexpr int ROWS = BLOCK_M; |
| 1333 | constexpr int COLS = BLOCK_N; |
| 1334 | assert(BLOCK_K == 16); |
| 1335 | |
| 1336 | __m512 va; |
| 1337 | __m512 vb[COLS]; |
| 1338 | __m512 vc[ROWS * COLS]; |
| 1339 | |
| 1340 | auto loadc = [&](auto idx) { |
| 1341 | vc[idx] = _mm512_setzero_ps(); |
| 1342 | }; |
| 1343 | Unroll<ROWS * COLS>{}(loadc); |
| 1344 | |
| 1345 | auto compute = [&](auto idx, auto k) { |
| 1346 | constexpr int row = idx / COLS; |
| 1347 | constexpr int col = idx % COLS; |
| 1348 | |
| 1349 | if constexpr (col == 0) { |
| 1350 | va = _mm512_loadu_ps(A + row * K + k); |
| 1351 | } |
| 1352 | if constexpr (row == 0) { |
| 1353 | vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); |
| 1354 | } |
| 1355 | vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); |
| 1356 | }; |
| 1357 | |
| 1358 | for (int k = 0; k < K; k += 16) { |
| 1359 | Unroll<ROWS * COLS>{}(compute, k); |
| 1360 | } |
| 1361 | |
| 1362 | auto storec = [&](auto idx) { |
| 1363 | constexpr int row = idx / COLS; |
| 1364 | constexpr int col = idx % COLS; |
| 1365 | C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); |
| 1366 | }; |
| 1367 | Unroll<ROWS * COLS>{}(storec); |
| 1368 | } |
| 1369 | }; |
| 1370 | |
| 1371 | #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ |
| 1372 | tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \ |
| 1373 | K, (const float *)src1->data + mb_start * K, \ |
| 1374 | (const type *)src0->data + nb_start * K, \ |
| 1375 | (float *)dst->data + mb_start * ldc + nb_start, ldc); |
| 1376 | |
| 1377 | |
| 1378 | // re-organize in the format {NB, KB, TILE_SIZE}: |
| 1379 | #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size |
| 1380 | |
| 1381 | template<typename TB, int BLOCK_K> |
| 1382 | void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) { |
| 1383 | const int NB = N / TILE_N; |
| 1384 | const int KB = K / BLOCK_K; |
| 1385 | const int TILE_SIZE = get_tile_size<TB>(); |
| 1386 | |
| 1387 | // parallel on NB should be enough |
| 1388 | parallel_for(NB, [&](int begin, int end) { |
| 1389 | for (int n = begin; n < end; ++n) { |
| 1390 | for (int k = 0; k < KB; ++k) { |
| 1391 | int n0 = n * TILE_N; |
| 1392 | pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); |
| 1393 | } |
| 1394 | } |
| 1395 | }); |
| 1396 | } |
| 1397 | |
| 1398 | template <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1399 | struct tinygemm_kernel_vnni {}; |
| 1400 | |
| 1401 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1402 | struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1403 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1404 | |
| 1405 | constexpr int COLS = BLOCK_N / 16; |
| 1406 | const int TILE_SIZE = TILE_N * sizeof(block_q4_0); |
| 1407 | |
| 1408 | const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A); |
| 1409 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1410 | |
| 1411 | __m512i va[8]; |
| 1412 | __m512 vc[COLS]; |
| 1413 | __m512 vd1; |
| 1414 | |
| 1415 | // sum of offsets, shared across COLS |
| 1416 | // |
| 1417 | // avx512-vnni does not have `_mm512_dpbssd_epi32`, |
| 1418 | // need to transfrom ss to us: |
| 1419 | // a * (b - 8) is equavilent to b * a - 8 * a |
| 1420 | // s u u u s u s |
| 1421 | // |
| 1422 | __m512i vcomp; |
| 1423 | |
| 1424 | const __m512i off = _mm512_set1_epi8(8); |
| 1425 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1426 | |
| 1427 | auto loadc = [&](auto col) { |
| 1428 | vc[col] = _mm512_setzero_ps(); |
| 1429 | }; |
| 1430 | Unroll<COLS>{}(loadc); |
| 1431 | |
| 1432 | auto compute = [&](auto col, auto i) { |
| 1433 | // load a and compute compensation |
| 1434 | if constexpr (col == 0) { |
| 1435 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); |
| 1436 | vcomp = _mm512_setzero_si512(); |
| 1437 | for (int k = 0; k < 8; ++k) { |
| 1438 | va[k] = _mm512_set1_epi32(a_ptr[k]); |
| 1439 | vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); |
| 1440 | } |
| 1441 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); |
| 1442 | } |
| 1443 | |
| 1444 | // load b |
| 1445 | __m512i vsum = _mm512_setzero_si512(); |
| 1446 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1447 | for (int k = 0; k < 8; k += 2) { |
| 1448 | __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); |
| 1449 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); |
| 1450 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); |
| 1451 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1452 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); |
| 1453 | } |
| 1454 | const int offset = TILE_N * TILE_K / 2; |
| 1455 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); |
| 1456 | vsum = _mm512_sub_epi32(vsum, vcomp); |
| 1457 | |
| 1458 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1459 | }; |
| 1460 | |
| 1461 | for (int i = 0; i < KB; ++i) { |
| 1462 | Unroll<COLS>{}(compute, i); |
| 1463 | } |
| 1464 | |
| 1465 | //store to C |
| 1466 | auto storec = [&](auto col) { |
| 1467 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1468 | }; |
| 1469 | Unroll<COLS>{}(storec); |
| 1470 | } |
| 1471 | }; |
| 1472 | |
| 1473 | template <int BLOCK_N, int BLOCK_K> |
| 1474 | struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> { |
| 1475 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1476 | |
| 1477 | constexpr int COLS = BLOCK_N / 16; |
| 1478 | const int TILE_SIZE = TILE_N * sizeof(block_q4_1); |
| 1479 | |
| 1480 | const block_q8_1 * RESTRICT A = static_cast<const block_q8_1 *>(_A); |
| 1481 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1482 | |
| 1483 | __m512i va[8]; |
| 1484 | __m512i vb[8]; |
| 1485 | __m512 vc[COLS]; |
| 1486 | __m512 vd1, vs1; |
| 1487 | |
| 1488 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1489 | |
| 1490 | auto loadc = [&](auto col) { |
| 1491 | vc[col] = _mm512_setzero_ps(); |
| 1492 | }; |
| 1493 | Unroll<COLS>{}(loadc); |
| 1494 | |
| 1495 | auto compute = [&](auto col, auto i) { |
| 1496 | // load a |
| 1497 | if constexpr (col == 0) { |
| 1498 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); |
| 1499 | for (int k = 0; k < 8; ++k) { |
| 1500 | va[k] = _mm512_set1_epi32(a_ptr[k]); |
| 1501 | } |
| 1502 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); |
| 1503 | vs1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].s)); |
| 1504 | } |
| 1505 | |
| 1506 | // load b |
| 1507 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1508 | for (int k = 0; k < 8; k += 2) { |
| 1509 | __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); |
| 1510 | vb[k + 0] = _mm512_and_si512(bytes, lowMask); |
| 1511 | vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1512 | } |
| 1513 | const int offset = TILE_N * TILE_K / 2; |
| 1514 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); |
| 1515 | const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half)))); |
| 1516 | |
| 1517 | __m512i vsum = _mm512_setzero_si512(); |
| 1518 | for (int k = 0; k < 8; ++k) { |
| 1519 | vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); |
| 1520 | } |
| 1521 | |
| 1522 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1523 | vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); |
| 1524 | }; |
| 1525 | |
| 1526 | for (int i = 0; i < KB; ++i) { |
| 1527 | Unroll<COLS>{}(compute, i); |
| 1528 | } |
| 1529 | |
| 1530 | //store to C |
| 1531 | auto storec = [&](auto col) { |
| 1532 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1533 | }; |
| 1534 | Unroll<COLS>{}(storec); |
| 1535 | } |
| 1536 | }; |
| 1537 | |
| 1538 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1539 | struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1540 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1541 | |
| 1542 | constexpr int COLS = BLOCK_N / 16; |
| 1543 | const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); |
| 1544 | |
| 1545 | const block_q8_0 * RESTRICT A = static_cast<const block_q8_0 *>(_A); |
| 1546 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1547 | |
| 1548 | __m512i va[8]; |
| 1549 | __m512i vb[8]; |
| 1550 | __m512 vc[COLS]; |
| 1551 | __m512 vd1; |
| 1552 | |
| 1553 | // Notes: s8s8 igemm compensation in avx512-vnni |
| 1554 | // change s8s8 to u8s8 with compensate |
| 1555 | // a * b = (a + 128) * b - 128 * b |
| 1556 | // s s u s u s |
| 1557 | // |
| 1558 | // (128 * b is pre-computed when packing B to vnni formats) |
| 1559 | // |
| 1560 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); |
| 1561 | |
| 1562 | auto loadc = [&](auto col) { |
| 1563 | vc[col] = _mm512_setzero_ps(); |
| 1564 | }; |
| 1565 | Unroll<COLS>{}(loadc); |
| 1566 | |
| 1567 | auto compute = [&](auto col, auto i) { |
| 1568 | // load a and add offset 128 |
| 1569 | if constexpr (col == 0) { |
| 1570 | const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs); |
| 1571 | for (int k = 0; k < 8; ++k) { |
| 1572 | va[k] = _mm512_set1_epi32(a_ptr[k]); |
| 1573 | va[k] = _mm512_add_epi8(va[k], off); |
| 1574 | } |
| 1575 | vd1 = _mm512_set1_ps(GGML_CPU_FP16_TO_FP32(A[0 * KB + i].d)); |
| 1576 | } |
| 1577 | |
| 1578 | // load b |
| 1579 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1580 | for (int k = 0; k < 8; ++k) { |
| 1581 | vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); |
| 1582 | } |
| 1583 | const int offset = TILE_N * TILE_K; |
| 1584 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); |
| 1585 | const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); |
| 1586 | const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); |
| 1587 | |
| 1588 | __m512i vsum = _mm512_setzero_si512(); |
| 1589 | for (int k = 0; k < 8; ++k) { |
| 1590 | vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); |
| 1591 | } |
| 1592 | vsum = _mm512_sub_epi32(vsum, vcomp); |
| 1593 | |
| 1594 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1595 | }; |
| 1596 | |
| 1597 | for (int i = 0; i < KB; ++i) { |
| 1598 | Unroll<COLS>{}(compute, i); |
| 1599 | } |
| 1600 | |
| 1601 | //store to C |
| 1602 | auto storec = [&](auto col) { |
| 1603 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1604 | }; |
| 1605 | Unroll<COLS>{}(storec); |
| 1606 | } |
| 1607 | }; |
| 1608 | |
| 1609 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1610 | struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1611 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1612 | |
| 1613 | constexpr int COLS = BLOCK_N / 16; |
| 1614 | const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; |
| 1615 | |
| 1616 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); |
| 1617 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1618 | |
| 1619 | // a.qs: 8 groups, 32 bytes each group (m256i) |
| 1620 | __m512i va[8]; |
| 1621 | // a.bsum: 8 groups, 2 bytes each group (m128i) |
| 1622 | __m512i va_bsum; |
| 1623 | __m512 vc[COLS]; |
| 1624 | __m512 vd1; |
| 1625 | |
| 1626 | // packed_B: |
| 1627 | const int offset_scales = (QK_K / 2) * TILE_N; |
| 1628 | const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; |
| 1629 | const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; |
| 1630 | const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); |
| 1631 | |
| 1632 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1633 | |
| 1634 | auto loadc = [&](auto col) { |
| 1635 | vc[col] = _mm512_setzero_ps(); |
| 1636 | }; |
| 1637 | Unroll<COLS>{}(loadc); |
| 1638 | |
| 1639 | // Notes: vnni formats in QK_K |
| 1640 | // a) quants vnni format |
| 1641 | // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 |
| 1642 | // from {16, 32} to {8, 64} |
| 1643 | // |
| 1644 | // b) min vnni format |
| 1645 | // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 |
| 1646 | // from {16, 8} to {4, 32} |
| 1647 | // |
| 1648 | auto compute = [&](auto col, auto i) { |
| 1649 | // load a |
| 1650 | if constexpr (col == 0) { |
| 1651 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { |
| 1652 | va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); |
| 1653 | } |
| 1654 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); |
| 1655 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); |
| 1656 | va_bsum = _mm512_castsi128_si512(q8s); |
| 1657 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); |
| 1658 | } |
| 1659 | |
| 1660 | // step 1: accumultate the quants |
| 1661 | __m512i acc = _mm512_setzero_si512(); |
| 1662 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1663 | const char * b_qs = b_ptr; |
| 1664 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { |
| 1665 | __m512i vsum = _mm512_setzero_si512(); |
| 1666 | for (int k = 0; k < 8; k += 2) { |
| 1667 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); |
| 1668 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); |
| 1669 | |
| 1670 | __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); |
| 1671 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); |
| 1672 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); |
| 1673 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1674 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); |
| 1675 | |
| 1676 | b_qs += 64; |
| 1677 | } |
| 1678 | // vacc += scale * (q8 @ q4) |
| 1679 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); |
| 1680 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); |
| 1681 | } |
| 1682 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); |
| 1683 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1684 | |
| 1685 | // step 2: accumulate the mins |
| 1686 | __m512i acc_m = _mm512_setzero_si512(); |
| 1687 | for (int k = 0; k < 4; ++k) { |
| 1688 | __m512i vmask = _mm512_set1_epi32(k); |
| 1689 | __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); |
| 1690 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); |
| 1691 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); |
| 1692 | } |
| 1693 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); |
| 1694 | vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); |
| 1695 | }; |
| 1696 | |
| 1697 | for (int i = 0; i < KB; ++i) { |
| 1698 | Unroll<COLS>{}(compute, i); |
| 1699 | } |
| 1700 | |
| 1701 | //store to C |
| 1702 | auto storec = [&](auto col) { |
| 1703 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1704 | }; |
| 1705 | Unroll<COLS>{}(storec); |
| 1706 | } |
| 1707 | }; |
| 1708 | |
| 1709 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1710 | struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1711 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1712 | |
| 1713 | constexpr int COLS = BLOCK_N / 16; |
| 1714 | const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; |
| 1715 | |
| 1716 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); |
| 1717 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1718 | |
| 1719 | // a.qs: 8 groups, 32 bytes each group (m256i) |
| 1720 | __m512i va[8]; |
| 1721 | // a.bsum: 8 groups, 2 bytes each group (m128i) |
| 1722 | __m512i va_bsum; |
| 1723 | __m512 vc[COLS]; |
| 1724 | __m512 vd1; |
| 1725 | |
| 1726 | // packed_B: |
| 1727 | const int offset_qh = (QK_K / 2) * TILE_N; |
| 1728 | const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; |
| 1729 | const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; |
| 1730 | const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; |
| 1731 | const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); |
| 1732 | |
| 1733 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1734 | |
| 1735 | auto loadc = [&](auto col) { |
| 1736 | vc[col] = _mm512_setzero_ps(); |
| 1737 | }; |
| 1738 | Unroll<COLS>{}(loadc); |
| 1739 | |
| 1740 | // Q5_K and Q4_K shares the same vnni formats, refer to notes above. |
| 1741 | auto compute = [&](auto col, auto i) { |
| 1742 | // load a |
| 1743 | if constexpr (col == 0) { |
| 1744 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { |
| 1745 | va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); |
| 1746 | } |
| 1747 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); |
| 1748 | const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); |
| 1749 | va_bsum = _mm512_castsi128_si512(q8s); |
| 1750 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); |
| 1751 | } |
| 1752 | |
| 1753 | // step 1: accumultate the quants |
| 1754 | __m512i acc = _mm512_setzero_si512(); |
| 1755 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1756 | const char * b_qs = b_ptr; |
| 1757 | const char * b_qh = b_ptr + offset_qh; |
| 1758 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { |
| 1759 | __m512i vsum = _mm512_setzero_si512(); |
| 1760 | __m512i hmask0 = _mm512_set1_epi8(0x1); |
| 1761 | __m512i hmask1 = _mm512_set1_epi8(0x2); |
| 1762 | __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); |
| 1763 | for (int k = 0; k < 8; k += 2) { |
| 1764 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); |
| 1765 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); |
| 1766 | |
| 1767 | __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); |
| 1768 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); |
| 1769 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1770 | |
| 1771 | __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); |
| 1772 | __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); |
| 1773 | |
| 1774 | hmask0 = _mm512_slli_epi16(hmask0, 2); |
| 1775 | hmask1 = _mm512_slli_epi16(hmask1, 2); |
| 1776 | vb0 = _mm512_add_epi8(vb0, vh0); |
| 1777 | vb1 = _mm512_add_epi8(vb1, vh1); |
| 1778 | |
| 1779 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); |
| 1780 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); |
| 1781 | |
| 1782 | b_qs += 64; |
| 1783 | } |
| 1784 | // vacc += scale * (q8 @ q5) |
| 1785 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); |
| 1786 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); |
| 1787 | } |
| 1788 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); |
| 1789 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1790 | |
| 1791 | // step 2: accumulate the mins |
| 1792 | __m512i acc_m = _mm512_setzero_si512(); |
| 1793 | for (int k = 0; k < 4; ++k) { |
| 1794 | __m512i vmask = _mm512_set1_epi32(k); |
| 1795 | __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); |
| 1796 | __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); |
| 1797 | acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); |
| 1798 | } |
| 1799 | const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); |
| 1800 | vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); |
| 1801 | }; |
| 1802 | |
| 1803 | for (int i = 0; i < KB; ++i) { |
| 1804 | Unroll<COLS>{}(compute, i); |
| 1805 | } |
| 1806 | |
| 1807 | //store to C |
| 1808 | auto storec = [&](auto col) { |
| 1809 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1810 | }; |
| 1811 | Unroll<COLS>{}(storec); |
| 1812 | } |
| 1813 | }; |
| 1814 | |
| 1815 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1816 | struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1817 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1818 | |
| 1819 | constexpr int COLS = BLOCK_N / 16; |
| 1820 | const int TILE_SIZE = TILE_N * sizeof(block_q6_K); |
| 1821 | |
| 1822 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); |
| 1823 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1824 | |
| 1825 | // load the 256 bytes from A to 4 avx512 vectors |
| 1826 | __m512i va[4]; |
| 1827 | __m512 vc[COLS]; |
| 1828 | __m512 vd1; |
| 1829 | |
| 1830 | // packed_B: |
| 1831 | const int offset_qh = (QK_K / 2) * TILE_N; |
| 1832 | const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; |
| 1833 | const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; |
| 1834 | |
| 1835 | // compensation |
| 1836 | __m512i vcomp; |
| 1837 | |
| 1838 | const __m512i m32s = _mm512_set1_epi32(32); |
| 1839 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1840 | |
| 1841 | auto loadc = [&](auto col) { |
| 1842 | vc[col] = _mm512_setzero_ps(); |
| 1843 | }; |
| 1844 | Unroll<COLS>{}(loadc); |
| 1845 | |
| 1846 | auto compute = [&](auto col, auto i) { |
| 1847 | if constexpr (col == 0) { |
| 1848 | // load a |
| 1849 | va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); |
| 1850 | va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); |
| 1851 | va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); |
| 1852 | va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); |
| 1853 | |
| 1854 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); |
| 1855 | vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); |
| 1856 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); |
| 1857 | } |
| 1858 | |
| 1859 | // accmulate the quants |
| 1860 | __m512i acc = _mm512_setzero_si512(); |
| 1861 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1862 | const char * b_qs = b_ptr; |
| 1863 | const char * b_qh = b_ptr + offset_qh; |
| 1864 | int mask = 0; |
| 1865 | for (int k_group = 0; k_group < QK_K / 16; ++k_group) { |
| 1866 | int r = k_group >> 2; |
| 1867 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1868 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1869 | |
| 1870 | __m512i vsum = _mm512_setzero_si512(); |
| 1871 | __m512i hmask = _mm512_set1_epi8(0x3); |
| 1872 | |
| 1873 | __m512i bytes = _mm512_loadu_si512(b_qs); |
| 1874 | __m512i hbits = _mm512_loadu_si512(b_qh); |
| 1875 | __m512i vb0 = _mm512_and_si512(bytes, lowMask); |
| 1876 | __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1877 | __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); |
| 1878 | __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); |
| 1879 | |
| 1880 | vb0 = _mm512_add_epi8(vb0, vh0); |
| 1881 | vb1 = _mm512_add_epi8(vb1, vh1); |
| 1882 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); |
| 1883 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); |
| 1884 | b_qs += 64; |
| 1885 | |
| 1886 | va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1887 | va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1888 | |
| 1889 | bytes = _mm512_loadu_si512(b_qs); |
| 1890 | vb0 = _mm512_and_si512(bytes, lowMask); |
| 1891 | vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); |
| 1892 | vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); |
| 1893 | vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); |
| 1894 | vb0 = _mm512_add_epi8(vb0, vh0); |
| 1895 | vb1 = _mm512_add_epi8(vb1, vh1); |
| 1896 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); |
| 1897 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); |
| 1898 | b_qs += 64; |
| 1899 | b_qh += 64; |
| 1900 | |
| 1901 | // B * A - 32 * A |
| 1902 | __m512i vmask = _mm512_set1_epi32(k_group); |
| 1903 | vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); |
| 1904 | |
| 1905 | // vacc += scale * (q8 @ q6) |
| 1906 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); |
| 1907 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); |
| 1908 | } |
| 1909 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); |
| 1910 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 1911 | }; |
| 1912 | |
| 1913 | for (int i = 0; i < KB; ++i) { |
| 1914 | Unroll<COLS>{}(compute, i); |
| 1915 | } |
| 1916 | |
| 1917 | //store to C |
| 1918 | auto storec = [&](int col) { |
| 1919 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 1920 | }; |
| 1921 | Unroll<COLS>{}(storec); |
| 1922 | } |
| 1923 | }; |
| 1924 | |
| 1925 | template <int BLOCK_M, int BLOCK_N, int BLOCK_K> |
| 1926 | struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> { |
| 1927 | static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 1928 | |
| 1929 | constexpr int COLS = BLOCK_N / 16; |
| 1930 | const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; |
| 1931 | |
| 1932 | const block_q8_K * RESTRICT A = static_cast<const block_q8_K *>(_A); |
| 1933 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 1934 | |
| 1935 | // load the 256 bytes from A to 4 avx512 vectors |
| 1936 | __m512i va[4]; |
| 1937 | __m512 vc[COLS]; |
| 1938 | __m512 vd1; |
| 1939 | |
| 1940 | // packed_B: |
| 1941 | const int offset_scales = (QK_K / 2) * TILE_N ; |
| 1942 | const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; |
| 1943 | |
| 1944 | // compensation |
| 1945 | __m512i vcomp; |
| 1946 | |
| 1947 | const __m256i m128s = _mm256_set1_epi16(128); |
| 1948 | const __m512i lowMask = _mm512_set1_epi8(0xF); |
| 1949 | |
| 1950 | const __m512i values128 = _mm512_set_epi8( |
| 1951 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1952 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1953 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, |
| 1954 | 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 |
| 1955 | ); |
| 1956 | const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); |
| 1957 | const __m512i values256 = _mm512_add_epi8(values128, off); |
| 1958 | |
| 1959 | auto loadc = [&](auto col) { |
| 1960 | vc[col] = _mm512_setzero_ps(); |
| 1961 | }; |
| 1962 | Unroll<COLS>{}(loadc); |
| 1963 | |
| 1964 | auto compute = [&](auto col, auto i) { |
| 1965 | if constexpr (col == 0) { |
| 1966 | // load a |
| 1967 | va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); |
| 1968 | va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); |
| 1969 | va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); |
| 1970 | va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); |
| 1971 | |
| 1972 | // compensation: 128 * A |
| 1973 | const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); |
| 1974 | vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); |
| 1975 | vd1 = _mm512_set1_ps(A[0 * KB + i].d); |
| 1976 | } |
| 1977 | |
| 1978 | // accmulate the quants |
| 1979 | __m512i acc = _mm512_setzero_si512(); |
| 1980 | const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); |
| 1981 | const char * b_qs = b_ptr; |
| 1982 | int mask = 0; |
| 1983 | for (int k_group = 0; k_group < QK_K / 32; ++k_group) { |
| 1984 | int r = k_group >> 1; |
| 1985 | __m512i vmask = _mm512_set1_epi32(k_group); |
| 1986 | __m512i vsum = _mm512_setzero_si512(); |
| 1987 | for (int k = 0; k < 8; k += 2) { |
| 1988 | __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1989 | __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); |
| 1990 | |
| 1991 | __m512i bytes = _mm512_loadu_si512(b_qs); |
| 1992 | __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); |
| 1993 | __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); |
| 1994 | |
| 1995 | vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); |
| 1996 | vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); |
| 1997 | b_qs += 64; |
| 1998 | } |
| 1999 | // (B + 128) * A - 128 * A |
| 2000 | vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); |
| 2001 | |
| 2002 | // vacc += scale * (q8 @ q4) |
| 2003 | const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); |
| 2004 | acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); |
| 2005 | } |
| 2006 | const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); |
| 2007 | vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); |
| 2008 | }; |
| 2009 | |
| 2010 | for (int i = 0; i < KB; ++i) { |
| 2011 | Unroll<COLS>{}(compute, i); |
| 2012 | } |
| 2013 | |
| 2014 | //store to C |
| 2015 | auto storec = [&](auto col) { |
| 2016 | _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); |
| 2017 | }; |
| 2018 | Unroll<COLS>{}(storec); |
| 2019 | } |
| 2020 | }; |
| 2021 | |
| 2022 | #define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ |
| 2023 | tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \ |
| 2024 | KB, (const char *)wdata + 0 * row_size_A, \ |
| 2025 | (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ |
| 2026 | (float *) dst->data + 0 * N + nb_start, ldc) |
| 2027 | |
| 2028 | template <typename TA, typename TB, typename TC, int BLOCK_K, |
| 2029 | typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0> |
| 2030 | void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { |
| 2031 | using packed_B_t = packed_B_type<TB>; |
| 2032 | const int TILE_SIZE = get_tile_size<TB>(); |
| 2033 | const bool need_unpack = do_unpack<TB>::value; |
| 2034 | |
| 2035 | GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); |
| 2036 | const TA * RESTRICT A = static_cast<const TA *>(_A); |
| 2037 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 2038 | |
| 2039 | const int m0 = std::min(M, TILE_M); |
| 2040 | const int m1 = std::max(M - TILE_M, 0); |
| 2041 | const int lda = KB * sizeof(TA); |
| 2042 | //const int ldb = KB * sizeof(TB); |
| 2043 | |
| 2044 | static thread_local packed_B_t Tile0[TILE_N * TILE_K]; |
| 2045 | static thread_local packed_B_t Tile1[TILE_N * TILE_K]; |
| 2046 | static thread_local int8_t Tile23[TILE_M * TILE_K]; |
| 2047 | |
| 2048 | static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; |
| 2049 | static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; |
| 2050 | |
| 2051 | // double buffering C to interleave avx512 and amx |
| 2052 | int32_t * C_cur = TileC0; |
| 2053 | int32_t * C_pre = TileC1; |
| 2054 | |
| 2055 | auto Tile4 = [&](int32_t * base) { return base; }; |
| 2056 | auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; |
| 2057 | auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; |
| 2058 | auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; |
| 2059 | |
| 2060 | if (M == 2 * TILE_M) { |
| 2061 | // i = 0 |
| 2062 | const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); |
| 2063 | const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); |
| 2064 | if (need_unpack) { |
| 2065 | unpack_B<TB>(Tile0, B_blk0); |
| 2066 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); |
| 2067 | } else { |
| 2068 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); |
| 2069 | } |
| 2070 | |
| 2071 | _tile_zero(TMM4); |
| 2072 | _tile_loadd(TMM2, A[0].qs, lda); |
| 2073 | _tile_dpbssd(TMM4, TMM2, TMM0); |
| 2074 | _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); |
| 2075 | |
| 2076 | _tile_zero(TMM5); |
| 2077 | _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); |
| 2078 | _tile_dpbssd(TMM5, TMM3, TMM0); |
| 2079 | _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); |
| 2080 | |
| 2081 | if (need_unpack) { |
| 2082 | unpack_B<TB>(Tile1, B_blk0); |
| 2083 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); |
| 2084 | } else { |
| 2085 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); |
| 2086 | } |
| 2087 | |
| 2088 | _tile_zero(TMM6); |
| 2089 | _tile_dpbssd(TMM6, TMM2, TMM1); |
| 2090 | _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); |
| 2091 | |
| 2092 | _tile_zero(TMM7); |
| 2093 | _tile_dpbssd(TMM7, TMM3, TMM1); |
| 2094 | _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); |
| 2095 | |
| 2096 | for (int i = 1; i < KB; ++i) { |
| 2097 | // index of previous iter |
| 2098 | const int ii = i - 1; |
| 2099 | const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); |
| 2100 | const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); |
| 2101 | GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { |
| 2102 | if (need_unpack) { |
| 2103 | unpack_B<TB>(Tile0, B_blk0); |
| 2104 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); |
| 2105 | } else { |
| 2106 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); |
| 2107 | } |
| 2108 | _tile_zero(TMM4); |
| 2109 | _tile_loadd(TMM2, A[i].qs, lda); |
| 2110 | acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); |
| 2111 | |
| 2112 | _tile_dpbssd(TMM4, TMM2, TMM0); |
| 2113 | _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); |
| 2114 | |
| 2115 | _tile_zero(TMM5); |
| 2116 | _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); |
| 2117 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); |
| 2118 | |
| 2119 | _tile_dpbssd(TMM5, TMM3, TMM0); |
| 2120 | _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); |
| 2121 | |
| 2122 | if (need_unpack) { |
| 2123 | unpack_B<TB>(Tile1, B_blk1); |
| 2124 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); |
| 2125 | } else { |
| 2126 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); |
| 2127 | } |
| 2128 | _tile_zero(TMM6); |
| 2129 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); |
| 2130 | |
| 2131 | _tile_dpbssd(TMM6, TMM2, TMM1); |
| 2132 | _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); |
| 2133 | |
| 2134 | _tile_zero(TMM7); |
| 2135 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); |
| 2136 | |
| 2137 | _tile_dpbssd(TMM7, TMM3, TMM1); |
| 2138 | _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); |
| 2139 | |
| 2140 | std::swap(C_cur, C_pre); |
| 2141 | }); |
| 2142 | } |
| 2143 | // final accumulation |
| 2144 | { |
| 2145 | int ii = KB - 1; |
| 2146 | acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); |
| 2147 | acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); |
| 2148 | acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); |
| 2149 | acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); |
| 2150 | } |
| 2151 | } else { |
| 2152 | for (int i = 0; i < KB; ++i) { |
| 2153 | _tile_zero(TMM4); |
| 2154 | _tile_zero(TMM6); |
| 2155 | if (m1 != 0) { |
| 2156 | _tile_zero(TMM5); |
| 2157 | _tile_zero(TMM7); |
| 2158 | } |
| 2159 | |
| 2160 | const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); |
| 2161 | const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); |
| 2162 | if (need_unpack) { |
| 2163 | unpack_B<TB>(Tile0, B_blk0); |
| 2164 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); |
| 2165 | } else { |
| 2166 | _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); |
| 2167 | } |
| 2168 | |
| 2169 | if (need_unpack) { |
| 2170 | unpack_B<TB>(Tile1, B_blk1); |
| 2171 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); |
| 2172 | } else { |
| 2173 | _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); |
| 2174 | } |
| 2175 | |
| 2176 | if (m0 == TILE_M) { |
| 2177 | _tile_loadd(TMM2, A[i].qs, lda); |
| 2178 | } else { |
| 2179 | unpack_A(Tile23, &A[i], KB, m0); |
| 2180 | _tile_loadd(TMM2, Tile23, TILE_K); |
| 2181 | } |
| 2182 | |
| 2183 | _tile_dpbssd(TMM4, TMM2, TMM0); |
| 2184 | _tile_dpbssd(TMM6, TMM2, TMM1); |
| 2185 | |
| 2186 | _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); |
| 2187 | _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); |
| 2188 | |
| 2189 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { |
| 2190 | acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); |
| 2191 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); |
| 2192 | }); |
| 2193 | |
| 2194 | if (m1 != 0) { |
| 2195 | unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); |
| 2196 | _tile_loadd(TMM3, Tile23, TILE_K); |
| 2197 | |
| 2198 | _tile_dpbssd(TMM5, TMM3, TMM0); |
| 2199 | _tile_dpbssd(TMM7, TMM3, TMM1); |
| 2200 | _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); |
| 2201 | _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); |
| 2202 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { |
| 2203 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); |
| 2204 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); |
| 2205 | }); |
| 2206 | } |
| 2207 | } |
| 2208 | } |
| 2209 | return; |
| 2210 | } |
| 2211 | |
| 2212 | template <typename TA, typename TB, typename TC, int BLOCK_K, |
| 2213 | typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0> |
| 2214 | void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { |
| 2215 | static_assert(std::is_same<TA, block_q8_K>::value); |
| 2216 | const int TILE_SIZE = get_tile_size<TB>(); |
| 2217 | |
| 2218 | GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); |
| 2219 | const TA * RESTRICT A = static_cast<const TA *>(_A); |
| 2220 | const char * RESTRICT B = static_cast<const char *>(_B); |
| 2221 | |
| 2222 | const int m0 = std::min(M, TILE_M); |
| 2223 | const int m1 = std::max(M - TILE_M, 0); |
| 2224 | //const int lda = KB * sizeof(TA); |
| 2225 | |
| 2226 | static thread_local int8_t Tile0[TILE_N * TILE_K]; |
| 2227 | static thread_local int8_t Tile1[TILE_N * TILE_K]; |
| 2228 | static thread_local int8_t Tile23[TILE_M * TILE_K]; |
| 2229 | |
| 2230 | // mat mul result for each group |
| 2231 | static thread_local int32_t Tile4[TILE_M * TILE_N]; |
| 2232 | static thread_local int32_t Tile5[TILE_M * TILE_N]; |
| 2233 | static thread_local int32_t Tile6[TILE_M * TILE_N]; |
| 2234 | static thread_local int32_t Tile7[TILE_M * TILE_N]; |
| 2235 | |
| 2236 | // sum of each QK_K block, contains 8 groups, int32 |
| 2237 | static thread_local int32_t Sumi4[TILE_M * TILE_N]; |
| 2238 | static thread_local int32_t Sumi5[TILE_M * TILE_N]; |
| 2239 | static thread_local int32_t Sumi6[TILE_M * TILE_N]; |
| 2240 | static thread_local int32_t Sumi7[TILE_M * TILE_N]; |
| 2241 | |
| 2242 | const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32; |
| 2243 | for (int i = 0; i < KB; ++i) { |
| 2244 | // step 1: accumulate the quants across 8 groups, each group with 32 |
| 2245 | for (int k = 0; k < QK_K / k_group_size; ++k) { |
| 2246 | GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { |
| 2247 | _tile_zero(TMM4); |
| 2248 | _tile_zero(TMM6); |
| 2249 | |
| 2250 | unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); |
| 2251 | _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); |
| 2252 | |
| 2253 | unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); |
| 2254 | _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); |
| 2255 | |
| 2256 | unpack_A<TB>(Tile23, &A[i], KB, k, m0); |
| 2257 | _tile_loadd(TMM2, Tile23, TILE_K); |
| 2258 | |
| 2259 | _tile_dpbssd(TMM4, TMM2, TMM0); |
| 2260 | _tile_dpbssd(TMM6, TMM2, TMM1); |
| 2261 | |
| 2262 | _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); |
| 2263 | _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); |
| 2264 | |
| 2265 | scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); |
| 2266 | scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); |
| 2267 | |
| 2268 | if (m1 != 0) { |
| 2269 | _tile_zero(TMM5); |
| 2270 | _tile_zero(TMM7); |
| 2271 | |
| 2272 | unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1); |
| 2273 | _tile_loadd(TMM3, Tile23, TILE_K); |
| 2274 | |
| 2275 | _tile_dpbssd(TMM5, TMM3, TMM0); |
| 2276 | _tile_dpbssd(TMM7, TMM3, TMM1); |
| 2277 | |
| 2278 | _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); |
| 2279 | _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); |
| 2280 | |
| 2281 | scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); |
| 2282 | scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); |
| 2283 | } |
| 2284 | }); |
| 2285 | } |
| 2286 | |
| 2287 | // step 2: accmulate the mins |
| 2288 | GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { |
| 2289 | acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); |
| 2290 | acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); |
| 2291 | if (m1 != 0) { |
| 2292 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); |
| 2293 | acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); |
| 2294 | } |
| 2295 | }); |
| 2296 | } |
| 2297 | return; |
| 2298 | } |
| 2299 | |
| 2300 | } // anonymous namespace |
| 2301 | |
| 2302 | // get the packed tensor size for quantized weights |
| 2303 | size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) { |
| 2304 | const enum ggml_type TYPE = tensor->type; |
| 2305 | |
| 2306 | const int K = tensor->ne[0]; // ne0: in_features |
| 2307 | const int N = tensor->ne[1]; // ne1: out_features |
| 2308 | |
| 2309 | auto get_tensor_size = [&] { |
| 2310 | size_t row_size_B{0}; |
| 2311 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2312 | row_size_B = get_row_size<type, blck_size>(K); |
| 2313 | }); |
| 2314 | return N * row_size_B; |
| 2315 | }; |
| 2316 | |
| 2317 | if (qtype_has_amx_kernels(TYPE)) { |
| 2318 | return get_tensor_size(); |
| 2319 | } else { |
| 2320 | // for f16, bf16 we don't do packing |
| 2321 | return ggml_nbytes(tensor); |
| 2322 | } |
| 2323 | } |
| 2324 | |
| 2325 | // pack weight to vnni format |
| 2326 | void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { |
| 2327 | GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now |
| 2328 | |
| 2329 | const enum ggml_type TYPE = tensor->type; |
| 2330 | |
| 2331 | const int K = tensor->ne[0]; // ne0: in_features |
| 2332 | const int N = tensor->ne[1]; // ne1: out_features |
| 2333 | |
| 2334 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2335 | convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K); |
| 2336 | }); |
| 2337 | } |
| 2338 | |
| 2339 | size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { |
| 2340 | struct ggml_tensor * src0 = dst->src[0]; |
| 2341 | |
| 2342 | const enum ggml_type TYPE = src0->type; |
| 2343 | |
| 2344 | const bool is_floating_type = TYPE == GGML_TYPE_F16; |
| 2345 | if (is_floating_type) { |
| 2346 | return 0; |
| 2347 | } |
| 2348 | |
| 2349 | const int M = dst->ne[1]; |
| 2350 | const int K = src0->ne[0]; |
| 2351 | |
| 2352 | size_t desired_wsize = 0; |
| 2353 | |
| 2354 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2355 | const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); |
| 2356 | desired_wsize = M * row_size_A; |
| 2357 | }); |
| 2358 | |
| 2359 | return desired_wsize; |
| 2360 | } |
| 2361 | |
| 2362 | // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) |
| 2363 | // |
| 2364 | // src0: weight in shape of {N, K}, quantized |
| 2365 | // src1: input in shape of {M, K}, float32 |
| 2366 | // dst: output in shape of {M, N}, float32 |
| 2367 | // |
| 2368 | // the function performs: dst = src1 @ src0.T |
| 2369 | // |
| 2370 | void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { |
| 2371 | struct ggml_tensor * src0 = dst->src[0]; |
| 2372 | struct ggml_tensor * src1 = dst->src[1]; |
| 2373 | |
| 2374 | const enum ggml_type TYPE = src0->type; |
| 2375 | |
| 2376 | // f16 only has avx512 kernels for now, |
| 2377 | // amx kernels will be added once 6th gen xeon is released. |
| 2378 | const bool is_floating_type = TYPE == GGML_TYPE_F16; |
| 2379 | |
| 2380 | const int M = dst->ne[1]; |
| 2381 | const int N = dst->ne[0]; |
| 2382 | const int K = src0->ne[0]; |
| 2383 | const int ldc = dst->nb[1] / dst->nb[0]; |
| 2384 | |
| 2385 | if (is_floating_type) { |
| 2386 | constexpr int BLOCK_M = 4; |
| 2387 | constexpr int BLOCK_N = 6; |
| 2388 | const int MB = div_up(M, BLOCK_M); |
| 2389 | const int NB = div_up(N, BLOCK_N); |
| 2390 | |
| 2391 | parallel_for_ggml(params, MB * NB, [&](int begin, int end) { |
| 2392 | GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { |
| 2393 | for (int i = begin; i < end; ++i) { |
| 2394 | int mb = i / NB; |
| 2395 | int nb = i % NB; |
| 2396 | |
| 2397 | int mb_start = mb * BLOCK_M; |
| 2398 | int mb_size = std::min(BLOCK_M, M - mb_start); |
| 2399 | int nb_start = nb * BLOCK_N; |
| 2400 | int nb_size = std::min(BLOCK_N, N - nb_start); |
| 2401 | |
| 2402 | switch (mb_size << 4 | nb_size) { |
| 2403 | case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; |
| 2404 | case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; |
| 2405 | case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; |
| 2406 | case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; |
| 2407 | case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; |
| 2408 | case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; |
| 2409 | case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; |
| 2410 | case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; |
| 2411 | case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; |
| 2412 | case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; |
| 2413 | case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; |
| 2414 | case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; |
| 2415 | default: fprintf(stderr, "Unexpected block size!\n" ); |
| 2416 | } |
| 2417 | } |
| 2418 | }); |
| 2419 | }); |
| 2420 | return; |
| 2421 | } |
| 2422 | |
| 2423 | // pointer to work space, used convert A from float to quantized type |
| 2424 | void * wdata = params->wdata; |
| 2425 | |
| 2426 | //TODO: performance improvement: merge quant A |
| 2427 | if (params->ith == 0) { |
| 2428 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2429 | const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); |
| 2430 | const size_t desired_wsize = M * row_size_A; |
| 2431 | if (params->wsize < desired_wsize) { |
| 2432 | GGML_ABORT("insufficient work space size" ); |
| 2433 | } |
| 2434 | |
| 2435 | // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size |
| 2436 | // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size |
| 2437 | GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); |
| 2438 | |
| 2439 | const float * A_data = static_cast<const float *>(src1->data); |
| 2440 | for (int m = 0; m < M; ++m) { |
| 2441 | from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K); |
| 2442 | } |
| 2443 | }); |
| 2444 | } |
| 2445 | |
| 2446 | ggml_barrier(params->threadpool); |
| 2447 | |
| 2448 | if (M == 1) { |
| 2449 | // MB = 1 and handle 8 tiles in each block |
| 2450 | constexpr int kTilesN = 4; |
| 2451 | constexpr int BLOCK_N = TILE_N * kTilesN; |
| 2452 | const int NB = div_up(N, BLOCK_N); |
| 2453 | |
| 2454 | parallel_for_ggml(params, NB, [&](int begin, int end) { |
| 2455 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2456 | const int KB = K / blck_size; |
| 2457 | const int TILE_SIZE = get_tile_size<type>(); |
| 2458 | const int row_size_A = KB * sizeof(vec_dot_type); |
| 2459 | for (int i = begin; i < end; ++i) { |
| 2460 | int nb = i; |
| 2461 | int nb_start = nb * BLOCK_N; |
| 2462 | int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 |
| 2463 | |
| 2464 | switch (nb_size) { |
| 2465 | //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; |
| 2466 | case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; |
| 2467 | case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; |
| 2468 | case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; |
| 2469 | case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; |
| 2470 | default: fprintf(stderr, "Unexpected n block size!\n" ); |
| 2471 | } |
| 2472 | } |
| 2473 | }); |
| 2474 | }); |
| 2475 | return; |
| 2476 | } |
| 2477 | |
| 2478 | // handle 4 tiles at a tile |
| 2479 | constexpr int BLOCK_M = TILE_M * 2; |
| 2480 | constexpr int BLOCK_N = TILE_N * 2; |
| 2481 | const int MB = div_up(M, BLOCK_M); |
| 2482 | const int NB = div_up(N, BLOCK_N); |
| 2483 | |
| 2484 | parallel_for_ggml(params, MB * NB, [&](int begin, int end) { |
| 2485 | // init tile config for each thread |
| 2486 | ggml_tile_config_init(); |
| 2487 | |
| 2488 | GGML_DISPATCH_QTYPES(TYPE, [&] { |
| 2489 | const int KB = K / blck_size; |
| 2490 | const int TILE_SIZE = get_tile_size<type>(); |
| 2491 | const int row_size_A = KB * sizeof(vec_dot_type); |
| 2492 | |
| 2493 | for (int i = begin; i < end; ++i) { |
| 2494 | int mb = i / NB; |
| 2495 | int nb = i % NB; |
| 2496 | |
| 2497 | int mb_start = mb * BLOCK_M; |
| 2498 | int mb_size = std::min(BLOCK_M, M - mb_start); |
| 2499 | int nb_start = nb * BLOCK_N; |
| 2500 | int nb_size = BLOCK_N; |
| 2501 | |
| 2502 | tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>( |
| 2503 | mb_size, nb_size, KB, |
| 2504 | (const char *)wdata + mb_start * row_size_A, |
| 2505 | (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), |
| 2506 | (float *) dst->data + mb_start * N + nb_start, ldc); |
| 2507 | } |
| 2508 | }); |
| 2509 | }); |
| 2510 | } |
| 2511 | |
| 2512 | #endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__) |
| 2513 | |