1// Copyright (c) 2019 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/split_invalid_unreachable_pass.h"
16
17#include "source/opt/ir_builder.h"
18#include "source/opt/ir_context.h"
19
20namespace spvtools {
21namespace opt {
22
23Pass::Status SplitInvalidUnreachablePass::Process() {
24 bool changed = false;
25 std::unordered_set<uint32_t> entry_points;
26 for (auto entry_point : context()->module()->entry_points()) {
27 entry_points.insert(entry_point.GetSingleWordOperand(1));
28 }
29
30 for (auto func = context()->module()->begin();
31 func != context()->module()->end(); ++func) {
32 if (entry_points.find(func->result_id()) == entry_points.end()) continue;
33 std::unordered_set<uint32_t> continue_targets;
34 std::unordered_set<uint32_t> merge_blocks;
35 std::unordered_set<BasicBlock*> unreachable_blocks;
36 for (auto block = func->begin(); block != func->end(); ++block) {
37 unreachable_blocks.insert(&*block);
38 uint32_t continue_target = block->ContinueBlockIdIfAny();
39 if (continue_target != 0) continue_targets.insert(continue_target);
40 uint32_t merge_block = block->MergeBlockIdIfAny();
41 if (merge_block != 0) merge_blocks.insert(merge_block);
42 }
43
44 cfg()->ForEachBlockInPostOrder(
45 func->entry().get(), [&unreachable_blocks](BasicBlock* inner_block) {
46 unreachable_blocks.erase(inner_block);
47 });
48
49 for (auto unreachable : unreachable_blocks) {
50 uint32_t block_id = unreachable->id();
51 if (continue_targets.find(block_id) == continue_targets.end() ||
52 merge_blocks.find(block_id) == merge_blocks.end()) {
53 continue;
54 }
55
56 std::vector<std::tuple<Instruction*, uint32_t>> usages;
57 context()->get_def_use_mgr()->ForEachUse(
58 unreachable->GetLabelInst(),
59 [&usages](Instruction* use, uint32_t idx) {
60 if ((use->opcode() == SpvOpLoopMerge && idx == 0) ||
61 use->opcode() == SpvOpSelectionMerge) {
62 usages.push_back(std::make_pair(use, idx));
63 }
64 });
65
66 for (auto usage : usages) {
67 Instruction* use;
68 uint32_t idx;
69 std::tie(use, idx) = usage;
70 uint32_t new_id = context()->TakeNextId();
71 std::unique_ptr<Instruction> new_label(
72 new Instruction(context(), SpvOpLabel, 0, new_id, {}));
73 get_def_use_mgr()->AnalyzeInstDefUse(new_label.get());
74 std::unique_ptr<BasicBlock> new_block(
75 new BasicBlock(std::move(new_label)));
76 auto* block_ptr = new_block.get();
77 InstructionBuilder builder(context(), new_block.get(),
78 IRContext::kAnalysisDefUse |
79 IRContext::kAnalysisInstrToBlockMapping);
80 builder.AddUnreachable();
81 cfg()->RegisterBlock(block_ptr);
82 (&*func)->InsertBasicBlockBefore(std::move(new_block), unreachable);
83 use->SetInOperand(0, {new_id});
84 get_def_use_mgr()->UpdateDefUse(use);
85 cfg()->AddEdges(block_ptr);
86 changed = true;
87 }
88 }
89 }
90
91 return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
92}
93
94} // namespace opt
95} // namespace spvtools
96