| 1 | #include <cstdio> |
| 2 | #include <type_traits> |
| 3 | #include <vector> |
| 4 | #include <random> |
| 5 | #include <chrono> |
| 6 | #include <cstdlib> |
| 7 | #include <cmath> |
| 8 | #include <cassert> |
| 9 | #include <cstring> |
| 10 | #include <array> |
| 11 | #include <type_traits> |
| 12 | |
| 13 | #include <ggml.h> |
| 14 | #include <ggml-cpu.h> |
| 15 | |
| 16 | constexpr int kVecSize = 1 << 16; |
| 17 | |
| 18 | // Copy-pasted from ggml.c |
| 19 | #define QK4_0 32 |
| 20 | typedef struct { |
| 21 | float d; // delta |
| 22 | uint8_t qs[QK4_0 / 2]; // nibbles / quants |
| 23 | } block_q4_0; |
| 24 | static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding" ); |
| 25 | |
| 26 | #define QK4_1 32 |
| 27 | typedef struct { |
| 28 | float d; // delta |
| 29 | float m; // min |
| 30 | uint8_t qs[QK4_1 / 2]; // nibbles / quants |
| 31 | } block_q4_1; |
| 32 | static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding" ); |
| 33 | |
| 34 | // Copy-pasted from ggml.c |
| 35 | #define QK8_0 32 |
| 36 | typedef struct { |
| 37 | float d; // delta |
| 38 | float s; // d * sum(qs[i]) |
| 39 | int8_t qs[QK8_0]; // quants |
| 40 | } block_q8_0; |
| 41 | static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding" ); |
| 42 | |
| 43 | static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same" ); |
| 44 | static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same" ); |
| 45 | |
| 46 | template <typename T> |
| 47 | static void fillQ4blocks(std::vector<T>& blocks, std::mt19937& rndm) { |
| 48 | for (auto& b : blocks) { |
| 49 | b.d = 1; |
| 50 | for (int i=0; i<QK4_1/2; ++i) { |
| 51 | uint8_t v1 = rndm() >> 28; |
| 52 | uint8_t v2 = rndm() >> 28; |
| 53 | b.qs[i] = v1 | (v2 << 4); |
| 54 | } |
| 55 | } |
| 56 | } |
| 57 | |
| 58 | static void fillQ80blocks(std::vector<block_q8_0>& blocks, std::mt19937& rndm) { |
| 59 | for (auto& b : blocks) { |
| 60 | b.d = 1; |
| 61 | int sum = 0; |
| 62 | for (int i=0; i<QK8_0; ++i) { |
| 63 | b.qs[i] = (rndm() >> 24) - 128; |
| 64 | sum += b.qs[i]; |
| 65 | } |
| 66 | b.s = b.d * sum; |
| 67 | } |
| 68 | } |
| 69 | |
| 70 | static float simpleDot(const block_q4_0& x, const block_q8_0& y) { |
| 71 | int s1 = 0; //, s2 = 0; |
| 72 | for (int i=0; i<QK4_1/2; i+=2) { |
| 73 | int v1 = x.qs[i+0] & 0xf; |
| 74 | int v2 = x.qs[i+0] >> 4; |
| 75 | int v3 = x.qs[i+1] & 0xf; |
| 76 | int v4 = x.qs[i+1] >> 4; |
| 77 | int j = 2*i; |
| 78 | s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; |
| 79 | //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; |
| 80 | } |
| 81 | return y.d * x.d * s1 - 8 * x.d * y.s; |
| 82 | //return y.d * x.d * (s1 - 8 * s2); |
| 83 | } |
| 84 | |
| 85 | static float simpleDot(const block_q4_1& x, const block_q8_0& y) { |
| 86 | int s1 = 0; //, s2 = 0; |
| 87 | for (int i=0; i<QK4_1/2; i+=2) { |
| 88 | int v1 = x.qs[i+0] & 0xf; |
| 89 | int v2 = x.qs[i+0] >> 4; |
| 90 | int v3 = x.qs[i+1] & 0xf; |
| 91 | int v4 = x.qs[i+1] >> 4; |
| 92 | int j = 2*i; |
| 93 | s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; |
| 94 | //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; |
| 95 | } |
| 96 | return y.d * x.d * s1 + y.s * x.m; |
| 97 | //return y.d * (x.d * s1 + x.m * s2); |
| 98 | } |
| 99 | |
| 100 | struct Stat { |
| 101 | double sum = 0, sumt = 0, sumt2 = 0, maxt = 0; |
| 102 | int nloop = 0; |
| 103 | void addResult(double s, double t) { |
| 104 | sum += s; |
| 105 | sumt += t; sumt2 += t*t; maxt = std::max(a: maxt, b: t); |
| 106 | ++nloop; |
| 107 | } |
| 108 | void reportResult(const char* title) const { |
| 109 | if (nloop < 1) { |
| 110 | printf(format: "%s(%s): no result\n" ,__func__,title); |
| 111 | return; |
| 112 | } |
| 113 | printf(format: "============ %s\n" ,title); |
| 114 | printf(format: "<dot> = %g\n" ,sum/nloop); |
| 115 | auto t = sumt/nloop, dt = sumt2/nloop - t*t; |
| 116 | if (dt > 0) dt = sqrt(x: dt); |
| 117 | printf(format: "<time> = %g +/- %g us. Max. time = %g us.\n" ,t,dt,maxt); |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | |
| 122 | int main(int argc, char** argv) { |
| 123 | |
| 124 | int nloop = argc > 1 ? atoi(nptr: argv[1]) : 10; |
| 125 | int type = argc > 2 ? atoi(nptr: argv[2]) : 1; |
| 126 | |
| 127 | std::mt19937 rndm(1234); |
| 128 | |
| 129 | std::vector<block_q4_1> x41; |
| 130 | std::vector<block_q4_0> x40; |
| 131 | std::vector<block_q8_0> y(kVecSize); |
| 132 | if (type == 0) x40.resize(new_size: kVecSize); |
| 133 | else { |
| 134 | x41.resize(new_size: kVecSize); |
| 135 | for (auto& b : x41) b.m = 1; |
| 136 | } |
| 137 | |
| 138 | auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1; |
| 139 | |
| 140 | const auto * funcs = ggml_get_type_traits_cpu(type: ggml_type); |
| 141 | |
| 142 | Stat simple, ggml; |
| 143 | |
| 144 | for (int iloop=0; iloop<nloop; ++iloop) { |
| 145 | |
| 146 | if (type == 0) fillQ4blocks(blocks&: x40, rndm); |
| 147 | else fillQ4blocks(blocks&: x41, rndm); |
| 148 | fillQ80blocks(blocks&: y, rndm); |
| 149 | |
| 150 | auto t1 = std::chrono::high_resolution_clock::now(); |
| 151 | double s = 0; |
| 152 | if (type == 0) for (int i=0; i<kVecSize; ++i) s += simpleDot(x: x40[i], y: y[i]); |
| 153 | else for (int i=0; i<kVecSize; ++i) s += simpleDot(x: x41[i], y: y[i]); |
| 154 | auto t2 = std::chrono::high_resolution_clock::now(); |
| 155 | auto t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(d: t2-t1).count(); |
| 156 | if (iloop > 3) simple.addResult(s, t); |
| 157 | |
| 158 | t1 = std::chrono::high_resolution_clock::now(); |
| 159 | float fs; |
| 160 | if (type == 0) funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1); |
| 161 | else funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1); |
| 162 | t2 = std::chrono::high_resolution_clock::now(); |
| 163 | t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(d: t2-t1).count(); |
| 164 | if (iloop > 3) ggml.addResult(s: fs, t); |
| 165 | |
| 166 | } |
| 167 | |
| 168 | // Report the time (and the average of the dot products so the compiler does not come up with the idea |
| 169 | // of optimizing away the function calls after figuring that the result is not used). |
| 170 | simple.reportResult(title: "Simple" ); |
| 171 | ggml.reportResult(title: "ggml" ); |
| 172 | return 0; |
| 173 | } |
| 174 | |