1//===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "OffloadWrapper.h"
10#include "llvm/ADT/ArrayRef.h"
11#include "llvm/IR/Constants.h"
12#include "llvm/IR/GlobalVariable.h"
13#include "llvm/IR/IRBuilder.h"
14#include "llvm/IR/LLVMContext.h"
15#include "llvm/IR/Module.h"
16#include "llvm/Object/OffloadBinary.h"
17#include "llvm/Support/Error.h"
18#include "llvm/TargetParser/Triple.h"
19#include "llvm/Transforms/Utils/ModuleUtils.h"
20
21using namespace llvm;
22
23namespace {
24/// Magic number that begins the section containing the CUDA fatbinary.
25constexpr unsigned CudaFatMagic = 0x466243b1;
26constexpr unsigned HIPFatMagic = 0x48495046;
27
28/// Copied from clang/CGCudaRuntime.h.
29enum OffloadEntryKindFlag : uint32_t {
30 /// Mark the entry as a global entry. This indicates the presense of a
31 /// kernel if the size size field is zero and a variable otherwise.
32 OffloadGlobalEntry = 0x0,
33 /// Mark the entry as a managed global variable.
34 OffloadGlobalManagedEntry = 0x1,
35 /// Mark the entry as a surface variable.
36 OffloadGlobalSurfaceEntry = 0x2,
37 /// Mark the entry as a texture variable.
38 OffloadGlobalTextureEntry = 0x3,
39};
40
41IntegerType *getSizeTTy(Module &M) {
42 LLVMContext &C = M.getContext();
43 switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) {
44 case 4u:
45 return Type::getInt32Ty(C);
46 case 8u:
47 return Type::getInt64Ty(C);
48 }
49 llvm_unreachable("unsupported pointer type size");
50}
51
52// struct __tgt_offload_entry {
53// void *addr;
54// char *name;
55// size_t size;
56// int32_t flags;
57// int32_t reserved;
58// };
59StructType *getEntryTy(Module &M) {
60 LLVMContext &C = M.getContext();
61 StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry");
62 if (!EntryTy)
63 EntryTy = StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C),
64 Type::getInt8PtrTy(C), getSizeTTy(M),
65 Type::getInt32Ty(C), Type::getInt32Ty(C));
66 return EntryTy;
67}
68
69PointerType *getEntryPtrTy(Module &M) {
70 return PointerType::getUnqual(getEntryTy(M));
71}
72
73// struct __tgt_device_image {
74// void *ImageStart;
75// void *ImageEnd;
76// __tgt_offload_entry *EntriesBegin;
77// __tgt_offload_entry *EntriesEnd;
78// };
79StructType *getDeviceImageTy(Module &M) {
80 LLVMContext &C = M.getContext();
81 StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image");
82 if (!ImageTy)
83 ImageTy = StructType::create("__tgt_device_image", Type::getInt8PtrTy(C),
84 Type::getInt8PtrTy(C), getEntryPtrTy(M),
85 getEntryPtrTy(M));
86 return ImageTy;
87}
88
89PointerType *getDeviceImagePtrTy(Module &M) {
90 return PointerType::getUnqual(getDeviceImageTy(M));
91}
92
93// struct __tgt_bin_desc {
94// int32_t NumDeviceImages;
95// __tgt_device_image *DeviceImages;
96// __tgt_offload_entry *HostEntriesBegin;
97// __tgt_offload_entry *HostEntriesEnd;
98// };
99StructType *getBinDescTy(Module &M) {
100 LLVMContext &C = M.getContext();
101 StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc");
102 if (!DescTy)
103 DescTy = StructType::create("__tgt_bin_desc", Type::getInt32Ty(C),
104 getDeviceImagePtrTy(M), getEntryPtrTy(M),
105 getEntryPtrTy(M));
106 return DescTy;
107}
108
109PointerType *getBinDescPtrTy(Module &M) {
110 return PointerType::getUnqual(getBinDescTy(M));
111}
112
113/// Creates binary descriptor for the given device images. Binary descriptor
114/// is an object that is passed to the offloading runtime at program startup
115/// and it describes all device images available in the executable or shared
116/// library. It is defined as follows
117///
118/// __attribute__((visibility("hidden")))
119/// extern __tgt_offload_entry *__start_omp_offloading_entries;
120/// __attribute__((visibility("hidden")))
121/// extern __tgt_offload_entry *__stop_omp_offloading_entries;
122///
123/// static const char Image0[] = { <Bufs.front() contents> };
124/// ...
125/// static const char ImageN[] = { <Bufs.back() contents> };
126///
127/// static const __tgt_device_image Images[] = {
128/// {
129/// Image0, /*ImageStart*/
130/// Image0 + sizeof(Image0), /*ImageEnd*/
131/// __start_omp_offloading_entries, /*EntriesBegin*/
132/// __stop_omp_offloading_entries /*EntriesEnd*/
133/// },
134/// ...
135/// {
136/// ImageN, /*ImageStart*/
137/// ImageN + sizeof(ImageN), /*ImageEnd*/
138/// __start_omp_offloading_entries, /*EntriesBegin*/
139/// __stop_omp_offloading_entries /*EntriesEnd*/
140/// }
141/// };
142///
143/// static const __tgt_bin_desc BinDesc = {
144/// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
145/// Images, /*DeviceImages*/
146/// __start_omp_offloading_entries, /*HostEntriesBegin*/
147/// __stop_omp_offloading_entries /*HostEntriesEnd*/
148/// };
149///
150/// Global variable that represents BinDesc is returned.
151GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) {
152 LLVMContext &C = M.getContext();
153 // Create external begin/end symbols for the offload entries table.
154 auto *EntriesB = new GlobalVariable(
155 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
156 /*Initializer*/ nullptr, "__start_omp_offloading_entries");
157 EntriesB->setVisibility(GlobalValue::HiddenVisibility);
158 auto *EntriesE = new GlobalVariable(
159 M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage,
160 /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
161 EntriesE->setVisibility(GlobalValue::HiddenVisibility);
162
163 // We assume that external begin/end symbols that we have created above will
164 // be defined by the linker. But linker will do that only if linker inputs
165 // have section with "omp_offloading_entries" name which is not guaranteed.
166 // So, we just create dummy zero sized object in the offload entries section
167 // to force linker to define those symbols.
168 auto *DummyInit =
169 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
170 auto *DummyEntry = new GlobalVariable(
171 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
172 "__dummy.omp_offloading.entry");
173 DummyEntry->setSection("omp_offloading_entries");
174 DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
175
176 auto *Zero = ConstantInt::get(getSizeTTy(M), 0u);
177 Constant *ZeroZero[] = {Zero, Zero};
178
179 // Create initializer for the images array.
180 SmallVector<Constant *, 4u> ImagesInits;
181 ImagesInits.reserve(Bufs.size());
182 for (ArrayRef<char> Buf : Bufs) {
183 auto *Data = ConstantDataArray::get(C, Buf);
184 auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
185 GlobalVariable::InternalLinkage, Data,
186 ".omp_offloading.device_image");
187 Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
188 Image->setSection(".llvm.offloading");
189 Image->setAlignment(Align(object::OffloadBinary::getAlignment()));
190
191 auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size());
192 Constant *ZeroSize[] = {Zero, Size};
193
194 auto *ImageB =
195 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero);
196 auto *ImageE =
197 ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize);
198
199 ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB,
200 ImageE, EntriesB, EntriesE));
201 }
202
203 // Then create images array.
204 auto *ImagesData = ConstantArray::get(
205 ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits);
206
207 auto *Images =
208 new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true,
209 GlobalValue::InternalLinkage, ImagesData,
210 ".omp_offloading.device_images");
211 Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
212
213 auto *ImagesB =
214 ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero);
215
216 // And finally create the binary descriptor object.
217 auto *DescInit = ConstantStruct::get(
218 getBinDescTy(M),
219 ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB,
220 EntriesB, EntriesE);
221
222 return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true,
223 GlobalValue::InternalLinkage, DescInit,
224 ".omp_offloading.descriptor");
225}
226
227void createRegisterFunction(Module &M, GlobalVariable *BinDesc) {
228 LLVMContext &C = M.getContext();
229 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
230 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
231 ".omp_offloading.descriptor_reg", &M);
232 Func->setSection(".text.startup");
233
234 // Get __tgt_register_lib function declaration.
235 auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
236 /*isVarArg*/ false);
237 FunctionCallee RegFuncC =
238 M.getOrInsertFunction("__tgt_register_lib", RegFuncTy);
239
240 // Construct function body
241 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
242 Builder.CreateCall(RegFuncC, BinDesc);
243 Builder.CreateRetVoid();
244
245 // Add this function to constructors.
246 // Set priority to 1 so that __tgt_register_lib is executed AFTER
247 // __tgt_register_requires (we want to know what requirements have been
248 // asked for before we load a libomptarget plugin so that by the time the
249 // plugin is loaded it can report how many devices there are which can
250 // satisfy these requirements).
251 appendToGlobalCtors(M, Func, /*Priority*/ 1);
252}
253
254void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) {
255 LLVMContext &C = M.getContext();
256 auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
257 auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage,
258 ".omp_offloading.descriptor_unreg", &M);
259 Func->setSection(".text.startup");
260
261 // Get __tgt_unregister_lib function declaration.
262 auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M),
263 /*isVarArg*/ false);
264 FunctionCallee UnRegFuncC =
265 M.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy);
266
267 // Construct function body
268 IRBuilder<> Builder(BasicBlock::Create(C, "entry", Func));
269 Builder.CreateCall(UnRegFuncC, BinDesc);
270 Builder.CreateRetVoid();
271
272 // Add this function to global destructors.
273 // Match priority of __tgt_register_lib
274 appendToGlobalDtors(M, Func, /*Priority*/ 1);
275}
276
277// struct fatbin_wrapper {
278// int32_t magic;
279// int32_t version;
280// void *image;
281// void *reserved;
282//};
283StructType *getFatbinWrapperTy(Module &M) {
284 LLVMContext &C = M.getContext();
285 StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper");
286 if (!FatbinTy)
287 FatbinTy = StructType::create("fatbin_wrapper", Type::getInt32Ty(C),
288 Type::getInt32Ty(C), Type::getInt8PtrTy(C),
289 Type::getInt8PtrTy(C));
290 return FatbinTy;
291}
292
293/// Embed the image \p Image into the module \p M so it can be found by the
294/// runtime.
295GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) {
296 LLVMContext &C = M.getContext();
297 llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C);
298 llvm::Triple Triple = llvm::Triple(M.getTargetTriple());
299
300 // Create the global string containing the fatbinary.
301 StringRef FatbinConstantSection =
302 IsHIP ? ".hip_fatbin"
303 : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
304 auto *Data = ConstantDataArray::get(C, Image);
305 auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true,
306 GlobalVariable::InternalLinkage, Data,
307 ".fatbin_image");
308 Fatbin->setSection(FatbinConstantSection);
309
310 // Create the fatbinary wrapper
311 StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment"
312 : Triple.isMacOSX() ? "__NV_CUDA,__fatbin"
313 : ".nvFatBinSegment";
314 Constant *FatbinWrapper[] = {
315 ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic),
316 ConstantInt::get(Type::getInt32Ty(C), 1),
317 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy),
318 ConstantPointerNull::get(Type::getInt8PtrTy(C))};
319
320 Constant *FatbinInitializer =
321 ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper);
322
323 auto *FatbinDesc =
324 new GlobalVariable(M, getFatbinWrapperTy(M),
325 /*isConstant*/ true, GlobalValue::InternalLinkage,
326 FatbinInitializer, ".fatbin_wrapper");
327 FatbinDesc->setSection(FatbinWrapperSection);
328 FatbinDesc->setAlignment(Align(8));
329
330 // We create a dummy entry to ensure the linker will define the begin / end
331 // symbols. The CUDA runtime should ignore the null address if we attempt to
332 // register it.
333 auto *DummyInit =
334 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u));
335 auto *DummyEntry = new GlobalVariable(
336 M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit,
337 IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
338 DummyEntry->setVisibility(GlobalValue::HiddenVisibility);
339 DummyEntry->setSection(IsHIP ? "hip_offloading_entries"
340 : "cuda_offloading_entries");
341
342 return FatbinDesc;
343}
344
345/// Create the register globals function. We will iterate all of the offloading
346/// entries stored at the begin / end symbols and register them according to
347/// their type. This creates the following function in IR:
348///
349/// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
350/// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
351///
352/// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
353/// void *, void *, void *, void *, int *);
354/// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
355/// int64_t, int32_t, int32_t);
356///
357/// void __cudaRegisterTest(void **fatbinHandle) {
358/// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
359/// entry != &__stop_cuda_offloading_entries; ++entry) {
360/// if (!entry->size)
361/// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
362/// entry->name, -1, 0, 0, 0, 0, 0);
363/// else
364/// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
365/// 0, entry->size, 0, 0);
366/// }
367/// }
368Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) {
369 LLVMContext &C = M.getContext();
370 // Get the __cudaRegisterFunction function declaration.
371 auto *RegFuncTy = FunctionType::get(
372 Type::getInt32Ty(C),
373 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
374 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
375 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C),
376 Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)},
377 /*isVarArg*/ false);
378 FunctionCallee RegFunc = M.getOrInsertFunction(
379 IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy);
380
381 // Get the __cudaRegisterVar function declaration.
382 auto *RegVarTy = FunctionType::get(
383 Type::getVoidTy(C),
384 {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C),
385 Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C),
386 getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)},
387 /*isVarArg*/ false);
388 FunctionCallee RegVar = M.getOrInsertFunction(
389 IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy);
390
391 // Create the references to the start / stop symbols defined by the linker.
392 auto *EntriesB =
393 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
394 /*isConstant*/ true, GlobalValue::ExternalLinkage,
395 /*Initializer*/ nullptr,
396 IsHIP ? "__start_hip_offloading_entries"
397 : "__start_cuda_offloading_entries");
398 EntriesB->setVisibility(GlobalValue::HiddenVisibility);
399 auto *EntriesE =
400 new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0),
401 /*isConstant*/ true, GlobalValue::ExternalLinkage,
402 /*Initializer*/ nullptr,
403 IsHIP ? "__stop_hip_offloading_entries"
404 : "__stop_cuda_offloading_entries");
405 EntriesE->setVisibility(GlobalValue::HiddenVisibility);
406
407 auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C),
408 Type::getInt8PtrTy(C)->getPointerTo(),
409 /*isVarArg*/ false);
410 auto *RegGlobalsFn =
411 Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage,
412 IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg", &M);
413 RegGlobalsFn->setSection(".text.startup");
414
415 // Create the loop to register all the entries.
416 IRBuilder<> Builder(BasicBlock::Create(C, "entry", RegGlobalsFn));
417 auto *EntryBB = BasicBlock::Create(C, "while.entry", RegGlobalsFn);
418 auto *IfThenBB = BasicBlock::Create(C, "if.then", RegGlobalsFn);
419 auto *IfElseBB = BasicBlock::Create(C, "if.else", RegGlobalsFn);
420 auto *SwGlobalBB = BasicBlock::Create(C, "sw.global", RegGlobalsFn);
421 auto *SwManagedBB = BasicBlock::Create(C, "sw.managed", RegGlobalsFn);
422 auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface", RegGlobalsFn);
423 auto *SwTextureBB = BasicBlock::Create(C, "sw.texture", RegGlobalsFn);
424 auto *IfEndBB = BasicBlock::Create(C, "if.end", RegGlobalsFn);
425 auto *ExitBB = BasicBlock::Create(C, "while.end", RegGlobalsFn);
426
427 auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE);
428 Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB);
429 Builder.SetInsertPoint(EntryBB);
430 auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry");
431 auto *AddrPtr =
432 Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
433 {ConstantInt::get(getSizeTTy(M), 0),
434 ConstantInt::get(Type::getInt32Ty(C), 0)});
435 auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr");
436 auto *NamePtr =
437 Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
438 {ConstantInt::get(getSizeTTy(M), 0),
439 ConstantInt::get(Type::getInt32Ty(C), 1)});
440 auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name");
441 auto *SizePtr =
442 Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
443 {ConstantInt::get(getSizeTTy(M), 0),
444 ConstantInt::get(Type::getInt32Ty(C), 2)});
445 auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size");
446 auto *FlagsPtr =
447 Builder.CreateInBoundsGEP(getEntryTy(M), Entry,
448 {ConstantInt::get(getSizeTTy(M), 0),
449 ConstantInt::get(Type::getInt32Ty(C), 3)});
450 auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag");
451 auto *FnCond =
452 Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M)));
453 Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB);
454
455 // Create kernel registration code.
456 Builder.SetInsertPoint(IfThenBB);
457 Builder.CreateCall(RegFunc,
458 {RegGlobalsFn->arg_begin(), Addr, Name, Name,
459 ConstantInt::get(Type::getInt32Ty(C), -1),
460 ConstantPointerNull::get(Type::getInt8PtrTy(C)),
461 ConstantPointerNull::get(Type::getInt8PtrTy(C)),
462 ConstantPointerNull::get(Type::getInt8PtrTy(C)),
463 ConstantPointerNull::get(Type::getInt8PtrTy(C)),
464 ConstantPointerNull::get(Type::getInt32PtrTy(C))});
465 Builder.CreateBr(IfEndBB);
466 Builder.SetInsertPoint(IfElseBB);
467
468 auto *Switch = Builder.CreateSwitch(Flags, IfEndBB);
469 // Create global variable registration code.
470 Builder.SetInsertPoint(SwGlobalBB);
471 Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name,
472 ConstantInt::get(Type::getInt32Ty(C), 0), Size,
473 ConstantInt::get(Type::getInt32Ty(C), 0),
474 ConstantInt::get(Type::getInt32Ty(C), 0)});
475 Builder.CreateBr(IfEndBB);
476 Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB);
477
478 // Create managed variable registration code.
479 Builder.SetInsertPoint(SwManagedBB);
480 Builder.CreateBr(IfEndBB);
481 Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB);
482
483 // Create surface variable registration code.
484 Builder.SetInsertPoint(SwSurfaceBB);
485 Builder.CreateBr(IfEndBB);
486 Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB);
487
488 // Create texture variable registration code.
489 Builder.SetInsertPoint(SwTextureBB);
490 Builder.CreateBr(IfEndBB);
491 Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB);
492
493 Builder.SetInsertPoint(IfEndBB);
494 auto *NewEntry = Builder.CreateInBoundsGEP(
495 getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1));
496 auto *Cmp = Builder.CreateICmpEQ(
497 NewEntry,
498 ConstantExpr::getInBoundsGetElementPtr(
499 ArrayType::get(getEntryTy(M), 0), EntriesE,
500 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
501 ConstantInt::get(getSizeTTy(M), 0)})));
502 Entry->addIncoming(
503 ConstantExpr::getInBoundsGetElementPtr(
504 ArrayType::get(getEntryTy(M), 0), EntriesB,
505 ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0),
506 ConstantInt::get(getSizeTTy(M), 0)})),
507 &RegGlobalsFn->getEntryBlock());
508 Entry->addIncoming(NewEntry, IfEndBB);
509 Builder.CreateCondBr(Cmp, ExitBB, EntryBB);
510 Builder.SetInsertPoint(ExitBB);
511 Builder.CreateRetVoid();
512
513 return RegGlobalsFn;
514}
515
516// Create the constructor and destructor to register the fatbinary with the CUDA
517// runtime.
518void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc,
519 bool IsHIP) {
520 LLVMContext &C = M.getContext();
521 auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
522 auto *CtorFunc =
523 Function::Create(CtorFuncTy, GlobalValue::InternalLinkage,
524 IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M);
525 CtorFunc->setSection(".text.startup");
526
527 auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false);
528 auto *DtorFunc =
529 Function::Create(DtorFuncTy, GlobalValue::InternalLinkage,
530 IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M);
531 DtorFunc->setSection(".text.startup");
532
533 // Get the __cudaRegisterFatBinary function declaration.
534 auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(),
535 Type::getInt8PtrTy(C),
536 /*isVarArg*/ false);
537 FunctionCallee RegFatbin = M.getOrInsertFunction(
538 IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy);
539 // Get the __cudaRegisterFatBinaryEnd function declaration.
540 auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C),
541 Type::getInt8PtrTy(C)->getPointerTo(),
542 /*isVarArg*/ false);
543 FunctionCallee RegFatbinEnd =
544 M.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy);
545 // Get the __cudaUnregisterFatBinary function declaration.
546 auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C),
547 Type::getInt8PtrTy(C)->getPointerTo(),
548 /*isVarArg*/ false);
549 FunctionCallee UnregFatbin = M.getOrInsertFunction(
550 IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
551 UnregFatTy);
552
553 auto *AtExitTy =
554 FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(),
555 /*isVarArg*/ false);
556 FunctionCallee AtExit = M.getOrInsertFunction("atexit", AtExitTy);
557
558 auto *BinaryHandleGlobal = new llvm::GlobalVariable(
559 M, Type::getInt8PtrTy(C)->getPointerTo(), false,
560 llvm::GlobalValue::InternalLinkage,
561 llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()),
562 IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle");
563
564 // Create the constructor to register this image with the runtime.
565 IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry", CtorFunc));
566 CallInst *Handle = CtorBuilder.CreateCall(
567 RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
568 FatbinDesc, Type::getInt8PtrTy(C)));
569 CtorBuilder.CreateAlignedStore(
570 Handle, BinaryHandleGlobal,
571 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
572 CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle);
573 if (!IsHIP)
574 CtorBuilder.CreateCall(RegFatbinEnd, Handle);
575 CtorBuilder.CreateCall(AtExit, DtorFunc);
576 CtorBuilder.CreateRetVoid();
577
578 // Create the destructor to unregister the image with the runtime. We cannot
579 // use a standard global destructor after CUDA 9.2 so this must be called by
580 // `atexit()` intead.
581 IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry", DtorFunc));
582 LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad(
583 Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal,
584 Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))));
585 DtorBuilder.CreateCall(UnregFatbin, BinaryHandle);
586 DtorBuilder.CreateRetVoid();
587
588 // Add this function to constructors.
589 appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1);
590}
591
592} // namespace
593
594Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) {
595 GlobalVariable *Desc = createBinDesc(M, Images);
596 if (!Desc)
597 return createStringError(inconvertibleErrorCode(),
598 "No binary descriptors created.");
599 createRegisterFunction(M, Desc);
600 createUnregisterFunction(M, Desc);
601 return Error::success();
602}
603
604Error wrapCudaBinary(Module &M, ArrayRef<char> Image) {
605 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false);
606 if (!Desc)
607 return createStringError(inconvertibleErrorCode(),
608 "No fatinbary section created.");
609
610 createRegisterFatbinFunction(M, Desc, /* IsHIP */ false);
611 return Error::success();
612}
613
614Error wrapHIPBinary(Module &M, ArrayRef<char> Image) {
615 GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true);
616 if (!Desc)
617 return createStringError(inconvertibleErrorCode(),
618 "No fatinbary section created.");
619
620 createRegisterFatbinFunction(M, Desc, /* IsHIP */ true);
621 return Error::success();
622}
623