1 | #include "duckdb/common/dl.hpp" |
2 | #include "duckdb/common/virtual_file_system.hpp" |
3 | #include "duckdb/main/extension_helper.hpp" |
4 | #include "duckdb/main/error_manager.hpp" |
5 | #include "mbedtls_wrapper.hpp" |
6 | |
7 | #ifndef DUCKDB_NO_THREADS |
8 | #include <thread> |
9 | #endif // DUCKDB_NO_THREADS |
10 | |
11 | #ifdef WASM_LOADABLE_EXTENSIONS |
12 | #include <emscripten.h> |
13 | #endif |
14 | |
15 | namespace duckdb { |
16 | |
17 | //===--------------------------------------------------------------------===// |
18 | // Load External Extension |
19 | //===--------------------------------------------------------------------===// |
20 | typedef void (*ext_init_fun_t)(DatabaseInstance &); |
21 | typedef const char *(*ext_version_fun_t)(void); |
22 | typedef bool (*ext_is_storage_t)(void); |
23 | |
24 | template <class T> |
25 | static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) { |
26 | auto function = dlsym(handle: dll, name: function_name.c_str()); |
27 | if (!function) { |
28 | throw IOException("File \"%s\" did not contain function \"%s\": %s" , filename, function_name, GetDLError()); |
29 | } |
30 | return (T)function; |
31 | } |
32 | |
33 | static void ComputeSHA256String(const std::string &to_hash, std::string *res) { |
34 | // Invoke MbedTls function to actually compute sha256 |
35 | *res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(file_content: to_hash); |
36 | } |
37 | |
38 | static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) { |
39 | const idx_t len = end - start; |
40 | string file_content; |
41 | file_content.resize(n: len); |
42 | handle->Read(buffer: (void *)file_content.data(), nr_bytes: len, location: start); |
43 | |
44 | ComputeSHA256String(to_hash: file_content, res); |
45 | } |
46 | |
47 | bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, |
48 | ExtensionInitResult &result, string &error) { |
49 | if (!config.options.enable_external_access) { |
50 | throw PermissionException("Loading external extensions is disabled through configuration" ); |
51 | } |
52 | auto filename = fs.ConvertSeparators(path: extension); |
53 | |
54 | // shorthand case |
55 | if (!ExtensionHelper::IsFullPath(extension)) { |
56 | string local_path = |
57 | !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory(); |
58 | |
59 | // convert random separators to platform-canonic |
60 | local_path = fs.ConvertSeparators(path: local_path); |
61 | // expand ~ in extension directory |
62 | local_path = fs.ExpandPath(path: local_path); |
63 | auto path_components = PathComponents(); |
64 | for (auto &path_ele : path_components) { |
65 | local_path = fs.JoinPath(a: local_path, path: path_ele); |
66 | } |
67 | string extension_name = ApplyExtensionAlias(extension_name: extension); |
68 | filename = fs.JoinPath(a: local_path, path: extension_name + ".duckdb_extension" ); |
69 | } |
70 | if (!fs.FileExists(filename)) { |
71 | string message; |
72 | bool exact_match = ExtensionHelper::CreateSuggestions(extension_name: extension, message); |
73 | if (exact_match) { |
74 | message += "\nInstall it first using \"INSTALL " + extension + "\"." ; |
75 | } |
76 | error = StringUtil::Format(fmt_str: "Extension \"%s\" not found.\n%s" , params: filename, params: message); |
77 | return false; |
78 | } |
79 | if (!config.options.allow_unsigned_extensions) { |
80 | auto handle = fs.OpenFile(path: filename, flags: FileFlags::FILE_FLAGS_READ); |
81 | |
82 | // signature is the last 256 bytes of the file |
83 | |
84 | string signature; |
85 | signature.resize(n: 256); |
86 | |
87 | auto signature_offset = handle->GetFileSize() - signature.size(); |
88 | |
89 | const idx_t maxLenChunks = 1024ULL * 1024ULL; |
90 | const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks; |
91 | std::vector<std::string> hash_chunks(numChunks); |
92 | std::vector<idx_t> splits(numChunks + 1); |
93 | |
94 | for (idx_t i = 0; i < numChunks; i++) { |
95 | splits[i] = maxLenChunks * i; |
96 | } |
97 | splits.back() = signature_offset; |
98 | |
99 | #ifndef DUCKDB_NO_THREADS |
100 | std::vector<std::thread> threads; |
101 | threads.reserve(n: numChunks); |
102 | for (idx_t i = 0; i < numChunks; i++) { |
103 | threads.emplace_back(args&: ComputeSHA256FileSegment, args: handle.get(), args&: splits[i], args&: splits[i + 1], args: &hash_chunks[i]); |
104 | } |
105 | |
106 | for (auto &thread : threads) { |
107 | thread.join(); |
108 | } |
109 | #else |
110 | for (idx_t i = 0; i < numChunks; i++) { |
111 | ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); |
112 | } |
113 | #endif // DUCKDB_NO_THREADS |
114 | |
115 | string hash_concatenation; |
116 | hash_concatenation.reserve(res_arg: 32 * numChunks); // 256 bits -> 32 bytes per chunk |
117 | |
118 | for (auto &hash_chunk : hash_chunks) { |
119 | hash_concatenation += hash_chunk; |
120 | } |
121 | |
122 | string two_level_hash; |
123 | ComputeSHA256String(to_hash: hash_concatenation, res: &two_level_hash); |
124 | |
125 | // TODO maybe we should do a stream read / hash update here |
126 | handle->Read(buffer: (void *)signature.data(), nr_bytes: signature.size(), location: signature_offset); |
127 | |
128 | bool any_valid = false; |
129 | for (auto &key : ExtensionHelper::GetPublicKeys()) { |
130 | if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(pubkey: key, signature, sha256_hash: two_level_hash)) { |
131 | any_valid = true; |
132 | break; |
133 | } |
134 | } |
135 | if (!any_valid) { |
136 | throw IOException(config.error_manager->FormatException(error_type: ErrorType::UNSIGNED_EXTENSION, params: filename)); |
137 | } |
138 | } |
139 | auto basename = fs.ExtractBaseName(path: filename); |
140 | |
141 | #ifdef WASM_LOADABLE_EXTENSIONS |
142 | EM_ASM( |
143 | { |
144 | // Next few lines should argubly in separate JavaScript-land function call |
145 | // TODO: move them out / have them configurable |
146 | const xhr = new XMLHttpRequest(); |
147 | xhr.open("GET" , UTF8ToString($0), false); |
148 | xhr.responseType = "arraybuffer" ; |
149 | xhr.send(null); |
150 | var uInt8Array = xhr.response; |
151 | WebAssembly.validate(uInt8Array); |
152 | console.log('Loading extension ', UTF8ToString($1)); |
153 | |
154 | // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen |
155 | FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array)); |
156 | }, |
157 | filename.c_str(), basename.c_str()); |
158 | auto dopen_from = basename; |
159 | #else |
160 | auto dopen_from = filename; |
161 | #endif |
162 | |
163 | auto lib_hdl = dlopen(file: dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL); |
164 | if (!lib_hdl) { |
165 | throw IOException("Extension \"%s\" could not be loaded: %s" , filename, GetDLError()); |
166 | } |
167 | |
168 | ext_version_fun_t version_fun; |
169 | auto version_fun_name = basename + "_version" ; |
170 | |
171 | version_fun = LoadFunctionFromDLL<ext_version_fun_t>(dll: lib_hdl, function_name: version_fun_name, filename); |
172 | |
173 | std::string engine_version = std::string(DuckDB::LibraryVersion()); |
174 | |
175 | auto version_fun_result = (*version_fun)(); |
176 | if (version_fun_result == nullptr) { |
177 | throw InvalidInputException("Extension \"%s\" returned a nullptr" , filename); |
178 | } |
179 | std::string extension_version = std::string(version_fun_result); |
180 | |
181 | // Trim v's if necessary |
182 | std::string extension_version_trimmed = extension_version; |
183 | std::string engine_version_trimmed = engine_version; |
184 | if (extension_version.length() > 0 && extension_version[0] == 'v') { |
185 | extension_version_trimmed = extension_version.substr(pos: 1); |
186 | } |
187 | if (engine_version.length() > 0 && engine_version[0] == 'v') { |
188 | engine_version_trimmed = engine_version.substr(pos: 1); |
189 | } |
190 | |
191 | if (extension_version_trimmed != engine_version_trimmed) { |
192 | throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)" , filename, |
193 | extension_version, engine_version); |
194 | } |
195 | |
196 | result.basename = basename; |
197 | result.filename = filename; |
198 | result.lib_hdl = lib_hdl; |
199 | return true; |
200 | } |
201 | |
202 | ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension) { |
203 | string error; |
204 | ExtensionInitResult result; |
205 | if (!TryInitialLoad(config, fs, extension, result, error)) { |
206 | if (!ExtensionHelper::AllowAutoInstall(extension)) { |
207 | throw IOException(error); |
208 | } |
209 | // the extension load failed - try installing the extension |
210 | ExtensionHelper::InstallExtension(config, fs, extension, force_install: false); |
211 | // try loading again |
212 | if (!TryInitialLoad(config, fs, extension, result, error)) { |
213 | throw IOException(error); |
214 | } |
215 | } |
216 | return result; |
217 | } |
218 | |
219 | bool ExtensionHelper::IsFullPath(const string &extension) { |
220 | return StringUtil::Contains(haystack: extension, needle: "." ) || StringUtil::Contains(haystack: extension, needle: "/" ) || |
221 | StringUtil::Contains(haystack: extension, needle: "\\" ); |
222 | } |
223 | |
224 | string ExtensionHelper::GetExtensionName(const string &original_name) { |
225 | auto extension = StringUtil::Lower(str: original_name); |
226 | if (!IsFullPath(extension)) { |
227 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
228 | } |
229 | auto splits = StringUtil::Split(str: StringUtil::Replace(source: extension, from: "\\" , to: "/" ), delimiter: '/'); |
230 | if (splits.empty()) { |
231 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
232 | } |
233 | splits = StringUtil::Split(str: splits.back(), delimiter: '.'); |
234 | if (splits.empty()) { |
235 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
236 | } |
237 | return ExtensionHelper::ApplyExtensionAlias(extension_name: splits.front()); |
238 | } |
239 | |
240 | void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension) { |
241 | if (db.ExtensionIsLoaded(name: extension)) { |
242 | return; |
243 | } |
244 | |
245 | auto res = InitialLoad(config&: DBConfig::GetConfig(db), fs, extension); |
246 | auto init_fun_name = res.basename + "_init" ; |
247 | |
248 | ext_init_fun_t init_fun; |
249 | init_fun = LoadFunctionFromDLL<ext_init_fun_t>(dll: res.lib_hdl, function_name: init_fun_name, filename: res.filename); |
250 | |
251 | try { |
252 | (*init_fun)(db); |
253 | } catch (std::exception &e) { |
254 | throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"" , |
255 | init_fun_name, res.filename, e.what()); |
256 | } |
257 | |
258 | db.SetExtensionLoaded(extension); |
259 | } |
260 | |
261 | void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { |
262 | LoadExternalExtension(db&: DatabaseInstance::GetDatabase(context), fs&: FileSystem::GetFileSystem(context), extension); |
263 | } |
264 | |
265 | string ExtensionHelper::(const string &path) { |
266 | auto first_colon = path.find(c: ':'); |
267 | if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ... |
268 | return "" ; |
269 | } |
270 | auto extension = path.substr(pos: 0, n: first_colon); |
271 | |
272 | if (path.substr(pos: first_colon, n: 3) == "://" ) { |
273 | // these are not extensions |
274 | return "" ; |
275 | } |
276 | |
277 | D_ASSERT(extension.size() > 1); |
278 | // needs to be alphanumeric |
279 | for (auto &ch : extension) { |
280 | if (!isalnum(ch) && ch != '_') { |
281 | return "" ; |
282 | } |
283 | } |
284 | return extension; |
285 | } |
286 | |
287 | } // namespace duckdb |
288 | |