1#include "common.cuh"
2
3static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
4 const block_q4_0 * x = (const block_q4_0 *) vx;
5
6 const float d = x[ib].d;
7
8 const int vui = x[ib].qs[iqs];
9
10 v.x = vui & 0xF;
11 v.y = vui >> 4;
12
13 v.x = (v.x - 8.0f) * d;
14 v.y = (v.y - 8.0f) * d;
15}
16
17static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
18 const block_q4_1 * x = (const block_q4_1 *) vx;
19
20 const float2 dm = __half22float2(x[ib].dm);
21
22 const int vui = x[ib].qs[iqs];
23
24 v.x = vui & 0xF;
25 v.y = vui >> 4;
26
27 v.x = (v.x * dm.x) + dm.y;
28 v.y = (v.y * dm.x) + dm.y;
29}
30
31static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
32 const block_q5_0 * x = (const block_q5_0 *) vx;
33
34 const float d = x[ib].d;
35
36 uint32_t qh;
37 memcpy(&qh, x[ib].qh, sizeof(qh));
38
39 const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
40 const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
41
42 v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
43 v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
44
45 v.x = (v.x - 16.0f) * d;
46 v.y = (v.y - 16.0f) * d;
47}
48
49static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
50 const block_q5_1 * x = (const block_q5_1 *) vx;
51
52 const float2 dm = __half22float2(x[ib].dm);
53
54 uint32_t qh;
55 memcpy(&qh, x[ib].qh, sizeof(qh));
56
57 const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
58 const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
59
60 v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
61 v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
62
63 v.x = (v.x * dm.x) + dm.y;
64 v.y = (v.y * dm.x) + dm.y;
65}
66
67static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
68 const block_q8_0 * x = (const block_q8_0 *) vx;
69
70 const float d = x[ib].d;
71
72 v.x = x[ib].qs[iqs + 0];
73 v.y = x[ib].qs[iqs + 1];
74
75 v.x *= d;
76 v.y *= d;
77}
78