1#include "duckdb/common/dl.hpp"
2#include "duckdb/common/virtual_file_system.hpp"
3#include "duckdb/main/extension_helper.hpp"
4#include "duckdb/main/error_manager.hpp"
5#include "mbedtls_wrapper.hpp"
6
7#ifndef DUCKDB_NO_THREADS
8#include <thread>
9#endif // DUCKDB_NO_THREADS
10
11#ifdef WASM_LOADABLE_EXTENSIONS
12#include <emscripten.h>
13#endif
14
15namespace duckdb {
16
17//===--------------------------------------------------------------------===//
18// Load External Extension
19//===--------------------------------------------------------------------===//
20typedef void (*ext_init_fun_t)(DatabaseInstance &);
21typedef const char *(*ext_version_fun_t)(void);
22typedef bool (*ext_is_storage_t)(void);
23
24template <class T>
25static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) {
26 auto function = dlsym(handle: dll, name: function_name.c_str());
27 if (!function) {
28 throw IOException("File \"%s\" did not contain function \"%s\": %s", filename, function_name, GetDLError());
29 }
30 return (T)function;
31}
32
33static void ComputeSHA256String(const std::string &to_hash, std::string *res) {
34 // Invoke MbedTls function to actually compute sha256
35 *res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(file_content: to_hash);
36}
37
38static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) {
39 const idx_t len = end - start;
40 string file_content;
41 file_content.resize(n: len);
42 handle->Read(buffer: (void *)file_content.data(), nr_bytes: len, location: start);
43
44 ComputeSHA256String(to_hash: file_content, res);
45}
46
47bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension,
48 ExtensionInitResult &result, string &error) {
49 if (!config.options.enable_external_access) {
50 throw PermissionException("Loading external extensions is disabled through configuration");
51 }
52 auto filename = fs.ConvertSeparators(path: extension);
53
54 // shorthand case
55 if (!ExtensionHelper::IsFullPath(extension)) {
56 string local_path =
57 !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory();
58
59 // convert random separators to platform-canonic
60 local_path = fs.ConvertSeparators(path: local_path);
61 // expand ~ in extension directory
62 local_path = fs.ExpandPath(path: local_path);
63 auto path_components = PathComponents();
64 for (auto &path_ele : path_components) {
65 local_path = fs.JoinPath(a: local_path, path: path_ele);
66 }
67 string extension_name = ApplyExtensionAlias(extension_name: extension);
68 filename = fs.JoinPath(a: local_path, path: extension_name + ".duckdb_extension");
69 }
70 if (!fs.FileExists(filename)) {
71 string message;
72 bool exact_match = ExtensionHelper::CreateSuggestions(extension_name: extension, message);
73 if (exact_match) {
74 message += "\nInstall it first using \"INSTALL " + extension + "\".";
75 }
76 error = StringUtil::Format(fmt_str: "Extension \"%s\" not found.\n%s", params: filename, params: message);
77 return false;
78 }
79 if (!config.options.allow_unsigned_extensions) {
80 auto handle = fs.OpenFile(path: filename, flags: FileFlags::FILE_FLAGS_READ);
81
82 // signature is the last 256 bytes of the file
83
84 string signature;
85 signature.resize(n: 256);
86
87 auto signature_offset = handle->GetFileSize() - signature.size();
88
89 const idx_t maxLenChunks = 1024ULL * 1024ULL;
90 const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks;
91 std::vector<std::string> hash_chunks(numChunks);
92 std::vector<idx_t> splits(numChunks + 1);
93
94 for (idx_t i = 0; i < numChunks; i++) {
95 splits[i] = maxLenChunks * i;
96 }
97 splits.back() = signature_offset;
98
99#ifndef DUCKDB_NO_THREADS
100 std::vector<std::thread> threads;
101 threads.reserve(n: numChunks);
102 for (idx_t i = 0; i < numChunks; i++) {
103 threads.emplace_back(args&: ComputeSHA256FileSegment, args: handle.get(), args&: splits[i], args&: splits[i + 1], args: &hash_chunks[i]);
104 }
105
106 for (auto &thread : threads) {
107 thread.join();
108 }
109#else
110 for (idx_t i = 0; i < numChunks; i++) {
111 ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]);
112 }
113#endif // DUCKDB_NO_THREADS
114
115 string hash_concatenation;
116 hash_concatenation.reserve(res_arg: 32 * numChunks); // 256 bits -> 32 bytes per chunk
117
118 for (auto &hash_chunk : hash_chunks) {
119 hash_concatenation += hash_chunk;
120 }
121
122 string two_level_hash;
123 ComputeSHA256String(to_hash: hash_concatenation, res: &two_level_hash);
124
125 // TODO maybe we should do a stream read / hash update here
126 handle->Read(buffer: (void *)signature.data(), nr_bytes: signature.size(), location: signature_offset);
127
128 bool any_valid = false;
129 for (auto &key : ExtensionHelper::GetPublicKeys()) {
130 if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(pubkey: key, signature, sha256_hash: two_level_hash)) {
131 any_valid = true;
132 break;
133 }
134 }
135 if (!any_valid) {
136 throw IOException(config.error_manager->FormatException(error_type: ErrorType::UNSIGNED_EXTENSION, params: filename));
137 }
138 }
139 auto basename = fs.ExtractBaseName(path: filename);
140
141#ifdef WASM_LOADABLE_EXTENSIONS
142 EM_ASM(
143 {
144 // Next few lines should argubly in separate JavaScript-land function call
145 // TODO: move them out / have them configurable
146 const xhr = new XMLHttpRequest();
147 xhr.open("GET", UTF8ToString($0), false);
148 xhr.responseType = "arraybuffer";
149 xhr.send(null);
150 var uInt8Array = xhr.response;
151 WebAssembly.validate(uInt8Array);
152 console.log('Loading extension ', UTF8ToString($1));
153
154 // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen
155 FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array));
156 },
157 filename.c_str(), basename.c_str());
158 auto dopen_from = basename;
159#else
160 auto dopen_from = filename;
161#endif
162
163 auto lib_hdl = dlopen(file: dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL);
164 if (!lib_hdl) {
165 throw IOException("Extension \"%s\" could not be loaded: %s", filename, GetDLError());
166 }
167
168 ext_version_fun_t version_fun;
169 auto version_fun_name = basename + "_version";
170
171 version_fun = LoadFunctionFromDLL<ext_version_fun_t>(dll: lib_hdl, function_name: version_fun_name, filename);
172
173 std::string engine_version = std::string(DuckDB::LibraryVersion());
174
175 auto version_fun_result = (*version_fun)();
176 if (version_fun_result == nullptr) {
177 throw InvalidInputException("Extension \"%s\" returned a nullptr", filename);
178 }
179 std::string extension_version = std::string(version_fun_result);
180
181 // Trim v's if necessary
182 std::string extension_version_trimmed = extension_version;
183 std::string engine_version_trimmed = engine_version;
184 if (extension_version.length() > 0 && extension_version[0] == 'v') {
185 extension_version_trimmed = extension_version.substr(pos: 1);
186 }
187 if (engine_version.length() > 0 && engine_version[0] == 'v') {
188 engine_version_trimmed = engine_version.substr(pos: 1);
189 }
190
191 if (extension_version_trimmed != engine_version_trimmed) {
192 throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)", filename,
193 extension_version, engine_version);
194 }
195
196 result.basename = basename;
197 result.filename = filename;
198 result.lib_hdl = lib_hdl;
199 return true;
200}
201
202ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension) {
203 string error;
204 ExtensionInitResult result;
205 if (!TryInitialLoad(config, fs, extension, result, error)) {
206 if (!ExtensionHelper::AllowAutoInstall(extension)) {
207 throw IOException(error);
208 }
209 // the extension load failed - try installing the extension
210 ExtensionHelper::InstallExtension(config, fs, extension, force_install: false);
211 // try loading again
212 if (!TryInitialLoad(config, fs, extension, result, error)) {
213 throw IOException(error);
214 }
215 }
216 return result;
217}
218
219bool ExtensionHelper::IsFullPath(const string &extension) {
220 return StringUtil::Contains(haystack: extension, needle: ".") || StringUtil::Contains(haystack: extension, needle: "/") ||
221 StringUtil::Contains(haystack: extension, needle: "\\");
222}
223
224string ExtensionHelper::GetExtensionName(const string &original_name) {
225 auto extension = StringUtil::Lower(str: original_name);
226 if (!IsFullPath(extension)) {
227 return ExtensionHelper::ApplyExtensionAlias(extension_name: extension);
228 }
229 auto splits = StringUtil::Split(str: StringUtil::Replace(source: extension, from: "\\", to: "/"), delimiter: '/');
230 if (splits.empty()) {
231 return ExtensionHelper::ApplyExtensionAlias(extension_name: extension);
232 }
233 splits = StringUtil::Split(str: splits.back(), delimiter: '.');
234 if (splits.empty()) {
235 return ExtensionHelper::ApplyExtensionAlias(extension_name: extension);
236 }
237 return ExtensionHelper::ApplyExtensionAlias(extension_name: splits.front());
238}
239
240void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension) {
241 if (db.ExtensionIsLoaded(name: extension)) {
242 return;
243 }
244
245 auto res = InitialLoad(config&: DBConfig::GetConfig(db), fs, extension);
246 auto init_fun_name = res.basename + "_init";
247
248 ext_init_fun_t init_fun;
249 init_fun = LoadFunctionFromDLL<ext_init_fun_t>(dll: res.lib_hdl, function_name: init_fun_name, filename: res.filename);
250
251 try {
252 (*init_fun)(db);
253 } catch (std::exception &e) {
254 throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"",
255 init_fun_name, res.filename, e.what());
256 }
257
258 db.SetExtensionLoaded(extension);
259}
260
261void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) {
262 LoadExternalExtension(db&: DatabaseInstance::GetDatabase(context), fs&: FileSystem::GetFileSystem(context), extension);
263}
264
265string ExtensionHelper::ExtractExtensionPrefixFromPath(const string &path) {
266 auto first_colon = path.find(c: ':');
267 if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ...
268 return "";
269 }
270 auto extension = path.substr(pos: 0, n: first_colon);
271
272 if (path.substr(pos: first_colon, n: 3) == "://") {
273 // these are not extensions
274 return "";
275 }
276
277 D_ASSERT(extension.size() > 1);
278 // needs to be alphanumeric
279 for (auto &ch : extension) {
280 if (!isalnum(ch) && ch != '_') {
281 return "";
282 }
283 }
284 return extension;
285}
286
287} // namespace duckdb
288