1 | #include <Interpreters/ExpressionJIT.h> |
2 | |
3 | #if USE_EMBEDDED_COMPILER |
4 | |
5 | #include <optional> |
6 | |
7 | #include <Columns/ColumnConst.h> |
8 | #include <Columns/ColumnNullable.h> |
9 | #include <Columns/ColumnVector.h> |
10 | #include <Common/LRUCache.h> |
11 | #include <Common/typeid_cast.h> |
12 | #include <Common/assert_cast.h> |
13 | #include <Common/ProfileEvents.h> |
14 | #include <Common/Stopwatch.h> |
15 | #include <DataTypes/DataTypeNullable.h> |
16 | #include <DataTypes/DataTypesNumber.h> |
17 | #include <DataTypes/Native.h> |
18 | #include <Functions/IFunctionAdaptors.h> |
19 | |
20 | #pragma GCC diagnostic push |
21 | #pragma GCC diagnostic ignored "-Wunused-parameter" |
22 | #pragma GCC diagnostic ignored "-Wnon-virtual-dtor" |
23 | |
24 | #include <llvm/Analysis/TargetTransformInfo.h> |
25 | #include <llvm/IR/BasicBlock.h> |
26 | #include <llvm/IR/DataLayout.h> |
27 | #include <llvm/IR/DerivedTypes.h> |
28 | #include <llvm/IR/Function.h> |
29 | #include <llvm/IR/IRBuilder.h> |
30 | #include <llvm/IR/LLVMContext.h> |
31 | #include <llvm/IR/Mangler.h> |
32 | #include <llvm/IR/Module.h> |
33 | #include <llvm/IR/Type.h> |
34 | #include <llvm/IR/LegacyPassManager.h> |
35 | #include <llvm/ExecutionEngine/ExecutionEngine.h> |
36 | #include <llvm/ExecutionEngine/JITSymbol.h> |
37 | #include <llvm/ExecutionEngine/SectionMemoryManager.h> |
38 | #include <llvm/ExecutionEngine/Orc/CompileUtils.h> |
39 | #include <llvm/ExecutionEngine/Orc/IRCompileLayer.h> |
40 | #include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h> |
41 | #include <llvm/Target/TargetMachine.h> |
42 | #include <llvm/MC/SubtargetFeature.h> |
43 | #include <llvm/Support/DynamicLibrary.h> |
44 | #include <llvm/Support/Host.h> |
45 | #include <llvm/Support/TargetRegistry.h> |
46 | #include <llvm/Support/TargetSelect.h> |
47 | #include <llvm/Transforms/IPO/PassManagerBuilder.h> |
48 | |
49 | #pragma GCC diagnostic pop |
50 | |
51 | /// 'LegacyRTDyldObjectLinkingLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use ORCv2 |
52 | /// 'LegacyIRCompileLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use the ORCv2 IRCompileLayer instead |
53 | #pragma GCC diagnostic ignored "-Wdeprecated-declarations" |
54 | |
55 | |
56 | namespace ProfileEvents |
57 | { |
58 | extern const Event CompileFunction; |
59 | extern const Event CompileExpressionsMicroseconds; |
60 | extern const Event CompileExpressionsBytes; |
61 | } |
62 | |
63 | namespace DB |
64 | { |
65 | |
66 | namespace ErrorCodes |
67 | { |
68 | extern const int LOGICAL_ERROR; |
69 | extern const int CANNOT_COMPILE_CODE; |
70 | } |
71 | |
72 | namespace |
73 | { |
74 | struct ColumnData |
75 | { |
76 | const char * data = nullptr; |
77 | const char * null = nullptr; |
78 | size_t stride = 0; |
79 | }; |
80 | |
81 | struct ColumnDataPlaceholder |
82 | { |
83 | llvm::Value * data_init; /// first row |
84 | llvm::Value * null_init; |
85 | llvm::Value * stride; |
86 | llvm::PHINode * data; /// current row |
87 | llvm::PHINode * null; |
88 | }; |
89 | } |
90 | |
91 | static ColumnData getColumnData(const IColumn * column) |
92 | { |
93 | ColumnData result; |
94 | const bool is_const = isColumnConst(*column); |
95 | if (is_const) |
96 | column = &reinterpret_cast<const ColumnConst *>(column)->getDataColumn(); |
97 | if (auto * nullable = typeid_cast<const ColumnNullable *>(column)) |
98 | { |
99 | result.null = nullable->getNullMapColumn().getRawData().data; |
100 | column = &nullable->getNestedColumn(); |
101 | } |
102 | result.data = column->getRawData().data; |
103 | result.stride = is_const ? 0 : column->sizeOfValueIfFixed(); |
104 | return result; |
105 | } |
106 | |
107 | static void applyFunction(IFunctionBase & function, Field & value) |
108 | { |
109 | const auto & type = function.getArgumentTypes().at(0); |
110 | Block block = {{ type->createColumnConst(1, value), type, "x" }, { nullptr, function.getReturnType(), "y" }}; |
111 | function.execute(block, {0}, 1, 1); |
112 | block.safeGetByPosition(1).column->get(0, value); |
113 | } |
114 | |
115 | static llvm::TargetMachine * getNativeMachine() |
116 | { |
117 | std::string error; |
118 | auto cpu = llvm::sys::getHostCPUName(); |
119 | auto triple = llvm::sys::getProcessTriple(); |
120 | auto target = llvm::TargetRegistry::lookupTarget(triple, error); |
121 | if (!target) |
122 | throw Exception("Could not initialize native target: " + error, ErrorCodes::CANNOT_COMPILE_CODE); |
123 | llvm::SubtargetFeatures features; |
124 | llvm::StringMap<bool> feature_map; |
125 | if (llvm::sys::getHostCPUFeatures(feature_map)) |
126 | for (auto & f : feature_map) |
127 | features.AddFeature(f.first(), f.second); |
128 | llvm::TargetOptions options; |
129 | return target->createTargetMachine( |
130 | triple, cpu, features.getString(), options, llvm::None, |
131 | llvm::None, llvm::CodeGenOpt::Default, /*jit=*/true |
132 | ); |
133 | } |
134 | |
135 | |
136 | struct SymbolResolver : public llvm::orc::SymbolResolver |
137 | { |
138 | llvm::LegacyJITSymbolResolver & impl; |
139 | |
140 | SymbolResolver(llvm::LegacyJITSymbolResolver & impl_) : impl(impl_) {} |
141 | |
142 | llvm::orc::SymbolNameSet getResponsibilitySet(const llvm::orc::SymbolNameSet & symbols) final |
143 | { |
144 | return symbols; |
145 | } |
146 | |
147 | llvm::orc::SymbolNameSet lookup(std::shared_ptr<llvm::orc::AsynchronousSymbolQuery> query, llvm::orc::SymbolNameSet symbols) final |
148 | { |
149 | llvm::orc::SymbolNameSet missing; |
150 | for (const auto & symbol : symbols) |
151 | { |
152 | bool has_resolved = false; |
153 | impl.lookup({*symbol}, [&](llvm::Expected<llvm::JITSymbolResolver::LookupResult> resolved) |
154 | { |
155 | if (resolved && resolved->size()) |
156 | { |
157 | query->notifySymbolMetRequiredState(symbol, resolved->begin()->second); |
158 | has_resolved = true; |
159 | } |
160 | }); |
161 | |
162 | if (!has_resolved) |
163 | missing.insert(symbol); |
164 | } |
165 | return missing; |
166 | } |
167 | }; |
168 | |
169 | |
170 | struct LLVMContext |
171 | { |
172 | std::shared_ptr<llvm::LLVMContext> context {std::make_shared<llvm::LLVMContext>()}; |
173 | std::unique_ptr<llvm::Module> module {std::make_unique<llvm::Module>("jit" , *context)}; |
174 | std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()}; |
175 | llvm::DataLayout layout {machine->createDataLayout()}; |
176 | llvm::IRBuilder<> builder {*context}; |
177 | |
178 | llvm::orc::ExecutionSession execution_session; |
179 | |
180 | std::shared_ptr<llvm::SectionMemoryManager> memory_manager; |
181 | llvm::orc::LegacyRTDyldObjectLinkingLayer object_layer; |
182 | llvm::orc::LegacyIRCompileLayer<decltype(object_layer), llvm::orc::SimpleCompiler> compile_layer; |
183 | |
184 | std::unordered_map<std::string, void *> symbols; |
185 | |
186 | LLVMContext() |
187 | : memory_manager(std::make_shared<llvm::SectionMemoryManager>()) |
188 | , object_layer(execution_session, [this](llvm::orc::VModuleKey) |
189 | { |
190 | return llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources{memory_manager, std::make_shared<SymbolResolver>(*memory_manager)}; |
191 | }) |
192 | , compile_layer(object_layer, llvm::orc::SimpleCompiler(*machine)) |
193 | { |
194 | module->setDataLayout(layout); |
195 | module->setTargetTriple(machine->getTargetTriple().getTriple()); |
196 | } |
197 | |
198 | /// returns used memory |
199 | void compileAllFunctionsToNativeCode() |
200 | { |
201 | if (!module->size()) |
202 | return; |
203 | llvm::PassManagerBuilder pass_manager_builder; |
204 | llvm::legacy::PassManager mpm; |
205 | llvm::legacy::FunctionPassManager fpm(module.get()); |
206 | pass_manager_builder.OptLevel = 3; |
207 | pass_manager_builder.SLPVectorize = true; |
208 | pass_manager_builder.LoopVectorize = true; |
209 | pass_manager_builder.RerollLoops = true; |
210 | pass_manager_builder.VerifyInput = true; |
211 | pass_manager_builder.VerifyOutput = true; |
212 | machine->adjustPassManager(pass_manager_builder); |
213 | fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); |
214 | mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); |
215 | pass_manager_builder.populateFunctionPassManager(fpm); |
216 | pass_manager_builder.populateModulePassManager(mpm); |
217 | fpm.doInitialization(); |
218 | for (auto & function : *module) |
219 | fpm.run(function); |
220 | fpm.doFinalization(); |
221 | mpm.run(*module); |
222 | |
223 | std::vector<std::string> functions; |
224 | functions.reserve(module->size()); |
225 | for (const auto & function : *module) |
226 | functions.emplace_back(function.getName()); |
227 | |
228 | llvm::orc::VModuleKey module_key = execution_session.allocateVModule(); |
229 | if (compile_layer.addModule(module_key, std::move(module))) |
230 | throw Exception("Cannot add module to compile layer" , ErrorCodes::CANNOT_COMPILE_CODE); |
231 | |
232 | for (const auto & name : functions) |
233 | { |
234 | std::string mangled_name; |
235 | llvm::raw_string_ostream mangled_name_stream(mangled_name); |
236 | llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, layout); |
237 | mangled_name_stream.flush(); |
238 | auto symbol = compile_layer.findSymbol(mangled_name, false); |
239 | if (!symbol) |
240 | continue; /// external function (e.g. an intrinsic that calls into libc) |
241 | auto address = symbol.getAddress(); |
242 | if (!address) |
243 | throw Exception("Function " + name + " failed to link" , ErrorCodes::CANNOT_COMPILE_CODE); |
244 | symbols[name] = reinterpret_cast<void *>(*address); |
245 | } |
246 | } |
247 | }; |
248 | |
249 | |
250 | template <typename... Ts, typename F> |
251 | static bool castToEither(IColumn * column, F && f) |
252 | { |
253 | return ((typeid_cast<Ts *>(column) ? f(*typeid_cast<Ts *>(column)) : false) || ...); |
254 | } |
255 | |
256 | class LLVMExecutableFunction : public IExecutableFunctionImpl |
257 | { |
258 | std::string name; |
259 | void * function; |
260 | |
261 | public: |
262 | LLVMExecutableFunction(const std::string & name_, const std::unordered_map<std::string, void *> & symbols) |
263 | : name(name_) |
264 | { |
265 | auto it = symbols.find(name); |
266 | if (symbols.end() == it) |
267 | throw Exception("Cannot find symbol " + name + " in LLVMContext" , ErrorCodes::LOGICAL_ERROR); |
268 | function = it->second; |
269 | } |
270 | |
271 | String getName() const override { return name; } |
272 | |
273 | bool useDefaultImplementationForNulls() const override { return false; } |
274 | |
275 | bool useDefaultImplementationForConstants() const override { return true; } |
276 | |
277 | void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t block_size) override |
278 | { |
279 | auto col_res = block.getByPosition(result).type->createColumn(); |
280 | |
281 | if (block_size) |
282 | { |
283 | if (!castToEither< |
284 | ColumnUInt8, ColumnUInt16, ColumnUInt32, ColumnUInt64, |
285 | ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64, |
286 | ColumnFloat32, ColumnFloat64>(col_res.get(), [block_size](auto & col) { col.getData().resize(block_size); return true; })) |
287 | throw Exception("Unexpected column in LLVMExecutableFunction: " + col_res->getName(), ErrorCodes::LOGICAL_ERROR); |
288 | |
289 | std::vector<ColumnData> columns(arguments.size() + 1); |
290 | for (size_t i = 0; i < arguments.size(); ++i) |
291 | { |
292 | auto * column = block.getByPosition(arguments[i]).column.get(); |
293 | if (!column) |
294 | throw Exception("Column " + block.getByPosition(arguments[i]).name + " is missing" , ErrorCodes::LOGICAL_ERROR); |
295 | columns[i] = getColumnData(column); |
296 | } |
297 | columns[arguments.size()] = getColumnData(col_res.get()); |
298 | reinterpret_cast<void (*) (size_t, ColumnData *)>(function)(block_size, columns.data()); |
299 | } |
300 | |
301 | block.getByPosition(result).column = std::move(col_res); |
302 | } |
303 | }; |
304 | |
305 | static void compileFunctionToLLVMByteCode(LLVMContext & context, const IFunctionBaseImpl & f) |
306 | { |
307 | ProfileEvents::increment(ProfileEvents::CompileFunction); |
308 | |
309 | auto & arg_types = f.getArgumentTypes(); |
310 | auto & b = context.builder; |
311 | auto * size_type = b.getIntNTy(sizeof(size_t) * 8); |
312 | auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type); |
313 | auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false); |
314 | auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), context.module.get()); |
315 | auto args = func->args().begin(); |
316 | llvm::Value * counter_arg = &*args++; |
317 | llvm::Value * columns_arg = &*args++; |
318 | |
319 | auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry" , func); |
320 | b.SetInsertPoint(entry); |
321 | std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1); |
322 | for (size_t i = 0; i <= arg_types.size(); ++i) |
323 | { |
324 | auto & type = i == arg_types.size() ? f.getReturnType() : arg_types[i]; |
325 | auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i)); |
326 | columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo()); |
327 | columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr; |
328 | columns[i].stride = b.CreateExtractValue(data, {2}); |
329 | } |
330 | |
331 | /// assume nonzero initial value in `counter_arg` |
332 | auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop" , func); |
333 | b.CreateBr(loop); |
334 | b.SetInsertPoint(loop); |
335 | auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2); |
336 | counter_phi->addIncoming(counter_arg, entry); |
337 | for (auto & col : columns) |
338 | { |
339 | col.data = b.CreatePHI(col.data_init->getType(), 2); |
340 | col.data->addIncoming(col.data_init, entry); |
341 | if (col.null_init) |
342 | { |
343 | col.null = b.CreatePHI(col.null_init->getType(), 2); |
344 | col.null->addIncoming(col.null_init, entry); |
345 | } |
346 | } |
347 | ValuePlaceholders arguments(arg_types.size()); |
348 | for (size_t i = 0; i < arguments.size(); ++i) |
349 | { |
350 | arguments[i] = [&b, &col = columns[i], &type = arg_types[i]]() -> llvm::Value * |
351 | { |
352 | auto * value = b.CreateLoad(col.data); |
353 | if (!col.null) |
354 | return value; |
355 | auto * is_null = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(0)); |
356 | auto * nullable = llvm::Constant::getNullValue(toNativeType(b, type)); |
357 | return b.CreateInsertValue(b.CreateInsertValue(nullable, value, {0}), is_null, {1}); |
358 | }; |
359 | } |
360 | auto * result = f.compile(b, std::move(arguments)); |
361 | if (columns.back().null) |
362 | { |
363 | b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data); |
364 | b.CreateStore(b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)), columns.back().null); |
365 | } |
366 | else |
367 | { |
368 | b.CreateStore(result, columns.back().data); |
369 | } |
370 | auto * cur_block = b.GetInsertBlock(); |
371 | for (auto & col : columns) |
372 | { |
373 | /// stride is either 0 or size of native type; output column is never constant; neither is at least one input |
374 | auto * is_const = &col == &columns.back() || columns.size() <= 2 ? b.getFalse() : b.CreateICmpEQ(col.stride, llvm::ConstantInt::get(size_type, 0)); |
375 | col.data->addIncoming(b.CreateSelect(is_const, col.data, b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1)), cur_block); |
376 | if (col.null) |
377 | col.null->addIncoming(b.CreateSelect(is_const, col.null, b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1)), cur_block); |
378 | } |
379 | counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block); |
380 | |
381 | auto * end = llvm::BasicBlock::Create(b.getContext(), "end" , func); |
382 | b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end); |
383 | b.SetInsertPoint(end); |
384 | b.CreateRetVoid(); |
385 | } |
386 | |
387 | static llvm::Constant * getNativeValue(llvm::Type * type, const IColumn & column, size_t i) |
388 | { |
389 | if (!type || column.size() <= i) |
390 | return nullptr; |
391 | if (auto * constant = typeid_cast<const ColumnConst *>(&column)) |
392 | return getNativeValue(type, constant->getDataColumn(), 0); |
393 | if (auto * nullable = typeid_cast<const ColumnNullable *>(&column)) |
394 | { |
395 | auto * value = getNativeValue(type->getContainedType(0), nullable->getNestedColumn(), i); |
396 | auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i)); |
397 | return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr; |
398 | } |
399 | if (type->isFloatTy()) |
400 | return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float32> &>(column).getElement(i)); |
401 | if (type->isDoubleTy()) |
402 | return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float64> &>(column).getElement(i)); |
403 | if (type->isIntegerTy()) |
404 | return llvm::ConstantInt::get(type, column.getUInt(i)); |
405 | /// TODO: if (type->isVectorTy()) |
406 | return nullptr; |
407 | } |
408 | |
409 | /// Same as IFunctionBase::compile, but also for constants and input columns. |
410 | using CompilableExpression = std::function<llvm::Value * (llvm::IRBuilderBase &, const ValuePlaceholders &)>; |
411 | |
412 | static CompilableExpression subexpression(ColumnPtr c, DataTypePtr type) |
413 | { |
414 | return [=](llvm::IRBuilderBase & b, const ValuePlaceholders &) { return getNativeValue(toNativeType(b, type), *c, 0); }; |
415 | } |
416 | |
417 | static CompilableExpression subexpression(size_t i) |
418 | { |
419 | return [=](llvm::IRBuilderBase &, const ValuePlaceholders & inputs) { return inputs[i](); }; |
420 | } |
421 | |
422 | static CompilableExpression subexpression(const IFunctionBase & f, std::vector<CompilableExpression> args) |
423 | { |
424 | return [&, args = std::move(args)](llvm::IRBuilderBase & builder, const ValuePlaceholders & inputs) |
425 | { |
426 | ValuePlaceholders input; |
427 | for (const auto & arg : args) |
428 | input.push_back([&]() { return arg(builder, inputs); }); |
429 | auto * result = f.compile(builder, input); |
430 | if (result->getType() != toNativeType(builder, f.getReturnType())) |
431 | throw Exception("Function " + f.getName() + " generated an llvm::Value of invalid type" , ErrorCodes::LOGICAL_ERROR); |
432 | return result; |
433 | }; |
434 | } |
435 | |
436 | struct LLVMModuleState |
437 | { |
438 | std::unordered_map<std::string, void *> symbols; |
439 | std::shared_ptr<llvm::LLVMContext> major_context; |
440 | std::shared_ptr<llvm::SectionMemoryManager> memory_manager; |
441 | }; |
442 | |
443 | LLVMFunction::LLVMFunction(const ExpressionActions::Actions & actions, const Block & sample_block) |
444 | : name(actions.back().result_name) |
445 | , module_state(std::make_unique<LLVMModuleState>()) |
446 | { |
447 | LLVMContext context; |
448 | for (const auto & c : sample_block) |
449 | /// TODO: implement `getNativeValue` for all types & replace the check with `c.column && toNativeType(...)` |
450 | if (c.column && getNativeValue(toNativeType(context.builder, c.type), *c.column, 0)) |
451 | subexpressions[c.name] = subexpression(c.column, c.type); |
452 | for (const auto & action : actions) |
453 | { |
454 | const auto & names = action.argument_names; |
455 | const auto & types = action.function_base->getArgumentTypes(); |
456 | std::vector<CompilableExpression> args; |
457 | for (size_t i = 0; i < names.size(); ++i) |
458 | { |
459 | auto inserted = subexpressions.emplace(names[i], subexpression(arg_names.size())); |
460 | if (inserted.second) |
461 | { |
462 | arg_names.push_back(names[i]); |
463 | arg_types.push_back(types[i]); |
464 | } |
465 | args.push_back(inserted.first->second); |
466 | } |
467 | subexpressions[action.result_name] = subexpression(*action.function_base, std::move(args)); |
468 | originals.push_back(action.function_base); |
469 | } |
470 | compileFunctionToLLVMByteCode(context, *this); |
471 | context.compileAllFunctionsToNativeCode(); |
472 | |
473 | module_state->symbols = context.symbols; |
474 | module_state->major_context = context.context; |
475 | module_state->memory_manager = context.memory_manager; |
476 | } |
477 | |
478 | llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const |
479 | { |
480 | auto it = subexpressions.find(name); |
481 | if (subexpressions.end() == it) |
482 | throw Exception("Cannot find subexpression " + name + " in LLVMFunction" , ErrorCodes::LOGICAL_ERROR); |
483 | return it->second(builder, values); |
484 | } |
485 | |
486 | ExecutableFunctionImplPtr LLVMFunction::prepare(const Block &, const ColumnNumbers &, size_t) const { return std::make_unique<LLVMExecutableFunction>(name, module_state->symbols); } |
487 | |
488 | bool LLVMFunction::isDeterministic() const |
489 | { |
490 | for (const auto & f : originals) |
491 | if (!f->isDeterministic()) |
492 | return false; |
493 | return true; |
494 | } |
495 | |
496 | bool LLVMFunction::isDeterministicInScopeOfQuery() const |
497 | { |
498 | for (const auto & f : originals) |
499 | if (!f->isDeterministicInScopeOfQuery()) |
500 | return false; |
501 | return true; |
502 | } |
503 | |
504 | bool LLVMFunction::isSuitableForConstantFolding() const |
505 | { |
506 | for (const auto & f : originals) |
507 | if (!f->isSuitableForConstantFolding()) |
508 | return false; |
509 | return true; |
510 | } |
511 | |
512 | bool LLVMFunction::isInjective(const Block & sample_block) |
513 | { |
514 | for (const auto & f : originals) |
515 | if (!f->isInjective(sample_block)) |
516 | return false; |
517 | return true; |
518 | } |
519 | |
520 | bool LLVMFunction::hasInformationAboutMonotonicity() const |
521 | { |
522 | for (const auto & f : originals) |
523 | if (!f->hasInformationAboutMonotonicity()) |
524 | return false; |
525 | return true; |
526 | } |
527 | |
528 | LLVMFunction::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const |
529 | { |
530 | const IDataType * type_ = &type; |
531 | Field left_ = left; |
532 | Field right_ = right; |
533 | Monotonicity result(true, true, true); |
534 | /// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls |
535 | for (size_t i = 0; i < originals.size(); ++i) |
536 | { |
537 | Monotonicity m = originals[i]->getMonotonicityForRange(*type_, left_, right_); |
538 | if (!m.is_monotonic) |
539 | return m; |
540 | result.is_positive ^= !m.is_positive; |
541 | result.is_always_monotonic &= m.is_always_monotonic; |
542 | if (i + 1 < originals.size()) |
543 | { |
544 | if (left_ != Field()) |
545 | applyFunction(*originals[i], left_); |
546 | if (right_ != Field()) |
547 | applyFunction(*originals[i], right_); |
548 | if (!m.is_positive) |
549 | std::swap(left_, right_); |
550 | type_ = originals[i]->getReturnType().get(); |
551 | } |
552 | } |
553 | return result; |
554 | } |
555 | |
556 | |
557 | static bool isCompilable(const IFunctionBase & function) |
558 | { |
559 | if (!canBeNativeType(*function.getReturnType())) |
560 | return false; |
561 | for (const auto & type : function.getArgumentTypes()) |
562 | if (!canBeNativeType(*type)) |
563 | return false; |
564 | return function.isCompilable(); |
565 | } |
566 | |
567 | static std::vector<std::unordered_set<std::optional<size_t>>> getActionsDependents(const ExpressionActions::Actions & actions, const Names & output_columns) |
568 | { |
569 | /// an empty optional is a poisoned value prohibiting the column's producer from being removed |
570 | /// (which it could be, if it was inlined into every dependent function). |
571 | std::unordered_map<std::string, std::unordered_set<std::optional<size_t>>> current_dependents; |
572 | for (const auto & name : output_columns) |
573 | current_dependents[name].emplace(); |
574 | /// a snapshot of each compilable function's dependents at the time of its execution. |
575 | std::vector<std::unordered_set<std::optional<size_t>>> dependents(actions.size()); |
576 | for (size_t i = actions.size(); i--;) |
577 | { |
578 | switch (actions[i].type) |
579 | { |
580 | case ExpressionAction::REMOVE_COLUMN: |
581 | current_dependents.erase(actions[i].source_name); |
582 | /// poison every other column used after this point so that inlining chains do not cross it. |
583 | for (auto & dep : current_dependents) |
584 | dep.second.emplace(); |
585 | break; |
586 | |
587 | case ExpressionAction::PROJECT: |
588 | current_dependents.clear(); |
589 | for (const auto & proj : actions[i].projection) |
590 | current_dependents[proj.first].emplace(); |
591 | break; |
592 | |
593 | case ExpressionAction::ADD_ALIASES: |
594 | for (const auto & proj : actions[i].projection) |
595 | current_dependents[proj.first].emplace(); |
596 | break; |
597 | |
598 | case ExpressionAction::ADD_COLUMN: |
599 | case ExpressionAction::COPY_COLUMN: |
600 | case ExpressionAction::ARRAY_JOIN: |
601 | case ExpressionAction::JOIN: |
602 | { |
603 | Names columns = actions[i].getNeededColumns(); |
604 | for (const auto & column : columns) |
605 | current_dependents[column].emplace(); |
606 | break; |
607 | } |
608 | |
609 | case ExpressionAction::APPLY_FUNCTION: |
610 | { |
611 | dependents[i] = current_dependents[actions[i].result_name]; |
612 | const bool compilable = isCompilable(*actions[i].function_base); |
613 | for (const auto & name : actions[i].argument_names) |
614 | { |
615 | if (compilable) |
616 | current_dependents[name].emplace(i); |
617 | else |
618 | current_dependents[name].emplace(); |
619 | } |
620 | break; |
621 | } |
622 | } |
623 | } |
624 | return dependents; |
625 | } |
626 | |
627 | void compileFunctions( |
628 | ExpressionActions::Actions & actions, |
629 | const Names & output_columns, |
630 | const Block & sample_block, |
631 | std::shared_ptr<CompiledExpressionCache> compilation_cache, |
632 | size_t min_count_to_compile_expression) |
633 | { |
634 | static std::unordered_map<UInt128, UInt32, UInt128Hash> counter; |
635 | static std::mutex mutex; |
636 | |
637 | struct LLVMTargetInitializer |
638 | { |
639 | LLVMTargetInitializer() |
640 | { |
641 | llvm::InitializeNativeTarget(); |
642 | llvm::InitializeNativeTargetAsmPrinter(); |
643 | llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); |
644 | } |
645 | }; |
646 | |
647 | static LLVMTargetInitializer initializer; |
648 | |
649 | auto dependents = getActionsDependents(actions, output_columns); |
650 | std::vector<ExpressionActions::Actions> fused(actions.size()); |
651 | for (size_t i = 0; i < actions.size(); ++i) |
652 | { |
653 | if (actions[i].type != ExpressionAction::APPLY_FUNCTION || !isCompilable(*actions[i].function_base)) |
654 | continue; |
655 | |
656 | fused[i].push_back(actions[i]); |
657 | if (dependents[i].find({}) != dependents[i].end()) |
658 | { |
659 | /// the result of compiling one function in isolation is pretty much the same as its `execute` method. |
660 | if (fused[i].size() == 1) |
661 | continue; |
662 | |
663 | auto hash_key = ExpressionActions::ActionsHash{}(fused[i]); |
664 | { |
665 | std::lock_guard lock(mutex); |
666 | if (counter[hash_key]++ < min_count_to_compile_expression) |
667 | continue; |
668 | } |
669 | |
670 | FunctionBasePtr fn; |
671 | if (compilation_cache) |
672 | { |
673 | std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&inlined_func=std::as_const(fused[i]), &sample_block] () |
674 | { |
675 | Stopwatch watch; |
676 | FunctionBasePtr result_fn; |
677 | result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(inlined_func, sample_block)); |
678 | ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds()); |
679 | return result_fn; |
680 | }); |
681 | } |
682 | else |
683 | { |
684 | Stopwatch watch; |
685 | fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(fused[i], sample_block)); |
686 | ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds()); |
687 | } |
688 | |
689 | actions[i].function_base = fn; |
690 | actions[i].argument_names = typeid_cast<const LLVMFunction *>(typeid_cast<const FunctionBaseAdaptor *>(fn.get())->getImpl())->getArgumentNames(); |
691 | actions[i].is_function_compiled = true; |
692 | |
693 | continue; |
694 | } |
695 | |
696 | /// TODO: determine whether it's profitable to inline the function if there's more than one dependent. |
697 | for (const auto & dep : dependents[i]) |
698 | fused[*dep].insert(fused[*dep].end(), fused[i].begin(), fused[i].end()); |
699 | } |
700 | |
701 | for (size_t i = 0; i < actions.size(); ++i) |
702 | { |
703 | if (actions[i].type == ExpressionAction::APPLY_FUNCTION && actions[i].is_function_compiled) |
704 | actions[i].function = actions[i].function_base->prepare({}, {}, 0); /// Arguments are not used for LLVMFunction. |
705 | } |
706 | } |
707 | |
708 | } |
709 | |
710 | #endif |
711 | |