1#pragma once
2
3#include "common.cuh"
4#include "mmq.cuh"
5
6#include <cstdint>
7
8#define CUDA_QUANTIZE_BLOCK_SIZE 256
9#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
10
11static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
12static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
13
14typedef void (*quantize_cuda_t)(
15 const float * x, const int32_t * ids, void * vy,
16 ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
17 int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
18
19void quantize_row_q8_1_cuda(
20 const float * x, const int32_t * ids, void * vy,
21 ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
22 int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
23
24void quantize_mmq_q8_1_cuda(
25 const float * x, const int32_t * ids, void * vy,
26 ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
27 int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
28