| 1 | #include <Interpreters/ExpressionJIT.h> |
| 2 | |
| 3 | #if USE_EMBEDDED_COMPILER |
| 4 | |
| 5 | #include <optional> |
| 6 | |
| 7 | #include <Columns/ColumnConst.h> |
| 8 | #include <Columns/ColumnNullable.h> |
| 9 | #include <Columns/ColumnVector.h> |
| 10 | #include <Common/LRUCache.h> |
| 11 | #include <Common/typeid_cast.h> |
| 12 | #include <Common/assert_cast.h> |
| 13 | #include <Common/ProfileEvents.h> |
| 14 | #include <Common/Stopwatch.h> |
| 15 | #include <DataTypes/DataTypeNullable.h> |
| 16 | #include <DataTypes/DataTypesNumber.h> |
| 17 | #include <DataTypes/Native.h> |
| 18 | #include <Functions/IFunctionAdaptors.h> |
| 19 | |
| 20 | #pragma GCC diagnostic push |
| 21 | #pragma GCC diagnostic ignored "-Wunused-parameter" |
| 22 | #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" |
| 23 | |
| 24 | #include <llvm/Analysis/TargetTransformInfo.h> |
| 25 | #include <llvm/IR/BasicBlock.h> |
| 26 | #include <llvm/IR/DataLayout.h> |
| 27 | #include <llvm/IR/DerivedTypes.h> |
| 28 | #include <llvm/IR/Function.h> |
| 29 | #include <llvm/IR/IRBuilder.h> |
| 30 | #include <llvm/IR/LLVMContext.h> |
| 31 | #include <llvm/IR/Mangler.h> |
| 32 | #include <llvm/IR/Module.h> |
| 33 | #include <llvm/IR/Type.h> |
| 34 | #include <llvm/IR/LegacyPassManager.h> |
| 35 | #include <llvm/ExecutionEngine/ExecutionEngine.h> |
| 36 | #include <llvm/ExecutionEngine/JITSymbol.h> |
| 37 | #include <llvm/ExecutionEngine/SectionMemoryManager.h> |
| 38 | #include <llvm/ExecutionEngine/Orc/CompileUtils.h> |
| 39 | #include <llvm/ExecutionEngine/Orc/IRCompileLayer.h> |
| 40 | #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h> |
| 41 | #include <llvm/Target/TargetMachine.h> |
| 42 | #include <llvm/MC/SubtargetFeature.h> |
| 43 | #include <llvm/Support/DynamicLibrary.h> |
| 44 | #include <llvm/Support/Host.h> |
| 45 | #include <llvm/Support/TargetRegistry.h> |
| 46 | #include <llvm/Support/TargetSelect.h> |
| 47 | #include <llvm/Transforms/IPO/PassManagerBuilder.h> |
| 48 | |
| 49 | #pragma GCC diagnostic pop |
| 50 | |
| 51 | /// 'LegacyRTDyldObjectLinkingLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use ORCv2 |
| 52 | /// 'LegacyIRCompileLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use the ORCv2 IRCompileLayer instead |
| 53 | #pragma GCC diagnostic ignored "-Wdeprecated-declarations" |
| 54 | |
| 55 | |
| 56 | namespace ProfileEvents |
| 57 | { |
| 58 | extern const Event CompileFunction; |
| 59 | extern const Event CompileExpressionsMicroseconds; |
| 60 | extern const Event CompileExpressionsBytes; |
| 61 | } |
| 62 | |
| 63 | namespace DB |
| 64 | { |
| 65 | |
| 66 | namespace ErrorCodes |
| 67 | { |
| 68 | extern const int LOGICAL_ERROR; |
| 69 | extern const int CANNOT_COMPILE_CODE; |
| 70 | } |
| 71 | |
| 72 | namespace |
| 73 | { |
| 74 | struct ColumnData |
| 75 | { |
| 76 | const char * data = nullptr; |
| 77 | const char * null = nullptr; |
| 78 | size_t stride = 0; |
| 79 | }; |
| 80 | |
| 81 | struct ColumnDataPlaceholder |
| 82 | { |
| 83 | llvm::Value * data_init; /// first row |
| 84 | llvm::Value * null_init; |
| 85 | llvm::Value * stride; |
| 86 | llvm::PHINode * data; /// current row |
| 87 | llvm::PHINode * null; |
| 88 | }; |
| 89 | } |
| 90 | |
| 91 | static ColumnData getColumnData(const IColumn * column) |
| 92 | { |
| 93 | ColumnData result; |
| 94 | const bool is_const = isColumnConst(*column); |
| 95 | if (is_const) |
| 96 | column = &reinterpret_cast<const ColumnConst *>(column)->getDataColumn(); |
| 97 | if (auto * nullable = typeid_cast<const ColumnNullable *>(column)) |
| 98 | { |
| 99 | result.null = nullable->getNullMapColumn().getRawData().data; |
| 100 | column = &nullable->getNestedColumn(); |
| 101 | } |
| 102 | result.data = column->getRawData().data; |
| 103 | result.stride = is_const ? 0 : column->sizeOfValueIfFixed(); |
| 104 | return result; |
| 105 | } |
| 106 | |
| 107 | static void applyFunction(IFunctionBase & function, Field & value) |
| 108 | { |
| 109 | const auto & type = function.getArgumentTypes().at(0); |
| 110 | Block block = {{ type->createColumnConst(1, value), type, "x" }, { nullptr, function.getReturnType(), "y" }}; |
| 111 | function.execute(block, {0}, 1, 1); |
| 112 | block.safeGetByPosition(1).column->get(0, value); |
| 113 | } |
| 114 | |
| 115 | static llvm::TargetMachine * getNativeMachine() |
| 116 | { |
| 117 | std::string error; |
| 118 | auto cpu = llvm::sys::getHostCPUName(); |
| 119 | auto triple = llvm::sys::getProcessTriple(); |
| 120 | auto target = llvm::TargetRegistry::lookupTarget(triple, error); |
| 121 | if (!target) |
| 122 | throw Exception("Could not initialize native target: " + error, ErrorCodes::CANNOT_COMPILE_CODE); |
| 123 | llvm::SubtargetFeatures features; |
| 124 | llvm::StringMap<bool> feature_map; |
| 125 | if (llvm::sys::getHostCPUFeatures(feature_map)) |
| 126 | for (auto & f : feature_map) |
| 127 | features.AddFeature(f.first(), f.second); |
| 128 | llvm::TargetOptions options; |
| 129 | return target->createTargetMachine( |
| 130 | triple, cpu, features.getString(), options, llvm::None, |
| 131 | llvm::None, llvm::CodeGenOpt::Default, /*jit=*/true |
| 132 | ); |
| 133 | } |
| 134 | |
| 135 | |
| 136 | struct SymbolResolver : public llvm::orc::SymbolResolver |
| 137 | { |
| 138 | llvm::LegacyJITSymbolResolver & impl; |
| 139 | |
| 140 | SymbolResolver(llvm::LegacyJITSymbolResolver & impl_) : impl(impl_) {} |
| 141 | |
| 142 | llvm::orc::SymbolNameSet getResponsibilitySet(const llvm::orc::SymbolNameSet & symbols) final |
| 143 | { |
| 144 | return symbols; |
| 145 | } |
| 146 | |
| 147 | llvm::orc::SymbolNameSet lookup(std::shared_ptr<llvm::orc::AsynchronousSymbolQuery> query, llvm::orc::SymbolNameSet symbols) final |
| 148 | { |
| 149 | llvm::orc::SymbolNameSet missing; |
| 150 | for (const auto & symbol : symbols) |
| 151 | { |
| 152 | bool has_resolved = false; |
| 153 | impl.lookup({*symbol}, [&](llvm::Expected<llvm::JITSymbolResolver::LookupResult> resolved) |
| 154 | { |
| 155 | if (resolved && resolved->size()) |
| 156 | { |
| 157 | query->notifySymbolMetRequiredState(symbol, resolved->begin()->second); |
| 158 | has_resolved = true; |
| 159 | } |
| 160 | }); |
| 161 | |
| 162 | if (!has_resolved) |
| 163 | missing.insert(symbol); |
| 164 | } |
| 165 | return missing; |
| 166 | } |
| 167 | }; |
| 168 | |
| 169 | |
| 170 | struct LLVMContext |
| 171 | { |
| 172 | std::shared_ptr<llvm::LLVMContext> context {std::make_shared<llvm::LLVMContext>()}; |
| 173 | std::unique_ptr<llvm::Module> module {std::make_unique<llvm::Module>("jit" , *context)}; |
| 174 | std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()}; |
| 175 | llvm::DataLayout layout {machine->createDataLayout()}; |
| 176 | llvm::IRBuilder<> builder {*context}; |
| 177 | |
| 178 | llvm::orc::ExecutionSession execution_session; |
| 179 | |
| 180 | std::shared_ptr<llvm::SectionMemoryManager> memory_manager; |
| 181 | llvm::orc::LegacyRTDyldObjectLinkingLayer object_layer; |
| 182 | llvm::orc::LegacyIRCompileLayer<decltype(object_layer), llvm::orc::SimpleCompiler> compile_layer; |
| 183 | |
| 184 | std::unordered_map<std::string, void *> symbols; |
| 185 | |
| 186 | LLVMContext() |
| 187 | : memory_manager(std::make_shared<llvm::SectionMemoryManager>()) |
| 188 | , object_layer(execution_session, [this](llvm::orc::VModuleKey) |
| 189 | { |
| 190 | return llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources{memory_manager, std::make_shared<SymbolResolver>(*memory_manager)}; |
| 191 | }) |
| 192 | , compile_layer(object_layer, llvm::orc::SimpleCompiler(*machine)) |
| 193 | { |
| 194 | module->setDataLayout(layout); |
| 195 | module->setTargetTriple(machine->getTargetTriple().getTriple()); |
| 196 | } |
| 197 | |
| 198 | /// returns used memory |
| 199 | void compileAllFunctionsToNativeCode() |
| 200 | { |
| 201 | if (!module->size()) |
| 202 | return; |
| 203 | llvm::PassManagerBuilder pass_manager_builder; |
| 204 | llvm::legacy::PassManager mpm; |
| 205 | llvm::legacy::FunctionPassManager fpm(module.get()); |
| 206 | pass_manager_builder.OptLevel = 3; |
| 207 | pass_manager_builder.SLPVectorize = true; |
| 208 | pass_manager_builder.LoopVectorize = true; |
| 209 | pass_manager_builder.RerollLoops = true; |
| 210 | pass_manager_builder.VerifyInput = true; |
| 211 | pass_manager_builder.VerifyOutput = true; |
| 212 | machine->adjustPassManager(pass_manager_builder); |
| 213 | fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); |
| 214 | mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); |
| 215 | pass_manager_builder.populateFunctionPassManager(fpm); |
| 216 | pass_manager_builder.populateModulePassManager(mpm); |
| 217 | fpm.doInitialization(); |
| 218 | for (auto & function : *module) |
| 219 | fpm.run(function); |
| 220 | fpm.doFinalization(); |
| 221 | mpm.run(*module); |
| 222 | |
| 223 | std::vector<std::string> functions; |
| 224 | functions.reserve(module->size()); |
| 225 | for (const auto & function : *module) |
| 226 | functions.emplace_back(function.getName()); |
| 227 | |
| 228 | llvm::orc::VModuleKey module_key = execution_session.allocateVModule(); |
| 229 | if (compile_layer.addModule(module_key, std::move(module))) |
| 230 | throw Exception("Cannot add module to compile layer" , ErrorCodes::CANNOT_COMPILE_CODE); |
| 231 | |
| 232 | for (const auto & name : functions) |
| 233 | { |
| 234 | std::string mangled_name; |
| 235 | llvm::raw_string_ostream mangled_name_stream(mangled_name); |
| 236 | llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, layout); |
| 237 | mangled_name_stream.flush(); |
| 238 | auto symbol = compile_layer.findSymbol(mangled_name, false); |
| 239 | if (!symbol) |
| 240 | continue; /// external function (e.g. an intrinsic that calls into libc) |
| 241 | auto address = symbol.getAddress(); |
| 242 | if (!address) |
| 243 | throw Exception("Function " + name + " failed to link" , ErrorCodes::CANNOT_COMPILE_CODE); |
| 244 | symbols[name] = reinterpret_cast<void *>(*address); |
| 245 | } |
| 246 | } |
| 247 | }; |
| 248 | |
| 249 | |
| 250 | template <typename... Ts, typename F> |
| 251 | static bool castToEither(IColumn * column, F && f) |
| 252 | { |
| 253 | return ((typeid_cast<Ts *>(column) ? f(*typeid_cast<Ts *>(column)) : false) || ...); |
| 254 | } |
| 255 | |
| 256 | class LLVMExecutableFunction : public IExecutableFunctionImpl |
| 257 | { |
| 258 | std::string name; |
| 259 | void * function; |
| 260 | |
| 261 | public: |
| 262 | LLVMExecutableFunction(const std::string & name_, const std::unordered_map<std::string, void *> & symbols) |
| 263 | : name(name_) |
| 264 | { |
| 265 | auto it = symbols.find(name); |
| 266 | if (symbols.end() == it) |
| 267 | throw Exception("Cannot find symbol " + name + " in LLVMContext" , ErrorCodes::LOGICAL_ERROR); |
| 268 | function = it->second; |
| 269 | } |
| 270 | |
| 271 | String getName() const override { return name; } |
| 272 | |
| 273 | bool useDefaultImplementationForNulls() const override { return false; } |
| 274 | |
| 275 | bool useDefaultImplementationForConstants() const override { return true; } |
| 276 | |
| 277 | void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t block_size) override |
| 278 | { |
| 279 | auto col_res = block.getByPosition(result).type->createColumn(); |
| 280 | |
| 281 | if (block_size) |
| 282 | { |
| 283 | if (!castToEither< |
| 284 | ColumnUInt8, ColumnUInt16, ColumnUInt32, ColumnUInt64, |
| 285 | ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, |
| 286 | ColumnFloat32, ColumnFloat64>(col_res.get(), [block_size](auto & col) { col.getData().resize(block_size); return true; })) |
| 287 | throw Exception("Unexpected column in LLVMExecutableFunction: " + col_res->getName(), ErrorCodes::LOGICAL_ERROR); |
| 288 | |
| 289 | std::vector<ColumnData> columns(arguments.size() + 1); |
| 290 | for (size_t i = 0; i < arguments.size(); ++i) |
| 291 | { |
| 292 | auto * column = block.getByPosition(arguments[i]).column.get(); |
| 293 | if (!column) |
| 294 | throw Exception("Column " + block.getByPosition(arguments[i]).name + " is missing" , ErrorCodes::LOGICAL_ERROR); |
| 295 | columns[i] = getColumnData(column); |
| 296 | } |
| 297 | columns[arguments.size()] = getColumnData(col_res.get()); |
| 298 | reinterpret_cast<void (*) (size_t, ColumnData *)>(function)(block_size, columns.data()); |
| 299 | } |
| 300 | |
| 301 | block.getByPosition(result).column = std::move(col_res); |
| 302 | } |
| 303 | }; |
| 304 | |
| 305 | static void compileFunctionToLLVMByteCode(LLVMContext & context, const IFunctionBaseImpl & f) |
| 306 | { |
| 307 | ProfileEvents::increment(ProfileEvents::CompileFunction); |
| 308 | |
| 309 | auto & arg_types = f.getArgumentTypes(); |
| 310 | auto & b = context.builder; |
| 311 | auto * size_type = b.getIntNTy(sizeof(size_t) * 8); |
| 312 | auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type); |
| 313 | auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false); |
| 314 | auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), context.module.get()); |
| 315 | auto args = func->args().begin(); |
| 316 | llvm::Value * counter_arg = &*args++; |
| 317 | llvm::Value * columns_arg = &*args++; |
| 318 | |
| 319 | auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry" , func); |
| 320 | b.SetInsertPoint(entry); |
| 321 | std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1); |
| 322 | for (size_t i = 0; i <= arg_types.size(); ++i) |
| 323 | { |
| 324 | auto & type = i == arg_types.size() ? f.getReturnType() : arg_types[i]; |
| 325 | auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i)); |
| 326 | columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo()); |
| 327 | columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr; |
| 328 | columns[i].stride = b.CreateExtractValue(data, {2}); |
| 329 | } |
| 330 | |
| 331 | /// assume nonzero initial value in `counter_arg` |
| 332 | auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop" , func); |
| 333 | b.CreateBr(loop); |
| 334 | b.SetInsertPoint(loop); |
| 335 | auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2); |
| 336 | counter_phi->addIncoming(counter_arg, entry); |
| 337 | for (auto & col : columns) |
| 338 | { |
| 339 | col.data = b.CreatePHI(col.data_init->getType(), 2); |
| 340 | col.data->addIncoming(col.data_init, entry); |
| 341 | if (col.null_init) |
| 342 | { |
| 343 | col.null = b.CreatePHI(col.null_init->getType(), 2); |
| 344 | col.null->addIncoming(col.null_init, entry); |
| 345 | } |
| 346 | } |
| 347 | ValuePlaceholders arguments(arg_types.size()); |
| 348 | for (size_t i = 0; i < arguments.size(); ++i) |
| 349 | { |
| 350 | arguments[i] = [&b, &col = columns[i], &type = arg_types[i]]() -> llvm::Value * |
| 351 | { |
| 352 | auto * value = b.CreateLoad(col.data); |
| 353 | if (!col.null) |
| 354 | return value; |
| 355 | auto * is_null = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(0)); |
| 356 | auto * nullable = llvm::Constant::getNullValue(toNativeType(b, type)); |
| 357 | return b.CreateInsertValue(b.CreateInsertValue(nullable, value, {0}), is_null, {1}); |
| 358 | }; |
| 359 | } |
| 360 | auto * result = f.compile(b, std::move(arguments)); |
| 361 | if (columns.back().null) |
| 362 | { |
| 363 | b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data); |
| 364 | b.CreateStore(b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)), columns.back().null); |
| 365 | } |
| 366 | else |
| 367 | { |
| 368 | b.CreateStore(result, columns.back().data); |
| 369 | } |
| 370 | auto * cur_block = b.GetInsertBlock(); |
| 371 | for (auto & col : columns) |
| 372 | { |
| 373 | /// stride is either 0 or size of native type; output column is never constant; neither is at least one input |
| 374 | auto * is_const = &col == &columns.back() || columns.size() <= 2 ? b.getFalse() : b.CreateICmpEQ(col.stride, llvm::ConstantInt::get(size_type, 0)); |
| 375 | col.data->addIncoming(b.CreateSelect(is_const, col.data, b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1)), cur_block); |
| 376 | if (col.null) |
| 377 | col.null->addIncoming(b.CreateSelect(is_const, col.null, b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1)), cur_block); |
| 378 | } |
| 379 | counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block); |
| 380 | |
| 381 | auto * end = llvm::BasicBlock::Create(b.getContext(), "end" , func); |
| 382 | b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end); |
| 383 | b.SetInsertPoint(end); |
| 384 | b.CreateRetVoid(); |
| 385 | } |
| 386 | |
| 387 | static llvm::Constant * getNativeValue(llvm::Type * type, const IColumn & column, size_t i) |
| 388 | { |
| 389 | if (!type || column.size() <= i) |
| 390 | return nullptr; |
| 391 | if (auto * constant = typeid_cast<const ColumnConst *>(&column)) |
| 392 | return getNativeValue(type, constant->getDataColumn(), 0); |
| 393 | if (auto * nullable = typeid_cast<const ColumnNullable *>(&column)) |
| 394 | { |
| 395 | auto * value = getNativeValue(type->getContainedType(0), nullable->getNestedColumn(), i); |
| 396 | auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i)); |
| 397 | return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr; |
| 398 | } |
| 399 | if (type->isFloatTy()) |
| 400 | return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float32> &>(column).getElement(i)); |
| 401 | if (type->isDoubleTy()) |
| 402 | return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float64> &>(column).getElement(i)); |
| 403 | if (type->isIntegerTy()) |
| 404 | return llvm::ConstantInt::get(type, column.getUInt(i)); |
| 405 | /// TODO: if (type->isVectorTy()) |
| 406 | return nullptr; |
| 407 | } |
| 408 | |
| 409 | /// Same as IFunctionBase::compile, but also for constants and input columns. |
| 410 | using CompilableExpression = std::function<llvm::Value * (llvm::IRBuilderBase &, const ValuePlaceholders &)>; |
| 411 | |
| 412 | static CompilableExpression subexpression(ColumnPtr c, DataTypePtr type) |
| 413 | { |
| 414 | return [=](llvm::IRBuilderBase & b, const ValuePlaceholders &) { return getNativeValue(toNativeType(b, type), *c, 0); }; |
| 415 | } |
| 416 | |
| 417 | static CompilableExpression subexpression(size_t i) |
| 418 | { |
| 419 | return [=](llvm::IRBuilderBase &, const ValuePlaceholders & inputs) { return inputs[i](); }; |
| 420 | } |
| 421 | |
| 422 | static CompilableExpression subexpression(const IFunctionBase & f, std::vector<CompilableExpression> args) |
| 423 | { |
| 424 | return [&, args = std::move(args)](llvm::IRBuilderBase & builder, const ValuePlaceholders & inputs) |
| 425 | { |
| 426 | ValuePlaceholders input; |
| 427 | for (const auto & arg : args) |
| 428 | input.push_back([&]() { return arg(builder, inputs); }); |
| 429 | auto * result = f.compile(builder, input); |
| 430 | if (result->getType() != toNativeType(builder, f.getReturnType())) |
| 431 | throw Exception("Function " + f.getName() + " generated an llvm::Value of invalid type" , ErrorCodes::LOGICAL_ERROR); |
| 432 | return result; |
| 433 | }; |
| 434 | } |
| 435 | |
| 436 | struct LLVMModuleState |
| 437 | { |
| 438 | std::unordered_map<std::string, void *> symbols; |
| 439 | std::shared_ptr<llvm::LLVMContext> major_context; |
| 440 | std::shared_ptr<llvm::SectionMemoryManager> memory_manager; |
| 441 | }; |
| 442 | |
| 443 | LLVMFunction::LLVMFunction(const ExpressionActions::Actions & actions, const Block & sample_block) |
| 444 | : name(actions.back().result_name) |
| 445 | , module_state(std::make_unique<LLVMModuleState>()) |
| 446 | { |
| 447 | LLVMContext context; |
| 448 | for (const auto & c : sample_block) |
| 449 | /// TODO: implement `getNativeValue` for all types & replace the check with `c.column && toNativeType(...)` |
| 450 | if (c.column && getNativeValue(toNativeType(context.builder, c.type), *c.column, 0)) |
| 451 | subexpressions[c.name] = subexpression(c.column, c.type); |
| 452 | for (const auto & action : actions) |
| 453 | { |
| 454 | const auto & names = action.argument_names; |
| 455 | const auto & types = action.function_base->getArgumentTypes(); |
| 456 | std::vector<CompilableExpression> args; |
| 457 | for (size_t i = 0; i < names.size(); ++i) |
| 458 | { |
| 459 | auto inserted = subexpressions.emplace(names[i], subexpression(arg_names.size())); |
| 460 | if (inserted.second) |
| 461 | { |
| 462 | arg_names.push_back(names[i]); |
| 463 | arg_types.push_back(types[i]); |
| 464 | } |
| 465 | args.push_back(inserted.first->second); |
| 466 | } |
| 467 | subexpressions[action.result_name] = subexpression(*action.function_base, std::move(args)); |
| 468 | originals.push_back(action.function_base); |
| 469 | } |
| 470 | compileFunctionToLLVMByteCode(context, *this); |
| 471 | context.compileAllFunctionsToNativeCode(); |
| 472 | |
| 473 | module_state->symbols = context.symbols; |
| 474 | module_state->major_context = context.context; |
| 475 | module_state->memory_manager = context.memory_manager; |
| 476 | } |
| 477 | |
| 478 | llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const |
| 479 | { |
| 480 | auto it = subexpressions.find(name); |
| 481 | if (subexpressions.end() == it) |
| 482 | throw Exception("Cannot find subexpression " + name + " in LLVMFunction" , ErrorCodes::LOGICAL_ERROR); |
| 483 | return it->second(builder, values); |
| 484 | } |
| 485 | |
| 486 | ExecutableFunctionImplPtr LLVMFunction::prepare(const Block &, const ColumnNumbers &, size_t) const { return std::make_unique<LLVMExecutableFunction>(name, module_state->symbols); } |
| 487 | |
| 488 | bool LLVMFunction::isDeterministic() const |
| 489 | { |
| 490 | for (const auto & f : originals) |
| 491 | if (!f->isDeterministic()) |
| 492 | return false; |
| 493 | return true; |
| 494 | } |
| 495 | |
| 496 | bool LLVMFunction::isDeterministicInScopeOfQuery() const |
| 497 | { |
| 498 | for (const auto & f : originals) |
| 499 | if (!f->isDeterministicInScopeOfQuery()) |
| 500 | return false; |
| 501 | return true; |
| 502 | } |
| 503 | |
| 504 | bool LLVMFunction::isSuitableForConstantFolding() const |
| 505 | { |
| 506 | for (const auto & f : originals) |
| 507 | if (!f->isSuitableForConstantFolding()) |
| 508 | return false; |
| 509 | return true; |
| 510 | } |
| 511 | |
| 512 | bool LLVMFunction::isInjective(const Block & sample_block) |
| 513 | { |
| 514 | for (const auto & f : originals) |
| 515 | if (!f->isInjective(sample_block)) |
| 516 | return false; |
| 517 | return true; |
| 518 | } |
| 519 | |
| 520 | bool LLVMFunction::hasInformationAboutMonotonicity() const |
| 521 | { |
| 522 | for (const auto & f : originals) |
| 523 | if (!f->hasInformationAboutMonotonicity()) |
| 524 | return false; |
| 525 | return true; |
| 526 | } |
| 527 | |
| 528 | LLVMFunction::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const |
| 529 | { |
| 530 | const IDataType * type_ = &type; |
| 531 | Field left_ = left; |
| 532 | Field right_ = right; |
| 533 | Monotonicity result(true, true, true); |
| 534 | /// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls |
| 535 | for (size_t i = 0; i < originals.size(); ++i) |
| 536 | { |
| 537 | Monotonicity m = originals[i]->getMonotonicityForRange(*type_, left_, right_); |
| 538 | if (!m.is_monotonic) |
| 539 | return m; |
| 540 | result.is_positive ^= !m.is_positive; |
| 541 | result.is_always_monotonic &= m.is_always_monotonic; |
| 542 | if (i + 1 < originals.size()) |
| 543 | { |
| 544 | if (left_ != Field()) |
| 545 | applyFunction(*originals[i], left_); |
| 546 | if (right_ != Field()) |
| 547 | applyFunction(*originals[i], right_); |
| 548 | if (!m.is_positive) |
| 549 | std::swap(left_, right_); |
| 550 | type_ = originals[i]->getReturnType().get(); |
| 551 | } |
| 552 | } |
| 553 | return result; |
| 554 | } |
| 555 | |
| 556 | |
| 557 | static bool isCompilable(const IFunctionBase & function) |
| 558 | { |
| 559 | if (!canBeNativeType(*function.getReturnType())) |
| 560 | return false; |
| 561 | for (const auto & type : function.getArgumentTypes()) |
| 562 | if (!canBeNativeType(*type)) |
| 563 | return false; |
| 564 | return function.isCompilable(); |
| 565 | } |
| 566 | |
| 567 | static std::vector<std::unordered_set<std::optional<size_t>>> getActionsDependents(const ExpressionActions::Actions & actions, const Names & output_columns) |
| 568 | { |
| 569 | /// an empty optional is a poisoned value prohibiting the column's producer from being removed |
| 570 | /// (which it could be, if it was inlined into every dependent function). |
| 571 | std::unordered_map<std::string, std::unordered_set<std::optional<size_t>>> current_dependents; |
| 572 | for (const auto & name : output_columns) |
| 573 | current_dependents[name].emplace(); |
| 574 | /// a snapshot of each compilable function's dependents at the time of its execution. |
| 575 | std::vector<std::unordered_set<std::optional<size_t>>> dependents(actions.size()); |
| 576 | for (size_t i = actions.size(); i--;) |
| 577 | { |
| 578 | switch (actions[i].type) |
| 579 | { |
| 580 | case ExpressionAction::REMOVE_COLUMN: |
| 581 | current_dependents.erase(actions[i].source_name); |
| 582 | /// poison every other column used after this point so that inlining chains do not cross it. |
| 583 | for (auto & dep : current_dependents) |
| 584 | dep.second.emplace(); |
| 585 | break; |
| 586 | |
| 587 | case ExpressionAction::PROJECT: |
| 588 | current_dependents.clear(); |
| 589 | for (const auto & proj : actions[i].projection) |
| 590 | current_dependents[proj.first].emplace(); |
| 591 | break; |
| 592 | |
| 593 | case ExpressionAction::ADD_ALIASES: |
| 594 | for (const auto & proj : actions[i].projection) |
| 595 | current_dependents[proj.first].emplace(); |
| 596 | break; |
| 597 | |
| 598 | case ExpressionAction::ADD_COLUMN: |
| 599 | case ExpressionAction::COPY_COLUMN: |
| 600 | case ExpressionAction::ARRAY_JOIN: |
| 601 | case ExpressionAction::JOIN: |
| 602 | { |
| 603 | Names columns = actions[i].getNeededColumns(); |
| 604 | for (const auto & column : columns) |
| 605 | current_dependents[column].emplace(); |
| 606 | break; |
| 607 | } |
| 608 | |
| 609 | case ExpressionAction::APPLY_FUNCTION: |
| 610 | { |
| 611 | dependents[i] = current_dependents[actions[i].result_name]; |
| 612 | const bool compilable = isCompilable(*actions[i].function_base); |
| 613 | for (const auto & name : actions[i].argument_names) |
| 614 | { |
| 615 | if (compilable) |
| 616 | current_dependents[name].emplace(i); |
| 617 | else |
| 618 | current_dependents[name].emplace(); |
| 619 | } |
| 620 | break; |
| 621 | } |
| 622 | } |
| 623 | } |
| 624 | return dependents; |
| 625 | } |
| 626 | |
| 627 | void compileFunctions( |
| 628 | ExpressionActions::Actions & actions, |
| 629 | const Names & output_columns, |
| 630 | const Block & sample_block, |
| 631 | std::shared_ptr<CompiledExpressionCache> compilation_cache, |
| 632 | size_t min_count_to_compile_expression) |
| 633 | { |
| 634 | static std::unordered_map<UInt128, UInt32, UInt128Hash> counter; |
| 635 | static std::mutex mutex; |
| 636 | |
| 637 | struct LLVMTargetInitializer |
| 638 | { |
| 639 | LLVMTargetInitializer() |
| 640 | { |
| 641 | llvm::InitializeNativeTarget(); |
| 642 | llvm::InitializeNativeTargetAsmPrinter(); |
| 643 | llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); |
| 644 | } |
| 645 | }; |
| 646 | |
| 647 | static LLVMTargetInitializer initializer; |
| 648 | |
| 649 | auto dependents = getActionsDependents(actions, output_columns); |
| 650 | std::vector<ExpressionActions::Actions> fused(actions.size()); |
| 651 | for (size_t i = 0; i < actions.size(); ++i) |
| 652 | { |
| 653 | if (actions[i].type != ExpressionAction::APPLY_FUNCTION || !isCompilable(*actions[i].function_base)) |
| 654 | continue; |
| 655 | |
| 656 | fused[i].push_back(actions[i]); |
| 657 | if (dependents[i].find({}) != dependents[i].end()) |
| 658 | { |
| 659 | /// the result of compiling one function in isolation is pretty much the same as its `execute` method. |
| 660 | if (fused[i].size() == 1) |
| 661 | continue; |
| 662 | |
| 663 | auto hash_key = ExpressionActions::ActionsHash{}(fused[i]); |
| 664 | { |
| 665 | std::lock_guard lock(mutex); |
| 666 | if (counter[hash_key]++ < min_count_to_compile_expression) |
| 667 | continue; |
| 668 | } |
| 669 | |
| 670 | FunctionBasePtr fn; |
| 671 | if (compilation_cache) |
| 672 | { |
| 673 | std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&inlined_func=std::as_const(fused[i]), &sample_block] () |
| 674 | { |
| 675 | Stopwatch watch; |
| 676 | FunctionBasePtr result_fn; |
| 677 | result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(inlined_func, sample_block)); |
| 678 | ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds()); |
| 679 | return result_fn; |
| 680 | }); |
| 681 | } |
| 682 | else |
| 683 | { |
| 684 | Stopwatch watch; |
| 685 | fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(fused[i], sample_block)); |
| 686 | ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds()); |
| 687 | } |
| 688 | |
| 689 | actions[i].function_base = fn; |
| 690 | actions[i].argument_names = typeid_cast<const LLVMFunction *>(typeid_cast<const FunctionBaseAdaptor *>(fn.get())->getImpl())->getArgumentNames(); |
| 691 | actions[i].is_function_compiled = true; |
| 692 | |
| 693 | continue; |
| 694 | } |
| 695 | |
| 696 | /// TODO: determine whether it's profitable to inline the function if there's more than one dependent. |
| 697 | for (const auto & dep : dependents[i]) |
| 698 | fused[*dep].insert(fused[*dep].end(), fused[i].begin(), fused[i].end()); |
| 699 | } |
| 700 | |
| 701 | for (size_t i = 0; i < actions.size(); ++i) |
| 702 | { |
| 703 | if (actions[i].type == ExpressionAction::APPLY_FUNCTION && actions[i].is_function_compiled) |
| 704 | actions[i].function = actions[i].function_base->prepare({}, {}, 0); /// Arguments are not used for LLVMFunction. |
| 705 | } |
| 706 | } |
| 707 | |
| 708 | } |
| 709 | |
| 710 | #endif |
| 711 | |