| 1 | #include "duckdb/common/dl.hpp" | 
| 2 | #include "duckdb/common/virtual_file_system.hpp" | 
| 3 | #include "duckdb/main/extension_helper.hpp" | 
| 4 | #include "duckdb/main/error_manager.hpp" | 
| 5 | #include "mbedtls_wrapper.hpp" | 
| 6 |  | 
| 7 | #ifndef DUCKDB_NO_THREADS | 
| 8 | #include <thread> | 
| 9 | #endif // DUCKDB_NO_THREADS | 
| 10 |  | 
| 11 | #ifdef WASM_LOADABLE_EXTENSIONS | 
| 12 | #include <emscripten.h> | 
| 13 | #endif | 
| 14 |  | 
| 15 | namespace duckdb { | 
| 16 |  | 
| 17 | //===--------------------------------------------------------------------===// | 
| 18 | // Load External Extension | 
| 19 | //===--------------------------------------------------------------------===// | 
| 20 | typedef void (*ext_init_fun_t)(DatabaseInstance &); | 
| 21 | typedef const char *(*ext_version_fun_t)(void); | 
| 22 | typedef bool (*ext_is_storage_t)(void); | 
| 23 |  | 
| 24 | template <class T> | 
| 25 | static T LoadFunctionFromDLL(void *dll, const string &function_name, const string &filename) { | 
| 26 | 	auto function = dlsym(handle: dll, name: function_name.c_str()); | 
| 27 | 	if (!function) { | 
| 28 | 		throw IOException("File \"%s\" did not contain function \"%s\": %s" , filename, function_name, GetDLError()); | 
| 29 | 	} | 
| 30 | 	return (T)function; | 
| 31 | } | 
| 32 |  | 
| 33 | static void ComputeSHA256String(const std::string &to_hash, std::string *res) { | 
| 34 | 	// Invoke MbedTls function to actually compute sha256 | 
| 35 | 	*res = duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(file_content: to_hash); | 
| 36 | } | 
| 37 |  | 
| 38 | static void ComputeSHA256FileSegment(FileHandle *handle, const idx_t start, const idx_t end, std::string *res) { | 
| 39 | 	const idx_t len = end - start; | 
| 40 | 	string file_content; | 
| 41 | 	file_content.resize(n: len); | 
| 42 | 	handle->Read(buffer: (void *)file_content.data(), nr_bytes: len, location: start); | 
| 43 |  | 
| 44 | 	ComputeSHA256String(to_hash: file_content, res); | 
| 45 | } | 
| 46 |  | 
| 47 | bool ExtensionHelper::TryInitialLoad(DBConfig &config, FileSystem &fs, const string &extension, | 
| 48 |                                      ExtensionInitResult &result, string &error) { | 
| 49 | 	if (!config.options.enable_external_access) { | 
| 50 | 		throw PermissionException("Loading external extensions is disabled through configuration" ); | 
| 51 | 	} | 
| 52 | 	auto filename = fs.ConvertSeparators(path: extension); | 
| 53 |  | 
| 54 | 	// shorthand case | 
| 55 | 	if (!ExtensionHelper::IsFullPath(extension)) { | 
| 56 | 		string local_path = | 
| 57 | 		    !config.options.extension_directory.empty() ? config.options.extension_directory : fs.GetHomeDirectory(); | 
| 58 |  | 
| 59 | 		// convert random separators to platform-canonic | 
| 60 | 		local_path = fs.ConvertSeparators(path: local_path); | 
| 61 | 		// expand ~ in extension directory | 
| 62 | 		local_path = fs.ExpandPath(path: local_path); | 
| 63 | 		auto path_components = PathComponents(); | 
| 64 | 		for (auto &path_ele : path_components) { | 
| 65 | 			local_path = fs.JoinPath(a: local_path, path: path_ele); | 
| 66 | 		} | 
| 67 | 		string extension_name = ApplyExtensionAlias(extension_name: extension); | 
| 68 | 		filename = fs.JoinPath(a: local_path, path: extension_name + ".duckdb_extension" ); | 
| 69 | 	} | 
| 70 | 	if (!fs.FileExists(filename)) { | 
| 71 | 		string message; | 
| 72 | 		bool exact_match = ExtensionHelper::CreateSuggestions(extension_name: extension, message); | 
| 73 | 		if (exact_match) { | 
| 74 | 			message += "\nInstall it first using \"INSTALL "  + extension + "\"." ; | 
| 75 | 		} | 
| 76 | 		error = StringUtil::Format(fmt_str: "Extension \"%s\" not found.\n%s" , params: filename, params: message); | 
| 77 | 		return false; | 
| 78 | 	} | 
| 79 | 	if (!config.options.allow_unsigned_extensions) { | 
| 80 | 		auto handle = fs.OpenFile(path: filename, flags: FileFlags::FILE_FLAGS_READ); | 
| 81 |  | 
| 82 | 		// signature is the last 256 bytes of the file | 
| 83 |  | 
| 84 | 		string signature; | 
| 85 | 		signature.resize(n: 256); | 
| 86 |  | 
| 87 | 		auto signature_offset = handle->GetFileSize() - signature.size(); | 
| 88 |  | 
| 89 | 		const idx_t maxLenChunks = 1024ULL * 1024ULL; | 
| 90 | 		const idx_t numChunks = (signature_offset + maxLenChunks - 1) / maxLenChunks; | 
| 91 | 		std::vector<std::string> hash_chunks(numChunks); | 
| 92 | 		std::vector<idx_t> splits(numChunks + 1); | 
| 93 |  | 
| 94 | 		for (idx_t i = 0; i < numChunks; i++) { | 
| 95 | 			splits[i] = maxLenChunks * i; | 
| 96 | 		} | 
| 97 | 		splits.back() = signature_offset; | 
| 98 |  | 
| 99 | #ifndef DUCKDB_NO_THREADS | 
| 100 | 		std::vector<std::thread> threads; | 
| 101 | 		threads.reserve(n: numChunks); | 
| 102 | 		for (idx_t i = 0; i < numChunks; i++) { | 
| 103 | 			threads.emplace_back(args&: ComputeSHA256FileSegment, args: handle.get(), args&: splits[i], args&: splits[i + 1], args: &hash_chunks[i]); | 
| 104 | 		} | 
| 105 |  | 
| 106 | 		for (auto &thread : threads) { | 
| 107 | 			thread.join(); | 
| 108 | 		} | 
| 109 | #else | 
| 110 | 		for (idx_t i = 0; i < numChunks; i++) { | 
| 111 | 			ComputeSHA256FileSegment(handle.get(), splits[i], splits[i + 1], &hash_chunks[i]); | 
| 112 | 		} | 
| 113 | #endif // DUCKDB_NO_THREADS | 
| 114 |  | 
| 115 | 		string hash_concatenation; | 
| 116 | 		hash_concatenation.reserve(res_arg: 32 * numChunks); // 256 bits -> 32 bytes per chunk | 
| 117 |  | 
| 118 | 		for (auto &hash_chunk : hash_chunks) { | 
| 119 | 			hash_concatenation += hash_chunk; | 
| 120 | 		} | 
| 121 |  | 
| 122 | 		string two_level_hash; | 
| 123 | 		ComputeSHA256String(to_hash: hash_concatenation, res: &two_level_hash); | 
| 124 |  | 
| 125 | 		// TODO maybe we should do a stream read / hash update here | 
| 126 | 		handle->Read(buffer: (void *)signature.data(), nr_bytes: signature.size(), location: signature_offset); | 
| 127 |  | 
| 128 | 		bool any_valid = false; | 
| 129 | 		for (auto &key : ExtensionHelper::GetPublicKeys()) { | 
| 130 | 			if (duckdb_mbedtls::MbedTlsWrapper::IsValidSha256Signature(pubkey: key, signature, sha256_hash: two_level_hash)) { | 
| 131 | 				any_valid = true; | 
| 132 | 				break; | 
| 133 | 			} | 
| 134 | 		} | 
| 135 | 		if (!any_valid) { | 
| 136 | 			throw IOException(config.error_manager->FormatException(error_type: ErrorType::UNSIGNED_EXTENSION, params: filename)); | 
| 137 | 		} | 
| 138 | 	} | 
| 139 | 	auto basename = fs.ExtractBaseName(path: filename); | 
| 140 |  | 
| 141 | #ifdef WASM_LOADABLE_EXTENSIONS | 
| 142 | 	EM_ASM( | 
| 143 | 	    { | 
| 144 | 		    // Next few lines should argubly in separate JavaScript-land function call | 
| 145 | 		    // TODO: move them out / have them configurable | 
| 146 | 		    const xhr = new XMLHttpRequest(); | 
| 147 | 		    xhr.open("GET" , UTF8ToString($0), false); | 
| 148 | 		    xhr.responseType = "arraybuffer" ; | 
| 149 | 		    xhr.send(null); | 
| 150 | 		    var uInt8Array = xhr.response; | 
| 151 | 		    WebAssembly.validate(uInt8Array); | 
| 152 | 		    console.log('Loading extension ', UTF8ToString($1)); | 
| 153 |  | 
| 154 | 		    // Here we add the uInt8Array to Emscripten's filesystem, for it to be found by dlopen | 
| 155 | 		    FS.writeFile(UTF8ToString($1), new Uint8Array(uInt8Array)); | 
| 156 | 	    }, | 
| 157 | 	    filename.c_str(), basename.c_str()); | 
| 158 | 	auto dopen_from = basename; | 
| 159 | #else | 
| 160 | 	auto dopen_from = filename; | 
| 161 | #endif | 
| 162 |  | 
| 163 | 	auto lib_hdl = dlopen(file: dopen_from.c_str(), RTLD_NOW | RTLD_LOCAL); | 
| 164 | 	if (!lib_hdl) { | 
| 165 | 		throw IOException("Extension \"%s\" could not be loaded: %s" , filename, GetDLError()); | 
| 166 | 	} | 
| 167 |  | 
| 168 | 	ext_version_fun_t version_fun; | 
| 169 | 	auto version_fun_name = basename + "_version" ; | 
| 170 |  | 
| 171 | 	version_fun = LoadFunctionFromDLL<ext_version_fun_t>(dll: lib_hdl, function_name: version_fun_name, filename); | 
| 172 |  | 
| 173 | 	std::string engine_version = std::string(DuckDB::LibraryVersion()); | 
| 174 |  | 
| 175 | 	auto version_fun_result = (*version_fun)(); | 
| 176 | 	if (version_fun_result == nullptr) { | 
| 177 | 		throw InvalidInputException("Extension \"%s\" returned a nullptr" , filename); | 
| 178 | 	} | 
| 179 | 	std::string extension_version = std::string(version_fun_result); | 
| 180 |  | 
| 181 | 	// Trim v's if necessary | 
| 182 | 	std::string extension_version_trimmed = extension_version; | 
| 183 | 	std::string engine_version_trimmed = engine_version; | 
| 184 | 	if (extension_version.length() > 0 && extension_version[0] == 'v') { | 
| 185 | 		extension_version_trimmed = extension_version.substr(pos: 1); | 
| 186 | 	} | 
| 187 | 	if (engine_version.length() > 0 && engine_version[0] == 'v') { | 
| 188 | 		engine_version_trimmed = engine_version.substr(pos: 1); | 
| 189 | 	} | 
| 190 |  | 
| 191 | 	if (extension_version_trimmed != engine_version_trimmed) { | 
| 192 | 		throw InvalidInputException("Extension \"%s\" version (%s) does not match DuckDB version (%s)" , filename, | 
| 193 | 		                            extension_version, engine_version); | 
| 194 | 	} | 
| 195 |  | 
| 196 | 	result.basename = basename; | 
| 197 | 	result.filename = filename; | 
| 198 | 	result.lib_hdl = lib_hdl; | 
| 199 | 	return true; | 
| 200 | } | 
| 201 |  | 
| 202 | ExtensionInitResult ExtensionHelper::InitialLoad(DBConfig &config, FileSystem &fs, const string &extension) { | 
| 203 | 	string error; | 
| 204 | 	ExtensionInitResult result; | 
| 205 | 	if (!TryInitialLoad(config, fs, extension, result, error)) { | 
| 206 | 		if (!ExtensionHelper::AllowAutoInstall(extension)) { | 
| 207 | 			throw IOException(error); | 
| 208 | 		} | 
| 209 | 		// the extension load failed - try installing the extension | 
| 210 | 		ExtensionHelper::InstallExtension(config, fs, extension, force_install: false); | 
| 211 | 		// try loading again | 
| 212 | 		if (!TryInitialLoad(config, fs, extension, result, error)) { | 
| 213 | 			throw IOException(error); | 
| 214 | 		} | 
| 215 | 	} | 
| 216 | 	return result; | 
| 217 | } | 
| 218 |  | 
| 219 | bool ExtensionHelper::IsFullPath(const string &extension) { | 
| 220 | 	return StringUtil::Contains(haystack: extension, needle: "." ) || StringUtil::Contains(haystack: extension, needle: "/" ) || | 
| 221 | 	       StringUtil::Contains(haystack: extension, needle: "\\" ); | 
| 222 | } | 
| 223 |  | 
| 224 | string ExtensionHelper::GetExtensionName(const string &original_name) { | 
| 225 | 	auto extension = StringUtil::Lower(str: original_name); | 
| 226 | 	if (!IsFullPath(extension)) { | 
| 227 | 		return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); | 
| 228 | 	} | 
| 229 | 	auto splits = StringUtil::Split(str: StringUtil::Replace(source: extension, from: "\\" , to: "/" ), delimiter: '/'); | 
| 230 | 	if (splits.empty()) { | 
| 231 | 		return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); | 
| 232 | 	} | 
| 233 | 	splits = StringUtil::Split(str: splits.back(), delimiter: '.'); | 
| 234 | 	if (splits.empty()) { | 
| 235 | 		return ExtensionHelper::ApplyExtensionAlias(extension_name: extension); | 
| 236 | 	} | 
| 237 | 	return ExtensionHelper::ApplyExtensionAlias(extension_name: splits.front()); | 
| 238 | } | 
| 239 |  | 
| 240 | void ExtensionHelper::LoadExternalExtension(DatabaseInstance &db, FileSystem &fs, const string &extension) { | 
| 241 | 	if (db.ExtensionIsLoaded(name: extension)) { | 
| 242 | 		return; | 
| 243 | 	} | 
| 244 |  | 
| 245 | 	auto res = InitialLoad(config&: DBConfig::GetConfig(db), fs, extension); | 
| 246 | 	auto init_fun_name = res.basename + "_init" ; | 
| 247 |  | 
| 248 | 	ext_init_fun_t init_fun; | 
| 249 | 	init_fun = LoadFunctionFromDLL<ext_init_fun_t>(dll: res.lib_hdl, function_name: init_fun_name, filename: res.filename); | 
| 250 |  | 
| 251 | 	try { | 
| 252 | 		(*init_fun)(db); | 
| 253 | 	} catch (std::exception &e) { | 
| 254 | 		throw InvalidInputException("Initialization function \"%s\" from file \"%s\" threw an exception: \"%s\"" , | 
| 255 | 		                            init_fun_name, res.filename, e.what()); | 
| 256 | 	} | 
| 257 |  | 
| 258 | 	db.SetExtensionLoaded(extension); | 
| 259 | } | 
| 260 |  | 
| 261 | void ExtensionHelper::LoadExternalExtension(ClientContext &context, const string &extension) { | 
| 262 | 	LoadExternalExtension(db&: DatabaseInstance::GetDatabase(context), fs&: FileSystem::GetFileSystem(context), extension); | 
| 263 | } | 
| 264 |  | 
| 265 | string ExtensionHelper::(const string &path) { | 
| 266 | 	auto first_colon = path.find(c: ':'); | 
| 267 | 	if (first_colon == string::npos || first_colon < 2) { // needs to be at least two characters because windows c: ... | 
| 268 | 		return "" ; | 
| 269 | 	} | 
| 270 | 	auto extension = path.substr(pos: 0, n: first_colon); | 
| 271 |  | 
| 272 | 	if (path.substr(pos: first_colon, n: 3) == "://" ) { | 
| 273 | 		// these are not extensions | 
| 274 | 		return "" ; | 
| 275 | 	} | 
| 276 |  | 
| 277 | 	D_ASSERT(extension.size() > 1); | 
| 278 | 	// needs to be alphanumeric | 
| 279 | 	for (auto &ch : extension) { | 
| 280 | 		if (!isalnum(ch) && ch != '_') { | 
| 281 | 			return "" ; | 
| 282 | 		} | 
| 283 | 	} | 
| 284 | 	return extension; | 
| 285 | } | 
| 286 |  | 
| 287 | } // namespace duckdb | 
| 288 |  |