1#include <Common/config.h>
2
3#if USE_AWS_S3
4
5#include <IO/S3Common.h>
6#include <Storages/StorageFactory.h>
7#include <Storages/StorageS3.h>
8
9#include <Interpreters/Context.h>
10#include <Interpreters/evaluateConstantExpression.h>
11#include <Parsers/ASTLiteral.h>
12
13#include <IO/ReadBufferFromS3.h>
14#include <IO/ReadHelpers.h>
15#include <IO/WriteBufferFromS3.h>
16#include <IO/WriteHelpers.h>
17
18#include <Formats/FormatFactory.h>
19
20#include <DataStreams/IBlockOutputStream.h>
21#include <DataStreams/IBlockInputStream.h>
22#include <DataStreams/AddingDefaultsBlockInputStream.h>
23
24#include <aws/s3/S3Client.h>
25
26
27namespace DB
28{
29namespace ErrorCodes
30{
31 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
32}
33
34
35namespace
36{
37 class StorageS3BlockInputStream : public IBlockInputStream
38 {
39 public:
40 StorageS3BlockInputStream(
41 const String & format,
42 const String & name_,
43 const Block & sample_block,
44 const Context & context,
45 UInt64 max_block_size,
46 const CompressionMethod compression_method,
47 const std::shared_ptr<Aws::S3::S3Client> & client,
48 const String & bucket,
49 const String & key)
50 : name(name_)
51 {
52 read_buf = getReadBuffer<ReadBufferFromS3>(compression_method, client, bucket, key);
53 reader = FormatFactory::instance().getInput(format, *read_buf, sample_block, context, max_block_size);
54 }
55
56 String getName() const override
57 {
58 return name;
59 }
60
61 Block readImpl() override
62 {
63 return reader->read();
64 }
65
66 Block getHeader() const override
67 {
68 return reader->getHeader();
69 }
70
71 void readPrefixImpl() override
72 {
73 reader->readPrefix();
74 }
75
76 void readSuffixImpl() override
77 {
78 reader->readSuffix();
79 }
80
81 private:
82 String name;
83 std::unique_ptr<ReadBuffer> read_buf;
84 BlockInputStreamPtr reader;
85 };
86
87 class StorageS3BlockOutputStream : public IBlockOutputStream
88 {
89 public:
90 StorageS3BlockOutputStream(
91 const String & format,
92 UInt64 min_upload_part_size,
93 const Block & sample_block_,
94 const Context & context,
95 const CompressionMethod compression_method,
96 const std::shared_ptr<Aws::S3::S3Client> & client,
97 const String & bucket,
98 const String & key)
99 : sample_block(sample_block_)
100 {
101 write_buf = getWriteBuffer<WriteBufferFromS3>(compression_method, client, bucket, key, min_upload_part_size);
102 writer = FormatFactory::instance().getOutput(format, *write_buf, sample_block, context);
103 }
104
105 Block getHeader() const override
106 {
107 return sample_block;
108 }
109
110 void write(const Block & block) override
111 {
112 writer->write(block);
113 }
114
115 void writePrefix() override
116 {
117 writer->writePrefix();
118 }
119
120 void writeSuffix() override
121 {
122 writer->writeSuffix();
123 writer->flush();
124 write_buf->finalize();
125 }
126
127 private:
128 Block sample_block;
129 std::unique_ptr<WriteBuffer> write_buf;
130 BlockOutputStreamPtr writer;
131 };
132}
133
134
135StorageS3::StorageS3(const S3::URI & uri_,
136 const String & access_key_id_,
137 const String & secret_access_key_,
138 const std::string & database_name_,
139 const std::string & table_name_,
140 const String & format_name_,
141 UInt64 min_upload_part_size_,
142 const ColumnsDescription & columns_,
143 const ConstraintsDescription & constraints_,
144 Context & context_,
145 const String & compression_method_ = "")
146 : IStorage(columns_)
147 , uri(uri_)
148 , context_global(context_)
149 , format_name(format_name_)
150 , database_name(database_name_)
151 , table_name(table_name_)
152 , min_upload_part_size(min_upload_part_size_)
153 , compression_method(compression_method_)
154 , client(S3::ClientFactory::instance().create(uri_.endpoint, access_key_id_, secret_access_key_))
155{
156 context_global.getRemoteHostFilter().checkURL(uri_.uri);
157 setColumns(columns_);
158 setConstraints(constraints_);
159}
160
161
162BlockInputStreams StorageS3::read(
163 const Names & column_names,
164 const SelectQueryInfo & /*query_info*/,
165 const Context & context,
166 QueryProcessingStage::Enum /*processed_stage*/,
167 size_t max_block_size,
168 unsigned /*num_streams*/)
169{
170 BlockInputStreamPtr block_input = std::make_shared<StorageS3BlockInputStream>(
171 format_name,
172 getName(),
173 getHeaderBlock(column_names),
174 context,
175 max_block_size,
176 IStorage::chooseCompressionMethod(uri.endpoint, compression_method),
177 client,
178 uri.bucket,
179 uri.key);
180
181 auto column_defaults = getColumns().getDefaults();
182 if (column_defaults.empty())
183 return {block_input};
184 return {std::make_shared<AddingDefaultsBlockInputStream>(block_input, column_defaults, context)};
185}
186
187void StorageS3::rename(const String & /*new_path_to_db*/, const String & new_database_name, const String & new_table_name, TableStructureWriteLockHolder &)
188{
189 table_name = new_table_name;
190 database_name = new_database_name;
191}
192
193BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const Context & /*context*/)
194{
195 return std::make_shared<StorageS3BlockOutputStream>(
196 format_name, min_upload_part_size, getSampleBlock(), context_global,
197 IStorage::chooseCompressionMethod(uri.endpoint, compression_method),
198 client, uri.bucket, uri.key);
199}
200
201void registerStorageS3(StorageFactory & factory)
202{
203 factory.registerStorage("S3", [](const StorageFactory::Arguments & args)
204 {
205 ASTs & engine_args = args.engine_args;
206
207 if (engine_args.size() < 2 || engine_args.size() > 5)
208 throw Exception(
209 "Storage S3 requires 2 to 5 arguments: url, [access_key_id, secret_access_key], name of used format and [compression_method].", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
210
211 for (size_t i = 0; i < engine_args.size(); ++i)
212 engine_args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[i], args.local_context);
213
214 String url = engine_args[0]->as<ASTLiteral &>().value.safeGet<String>();
215 Poco::URI uri (url);
216 S3::URI s3_uri (uri);
217
218 String format_name = engine_args[engine_args.size() - 1]->as<ASTLiteral &>().value.safeGet<String>();
219
220 String access_key_id;
221 String secret_access_key;
222 if (engine_args.size() >= 4)
223 {
224 access_key_id = engine_args[1]->as<ASTLiteral &>().value.safeGet<String>();
225 secret_access_key = engine_args[2]->as<ASTLiteral &>().value.safeGet<String>();
226 }
227
228 UInt64 min_upload_part_size = args.local_context.getSettingsRef().s3_min_upload_part_size;
229
230 String compression_method;
231 if (engine_args.size() == 3 || engine_args.size() == 5)
232 compression_method = engine_args.back()->as<ASTLiteral &>().value.safeGet<String>();
233 else
234 compression_method = "auto";
235
236 return StorageS3::create(s3_uri, access_key_id, secret_access_key, args.database_name, args.table_name, format_name, min_upload_part_size, args.columns, args.constraints, args.context);
237 });
238}
239
240}
241
242#endif
243