1 | // Copyright (c) 2019 Google LLC. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/split_invalid_unreachable_pass.h" |
16 | |
17 | #include "source/opt/ir_builder.h" |
18 | #include "source/opt/ir_context.h" |
19 | |
20 | namespace spvtools { |
21 | namespace opt { |
22 | |
23 | Pass::Status SplitInvalidUnreachablePass::Process() { |
24 | bool changed = false; |
25 | std::unordered_set<uint32_t> entry_points; |
26 | for (auto entry_point : context()->module()->entry_points()) { |
27 | entry_points.insert(entry_point.GetSingleWordOperand(1)); |
28 | } |
29 | |
30 | for (auto func = context()->module()->begin(); |
31 | func != context()->module()->end(); ++func) { |
32 | if (entry_points.find(func->result_id()) == entry_points.end()) continue; |
33 | std::unordered_set<uint32_t> continue_targets; |
34 | std::unordered_set<uint32_t> merge_blocks; |
35 | std::unordered_set<BasicBlock*> unreachable_blocks; |
36 | for (auto block = func->begin(); block != func->end(); ++block) { |
37 | unreachable_blocks.insert(&*block); |
38 | uint32_t continue_target = block->ContinueBlockIdIfAny(); |
39 | if (continue_target != 0) continue_targets.insert(continue_target); |
40 | uint32_t merge_block = block->MergeBlockIdIfAny(); |
41 | if (merge_block != 0) merge_blocks.insert(merge_block); |
42 | } |
43 | |
44 | cfg()->ForEachBlockInPostOrder( |
45 | func->entry().get(), [&unreachable_blocks](BasicBlock* inner_block) { |
46 | unreachable_blocks.erase(inner_block); |
47 | }); |
48 | |
49 | for (auto unreachable : unreachable_blocks) { |
50 | uint32_t block_id = unreachable->id(); |
51 | if (continue_targets.find(block_id) == continue_targets.end() || |
52 | merge_blocks.find(block_id) == merge_blocks.end()) { |
53 | continue; |
54 | } |
55 | |
56 | std::vector<std::tuple<Instruction*, uint32_t>> usages; |
57 | context()->get_def_use_mgr()->ForEachUse( |
58 | unreachable->GetLabelInst(), |
59 | [&usages](Instruction* use, uint32_t idx) { |
60 | if ((use->opcode() == SpvOpLoopMerge && idx == 0) || |
61 | use->opcode() == SpvOpSelectionMerge) { |
62 | usages.push_back(std::make_pair(use, idx)); |
63 | } |
64 | }); |
65 | |
66 | for (auto usage : usages) { |
67 | Instruction* use; |
68 | uint32_t idx; |
69 | std::tie(use, idx) = usage; |
70 | uint32_t new_id = context()->TakeNextId(); |
71 | std::unique_ptr<Instruction> new_label( |
72 | new Instruction(context(), SpvOpLabel, 0, new_id, {})); |
73 | get_def_use_mgr()->AnalyzeInstDefUse(new_label.get()); |
74 | std::unique_ptr<BasicBlock> new_block( |
75 | new BasicBlock(std::move(new_label))); |
76 | auto* block_ptr = new_block.get(); |
77 | InstructionBuilder builder(context(), new_block.get(), |
78 | IRContext::kAnalysisDefUse | |
79 | IRContext::kAnalysisInstrToBlockMapping); |
80 | builder.AddUnreachable(); |
81 | cfg()->RegisterBlock(block_ptr); |
82 | (&*func)->InsertBasicBlockBefore(std::move(new_block), unreachable); |
83 | use->SetInOperand(0, {new_id}); |
84 | get_def_use_mgr()->UpdateDefUse(use); |
85 | cfg()->AddEdges(block_ptr); |
86 | changed = true; |
87 | } |
88 | } |
89 | } |
90 | |
91 | return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
92 | } |
93 | |
94 | } // namespace opt |
95 | } // namespace spvtools |
96 | |