1 | //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "OffloadWrapper.h" |
10 | #include "llvm/ADT/ArrayRef.h" |
11 | #include "llvm/IR/Constants.h" |
12 | #include "llvm/IR/GlobalVariable.h" |
13 | #include "llvm/IR/IRBuilder.h" |
14 | #include "llvm/IR/LLVMContext.h" |
15 | #include "llvm/IR/Module.h" |
16 | #include "llvm/Object/OffloadBinary.h" |
17 | #include "llvm/Support/Error.h" |
18 | #include "llvm/TargetParser/Triple.h" |
19 | #include "llvm/Transforms/Utils/ModuleUtils.h" |
20 | |
21 | using namespace llvm; |
22 | |
23 | namespace { |
24 | /// Magic number that begins the section containing the CUDA fatbinary. |
25 | constexpr unsigned CudaFatMagic = 0x466243b1; |
26 | constexpr unsigned HIPFatMagic = 0x48495046; |
27 | |
28 | /// Copied from clang/CGCudaRuntime.h. |
29 | enum OffloadEntryKindFlag : uint32_t { |
30 | /// Mark the entry as a global entry. This indicates the presense of a |
31 | /// kernel if the size size field is zero and a variable otherwise. |
32 | OffloadGlobalEntry = 0x0, |
33 | /// Mark the entry as a managed global variable. |
34 | OffloadGlobalManagedEntry = 0x1, |
35 | /// Mark the entry as a surface variable. |
36 | OffloadGlobalSurfaceEntry = 0x2, |
37 | /// Mark the entry as a texture variable. |
38 | OffloadGlobalTextureEntry = 0x3, |
39 | }; |
40 | |
41 | IntegerType *getSizeTTy(Module &M) { |
42 | LLVMContext &C = M.getContext(); |
43 | switch (M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C))) { |
44 | case 4u: |
45 | return Type::getInt32Ty(C); |
46 | case 8u: |
47 | return Type::getInt64Ty(C); |
48 | } |
49 | llvm_unreachable("unsupported pointer type size" ); |
50 | } |
51 | |
52 | // struct __tgt_offload_entry { |
53 | // void *addr; |
54 | // char *name; |
55 | // size_t size; |
56 | // int32_t flags; |
57 | // int32_t reserved; |
58 | // }; |
59 | StructType *getEntryTy(Module &M) { |
60 | LLVMContext &C = M.getContext(); |
61 | StructType *EntryTy = StructType::getTypeByName(C, "__tgt_offload_entry" ); |
62 | if (!EntryTy) |
63 | EntryTy = StructType::create("__tgt_offload_entry" , Type::getInt8PtrTy(C), |
64 | Type::getInt8PtrTy(C), getSizeTTy(M), |
65 | Type::getInt32Ty(C), Type::getInt32Ty(C)); |
66 | return EntryTy; |
67 | } |
68 | |
69 | PointerType *getEntryPtrTy(Module &M) { |
70 | return PointerType::getUnqual(getEntryTy(M)); |
71 | } |
72 | |
73 | // struct __tgt_device_image { |
74 | // void *ImageStart; |
75 | // void *ImageEnd; |
76 | // __tgt_offload_entry *EntriesBegin; |
77 | // __tgt_offload_entry *EntriesEnd; |
78 | // }; |
79 | StructType *getDeviceImageTy(Module &M) { |
80 | LLVMContext &C = M.getContext(); |
81 | StructType *ImageTy = StructType::getTypeByName(C, "__tgt_device_image" ); |
82 | if (!ImageTy) |
83 | ImageTy = StructType::create("__tgt_device_image" , Type::getInt8PtrTy(C), |
84 | Type::getInt8PtrTy(C), getEntryPtrTy(M), |
85 | getEntryPtrTy(M)); |
86 | return ImageTy; |
87 | } |
88 | |
89 | PointerType *getDeviceImagePtrTy(Module &M) { |
90 | return PointerType::getUnqual(getDeviceImageTy(M)); |
91 | } |
92 | |
93 | // struct __tgt_bin_desc { |
94 | // int32_t NumDeviceImages; |
95 | // __tgt_device_image *DeviceImages; |
96 | // __tgt_offload_entry *HostEntriesBegin; |
97 | // __tgt_offload_entry *HostEntriesEnd; |
98 | // }; |
99 | StructType *getBinDescTy(Module &M) { |
100 | LLVMContext &C = M.getContext(); |
101 | StructType *DescTy = StructType::getTypeByName(C, "__tgt_bin_desc" ); |
102 | if (!DescTy) |
103 | DescTy = StructType::create("__tgt_bin_desc" , Type::getInt32Ty(C), |
104 | getDeviceImagePtrTy(M), getEntryPtrTy(M), |
105 | getEntryPtrTy(M)); |
106 | return DescTy; |
107 | } |
108 | |
109 | PointerType *getBinDescPtrTy(Module &M) { |
110 | return PointerType::getUnqual(getBinDescTy(M)); |
111 | } |
112 | |
113 | /// Creates binary descriptor for the given device images. Binary descriptor |
114 | /// is an object that is passed to the offloading runtime at program startup |
115 | /// and it describes all device images available in the executable or shared |
116 | /// library. It is defined as follows |
117 | /// |
118 | /// __attribute__((visibility("hidden"))) |
119 | /// extern __tgt_offload_entry *__start_omp_offloading_entries; |
120 | /// __attribute__((visibility("hidden"))) |
121 | /// extern __tgt_offload_entry *__stop_omp_offloading_entries; |
122 | /// |
123 | /// static const char Image0[] = { <Bufs.front() contents> }; |
124 | /// ... |
125 | /// static const char ImageN[] = { <Bufs.back() contents> }; |
126 | /// |
127 | /// static const __tgt_device_image Images[] = { |
128 | /// { |
129 | /// Image0, /*ImageStart*/ |
130 | /// Image0 + sizeof(Image0), /*ImageEnd*/ |
131 | /// __start_omp_offloading_entries, /*EntriesBegin*/ |
132 | /// __stop_omp_offloading_entries /*EntriesEnd*/ |
133 | /// }, |
134 | /// ... |
135 | /// { |
136 | /// ImageN, /*ImageStart*/ |
137 | /// ImageN + sizeof(ImageN), /*ImageEnd*/ |
138 | /// __start_omp_offloading_entries, /*EntriesBegin*/ |
139 | /// __stop_omp_offloading_entries /*EntriesEnd*/ |
140 | /// } |
141 | /// }; |
142 | /// |
143 | /// static const __tgt_bin_desc BinDesc = { |
144 | /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/ |
145 | /// Images, /*DeviceImages*/ |
146 | /// __start_omp_offloading_entries, /*HostEntriesBegin*/ |
147 | /// __stop_omp_offloading_entries /*HostEntriesEnd*/ |
148 | /// }; |
149 | /// |
150 | /// Global variable that represents BinDesc is returned. |
151 | GlobalVariable *createBinDesc(Module &M, ArrayRef<ArrayRef<char>> Bufs) { |
152 | LLVMContext &C = M.getContext(); |
153 | // Create external begin/end symbols for the offload entries table. |
154 | auto *EntriesB = new GlobalVariable( |
155 | M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, |
156 | /*Initializer*/ nullptr, "__start_omp_offloading_entries" ); |
157 | EntriesB->setVisibility(GlobalValue::HiddenVisibility); |
158 | auto *EntriesE = new GlobalVariable( |
159 | M, getEntryTy(M), /*isConstant*/ true, GlobalValue::ExternalLinkage, |
160 | /*Initializer*/ nullptr, "__stop_omp_offloading_entries" ); |
161 | EntriesE->setVisibility(GlobalValue::HiddenVisibility); |
162 | |
163 | // We assume that external begin/end symbols that we have created above will |
164 | // be defined by the linker. But linker will do that only if linker inputs |
165 | // have section with "omp_offloading_entries" name which is not guaranteed. |
166 | // So, we just create dummy zero sized object in the offload entries section |
167 | // to force linker to define those symbols. |
168 | auto *DummyInit = |
169 | ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); |
170 | auto *DummyEntry = new GlobalVariable( |
171 | M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, |
172 | "__dummy.omp_offloading.entry" ); |
173 | DummyEntry->setSection("omp_offloading_entries" ); |
174 | DummyEntry->setVisibility(GlobalValue::HiddenVisibility); |
175 | |
176 | auto *Zero = ConstantInt::get(getSizeTTy(M), 0u); |
177 | Constant *ZeroZero[] = {Zero, Zero}; |
178 | |
179 | // Create initializer for the images array. |
180 | SmallVector<Constant *, 4u> ImagesInits; |
181 | ImagesInits.reserve(Bufs.size()); |
182 | for (ArrayRef<char> Buf : Bufs) { |
183 | auto *Data = ConstantDataArray::get(C, Buf); |
184 | auto *Image = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, |
185 | GlobalVariable::InternalLinkage, Data, |
186 | ".omp_offloading.device_image" ); |
187 | Image->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); |
188 | Image->setSection(".llvm.offloading" ); |
189 | Image->setAlignment(Align(object::OffloadBinary::getAlignment())); |
190 | |
191 | auto *Size = ConstantInt::get(getSizeTTy(M), Buf.size()); |
192 | Constant *ZeroSize[] = {Zero, Size}; |
193 | |
194 | auto *ImageB = |
195 | ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroZero); |
196 | auto *ImageE = |
197 | ConstantExpr::getGetElementPtr(Image->getValueType(), Image, ZeroSize); |
198 | |
199 | ImagesInits.push_back(ConstantStruct::get(getDeviceImageTy(M), ImageB, |
200 | ImageE, EntriesB, EntriesE)); |
201 | } |
202 | |
203 | // Then create images array. |
204 | auto *ImagesData = ConstantArray::get( |
205 | ArrayType::get(getDeviceImageTy(M), ImagesInits.size()), ImagesInits); |
206 | |
207 | auto *Images = |
208 | new GlobalVariable(M, ImagesData->getType(), /*isConstant*/ true, |
209 | GlobalValue::InternalLinkage, ImagesData, |
210 | ".omp_offloading.device_images" ); |
211 | Images->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); |
212 | |
213 | auto *ImagesB = |
214 | ConstantExpr::getGetElementPtr(Images->getValueType(), Images, ZeroZero); |
215 | |
216 | // And finally create the binary descriptor object. |
217 | auto *DescInit = ConstantStruct::get( |
218 | getBinDescTy(M), |
219 | ConstantInt::get(Type::getInt32Ty(C), ImagesInits.size()), ImagesB, |
220 | EntriesB, EntriesE); |
221 | |
222 | return new GlobalVariable(M, DescInit->getType(), /*isConstant*/ true, |
223 | GlobalValue::InternalLinkage, DescInit, |
224 | ".omp_offloading.descriptor" ); |
225 | } |
226 | |
227 | void createRegisterFunction(Module &M, GlobalVariable *BinDesc) { |
228 | LLVMContext &C = M.getContext(); |
229 | auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); |
230 | auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, |
231 | ".omp_offloading.descriptor_reg" , &M); |
232 | Func->setSection(".text.startup" ); |
233 | |
234 | // Get __tgt_register_lib function declaration. |
235 | auto *RegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), |
236 | /*isVarArg*/ false); |
237 | FunctionCallee RegFuncC = |
238 | M.getOrInsertFunction("__tgt_register_lib" , RegFuncTy); |
239 | |
240 | // Construct function body |
241 | IRBuilder<> Builder(BasicBlock::Create(C, "entry" , Func)); |
242 | Builder.CreateCall(RegFuncC, BinDesc); |
243 | Builder.CreateRetVoid(); |
244 | |
245 | // Add this function to constructors. |
246 | // Set priority to 1 so that __tgt_register_lib is executed AFTER |
247 | // __tgt_register_requires (we want to know what requirements have been |
248 | // asked for before we load a libomptarget plugin so that by the time the |
249 | // plugin is loaded it can report how many devices there are which can |
250 | // satisfy these requirements). |
251 | appendToGlobalCtors(M, Func, /*Priority*/ 1); |
252 | } |
253 | |
254 | void createUnregisterFunction(Module &M, GlobalVariable *BinDesc) { |
255 | LLVMContext &C = M.getContext(); |
256 | auto *FuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); |
257 | auto *Func = Function::Create(FuncTy, GlobalValue::InternalLinkage, |
258 | ".omp_offloading.descriptor_unreg" , &M); |
259 | Func->setSection(".text.startup" ); |
260 | |
261 | // Get __tgt_unregister_lib function declaration. |
262 | auto *UnRegFuncTy = FunctionType::get(Type::getVoidTy(C), getBinDescPtrTy(M), |
263 | /*isVarArg*/ false); |
264 | FunctionCallee UnRegFuncC = |
265 | M.getOrInsertFunction("__tgt_unregister_lib" , UnRegFuncTy); |
266 | |
267 | // Construct function body |
268 | IRBuilder<> Builder(BasicBlock::Create(C, "entry" , Func)); |
269 | Builder.CreateCall(UnRegFuncC, BinDesc); |
270 | Builder.CreateRetVoid(); |
271 | |
272 | // Add this function to global destructors. |
273 | // Match priority of __tgt_register_lib |
274 | appendToGlobalDtors(M, Func, /*Priority*/ 1); |
275 | } |
276 | |
277 | // struct fatbin_wrapper { |
278 | // int32_t magic; |
279 | // int32_t version; |
280 | // void *image; |
281 | // void *reserved; |
282 | //}; |
283 | StructType *getFatbinWrapperTy(Module &M) { |
284 | LLVMContext &C = M.getContext(); |
285 | StructType *FatbinTy = StructType::getTypeByName(C, "fatbin_wrapper" ); |
286 | if (!FatbinTy) |
287 | FatbinTy = StructType::create("fatbin_wrapper" , Type::getInt32Ty(C), |
288 | Type::getInt32Ty(C), Type::getInt8PtrTy(C), |
289 | Type::getInt8PtrTy(C)); |
290 | return FatbinTy; |
291 | } |
292 | |
293 | /// Embed the image \p Image into the module \p M so it can be found by the |
294 | /// runtime. |
295 | GlobalVariable *createFatbinDesc(Module &M, ArrayRef<char> Image, bool IsHIP) { |
296 | LLVMContext &C = M.getContext(); |
297 | llvm::Type *Int8PtrTy = Type::getInt8PtrTy(C); |
298 | llvm::Triple Triple = llvm::Triple(M.getTargetTriple()); |
299 | |
300 | // Create the global string containing the fatbinary. |
301 | StringRef FatbinConstantSection = |
302 | IsHIP ? ".hip_fatbin" |
303 | : (Triple.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin" ); |
304 | auto *Data = ConstantDataArray::get(C, Image); |
305 | auto *Fatbin = new GlobalVariable(M, Data->getType(), /*isConstant*/ true, |
306 | GlobalVariable::InternalLinkage, Data, |
307 | ".fatbin_image" ); |
308 | Fatbin->setSection(FatbinConstantSection); |
309 | |
310 | // Create the fatbinary wrapper |
311 | StringRef FatbinWrapperSection = IsHIP ? ".hipFatBinSegment" |
312 | : Triple.isMacOSX() ? "__NV_CUDA,__fatbin" |
313 | : ".nvFatBinSegment" ; |
314 | Constant *FatbinWrapper[] = { |
315 | ConstantInt::get(Type::getInt32Ty(C), IsHIP ? HIPFatMagic : CudaFatMagic), |
316 | ConstantInt::get(Type::getInt32Ty(C), 1), |
317 | ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin, Int8PtrTy), |
318 | ConstantPointerNull::get(Type::getInt8PtrTy(C))}; |
319 | |
320 | Constant *FatbinInitializer = |
321 | ConstantStruct::get(getFatbinWrapperTy(M), FatbinWrapper); |
322 | |
323 | auto *FatbinDesc = |
324 | new GlobalVariable(M, getFatbinWrapperTy(M), |
325 | /*isConstant*/ true, GlobalValue::InternalLinkage, |
326 | FatbinInitializer, ".fatbin_wrapper" ); |
327 | FatbinDesc->setSection(FatbinWrapperSection); |
328 | FatbinDesc->setAlignment(Align(8)); |
329 | |
330 | // We create a dummy entry to ensure the linker will define the begin / end |
331 | // symbols. The CUDA runtime should ignore the null address if we attempt to |
332 | // register it. |
333 | auto *DummyInit = |
334 | ConstantAggregateZero::get(ArrayType::get(getEntryTy(M), 0u)); |
335 | auto *DummyEntry = new GlobalVariable( |
336 | M, DummyInit->getType(), true, GlobalVariable::ExternalLinkage, DummyInit, |
337 | IsHIP ? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry" ); |
338 | DummyEntry->setVisibility(GlobalValue::HiddenVisibility); |
339 | DummyEntry->setSection(IsHIP ? "hip_offloading_entries" |
340 | : "cuda_offloading_entries" ); |
341 | |
342 | return FatbinDesc; |
343 | } |
344 | |
345 | /// Create the register globals function. We will iterate all of the offloading |
346 | /// entries stored at the begin / end symbols and register them according to |
347 | /// their type. This creates the following function in IR: |
348 | /// |
349 | /// extern struct __tgt_offload_entry __start_cuda_offloading_entries; |
350 | /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries; |
351 | /// |
352 | /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int, |
353 | /// void *, void *, void *, void *, int *); |
354 | /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t, |
355 | /// int64_t, int32_t, int32_t); |
356 | /// |
357 | /// void __cudaRegisterTest(void **fatbinHandle) { |
358 | /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries; |
359 | /// entry != &__stop_cuda_offloading_entries; ++entry) { |
360 | /// if (!entry->size) |
361 | /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name, |
362 | /// entry->name, -1, 0, 0, 0, 0, 0); |
363 | /// else |
364 | /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name, |
365 | /// 0, entry->size, 0, 0); |
366 | /// } |
367 | /// } |
368 | Function *createRegisterGlobalsFunction(Module &M, bool IsHIP) { |
369 | LLVMContext &C = M.getContext(); |
370 | // Get the __cudaRegisterFunction function declaration. |
371 | auto *RegFuncTy = FunctionType::get( |
372 | Type::getInt32Ty(C), |
373 | {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), |
374 | Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), |
375 | Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), |
376 | Type::getInt8PtrTy(C), Type::getInt32PtrTy(C)}, |
377 | /*isVarArg*/ false); |
378 | FunctionCallee RegFunc = M.getOrInsertFunction( |
379 | IsHIP ? "__hipRegisterFunction" : "__cudaRegisterFunction" , RegFuncTy); |
380 | |
381 | // Get the __cudaRegisterVar function declaration. |
382 | auto *RegVarTy = FunctionType::get( |
383 | Type::getVoidTy(C), |
384 | {Type::getInt8PtrTy(C)->getPointerTo(), Type::getInt8PtrTy(C), |
385 | Type::getInt8PtrTy(C), Type::getInt8PtrTy(C), Type::getInt32Ty(C), |
386 | getSizeTTy(M), Type::getInt32Ty(C), Type::getInt32Ty(C)}, |
387 | /*isVarArg*/ false); |
388 | FunctionCallee RegVar = M.getOrInsertFunction( |
389 | IsHIP ? "__hipRegisterVar" : "__cudaRegisterVar" , RegVarTy); |
390 | |
391 | // Create the references to the start / stop symbols defined by the linker. |
392 | auto *EntriesB = |
393 | new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), |
394 | /*isConstant*/ true, GlobalValue::ExternalLinkage, |
395 | /*Initializer*/ nullptr, |
396 | IsHIP ? "__start_hip_offloading_entries" |
397 | : "__start_cuda_offloading_entries" ); |
398 | EntriesB->setVisibility(GlobalValue::HiddenVisibility); |
399 | auto *EntriesE = |
400 | new GlobalVariable(M, ArrayType::get(getEntryTy(M), 0), |
401 | /*isConstant*/ true, GlobalValue::ExternalLinkage, |
402 | /*Initializer*/ nullptr, |
403 | IsHIP ? "__stop_hip_offloading_entries" |
404 | : "__stop_cuda_offloading_entries" ); |
405 | EntriesE->setVisibility(GlobalValue::HiddenVisibility); |
406 | |
407 | auto *RegGlobalsTy = FunctionType::get(Type::getVoidTy(C), |
408 | Type::getInt8PtrTy(C)->getPointerTo(), |
409 | /*isVarArg*/ false); |
410 | auto *RegGlobalsFn = |
411 | Function::Create(RegGlobalsTy, GlobalValue::InternalLinkage, |
412 | IsHIP ? ".hip.globals_reg" : ".cuda.globals_reg" , &M); |
413 | RegGlobalsFn->setSection(".text.startup" ); |
414 | |
415 | // Create the loop to register all the entries. |
416 | IRBuilder<> Builder(BasicBlock::Create(C, "entry" , RegGlobalsFn)); |
417 | auto *EntryBB = BasicBlock::Create(C, "while.entry" , RegGlobalsFn); |
418 | auto *IfThenBB = BasicBlock::Create(C, "if.then" , RegGlobalsFn); |
419 | auto *IfElseBB = BasicBlock::Create(C, "if.else" , RegGlobalsFn); |
420 | auto *SwGlobalBB = BasicBlock::Create(C, "sw.global" , RegGlobalsFn); |
421 | auto *SwManagedBB = BasicBlock::Create(C, "sw.managed" , RegGlobalsFn); |
422 | auto *SwSurfaceBB = BasicBlock::Create(C, "sw.surface" , RegGlobalsFn); |
423 | auto *SwTextureBB = BasicBlock::Create(C, "sw.texture" , RegGlobalsFn); |
424 | auto *IfEndBB = BasicBlock::Create(C, "if.end" , RegGlobalsFn); |
425 | auto *ExitBB = BasicBlock::Create(C, "while.end" , RegGlobalsFn); |
426 | |
427 | auto *EntryCmp = Builder.CreateICmpNE(EntriesB, EntriesE); |
428 | Builder.CreateCondBr(EntryCmp, EntryBB, ExitBB); |
429 | Builder.SetInsertPoint(EntryBB); |
430 | auto *Entry = Builder.CreatePHI(getEntryPtrTy(M), 2, "entry" ); |
431 | auto *AddrPtr = |
432 | Builder.CreateInBoundsGEP(getEntryTy(M), Entry, |
433 | {ConstantInt::get(getSizeTTy(M), 0), |
434 | ConstantInt::get(Type::getInt32Ty(C), 0)}); |
435 | auto *Addr = Builder.CreateLoad(Type::getInt8PtrTy(C), AddrPtr, "addr" ); |
436 | auto *NamePtr = |
437 | Builder.CreateInBoundsGEP(getEntryTy(M), Entry, |
438 | {ConstantInt::get(getSizeTTy(M), 0), |
439 | ConstantInt::get(Type::getInt32Ty(C), 1)}); |
440 | auto *Name = Builder.CreateLoad(Type::getInt8PtrTy(C), NamePtr, "name" ); |
441 | auto *SizePtr = |
442 | Builder.CreateInBoundsGEP(getEntryTy(M), Entry, |
443 | {ConstantInt::get(getSizeTTy(M), 0), |
444 | ConstantInt::get(Type::getInt32Ty(C), 2)}); |
445 | auto *Size = Builder.CreateLoad(getSizeTTy(M), SizePtr, "size" ); |
446 | auto *FlagsPtr = |
447 | Builder.CreateInBoundsGEP(getEntryTy(M), Entry, |
448 | {ConstantInt::get(getSizeTTy(M), 0), |
449 | ConstantInt::get(Type::getInt32Ty(C), 3)}); |
450 | auto *Flags = Builder.CreateLoad(Type::getInt32Ty(C), FlagsPtr, "flag" ); |
451 | auto *FnCond = |
452 | Builder.CreateICmpEQ(Size, ConstantInt::getNullValue(getSizeTTy(M))); |
453 | Builder.CreateCondBr(FnCond, IfThenBB, IfElseBB); |
454 | |
455 | // Create kernel registration code. |
456 | Builder.SetInsertPoint(IfThenBB); |
457 | Builder.CreateCall(RegFunc, |
458 | {RegGlobalsFn->arg_begin(), Addr, Name, Name, |
459 | ConstantInt::get(Type::getInt32Ty(C), -1), |
460 | ConstantPointerNull::get(Type::getInt8PtrTy(C)), |
461 | ConstantPointerNull::get(Type::getInt8PtrTy(C)), |
462 | ConstantPointerNull::get(Type::getInt8PtrTy(C)), |
463 | ConstantPointerNull::get(Type::getInt8PtrTy(C)), |
464 | ConstantPointerNull::get(Type::getInt32PtrTy(C))}); |
465 | Builder.CreateBr(IfEndBB); |
466 | Builder.SetInsertPoint(IfElseBB); |
467 | |
468 | auto *Switch = Builder.CreateSwitch(Flags, IfEndBB); |
469 | // Create global variable registration code. |
470 | Builder.SetInsertPoint(SwGlobalBB); |
471 | Builder.CreateCall(RegVar, {RegGlobalsFn->arg_begin(), Addr, Name, Name, |
472 | ConstantInt::get(Type::getInt32Ty(C), 0), Size, |
473 | ConstantInt::get(Type::getInt32Ty(C), 0), |
474 | ConstantInt::get(Type::getInt32Ty(C), 0)}); |
475 | Builder.CreateBr(IfEndBB); |
476 | Switch->addCase(Builder.getInt32(OffloadGlobalEntry), SwGlobalBB); |
477 | |
478 | // Create managed variable registration code. |
479 | Builder.SetInsertPoint(SwManagedBB); |
480 | Builder.CreateBr(IfEndBB); |
481 | Switch->addCase(Builder.getInt32(OffloadGlobalManagedEntry), SwManagedBB); |
482 | |
483 | // Create surface variable registration code. |
484 | Builder.SetInsertPoint(SwSurfaceBB); |
485 | Builder.CreateBr(IfEndBB); |
486 | Switch->addCase(Builder.getInt32(OffloadGlobalSurfaceEntry), SwSurfaceBB); |
487 | |
488 | // Create texture variable registration code. |
489 | Builder.SetInsertPoint(SwTextureBB); |
490 | Builder.CreateBr(IfEndBB); |
491 | Switch->addCase(Builder.getInt32(OffloadGlobalTextureEntry), SwTextureBB); |
492 | |
493 | Builder.SetInsertPoint(IfEndBB); |
494 | auto *NewEntry = Builder.CreateInBoundsGEP( |
495 | getEntryTy(M), Entry, ConstantInt::get(getSizeTTy(M), 1)); |
496 | auto *Cmp = Builder.CreateICmpEQ( |
497 | NewEntry, |
498 | ConstantExpr::getInBoundsGetElementPtr( |
499 | ArrayType::get(getEntryTy(M), 0), EntriesE, |
500 | ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), |
501 | ConstantInt::get(getSizeTTy(M), 0)}))); |
502 | Entry->addIncoming( |
503 | ConstantExpr::getInBoundsGetElementPtr( |
504 | ArrayType::get(getEntryTy(M), 0), EntriesB, |
505 | ArrayRef<Constant *>({ConstantInt::get(getSizeTTy(M), 0), |
506 | ConstantInt::get(getSizeTTy(M), 0)})), |
507 | &RegGlobalsFn->getEntryBlock()); |
508 | Entry->addIncoming(NewEntry, IfEndBB); |
509 | Builder.CreateCondBr(Cmp, ExitBB, EntryBB); |
510 | Builder.SetInsertPoint(ExitBB); |
511 | Builder.CreateRetVoid(); |
512 | |
513 | return RegGlobalsFn; |
514 | } |
515 | |
516 | // Create the constructor and destructor to register the fatbinary with the CUDA |
517 | // runtime. |
518 | void createRegisterFatbinFunction(Module &M, GlobalVariable *FatbinDesc, |
519 | bool IsHIP) { |
520 | LLVMContext &C = M.getContext(); |
521 | auto *CtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); |
522 | auto *CtorFunc = |
523 | Function::Create(CtorFuncTy, GlobalValue::InternalLinkage, |
524 | IsHIP ? ".hip.fatbin_reg" : ".cuda.fatbin_reg" , &M); |
525 | CtorFunc->setSection(".text.startup" ); |
526 | |
527 | auto *DtorFuncTy = FunctionType::get(Type::getVoidTy(C), /*isVarArg*/ false); |
528 | auto *DtorFunc = |
529 | Function::Create(DtorFuncTy, GlobalValue::InternalLinkage, |
530 | IsHIP ? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg" , &M); |
531 | DtorFunc->setSection(".text.startup" ); |
532 | |
533 | // Get the __cudaRegisterFatBinary function declaration. |
534 | auto *RegFatTy = FunctionType::get(Type::getInt8PtrTy(C)->getPointerTo(), |
535 | Type::getInt8PtrTy(C), |
536 | /*isVarArg*/ false); |
537 | FunctionCallee RegFatbin = M.getOrInsertFunction( |
538 | IsHIP ? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary" , RegFatTy); |
539 | // Get the __cudaRegisterFatBinaryEnd function declaration. |
540 | auto *RegFatEndTy = FunctionType::get(Type::getVoidTy(C), |
541 | Type::getInt8PtrTy(C)->getPointerTo(), |
542 | /*isVarArg*/ false); |
543 | FunctionCallee RegFatbinEnd = |
544 | M.getOrInsertFunction("__cudaRegisterFatBinaryEnd" , RegFatEndTy); |
545 | // Get the __cudaUnregisterFatBinary function declaration. |
546 | auto *UnregFatTy = FunctionType::get(Type::getVoidTy(C), |
547 | Type::getInt8PtrTy(C)->getPointerTo(), |
548 | /*isVarArg*/ false); |
549 | FunctionCallee UnregFatbin = M.getOrInsertFunction( |
550 | IsHIP ? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary" , |
551 | UnregFatTy); |
552 | |
553 | auto *AtExitTy = |
554 | FunctionType::get(Type::getInt32Ty(C), DtorFuncTy->getPointerTo(), |
555 | /*isVarArg*/ false); |
556 | FunctionCallee AtExit = M.getOrInsertFunction("atexit" , AtExitTy); |
557 | |
558 | auto *BinaryHandleGlobal = new llvm::GlobalVariable( |
559 | M, Type::getInt8PtrTy(C)->getPointerTo(), false, |
560 | llvm::GlobalValue::InternalLinkage, |
561 | llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C)->getPointerTo()), |
562 | IsHIP ? ".hip.binary_handle" : ".cuda.binary_handle" ); |
563 | |
564 | // Create the constructor to register this image with the runtime. |
565 | IRBuilder<> CtorBuilder(BasicBlock::Create(C, "entry" , CtorFunc)); |
566 | CallInst *Handle = CtorBuilder.CreateCall( |
567 | RegFatbin, ConstantExpr::getPointerBitCastOrAddrSpaceCast( |
568 | FatbinDesc, Type::getInt8PtrTy(C))); |
569 | CtorBuilder.CreateAlignedStore( |
570 | Handle, BinaryHandleGlobal, |
571 | Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); |
572 | CtorBuilder.CreateCall(createRegisterGlobalsFunction(M, IsHIP), Handle); |
573 | if (!IsHIP) |
574 | CtorBuilder.CreateCall(RegFatbinEnd, Handle); |
575 | CtorBuilder.CreateCall(AtExit, DtorFunc); |
576 | CtorBuilder.CreateRetVoid(); |
577 | |
578 | // Create the destructor to unregister the image with the runtime. We cannot |
579 | // use a standard global destructor after CUDA 9.2 so this must be called by |
580 | // `atexit()` intead. |
581 | IRBuilder<> DtorBuilder(BasicBlock::Create(C, "entry" , DtorFunc)); |
582 | LoadInst *BinaryHandle = DtorBuilder.CreateAlignedLoad( |
583 | Type::getInt8PtrTy(C)->getPointerTo(), BinaryHandleGlobal, |
584 | Align(M.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C)))); |
585 | DtorBuilder.CreateCall(UnregFatbin, BinaryHandle); |
586 | DtorBuilder.CreateRetVoid(); |
587 | |
588 | // Add this function to constructors. |
589 | appendToGlobalCtors(M, CtorFunc, /*Priority*/ 1); |
590 | } |
591 | |
592 | } // namespace |
593 | |
594 | Error wrapOpenMPBinaries(Module &M, ArrayRef<ArrayRef<char>> Images) { |
595 | GlobalVariable *Desc = createBinDesc(M, Images); |
596 | if (!Desc) |
597 | return createStringError(inconvertibleErrorCode(), |
598 | "No binary descriptors created." ); |
599 | createRegisterFunction(M, Desc); |
600 | createUnregisterFunction(M, Desc); |
601 | return Error::success(); |
602 | } |
603 | |
604 | Error wrapCudaBinary(Module &M, ArrayRef<char> Image) { |
605 | GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ false); |
606 | if (!Desc) |
607 | return createStringError(inconvertibleErrorCode(), |
608 | "No fatinbary section created." ); |
609 | |
610 | createRegisterFatbinFunction(M, Desc, /* IsHIP */ false); |
611 | return Error::success(); |
612 | } |
613 | |
614 | Error wrapHIPBinary(Module &M, ArrayRef<char> Image) { |
615 | GlobalVariable *Desc = createFatbinDesc(M, Image, /* IsHIP */ true); |
616 | if (!Desc) |
617 | return createStringError(inconvertibleErrorCode(), |
618 | "No fatinbary section created." ); |
619 | |
620 | createRegisterFatbinFunction(M, Desc, /* IsHIP */ true); |
621 | return Error::success(); |
622 | } |
623 | |