| 1 | #include "duckdb/common/dl.hpp" |
| 2 | #include "duckdb/common/virtual_file_system.hpp" |
| 3 | #include "duckdb/main/extension_helper.hpp" |
| 4 | #include "duckdb/main/error_manager.hpp" |
| 5 | #include "mbedtls_wrapper.hpp" |
| 6 | |
| 7 | #ifndef DUCKDB_NO_THREADS |
| 8 | #include <thread> |
| 9 | #endif // DUCKDB_NO_THREADS |
| 10 | |
| 11 | #ifdef WASM_LOADABLE_EXTENSIONS |
| 12 | #include <emscripten.h> |
| 13 | #endif |
| 14 | |
| 15 | namespace duckdb { |
| 16 | |
| 17 | //===--------------------------------------------------------------------===// |
| 18 | // Load External Extension |
| 19 | //===--------------------------------------------------------------------===// |
| 20 | typedef void (*ext_init_fun_t)(DatabaseInstance &); |
| 21 | typedef const char *(*ext_version_fun_t)(void); |
| 22 | typedef bool (*ext_is_storage_t)(void); |
| 23 | |
| 24 | template <class T> |
| 25 | static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) { |
| 26 | auto function = dlsym(handle: dll, name: function_name.c_str()); |
| 27 | if (!function) { |
| 28 | throw IOException("File \"%s\" did not contain function \"%s\": %s" , filename, function_name, GetDLError()); |
| 29 | } |
| 30 | return (T)function; |
| 31 | } |
| 32 | |
| 33 | static void ComputeSHA256String(const std::string &to_hash, std::string *res) { |
| 34 | // Invoke MbedTls function to actually compute sha256 |
| 35 | *res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(file_content: to_hash); |
| 36 | } |
| 37 | |
| 38 | static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) { |
| 39 | const idx_t len = end - start; |
| 40 | string file_content; |
| 41 | file_content.resize(n: len); |
| 42 | handle->Read(buffer: (void *)file_content.data(), nr_bytes: len, location: start); |
| 43 | |
| 44 | ComputeSHA256String(to_hash: file_content, res); |
| 45 | } |
| 46 | |
| 47 | bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, |
| 48 | ExtensionInitResult &result, string &error) { |
| 49 | if (!config.options.enable_external_access) { |
| 50 | throw PermissionException("Loading external extensions is disabled through configuration" ); |
| 51 | } |
| 52 | auto filename = fs.ConvertSeparators(path: extension); |
| 53 | |
| 54 | // shorthand case |
| 55 | if (!ExtensionHelper::IsFullPath(extension)) { |
| 56 | string local_path = |
| 57 | !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory(); |
| 58 | |
| 59 | // convert random separators to platform-canonic |
| 60 | local_path = fs.ConvertSeparators(path: local_path); |
| 61 | // expand ~ in extension directory |
| 62 | local_path = fs.ExpandPath(path: local_path); |
| 63 | auto path_components = PathComponents(); |
| 64 | for (auto &path_ele : path_components) { |
| 65 | local_path = fs.JoinPath(a: local_path, path: path_ele); |
| 66 | } |
| 67 | string extension_name = ApplyExtensionAlias(extension_name: extension); |
| 68 | filename = fs.JoinPath(a: local_path, path: extension_name + ".duckdb_extension" ); |
| 69 | } |
| 70 | if (!fs.FileExists(filename)) { |
| 71 | string message; |
| 72 | bool exact_match = ExtensionHelper::CreateSuggestions(extension_name: extension, message); |
| 73 | if (exact_match) { |
| 74 | message += "\nInstall it first using \"INSTALL " + extension + "\"." ; |
| 75 | } |
| 76 | error = StringUtil::Format(fmt_str: "Extension \"%s\" not found.\n%s" , params: filename, params: message); |
| 77 | return false; |
| 78 | } |
| 79 | if (!config.options.allow_unsigned_extensions) { |
| 80 | auto handle = fs.OpenFile(path: filename, flags: FileFlags::FILE_FLAGS_READ); |
| 81 | |
| 82 | // signature is the last 256 bytes of the file |
| 83 | |
| 84 | string signature; |
| 85 | signature.resize(n: 256); |
| 86 | |
| 87 | auto signature_offset = handle->GetFileSize() - signature.size(); |
| 88 | |
| 89 | const idx_t maxLenChunks = 1024ULL * 1024ULL; |
| 90 | const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks; |
| 91 | std::vector<std::string> hash_chunks(numChunks); |
| 92 | std::vector<idx_t> splits(numChunks + 1); |
| 93 | |
| 94 | for (idx_t i = 0; i < numChunks; i++) { |
| 95 | splits[i] = maxLenChunks * i; |
| 96 | } |
| 97 | splits.back() = signature_offset; |
| 98 | |
| 99 | #ifndef DUCKDB_NO_THREADS |
| 100 | std::vector<std::thread> threads; |
| 101 | threads.reserve(n: numChunks); |
| 102 | for (idx_t i = 0; i < numChunks; i++) { |
| 103 | threads.emplace_back(args&: ComputeSHA256FileSegment, args: handle.get(), args&: splits[i], args&: splits[i + 1], args: &hash_chunks[i]); |
| 104 | } |
| 105 | |
| 106 | for (auto &thread : threads) { |
| 107 | thread.join(); |
| 108 | } |
| 109 | #else |
| 110 | for (idx_t i = 0; i < numChunks; i++) { |
| 111 | ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); |
| 112 | } |
| 113 | #endif // DUCKDB_NO_THREADS |
| 114 | |
| 115 | string hash_concatenation; |
| 116 | hash_concatenation.reserve(res_arg: 32 * numChunks); // 256 bits -> 32 bytes per chunk |
| 117 | |
| 118 | for (auto &hash_chunk : hash_chunks) { |
| 119 | hash_concatenation += hash_chunk; |
| 120 | } |
| 121 | |
| 122 | string two_level_hash; |
| 123 | ComputeSHA256String(to_hash: hash_concatenation, res: &two_level_hash); |
| 124 | |
| 125 | // TODO maybe we should do a stream read / hash update here |
| 126 | handle->Read(buffer: (void *)signature.data(), nr_bytes: signature.size(), location: signature_offset); |
| 127 | |
| 128 | bool any_valid = false; |
| 129 | for (auto &key : ExtensionHelper::GetPublicKeys()) { |
| 130 | if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(pubkey: key, signature, sha256_hash: two_level_hash)) { |
| 131 | any_valid = true; |
| 132 | break; |
| 133 | } |
| 134 | } |
| 135 | if (!any_valid) { |
| 136 | throw IOException(config.error_manager->FormatException(error_type: ErrorType::UNSIGNED_EXTENSION, params: filename)); |
| 137 | } |
| 138 | } |
| 139 | auto basename = fs.ExtractBaseName(path: filename); |
| 140 | |
| 141 | #ifdef WASM_LOADABLE_EXTENSIONS |
| 142 | EM_ASM( |
| 143 | { |
| 144 | // Next few lines should argubly in separate JavaScript-land function call |
| 145 | // TODO: move them out / have them configurable |
| 146 | const xhr = new XMLHttpRequest(); |
| 147 | xhr.open("GET" , UTF8ToString($0), false); |
| 148 | xhr.responseType = "arraybuffer" ; |
| 149 | xhr.send(null); |
| 150 | var uInt8Array = xhr.response; |
| 151 | WebAssembly.validate(uInt8Array); |
| 152 | console.log('Loading extension ', UTF8ToString($1)); |
| 153 | |
| 154 | // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen |
| 155 | FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array)); |
| 156 | }, |
| 157 | filename.c_str(), basename.c_str()); |
| 158 | auto dopen_from = basename; |
| 159 | #else |
| 160 | auto dopen_from = filename; |
| 161 | #endif |
| 162 | |
| 163 | auto lib_hdl = dlopen(file: dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL); |
| 164 | if (!lib_hdl) { |
| 165 | throw IOException("Extension \"%s\" could not be loaded: %s" , filename, GetDLError()); |
| 166 | } |
| 167 | |
| 168 | ext_version_fun_t version_fun; |
| 169 | auto version_fun_name = basename + "_version" ; |
| 170 | |
| 171 | version_fun = LoadFunctionFromDLL<ext_version_fun_t>(dll: lib_hdl, function_name: version_fun_name, filename); |
| 172 | |
| 173 | std::string engine_version = std::string(DuckDB::LibraryVersion()); |
| 174 | |
| 175 | auto version_fun_result = (*version_fun)(); |
| 176 | if (version_fun_result == nullptr) { |
| 177 | throw InvalidInputException("Extension \"%s\" returned a nullptr" , filename); |
| 178 | } |
| 179 | std::string extension_version = std::string(version_fun_result); |
| 180 | |
| 181 | // Trim v's if necessary |
| 182 | std::string extension_version_trimmed = extension_version; |
| 183 | std::string engine_version_trimmed = engine_version; |
| 184 | if (extension_version.length() > 0 && extension_version[0] == 'v') { |
| 185 | extension_version_trimmed = extension_version.substr(pos: 1); |
| 186 | } |
| 187 | if (engine_version.length() > 0 && engine_version[0] == 'v') { |
| 188 | engine_version_trimmed = engine_version.substr(pos: 1); |
| 189 | } |
| 190 | |
| 191 | if (extension_version_trimmed != engine_version_trimmed) { |
| 192 | throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)" , filename, |
| 193 | extension_version, engine_version); |
| 194 | } |
| 195 | |
| 196 | result.basename = basename; |
| 197 | result.filename = filename; |
| 198 | result.lib_hdl = lib_hdl; |
| 199 | return true; |
| 200 | } |
| 201 | |
| 202 | ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension) { |
| 203 | string error; |
| 204 | ExtensionInitResult result; |
| 205 | if (!TryInitialLoad(config, fs, extension, result, error)) { |
| 206 | if (!ExtensionHelper::AllowAutoInstall(extension)) { |
| 207 | throw IOException(error); |
| 208 | } |
| 209 | // the extension load failed - try installing the extension |
| 210 | ExtensionHelper::InstallExtension(config, fs, extension, force_install: false); |
| 211 | // try loading again |
| 212 | if (!TryInitialLoad(config, fs, extension, result, error)) { |
| 213 | throw IOException(error); |
| 214 | } |
| 215 | } |
| 216 | return result; |
| 217 | } |
| 218 | |
| 219 | bool ExtensionHelper::IsFullPath(const string &extension) { |
| 220 | return StringUtil::Contains(haystack: extension, needle: "." ) || StringUtil::Contains(haystack: extension, needle: "/" ) || |
| 221 | StringUtil::Contains(haystack: extension, needle: "\\" ); |
| 222 | } |
| 223 | |
| 224 | string ExtensionHelper::GetExtensionName(const string &original_name) { |
| 225 | auto extension = StringUtil::Lower(str: original_name); |
| 226 | if (!IsFullPath(extension)) { |
| 227 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
| 228 | } |
| 229 | auto splits = StringUtil::Split(str: StringUtil::Replace(source: extension, from: "\\" , to: "/" ), delimiter: '/'); |
| 230 | if (splits.empty()) { |
| 231 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
| 232 | } |
| 233 | splits = StringUtil::Split(str: splits.back(), delimiter: '.'); |
| 234 | if (splits.empty()) { |
| 235 | return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); |
| 236 | } |
| 237 | return ExtensionHelper::ApplyExtensionAlias(extension_name: splits.front()); |
| 238 | } |
| 239 | |
| 240 | void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension) { |
| 241 | if (db.ExtensionIsLoaded(name: extension)) { |
| 242 | return; |
| 243 | } |
| 244 | |
| 245 | auto res = InitialLoad(config&: DBConfig::GetConfig(db), fs, extension); |
| 246 | auto init_fun_name = res.basename + "_init" ; |
| 247 | |
| 248 | ext_init_fun_t init_fun; |
| 249 | init_fun = LoadFunctionFromDLL<ext_init_fun_t>(dll: res.lib_hdl, function_name: init_fun_name, filename: res.filename); |
| 250 | |
| 251 | try { |
| 252 | (*init_fun)(db); |
| 253 | } catch (std::exception &e) { |
| 254 | throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"" , |
| 255 | init_fun_name, res.filename, e.what()); |
| 256 | } |
| 257 | |
| 258 | db.SetExtensionLoaded(extension); |
| 259 | } |
| 260 | |
| 261 | void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { |
| 262 | LoadExternalExtension(db&: DatabaseInstance::GetDatabase(context), fs&: FileSystem::GetFileSystem(context), extension); |
| 263 | } |
| 264 | |
| 265 | string ExtensionHelper::(const string &path) { |
| 266 | auto first_colon = path.find(c: ':'); |
| 267 | if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ... |
| 268 | return "" ; |
| 269 | } |
| 270 | auto extension = path.substr(pos: 0, n: first_colon); |
| 271 | |
| 272 | if (path.substr(pos: first_colon, n: 3) == "://" ) { |
| 273 | // these are not extensions |
| 274 | return "" ; |
| 275 | } |
| 276 | |
| 277 | D_ASSERT(extension.size() > 1); |
| 278 | // needs to be alphanumeric |
| 279 | for (auto &ch : extension) { |
| 280 | if (!isalnum(ch) && ch != '_') { |
| 281 | return "" ; |
| 282 | } |
| 283 | } |
| 284 | return extension; |
| 285 | } |
| 286 | |
| 287 | } // namespace duckdb |
| 288 | |