1#include <Interpreters/ExpressionJIT.h>
2
3#if USE_EMBEDDED_COMPILER
4
5#include <optional>
6
7#include <Columns/ColumnConst.h>
8#include <Columns/ColumnNullable.h>
9#include <Columns/ColumnVector.h>
10#include <Common/LRUCache.h>
11#include <Common/typeid_cast.h>
12#include <Common/assert_cast.h>
13#include <Common/ProfileEvents.h>
14#include <Common/Stopwatch.h>
15#include <DataTypes/DataTypeNullable.h>
16#include <DataTypes/DataTypesNumber.h>
17#include <DataTypes/Native.h>
18#include <Functions/IFunctionAdaptors.h>
19
20#pragma GCC diagnostic push
21#pragma GCC diagnostic ignored "-Wunused-parameter"
22#pragma GCC diagnostic ignored "-Wnon-virtual-dtor"
23
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/IR/BasicBlock.h>
26#include <llvm/IR/DataLayout.h>
27#include <llvm/IR/DerivedTypes.h>
28#include <llvm/IR/Function.h>
29#include <llvm/IR/IRBuilder.h>
30#include <llvm/IR/LLVMContext.h>
31#include <llvm/IR/Mangler.h>
32#include <llvm/IR/Module.h>
33#include <llvm/IR/Type.h>
34#include <llvm/IR/LegacyPassManager.h>
35#include <llvm/ExecutionEngine/ExecutionEngine.h>
36#include <llvm/ExecutionEngine/JITSymbol.h>
37#include <llvm/ExecutionEngine/SectionMemoryManager.h>
38#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
39#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
40#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
41#include <llvm/Target/TargetMachine.h>
42#include <llvm/MC/SubtargetFeature.h>
43#include <llvm/Support/DynamicLibrary.h>
44#include <llvm/Support/Host.h>
45#include <llvm/Support/TargetRegistry.h>
46#include <llvm/Support/TargetSelect.h>
47#include <llvm/Transforms/IPO/PassManagerBuilder.h>
48
49#pragma GCC diagnostic pop
50
51/// 'LegacyRTDyldObjectLinkingLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use ORCv2
52/// 'LegacyIRCompileLayer' is deprecated: ORCv1 layers (layers with the 'Legacy' prefix) are deprecated. Please use the ORCv2 IRCompileLayer instead
53#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
54
55
56namespace ProfileEvents
57{
58 extern const Event CompileFunction;
59 extern const Event CompileExpressionsMicroseconds;
60 extern const Event CompileExpressionsBytes;
61}
62
63namespace DB
64{
65
66namespace ErrorCodes
67{
68 extern const int LOGICAL_ERROR;
69 extern const int CANNOT_COMPILE_CODE;
70}
71
72namespace
73{
74 struct ColumnData
75 {
76 const char * data = nullptr;
77 const char * null = nullptr;
78 size_t stride = 0;
79 };
80
81 struct ColumnDataPlaceholder
82 {
83 llvm::Value * data_init; /// first row
84 llvm::Value * null_init;
85 llvm::Value * stride;
86 llvm::PHINode * data; /// current row
87 llvm::PHINode * null;
88 };
89}
90
91static ColumnData getColumnData(const IColumn * column)
92{
93 ColumnData result;
94 const bool is_const = isColumnConst(*column);
95 if (is_const)
96 column = &reinterpret_cast<const ColumnConst *>(column)->getDataColumn();
97 if (auto * nullable = typeid_cast<const ColumnNullable *>(column))
98 {
99 result.null = nullable->getNullMapColumn().getRawData().data;
100 column = &nullable->getNestedColumn();
101 }
102 result.data = column->getRawData().data;
103 result.stride = is_const ? 0 : column->sizeOfValueIfFixed();
104 return result;
105}
106
107static void applyFunction(IFunctionBase & function, Field & value)
108{
109 const auto & type = function.getArgumentTypes().at(0);
110 Block block = {{ type->createColumnConst(1, value), type, "x" }, { nullptr, function.getReturnType(), "y" }};
111 function.execute(block, {0}, 1, 1);
112 block.safeGetByPosition(1).column->get(0, value);
113}
114
115static llvm::TargetMachine * getNativeMachine()
116{
117 std::string error;
118 auto cpu = llvm::sys::getHostCPUName();
119 auto triple = llvm::sys::getProcessTriple();
120 auto target = llvm::TargetRegistry::lookupTarget(triple, error);
121 if (!target)
122 throw Exception("Could not initialize native target: " + error, ErrorCodes::CANNOT_COMPILE_CODE);
123 llvm::SubtargetFeatures features;
124 llvm::StringMap<bool> feature_map;
125 if (llvm::sys::getHostCPUFeatures(feature_map))
126 for (auto & f : feature_map)
127 features.AddFeature(f.first(), f.second);
128 llvm::TargetOptions options;
129 return target->createTargetMachine(
130 triple, cpu, features.getString(), options, llvm::None,
131 llvm::None, llvm::CodeGenOpt::Default, /*jit=*/true
132 );
133}
134
135
136struct SymbolResolver : public llvm::orc::SymbolResolver
137{
138 llvm::LegacyJITSymbolResolver & impl;
139
140 SymbolResolver(llvm::LegacyJITSymbolResolver & impl_) : impl(impl_) {}
141
142 llvm::orc::SymbolNameSet getResponsibilitySet(const llvm::orc::SymbolNameSet & symbols) final
143 {
144 return symbols;
145 }
146
147 llvm::orc::SymbolNameSet lookup(std::shared_ptr<llvm::orc::AsynchronousSymbolQuery> query, llvm::orc::SymbolNameSet symbols) final
148 {
149 llvm::orc::SymbolNameSet missing;
150 for (const auto & symbol : symbols)
151 {
152 bool has_resolved = false;
153 impl.lookup({*symbol}, [&](llvm::Expected<llvm::JITSymbolResolver::LookupResult> resolved)
154 {
155 if (resolved && resolved->size())
156 {
157 query->notifySymbolMetRequiredState(symbol, resolved->begin()->second);
158 has_resolved = true;
159 }
160 });
161
162 if (!has_resolved)
163 missing.insert(symbol);
164 }
165 return missing;
166 }
167};
168
169
170struct LLVMContext
171{
172 std::shared_ptr<llvm::LLVMContext> context {std::make_shared<llvm::LLVMContext>()};
173 std::unique_ptr<llvm::Module> module {std::make_unique<llvm::Module>("jit", *context)};
174 std::unique_ptr<llvm::TargetMachine> machine {getNativeMachine()};
175 llvm::DataLayout layout {machine->createDataLayout()};
176 llvm::IRBuilder<> builder {*context};
177
178 llvm::orc::ExecutionSession execution_session;
179
180 std::shared_ptr<llvm::SectionMemoryManager> memory_manager;
181 llvm::orc::LegacyRTDyldObjectLinkingLayer object_layer;
182 llvm::orc::LegacyIRCompileLayer<decltype(object_layer), llvm::orc::SimpleCompiler> compile_layer;
183
184 std::unordered_map<std::string, void *> symbols;
185
186 LLVMContext()
187 : memory_manager(std::make_shared<llvm::SectionMemoryManager>())
188 , object_layer(execution_session, [this](llvm::orc::VModuleKey)
189 {
190 return llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources{memory_manager, std::make_shared<SymbolResolver>(*memory_manager)};
191 })
192 , compile_layer(object_layer, llvm::orc::SimpleCompiler(*machine))
193 {
194 module->setDataLayout(layout);
195 module->setTargetTriple(machine->getTargetTriple().getTriple());
196 }
197
198 /// returns used memory
199 void compileAllFunctionsToNativeCode()
200 {
201 if (!module->size())
202 return;
203 llvm::PassManagerBuilder pass_manager_builder;
204 llvm::legacy::PassManager mpm;
205 llvm::legacy::FunctionPassManager fpm(module.get());
206 pass_manager_builder.OptLevel = 3;
207 pass_manager_builder.SLPVectorize = true;
208 pass_manager_builder.LoopVectorize = true;
209 pass_manager_builder.RerollLoops = true;
210 pass_manager_builder.VerifyInput = true;
211 pass_manager_builder.VerifyOutput = true;
212 machine->adjustPassManager(pass_manager_builder);
213 fpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
214 mpm.add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis()));
215 pass_manager_builder.populateFunctionPassManager(fpm);
216 pass_manager_builder.populateModulePassManager(mpm);
217 fpm.doInitialization();
218 for (auto & function : *module)
219 fpm.run(function);
220 fpm.doFinalization();
221 mpm.run(*module);
222
223 std::vector<std::string> functions;
224 functions.reserve(module->size());
225 for (const auto & function : *module)
226 functions.emplace_back(function.getName());
227
228 llvm::orc::VModuleKey module_key = execution_session.allocateVModule();
229 if (compile_layer.addModule(module_key, std::move(module)))
230 throw Exception("Cannot add module to compile layer", ErrorCodes::CANNOT_COMPILE_CODE);
231
232 for (const auto & name : functions)
233 {
234 std::string mangled_name;
235 llvm::raw_string_ostream mangled_name_stream(mangled_name);
236 llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, layout);
237 mangled_name_stream.flush();
238 auto symbol = compile_layer.findSymbol(mangled_name, false);
239 if (!symbol)
240 continue; /// external function (e.g. an intrinsic that calls into libc)
241 auto address = symbol.getAddress();
242 if (!address)
243 throw Exception("Function " + name + " failed to link", ErrorCodes::CANNOT_COMPILE_CODE);
244 symbols[name] = reinterpret_cast<void *>(*address);
245 }
246 }
247};
248
249
250template <typename... Ts, typename F>
251static bool castToEither(IColumn * column, F && f)
252{
253 return ((typeid_cast<Ts *>(column) ? f(*typeid_cast<Ts *>(column)) : false) || ...);
254}
255
256class LLVMExecutableFunction : public IExecutableFunctionImpl
257{
258 std::string name;
259 void * function;
260
261public:
262 LLVMExecutableFunction(const std::string & name_, const std::unordered_map<std::string, void *> & symbols)
263 : name(name_)
264 {
265 auto it = symbols.find(name);
266 if (symbols.end() == it)
267 throw Exception("Cannot find symbol " + name + " in LLVMContext", ErrorCodes::LOGICAL_ERROR);
268 function = it->second;
269 }
270
271 String getName() const override { return name; }
272
273 bool useDefaultImplementationForNulls() const override { return false; }
274
275 bool useDefaultImplementationForConstants() const override { return true; }
276
277 void execute(Block & block, const ColumnNumbers & arguments, size_t result, size_t block_size) override
278 {
279 auto col_res = block.getByPosition(result).type->createColumn();
280
281 if (block_size)
282 {
283 if (!castToEither<
284 ColumnUInt8, ColumnUInt16, ColumnUInt32, ColumnUInt64,
285 ColumnInt8, ColumnInt16, ColumnInt32, ColumnInt64,
286 ColumnFloat32, ColumnFloat64>(col_res.get(), [block_size](auto & col) { col.getData().resize(block_size); return true; }))
287 throw Exception("Unexpected column in LLVMExecutableFunction: " + col_res->getName(), ErrorCodes::LOGICAL_ERROR);
288
289 std::vector<ColumnData> columns(arguments.size() + 1);
290 for (size_t i = 0; i < arguments.size(); ++i)
291 {
292 auto * column = block.getByPosition(arguments[i]).column.get();
293 if (!column)
294 throw Exception("Column " + block.getByPosition(arguments[i]).name + " is missing", ErrorCodes::LOGICAL_ERROR);
295 columns[i] = getColumnData(column);
296 }
297 columns[arguments.size()] = getColumnData(col_res.get());
298 reinterpret_cast<void (*) (size_t, ColumnData *)>(function)(block_size, columns.data());
299 }
300
301 block.getByPosition(result).column = std::move(col_res);
302 }
303};
304
305static void compileFunctionToLLVMByteCode(LLVMContext & context, const IFunctionBaseImpl & f)
306{
307 ProfileEvents::increment(ProfileEvents::CompileFunction);
308
309 auto & arg_types = f.getArgumentTypes();
310 auto & b = context.builder;
311 auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
312 auto * data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy(), size_type);
313 auto * func_type = llvm::FunctionType::get(b.getVoidTy(), { size_type, data_type->getPointerTo() }, /*isVarArg=*/false);
314 auto * func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, f.getName(), context.module.get());
315 auto args = func->args().begin();
316 llvm::Value * counter_arg = &*args++;
317 llvm::Value * columns_arg = &*args++;
318
319 auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", func);
320 b.SetInsertPoint(entry);
321 std::vector<ColumnDataPlaceholder> columns(arg_types.size() + 1);
322 for (size_t i = 0; i <= arg_types.size(); ++i)
323 {
324 auto & type = i == arg_types.size() ? f.getReturnType() : arg_types[i];
325 auto * data = b.CreateLoad(b.CreateConstInBoundsGEP1_32(data_type, columns_arg, i));
326 columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(type))->getPointerTo());
327 columns[i].null_init = type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
328 columns[i].stride = b.CreateExtractValue(data, {2});
329 }
330
331 /// assume nonzero initial value in `counter_arg`
332 auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", func);
333 b.CreateBr(loop);
334 b.SetInsertPoint(loop);
335 auto * counter_phi = b.CreatePHI(counter_arg->getType(), 2);
336 counter_phi->addIncoming(counter_arg, entry);
337 for (auto & col : columns)
338 {
339 col.data = b.CreatePHI(col.data_init->getType(), 2);
340 col.data->addIncoming(col.data_init, entry);
341 if (col.null_init)
342 {
343 col.null = b.CreatePHI(col.null_init->getType(), 2);
344 col.null->addIncoming(col.null_init, entry);
345 }
346 }
347 ValuePlaceholders arguments(arg_types.size());
348 for (size_t i = 0; i < arguments.size(); ++i)
349 {
350 arguments[i] = [&b, &col = columns[i], &type = arg_types[i]]() -> llvm::Value *
351 {
352 auto * value = b.CreateLoad(col.data);
353 if (!col.null)
354 return value;
355 auto * is_null = b.CreateICmpNE(b.CreateLoad(col.null), b.getInt8(0));
356 auto * nullable = llvm::Constant::getNullValue(toNativeType(b, type));
357 return b.CreateInsertValue(b.CreateInsertValue(nullable, value, {0}), is_null, {1});
358 };
359 }
360 auto * result = f.compile(b, std::move(arguments));
361 if (columns.back().null)
362 {
363 b.CreateStore(b.CreateExtractValue(result, {0}), columns.back().data);
364 b.CreateStore(b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)), columns.back().null);
365 }
366 else
367 {
368 b.CreateStore(result, columns.back().data);
369 }
370 auto * cur_block = b.GetInsertBlock();
371 for (auto & col : columns)
372 {
373 /// stride is either 0 or size of native type; output column is never constant; neither is at least one input
374 auto * is_const = &col == &columns.back() || columns.size() <= 2 ? b.getFalse() : b.CreateICmpEQ(col.stride, llvm::ConstantInt::get(size_type, 0));
375 col.data->addIncoming(b.CreateSelect(is_const, col.data, b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1)), cur_block);
376 if (col.null)
377 col.null->addIncoming(b.CreateSelect(is_const, col.null, b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1)), cur_block);
378 }
379 counter_phi->addIncoming(b.CreateSub(counter_phi, llvm::ConstantInt::get(size_type, 1)), cur_block);
380
381 auto * end = llvm::BasicBlock::Create(b.getContext(), "end", func);
382 b.CreateCondBr(b.CreateICmpNE(counter_phi, llvm::ConstantInt::get(size_type, 1)), loop, end);
383 b.SetInsertPoint(end);
384 b.CreateRetVoid();
385}
386
387static llvm::Constant * getNativeValue(llvm::Type * type, const IColumn & column, size_t i)
388{
389 if (!type || column.size() <= i)
390 return nullptr;
391 if (auto * constant = typeid_cast<const ColumnConst *>(&column))
392 return getNativeValue(type, constant->getDataColumn(), 0);
393 if (auto * nullable = typeid_cast<const ColumnNullable *>(&column))
394 {
395 auto * value = getNativeValue(type->getContainedType(0), nullable->getNestedColumn(), i);
396 auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable->isNullAt(i));
397 return value ? llvm::ConstantStruct::get(static_cast<llvm::StructType *>(type), value, is_null) : nullptr;
398 }
399 if (type->isFloatTy())
400 return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float32> &>(column).getElement(i));
401 if (type->isDoubleTy())
402 return llvm::ConstantFP::get(type, assert_cast<const ColumnVector<Float64> &>(column).getElement(i));
403 if (type->isIntegerTy())
404 return llvm::ConstantInt::get(type, column.getUInt(i));
405 /// TODO: if (type->isVectorTy())
406 return nullptr;
407}
408
409/// Same as IFunctionBase::compile, but also for constants and input columns.
410using CompilableExpression = std::function<llvm::Value * (llvm::IRBuilderBase &, const ValuePlaceholders &)>;
411
412static CompilableExpression subexpression(ColumnPtr c, DataTypePtr type)
413{
414 return [=](llvm::IRBuilderBase & b, const ValuePlaceholders &) { return getNativeValue(toNativeType(b, type), *c, 0); };
415}
416
417static CompilableExpression subexpression(size_t i)
418{
419 return [=](llvm::IRBuilderBase &, const ValuePlaceholders & inputs) { return inputs[i](); };
420}
421
422static CompilableExpression subexpression(const IFunctionBase & f, std::vector<CompilableExpression> args)
423{
424 return [&, args = std::move(args)](llvm::IRBuilderBase & builder, const ValuePlaceholders & inputs)
425 {
426 ValuePlaceholders input;
427 for (const auto & arg : args)
428 input.push_back([&]() { return arg(builder, inputs); });
429 auto * result = f.compile(builder, input);
430 if (result->getType() != toNativeType(builder, f.getReturnType()))
431 throw Exception("Function " + f.getName() + " generated an llvm::Value of invalid type", ErrorCodes::LOGICAL_ERROR);
432 return result;
433 };
434}
435
436struct LLVMModuleState
437{
438 std::unordered_map<std::string, void *> symbols;
439 std::shared_ptr<llvm::LLVMContext> major_context;
440 std::shared_ptr<llvm::SectionMemoryManager> memory_manager;
441};
442
443LLVMFunction::LLVMFunction(const ExpressionActions::Actions & actions, const Block & sample_block)
444 : name(actions.back().result_name)
445 , module_state(std::make_unique<LLVMModuleState>())
446{
447 LLVMContext context;
448 for (const auto & c : sample_block)
449 /// TODO: implement `getNativeValue` for all types & replace the check with `c.column && toNativeType(...)`
450 if (c.column && getNativeValue(toNativeType(context.builder, c.type), *c.column, 0))
451 subexpressions[c.name] = subexpression(c.column, c.type);
452 for (const auto & action : actions)
453 {
454 const auto & names = action.argument_names;
455 const auto & types = action.function_base->getArgumentTypes();
456 std::vector<CompilableExpression> args;
457 for (size_t i = 0; i < names.size(); ++i)
458 {
459 auto inserted = subexpressions.emplace(names[i], subexpression(arg_names.size()));
460 if (inserted.second)
461 {
462 arg_names.push_back(names[i]);
463 arg_types.push_back(types[i]);
464 }
465 args.push_back(inserted.first->second);
466 }
467 subexpressions[action.result_name] = subexpression(*action.function_base, std::move(args));
468 originals.push_back(action.function_base);
469 }
470 compileFunctionToLLVMByteCode(context, *this);
471 context.compileAllFunctionsToNativeCode();
472
473 module_state->symbols = context.symbols;
474 module_state->major_context = context.context;
475 module_state->memory_manager = context.memory_manager;
476}
477
478llvm::Value * LLVMFunction::compile(llvm::IRBuilderBase & builder, ValuePlaceholders values) const
479{
480 auto it = subexpressions.find(name);
481 if (subexpressions.end() == it)
482 throw Exception("Cannot find subexpression " + name + " in LLVMFunction", ErrorCodes::LOGICAL_ERROR);
483 return it->second(builder, values);
484}
485
486ExecutableFunctionImplPtr LLVMFunction::prepare(const Block &, const ColumnNumbers &, size_t) const { return std::make_unique<LLVMExecutableFunction>(name, module_state->symbols); }
487
488bool LLVMFunction::isDeterministic() const
489{
490 for (const auto & f : originals)
491 if (!f->isDeterministic())
492 return false;
493 return true;
494}
495
496bool LLVMFunction::isDeterministicInScopeOfQuery() const
497{
498 for (const auto & f : originals)
499 if (!f->isDeterministicInScopeOfQuery())
500 return false;
501 return true;
502}
503
504bool LLVMFunction::isSuitableForConstantFolding() const
505{
506 for (const auto & f : originals)
507 if (!f->isSuitableForConstantFolding())
508 return false;
509 return true;
510}
511
512bool LLVMFunction::isInjective(const Block & sample_block)
513{
514 for (const auto & f : originals)
515 if (!f->isInjective(sample_block))
516 return false;
517 return true;
518}
519
520bool LLVMFunction::hasInformationAboutMonotonicity() const
521{
522 for (const auto & f : originals)
523 if (!f->hasInformationAboutMonotonicity())
524 return false;
525 return true;
526}
527
528LLVMFunction::Monotonicity LLVMFunction::getMonotonicityForRange(const IDataType & type, const Field & left, const Field & right) const
529{
530 const IDataType * type_ = &type;
531 Field left_ = left;
532 Field right_ = right;
533 Monotonicity result(true, true, true);
534 /// monotonicity is only defined for unary functions, so the chain must describe a sequence of nested calls
535 for (size_t i = 0; i < originals.size(); ++i)
536 {
537 Monotonicity m = originals[i]->getMonotonicityForRange(*type_, left_, right_);
538 if (!m.is_monotonic)
539 return m;
540 result.is_positive ^= !m.is_positive;
541 result.is_always_monotonic &= m.is_always_monotonic;
542 if (i + 1 < originals.size())
543 {
544 if (left_ != Field())
545 applyFunction(*originals[i], left_);
546 if (right_ != Field())
547 applyFunction(*originals[i], right_);
548 if (!m.is_positive)
549 std::swap(left_, right_);
550 type_ = originals[i]->getReturnType().get();
551 }
552 }
553 return result;
554}
555
556
557static bool isCompilable(const IFunctionBase & function)
558{
559 if (!canBeNativeType(*function.getReturnType()))
560 return false;
561 for (const auto & type : function.getArgumentTypes())
562 if (!canBeNativeType(*type))
563 return false;
564 return function.isCompilable();
565}
566
567static std::vector<std::unordered_set<std::optional<size_t>>> getActionsDependents(const ExpressionActions::Actions & actions, const Names & output_columns)
568{
569 /// an empty optional is a poisoned value prohibiting the column's producer from being removed
570 /// (which it could be, if it was inlined into every dependent function).
571 std::unordered_map<std::string, std::unordered_set<std::optional<size_t>>> current_dependents;
572 for (const auto & name : output_columns)
573 current_dependents[name].emplace();
574 /// a snapshot of each compilable function's dependents at the time of its execution.
575 std::vector<std::unordered_set<std::optional<size_t>>> dependents(actions.size());
576 for (size_t i = actions.size(); i--;)
577 {
578 switch (actions[i].type)
579 {
580 case ExpressionAction::REMOVE_COLUMN:
581 current_dependents.erase(actions[i].source_name);
582 /// poison every other column used after this point so that inlining chains do not cross it.
583 for (auto & dep : current_dependents)
584 dep.second.emplace();
585 break;
586
587 case ExpressionAction::PROJECT:
588 current_dependents.clear();
589 for (const auto & proj : actions[i].projection)
590 current_dependents[proj.first].emplace();
591 break;
592
593 case ExpressionAction::ADD_ALIASES:
594 for (const auto & proj : actions[i].projection)
595 current_dependents[proj.first].emplace();
596 break;
597
598 case ExpressionAction::ADD_COLUMN:
599 case ExpressionAction::COPY_COLUMN:
600 case ExpressionAction::ARRAY_JOIN:
601 case ExpressionAction::JOIN:
602 {
603 Names columns = actions[i].getNeededColumns();
604 for (const auto & column : columns)
605 current_dependents[column].emplace();
606 break;
607 }
608
609 case ExpressionAction::APPLY_FUNCTION:
610 {
611 dependents[i] = current_dependents[actions[i].result_name];
612 const bool compilable = isCompilable(*actions[i].function_base);
613 for (const auto & name : actions[i].argument_names)
614 {
615 if (compilable)
616 current_dependents[name].emplace(i);
617 else
618 current_dependents[name].emplace();
619 }
620 break;
621 }
622 }
623 }
624 return dependents;
625}
626
627void compileFunctions(
628 ExpressionActions::Actions & actions,
629 const Names & output_columns,
630 const Block & sample_block,
631 std::shared_ptr<CompiledExpressionCache> compilation_cache,
632 size_t min_count_to_compile_expression)
633{
634 static std::unordered_map<UInt128, UInt32, UInt128Hash> counter;
635 static std::mutex mutex;
636
637 struct LLVMTargetInitializer
638 {
639 LLVMTargetInitializer()
640 {
641 llvm::InitializeNativeTarget();
642 llvm::InitializeNativeTargetAsmPrinter();
643 llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
644 }
645 };
646
647 static LLVMTargetInitializer initializer;
648
649 auto dependents = getActionsDependents(actions, output_columns);
650 std::vector<ExpressionActions::Actions> fused(actions.size());
651 for (size_t i = 0; i < actions.size(); ++i)
652 {
653 if (actions[i].type != ExpressionAction::APPLY_FUNCTION || !isCompilable(*actions[i].function_base))
654 continue;
655
656 fused[i].push_back(actions[i]);
657 if (dependents[i].find({}) != dependents[i].end())
658 {
659 /// the result of compiling one function in isolation is pretty much the same as its `execute` method.
660 if (fused[i].size() == 1)
661 continue;
662
663 auto hash_key = ExpressionActions::ActionsHash{}(fused[i]);
664 {
665 std::lock_guard lock(mutex);
666 if (counter[hash_key]++ < min_count_to_compile_expression)
667 continue;
668 }
669
670 FunctionBasePtr fn;
671 if (compilation_cache)
672 {
673 std::tie(fn, std::ignore) = compilation_cache->getOrSet(hash_key, [&inlined_func=std::as_const(fused[i]), &sample_block] ()
674 {
675 Stopwatch watch;
676 FunctionBasePtr result_fn;
677 result_fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(inlined_func, sample_block));
678 ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
679 return result_fn;
680 });
681 }
682 else
683 {
684 Stopwatch watch;
685 fn = std::make_shared<FunctionBaseAdaptor>(std::make_unique<LLVMFunction>(fused[i], sample_block));
686 ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
687 }
688
689 actions[i].function_base = fn;
690 actions[i].argument_names = typeid_cast<const LLVMFunction *>(typeid_cast<const FunctionBaseAdaptor *>(fn.get())->getImpl())->getArgumentNames();
691 actions[i].is_function_compiled = true;
692
693 continue;
694 }
695
696 /// TODO: determine whether it's profitable to inline the function if there's more than one dependent.
697 for (const auto & dep : dependents[i])
698 fused[*dep].insert(fused[*dep].end(), fused[i].begin(), fused[i].end());
699 }
700
701 for (size_t i = 0; i < actions.size(); ++i)
702 {
703 if (actions[i].type == ExpressionAction::APPLY_FUNCTION && actions[i].is_function_compiled)
704 actions[i].function = actions[i].function_base->prepare({}, {}, 0); /// Arguments are not used for LLVMFunction.
705 }
706}
707
708}
709
710#endif
711