1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "fix_storage_class.h"
16
17#include <set>
18
19#include "source/opt/instruction.h"
20#include "source/opt/ir_context.h"
21
22namespace spvtools {
23namespace opt {
24
25Pass::Status FixStorageClass::Process() {
26 bool modified = false;
27
28 get_module()->ForEachInst([this, &modified](Instruction* inst) {
29 if (inst->opcode() == SpvOpVariable) {
30 std::set<uint32_t> seen;
31 std::vector<std::pair<Instruction*, uint32_t>> uses;
32 get_def_use_mgr()->ForEachUse(inst,
33 [&uses](Instruction* use, uint32_t op_idx) {
34 uses.push_back({use, op_idx});
35 });
36
37 for (auto& use : uses) {
38 modified |= PropagateStorageClass(
39 use.first,
40 static_cast<SpvStorageClass>(inst->GetSingleWordInOperand(0)),
41 &seen);
42 assert(seen.empty() && "Seen was not properly reset.");
43 modified |=
44 PropagateType(use.first, inst->type_id(), use.second, &seen);
45 assert(seen.empty() && "Seen was not properly reset.");
46 }
47 }
48 });
49 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
50}
51
52bool FixStorageClass::PropagateStorageClass(Instruction* inst,
53 SpvStorageClass storage_class,
54 std::set<uint32_t>* seen) {
55 if (!IsPointerResultType(inst)) {
56 return false;
57 }
58
59 if (IsPointerToStorageClass(inst, storage_class)) {
60 if (inst->opcode() == SpvOpPhi) {
61 if (!seen->insert(inst->result_id()).second) {
62 return false;
63 }
64 }
65
66 bool modified = false;
67 std::vector<Instruction*> uses;
68 get_def_use_mgr()->ForEachUser(
69 inst, [&uses](Instruction* use) { uses.push_back(use); });
70 for (Instruction* use : uses) {
71 modified |= PropagateStorageClass(use, storage_class, seen);
72 }
73
74 if (inst->opcode() == SpvOpPhi) {
75 seen->erase(inst->result_id());
76 }
77 return modified;
78 }
79
80 switch (inst->opcode()) {
81 case SpvOpAccessChain:
82 case SpvOpPtrAccessChain:
83 case SpvOpInBoundsAccessChain:
84 case SpvOpCopyObject:
85 case SpvOpPhi:
86 case SpvOpSelect:
87 FixInstructionStorageClass(inst, storage_class, seen);
88 return true;
89 case SpvOpFunctionCall:
90 // We cannot be sure of the actual connection between the storage class
91 // of the parameter and the storage class of the result, so we should not
92 // do anything. If the result type needs to be fixed, the function call
93 // should be inlined.
94 return false;
95 case SpvOpImageTexelPointer:
96 case SpvOpLoad:
97 case SpvOpStore:
98 case SpvOpCopyMemory:
99 case SpvOpCopyMemorySized:
100 case SpvOpVariable:
101 case SpvOpBitcast:
102 // Nothing to change for these opcode. The result type is the same
103 // regardless of the storage class of the operand.
104 return false;
105 default:
106 assert(false &&
107 "Not expecting instruction to have a pointer result type.");
108 return false;
109 }
110}
111
112void FixStorageClass::FixInstructionStorageClass(Instruction* inst,
113 SpvStorageClass storage_class,
114 std::set<uint32_t>* seen) {
115 assert(IsPointerResultType(inst) &&
116 "The result type of the instruction must be a pointer.");
117
118 ChangeResultStorageClass(inst, storage_class);
119
120 std::vector<Instruction*> uses;
121 get_def_use_mgr()->ForEachUser(
122 inst, [&uses](Instruction* use) { uses.push_back(use); });
123 for (Instruction* use : uses) {
124 PropagateStorageClass(use, storage_class, seen);
125 }
126}
127
128void FixStorageClass::ChangeResultStorageClass(
129 Instruction* inst, SpvStorageClass storage_class) const {
130 analysis::TypeManager* type_mgr = context()->get_type_mgr();
131 Instruction* result_type_inst = get_def_use_mgr()->GetDef(inst->type_id());
132 assert(result_type_inst->opcode() == SpvOpTypePointer);
133 uint32_t pointee_type_id = result_type_inst->GetSingleWordInOperand(1);
134 uint32_t new_result_type_id =
135 type_mgr->FindPointerToType(pointee_type_id, storage_class);
136 inst->SetResultType(new_result_type_id);
137 context()->UpdateDefUse(inst);
138}
139
140bool FixStorageClass::IsPointerResultType(Instruction* inst) {
141 if (inst->type_id() == 0) {
142 return false;
143 }
144 const analysis::Type* ret_type =
145 context()->get_type_mgr()->GetType(inst->type_id());
146 return ret_type->AsPointer() != nullptr;
147}
148
149bool FixStorageClass::IsPointerToStorageClass(Instruction* inst,
150 SpvStorageClass storage_class) {
151 analysis::TypeManager* type_mgr = context()->get_type_mgr();
152 analysis::Type* pType = type_mgr->GetType(inst->type_id());
153 const analysis::Pointer* result_type = pType->AsPointer();
154
155 if (result_type == nullptr) {
156 return false;
157 }
158
159 return (result_type->storage_class() == storage_class);
160}
161
162bool FixStorageClass::ChangeResultType(Instruction* inst,
163 uint32_t new_type_id) {
164 if (inst->type_id() == new_type_id) {
165 return false;
166 }
167
168 context()->ForgetUses(inst);
169 inst->SetResultType(new_type_id);
170 context()->AnalyzeUses(inst);
171 return true;
172}
173
174bool FixStorageClass::PropagateType(Instruction* inst, uint32_t type_id,
175 uint32_t op_idx, std::set<uint32_t>* seen) {
176 assert(type_id != 0 && "Not given a valid type in PropagateType");
177 bool modified = false;
178
179 // If the type of operand |op_idx| forces the result type of |inst| to a
180 // particular type, then we want find that type.
181 uint32_t new_type_id = 0;
182 switch (inst->opcode()) {
183 case SpvOpAccessChain:
184 case SpvOpPtrAccessChain:
185 case SpvOpInBoundsAccessChain:
186 case SpvOpInBoundsPtrAccessChain:
187 if (op_idx == 2) {
188 new_type_id = WalkAccessChainType(inst, type_id);
189 }
190 break;
191 case SpvOpCopyObject:
192 new_type_id = type_id;
193 break;
194 case SpvOpPhi:
195 if (seen->insert(inst->result_id()).second) {
196 new_type_id = type_id;
197 }
198 break;
199 case SpvOpSelect:
200 if (op_idx > 2) {
201 new_type_id = type_id;
202 }
203 break;
204 case SpvOpFunctionCall:
205 // We cannot be sure of the actual connection between the type
206 // of the parameter and the type of the result, so we should not
207 // do anything. If the result type needs to be fixed, the function call
208 // should be inlined.
209 return false;
210 case SpvOpLoad: {
211 Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
212 new_type_id = type_inst->GetSingleWordInOperand(1);
213 break;
214 }
215 case SpvOpStore: {
216 uint32_t obj_id = inst->GetSingleWordInOperand(1);
217 Instruction* obj_inst = get_def_use_mgr()->GetDef(obj_id);
218 uint32_t obj_type_id = obj_inst->type_id();
219
220 uint32_t ptr_id = inst->GetSingleWordInOperand(0);
221 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
222 uint32_t pointee_type_id = GetPointeeTypeId(ptr_inst);
223
224 if (obj_type_id != pointee_type_id) {
225 uint32_t copy_id = GenerateCopy(obj_inst, pointee_type_id, inst);
226 inst->SetInOperand(1, {copy_id});
227 context()->UpdateDefUse(inst);
228 }
229 } break;
230 case SpvOpCopyMemory:
231 case SpvOpCopyMemorySized:
232 // TODO: May need to expand the copy as we do with the stores.
233 break;
234 case SpvOpCompositeConstruct:
235 case SpvOpCompositeExtract:
236 case SpvOpCompositeInsert:
237 // TODO: DXC does not seem to generate code that will require changes to
238 // these opcode. The can be implemented when they come up.
239 break;
240 case SpvOpImageTexelPointer:
241 case SpvOpBitcast:
242 // Nothing to change for these opcode. The result type is the same
243 // regardless of the type of the operand.
244 return false;
245 default:
246 // I expect the remaining instructions to act on types that are guaranteed
247 // to be unique, so no change will be necessary.
248 break;
249 }
250
251 // If the operand forces the result type, then make sure the result type
252 // matches, and update the uses of |inst|. We do not have to check the uses
253 // of |inst| in the result type is not forced because we are only looking for
254 // issue that come from mismatches between function formal and actual
255 // parameters after the function has been inlined. These parameters are
256 // pointers. Once the type no longer depends on the type of the parameter,
257 // then the types should have be correct.
258 if (new_type_id != 0) {
259 modified = ChangeResultType(inst, new_type_id);
260
261 std::vector<std::pair<Instruction*, uint32_t>> uses;
262 get_def_use_mgr()->ForEachUse(inst,
263 [&uses](Instruction* use, uint32_t idx) {
264 uses.push_back({use, idx});
265 });
266
267 for (auto& use : uses) {
268 PropagateType(use.first, new_type_id, use.second, seen);
269 }
270
271 if (inst->opcode() == SpvOpPhi) {
272 seen->erase(inst->result_id());
273 }
274 }
275 return modified;
276}
277
278uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
279 uint32_t start_idx = 0;
280 switch (inst->opcode()) {
281 case SpvOpAccessChain:
282 case SpvOpInBoundsAccessChain:
283 start_idx = 1;
284 break;
285 case SpvOpPtrAccessChain:
286 case SpvOpInBoundsPtrAccessChain:
287 start_idx = 2;
288 break;
289 default:
290 assert(false);
291 break;
292 }
293
294 Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id);
295 assert(orig_type_inst->opcode() == SpvOpTypePointer);
296 id = orig_type_inst->GetSingleWordInOperand(1);
297
298 for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) {
299 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
300 switch (type_inst->opcode()) {
301 case SpvOpTypeArray:
302 case SpvOpTypeRuntimeArray:
303 case SpvOpTypeMatrix:
304 case SpvOpTypeVector:
305 id = type_inst->GetSingleWordInOperand(0);
306 break;
307 case SpvOpTypeStruct: {
308 const analysis::Constant* index_const =
309 context()->get_constant_mgr()->FindDeclaredConstant(
310 inst->GetSingleWordInOperand(i));
311 uint32_t index = index_const->GetU32();
312 id = type_inst->GetSingleWordInOperand(index);
313 break;
314 }
315 default:
316 break;
317 }
318 assert(id != 0 &&
319 "Tried to extract from an object where it cannot be done.");
320 }
321
322 return context()->get_type_mgr()->FindPointerToType(
323 id,
324 static_cast<SpvStorageClass>(orig_type_inst->GetSingleWordInOperand(0)));
325}
326
327// namespace opt
328
329} // namespace opt
330} // namespace spvtools
331