1// Copyright (c) 2015-2016 The Khronos Group Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/val/validation_state.h"
16
17#include <cassert>
18#include <stack>
19#include <utility>
20
21#include "source/opcode.h"
22#include "source/spirv_constant.h"
23#include "source/spirv_target_env.h"
24#include "source/val/basic_block.h"
25#include "source/val/construct.h"
26#include "source/val/function.h"
27#include "spirv-tools/libspirv.h"
28
29namespace spvtools {
30namespace val {
31namespace {
32
33bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
34 // See Section 2.4
35 bool out = false;
36 // clang-format off
37 switch (layout) {
38 case kLayoutCapabilities: out = op == SpvOpCapability; break;
39 case kLayoutExtensions: out = op == SpvOpExtension; break;
40 case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break;
41 case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break;
42 case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break;
43 case kLayoutExecutionMode:
44 out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId;
45 break;
46 case kLayoutDebug1:
47 switch (op) {
48 case SpvOpSourceContinued:
49 case SpvOpSource:
50 case SpvOpSourceExtension:
51 case SpvOpString:
52 out = true;
53 break;
54 default: break;
55 }
56 break;
57 case kLayoutDebug2:
58 switch (op) {
59 case SpvOpName:
60 case SpvOpMemberName:
61 out = true;
62 break;
63 default: break;
64 }
65 break;
66 case kLayoutDebug3:
67 // Only OpModuleProcessed is allowed here.
68 out = (op == SpvOpModuleProcessed);
69 break;
70 case kLayoutAnnotations:
71 switch (op) {
72 case SpvOpDecorate:
73 case SpvOpMemberDecorate:
74 case SpvOpGroupDecorate:
75 case SpvOpGroupMemberDecorate:
76 case SpvOpDecorationGroup:
77 case SpvOpDecorateId:
78 case SpvOpDecorateStringGOOGLE:
79 case SpvOpMemberDecorateStringGOOGLE:
80 out = true;
81 break;
82 default: break;
83 }
84 break;
85 case kLayoutTypes:
86 if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
87 out = true;
88 break;
89 }
90 switch (op) {
91 case SpvOpTypeForwardPointer:
92 case SpvOpVariable:
93 case SpvOpLine:
94 case SpvOpNoLine:
95 case SpvOpUndef:
96 // SpvOpExtInst is only allowed here for certain extended instruction
97 // sets. This will be checked separately
98 case SpvOpExtInst:
99 out = true;
100 break;
101 default: break;
102 }
103 break;
104 case kLayoutFunctionDeclarations:
105 case kLayoutFunctionDefinitions:
106 // NOTE: These instructions should NOT be in these layout sections
107 if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
108 out = false;
109 break;
110 }
111 switch (op) {
112 case SpvOpCapability:
113 case SpvOpExtension:
114 case SpvOpExtInstImport:
115 case SpvOpMemoryModel:
116 case SpvOpEntryPoint:
117 case SpvOpExecutionMode:
118 case SpvOpExecutionModeId:
119 case SpvOpSourceContinued:
120 case SpvOpSource:
121 case SpvOpSourceExtension:
122 case SpvOpString:
123 case SpvOpName:
124 case SpvOpMemberName:
125 case SpvOpModuleProcessed:
126 case SpvOpDecorate:
127 case SpvOpMemberDecorate:
128 case SpvOpGroupDecorate:
129 case SpvOpGroupMemberDecorate:
130 case SpvOpDecorationGroup:
131 case SpvOpTypeForwardPointer:
132 out = false;
133 break;
134 default:
135 out = true;
136 break;
137 }
138 }
139 // clang-format on
140 return out;
141}
142
143// Counts the number of instructions and functions in the file.
144spv_result_t CountInstructions(void* user_data,
145 const spv_parsed_instruction_t* inst) {
146 ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
147 if (inst->opcode == SpvOpFunction) _.increment_total_functions();
148 _.increment_total_instructions();
149
150 return SPV_SUCCESS;
151}
152
153spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
154 uint32_t version, uint32_t generator, uint32_t id_bound,
155 uint32_t) {
156 ValidationState_t& vstate =
157 *(reinterpret_cast<ValidationState_t*>(user_data));
158 vstate.setIdBound(id_bound);
159 vstate.setGenerator(generator);
160 vstate.setVersion(version);
161
162 return SPV_SUCCESS;
163}
164
165// Add features based on SPIR-V core version number.
166void UpdateFeaturesBasedOnSpirvVersion(ValidationState_t::Feature* features,
167 uint32_t version) {
168 assert(features);
169 if (version >= SPV_SPIRV_VERSION_WORD(1, 4)) {
170 features->select_between_composites = true;
171 features->copy_memory_permits_two_memory_accesses = true;
172 features->uconvert_spec_constant_op = true;
173 features->nonwritable_var_in_function_or_private = true;
174 }
175}
176
177} // namespace
178
179ValidationState_t::ValidationState_t(const spv_const_context ctx,
180 const spv_const_validator_options opt,
181 const uint32_t* words,
182 const size_t num_words,
183 const uint32_t max_warnings)
184 : context_(ctx),
185 options_(opt),
186 words_(words),
187 num_words_(num_words),
188 unresolved_forward_ids_{},
189 operand_names_{},
190 current_layout_section_(kLayoutCapabilities),
191 module_functions_(),
192 module_capabilities_(),
193 module_extensions_(),
194 ordered_instructions_(),
195 all_definitions_(),
196 global_vars_(),
197 local_vars_(),
198 struct_nesting_depth_(),
199 struct_has_nested_blockorbufferblock_struct_(),
200 grammar_(ctx),
201 addressing_model_(SpvAddressingModelMax),
202 memory_model_(SpvMemoryModelMax),
203 pointer_size_and_alignment_(0),
204 in_function_(false),
205 num_of_warnings_(0),
206 max_num_of_warnings_(max_warnings) {
207 assert(opt && "Validator options may not be Null.");
208
209 const auto env = context_->target_env;
210
211 if (spvIsVulkanEnv(env)) {
212 // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
213 if (env != SPV_ENV_VULKAN_1_0) {
214 features_.env_relaxed_block_layout = true;
215 }
216 }
217
218 // Only attempt to count if we have words, otherwise let the other validation
219 // fail and generate an error.
220 if (num_words > 0) {
221 // Count the number of instructions in the binary.
222 // This parse should not produce any error messages. Hijack the context and
223 // replace the message consumer so that we do not pollute any state in input
224 // consumer.
225 spv_context_t hijacked_context = *ctx;
226 hijacked_context.consumer = [](spv_message_level_t, const char*,
227 const spv_position_t&, const char*) {};
228 spvBinaryParse(&hijacked_context, this, words, num_words, setHeader,
229 CountInstructions,
230 /* diagnostic = */ nullptr);
231 preallocateStorage();
232 }
233 UpdateFeaturesBasedOnSpirvVersion(&features_, version_);
234
235 friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
236 context_, words_, num_words_);
237 name_mapper_ = friendly_mapper_->GetNameMapper();
238}
239
240void ValidationState_t::preallocateStorage() {
241 ordered_instructions_.reserve(total_instructions_);
242 module_functions_.reserve(total_functions_);
243}
244
245spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
246 unresolved_forward_ids_.insert(id);
247 return SPV_SUCCESS;
248}
249
250spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
251 unresolved_forward_ids_.erase(id);
252 return SPV_SUCCESS;
253}
254
255spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
256 forward_pointer_ids_.insert(id);
257 return SPV_SUCCESS;
258}
259
260bool ValidationState_t::IsForwardPointer(uint32_t id) const {
261 return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
262}
263
264void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
265 operand_names_[id] = name;
266}
267
268std::string ValidationState_t::getIdName(uint32_t id) const {
269 const std::string id_name = name_mapper_(id);
270
271 std::stringstream out;
272 out << id << "[%" << id_name << "]";
273 return out.str();
274}
275
276size_t ValidationState_t::unresolved_forward_id_count() const {
277 return unresolved_forward_ids_.size();
278}
279
280std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
281 std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
282 std::end(unresolved_forward_ids_));
283 return out;
284}
285
286bool ValidationState_t::IsDefinedId(uint32_t id) const {
287 return all_definitions_.find(id) != std::end(all_definitions_);
288}
289
290const Instruction* ValidationState_t::FindDef(uint32_t id) const {
291 auto it = all_definitions_.find(id);
292 if (it == all_definitions_.end()) return nullptr;
293 return it->second;
294}
295
296Instruction* ValidationState_t::FindDef(uint32_t id) {
297 auto it = all_definitions_.find(id);
298 if (it == all_definitions_.end()) return nullptr;
299 return it->second;
300}
301
302ModuleLayoutSection ValidationState_t::current_layout_section() const {
303 return current_layout_section_;
304}
305
306void ValidationState_t::ProgressToNextLayoutSectionOrder() {
307 // Guard against going past the last element(kLayoutFunctionDefinitions)
308 if (current_layout_section_ <= kLayoutFunctionDefinitions) {
309 current_layout_section_ =
310 static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
311 }
312}
313
314bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
315 return IsInstructionInLayoutSection(current_layout_section_, op);
316}
317
318DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
319 const Instruction* inst) {
320 if (error_code == SPV_WARNING) {
321 if (num_of_warnings_ == max_num_of_warnings_) {
322 DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
323 << "Other warnings have been suppressed.\n";
324 }
325 if (num_of_warnings_ >= max_num_of_warnings_) {
326 return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
327 }
328 ++num_of_warnings_;
329 }
330
331 std::string disassembly;
332 if (inst) disassembly = Disassemble(*inst);
333
334 return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
335 context_->consumer, disassembly, error_code);
336}
337
338std::vector<Function>& ValidationState_t::functions() {
339 return module_functions_;
340}
341
342Function& ValidationState_t::current_function() {
343 assert(in_function_body());
344 return module_functions_.back();
345}
346
347const Function& ValidationState_t::current_function() const {
348 assert(in_function_body());
349 return module_functions_.back();
350}
351
352const Function* ValidationState_t::function(uint32_t id) const {
353 const auto it = id_to_function_.find(id);
354 if (it == id_to_function_.end()) return nullptr;
355 return it->second;
356}
357
358Function* ValidationState_t::function(uint32_t id) {
359 auto it = id_to_function_.find(id);
360 if (it == id_to_function_.end()) return nullptr;
361 return it->second;
362}
363
364bool ValidationState_t::in_function_body() const { return in_function_; }
365
366bool ValidationState_t::in_block() const {
367 return module_functions_.empty() == false &&
368 module_functions_.back().current_block() != nullptr;
369}
370
371void ValidationState_t::RegisterCapability(SpvCapability cap) {
372 // Avoid redundant work. Otherwise the recursion could induce work
373 // quadrdatic in the capability dependency depth. (Ok, not much, but
374 // it's something.)
375 if (module_capabilities_.Contains(cap)) return;
376
377 module_capabilities_.Add(cap);
378 spv_operand_desc desc;
379 if (SPV_SUCCESS ==
380 grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
381 CapabilitySet(desc->numCapabilities, desc->capabilities)
382 .ForEach([this](SpvCapability c) { RegisterCapability(c); });
383 }
384
385 switch (cap) {
386 case SpvCapabilityKernel:
387 features_.group_ops_reduce_and_scans = true;
388 break;
389 case SpvCapabilityInt8:
390 features_.use_int8_type = true;
391 features_.declare_int8_type = true;
392 break;
393 case SpvCapabilityStorageBuffer8BitAccess:
394 case SpvCapabilityUniformAndStorageBuffer8BitAccess:
395 case SpvCapabilityStoragePushConstant8:
396 features_.declare_int8_type = true;
397 break;
398 case SpvCapabilityInt16:
399 features_.declare_int16_type = true;
400 break;
401 case SpvCapabilityFloat16:
402 case SpvCapabilityFloat16Buffer:
403 features_.declare_float16_type = true;
404 break;
405 case SpvCapabilityStorageUniformBufferBlock16:
406 case SpvCapabilityStorageUniform16:
407 case SpvCapabilityStoragePushConstant16:
408 case SpvCapabilityStorageInputOutput16:
409 features_.declare_int16_type = true;
410 features_.declare_float16_type = true;
411 features_.free_fp_rounding_mode = true;
412 break;
413 case SpvCapabilityVariablePointers:
414 features_.variable_pointers = true;
415 features_.variable_pointers_storage_buffer = true;
416 break;
417 case SpvCapabilityVariablePointersStorageBuffer:
418 features_.variable_pointers_storage_buffer = true;
419 break;
420 default:
421 break;
422 }
423}
424
425void ValidationState_t::RegisterExtension(Extension ext) {
426 if (module_extensions_.Contains(ext)) return;
427
428 module_extensions_.Add(ext);
429
430 switch (ext) {
431 case kSPV_AMD_gpu_shader_half_float:
432 case kSPV_AMD_gpu_shader_half_float_fetch:
433 // SPV_AMD_gpu_shader_half_float enables float16 type.
434 // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
435 features_.declare_float16_type = true;
436 break;
437 case kSPV_AMD_gpu_shader_int16:
438 // This is not yet in the extension, but it's recommended for it.
439 // See https://github.com/KhronosGroup/glslang/issues/848
440 features_.uconvert_spec_constant_op = true;
441 break;
442 case kSPV_AMD_shader_ballot:
443 // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
444 // enables the use of group operations Reduce, InclusiveScan,
445 // and ExclusiveScan. Enable it manually.
446 // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
447 features_.group_ops_reduce_and_scans = true;
448 break;
449 default:
450 break;
451 }
452}
453
454bool ValidationState_t::HasAnyOfCapabilities(
455 const CapabilitySet& capabilities) const {
456 return module_capabilities_.HasAnyOf(capabilities);
457}
458
459bool ValidationState_t::HasAnyOfExtensions(
460 const ExtensionSet& extensions) const {
461 return module_extensions_.HasAnyOf(extensions);
462}
463
464void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
465 addressing_model_ = am;
466 switch (am) {
467 case SpvAddressingModelPhysical32:
468 pointer_size_and_alignment_ = 4;
469 break;
470 default:
471 // fall through
472 case SpvAddressingModelPhysical64:
473 case SpvAddressingModelPhysicalStorageBuffer64EXT:
474 pointer_size_and_alignment_ = 8;
475 break;
476 }
477}
478
479SpvAddressingModel ValidationState_t::addressing_model() const {
480 return addressing_model_;
481}
482
483void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
484 memory_model_ = mm;
485}
486
487SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
488
489spv_result_t ValidationState_t::RegisterFunction(
490 uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
491 uint32_t function_type_id) {
492 assert(in_function_body() == false &&
493 "RegisterFunction can only be called when parsing the binary outside "
494 "of another function");
495 in_function_ = true;
496 module_functions_.emplace_back(id, ret_type_id, function_control,
497 function_type_id);
498 id_to_function_.emplace(id, &current_function());
499
500 // TODO(umar): validate function type and type_id
501
502 return SPV_SUCCESS;
503}
504
505spv_result_t ValidationState_t::RegisterFunctionEnd() {
506 assert(in_function_body() == true &&
507 "RegisterFunctionEnd can only be called when parsing the binary "
508 "inside of another function");
509 assert(in_block() == false &&
510 "RegisterFunctionParameter can only be called when parsing the binary "
511 "ouside of a block");
512 current_function().RegisterFunctionEnd();
513 in_function_ = false;
514 return SPV_SUCCESS;
515}
516
517Instruction* ValidationState_t::AddOrderedInstruction(
518 const spv_parsed_instruction_t* inst) {
519 ordered_instructions_.emplace_back(inst);
520 ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
521 return &ordered_instructions_.back();
522}
523
524// Improves diagnostic messages by collecting names of IDs
525void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
526 switch (inst->opcode()) {
527 case SpvOpName: {
528 const auto target = inst->GetOperandAs<uint32_t>(0);
529 const auto* str = reinterpret_cast<const char*>(inst->words().data() +
530 inst->operand(1).offset);
531 AssignNameToId(target, str);
532 break;
533 }
534 case SpvOpMemberName: {
535 const auto target = inst->GetOperandAs<uint32_t>(0);
536 const auto* str = reinterpret_cast<const char*>(inst->words().data() +
537 inst->operand(2).offset);
538 AssignNameToId(target, str);
539 break;
540 }
541 case SpvOpSourceContinued:
542 case SpvOpSource:
543 case SpvOpSourceExtension:
544 case SpvOpString:
545 case SpvOpLine:
546 case SpvOpNoLine:
547 default:
548 break;
549 }
550}
551
552void ValidationState_t::RegisterInstruction(Instruction* inst) {
553 if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
554
555 // If the instruction is using an OpTypeSampledImage as an operand, it should
556 // be recorded. The validator will ensure that all usages of an
557 // OpTypeSampledImage and its definition are in the same basic block.
558 for (uint16_t i = 0; i < inst->operands().size(); ++i) {
559 const spv_parsed_operand_t& operand = inst->operand(i);
560 if (SPV_OPERAND_TYPE_ID == operand.type) {
561 const uint32_t operand_word = inst->word(operand.offset);
562 Instruction* operand_inst = FindDef(operand_word);
563 if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
564 RegisterSampledImageConsumer(operand_word, inst);
565 }
566 }
567 }
568}
569
570std::vector<Instruction*> ValidationState_t::getSampledImageConsumers(
571 uint32_t sampled_image_id) const {
572 std::vector<Instruction*> result;
573 auto iter = sampled_image_consumers_.find(sampled_image_id);
574 if (iter != sampled_image_consumers_.end()) {
575 result = iter->second;
576 }
577 return result;
578}
579
580void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
581 Instruction* consumer) {
582 sampled_image_consumers_[sampled_image_id].push_back(consumer);
583}
584
585uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
586
587void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
588
589bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
590 std::vector<uint32_t> key;
591 key.push_back(static_cast<uint32_t>(inst->opcode()));
592 for (size_t index = 0; index < inst->operands().size(); ++index) {
593 const spv_parsed_operand_t& operand = inst->operand(index);
594
595 if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
596
597 const int words_begin = operand.offset;
598 const int words_end = words_begin + operand.num_words;
599 assert(words_end <= static_cast<int>(inst->words().size()));
600
601 key.insert(key.end(), inst->words().begin() + words_begin,
602 inst->words().begin() + words_end);
603 }
604
605 return unique_type_declarations_.insert(std::move(key)).second;
606}
607
608uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
609 const Instruction* inst = FindDef(id);
610 return inst ? inst->type_id() : 0;
611}
612
613SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
614 const Instruction* inst = FindDef(id);
615 return inst ? inst->opcode() : SpvOpNop;
616}
617
618uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
619 const Instruction* inst = FindDef(id);
620 assert(inst);
621
622 switch (inst->opcode()) {
623 case SpvOpTypeFloat:
624 case SpvOpTypeInt:
625 case SpvOpTypeBool:
626 return id;
627
628 case SpvOpTypeVector:
629 return inst->word(2);
630
631 case SpvOpTypeMatrix:
632 return GetComponentType(inst->word(2));
633
634 case SpvOpTypeCooperativeMatrixNV:
635 return inst->word(2);
636
637 default:
638 break;
639 }
640
641 if (inst->type_id()) return GetComponentType(inst->type_id());
642
643 assert(0);
644 return 0;
645}
646
647uint32_t ValidationState_t::GetDimension(uint32_t id) const {
648 const Instruction* inst = FindDef(id);
649 assert(inst);
650
651 switch (inst->opcode()) {
652 case SpvOpTypeFloat:
653 case SpvOpTypeInt:
654 case SpvOpTypeBool:
655 return 1;
656
657 case SpvOpTypeVector:
658 case SpvOpTypeMatrix:
659 return inst->word(3);
660
661 case SpvOpTypeCooperativeMatrixNV:
662 // Actual dimension isn't known, return 0
663 return 0;
664
665 default:
666 break;
667 }
668
669 if (inst->type_id()) return GetDimension(inst->type_id());
670
671 assert(0);
672 return 0;
673}
674
675uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
676 const uint32_t component_type_id = GetComponentType(id);
677 const Instruction* inst = FindDef(component_type_id);
678 assert(inst);
679
680 if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
681 return inst->word(2);
682
683 if (inst->opcode() == SpvOpTypeBool) return 1;
684
685 assert(0);
686 return 0;
687}
688
689bool ValidationState_t::IsVoidType(uint32_t id) const {
690 const Instruction* inst = FindDef(id);
691 assert(inst);
692 return inst->opcode() == SpvOpTypeVoid;
693}
694
695bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
696 const Instruction* inst = FindDef(id);
697 assert(inst);
698 return inst->opcode() == SpvOpTypeFloat;
699}
700
701bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
702 const Instruction* inst = FindDef(id);
703 assert(inst);
704
705 if (inst->opcode() == SpvOpTypeVector) {
706 return IsFloatScalarType(GetComponentType(id));
707 }
708
709 return false;
710}
711
712bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
713 const Instruction* inst = FindDef(id);
714 assert(inst);
715
716 if (inst->opcode() == SpvOpTypeFloat) {
717 return true;
718 }
719
720 if (inst->opcode() == SpvOpTypeVector) {
721 return IsFloatScalarType(GetComponentType(id));
722 }
723
724 return false;
725}
726
727bool ValidationState_t::IsIntScalarType(uint32_t id) const {
728 const Instruction* inst = FindDef(id);
729 assert(inst);
730 return inst->opcode() == SpvOpTypeInt;
731}
732
733bool ValidationState_t::IsIntVectorType(uint32_t id) const {
734 const Instruction* inst = FindDef(id);
735 assert(inst);
736
737 if (inst->opcode() == SpvOpTypeVector) {
738 return IsIntScalarType(GetComponentType(id));
739 }
740
741 return false;
742}
743
744bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
745 const Instruction* inst = FindDef(id);
746 assert(inst);
747
748 if (inst->opcode() == SpvOpTypeInt) {
749 return true;
750 }
751
752 if (inst->opcode() == SpvOpTypeVector) {
753 return IsIntScalarType(GetComponentType(id));
754 }
755
756 return false;
757}
758
759bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
760 const Instruction* inst = FindDef(id);
761 assert(inst);
762 return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
763}
764
765bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
766 const Instruction* inst = FindDef(id);
767 assert(inst);
768
769 if (inst->opcode() == SpvOpTypeVector) {
770 return IsUnsignedIntScalarType(GetComponentType(id));
771 }
772
773 return false;
774}
775
776bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
777 const Instruction* inst = FindDef(id);
778 assert(inst);
779 return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
780}
781
782bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
783 const Instruction* inst = FindDef(id);
784 assert(inst);
785
786 if (inst->opcode() == SpvOpTypeVector) {
787 return IsSignedIntScalarType(GetComponentType(id));
788 }
789
790 return false;
791}
792
793bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
794 const Instruction* inst = FindDef(id);
795 assert(inst);
796 return inst->opcode() == SpvOpTypeBool;
797}
798
799bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
800 const Instruction* inst = FindDef(id);
801 assert(inst);
802
803 if (inst->opcode() == SpvOpTypeVector) {
804 return IsBoolScalarType(GetComponentType(id));
805 }
806
807 return false;
808}
809
810bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
811 const Instruction* inst = FindDef(id);
812 assert(inst);
813
814 if (inst->opcode() == SpvOpTypeBool) {
815 return true;
816 }
817
818 if (inst->opcode() == SpvOpTypeVector) {
819 return IsBoolScalarType(GetComponentType(id));
820 }
821
822 return false;
823}
824
825bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
826 const Instruction* inst = FindDef(id);
827 assert(inst);
828
829 if (inst->opcode() == SpvOpTypeMatrix) {
830 return IsFloatScalarType(GetComponentType(id));
831 }
832
833 return false;
834}
835
836bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
837 uint32_t* num_cols,
838 uint32_t* column_type,
839 uint32_t* component_type) const {
840 if (!id) return false;
841
842 const Instruction* mat_inst = FindDef(id);
843 assert(mat_inst);
844 if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
845
846 const uint32_t vec_type = mat_inst->word(2);
847 const Instruction* vec_inst = FindDef(vec_type);
848 assert(vec_inst);
849
850 if (vec_inst->opcode() != SpvOpTypeVector) {
851 assert(0);
852 return false;
853 }
854
855 *num_cols = mat_inst->word(3);
856 *num_rows = vec_inst->word(3);
857 *column_type = mat_inst->word(2);
858 *component_type = vec_inst->word(2);
859
860 return true;
861}
862
863bool ValidationState_t::GetStructMemberTypes(
864 uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
865 member_types->clear();
866 if (!struct_type_id) return false;
867
868 const Instruction* inst = FindDef(struct_type_id);
869 assert(inst);
870 if (inst->opcode() != SpvOpTypeStruct) return false;
871
872 *member_types =
873 std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
874
875 if (member_types->empty()) return false;
876
877 return true;
878}
879
880bool ValidationState_t::IsPointerType(uint32_t id) const {
881 const Instruction* inst = FindDef(id);
882 assert(inst);
883 return inst->opcode() == SpvOpTypePointer;
884}
885
886bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
887 uint32_t* storage_class) const {
888 if (!id) return false;
889
890 const Instruction* inst = FindDef(id);
891 assert(inst);
892 if (inst->opcode() != SpvOpTypePointer) return false;
893
894 *storage_class = inst->word(2);
895 *data_type = inst->word(3);
896 return true;
897}
898
899bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
900 const Instruction* inst = FindDef(id);
901 assert(inst);
902 return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
903}
904
905bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
906 if (!IsCooperativeMatrixType(id)) return false;
907 return IsFloatScalarType(FindDef(id)->word(2));
908}
909
910bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
911 if (!IsCooperativeMatrixType(id)) return false;
912 return IsIntScalarType(FindDef(id)->word(2));
913}
914
915bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
916 if (!IsCooperativeMatrixType(id)) return false;
917 return IsUnsignedIntScalarType(FindDef(id)->word(2));
918}
919
920spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
921 const Instruction* inst, uint32_t m1, uint32_t m2) {
922 const auto m1_type = FindDef(m1);
923 const auto m2_type = FindDef(m2);
924
925 if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
926 m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
927 return diag(SPV_ERROR_INVALID_DATA, inst)
928 << "Expected cooperative matrix types";
929 }
930
931 uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
932 uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
933 uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
934
935 uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
936 uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
937 uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
938
939 bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
940 m2_is_const_int32 = false;
941 uint32_t m1_value = 0, m2_value = 0;
942
943 std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
944 EvalInt32IfConst(m1_scope_id);
945 std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
946 EvalInt32IfConst(m2_scope_id);
947
948 if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
949 return diag(SPV_ERROR_INVALID_DATA, inst)
950 << "Expected scopes of Matrix and Result Type to be "
951 << "identical";
952 }
953
954 std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
955 EvalInt32IfConst(m1_rows_id);
956 std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
957 EvalInt32IfConst(m2_rows_id);
958
959 if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
960 return diag(SPV_ERROR_INVALID_DATA, inst)
961 << "Expected rows of Matrix type and Result Type to be "
962 << "identical";
963 }
964
965 std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
966 EvalInt32IfConst(m1_cols_id);
967 std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
968 EvalInt32IfConst(m2_cols_id);
969
970 if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
971 return diag(SPV_ERROR_INVALID_DATA, inst)
972 << "Expected columns of Matrix type and Result Type to be "
973 << "identical";
974 }
975
976 return SPV_SUCCESS;
977}
978
979uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
980 size_t operand_index) const {
981 return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
982}
983
984bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
985 const Instruction* inst = FindDef(id);
986 if (!inst) {
987 assert(0 && "Instruction not found");
988 return false;
989 }
990
991 if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
992 return false;
993
994 if (!IsIntScalarType(inst->type_id())) return false;
995
996 if (inst->words().size() == 4) {
997 *val = inst->word(3);
998 } else {
999 assert(inst->words().size() == 5);
1000 *val = inst->word(3);
1001 *val |= uint64_t(inst->word(4)) << 32;
1002 }
1003 return true;
1004}
1005
1006std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
1007 uint32_t id) const {
1008 const Instruction* const inst = FindDef(id);
1009 assert(inst);
1010 const uint32_t type = inst->type_id();
1011
1012 if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
1013 return std::make_tuple(false, false, 0);
1014 }
1015
1016 // Spec constant values cannot be evaluated so don't consider constant for
1017 // the purpose of this method.
1018 if (!spvOpcodeIsConstant(inst->opcode()) ||
1019 spvOpcodeIsSpecConstant(inst->opcode())) {
1020 return std::make_tuple(true, false, 0);
1021 }
1022
1023 if (inst->opcode() == SpvOpConstantNull) {
1024 return std::make_tuple(true, true, 0);
1025 }
1026
1027 assert(inst->words().size() == 4);
1028 return std::make_tuple(true, true, inst->word(3));
1029}
1030
1031void ValidationState_t::ComputeFunctionToEntryPointMapping() {
1032 for (const uint32_t entry_point : entry_points()) {
1033 std::stack<uint32_t> call_stack;
1034 std::set<uint32_t> visited;
1035 call_stack.push(entry_point);
1036 while (!call_stack.empty()) {
1037 const uint32_t called_func_id = call_stack.top();
1038 call_stack.pop();
1039 if (!visited.insert(called_func_id).second) continue;
1040
1041 function_to_entry_points_[called_func_id].push_back(entry_point);
1042
1043 const Function* called_func = function(called_func_id);
1044 if (called_func) {
1045 // Other checks should error out on this invalid SPIR-V.
1046 for (const uint32_t new_call : called_func->function_call_targets()) {
1047 call_stack.push(new_call);
1048 }
1049 }
1050 }
1051 }
1052}
1053
1054void ValidationState_t::ComputeRecursiveEntryPoints() {
1055 for (const Function& func : functions()) {
1056 std::stack<uint32_t> call_stack;
1057 std::set<uint32_t> visited;
1058
1059 for (const uint32_t new_call : func.function_call_targets()) {
1060 call_stack.push(new_call);
1061 }
1062
1063 while (!call_stack.empty()) {
1064 const uint32_t called_func_id = call_stack.top();
1065 call_stack.pop();
1066
1067 if (!visited.insert(called_func_id).second) continue;
1068
1069 if (called_func_id == func.id()) {
1070 for (const uint32_t entry_point :
1071 function_to_entry_points_[called_func_id])
1072 recursive_entry_points_.insert(entry_point);
1073 break;
1074 }
1075
1076 const Function* called_func = function(called_func_id);
1077 if (called_func) {
1078 // Other checks should error out on this invalid SPIR-V.
1079 for (const uint32_t new_call : called_func->function_call_targets()) {
1080 call_stack.push(new_call);
1081 }
1082 }
1083 }
1084 }
1085}
1086
1087const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
1088 uint32_t func) const {
1089 auto iter = function_to_entry_points_.find(func);
1090 if (iter == function_to_entry_points_.end()) {
1091 return empty_ids_;
1092 } else {
1093 return iter->second;
1094 }
1095}
1096
1097std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
1098 std::set<uint32_t> referenced_entry_points;
1099 const auto inst = FindDef(id);
1100 if (!inst) return referenced_entry_points;
1101
1102 std::vector<const Instruction*> stack;
1103 stack.push_back(inst);
1104 while (!stack.empty()) {
1105 const auto current_inst = stack.back();
1106 stack.pop_back();
1107
1108 if (const auto func = current_inst->function()) {
1109 // Instruction lives in a function, we can stop searching.
1110 const auto function_entry_points = FunctionEntryPoints(func->id());
1111 referenced_entry_points.insert(function_entry_points.begin(),
1112 function_entry_points.end());
1113 } else {
1114 // Instruction is in the global scope, keep searching its uses.
1115 for (auto pair : current_inst->uses()) {
1116 const auto next_inst = pair.first;
1117 stack.push_back(next_inst);
1118 }
1119 }
1120 }
1121
1122 return referenced_entry_points;
1123}
1124
1125std::string ValidationState_t::Disassemble(const Instruction& inst) const {
1126 const spv_parsed_instruction_t& c_inst(inst.c_inst());
1127 return Disassemble(c_inst.words, c_inst.num_words);
1128}
1129
1130std::string ValidationState_t::Disassemble(const uint32_t* words,
1131 uint16_t num_words) const {
1132 uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
1133 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
1134
1135 return spvInstructionBinaryToText(context()->target_env, words, num_words,
1136 words_, num_words_, disassembly_options);
1137}
1138
1139bool ValidationState_t::LogicallyMatch(const Instruction* lhs,
1140 const Instruction* rhs,
1141 bool check_decorations) {
1142 if (lhs->opcode() != rhs->opcode()) {
1143 return false;
1144 }
1145
1146 if (check_decorations) {
1147 const auto& dec_a = id_decorations(lhs->id());
1148 const auto& dec_b = id_decorations(rhs->id());
1149
1150 for (const auto& dec : dec_b) {
1151 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
1152 return false;
1153 }
1154 }
1155 }
1156
1157 if (lhs->opcode() == SpvOpTypeArray) {
1158 // Size operands must match.
1159 if (lhs->GetOperandAs<uint32_t>(2u) != rhs->GetOperandAs<uint32_t>(2u)) {
1160 return false;
1161 }
1162
1163 // Elements must match or logically match.
1164 const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(1u);
1165 const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(1u);
1166 if (lhs_ele_id == rhs_ele_id) {
1167 return true;
1168 }
1169
1170 const auto lhs_ele = FindDef(lhs_ele_id);
1171 const auto rhs_ele = FindDef(rhs_ele_id);
1172 if (!lhs_ele || !rhs_ele) {
1173 return false;
1174 }
1175 return LogicallyMatch(lhs_ele, rhs_ele, check_decorations);
1176 } else if (lhs->opcode() == SpvOpTypeStruct) {
1177 // Number of elements must match.
1178 if (lhs->operands().size() != rhs->operands().size()) {
1179 return false;
1180 }
1181
1182 for (size_t i = 1u; i < lhs->operands().size(); ++i) {
1183 const auto lhs_ele_id = lhs->GetOperandAs<uint32_t>(i);
1184 const auto rhs_ele_id = rhs->GetOperandAs<uint32_t>(i);
1185 // Elements must match or logically match.
1186 if (lhs_ele_id == rhs_ele_id) {
1187 continue;
1188 }
1189
1190 const auto lhs_ele = FindDef(lhs_ele_id);
1191 const auto rhs_ele = FindDef(rhs_ele_id);
1192 if (!lhs_ele || !rhs_ele) {
1193 return false;
1194 }
1195
1196 if (!LogicallyMatch(lhs_ele, rhs_ele, check_decorations)) {
1197 return false;
1198 }
1199 }
1200
1201 // All checks passed.
1202 return true;
1203 }
1204
1205 // No other opcodes are acceptable at this point. Arrays and structs are
1206 // caught above and if they're elements are not arrays or structs they are
1207 // required to match exactly.
1208 return false;
1209}
1210
1211const Instruction* ValidationState_t::TracePointer(
1212 const Instruction* inst) const {
1213 auto base_ptr = inst;
1214 while (base_ptr->opcode() == SpvOpAccessChain ||
1215 base_ptr->opcode() == SpvOpInBoundsAccessChain ||
1216 base_ptr->opcode() == SpvOpPtrAccessChain ||
1217 base_ptr->opcode() == SpvOpInBoundsPtrAccessChain ||
1218 base_ptr->opcode() == SpvOpCopyObject) {
1219 base_ptr = FindDef(base_ptr->GetOperandAs<uint32_t>(2u));
1220 }
1221 return base_ptr;
1222}
1223
1224bool ValidationState_t::ContainsSizedIntOrFloatType(uint32_t id, SpvOp type,
1225 uint32_t width) const {
1226 if (type != SpvOpTypeInt && type != SpvOpTypeFloat) return false;
1227
1228 const auto inst = FindDef(id);
1229 if (!inst) return false;
1230
1231 if (inst->opcode() == type) {
1232 return inst->GetOperandAs<uint32_t>(1u) == width;
1233 }
1234
1235 switch (inst->opcode()) {
1236 case SpvOpTypeArray:
1237 case SpvOpTypeRuntimeArray:
1238 case SpvOpTypeVector:
1239 case SpvOpTypeMatrix:
1240 case SpvOpTypeImage:
1241 case SpvOpTypeSampledImage:
1242 case SpvOpTypeCooperativeMatrixNV:
1243 return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(1u), type,
1244 width);
1245 case SpvOpTypePointer:
1246 if (IsForwardPointer(id)) return false;
1247 return ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(2u), type,
1248 width);
1249 case SpvOpTypeFunction:
1250 case SpvOpTypeStruct: {
1251 for (uint32_t i = 1; i < inst->operands().size(); ++i) {
1252 if (ContainsSizedIntOrFloatType(inst->GetOperandAs<uint32_t>(i), type,
1253 width))
1254 return true;
1255 }
1256 return false;
1257 }
1258 default:
1259 return false;
1260 }
1261}
1262
1263bool ValidationState_t::ContainsLimitedUseIntOrFloatType(uint32_t id) const {
1264 if ((!HasCapability(SpvCapabilityInt16) &&
1265 ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 16)) ||
1266 (!HasCapability(SpvCapabilityInt8) &&
1267 ContainsSizedIntOrFloatType(id, SpvOpTypeInt, 8)) ||
1268 (!HasCapability(SpvCapabilityFloat16) &&
1269 ContainsSizedIntOrFloatType(id, SpvOpTypeFloat, 16))) {
1270 return true;
1271 }
1272 return false;
1273}
1274
1275bool ValidationState_t::IsValidStorageClass(
1276 SpvStorageClass storage_class) const {
1277 if (spvIsWebGPUEnv(context()->target_env)) {
1278 switch (storage_class) {
1279 case SpvStorageClassUniformConstant:
1280 case SpvStorageClassUniform:
1281 case SpvStorageClassStorageBuffer:
1282 case SpvStorageClassInput:
1283 case SpvStorageClassOutput:
1284 case SpvStorageClassImage:
1285 case SpvStorageClassWorkgroup:
1286 case SpvStorageClassPrivate:
1287 case SpvStorageClassFunction:
1288 return true;
1289 default:
1290 return false;
1291 }
1292 }
1293
1294 if (spvIsVulkanEnv(context()->target_env)) {
1295 switch (storage_class) {
1296 case SpvStorageClassUniformConstant:
1297 case SpvStorageClassUniform:
1298 case SpvStorageClassStorageBuffer:
1299 case SpvStorageClassInput:
1300 case SpvStorageClassOutput:
1301 case SpvStorageClassImage:
1302 case SpvStorageClassWorkgroup:
1303 case SpvStorageClassPrivate:
1304 case SpvStorageClassFunction:
1305 case SpvStorageClassPushConstant:
1306 case SpvStorageClassPhysicalStorageBuffer:
1307 case SpvStorageClassRayPayloadNV:
1308 case SpvStorageClassIncomingRayPayloadNV:
1309 case SpvStorageClassHitAttributeNV:
1310 case SpvStorageClassCallableDataNV:
1311 case SpvStorageClassIncomingCallableDataNV:
1312 case SpvStorageClassShaderRecordBufferNV:
1313 return true;
1314 default:
1315 return false;
1316 }
1317 }
1318
1319 return true;
1320}
1321
1322} // namespace val
1323} // namespace spvtools
1324