1// Copyright (c) 2018 The Khronos Group Inc.
2// Copyright (c) 2018 Valve Corporation
3// Copyright (c) 2018 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "inst_bindless_check_pass.h"
18
19namespace {
20
21// Input Operand Indices
22static const int kSpvImageSampleImageIdInIdx = 0;
23static const int kSpvSampledImageImageIdInIdx = 0;
24static const int kSpvSampledImageSamplerIdInIdx = 1;
25static const int kSpvImageSampledImageIdInIdx = 0;
26static const int kSpvLoadPtrIdInIdx = 0;
27static const int kSpvAccessChainBaseIdInIdx = 0;
28static const int kSpvAccessChainIndex0IdInIdx = 1;
29static const int kSpvTypePointerTypeIdInIdx = 1;
30static const int kSpvTypeArrayLengthIdInIdx = 1;
31static const int kSpvConstantValueInIdx = 0;
32static const int kSpvVariableStorageClassInIdx = 0;
33
34} // anonymous namespace
35
36namespace spvtools {
37namespace opt {
38
39uint32_t InstBindlessCheckPass::GenDebugReadLength(
40 uint32_t var_id, InstructionBuilder* builder) {
41 uint32_t desc_set_idx =
42 var2desc_set_[var_id] + kDebugInputBindlessOffsetLengths;
43 uint32_t desc_set_idx_id = builder->GetUintConstantId(desc_set_idx);
44 uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
45 return GenDebugDirectRead({desc_set_idx_id, binding_idx_id}, builder);
46}
47
48uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
49 uint32_t desc_idx_id,
50 InstructionBuilder* builder) {
51 uint32_t desc_set_base_id =
52 builder->GetUintConstantId(kDebugInputBindlessInitOffset);
53 uint32_t desc_set_idx_id = builder->GetUintConstantId(var2desc_set_[var_id]);
54 uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
55 uint32_t u_desc_idx_id = GenUintCastCode(desc_idx_id, builder);
56 return GenDebugDirectRead(
57 {desc_set_base_id, desc_set_idx_id, binding_idx_id, u_desc_idx_id},
58 builder);
59}
60
61uint32_t InstBindlessCheckPass::CloneOriginalReference(
62 ref_analysis* ref, InstructionBuilder* builder) {
63 // If original is image based, start by cloning descriptor load
64 uint32_t new_image_id = 0;
65 if (ref->desc_load_id != 0) {
66 Instruction* desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
67 Instruction* new_load_inst = builder->AddLoad(
68 desc_load_inst->type_id(),
69 desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
70 uid2offset_[new_load_inst->unique_id()] =
71 uid2offset_[desc_load_inst->unique_id()];
72 uint32_t new_load_id = new_load_inst->result_id();
73 get_decoration_mgr()->CloneDecorations(desc_load_inst->result_id(),
74 new_load_id);
75 new_image_id = new_load_id;
76 // Clone Image/SampledImage with new load, if needed
77 if (ref->image_id != 0) {
78 Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
79 if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
80 Instruction* new_image_inst = builder->AddBinaryOp(
81 image_inst->type_id(), SpvOpSampledImage, new_load_id,
82 image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
83 uid2offset_[new_image_inst->unique_id()] =
84 uid2offset_[image_inst->unique_id()];
85 new_image_id = new_image_inst->result_id();
86 } else {
87 assert(image_inst->opcode() == SpvOp::SpvOpImage &&
88 "expecting OpImage");
89 Instruction* new_image_inst =
90 builder->AddUnaryOp(image_inst->type_id(), SpvOpImage, new_load_id);
91 uid2offset_[new_image_inst->unique_id()] =
92 uid2offset_[image_inst->unique_id()];
93 new_image_id = new_image_inst->result_id();
94 }
95 get_decoration_mgr()->CloneDecorations(ref->image_id, new_image_id);
96 }
97 }
98 // Clone original reference
99 std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
100 uint32_t ref_result_id = ref->ref_inst->result_id();
101 uint32_t new_ref_id = 0;
102 if (ref_result_id != 0) {
103 new_ref_id = TakeNextId();
104 new_ref_inst->SetResultId(new_ref_id);
105 }
106 // Update new ref with new image if created
107 if (new_image_id != 0)
108 new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
109 // Register new reference and add to new block
110 Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
111 uid2offset_[added_inst->unique_id()] =
112 uid2offset_[ref->ref_inst->unique_id()];
113 if (new_ref_id != 0)
114 get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
115 return new_ref_id;
116}
117
118uint32_t InstBindlessCheckPass::GetImageId(Instruction* inst) {
119 switch (inst->opcode()) {
120 case SpvOp::SpvOpImageSampleImplicitLod:
121 case SpvOp::SpvOpImageSampleExplicitLod:
122 case SpvOp::SpvOpImageSampleDrefImplicitLod:
123 case SpvOp::SpvOpImageSampleDrefExplicitLod:
124 case SpvOp::SpvOpImageSampleProjImplicitLod:
125 case SpvOp::SpvOpImageSampleProjExplicitLod:
126 case SpvOp::SpvOpImageSampleProjDrefImplicitLod:
127 case SpvOp::SpvOpImageSampleProjDrefExplicitLod:
128 case SpvOp::SpvOpImageGather:
129 case SpvOp::SpvOpImageDrefGather:
130 case SpvOp::SpvOpImageQueryLod:
131 case SpvOp::SpvOpImageSparseSampleImplicitLod:
132 case SpvOp::SpvOpImageSparseSampleExplicitLod:
133 case SpvOp::SpvOpImageSparseSampleDrefImplicitLod:
134 case SpvOp::SpvOpImageSparseSampleDrefExplicitLod:
135 case SpvOp::SpvOpImageSparseSampleProjImplicitLod:
136 case SpvOp::SpvOpImageSparseSampleProjExplicitLod:
137 case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod:
138 case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod:
139 case SpvOp::SpvOpImageSparseGather:
140 case SpvOp::SpvOpImageSparseDrefGather:
141 case SpvOp::SpvOpImageFetch:
142 case SpvOp::SpvOpImageRead:
143 case SpvOp::SpvOpImageQueryFormat:
144 case SpvOp::SpvOpImageQueryOrder:
145 case SpvOp::SpvOpImageQuerySizeLod:
146 case SpvOp::SpvOpImageQuerySize:
147 case SpvOp::SpvOpImageQueryLevels:
148 case SpvOp::SpvOpImageQuerySamples:
149 case SpvOp::SpvOpImageSparseFetch:
150 case SpvOp::SpvOpImageSparseRead:
151 case SpvOp::SpvOpImageWrite:
152 return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
153 default:
154 break;
155 }
156 return 0;
157}
158
159Instruction* InstBindlessCheckPass::GetDescriptorTypeInst(
160 Instruction* var_inst) {
161 uint32_t var_type_id = var_inst->type_id();
162 Instruction* var_type_inst = get_def_use_mgr()->GetDef(var_type_id);
163 uint32_t desc_type_id =
164 var_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
165 return get_def_use_mgr()->GetDef(desc_type_id);
166}
167
168bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
169 ref_analysis* ref) {
170 ref->ref_inst = ref_inst;
171 if (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) {
172 ref->desc_load_id = 0;
173 ref->ptr_id = ref_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
174 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
175 if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return false;
176 ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
177 Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
178 if (var_inst->opcode() != SpvOp::SpvOpVariable) return false;
179 uint32_t storage_class =
180 var_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx);
181 switch (storage_class) {
182 case SpvStorageClassUniform:
183 case SpvStorageClassUniformConstant:
184 case SpvStorageClassStorageBuffer:
185 break;
186 default:
187 return false;
188 break;
189 }
190 Instruction* desc_type_inst = GetDescriptorTypeInst(var_inst);
191 switch (desc_type_inst->opcode()) {
192 case SpvOpTypeArray:
193 case SpvOpTypeRuntimeArray:
194 // A load through a descriptor array will have at least 3 operands. We
195 // do not want to instrument loads of descriptors here which are part of
196 // an image-based reference.
197 if (ptr_inst->NumInOperands() < 3) return false;
198 ref->index_id =
199 ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
200 break;
201 default:
202 ref->index_id = 0;
203 break;
204 }
205 return true;
206 }
207 // Reference is not load or store. If not an image-based reference, return.
208 ref->image_id = GetImageId(ref_inst);
209 if (ref->image_id == 0) return false;
210 Instruction* image_inst = get_def_use_mgr()->GetDef(ref->image_id);
211 Instruction* desc_load_inst = nullptr;
212 if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
213 ref->desc_load_id =
214 image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
215 desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
216 } else if (image_inst->opcode() == SpvOp::SpvOpImage) {
217 ref->desc_load_id =
218 image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
219 desc_load_inst = get_def_use_mgr()->GetDef(ref->desc_load_id);
220 } else {
221 ref->desc_load_id = ref->image_id;
222 desc_load_inst = image_inst;
223 ref->image_id = 0;
224 }
225 if (desc_load_inst->opcode() != SpvOp::SpvOpLoad) {
226 // TODO(greg-lunarg): Handle additional possibilities?
227 return false;
228 }
229 ref->ptr_id = desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
230 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
231 if (ptr_inst->opcode() == SpvOp::SpvOpVariable) {
232 ref->index_id = 0;
233 ref->var_id = ref->ptr_id;
234 } else if (ptr_inst->opcode() == SpvOp::SpvOpAccessChain) {
235 if (ptr_inst->NumInOperands() != 2) {
236 assert(false && "unexpected bindless index number");
237 return false;
238 }
239 ref->index_id =
240 ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
241 ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
242 Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
243 if (var_inst->opcode() != SpvOpVariable) {
244 assert(false && "unexpected bindless base");
245 return false;
246 }
247 } else {
248 // TODO(greg-lunarg): Handle additional possibilities?
249 return false;
250 }
251 return true;
252}
253
254void InstBindlessCheckPass::GenCheckCode(
255 uint32_t check_id, uint32_t error_id, uint32_t length_id,
256 uint32_t stage_idx, ref_analysis* ref,
257 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
258 BasicBlock* back_blk_ptr = &*new_blocks->back();
259 InstructionBuilder builder(
260 context(), back_blk_ptr,
261 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
262 // Gen conditional branch on check_id. Valid branch generates original
263 // reference. Invalid generates debug output and zero result (if needed).
264 uint32_t merge_blk_id = TakeNextId();
265 uint32_t valid_blk_id = TakeNextId();
266 uint32_t invalid_blk_id = TakeNextId();
267 std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
268 std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
269 std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
270 (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
271 merge_blk_id, SpvSelectionControlMaskNone);
272 // Gen valid bounds branch
273 std::unique_ptr<BasicBlock> new_blk_ptr(
274 new BasicBlock(std::move(valid_label)));
275 builder.SetInsertPoint(&*new_blk_ptr);
276 uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
277 (void)builder.AddBranch(merge_blk_id);
278 new_blocks->push_back(std::move(new_blk_ptr));
279 // Gen invalid block
280 new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
281 builder.SetInsertPoint(&*new_blk_ptr);
282 uint32_t u_index_id = GenUintCastCode(ref->index_id, &builder);
283 GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
284 {error_id, u_index_id, length_id}, &builder);
285 // Remember last invalid block id
286 uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
287 // Gen zero for invalid reference
288 uint32_t ref_type_id = ref->ref_inst->type_id();
289 (void)builder.AddBranch(merge_blk_id);
290 new_blocks->push_back(std::move(new_blk_ptr));
291 // Gen merge block
292 new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
293 builder.SetInsertPoint(&*new_blk_ptr);
294 // Gen phi of new reference and zero, if necessary, and replace the
295 // result id of the original reference with that of the Phi. Kill original
296 // reference.
297 if (new_ref_id != 0) {
298 Instruction* phi_inst = builder.AddPhi(
299 ref_type_id, {new_ref_id, valid_blk_id, GetNullId(ref_type_id),
300 last_invalid_blk_id});
301 context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
302 phi_inst->result_id());
303 }
304 new_blocks->push_back(std::move(new_blk_ptr));
305 context()->KillInst(ref->ref_inst);
306}
307
308void InstBindlessCheckPass::GenBoundsCheckCode(
309 BasicBlock::iterator ref_inst_itr,
310 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
311 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
312 // Look for reference through indexed descriptor. If found, analyze and
313 // save components. If not, return.
314 ref_analysis ref;
315 if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
316 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
317 if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return;
318 // If index and bound both compile-time constants and index < bound,
319 // return without changing
320 Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
321 Instruction* desc_type_inst = GetDescriptorTypeInst(var_inst);
322 uint32_t length_id = 0;
323 if (desc_type_inst->opcode() == SpvOpTypeArray) {
324 length_id =
325 desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
326 Instruction* index_inst = get_def_use_mgr()->GetDef(ref.index_id);
327 Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
328 if (index_inst->opcode() == SpvOpConstant &&
329 length_inst->opcode() == SpvOpConstant &&
330 index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
331 length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
332 return;
333 } else if (!input_length_enabled_ ||
334 desc_type_inst->opcode() != SpvOpTypeRuntimeArray) {
335 return;
336 }
337 // Move original block's preceding instructions into first new block
338 std::unique_ptr<BasicBlock> new_blk_ptr;
339 MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
340 InstructionBuilder builder(
341 context(), &*new_blk_ptr,
342 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
343 new_blocks->push_back(std::move(new_blk_ptr));
344 uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
345 // If length id not yet set, descriptor array is runtime size so
346 // generate load of length from stage's debug input buffer.
347 if (length_id == 0) {
348 assert(desc_type_inst->opcode() == SpvOpTypeRuntimeArray &&
349 "unexpected bindless type");
350 length_id = GenDebugReadLength(ref.var_id, &builder);
351 }
352 // Generate full runtime bounds test code with true branch
353 // being full reference and false branch being debug output and zero
354 // for the referenced value.
355 Instruction* ult_inst =
356 builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, ref.index_id, length_id);
357 GenCheckCode(ult_inst->result_id(), error_id, length_id, stage_idx, &ref,
358 new_blocks);
359 // Move original block's remaining code into remainder/merge block and add
360 // to new blocks
361 BasicBlock* back_blk_ptr = &*new_blocks->back();
362 MovePostludeCode(ref_block_itr, back_blk_ptr);
363}
364
365void InstBindlessCheckPass::GenInitCheckCode(
366 BasicBlock::iterator ref_inst_itr,
367 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
368 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
369 // Look for reference through descriptor. If not, return.
370 ref_analysis ref;
371 if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
372 // Move original block's preceding instructions into first new block
373 std::unique_ptr<BasicBlock> new_blk_ptr;
374 MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
375 InstructionBuilder builder(
376 context(), &*new_blk_ptr,
377 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
378 new_blocks->push_back(std::move(new_blk_ptr));
379 // Read initialization status from debug input buffer. If index id not yet
380 // set, binding is single descriptor, so set index to constant 0.
381 uint32_t zero_id = builder.GetUintConstantId(0u);
382 if (ref.index_id == 0) ref.index_id = zero_id;
383 uint32_t init_id = GenDebugReadInit(ref.var_id, ref.index_id, &builder);
384 // Generate full runtime non-zero init test code with true branch
385 // being full reference and false branch being debug output and zero
386 // for the referenced value.
387 Instruction* uneq_inst =
388 builder.AddBinaryOp(GetBoolId(), SpvOpINotEqual, init_id, zero_id);
389 uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessUninit);
390 GenCheckCode(uneq_inst->result_id(), error_id, zero_id, stage_idx, &ref,
391 new_blocks);
392 // Move original block's remaining code into remainder/merge block and add
393 // to new blocks
394 BasicBlock* back_blk_ptr = &*new_blocks->back();
395 MovePostludeCode(ref_block_itr, back_blk_ptr);
396}
397
398void InstBindlessCheckPass::InitializeInstBindlessCheck() {
399 // Initialize base class
400 InitializeInstrument();
401 // If runtime array length support enabled, create variable mappings. Length
402 // support is always enabled if descriptor init check is enabled.
403 if (input_length_enabled_)
404 for (auto& anno : get_module()->annotations())
405 if (anno.opcode() == SpvOpDecorate) {
406 if (anno.GetSingleWordInOperand(1u) == SpvDecorationDescriptorSet)
407 var2desc_set_[anno.GetSingleWordInOperand(0u)] =
408 anno.GetSingleWordInOperand(2u);
409 else if (anno.GetSingleWordInOperand(1u) == SpvDecorationBinding)
410 var2binding_[anno.GetSingleWordInOperand(0u)] =
411 anno.GetSingleWordInOperand(2u);
412 }
413}
414
415Pass::Status InstBindlessCheckPass::ProcessImpl() {
416 // Perform bindless bounds check on each entry point function in module
417 InstProcessFunction pfn =
418 [this](BasicBlock::iterator ref_inst_itr,
419 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
420 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
421 return GenBoundsCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
422 new_blocks);
423 };
424 bool modified = InstProcessEntryPointCallTree(pfn);
425 if (input_init_enabled_) {
426 // Perform descriptor initialization check on each entry point function in
427 // module
428 pfn = [this](BasicBlock::iterator ref_inst_itr,
429 UptrVectorIterator<BasicBlock> ref_block_itr,
430 uint32_t stage_idx,
431 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
432 return GenInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
433 new_blocks);
434 };
435 modified |= InstProcessEntryPointCallTree(pfn);
436 }
437 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
438}
439
440Pass::Status InstBindlessCheckPass::Process() {
441 InitializeInstBindlessCheck();
442 return ProcessImpl();
443}
444
445} // namespace opt
446} // namespace spvtools
447