1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/util/compression_zstd.h"
19
20#include <cstddef>
21#include <cstdint>
22#include <sstream>
23
24#include <zstd.h>
25
26#include "arrow/status.h"
27#include "arrow/util/logging.h"
28#include "arrow/util/macros.h"
29
30using std::size_t;
31
32namespace arrow {
33namespace util {
34
35// XXX level = 1 probably doesn't compress very much
36constexpr int kZSTDDefaultCompressionLevel = 1;
37
38static Status ZSTDError(size_t ret, const char* prefix_msg) {
39 return Status::IOError(prefix_msg, ZSTD_getErrorName(ret));
40}
41
42// ----------------------------------------------------------------------
43// ZSTD decompressor implementation
44
45class ZSTDDecompressor : public Decompressor {
46 public:
47 ZSTDDecompressor() : stream_(ZSTD_createDStream()) {}
48
49 ~ZSTDDecompressor() override { ZSTD_freeDStream(stream_); }
50
51 Status Init() {
52 finished_ = false;
53 size_t ret = ZSTD_initDStream(stream_);
54 if (ZSTD_isError(ret)) {
55 return ZSTDError(ret, "ZSTD init failed: ");
56 } else {
57 return Status::OK();
58 }
59 }
60
61 Status Decompress(int64_t input_len, const uint8_t* input, int64_t output_len,
62 uint8_t* output, int64_t* bytes_read, int64_t* bytes_written,
63 bool* need_more_output) override {
64 ZSTD_inBuffer in_buf;
65 ZSTD_outBuffer out_buf;
66
67 in_buf.src = input;
68 in_buf.size = static_cast<size_t>(input_len);
69 in_buf.pos = 0;
70 out_buf.dst = output;
71 out_buf.size = static_cast<size_t>(output_len);
72 out_buf.pos = 0;
73
74 size_t ret;
75 ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf);
76 if (ZSTD_isError(ret)) {
77 return ZSTDError(ret, "ZSTD decompress failed: ");
78 }
79 *bytes_read = static_cast<int64_t>(in_buf.pos);
80 *bytes_written = static_cast<int64_t>(out_buf.pos);
81 *need_more_output = *bytes_read == 0 && *bytes_written == 0;
82 finished_ = (ret == 0);
83 return Status::OK();
84 }
85
86 bool IsFinished() override { return finished_; }
87
88 protected:
89 ZSTD_DStream* stream_;
90 bool finished_;
91};
92
93// ----------------------------------------------------------------------
94// ZSTD compressor implementation
95
96class ZSTDCompressor : public Compressor {
97 public:
98 ZSTDCompressor() : stream_(ZSTD_createCStream()) {}
99
100 ~ZSTDCompressor() override { ZSTD_freeCStream(stream_); }
101
102 Status Init() {
103 size_t ret = ZSTD_initCStream(stream_, kZSTDDefaultCompressionLevel);
104 if (ZSTD_isError(ret)) {
105 return ZSTDError(ret, "ZSTD init failed: ");
106 } else {
107 return Status::OK();
108 }
109 }
110
111 Status Compress(int64_t input_len, const uint8_t* input, int64_t output_len,
112 uint8_t* output, int64_t* bytes_read, int64_t* bytes_written) override;
113
114 Status Flush(int64_t output_len, uint8_t* output, int64_t* bytes_written,
115 bool* should_retry) override;
116
117 Status End(int64_t output_len, uint8_t* output, int64_t* bytes_written,
118 bool* should_retry) override;
119
120 protected:
121 ZSTD_CStream* stream_;
122};
123
124Status ZSTDCompressor::Compress(int64_t input_len, const uint8_t* input,
125 int64_t output_len, uint8_t* output, int64_t* bytes_read,
126 int64_t* bytes_written) {
127 ZSTD_inBuffer in_buf;
128 ZSTD_outBuffer out_buf;
129
130 in_buf.src = input;
131 in_buf.size = static_cast<size_t>(input_len);
132 in_buf.pos = 0;
133 out_buf.dst = output;
134 out_buf.size = static_cast<size_t>(output_len);
135 out_buf.pos = 0;
136
137 size_t ret;
138 ret = ZSTD_compressStream(stream_, &out_buf, &in_buf);
139 if (ZSTD_isError(ret)) {
140 return ZSTDError(ret, "ZSTD compress failed: ");
141 }
142 *bytes_read = static_cast<int64_t>(in_buf.pos);
143 *bytes_written = static_cast<int64_t>(out_buf.pos);
144 return Status::OK();
145}
146
147Status ZSTDCompressor::Flush(int64_t output_len, uint8_t* output, int64_t* bytes_written,
148 bool* should_retry) {
149 ZSTD_outBuffer out_buf;
150
151 out_buf.dst = output;
152 out_buf.size = static_cast<size_t>(output_len);
153 out_buf.pos = 0;
154
155 size_t ret;
156 ret = ZSTD_flushStream(stream_, &out_buf);
157 if (ZSTD_isError(ret)) {
158 return ZSTDError(ret, "ZSTD flush failed: ");
159 }
160 *bytes_written = static_cast<int64_t>(out_buf.pos);
161 *should_retry = ret > 0;
162 return Status::OK();
163}
164
165Status ZSTDCompressor::End(int64_t output_len, uint8_t* output, int64_t* bytes_written,
166 bool* should_retry) {
167 ZSTD_outBuffer out_buf;
168
169 out_buf.dst = output;
170 out_buf.size = static_cast<size_t>(output_len);
171 out_buf.pos = 0;
172
173 size_t ret;
174 ret = ZSTD_endStream(stream_, &out_buf);
175 if (ZSTD_isError(ret)) {
176 return ZSTDError(ret, "ZSTD end failed: ");
177 }
178 *bytes_written = static_cast<int64_t>(out_buf.pos);
179 *should_retry = ret > 0;
180 return Status::OK();
181}
182
183// ----------------------------------------------------------------------
184// ZSTD codec implementation
185
186Status ZSTDCodec::MakeCompressor(std::shared_ptr<Compressor>* out) {
187 auto ptr = std::make_shared<ZSTDCompressor>();
188 RETURN_NOT_OK(ptr->Init());
189 *out = ptr;
190 return Status::OK();
191}
192
193Status ZSTDCodec::MakeDecompressor(std::shared_ptr<Decompressor>* out) {
194 auto ptr = std::make_shared<ZSTDDecompressor>();
195 RETURN_NOT_OK(ptr->Init());
196 *out = ptr;
197 return Status::OK();
198}
199
200Status ZSTDCodec::Decompress(int64_t input_len, const uint8_t* input,
201 int64_t output_buffer_len, uint8_t* output_buffer) {
202 return Decompress(input_len, input, output_buffer_len, output_buffer, nullptr);
203}
204
205Status ZSTDCodec::Decompress(int64_t input_len, const uint8_t* input,
206 int64_t output_buffer_len, uint8_t* output_buffer,
207 int64_t* output_len) {
208 if (output_buffer == nullptr) {
209 // We may pass a NULL 0-byte output buffer but some zstd versions demand
210 // a valid pointer: https://github.com/facebook/zstd/issues/1385
211 static uint8_t empty_buffer[1];
212 DCHECK_EQ(output_buffer_len, 0);
213 output_buffer = empty_buffer;
214 }
215
216 size_t ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_buffer_len),
217 input, static_cast<size_t>(input_len));
218 if (ZSTD_isError(ret)) {
219 return ZSTDError(ret, "ZSTD decompression failed: ");
220 }
221 if (static_cast<int64_t>(ret) != output_buffer_len) {
222 return Status::IOError("Corrupt ZSTD compressed data.");
223 }
224 if (output_len) {
225 *output_len = static_cast<int64_t>(ret);
226 }
227 return Status::OK();
228}
229
230int64_t ZSTDCodec::MaxCompressedLen(int64_t input_len,
231 const uint8_t* ARROW_ARG_UNUSED(input)) {
232 return ZSTD_compressBound(input_len);
233}
234
235Status ZSTDCodec::Compress(int64_t input_len, const uint8_t* input,
236 int64_t output_buffer_len, uint8_t* output_buffer,
237 int64_t* output_len) {
238 size_t ret =
239 ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len), input,
240 static_cast<size_t>(input_len), kZSTDDefaultCompressionLevel);
241 if (ZSTD_isError(ret)) {
242 return ZSTDError(ret, "ZSTD compression failed: ");
243 }
244 *output_len = static_cast<int64_t>(ret);
245 return Status::OK();
246}
247
248} // namespace util
249} // namespace arrow
250