1 | // SPDX-FileCopyrightText: 2023 UnionTech Software Technology Co., Ltd. |
2 | // |
3 | // SPDX-License-Identifier: GPL-3.0-or-later |
4 | |
5 | #include "zstd_writer.h" |
6 | #include "easylogging++.h" |
7 | #include <algorithm> |
8 | #include <errno.h> |
9 | #include <assert.h> |
10 | |
11 | using namespace std; |
12 | |
13 | ZstdWriter::ZstdWriter(void) |
14 | { |
15 | pos = 0; |
16 | total = 0; |
17 | fout = nullptr; |
18 | cstream = nullptr; |
19 | buffInSize = 0; |
20 | buffOutSize = 0; |
21 | } |
22 | |
23 | ZstdWriter::~ZstdWriter(void) |
24 | { |
25 | close(); |
26 | } |
27 | |
28 | int ZstdWriter::open(const char* outName, int compress_level) |
29 | { |
30 | fout = fopen(outName, "wb" ); |
31 | if (NULL == fout) { |
32 | LOG(ERROR) << "failed create file : " << outName << ", error:" << errno; |
33 | } |
34 | buffInSize = ZSTD_CStreamInSize(); /* can always read one full block */ |
35 | pos = 0; |
36 | total = buffInSize*10; |
37 | buffIn = make_unique<char[]>(total); |
38 | buffOutSize = ZSTD_CStreamOutSize(); /* can always flush a full block */ |
39 | buffOut = make_unique<char[]>(buffOutSize); |
40 | |
41 | cstream = ZSTD_createCStream(); |
42 | if (cstream==NULL) { |
43 | LOG(ERROR) << "ZSTD_createCStream() error" ; |
44 | exit(10); |
45 | } |
46 | size_t const initResult = ZSTD_initCStream(cstream, compress_level); |
47 | if (ZSTD_isError(initResult)) { |
48 | LOG(ERROR) << "ZSTD_initCStream() error : " << ZSTD_getErrorName(initResult); |
49 | exit(11); |
50 | } |
51 | |
52 | return 0; |
53 | } |
54 | |
55 | int ZstdWriter::write(const void* buf, size_t size) |
56 | { |
57 | if (total < pos + size) { |
58 | flush(buffIn.get(), pos); |
59 | pos = 0; |
60 | } |
61 | |
62 | if (total < size) { |
63 | assert (0 == pos); |
64 | flush(buf, size); |
65 | } |
66 | else { |
67 | memcpy(buffIn.get() + pos, buf, size); |
68 | pos += size; |
69 | } |
70 | |
71 | return static_cast<int>(size); |
72 | } |
73 | |
74 | int ZstdWriter::close() |
75 | { |
76 | if (cstream) { |
77 | flush(buffIn.get(), pos); |
78 | |
79 | ZSTD_outBuffer output = { buffOut.get(), buffOutSize, 0 }; |
80 | /* close frame */ |
81 | size_t const remainingToFlush = ZSTD_endStream(cstream, &output); |
82 | if (remainingToFlush) { |
83 | LOG(ERROR) << "not fully flushed" ; |
84 | } |
85 | fwrite(buffOut.get(), 1, output.pos, fout); |
86 | fclose(fout); |
87 | fout = nullptr; |
88 | |
89 | ZSTD_freeCStream(cstream); |
90 | cstream = nullptr; |
91 | } |
92 | |
93 | return 0; |
94 | } |
95 | |
96 | /****************** private methods */ |
97 | int ZstdWriter::flush(const void* buf, size_t size) |
98 | { |
99 | size_t toRead = buffInSize; |
100 | const char* walk = reinterpret_cast<const char*>(buf); |
101 | const char* end = walk + size; |
102 | |
103 | #if 0 |
104 | { |
105 | // debug |
106 | char name[128]; |
107 | sprintf(name, "/tmp/strace2-%d.bin" , size); |
108 | FILE* pf = fopen(name, "wb" ); |
109 | if (pf) { |
110 | fwrite(buf, 1, size, pf); |
111 | fclose(pf); |
112 | } |
113 | } |
114 | #endif |
115 | |
116 | while (walk < end) { |
117 | ZSTD_inBuffer input = {walk, 0, 0 }; |
118 | |
119 | if (walk + toRead > end) input.size = (end - walk); |
120 | else input.size = toRead; |
121 | walk += input.size; |
122 | |
123 | while (input.pos < input.size) { |
124 | ZSTD_outBuffer output = { buffOut.get(), buffOutSize, 0 }; |
125 | |
126 | /* toRead is guaranteed to be <= ZSTD_CStreamInSize() */ |
127 | toRead = ZSTD_compressStream(cstream, &output , &input); |
128 | if (ZSTD_isError(toRead)) { |
129 | LOG(ERROR) << "ZSTD_compressStream() error : " |
130 | << ZSTD_getErrorName(toRead); |
131 | return -1; |
132 | } |
133 | |
134 | /* Safely handle case when `buffInSize` is manually changed to a value < |
135 | ZSTD_CStreamInSize()*/ |
136 | if (toRead > buffInSize) toRead = buffInSize; |
137 | |
138 | /* Safely handle case when `buffInSize` is manually changed to |
139 | * a value < ZSTD_CStreamInSize()*/ |
140 | fwrite(buffOut.get(), 1, output.pos, fout); |
141 | } |
142 | } |
143 | |
144 | return 0; |
145 | } |
146 | |
147 | |