1#include <cstdio>
2#include <type_traits>
3#include <vector>
4#include <random>
5#include <chrono>
6#include <cstdlib>
7#include <cmath>
8#include <cassert>
9#include <cstring>
10#include <array>
11#include <type_traits>
12
13#include <ggml.h>
14#include <ggml-cpu.h>
15
16constexpr int kVecSize = 1 << 16;
17
18// Copy-pasted from ggml.c
19#define QK4_0 32
20typedef struct {
21 float d; // delta
22 uint8_t qs[QK4_0 / 2]; // nibbles / quants
23} block_q4_0;
24static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
25
26#define QK4_1 32
27typedef struct {
28 float d; // delta
29 float m; // min
30 uint8_t qs[QK4_1 / 2]; // nibbles / quants
31} block_q4_1;
32static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
33
34// Copy-pasted from ggml.c
35#define QK8_0 32
36typedef struct {
37 float d; // delta
38 float s; // d * sum(qs[i])
39 int8_t qs[QK8_0]; // quants
40} block_q8_0;
41static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
42
43static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same");
44static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same");
45
46template <typename T>
47static void fillQ4blocks(std::vector<T>& blocks, std::mt19937& rndm) {
48 for (auto& b : blocks) {
49 b.d = 1;
50 for (int i=0; i<QK4_1/2; ++i) {
51 uint8_t v1 = rndm() >> 28;
52 uint8_t v2 = rndm() >> 28;
53 b.qs[i] = v1 | (v2 << 4);
54 }
55 }
56}
57
58static void fillQ80blocks(std::vector<block_q8_0>& blocks, std::mt19937& rndm) {
59 for (auto& b : blocks) {
60 b.d = 1;
61 int sum = 0;
62 for (int i=0; i<QK8_0; ++i) {
63 b.qs[i] = (rndm() >> 24) - 128;
64 sum += b.qs[i];
65 }
66 b.s = b.d * sum;
67 }
68}
69
70static float simpleDot(const block_q4_0& x, const block_q8_0& y) {
71 int s1 = 0; //, s2 = 0;
72 for (int i=0; i<QK4_1/2; i+=2) {
73 int v1 = x.qs[i+0] & 0xf;
74 int v2 = x.qs[i+0] >> 4;
75 int v3 = x.qs[i+1] & 0xf;
76 int v4 = x.qs[i+1] >> 4;
77 int j = 2*i;
78 s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3];
79 //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3];
80 }
81 return y.d * x.d * s1 - 8 * x.d * y.s;
82 //return y.d * x.d * (s1 - 8 * s2);
83}
84
85static float simpleDot(const block_q4_1& x, const block_q8_0& y) {
86 int s1 = 0; //, s2 = 0;
87 for (int i=0; i<QK4_1/2; i+=2) {
88 int v1 = x.qs[i+0] & 0xf;
89 int v2 = x.qs[i+0] >> 4;
90 int v3 = x.qs[i+1] & 0xf;
91 int v4 = x.qs[i+1] >> 4;
92 int j = 2*i;
93 s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3];
94 //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3];
95 }
96 return y.d * x.d * s1 + y.s * x.m;
97 //return y.d * (x.d * s1 + x.m * s2);
98}
99
100struct Stat {
101 double sum = 0, sumt = 0, sumt2 = 0, maxt = 0;
102 int nloop = 0;
103 void addResult(double s, double t) {
104 sum += s;
105 sumt += t; sumt2 += t*t; maxt = std::max(a: maxt, b: t);
106 ++nloop;
107 }
108 void reportResult(const char* title) const {
109 if (nloop < 1) {
110 printf(format: "%s(%s): no result\n",__func__,title);
111 return;
112 }
113 printf(format: "============ %s\n",title);
114 printf(format: "<dot> = %g\n",sum/nloop);
115 auto t = sumt/nloop, dt = sumt2/nloop - t*t;
116 if (dt > 0) dt = sqrt(x: dt);
117 printf(format: "<time> = %g +/- %g us. Max. time = %g us.\n",t,dt,maxt);
118 }
119};
120
121
122int main(int argc, char** argv) {
123
124 int nloop = argc > 1 ? atoi(nptr: argv[1]) : 10;
125 int type = argc > 2 ? atoi(nptr: argv[2]) : 1;
126
127 std::mt19937 rndm(1234);
128
129 std::vector<block_q4_1> x41;
130 std::vector<block_q4_0> x40;
131 std::vector<block_q8_0> y(kVecSize);
132 if (type == 0) x40.resize(new_size: kVecSize);
133 else {
134 x41.resize(new_size: kVecSize);
135 for (auto& b : x41) b.m = 1;
136 }
137
138 auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1;
139
140 const auto * funcs = ggml_get_type_traits_cpu(type: ggml_type);
141
142 Stat simple, ggml;
143
144 for (int iloop=0; iloop<nloop; ++iloop) {
145
146 if (type == 0) fillQ4blocks(blocks&: x40, rndm);
147 else fillQ4blocks(blocks&: x41, rndm);
148 fillQ80blocks(blocks&: y, rndm);
149
150 auto t1 = std::chrono::high_resolution_clock::now();
151 double s = 0;
152 if (type == 0) for (int i=0; i<kVecSize; ++i) s += simpleDot(x: x40[i], y: y[i]);
153 else for (int i=0; i<kVecSize; ++i) s += simpleDot(x: x41[i], y: y[i]);
154 auto t2 = std::chrono::high_resolution_clock::now();
155 auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(d: t2-t1).count();
156 if (iloop > 3) simple.addResult(s, t);
157
158 t1 = std::chrono::high_resolution_clock::now();
159 float fs;
160 if (type == 0) funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1);
161 else funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1);
162 t2 = std::chrono::high_resolution_clock::now();
163 t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(d: t2-t1).count();
164 if (iloop > 3) ggml.addResult(s: fs, t);
165
166 }
167
168 // Report the time (and the average of the dot products so the compiler does not come up with the idea
169 // of optimizing away the function calls after figuring that the result is not used).
170 simple.reportResult(title: "Simple");
171 ggml.reportResult(title: "ggml");
172 return 0;
173}
174