1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include "arrow/util/compression_zstd.h" |
19 | |
20 | #include <cstddef> |
21 | #include <cstdint> |
22 | #include <sstream> |
23 | |
24 | #include <zstd.h> |
25 | |
26 | #include "arrow/status.h" |
27 | #include "arrow/util/logging.h" |
28 | #include "arrow/util/macros.h" |
29 | |
30 | using std::size_t; |
31 | |
32 | namespace arrow { |
33 | namespace util { |
34 | |
35 | // XXX level = 1 probably doesn't compress very much |
36 | constexpr int kZSTDDefaultCompressionLevel = 1; |
37 | |
38 | static Status ZSTDError(size_t ret, const char* prefix_msg) { |
39 | return Status::IOError(prefix_msg, ZSTD_getErrorName(ret)); |
40 | } |
41 | |
42 | // ---------------------------------------------------------------------- |
43 | // ZSTD decompressor implementation |
44 | |
45 | class ZSTDDecompressor : public Decompressor { |
46 | public: |
47 | ZSTDDecompressor() : stream_(ZSTD_createDStream()) {} |
48 | |
49 | ~ZSTDDecompressor() override { ZSTD_freeDStream(stream_); } |
50 | |
51 | Status Init() { |
52 | finished_ = false; |
53 | size_t ret = ZSTD_initDStream(stream_); |
54 | if (ZSTD_isError(ret)) { |
55 | return ZSTDError(ret, "ZSTD init failed: " ); |
56 | } else { |
57 | return Status::OK(); |
58 | } |
59 | } |
60 | |
61 | Status Decompress(int64_t input_len, const uint8_t* input, int64_t output_len, |
62 | uint8_t* output, int64_t* bytes_read, int64_t* bytes_written, |
63 | bool* need_more_output) override { |
64 | ZSTD_inBuffer in_buf; |
65 | ZSTD_outBuffer out_buf; |
66 | |
67 | in_buf.src = input; |
68 | in_buf.size = static_cast<size_t>(input_len); |
69 | in_buf.pos = 0; |
70 | out_buf.dst = output; |
71 | out_buf.size = static_cast<size_t>(output_len); |
72 | out_buf.pos = 0; |
73 | |
74 | size_t ret; |
75 | ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf); |
76 | if (ZSTD_isError(ret)) { |
77 | return ZSTDError(ret, "ZSTD decompress failed: " ); |
78 | } |
79 | *bytes_read = static_cast<int64_t>(in_buf.pos); |
80 | *bytes_written = static_cast<int64_t>(out_buf.pos); |
81 | *need_more_output = *bytes_read == 0 && *bytes_written == 0; |
82 | finished_ = (ret == 0); |
83 | return Status::OK(); |
84 | } |
85 | |
86 | bool IsFinished() override { return finished_; } |
87 | |
88 | protected: |
89 | ZSTD_DStream* stream_; |
90 | bool finished_; |
91 | }; |
92 | |
93 | // ---------------------------------------------------------------------- |
94 | // ZSTD compressor implementation |
95 | |
96 | class ZSTDCompressor : public Compressor { |
97 | public: |
98 | ZSTDCompressor() : stream_(ZSTD_createCStream()) {} |
99 | |
100 | ~ZSTDCompressor() override { ZSTD_freeCStream(stream_); } |
101 | |
102 | Status Init() { |
103 | size_t ret = ZSTD_initCStream(stream_, kZSTDDefaultCompressionLevel); |
104 | if (ZSTD_isError(ret)) { |
105 | return ZSTDError(ret, "ZSTD init failed: " ); |
106 | } else { |
107 | return Status::OK(); |
108 | } |
109 | } |
110 | |
111 | Status Compress(int64_t input_len, const uint8_t* input, int64_t output_len, |
112 | uint8_t* output, int64_t* bytes_read, int64_t* bytes_written) override; |
113 | |
114 | Status Flush(int64_t output_len, uint8_t* output, int64_t* bytes_written, |
115 | bool* should_retry) override; |
116 | |
117 | Status End(int64_t output_len, uint8_t* output, int64_t* bytes_written, |
118 | bool* should_retry) override; |
119 | |
120 | protected: |
121 | ZSTD_CStream* stream_; |
122 | }; |
123 | |
124 | Status ZSTDCompressor::Compress(int64_t input_len, const uint8_t* input, |
125 | int64_t output_len, uint8_t* output, int64_t* bytes_read, |
126 | int64_t* bytes_written) { |
127 | ZSTD_inBuffer in_buf; |
128 | ZSTD_outBuffer out_buf; |
129 | |
130 | in_buf.src = input; |
131 | in_buf.size = static_cast<size_t>(input_len); |
132 | in_buf.pos = 0; |
133 | out_buf.dst = output; |
134 | out_buf.size = static_cast<size_t>(output_len); |
135 | out_buf.pos = 0; |
136 | |
137 | size_t ret; |
138 | ret = ZSTD_compressStream(stream_, &out_buf, &in_buf); |
139 | if (ZSTD_isError(ret)) { |
140 | return ZSTDError(ret, "ZSTD compress failed: " ); |
141 | } |
142 | *bytes_read = static_cast<int64_t>(in_buf.pos); |
143 | *bytes_written = static_cast<int64_t>(out_buf.pos); |
144 | return Status::OK(); |
145 | } |
146 | |
147 | Status ZSTDCompressor::Flush(int64_t output_len, uint8_t* output, int64_t* bytes_written, |
148 | bool* should_retry) { |
149 | ZSTD_outBuffer out_buf; |
150 | |
151 | out_buf.dst = output; |
152 | out_buf.size = static_cast<size_t>(output_len); |
153 | out_buf.pos = 0; |
154 | |
155 | size_t ret; |
156 | ret = ZSTD_flushStream(stream_, &out_buf); |
157 | if (ZSTD_isError(ret)) { |
158 | return ZSTDError(ret, "ZSTD flush failed: " ); |
159 | } |
160 | *bytes_written = static_cast<int64_t>(out_buf.pos); |
161 | *should_retry = ret > 0; |
162 | return Status::OK(); |
163 | } |
164 | |
165 | Status ZSTDCompressor::End(int64_t output_len, uint8_t* output, int64_t* bytes_written, |
166 | bool* should_retry) { |
167 | ZSTD_outBuffer out_buf; |
168 | |
169 | out_buf.dst = output; |
170 | out_buf.size = static_cast<size_t>(output_len); |
171 | out_buf.pos = 0; |
172 | |
173 | size_t ret; |
174 | ret = ZSTD_endStream(stream_, &out_buf); |
175 | if (ZSTD_isError(ret)) { |
176 | return ZSTDError(ret, "ZSTD end failed: " ); |
177 | } |
178 | *bytes_written = static_cast<int64_t>(out_buf.pos); |
179 | *should_retry = ret > 0; |
180 | return Status::OK(); |
181 | } |
182 | |
183 | // ---------------------------------------------------------------------- |
184 | // ZSTD codec implementation |
185 | |
186 | Status ZSTDCodec::MakeCompressor(std::shared_ptr<Compressor>* out) { |
187 | auto ptr = std::make_shared<ZSTDCompressor>(); |
188 | RETURN_NOT_OK(ptr->Init()); |
189 | *out = ptr; |
190 | return Status::OK(); |
191 | } |
192 | |
193 | Status ZSTDCodec::MakeDecompressor(std::shared_ptr<Decompressor>* out) { |
194 | auto ptr = std::make_shared<ZSTDDecompressor>(); |
195 | RETURN_NOT_OK(ptr->Init()); |
196 | *out = ptr; |
197 | return Status::OK(); |
198 | } |
199 | |
200 | Status ZSTDCodec::Decompress(int64_t input_len, const uint8_t* input, |
201 | int64_t output_buffer_len, uint8_t* output_buffer) { |
202 | return Decompress(input_len, input, output_buffer_len, output_buffer, nullptr); |
203 | } |
204 | |
205 | Status ZSTDCodec::Decompress(int64_t input_len, const uint8_t* input, |
206 | int64_t output_buffer_len, uint8_t* output_buffer, |
207 | int64_t* output_len) { |
208 | if (output_buffer == nullptr) { |
209 | // We may pass a NULL 0-byte output buffer but some zstd versions demand |
210 | // a valid pointer: https://github.com/facebook/zstd/issues/1385 |
211 | static uint8_t empty_buffer[1]; |
212 | DCHECK_EQ(output_buffer_len, 0); |
213 | output_buffer = empty_buffer; |
214 | } |
215 | |
216 | size_t ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_buffer_len), |
217 | input, static_cast<size_t>(input_len)); |
218 | if (ZSTD_isError(ret)) { |
219 | return ZSTDError(ret, "ZSTD decompression failed: " ); |
220 | } |
221 | if (static_cast<int64_t>(ret) != output_buffer_len) { |
222 | return Status::IOError("Corrupt ZSTD compressed data." ); |
223 | } |
224 | if (output_len) { |
225 | *output_len = static_cast<int64_t>(ret); |
226 | } |
227 | return Status::OK(); |
228 | } |
229 | |
230 | int64_t ZSTDCodec::MaxCompressedLen(int64_t input_len, |
231 | const uint8_t* ARROW_ARG_UNUSED(input)) { |
232 | return ZSTD_compressBound(input_len); |
233 | } |
234 | |
235 | Status ZSTDCodec::Compress(int64_t input_len, const uint8_t* input, |
236 | int64_t output_buffer_len, uint8_t* output_buffer, |
237 | int64_t* output_len) { |
238 | size_t ret = |
239 | ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len), input, |
240 | static_cast<size_t>(input_len), kZSTDDefaultCompressionLevel); |
241 | if (ZSTD_isError(ret)) { |
242 | return ZSTDError(ret, "ZSTD compression failed: " ); |
243 | } |
244 | *output_len = static_cast<int64_t>(ret); |
245 | return Status::OK(); |
246 | } |
247 | |
248 | } // namespace util |
249 | } // namespace arrow |
250 | |