1// SPDX-FileCopyrightText: 2023 UnionTech Software Technology Co., Ltd.
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4
5#include "zstd_writer.h"
6#include "easylogging++.h"
7#include <algorithm>
8#include <errno.h>
9#include <assert.h>
10
11using namespace std;
12
13ZstdWriter::ZstdWriter(void)
14{
15 pos = 0;
16 total = 0;
17 fout = nullptr;
18 cstream = nullptr;
19 buffInSize = 0;
20 buffOutSize = 0;
21}
22
23ZstdWriter::~ZstdWriter(void)
24{
25 close();
26}
27
28int ZstdWriter::open(const char* outName, int compress_level)
29{
30 fout = fopen(outName, "wb");
31 if (NULL == fout) {
32 LOG(ERROR) << "failed create file : " << outName << ", error:" << errno;
33 }
34 buffInSize = ZSTD_CStreamInSize(); /* can always read one full block */
35 pos = 0;
36 total = buffInSize*10;
37 buffIn = make_unique<char[]>(total);
38 buffOutSize = ZSTD_CStreamOutSize(); /* can always flush a full block */
39 buffOut = make_unique<char[]>(buffOutSize);
40
41 cstream = ZSTD_createCStream();
42 if (cstream==NULL) {
43 LOG(ERROR) << "ZSTD_createCStream() error";
44 exit(10);
45 }
46 size_t const initResult = ZSTD_initCStream(cstream, compress_level);
47 if (ZSTD_isError(initResult)) {
48 LOG(ERROR) << "ZSTD_initCStream() error : " << ZSTD_getErrorName(initResult);
49 exit(11);
50 }
51
52 return 0;
53}
54
55int ZstdWriter::write(const void* buf, size_t size)
56{
57 if (total < pos + size) {
58 flush(buffIn.get(), pos);
59 pos = 0;
60 }
61
62 if (total < size) {
63 assert (0 == pos);
64 flush(buf, size);
65 }
66 else {
67 memcpy(buffIn.get() + pos, buf, size);
68 pos += size;
69 }
70
71 return static_cast<int>(size);
72}
73
74int ZstdWriter::close()
75{
76 if (cstream) {
77 flush(buffIn.get(), pos);
78
79 ZSTD_outBuffer output = { buffOut.get(), buffOutSize, 0 };
80 /* close frame */
81 size_t const remainingToFlush = ZSTD_endStream(cstream, &output);
82 if (remainingToFlush) {
83 LOG(ERROR) << "not fully flushed";
84 }
85 fwrite(buffOut.get(), 1, output.pos, fout);
86 fclose(fout);
87 fout = nullptr;
88
89 ZSTD_freeCStream(cstream);
90 cstream = nullptr;
91 }
92
93 return 0;
94}
95
96/****************** private methods */
97int ZstdWriter::flush(const void* buf, size_t size)
98{
99 size_t toRead = buffInSize;
100 const char* walk = reinterpret_cast<const char*>(buf);
101 const char* end = walk + size;
102
103#if 0
104 {
105 // debug
106 char name[128];
107 sprintf(name, "/tmp/strace2-%d.bin", size);
108 FILE* pf = fopen(name, "wb");
109 if (pf) {
110 fwrite(buf, 1, size, pf);
111 fclose(pf);
112 }
113 }
114#endif
115
116 while (walk < end) {
117 ZSTD_inBuffer input = {walk, 0, 0 };
118
119 if (walk + toRead > end) input.size = (end - walk);
120 else input.size = toRead;
121 walk += input.size;
122
123 while (input.pos < input.size) {
124 ZSTD_outBuffer output = { buffOut.get(), buffOutSize, 0 };
125
126 /* toRead is guaranteed to be <= ZSTD_CStreamInSize() */
127 toRead = ZSTD_compressStream(cstream, &output , &input);
128 if (ZSTD_isError(toRead)) {
129 LOG(ERROR) << "ZSTD_compressStream() error : "
130 << ZSTD_getErrorName(toRead);
131 return -1;
132 }
133
134 /* Safely handle case when `buffInSize` is manually changed to a value <
135 ZSTD_CStreamInSize()*/
136 if (toRead > buffInSize) toRead = buffInSize;
137
138 /* Safely handle case when `buffInSize` is manually changed to
139 * a value < ZSTD_CStreamInSize()*/
140 fwrite(buffOut.get(), 1, output.pos, fout);
141 }
142 }
143
144 return 0;
145}
146
147