1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// This pass injects code in a graphics shader to implement guarantees
16// satisfying Vulkan's robustBufferAcces rules. Robust access rules permit
17// an out-of-bounds access to be redirected to an access of the same type
18// (load, store, etc.) but within the same root object.
19//
20// We assume baseline functionality in Vulkan, i.e. the module uses
21// logical addressing mode, without VK_KHR_variable_pointers.
22//
23// - Logical addressing mode implies:
24// - Each root pointer (a pointer that exists other than by the
25// execution of a shader instruction) is the result of an OpVariable.
26//
27// - Instructions that result in pointers are:
28// OpVariable
29// OpAccessChain
30// OpInBoundsAccessChain
31// OpFunctionParameter
32// OpImageTexelPointer
33// OpCopyObject
34//
35// - Instructions that use a pointer are:
36// OpLoad
37// OpStore
38// OpAccessChain
39// OpInBoundsAccessChain
40// OpFunctionCall
41// OpImageTexelPointer
42// OpCopyMemory
43// OpCopyObject
44// all OpAtomic* instructions
45//
46// We classify pointer-users into:
47// - Accesses:
48// - OpLoad
49// - OpStore
50// - OpAtomic*
51// - OpCopyMemory
52//
53// - Address calculations:
54// - OpAccessChain
55// - OpInBoundsAccessChain
56//
57// - Pass-through:
58// - OpFunctionCall
59// - OpFunctionParameter
60// - OpCopyObject
61//
62// The strategy is:
63//
64// - Handle only logical addressing mode. In particular, don't handle a module
65// if it uses one of the variable-pointers capabilities.
66//
67// - Don't handle modules using capability RuntimeDescriptorArrayEXT. So the
68// only runtime arrays are those that are the last member in a
69// Block-decorated struct. This allows us to feasibly/easily compute the
70// length of the runtime array. See below.
71//
72// - The memory locations accessed by OpLoad, OpStore, OpCopyMemory, and
73// OpAtomic* are determined by their pointer parameter or parameters.
74// Pointers are always (correctly) typed and so the address and number of
75// consecutive locations are fully determined by the pointer.
76//
77// - A pointer value orginates as one of few cases:
78//
79// - OpVariable for an interface object or an array of them: image,
80// buffer (UBO or SSBO), sampler, sampled-image, push-constant, input
81// variable, output variable. The execution environment is responsible for
82// allocating the correct amount of storage for these, and for ensuring
83// each resource bound to such a variable is big enough to contain the
84// SPIR-V pointee type of the variable.
85//
86// - OpVariable for a non-interface object. These are variables in
87// Workgroup, Private, and Function storage classes. The compiler ensures
88// the underlying allocation is big enough to store the entire SPIR-V
89// pointee type of the variable.
90//
91// - An OpFunctionParameter. This always maps to a pointer parameter to an
92// OpFunctionCall.
93//
94// - In logical addressing mode, these are severely limited:
95// "Any pointer operand to an OpFunctionCall must be:
96// - a memory object declaration, or
97// - a pointer to an element in an array that is a memory object
98// declaration, where the element type is OpTypeSampler or OpTypeImage"
99//
100// - This has an important simplifying consequence:
101//
102// - When looking for a pointer to the structure containing a runtime
103// array, you begin with a pointer to the runtime array and trace
104// backward in the function. You never have to trace back beyond
105// your function call boundary. So you can't take a partial access
106// chain into an SSBO, then pass that pointer into a function. So
107// we don't resort to using fat pointers to compute array length.
108// We can trace back to a pointer to the containing structure,
109// and use that in an OpArrayLength instruction. (The structure type
110// gives us the member index of the runtime array.)
111//
112// - Otherwise, the pointer type fully encodes the range of valid
113// addresses. In particular, the type of a pointer to an aggregate
114// value fully encodes the range of indices when indexing into
115// that aggregate.
116//
117// - The pointer is the result of an access chain instruction. We clamp
118// indices contributing to address calculations. As noted above, the
119// valid ranges are either bound by the length of a runtime array, or
120// by the type of the base pointer. The length of a runtime array is
121// the result of an OpArrayLength instruction acting on the pointer of
122// the containing structure as noted above.
123//
124// - Access chain indices are always treated as signed, so:
125// - Clamp the upper bound at the signed integer maximum.
126// - Use SClamp for all clamping.
127//
128// - TODO(dneto): OpImageTexelPointer:
129// - Clamp coordinate to the image size returned by OpImageQuerySize
130// - If multi-sampled, clamp the sample index to the count returned by
131// OpImageQuerySamples.
132// - If not multi-sampled, set the sample index to 0.
133//
134// - Rely on the external validator to check that pointers are only
135// used by the instructions as above.
136//
137// - Handles OpTypeRuntimeArray
138// Track pointer back to original resource (pointer to struct), so we can
139// query the runtime array size.
140//
141
142#include "graphics_robust_access_pass.h"
143
144#include <algorithm>
145#include <cstring>
146#include <functional>
147#include <initializer_list>
148#include <limits>
149#include <utility>
150
151#include "constants.h"
152#include "def_use_manager.h"
153#include "function.h"
154#include "ir_context.h"
155#include "module.h"
156#include "pass.h"
157#include "source/diagnostic.h"
158#include "source/util/make_unique.h"
159#include "spirv-tools/libspirv.h"
160#include "spirv/unified1/GLSL.std.450.h"
161#include "spirv/unified1/spirv.h"
162#include "type_manager.h"
163#include "types.h"
164
165namespace spvtools {
166namespace opt {
167
168using opt::Instruction;
169using opt::Operand;
170using spvtools::MakeUnique;
171
172GraphicsRobustAccessPass::GraphicsRobustAccessPass() : module_status_() {}
173
174Pass::Status GraphicsRobustAccessPass::Process() {
175 module_status_ = PerModuleState();
176
177 ProcessCurrentModule();
178
179 auto result = module_status_.failed
180 ? Status::Failure
181 : (module_status_.modified ? Status::SuccessWithChange
182 : Status::SuccessWithoutChange);
183
184 return result;
185}
186
187spvtools::DiagnosticStream GraphicsRobustAccessPass::Fail() {
188 module_status_.failed = true;
189 // We don't really have a position, and we'll ignore the result.
190 return std::move(
191 spvtools::DiagnosticStream({}, consumer(), "", SPV_ERROR_INVALID_BINARY)
192 << name() << ": ");
193}
194
195spv_result_t GraphicsRobustAccessPass::IsCompatibleModule() {
196 auto* feature_mgr = context()->get_feature_mgr();
197 if (!feature_mgr->HasCapability(SpvCapabilityShader))
198 return Fail() << "Can only process Shader modules";
199 if (feature_mgr->HasCapability(SpvCapabilityVariablePointers))
200 return Fail() << "Can't process modules with VariablePointers capability";
201 if (feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer))
202 return Fail() << "Can't process modules with VariablePointersStorageBuffer "
203 "capability";
204 if (feature_mgr->HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
205 // These have a RuntimeArray outside of Block-decorated struct. There
206 // is no way to compute the array length from within SPIR-V.
207 return Fail() << "Can't process modules with RuntimeDescriptorArrayEXT "
208 "capability";
209 }
210
211 {
212 auto* inst = context()->module()->GetMemoryModel();
213 const auto addressing_model = inst->GetSingleWordOperand(0);
214 if (addressing_model != SpvAddressingModelLogical)
215 return Fail() << "Addressing model must be Logical. Found "
216 << inst->PrettyPrint();
217 }
218 return SPV_SUCCESS;
219}
220
221spv_result_t GraphicsRobustAccessPass::ProcessCurrentModule() {
222 auto err = IsCompatibleModule();
223 if (err != SPV_SUCCESS) return err;
224
225 ProcessFunction fn = [this](opt::Function* f) { return ProcessAFunction(f); };
226 module_status_.modified |= context()->ProcessReachableCallTree(fn);
227
228 // Need something here. It's the price we pay for easier failure paths.
229 return SPV_SUCCESS;
230}
231
232bool GraphicsRobustAccessPass::ProcessAFunction(opt::Function* function) {
233 // Ensure that all pointers computed inside a function are within bounds.
234 // Find the access chains in this block before trying to modify them.
235 std::vector<Instruction*> access_chains;
236 std::vector<Instruction*> image_texel_pointers;
237 for (auto& block : *function) {
238 for (auto& inst : block) {
239 switch (inst.opcode()) {
240 case SpvOpAccessChain:
241 case SpvOpInBoundsAccessChain:
242 access_chains.push_back(&inst);
243 break;
244 case SpvOpImageTexelPointer:
245 image_texel_pointers.push_back(&inst);
246 break;
247 default:
248 break;
249 }
250 }
251 }
252 for (auto* inst : access_chains) {
253 ClampIndicesForAccessChain(inst);
254 if (module_status_.failed) return module_status_.modified;
255 }
256
257 for (auto* inst : image_texel_pointers) {
258 if (SPV_SUCCESS != ClampCoordinateForImageTexelPointer(inst)) break;
259 }
260 return module_status_.modified;
261}
262
263void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
264 Instruction* access_chain) {
265 Instruction& inst = *access_chain;
266
267 auto* constant_mgr = context()->get_constant_mgr();
268 auto* def_use_mgr = context()->get_def_use_mgr();
269 auto* type_mgr = context()->get_type_mgr();
270 const bool have_int64_cap =
271 context()->get_feature_mgr()->HasCapability(SpvCapabilityInt64);
272
273 // Replaces one of the OpAccessChain index operands with a new value.
274 // Updates def-use analysis.
275 auto replace_index = [&inst, def_use_mgr](uint32_t operand_index,
276 Instruction* new_value) {
277 inst.SetOperand(operand_index, {new_value->result_id()});
278 def_use_mgr->AnalyzeInstUse(&inst);
279 return SPV_SUCCESS;
280 };
281
282 // Replaces one of the OpAccesssChain index operands with a clamped value.
283 // Replace the operand at |operand_index| with the value computed from
284 // signed_clamp(%old_value, %min_value, %max_value). It also analyzes
285 // the new instruction and records that them module is modified.
286 // Assumes %min_value is signed-less-or-equal than %max_value. (All callees
287 // use 0 for %min_value).
288 auto clamp_index = [&inst, type_mgr, this, &replace_index](
289 uint32_t operand_index, Instruction* old_value,
290 Instruction* min_value, Instruction* max_value) {
291 auto* clamp_inst =
292 MakeSClampInst(*type_mgr, old_value, min_value, max_value, &inst);
293 return replace_index(operand_index, clamp_inst);
294 };
295
296 // Ensures the specified index of access chain |inst| has a value that is
297 // at most |count| - 1. If the index is already a constant value less than
298 // |count| then no change is made.
299 auto clamp_to_literal_count =
300 [&inst, this, &constant_mgr, &type_mgr, have_int64_cap, &replace_index,
301 &clamp_index](uint32_t operand_index, uint64_t count) -> spv_result_t {
302 Instruction* index_inst =
303 this->GetDef(inst.GetSingleWordOperand(operand_index));
304 const auto* index_type =
305 type_mgr->GetType(index_inst->type_id())->AsInteger();
306 assert(index_type);
307 const auto index_width = index_type->width();
308
309 if (count <= 1) {
310 // Replace the index with 0.
311 return replace_index(operand_index, GetValueForType(0, index_type));
312 }
313
314 uint64_t maxval = count - 1;
315
316 // Compute the bit width of a viable type to hold |maxval|.
317 // Look for a bit width, up to 64 bits wide, to fit maxval.
318 uint32_t maxval_width = index_width;
319 while ((maxval_width < 64) && (0 != (maxval >> maxval_width))) {
320 maxval_width *= 2;
321 }
322 // Determine the type for |maxval|.
323 analysis::Integer signed_type_for_query(maxval_width, true);
324 auto* maxval_type =
325 type_mgr->GetRegisteredType(&signed_type_for_query)->AsInteger();
326 // Access chain indices are treated as signed, so limit the maximum value
327 // of the index so it will always be positive for a signed clamp operation.
328 maxval = std::min(maxval, ((uint64_t(1) << (maxval_width - 1)) - 1));
329
330 if (index_width > 64) {
331 return this->Fail() << "Can't handle indices wider than 64 bits, found "
332 "constant index with "
333 << index_width << " bits as index number "
334 << operand_index << " of access chain "
335 << inst.PrettyPrint();
336 }
337
338 // Split into two cases: the current index is a constant, or not.
339
340 // If the index is a constant then |index_constant| will not be a null
341 // pointer. (If index is an |OpConstantNull| then it |index_constant| will
342 // not be a null pointer.) Since access chain indices must be scalar
343 // integers, this can't be a spec constant.
344 if (auto* index_constant = constant_mgr->GetConstantFromInst(index_inst)) {
345 auto* int_index_constant = index_constant->AsIntConstant();
346 int64_t value = 0;
347 // OpAccessChain indices are treated as signed. So get the signed
348 // constant value here.
349 if (index_width <= 32) {
350 value = int64_t(int_index_constant->GetS32BitValue());
351 } else if (index_width <= 64) {
352 value = int_index_constant->GetS64BitValue();
353 }
354 if (value < 0) {
355 return replace_index(operand_index, GetValueForType(0, index_type));
356 } else if (uint64_t(value) <= maxval) {
357 // Nothing to do.
358 return SPV_SUCCESS;
359 } else {
360 // Replace with maxval.
361 assert(count > 0); // Already took care of this case above.
362 return replace_index(operand_index,
363 GetValueForType(maxval, maxval_type));
364 }
365 } else {
366 // Generate a clamp instruction.
367 assert(maxval >= 1);
368 assert(index_width <= 64); // Otherwise, already returned above.
369 if (index_width >= 64 && !have_int64_cap) {
370 // An inconsistent module.
371 return Fail() << "Access chain index is wider than 64 bits, but Int64 "
372 "is not declared: "
373 << index_inst->PrettyPrint();
374 }
375 // Widen the index value if necessary
376 if (maxval_width > index_width) {
377 // Find the wider type. We only need this case if a constant array
378 // bound is too big.
379
380 // From how we calculated maxval_width, widening won't require adding
381 // the Int64 capability.
382 assert(have_int64_cap || maxval_width <= 32);
383 if (!have_int64_cap && maxval_width >= 64) {
384 // Be defensive, but this shouldn't happen.
385 return this->Fail()
386 << "Clamping index would require adding Int64 capability. "
387 << "Can't clamp 32-bit index " << operand_index
388 << " of access chain " << inst.PrettyPrint();
389 }
390 index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
391 index_inst, &inst);
392 }
393
394 // Finally, clamp the index.
395 return clamp_index(operand_index, index_inst,
396 GetValueForType(0, maxval_type),
397 GetValueForType(maxval, maxval_type));
398 }
399 return SPV_SUCCESS;
400 };
401
402 // Ensures the specified index of access chain |inst| has a value that is at
403 // most the value of |count_inst| minus 1, where |count_inst| is treated as an
404 // unsigned integer. This can log a failure.
405 auto clamp_to_count = [&inst, this, &constant_mgr, &clamp_to_literal_count,
406 &clamp_index,
407 &type_mgr](uint32_t operand_index,
408 Instruction* count_inst) -> spv_result_t {
409 Instruction* index_inst =
410 this->GetDef(inst.GetSingleWordOperand(operand_index));
411 const auto* index_type =
412 type_mgr->GetType(index_inst->type_id())->AsInteger();
413 const auto* count_type =
414 type_mgr->GetType(count_inst->type_id())->AsInteger();
415 assert(index_type);
416 if (const auto* count_constant =
417 constant_mgr->GetConstantFromInst(count_inst)) {
418 uint64_t value = 0;
419 const auto width = count_constant->type()->AsInteger()->width();
420 if (width <= 32) {
421 value = count_constant->AsIntConstant()->GetU32BitValue();
422 } else if (width <= 64) {
423 value = count_constant->AsIntConstant()->GetU64BitValue();
424 } else {
425 return this->Fail() << "Can't handle indices wider than 64 bits, found "
426 "constant index with "
427 << index_type->width() << "bits";
428 }
429 return clamp_to_literal_count(operand_index, value);
430 } else {
431 // Widen them to the same width.
432 const auto index_width = index_type->width();
433 const auto count_width = count_type->width();
434 const auto target_width = std::max(index_width, count_width);
435 // UConvert requires the result type to have 0 signedness. So enforce
436 // that here.
437 auto* wider_type = index_width < count_width ? count_type : index_type;
438 if (index_type->width() < target_width) {
439 // Access chain indices are treated as signed integers.
440 index_inst = WidenInteger(true, target_width, index_inst, &inst);
441 } else if (count_type->width() < target_width) {
442 // Assume type sizes are treated as unsigned.
443 count_inst = WidenInteger(false, target_width, count_inst, &inst);
444 }
445 // Compute count - 1.
446 // It doesn't matter if 1 is signed or unsigned.
447 auto* one = GetValueForType(1, wider_type);
448 auto* count_minus_1 = InsertInst(
449 &inst, SpvOpISub, type_mgr->GetId(wider_type), TakeNextId(),
450 {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
451 {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
452 auto* zero = GetValueForType(0, wider_type);
453 // Make sure we clamp to an upper bound that is at most the signed max
454 // for the target type.
455 const uint64_t max_signed_value =
456 ((uint64_t(1) << (target_width - 1)) - 1);
457 // Use unsigned-min to ensure that the result is always non-negative.
458 // That ensures we satisfy the invariant for SClamp, where the "min"
459 // argument we give it (zero), is no larger than the third argument.
460 auto* upper_bound =
461 MakeUMinInst(*type_mgr, count_minus_1,
462 GetValueForType(max_signed_value, wider_type), &inst);
463 // Now clamp the index to this upper bound.
464 return clamp_index(operand_index, index_inst, zero, upper_bound);
465 }
466 return SPV_SUCCESS;
467 };
468
469 const Instruction* base_inst = GetDef(inst.GetSingleWordInOperand(0));
470 const Instruction* base_type = GetDef(base_inst->type_id());
471 Instruction* pointee_type = GetDef(base_type->GetSingleWordInOperand(1));
472
473 // Walk the indices from earliest to latest, replacing indices with a
474 // clamped value, and updating the pointee_type. The order matters for
475 // the case when we have to compute the length of a runtime array. In
476 // that the algorithm relies on the fact that that the earlier indices
477 // have already been clamped.
478 const uint32_t num_operands = inst.NumOperands();
479 for (uint32_t idx = 3; !module_status_.failed && idx < num_operands; ++idx) {
480 const uint32_t index_id = inst.GetSingleWordOperand(idx);
481 Instruction* index_inst = GetDef(index_id);
482
483 switch (pointee_type->opcode()) {
484 case SpvOpTypeMatrix: // Use column count
485 case SpvOpTypeVector: // Use component count
486 {
487 const uint32_t count = pointee_type->GetSingleWordOperand(2);
488 clamp_to_literal_count(idx, count);
489 pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
490 } break;
491
492 case SpvOpTypeArray: {
493 // The array length can be a spec constant, so go through the general
494 // case.
495 Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
496 clamp_to_count(idx, array_len);
497 pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
498 } break;
499
500 case SpvOpTypeStruct: {
501 // SPIR-V requires the index to be an OpConstant.
502 // We need to know the index literal value so we can compute the next
503 // pointee type.
504 if (index_inst->opcode() != SpvOpConstant ||
505 !constant_mgr->GetConstantFromInst(index_inst)
506 ->type()
507 ->AsInteger()) {
508 Fail() << "Member index into struct is not a constant integer: "
509 << index_inst->PrettyPrint(
510 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
511 << "\nin access chain: "
512 << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
513 return;
514 }
515 const auto num_members = pointee_type->NumInOperands();
516 const auto* index_constant =
517 constant_mgr->GetConstantFromInst(index_inst);
518 // Get the sign-extended value, since access index is always treated as
519 // signed.
520 const auto index_value = index_constant->GetSignExtendedValue();
521 if (index_value < 0 || index_value >= num_members) {
522 Fail() << "Member index " << index_value
523 << " is out of bounds for struct type: "
524 << pointee_type->PrettyPrint(
525 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
526 << "\nin access chain: "
527 << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
528 return;
529 }
530 pointee_type = GetDef(pointee_type->GetSingleWordInOperand(
531 static_cast<uint32_t>(index_value)));
532 // No need to clamp this index. We just checked that it's valid.
533 } break;
534
535 case SpvOpTypeRuntimeArray: {
536 auto* array_len = MakeRuntimeArrayLengthInst(&inst, idx);
537 if (!array_len) { // We've already signaled an error.
538 return;
539 }
540 clamp_to_count(idx, array_len);
541 if (module_status_.failed) return;
542 pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
543 } break;
544
545 default:
546 Fail() << " Unhandled pointee type for access chain "
547 << pointee_type->PrettyPrint(
548 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
549 }
550 }
551}
552
553uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
554 if (module_status_.glsl_insts_id == 0) {
555 // This string serves double-duty as raw data for a string and for a vector
556 // of 32-bit words
557 const char glsl[] = "GLSL.std.450\0\0\0\0";
558 const size_t glsl_str_byte_len = 16;
559 // Use an existing import if we can.
560 for (auto& inst : context()->module()->ext_inst_imports()) {
561 const auto& name_words = inst.GetInOperand(0).words;
562 if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()),
563 glsl, glsl_str_byte_len)) {
564 module_status_.glsl_insts_id = inst.result_id();
565 }
566 }
567 if (module_status_.glsl_insts_id == 0) {
568 // Make a new import instruction.
569 module_status_.glsl_insts_id = TakeNextId();
570 std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t));
571 std::memcpy(words.data(), glsl, glsl_str_byte_len);
572 auto import_inst = MakeUnique<Instruction>(
573 context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id,
574 std::initializer_list<Operand>{
575 Operand{SPV_OPERAND_TYPE_LITERAL_STRING, std::move(words)}});
576 Instruction* inst = import_inst.get();
577 context()->module()->AddExtInstImport(std::move(import_inst));
578 module_status_.modified = true;
579 context()->AnalyzeDefUse(inst);
580 // Reanalyze the feature list, since we added an extended instruction
581 // set improt.
582 context()->get_feature_mgr()->Analyze(context()->module());
583 }
584 }
585 return module_status_.glsl_insts_id;
586}
587
588opt::Instruction* opt::GraphicsRobustAccessPass::GetValueForType(
589 uint64_t value, const analysis::Integer* type) {
590 auto* mgr = context()->get_constant_mgr();
591 assert(type->width() <= 64);
592 std::vector<uint32_t> words;
593 words.push_back(uint32_t(value));
594 if (type->width() > 32) {
595 words.push_back(uint32_t(value >> 32u));
596 }
597 const auto* constant = mgr->GetConstant(type, words);
598 return mgr->GetDefiningInstruction(
599 constant, context()->get_type_mgr()->GetTypeInstruction(type));
600}
601
602opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
603 bool sign_extend, uint32_t bit_width, Instruction* value,
604 Instruction* before_inst) {
605 analysis::Integer unsigned_type_for_query(bit_width, false);
606 auto* type_mgr = context()->get_type_mgr();
607 auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
608 auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
609 auto conversion_id = TakeNextId();
610 auto* conversion = InsertInst(
611 before_inst, (sign_extend ? SpvOpSConvert : SpvOpUConvert), type_id,
612 conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
613 return conversion;
614}
615
616Instruction* GraphicsRobustAccessPass::MakeUMinInst(
617 const analysis::TypeManager& tm, Instruction* x, Instruction* y,
618 Instruction* where) {
619 // Get IDs of instructions we'll be referencing. Evaluate them before calling
620 // the function so we force a deterministic ordering in case both of them need
621 // to take a new ID.
622 const uint32_t glsl_insts_id = GetGlslInsts();
623 uint32_t smin_id = TakeNextId();
624 const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
625 const auto ywidth = tm.GetType(y->type_id())->AsInteger()->width();
626 assert(xwidth == ywidth);
627 (void)xwidth;
628 (void)ywidth;
629 auto* smin_inst = InsertInst(
630 where, SpvOpExtInst, x->type_id(), smin_id,
631 {
632 {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
633 {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450UMin}},
634 {SPV_OPERAND_TYPE_ID, {x->result_id()}},
635 {SPV_OPERAND_TYPE_ID, {y->result_id()}},
636 });
637 return smin_inst;
638}
639
640Instruction* GraphicsRobustAccessPass::MakeSClampInst(
641 const analysis::TypeManager& tm, Instruction* x, Instruction* min,
642 Instruction* max, Instruction* where) {
643 // Get IDs of instructions we'll be referencing. Evaluate them before calling
644 // the function so we force a deterministic ordering in case both of them need
645 // to take a new ID.
646 const uint32_t glsl_insts_id = GetGlslInsts();
647 uint32_t clamp_id = TakeNextId();
648 const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
649 const auto minwidth = tm.GetType(min->type_id())->AsInteger()->width();
650 const auto maxwidth = tm.GetType(max->type_id())->AsInteger()->width();
651 assert(xwidth == minwidth);
652 assert(xwidth == maxwidth);
653 (void)xwidth;
654 (void)minwidth;
655 (void)maxwidth;
656 auto* clamp_inst = InsertInst(
657 where, SpvOpExtInst, x->type_id(), clamp_id,
658 {
659 {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
660 {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450SClamp}},
661 {SPV_OPERAND_TYPE_ID, {x->result_id()}},
662 {SPV_OPERAND_TYPE_ID, {min->result_id()}},
663 {SPV_OPERAND_TYPE_ID, {max->result_id()}},
664 });
665 return clamp_inst;
666}
667
668Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
669 Instruction* access_chain, uint32_t operand_index) {
670 // The Index parameter to the access chain at |operand_index| is indexing
671 // *into* the runtime-array. To get the number of elements in the runtime
672 // array we need a pointer to the Block-decorated struct that contains the
673 // runtime array. So conceptually we have to go 2 steps backward in the
674 // access chain. The two steps backward might forces us to traverse backward
675 // across multiple dominating instructions.
676 auto* type_mgr = context()->get_type_mgr();
677
678 // How many access chain indices do we have to unwind to find the pointer
679 // to the struct containing the runtime array?
680 uint32_t steps_remaining = 2;
681 // Find or create an instruction computing the pointer to the structure
682 // containing the runtime array.
683 // Walk backward through pointer address calculations until we either get
684 // to exactly the right base pointer, or to an access chain instruction
685 // that we can replicate but truncate to compute the address of the right
686 // struct.
687 Instruction* current_access_chain = access_chain;
688 Instruction* pointer_to_containing_struct = nullptr;
689 while (steps_remaining > 0) {
690 switch (current_access_chain->opcode()) {
691 case SpvOpCopyObject:
692 // Whoops. Walk right through this one.
693 current_access_chain =
694 GetDef(current_access_chain->GetSingleWordInOperand(0));
695 break;
696 case SpvOpAccessChain:
697 case SpvOpInBoundsAccessChain: {
698 const int first_index_operand = 3;
699 // How many indices in this access chain contribute to getting us
700 // to an element in the runtime array?
701 const auto num_contributing_indices =
702 current_access_chain == access_chain
703 ? operand_index - (first_index_operand - 1)
704 : current_access_chain->NumInOperands() - 1 /* skip the base */;
705 Instruction* base =
706 GetDef(current_access_chain->GetSingleWordInOperand(0));
707 if (num_contributing_indices == steps_remaining) {
708 // The base pointer points to the structure.
709 pointer_to_containing_struct = base;
710 steps_remaining = 0;
711 break;
712 } else if (num_contributing_indices < steps_remaining) {
713 // Peel off the index and keep going backward.
714 steps_remaining -= num_contributing_indices;
715 current_access_chain = base;
716 } else {
717 // This access chain has more indices than needed. Generate a new
718 // access chain instruction, but truncating the list of indices.
719 const int base_operand = 2;
720 // We'll use the base pointer and the indices up to but not including
721 // the one indexing into the runtime array.
722 Instruction::OperandList ops;
723 // Use the base pointer
724 ops.push_back(current_access_chain->GetOperand(base_operand));
725 const uint32_t num_indices_to_keep =
726 num_contributing_indices - steps_remaining - 1;
727 for (uint32_t i = 0; i <= num_indices_to_keep; i++) {
728 ops.push_back(
729 current_access_chain->GetOperand(first_index_operand + i));
730 }
731 // Compute the type of the result of the new access chain. Start at
732 // the base and walk the indices in a forward direction.
733 auto* constant_mgr = context()->get_constant_mgr();
734 std::vector<uint32_t> indices_for_type;
735 for (uint32_t i = 0; i < ops.size() - 1; i++) {
736 uint32_t index_for_type_calculation = 0;
737 Instruction* index =
738 GetDef(current_access_chain->GetSingleWordOperand(
739 first_index_operand + i));
740 if (auto* index_constant =
741 constant_mgr->GetConstantFromInst(index)) {
742 // We only need 32 bits. For the type calculation, it's sufficient
743 // to take the zero-extended value. It only matters for the struct
744 // case, and struct member indices are unsigned.
745 index_for_type_calculation =
746 uint32_t(index_constant->GetZeroExtendedValue());
747 } else {
748 // Indexing into a variably-sized thing like an array. Use 0.
749 index_for_type_calculation = 0;
750 }
751 indices_for_type.push_back(index_for_type_calculation);
752 }
753 auto* base_ptr_type = type_mgr->GetType(base->type_id())->AsPointer();
754 auto* base_pointee_type = base_ptr_type->pointee_type();
755 auto* new_access_chain_result_pointee_type =
756 type_mgr->GetMemberType(base_pointee_type, indices_for_type);
757 const uint32_t new_access_chain_type_id = type_mgr->FindPointerToType(
758 type_mgr->GetId(new_access_chain_result_pointee_type),
759 base_ptr_type->storage_class());
760
761 // Create the instruction and insert it.
762 const auto new_access_chain_id = TakeNextId();
763 auto* new_access_chain =
764 InsertInst(current_access_chain, current_access_chain->opcode(),
765 new_access_chain_type_id, new_access_chain_id, ops);
766 pointer_to_containing_struct = new_access_chain;
767 steps_remaining = 0;
768 break;
769 }
770 } break;
771 default:
772 Fail() << "Unhandled access chain in logical addressing mode passes "
773 "through "
774 << current_access_chain->PrettyPrint(
775 SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET |
776 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
777 return nullptr;
778 }
779 }
780 assert(pointer_to_containing_struct);
781 auto* pointee_type =
782 type_mgr->GetType(pointer_to_containing_struct->type_id())
783 ->AsPointer()
784 ->pointee_type();
785
786 auto* struct_type = pointee_type->AsStruct();
787 const uint32_t member_index_of_runtime_array =
788 uint32_t(struct_type->element_types().size() - 1);
789 // Create the length-of-array instruction before the original access chain,
790 // but after the generation of the pointer to the struct.
791 const auto array_len_id = TakeNextId();
792 analysis::Integer uint_type_for_query(32, false);
793 auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
794 auto* array_len = InsertInst(
795 access_chain, SpvOpArrayLength, type_mgr->GetId(uint_type), array_len_id,
796 {{SPV_OPERAND_TYPE_ID, {pointer_to_containing_struct->result_id()}},
797 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index_of_runtime_array}}});
798 return array_len;
799}
800
801spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
802 opt::Instruction* image_texel_pointer) {
803 // TODO(dneto): Write tests for this code.
804 // TODO(dneto): Use signed-clamp
805 return SPV_SUCCESS;
806
807 // Example:
808 // %texel_ptr = OpImageTexelPointer %texel_ptr_type %image_ptr %coord
809 // %sample
810 //
811 // We want to clamp %coord components between vector-0 and the result
812 // of OpImageQuerySize acting on the underlying image. So insert:
813 // %image = OpLoad %image_type %image_ptr
814 // %query_size = OpImageQuerySize %query_size_type %image
815 //
816 // For a multi-sampled image, %sample is the sample index, and we need
817 // to clamp it between zero and the number of samples in the image.
818 // %sample_count = OpImageQuerySamples %uint %image
819 // %max_sample_index = OpISub %uint %sample_count %uint_1
820 // For non-multi-sampled images, the sample index must be constant zero.
821
822 auto* def_use_mgr = context()->get_def_use_mgr();
823 auto* type_mgr = context()->get_type_mgr();
824 auto* constant_mgr = context()->get_constant_mgr();
825
826 auto* image_ptr = GetDef(image_texel_pointer->GetSingleWordInOperand(0));
827 auto* image_ptr_type = GetDef(image_ptr->type_id());
828 auto image_type_id = image_ptr_type->GetSingleWordInOperand(1);
829 auto* image_type = GetDef(image_type_id);
830 auto* coord = GetDef(image_texel_pointer->GetSingleWordInOperand(1));
831 auto* samples = GetDef(image_texel_pointer->GetSingleWordInOperand(2));
832
833 // We will modify the module, at least by adding image query instructions.
834 module_status_.modified = true;
835
836 // Declare the ImageQuery capability if the module doesn't already have it.
837 auto* feature_mgr = context()->get_feature_mgr();
838 if (!feature_mgr->HasCapability(SpvCapabilityImageQuery)) {
839 auto cap = MakeUnique<Instruction>(
840 context(), SpvOpCapability, 0, 0,
841 std::initializer_list<Operand>{
842 {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityImageQuery}}});
843 def_use_mgr->AnalyzeInstDefUse(cap.get());
844 context()->AddCapability(std::move(cap));
845 feature_mgr->Analyze(context()->module());
846 }
847
848 // OpImageTexelPointer is used to translate a coordinate and sample index
849 // into an address for use with an atomic operation. That is, it may only
850 // used with what Vulkan calls a "storage image"
851 // (OpTypeImage parameter Sampled=2).
852 // Note: A storage image never has a level-of-detail associated with it.
853
854 // Constraints on the sample id:
855 // - Only 2D images can be multi-sampled: OpTypeImage parameter MS=1
856 // only if Dim=2D.
857 // - Non-multi-sampled images (OpTypeImage parameter MS=0) must use
858 // sample ID to a constant 0.
859
860 // The coordinate is treated as unsigned, and should be clamped against the
861 // image "size", returned by OpImageQuerySize. (Note: OpImageQuerySizeLod
862 // is only usable with a sampled image, i.e. its image type has Sampled=1).
863
864 // Determine the result type for the OpImageQuerySize.
865 // For non-arrayed images:
866 // non-Cube:
867 // - Always the same as the coordinate type
868 // Cube:
869 // - Use all but the last component of the coordinate (which is the face
870 // index from 0 to 5).
871 // For arrayed images (in Vulkan the Dim is 1D, 2D, or Cube):
872 // non-Cube:
873 // - A vector with the components in the coordinate, and one more for
874 // the layer index.
875 // Cube:
876 // - The same as the coordinate type: 3-element integer vector.
877 // - The third component from the size query is the layer count.
878 // - The third component in the texel pointer calculation is
879 // 6 * layer + face, where 0 <= face < 6.
880 // Cube: Use all but the last component of the coordinate (which is the face
881 // index from 0 to 5).
882 const auto dim = SpvDim(image_type->GetSingleWordInOperand(1));
883 const bool arrayed = image_type->GetSingleWordInOperand(3) == 1;
884 const bool multisampled = image_type->GetSingleWordInOperand(4) != 0;
885 const auto query_num_components = [dim, arrayed, this]() -> int {
886 const int arrayness_bonus = arrayed ? 1 : 0;
887 int num_coords = 0;
888 switch (dim) {
889 case SpvDimBuffer:
890 case SpvDim1D:
891 num_coords = 1;
892 break;
893 case SpvDimCube:
894 // For cube, we need bounds for x, y, but not face.
895 case SpvDimRect:
896 case SpvDim2D:
897 num_coords = 2;
898 break;
899 case SpvDim3D:
900 num_coords = 3;
901 break;
902 case SpvDimSubpassData:
903 case SpvDimMax:
904 return Fail() << "Invalid image dimension for OpImageTexelPointer: "
905 << int(dim);
906 break;
907 }
908 return num_coords + arrayness_bonus;
909 }();
910 const auto* coord_component_type = [type_mgr, coord]() {
911 const analysis::Type* coord_type = type_mgr->GetType(coord->type_id());
912 if (auto* vector_type = coord_type->AsVector()) {
913 return vector_type->element_type()->AsInteger();
914 }
915 return coord_type->AsInteger();
916 }();
917 // For now, only handle 32-bit case for coordinates.
918 if (!coord_component_type) {
919 return Fail() << " Coordinates for OpImageTexelPointer are not integral: "
920 << image_texel_pointer->PrettyPrint(
921 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
922 }
923 if (coord_component_type->width() != 32) {
924 return Fail() << " Expected OpImageTexelPointer coordinate components to "
925 "be 32-bits wide. They are "
926 << coord_component_type->width() << " bits. "
927 << image_texel_pointer->PrettyPrint(
928 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
929 }
930 const auto* query_size_type =
931 [type_mgr, coord_component_type,
932 query_num_components]() -> const analysis::Type* {
933 if (query_num_components == 1) return coord_component_type;
934 analysis::Vector proposed(coord_component_type, query_num_components);
935 return type_mgr->GetRegisteredType(&proposed);
936 }();
937
938 const uint32_t image_id = TakeNextId();
939 auto* image =
940 InsertInst(image_texel_pointer, SpvOpLoad, image_type_id, image_id,
941 {{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
942
943 const uint32_t query_size_id = TakeNextId();
944 auto* query_size =
945 InsertInst(image_texel_pointer, SpvOpImageQuerySize,
946 type_mgr->GetTypeInstruction(query_size_type), query_size_id,
947 {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
948
949 auto* component_1 = constant_mgr->GetConstant(coord_component_type, {1});
950 const uint32_t component_1_id =
951 constant_mgr->GetDefiningInstruction(component_1)->result_id();
952 auto* component_0 = constant_mgr->GetConstant(coord_component_type, {0});
953 const uint32_t component_0_id =
954 constant_mgr->GetDefiningInstruction(component_0)->result_id();
955
956 // If the image is a cube array, then the last component of the queried
957 // size is the layer count. In the query, we have to accomodate folding
958 // in the face index ranging from 0 through 5. The inclusive upper bound
959 // on the third coordinate therefore is multiplied by 6.
960 auto* query_size_including_faces = query_size;
961 if (arrayed && (dim == SpvDimCube)) {
962 // Multiply the last coordinate by 6.
963 auto* component_6 = constant_mgr->GetConstant(coord_component_type, {6});
964 const uint32_t component_6_id =
965 constant_mgr->GetDefiningInstruction(component_6)->result_id();
966 assert(query_num_components == 3);
967 auto* multiplicand = constant_mgr->GetConstant(
968 query_size_type, {component_1_id, component_1_id, component_6_id});
969 auto* multiplicand_inst =
970 constant_mgr->GetDefiningInstruction(multiplicand);
971 const auto query_size_including_faces_id = TakeNextId();
972 query_size_including_faces = InsertInst(
973 image_texel_pointer, SpvOpIMul,
974 type_mgr->GetTypeInstruction(query_size_type),
975 query_size_including_faces_id,
976 {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
977 {SPV_OPERAND_TYPE_ID, {multiplicand_inst->result_id()}}});
978 }
979
980 // Make a coordinate-type with all 1 components.
981 auto* coordinate_1 =
982 query_num_components == 1
983 ? component_1
984 : constant_mgr->GetConstant(
985 query_size_type,
986 std::vector<uint32_t>(query_num_components, component_1_id));
987 // Make a coordinate-type with all 1 components.
988 auto* coordinate_0 =
989 query_num_components == 0
990 ? component_0
991 : constant_mgr->GetConstant(
992 query_size_type,
993 std::vector<uint32_t>(query_num_components, component_0_id));
994
995 const uint32_t query_max_including_faces_id = TakeNextId();
996 auto* query_max_including_faces = InsertInst(
997 image_texel_pointer, SpvOpISub,
998 type_mgr->GetTypeInstruction(query_size_type),
999 query_max_including_faces_id,
1000 {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
1001 {SPV_OPERAND_TYPE_ID,
1002 {constant_mgr->GetDefiningInstruction(coordinate_1)->result_id()}}});
1003
1004 // Clamp the coordinate
1005 auto* clamp_coord = MakeSClampInst(
1006 *type_mgr, coord, constant_mgr->GetDefiningInstruction(coordinate_0),
1007 query_max_including_faces, image_texel_pointer);
1008 image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
1009
1010 // Clamp the sample index
1011 if (multisampled) {
1012 // Get the sample count via OpImageQuerySamples
1013 const auto query_samples_id = TakeNextId();
1014 auto* query_samples = InsertInst(
1015 image_texel_pointer, SpvOpImageQuerySamples,
1016 constant_mgr->GetDefiningInstruction(component_0)->type_id(),
1017 query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
1018
1019 const auto max_samples_id = TakeNextId();
1020 auto* max_samples = InsertInst(image_texel_pointer, SpvOpImageQuerySamples,
1021 query_samples->type_id(), max_samples_id,
1022 {{SPV_OPERAND_TYPE_ID, {query_samples_id}},
1023 {SPV_OPERAND_TYPE_ID, {component_1_id}}});
1024
1025 auto* clamp_samples = MakeSClampInst(
1026 *type_mgr, samples, constant_mgr->GetDefiningInstruction(coordinate_0),
1027 max_samples, image_texel_pointer);
1028 image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
1029
1030 } else {
1031 // Just replace it with 0. Don't even check what was there before.
1032 image_texel_pointer->SetInOperand(2, {component_0_id});
1033 }
1034
1035 def_use_mgr->AnalyzeInstUse(image_texel_pointer);
1036
1037 return SPV_SUCCESS;
1038}
1039
1040opt::Instruction* GraphicsRobustAccessPass::InsertInst(
1041 opt::Instruction* where_inst, SpvOp opcode, uint32_t type_id,
1042 uint32_t result_id, const Instruction::OperandList& operands) {
1043 module_status_.modified = true;
1044 auto* result = where_inst->InsertBefore(
1045 MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));
1046 context()->get_def_use_mgr()->AnalyzeInstDefUse(result);
1047 auto* basic_block = context()->get_instr_block(where_inst);
1048 context()->set_instr_block(result, basic_block);
1049 return result;
1050}
1051
1052} // namespace opt
1053} // namespace spvtools
1054