1#include "amx.h"
2#include "common.h"
3#include "mmq.h"
4#include "ggml-backend-impl.h"
5#include "ggml-backend.h"
6#include "ggml-impl.h"
7#include "ggml-cpu.h"
8#include "traits.h"
9
10#if defined(__linux__)
11#include <sys/syscall.h>
12#include <unistd.h>
13#endif
14
15#include <cstdlib>
16#include <cstring>
17#include <memory>
18
19#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
20
21// AMX type_trais
22namespace ggml::cpu::amx {
23class tensor_traits : public ggml::cpu::tensor_traits {
24 bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
25 size = ggml_backend_amx_desired_wsize(op);
26 return true;
27 }
28
29 bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
30 if (op->op == GGML_OP_MUL_MAT) {
31 ggml_backend_amx_mul_mat(params, op);
32 return true;
33 }
34 return false;
35 }
36};
37
38static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
39 static tensor_traits traits;
40 return &traits;
41}
42} // namespace ggml::cpu::amx
43
44// AMX buffer interface
45static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
46 free(buffer->context);
47}
48
49static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
50 return (void *) (buffer->context);
51}
52
53static enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
54 tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
55
56 GGML_UNUSED(buffer);
57 return GGML_STATUS_SUCCESS;
58}
59
60static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
61 uint8_t value, size_t offset, size_t size) {
62 memset((char *) tensor->data + offset, value, size);
63
64 GGML_UNUSED(buffer);
65}
66
67static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
68 const void * data, size_t offset, size_t size) {
69 if (qtype_has_amx_kernels(tensor->type)) {
70 GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type));
71 ggml_backend_amx_convert_weight(tensor, data, offset, size);
72 } else {
73 memcpy((char *) tensor->data + offset, data, size);
74 }
75
76 GGML_UNUSED(buffer);
77}
78
79/*
80// need to figure what we need to do with buffer->extra.
81static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
82 GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
83 memcpy(data, (const char *)tensor->data + offset, size);
84
85 GGML_UNUSED(buffer);
86}
87
88static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
89 if (ggml_backend_buffer_is_host(src->buffer)) {
90 if (qtype_has_amx_kernels(src->type)) {
91 ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst));
92 } else {
93 memcpy(dst->data, src->data, ggml_nbytes(src));
94 }
95 return true;
96 }
97 return false;
98
99 GGML_UNUSED(buffer);
100}
101*/
102
103static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
104 memset(buffer->context, value, buffer->size);
105}
106
107static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
108 /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
109 /* .get_base = */ ggml_backend_amx_buffer_get_base,
110 /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor,
111 /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
112 /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
113 /* .get_tensor = */ nullptr,
114 /* .cpy_tensor = */ nullptr,
115 /* .clear = */ ggml_backend_amx_buffer_clear,
116 /* .reset = */ nullptr,
117};
118
119static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
120 return "AMX";
121
122 GGML_UNUSED(buft);
123}
124
125static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
126 void * data = ggml_aligned_malloc(size);
127 if (data == NULL) {
128 fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
129 return NULL;
130 }
131
132 return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
133}
134
135static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
136 return TENSOR_ALIGNMENT;
137
138 GGML_UNUSED(buft);
139}
140
141namespace ggml::cpu::amx {
142class extra_buffer_type : ggml::cpu::extra_buffer_type {
143 bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
144 // handle only 2d gemm for now
145 auto is_contiguous_2d = [](const struct ggml_tensor * t) {
146 return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
147 };
148
149 if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
150 is_contiguous_2d(op->src[1]) && // src1 must be contiguous
151 op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
152 op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
153 op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
154 (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
155 // src1 must be host buffer
156 if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
157 return false;
158 }
159 // src1 must be float32
160 if (op->src[1]->type == GGML_TYPE_F32) {
161 return true;
162 }
163 }
164 return false;
165 }
166
167 ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
168 if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
169 op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {
170 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
171 }
172
173 return nullptr;
174 }
175};
176} // namespace ggml::cpu::amx
177
178static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
179 return ggml_backend_amx_get_alloc_size(tensor);
180
181 GGML_UNUSED(buft);
182}
183
184#define ARCH_GET_XCOMP_PERM 0x1022
185#define ARCH_REQ_XCOMP_PERM 0x1023
186#define XFEATURE_XTILECFG 17
187#define XFEATURE_XTILEDATA 18
188
189static bool ggml_amx_init() {
190#if defined(__linux__)
191 if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
192 fprintf(stderr, "AMX is not ready to be used!\n");
193 return false;
194 }
195 return true;
196#elif defined(_WIN32)
197 return true;
198#else
199 return false;
200#endif
201}
202
203ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
204 static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
205 /* .iface = */ {
206 /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
207 /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
208 /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
209 /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
210 /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
211 /* .is_host = */ nullptr,
212 },
213 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
214 /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
215 };
216
217 if (!ggml_amx_init()) {
218 return nullptr;
219 }
220
221 return &ggml_backend_buffer_type_amx;
222}
223
224#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
225