1/*
2 * Copyright 2018-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <folly/compression/Zstd.h>
17
18#if FOLLY_HAVE_LIBZSTD
19
20#include <stdexcept>
21#include <string>
22
23#include <zstd.h>
24
25#include <folly/Conv.h>
26#include <folly/Range.h>
27#include <folly/ScopeGuard.h>
28#include <folly/compression/Utils.h>
29
30static_assert(
31 ZSTD_VERSION_NUMBER >= 10302,
32 "zstd-1.3.2 is the minimum supported zstd version.");
33
34using folly::io::compression::detail::dataStartsWithLE;
35using folly::io::compression::detail::prefixToStringLE;
36
37namespace folly {
38namespace io {
39namespace zstd {
40namespace {
41
42// Compatibility helpers for zstd versions < 1.3.8.
43#if ZSTD_VERSION_NUMBER < 10308
44
45#define ZSTD_compressStream2 ZSTD_compress_generic
46#define ZSTD_c_compressionLevel ZSTD_p_compressionLevel
47#define ZSTD_c_contentSizeFlag ZSTD_p_contentSizeFlag
48
49void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
50 ZSTD_CCtx_reset(cctx);
51}
52
53void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
54 ZSTD_DCtx_reset(dctx);
55}
56
57#else
58
59void resetCCtxSessionAndParameters(ZSTD_CCtx* cctx) {
60 ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
61}
62
63void resetDCtxSessionAndParameters(ZSTD_DCtx* dctx) {
64 ZSTD_DCtx_reset(dctx, ZSTD_reset_session_and_parameters);
65}
66
67#endif
68
69void zstdFreeCCtx(ZSTD_CCtx* zc) {
70 ZSTD_freeCCtx(zc);
71}
72
73void zstdFreeDCtx(ZSTD_DCtx* zd) {
74 ZSTD_freeDCtx(zd);
75}
76
77size_t zstdThrowIfError(size_t rc) {
78 if (!ZSTD_isError(rc)) {
79 return rc;
80 }
81 throw std::runtime_error(
82 to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
83}
84
85ZSTD_EndDirective zstdTranslateFlush(StreamCodec::FlushOp flush) {
86 switch (flush) {
87 case StreamCodec::FlushOp::NONE:
88 return ZSTD_e_continue;
89 case StreamCodec::FlushOp::FLUSH:
90 return ZSTD_e_flush;
91 case StreamCodec::FlushOp::END:
92 return ZSTD_e_end;
93 default:
94 throw std::invalid_argument("ZSTDStreamCodec: Invalid flush");
95 }
96}
97
98class ZSTDStreamCodec final : public StreamCodec {
99 public:
100 explicit ZSTDStreamCodec(Options options);
101
102 std::vector<std::string> validPrefixes() const override;
103 bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
104 const override;
105
106 private:
107 bool doNeedsUncompressedLength() const override;
108 uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
109 Optional<uint64_t> doGetUncompressedLength(
110 IOBuf const* data,
111 Optional<uint64_t> uncompressedLength) const override;
112
113 void doResetStream() override;
114 bool doCompressStream(
115 ByteRange& input,
116 MutableByteRange& output,
117 StreamCodec::FlushOp flushOp) override;
118 bool doUncompressStream(
119 ByteRange& input,
120 MutableByteRange& output,
121 StreamCodec::FlushOp flushOp) override;
122
123 void resetCCtx();
124 void resetDCtx();
125
126 Options options_;
127 bool needReset_{true};
128 std::unique_ptr<
129 ZSTD_CCtx,
130 folly::static_function_deleter<ZSTD_CCtx, &zstdFreeCCtx>>
131 cctx_{nullptr};
132 std::unique_ptr<
133 ZSTD_DCtx,
134 folly::static_function_deleter<ZSTD_DCtx, &zstdFreeDCtx>>
135 dctx_{nullptr};
136};
137
138constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
139
140std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
141 return {prefixToStringLE(kZSTDMagicLE)};
142}
143
144bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
145 const {
146 return dataStartsWithLE(data, kZSTDMagicLE);
147}
148
149CodecType codecType(Options const& options) {
150 int const level = options.level();
151 DCHECK_NE(level, 0);
152 return level > 0 ? CodecType::ZSTD : CodecType::ZSTD_FAST;
153}
154
155ZSTDStreamCodec::ZSTDStreamCodec(Options options)
156 : StreamCodec(codecType(options), options.level()),
157 options_(std::move(options)) {}
158
159bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
160 return false;
161}
162
163uint64_t ZSTDStreamCodec::doMaxCompressedLength(
164 uint64_t uncompressedLength) const {
165 return ZSTD_compressBound(uncompressedLength);
166}
167
168Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
169 IOBuf const* data,
170 Optional<uint64_t> uncompressedLength) const {
171 // Read decompressed size from frame if available in first IOBuf.
172 auto const decompressedSize =
173 ZSTD_getFrameContentSize(data->data(), data->length());
174 if (decompressedSize == ZSTD_CONTENTSIZE_UNKNOWN ||
175 decompressedSize == ZSTD_CONTENTSIZE_ERROR) {
176 return uncompressedLength;
177 }
178 if (uncompressedLength && *uncompressedLength != decompressedSize) {
179 throw std::runtime_error("ZSTD: invalid uncompressed length");
180 }
181 return decompressedSize;
182}
183
184void ZSTDStreamCodec::doResetStream() {
185 needReset_ = true;
186}
187
188void ZSTDStreamCodec::resetCCtx() {
189 if (!cctx_) {
190 cctx_.reset(ZSTD_createCCtx());
191 if (!cctx_) {
192 throw std::bad_alloc{};
193 }
194 }
195 resetCCtxSessionAndParameters(cctx_.get());
196 zstdThrowIfError(
197 ZSTD_CCtx_setParametersUsingCCtxParams(cctx_.get(), options_.params()));
198 zstdThrowIfError(ZSTD_CCtx_setPledgedSrcSize(
199 cctx_.get(), uncompressedLength().value_or(ZSTD_CONTENTSIZE_UNKNOWN)));
200}
201
202bool ZSTDStreamCodec::doCompressStream(
203 ByteRange& input,
204 MutableByteRange& output,
205 StreamCodec::FlushOp flushOp) {
206 if (needReset_) {
207 resetCCtx();
208 needReset_ = false;
209 }
210 ZSTD_inBuffer in = {input.data(), input.size(), 0};
211 ZSTD_outBuffer out = {output.data(), output.size(), 0};
212 SCOPE_EXIT {
213 input.uncheckedAdvance(in.pos);
214 output.uncheckedAdvance(out.pos);
215 };
216 size_t const rc = zstdThrowIfError(ZSTD_compressStream2(
217 cctx_.get(), &out, &in, zstdTranslateFlush(flushOp)));
218 switch (flushOp) {
219 case StreamCodec::FlushOp::NONE:
220 return false;
221 case StreamCodec::FlushOp::FLUSH:
222 case StreamCodec::FlushOp::END:
223 return rc == 0;
224 default:
225 throw std::invalid_argument("ZSTD: invalid FlushOp");
226 }
227}
228
229void ZSTDStreamCodec::resetDCtx() {
230 if (!dctx_) {
231 dctx_.reset(ZSTD_createDCtx());
232 if (!dctx_) {
233 throw std::bad_alloc{};
234 }
235 }
236 resetDCtxSessionAndParameters(dctx_.get());
237 if (options_.maxWindowSize() != 0) {
238 zstdThrowIfError(
239 ZSTD_DCtx_setMaxWindowSize(dctx_.get(), options_.maxWindowSize()));
240 }
241}
242
243bool ZSTDStreamCodec::doUncompressStream(
244 ByteRange& input,
245 MutableByteRange& output,
246 StreamCodec::FlushOp) {
247 if (needReset_) {
248 resetDCtx();
249 needReset_ = false;
250 }
251 ZSTD_inBuffer in = {input.data(), input.size(), 0};
252 ZSTD_outBuffer out = {output.data(), output.size(), 0};
253 SCOPE_EXIT {
254 input.uncheckedAdvance(in.pos);
255 output.uncheckedAdvance(out.pos);
256 };
257 size_t const rc =
258 zstdThrowIfError(ZSTD_decompressStream(dctx_.get(), &out, &in));
259 return rc == 0;
260}
261
262} // namespace
263
264Options::Options(int level) : params_(ZSTD_createCCtxParams()), level_(level) {
265 if (params_ == nullptr) {
266 throw std::bad_alloc{};
267 }
268#if ZSTD_VERSION_NUMBER >= 10304
269 zstdThrowIfError(ZSTD_CCtxParams_init(params_.get(), level));
270#else
271 zstdThrowIfError(ZSTD_initCCtxParams(params_.get(), level));
272 set(ZSTD_c_contentSizeFlag, 1);
273#endif
274 // zstd-1.3.4 is buggy and only disables Huffman decompression for negative
275 // compression levels if this call is present. This call is begign in other
276 // versions.
277 set(ZSTD_c_compressionLevel, level);
278}
279
280void Options::set(ZSTD_cParameter param, unsigned value) {
281 zstdThrowIfError(ZSTD_CCtxParam_setParameter(params_.get(), param, value));
282 if (param == ZSTD_c_compressionLevel) {
283 level_ = static_cast<int>(value);
284 }
285}
286
287/* static */ void Options::freeCCtxParams(ZSTD_CCtx_params* params) {
288 ZSTD_freeCCtxParams(params);
289}
290
291std::unique_ptr<Codec> getCodec(Options options) {
292 return std::make_unique<ZSTDStreamCodec>(std::move(options));
293}
294
295std::unique_ptr<StreamCodec> getStreamCodec(Options options) {
296 return std::make_unique<ZSTDStreamCodec>(std::move(options));
297}
298
299} // namespace zstd
300} // namespace io
301} // namespace folly
302
303#endif
304