1 | #include <Compression/CompressionCodecZSTD.h> |
2 | #include <Compression/CompressionInfo.h> |
3 | #include <IO/ReadHelpers.h> |
4 | #include <Compression/CompressionFactory.h> |
5 | #include <zstd.h> |
6 | #include <Core/Field.h> |
7 | #include <Parsers/IAST.h> |
8 | #include <Parsers/ASTLiteral.h> |
9 | #include <Common/typeid_cast.h> |
10 | #include <IO/WriteHelpers.h> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace ErrorCodes |
17 | { |
18 | extern const int CANNOT_COMPRESS; |
19 | extern const int CANNOT_DECOMPRESS; |
20 | extern const int ILLEGAL_SYNTAX_FOR_CODEC_TYPE; |
21 | extern const int ILLEGAL_CODEC_PARAMETER; |
22 | } |
23 | |
24 | UInt8 CompressionCodecZSTD::getMethodByte() const |
25 | { |
26 | return static_cast<UInt8>(CompressionMethodByte::ZSTD); |
27 | } |
28 | |
29 | String CompressionCodecZSTD::getCodecDesc() const |
30 | { |
31 | return "ZSTD(" + toString(level) + ")" ; |
32 | } |
33 | |
34 | UInt32 CompressionCodecZSTD::getMaxCompressedDataSize(UInt32 uncompressed_size) const |
35 | { |
36 | return ZSTD_compressBound(uncompressed_size); |
37 | } |
38 | |
39 | |
40 | UInt32 CompressionCodecZSTD::doCompressData(const char * source, UInt32 source_size, char * dest) const |
41 | { |
42 | size_t compressed_size = ZSTD_compress(dest, ZSTD_compressBound(source_size), source, source_size, level); |
43 | |
44 | if (ZSTD_isError(compressed_size)) |
45 | throw Exception("Cannot compress block with ZSTD: " + std::string(ZSTD_getErrorName(compressed_size)), ErrorCodes::CANNOT_COMPRESS); |
46 | |
47 | return compressed_size; |
48 | } |
49 | |
50 | |
51 | void CompressionCodecZSTD::doDecompressData(const char * source, UInt32 source_size, char * dest, UInt32 uncompressed_size) const |
52 | { |
53 | size_t res = ZSTD_decompress(dest, uncompressed_size, source, source_size); |
54 | |
55 | if (ZSTD_isError(res)) |
56 | throw Exception("Cannot ZSTD_decompress: " + std::string(ZSTD_getErrorName(res)), ErrorCodes::CANNOT_DECOMPRESS); |
57 | } |
58 | |
59 | CompressionCodecZSTD::CompressionCodecZSTD(int level_) |
60 | :level(level_) |
61 | { |
62 | } |
63 | |
64 | void registerCodecZSTD(CompressionCodecFactory & factory) |
65 | { |
66 | UInt8 method_code = UInt8(CompressionMethodByte::ZSTD); |
67 | factory.registerCompressionCodec("ZSTD" , method_code, [&](const ASTPtr & arguments) -> CompressionCodecPtr |
68 | { |
69 | int level = CompressionCodecZSTD::ZSTD_DEFAULT_LEVEL; |
70 | if (arguments && !arguments->children.empty()) |
71 | { |
72 | if (arguments->children.size() > 1) |
73 | throw Exception("ZSTD codec must have 1 parameter, given " + std::to_string(arguments->children.size()), ErrorCodes::ILLEGAL_SYNTAX_FOR_CODEC_TYPE); |
74 | |
75 | const auto children = arguments->children; |
76 | const auto * literal = children[0]->as<ASTLiteral>(); |
77 | level = literal->value.safeGet<UInt64>(); |
78 | if (level > ZSTD_maxCLevel()) |
79 | throw Exception("ZSTD codec can't have level more that " + toString(ZSTD_maxCLevel()) + ", given " + toString(level), ErrorCodes::ILLEGAL_CODEC_PARAMETER); |
80 | } |
81 | |
82 | return std::make_shared<CompressionCodecZSTD>(level); |
83 | }); |
84 | } |
85 | |
86 | } |
87 | |