1// Copyright (c) 2019 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/generate_webgpu_initializers_pass.h"
16#include "source/opt/ir_context.h"
17
18namespace spvtools {
19namespace opt {
20
21using inst_iterator = InstructionList::iterator;
22
23namespace {
24
25bool NeedsWebGPUInitializer(Instruction* inst) {
26 if (inst->opcode() != SpvOpVariable) return false;
27
28 auto storage_class = inst->GetSingleWordOperand(2);
29 if (storage_class != SpvStorageClassOutput &&
30 storage_class != SpvStorageClassPrivate &&
31 storage_class != SpvStorageClassFunction) {
32 return false;
33 }
34
35 if (inst->NumOperands() > 3) return false;
36
37 return true;
38}
39
40} // namespace
41
42Pass::Status GenerateWebGPUInitializersPass::Process() {
43 auto* module = context()->module();
44 bool changed = false;
45
46 // Handle global/module scoped variables
47 for (auto iter = module->types_values_begin();
48 iter != module->types_values_end(); ++iter) {
49 Instruction* inst = &(*iter);
50
51 if (inst->opcode() == SpvOpConstantNull) {
52 null_constant_type_map_[inst->type_id()] = inst;
53 seen_null_constants_.insert(inst);
54 continue;
55 }
56
57 if (!NeedsWebGPUInitializer(inst)) continue;
58
59 changed = true;
60
61 auto* constant_inst = GetNullConstantForVariable(inst);
62 if (!constant_inst) return Status::Failure;
63
64 if (seen_null_constants_.find(constant_inst) ==
65 seen_null_constants_.end()) {
66 constant_inst->InsertBefore(inst);
67 null_constant_type_map_[inst->type_id()] = inst;
68 seen_null_constants_.insert(inst);
69 }
70 AddNullInitializerToVariable(constant_inst, inst);
71 }
72
73 // Handle local/function scoped variables
74 for (auto func = module->begin(); func != module->end(); ++func) {
75 auto block = func->entry().get();
76 for (auto iter = block->begin();
77 iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
78 Instruction* inst = &(*iter);
79 if (!NeedsWebGPUInitializer(inst)) continue;
80
81 changed = true;
82 auto* constant_inst = GetNullConstantForVariable(inst);
83 if (!constant_inst) return Status::Failure;
84
85 AddNullInitializerToVariable(constant_inst, inst);
86 }
87 }
88
89 return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
90}
91
92Instruction* GenerateWebGPUInitializersPass::GetNullConstantForVariable(
93 Instruction* variable_inst) {
94 auto constant_mgr = context()->get_constant_mgr();
95 auto* def_use_mgr = get_def_use_mgr();
96
97 auto* ptr_inst = def_use_mgr->GetDef(variable_inst->type_id());
98 auto type_id = ptr_inst->GetInOperand(1).words[0];
99 if (null_constant_type_map_.find(type_id) == null_constant_type_map_.end()) {
100 auto* constant_type = context()->get_type_mgr()->GetType(type_id);
101 auto* constant = constant_mgr->GetConstant(constant_type, {});
102 return constant_mgr->GetDefiningInstruction(constant, type_id);
103 } else {
104 return null_constant_type_map_[type_id];
105 }
106}
107
108void GenerateWebGPUInitializersPass::AddNullInitializerToVariable(
109 Instruction* constant_inst, Instruction* variable_inst) {
110 auto constant_id = constant_inst->result_id();
111 variable_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {constant_id}));
112 get_def_use_mgr()->AnalyzeInstUse(variable_inst);
113}
114
115} // namespace opt
116} // namespace spvtools
117