1#pragma once
2
3#include "ggml.h"
4#include "ggml-cpu-impl.h"
5
6#include <algorithm>
7#include <memory>
8#include <type_traits>
9
10#if defined(GGML_USE_OPENMP)
11#include <omp.h>
12#endif
13
14#define TILE_M 16
15#define TILE_N 16
16#define TILE_K 32
17#define VNNI_BLK 4
18
19#define AMX_BLK_SIZE 32
20
21#define TMM0 0
22#define TMM1 1
23#define TMM2 2
24#define TMM3 3
25#define TMM4 4
26#define TMM5 5
27#define TMM6 6
28#define TMM7 7
29
30// parallel routines
31template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
32inline T div_up(T x, T y) { return (x + y - 1) / y; }
33
34template <typename T>
35inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
36#if 0
37 // onednn partition pattern
38 T& n_my = n_end;
39 if (nth <= 1 || n == 0) {
40 n_start = 0;
41 n_my = n;
42 } else {
43 T n1 = div_up(n, nth);
44 T n2 = n1 - 1;
45 T T1 = n - n2 * nth;
46 n_my = ith < T1 ? n1 : n2;
47 n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
48 }
49 n_end += n_start;
50#else
51 // pytorch aten partition pattern
52 T n_my = div_up(n, nth);
53 n_start = ith * n_my;
54 n_end = std::min(n_start + n_my, n);
55#endif
56}
57
58template <typename func_t>
59inline void parallel_for(int n, const func_t& f) {
60#if defined(GGML_USE_OPENMP)
61#pragma omp parallel
62{
63 int nth = omp_get_num_threads();
64 int ith = omp_get_thread_num();
65 int tbegin, tend;
66 balance211(n, nth, ith, n_start&: tbegin, n_end&: tend);
67 f(tbegin, tend);
68}
69#else
70 f(0, n);
71#endif
72}
73
74template <typename func_t>
75inline void parallel_for_ggml(const ggml_compute_params * params, int n, const func_t & f) {
76 int tbegin, tend;
77 balance211(n, nth: params->nth, ith: params->ith, n_start&: tbegin, n_end&: tend);
78 f(tbegin, tend);
79}
80
81// quantized types that have AMX support
82inline bool qtype_has_amx_kernels(const enum ggml_type type) {
83 // TODO: fix padding for vnni format
84 return (type == GGML_TYPE_Q4_0) ||
85 (type == GGML_TYPE_Q4_1) ||
86 (type == GGML_TYPE_Q8_0) ||
87 (type == GGML_TYPE_Q4_K) ||
88 (type == GGML_TYPE_Q5_K) ||
89 (type == GGML_TYPE_Q6_K) ||
90 (type == GGML_TYPE_IQ4_XS);
91}
92