1 | // Copyright (c) 2017 The Khronos Group Inc. |
2 | // Copyright (c) 2017 Valve Corporation |
3 | // Copyright (c) 2017 LunarG Inc. |
4 | // |
5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | // you may not use this file except in compliance with the License. |
7 | // You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, software |
12 | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | // See the License for the specific language governing permissions and |
15 | // limitations under the License. |
16 | |
17 | #include "source/opt/mem_pass.h" |
18 | |
19 | #include <memory> |
20 | #include <set> |
21 | #include <vector> |
22 | |
23 | #include "source/cfa.h" |
24 | #include "source/opt/basic_block.h" |
25 | #include "source/opt/dominator_analysis.h" |
26 | #include "source/opt/ir_context.h" |
27 | #include "source/opt/iterator.h" |
28 | |
29 | namespace spvtools { |
30 | namespace opt { |
31 | |
32 | namespace { |
33 | |
34 | const uint32_t kCopyObjectOperandInIdx = 0; |
35 | const uint32_t kTypePointerStorageClassInIdx = 0; |
36 | const uint32_t kTypePointerTypeIdInIdx = 1; |
37 | |
38 | } // namespace |
39 | |
40 | bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { |
41 | switch (typeInst->opcode()) { |
42 | case SpvOpTypeInt: |
43 | case SpvOpTypeFloat: |
44 | case SpvOpTypeBool: |
45 | case SpvOpTypeVector: |
46 | case SpvOpTypeMatrix: |
47 | case SpvOpTypeImage: |
48 | case SpvOpTypeSampler: |
49 | case SpvOpTypeSampledImage: |
50 | case SpvOpTypePointer: |
51 | return true; |
52 | default: |
53 | break; |
54 | } |
55 | return false; |
56 | } |
57 | |
58 | bool MemPass::IsTargetType(const Instruction* typeInst) const { |
59 | if (IsBaseTargetType(typeInst)) return true; |
60 | if (typeInst->opcode() == SpvOpTypeArray) { |
61 | if (!IsTargetType( |
62 | get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) { |
63 | return false; |
64 | } |
65 | return true; |
66 | } |
67 | if (typeInst->opcode() != SpvOpTypeStruct) return false; |
68 | // All struct members must be math type |
69 | return typeInst->WhileEachInId([this](const uint32_t* tid) { |
70 | Instruction* compTypeInst = get_def_use_mgr()->GetDef(*tid); |
71 | if (!IsTargetType(compTypeInst)) return false; |
72 | return true; |
73 | }); |
74 | } |
75 | |
76 | bool MemPass::IsNonPtrAccessChain(const SpvOp opcode) const { |
77 | return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain; |
78 | } |
79 | |
80 | bool MemPass::IsPtr(uint32_t ptrId) { |
81 | uint32_t varId = ptrId; |
82 | Instruction* ptrInst = get_def_use_mgr()->GetDef(varId); |
83 | while (ptrInst->opcode() == SpvOpCopyObject) { |
84 | varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx); |
85 | ptrInst = get_def_use_mgr()->GetDef(varId); |
86 | } |
87 | const SpvOp op = ptrInst->opcode(); |
88 | if (op == SpvOpVariable || IsNonPtrAccessChain(op)) return true; |
89 | if (op != SpvOpFunctionParameter) return false; |
90 | const uint32_t varTypeId = ptrInst->type_id(); |
91 | const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); |
92 | return varTypeInst->opcode() == SpvOpTypePointer; |
93 | } |
94 | |
95 | Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) { |
96 | *varId = ptrId; |
97 | Instruction* ptrInst = get_def_use_mgr()->GetDef(*varId); |
98 | Instruction* varInst; |
99 | |
100 | if (ptrInst->opcode() != SpvOpVariable && |
101 | ptrInst->opcode() != SpvOpFunctionParameter) { |
102 | varInst = ptrInst->GetBaseAddress(); |
103 | } else { |
104 | varInst = ptrInst; |
105 | } |
106 | if (varInst->opcode() == SpvOpVariable) { |
107 | *varId = varInst->result_id(); |
108 | } else { |
109 | *varId = 0; |
110 | } |
111 | |
112 | while (ptrInst->opcode() == SpvOpCopyObject) { |
113 | uint32_t temp = ptrInst->GetSingleWordInOperand(0); |
114 | ptrInst = get_def_use_mgr()->GetDef(temp); |
115 | } |
116 | |
117 | return ptrInst; |
118 | } |
119 | |
120 | Instruction* MemPass::GetPtr(Instruction* ip, uint32_t* varId) { |
121 | assert(ip->opcode() == SpvOpStore || ip->opcode() == SpvOpLoad || |
122 | ip->opcode() == SpvOpImageTexelPointer || ip->IsAtomicWithLoad()); |
123 | |
124 | // All of these opcode place the pointer in position 0. |
125 | const uint32_t ptrId = ip->GetSingleWordInOperand(0); |
126 | return GetPtr(ptrId, varId); |
127 | } |
128 | |
129 | bool MemPass::HasOnlyNamesAndDecorates(uint32_t id) const { |
130 | return get_def_use_mgr()->WhileEachUser(id, [this](Instruction* user) { |
131 | SpvOp op = user->opcode(); |
132 | if (op != SpvOpName && !IsNonTypeDecorate(op)) { |
133 | return false; |
134 | } |
135 | return true; |
136 | }); |
137 | } |
138 | |
139 | void MemPass::KillAllInsts(BasicBlock* bp, bool killLabel) { |
140 | bp->KillAllInsts(killLabel); |
141 | } |
142 | |
143 | bool MemPass::HasLoads(uint32_t varId) const { |
144 | return !get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { |
145 | SpvOp op = user->opcode(); |
146 | // TODO(): The following is slightly conservative. Could be |
147 | // better handling of non-store/name. |
148 | if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { |
149 | if (HasLoads(user->result_id())) { |
150 | return false; |
151 | } |
152 | } else if (op != SpvOpStore && op != SpvOpName && !IsNonTypeDecorate(op)) { |
153 | return false; |
154 | } |
155 | return true; |
156 | }); |
157 | } |
158 | |
159 | bool MemPass::IsLiveVar(uint32_t varId) const { |
160 | const Instruction* varInst = get_def_use_mgr()->GetDef(varId); |
161 | // assume live if not a variable eg. function parameter |
162 | if (varInst->opcode() != SpvOpVariable) return true; |
163 | // non-function scope vars are live |
164 | const uint32_t varTypeId = varInst->type_id(); |
165 | const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); |
166 | if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != |
167 | SpvStorageClassFunction) |
168 | return true; |
169 | // test if variable is loaded from |
170 | return HasLoads(varId); |
171 | } |
172 | |
173 | void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) { |
174 | get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) { |
175 | SpvOp op = user->opcode(); |
176 | if (IsNonPtrAccessChain(op)) { |
177 | AddStores(user->result_id(), insts); |
178 | } else if (op == SpvOpStore) { |
179 | insts->push(user); |
180 | } |
181 | }); |
182 | } |
183 | |
184 | void MemPass::DCEInst(Instruction* inst, |
185 | const std::function<void(Instruction*)>& call_back) { |
186 | std::queue<Instruction*> deadInsts; |
187 | deadInsts.push(inst); |
188 | while (!deadInsts.empty()) { |
189 | Instruction* di = deadInsts.front(); |
190 | // Don't delete labels |
191 | if (di->opcode() == SpvOpLabel) { |
192 | deadInsts.pop(); |
193 | continue; |
194 | } |
195 | // Remember operands |
196 | std::set<uint32_t> ids; |
197 | di->ForEachInId([&ids](uint32_t* iid) { ids.insert(*iid); }); |
198 | uint32_t varId = 0; |
199 | // Remember variable if dead load |
200 | if (di->opcode() == SpvOpLoad) (void)GetPtr(di, &varId); |
201 | if (call_back) { |
202 | call_back(di); |
203 | } |
204 | context()->KillInst(di); |
205 | // For all operands with no remaining uses, add their instruction |
206 | // to the dead instruction queue. |
207 | for (auto id : ids) |
208 | if (HasOnlyNamesAndDecorates(id)) { |
209 | Instruction* odi = get_def_use_mgr()->GetDef(id); |
210 | if (context()->IsCombinatorInstruction(odi)) deadInsts.push(odi); |
211 | } |
212 | // if a load was deleted and it was the variable's |
213 | // last load, add all its stores to dead queue |
214 | if (varId != 0 && !IsLiveVar(varId)) AddStores(varId, &deadInsts); |
215 | deadInsts.pop(); |
216 | } |
217 | } |
218 | |
219 | MemPass::MemPass() {} |
220 | |
221 | bool MemPass::HasOnlySupportedRefs(uint32_t varId) { |
222 | return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) { |
223 | SpvOp op = user->opcode(); |
224 | if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && |
225 | !IsNonTypeDecorate(op)) { |
226 | return false; |
227 | } |
228 | return true; |
229 | }); |
230 | } |
231 | |
232 | uint32_t MemPass::Type2Undef(uint32_t type_id) { |
233 | const auto uitr = type2undefs_.find(type_id); |
234 | if (uitr != type2undefs_.end()) return uitr->second; |
235 | const uint32_t undefId = TakeNextId(); |
236 | if (undefId == 0) { |
237 | return 0; |
238 | } |
239 | |
240 | std::unique_ptr<Instruction> undef_inst( |
241 | new Instruction(context(), SpvOpUndef, type_id, undefId, {})); |
242 | get_def_use_mgr()->AnalyzeInstDefUse(&*undef_inst); |
243 | get_module()->AddGlobalValue(std::move(undef_inst)); |
244 | type2undefs_[type_id] = undefId; |
245 | return undefId; |
246 | } |
247 | |
248 | bool MemPass::IsTargetVar(uint32_t varId) { |
249 | if (varId == 0) { |
250 | return false; |
251 | } |
252 | |
253 | if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end()) |
254 | return false; |
255 | if (seen_target_vars_.find(varId) != seen_target_vars_.end()) return true; |
256 | const Instruction* varInst = get_def_use_mgr()->GetDef(varId); |
257 | if (varInst->opcode() != SpvOpVariable) return false; |
258 | const uint32_t varTypeId = varInst->type_id(); |
259 | const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId); |
260 | if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) != |
261 | SpvStorageClassFunction) { |
262 | seen_non_target_vars_.insert(varId); |
263 | return false; |
264 | } |
265 | const uint32_t varPteTypeId = |
266 | varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx); |
267 | Instruction* varPteTypeInst = get_def_use_mgr()->GetDef(varPteTypeId); |
268 | if (!IsTargetType(varPteTypeInst)) { |
269 | seen_non_target_vars_.insert(varId); |
270 | return false; |
271 | } |
272 | seen_target_vars_.insert(varId); |
273 | return true; |
274 | } |
275 | |
276 | // Remove all |phi| operands coming from unreachable blocks (i.e., blocks not in |
277 | // |reachable_blocks|). There are two types of removal that this function can |
278 | // perform: |
279 | // |
280 | // 1- Any operand that comes directly from an unreachable block is completely |
281 | // removed. Since the block is unreachable, the edge between the unreachable |
282 | // block and the block holding |phi| has been removed. |
283 | // |
284 | // 2- Any operand that comes via a live block and was defined at an unreachable |
285 | // block gets its value replaced with an OpUndef value. Since the argument |
286 | // was generated in an unreachable block, it no longer exists, so it cannot |
287 | // be referenced. However, since the value does not reach |phi| directly |
288 | // from the unreachable block, the operand cannot be removed from |phi|. |
289 | // Therefore, we replace the argument value with OpUndef. |
290 | // |
291 | // For example, in the switch() below, assume that we want to remove the |
292 | // argument with value %11 coming from block %41. |
293 | // |
294 | // [ ... ] |
295 | // %41 = OpLabel <--- Unreachable block |
296 | // %11 = OpLoad %int %y |
297 | // [ ... ] |
298 | // OpSelectionMerge %16 None |
299 | // OpSwitch %12 %16 10 %13 13 %14 18 %15 |
300 | // %13 = OpLabel |
301 | // OpBranch %16 |
302 | // %14 = OpLabel |
303 | // OpStore %outparm %int_14 |
304 | // OpBranch %16 |
305 | // %15 = OpLabel |
306 | // OpStore %outparm %int_15 |
307 | // OpBranch %16 |
308 | // %16 = OpLabel |
309 | // %30 = OpPhi %int %11 %41 %int_42 %13 %11 %14 %11 %15 |
310 | // |
311 | // Since %41 is now an unreachable block, the first operand of |phi| needs to |
312 | // be removed completely. But the operands (%11 %14) and (%11 %15) cannot be |
313 | // removed because %14 and %15 are reachable blocks. Since %11 no longer exist, |
314 | // in those arguments, we replace all references to %11 with an OpUndef value. |
315 | // This results in |phi| looking like: |
316 | // |
317 | // %50 = OpUndef %int |
318 | // [ ... ] |
319 | // %30 = OpPhi %int %int_42 %13 %50 %14 %50 %15 |
320 | void MemPass::RemovePhiOperands( |
321 | Instruction* phi, const std::unordered_set<BasicBlock*>& reachable_blocks) { |
322 | std::vector<Operand> keep_operands; |
323 | uint32_t type_id = 0; |
324 | // The id of an undefined value we've generated. |
325 | uint32_t undef_id = 0; |
326 | |
327 | // Traverse all the operands in |phi|. Build the new operand vector by adding |
328 | // all the original operands from |phi| except the unwanted ones. |
329 | for (uint32_t i = 0; i < phi->NumOperands();) { |
330 | if (i < 2) { |
331 | // The first two arguments are always preserved. |
332 | keep_operands.push_back(phi->GetOperand(i)); |
333 | ++i; |
334 | continue; |
335 | } |
336 | |
337 | // The remaining Phi arguments come in pairs. Index 'i' contains the |
338 | // variable id, index 'i + 1' is the originating block id. |
339 | assert(i % 2 == 0 && i < phi->NumOperands() - 1 && |
340 | "malformed Phi arguments" ); |
341 | |
342 | BasicBlock* in_block = cfg()->block(phi->GetSingleWordOperand(i + 1)); |
343 | if (reachable_blocks.find(in_block) == reachable_blocks.end()) { |
344 | // If the incoming block is unreachable, remove both operands as this |
345 | // means that the |phi| has lost an incoming edge. |
346 | i += 2; |
347 | continue; |
348 | } |
349 | |
350 | // In all other cases, the operand must be kept but may need to be changed. |
351 | uint32_t arg_id = phi->GetSingleWordOperand(i); |
352 | Instruction* arg_def_instr = get_def_use_mgr()->GetDef(arg_id); |
353 | BasicBlock* def_block = context()->get_instr_block(arg_def_instr); |
354 | if (def_block && |
355 | reachable_blocks.find(def_block) == reachable_blocks.end()) { |
356 | // If the current |phi| argument was defined in an unreachable block, it |
357 | // means that this |phi| argument is no longer defined. Replace it with |
358 | // |undef_id|. |
359 | if (!undef_id) { |
360 | type_id = arg_def_instr->type_id(); |
361 | undef_id = Type2Undef(type_id); |
362 | } |
363 | keep_operands.push_back( |
364 | Operand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, {undef_id})); |
365 | } else { |
366 | // Otherwise, the argument comes from a reachable block or from no block |
367 | // at all (meaning that it was defined in the global section of the |
368 | // program). In both cases, keep the argument intact. |
369 | keep_operands.push_back(phi->GetOperand(i)); |
370 | } |
371 | |
372 | keep_operands.push_back(phi->GetOperand(i + 1)); |
373 | |
374 | i += 2; |
375 | } |
376 | |
377 | context()->ForgetUses(phi); |
378 | phi->ReplaceOperands(keep_operands); |
379 | context()->AnalyzeUses(phi); |
380 | } |
381 | |
382 | void MemPass::RemoveBlock(Function::iterator* bi) { |
383 | auto& rm_block = **bi; |
384 | |
385 | // Remove instructions from the block. |
386 | rm_block.ForEachInst([&rm_block, this](Instruction* inst) { |
387 | // Note that we do not kill the block label instruction here. The label |
388 | // instruction is needed to identify the block, which is needed by the |
389 | // removal of phi operands. |
390 | if (inst != rm_block.GetLabelInst()) { |
391 | context()->KillInst(inst); |
392 | } |
393 | }); |
394 | |
395 | // Remove the label instruction last. |
396 | auto label = rm_block.GetLabelInst(); |
397 | context()->KillInst(label); |
398 | |
399 | *bi = bi->Erase(); |
400 | } |
401 | |
402 | bool MemPass::RemoveUnreachableBlocks(Function* func) { |
403 | bool modified = false; |
404 | |
405 | // Mark reachable all blocks reachable from the function's entry block. |
406 | std::unordered_set<BasicBlock*> reachable_blocks; |
407 | std::unordered_set<BasicBlock*> visited_blocks; |
408 | std::queue<BasicBlock*> worklist; |
409 | reachable_blocks.insert(func->entry().get()); |
410 | |
411 | // Initially mark the function entry point as reachable. |
412 | worklist.push(func->entry().get()); |
413 | |
414 | auto mark_reachable = [&reachable_blocks, &visited_blocks, &worklist, |
415 | this](uint32_t label_id) { |
416 | auto successor = cfg()->block(label_id); |
417 | if (visited_blocks.count(successor) == 0) { |
418 | reachable_blocks.insert(successor); |
419 | worklist.push(successor); |
420 | visited_blocks.insert(successor); |
421 | } |
422 | }; |
423 | |
424 | // Transitively mark all blocks reachable from the entry as reachable. |
425 | while (!worklist.empty()) { |
426 | BasicBlock* block = worklist.front(); |
427 | worklist.pop(); |
428 | |
429 | // All the successors of a live block are also live. |
430 | static_cast<const BasicBlock*>(block)->ForEachSuccessorLabel( |
431 | mark_reachable); |
432 | |
433 | // All the Merge and ContinueTarget blocks of a live block are also live. |
434 | block->ForMergeAndContinueLabel(mark_reachable); |
435 | } |
436 | |
437 | // Update operands of Phi nodes that reference unreachable blocks. |
438 | for (auto& block : *func) { |
439 | // If the block is about to be removed, don't bother updating its |
440 | // Phi instructions. |
441 | if (reachable_blocks.count(&block) == 0) { |
442 | continue; |
443 | } |
444 | |
445 | // If the block is reachable and has Phi instructions, remove all |
446 | // operands from its Phi instructions that reference unreachable blocks. |
447 | // If the block has no Phi instructions, this is a no-op. |
448 | block.ForEachPhiInst([&reachable_blocks, this](Instruction* phi) { |
449 | RemovePhiOperands(phi, reachable_blocks); |
450 | }); |
451 | } |
452 | |
453 | // Erase unreachable blocks. |
454 | for (auto ebi = func->begin(); ebi != func->end();) { |
455 | if (reachable_blocks.count(&*ebi) == 0) { |
456 | RemoveBlock(&ebi); |
457 | modified = true; |
458 | } else { |
459 | ++ebi; |
460 | } |
461 | } |
462 | |
463 | return modified; |
464 | } |
465 | |
466 | bool MemPass::CFGCleanup(Function* func) { |
467 | bool modified = false; |
468 | modified |= RemoveUnreachableBlocks(func); |
469 | return modified; |
470 | } |
471 | |
472 | void MemPass::CollectTargetVars(Function* func) { |
473 | seen_target_vars_.clear(); |
474 | seen_non_target_vars_.clear(); |
475 | type2undefs_.clear(); |
476 | |
477 | // Collect target (and non-) variable sets. Remove variables with |
478 | // non-load/store refs from target variable set |
479 | for (auto& blk : *func) { |
480 | for (auto& inst : blk) { |
481 | switch (inst.opcode()) { |
482 | case SpvOpStore: |
483 | case SpvOpLoad: { |
484 | uint32_t varId; |
485 | (void)GetPtr(&inst, &varId); |
486 | if (!IsTargetVar(varId)) break; |
487 | if (HasOnlySupportedRefs(varId)) break; |
488 | seen_non_target_vars_.insert(varId); |
489 | seen_target_vars_.erase(varId); |
490 | } break; |
491 | default: |
492 | break; |
493 | } |
494 | } |
495 | } |
496 | } |
497 | |
498 | } // namespace opt |
499 | } // namespace spvtools |
500 | |