1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/licm_pass.h"
16
17#include <queue>
18#include <utility>
19
20#include "source/opt/module.h"
21#include "source/opt/pass.h"
22
23namespace spvtools {
24namespace opt {
25
26Pass::Status LICMPass::Process() { return ProcessIRContext(); }
27
28Pass::Status LICMPass::ProcessIRContext() {
29 Status status = Status::SuccessWithoutChange;
30 Module* module = get_module();
31
32 // Process each function in the module
33 for (auto func = module->begin();
34 func != module->end() && status != Status::Failure; ++func) {
35 status = CombineStatus(status, ProcessFunction(&*func));
36 }
37 return status;
38}
39
40Pass::Status LICMPass::ProcessFunction(Function* f) {
41 Status status = Status::SuccessWithoutChange;
42 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
43
44 // Process each loop in the function
45 for (auto it = loop_descriptor->begin();
46 it != loop_descriptor->end() && status != Status::Failure; ++it) {
47 Loop& loop = *it;
48 // Ignore nested loops, as we will process them in order in ProcessLoop
49 if (loop.IsNested()) {
50 continue;
51 }
52 status = CombineStatus(status, ProcessLoop(&loop, f));
53 }
54 return status;
55}
56
57Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) {
58 Status status = Status::SuccessWithoutChange;
59
60 // Process all nested loops first
61 for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure;
62 ++nl) {
63 Loop* nested_loop = *nl;
64 status = CombineStatus(status, ProcessLoop(nested_loop, f));
65 }
66
67 std::vector<BasicBlock*> loop_bbs{};
68 status = CombineStatus(
69 status,
70 AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs));
71
72 for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) {
73 BasicBlock* bb = loop_bbs[i];
74 // do not delete the element
75 status =
76 CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs));
77 }
78
79 return status;
80}
81
82Pass::Status LICMPass::AnalyseAndHoistFromBB(
83 Loop* loop, Function* f, BasicBlock* bb,
84 std::vector<BasicBlock*>* loop_bbs) {
85 bool modified = false;
86 std::function<bool(Instruction*)> hoist_inst =
87 [this, &loop, &modified](Instruction* inst) {
88 if (loop->ShouldHoistInstruction(this->context(), inst)) {
89 if (!HoistInstruction(loop, inst)) {
90 return false;
91 }
92 modified = true;
93 }
94 return true;
95 };
96
97 if (IsImmediatelyContainedInLoop(loop, f, bb)) {
98 if (!bb->WhileEachInst(hoist_inst, false)) {
99 return Status::Failure;
100 }
101 }
102
103 DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f);
104 DominatorTree& dom_tree = dom_analysis->GetDomTree();
105
106 for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) {
107 if (loop->IsInsideLoop(child_dom_tree_node->bb_)) {
108 loop_bbs->push_back(child_dom_tree_node->bb_);
109 }
110 }
111
112 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
113}
114
115bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f,
116 BasicBlock* bb) {
117 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
118 return loop == (*loop_descriptor)[bb->id()];
119}
120
121bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
122 // TODO(1841): Handle failure to create pre-header.
123 BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock();
124 if (!pre_header_bb) {
125 return false;
126 }
127 Instruction* insertion_point = &*pre_header_bb->tail();
128 Instruction* previous_node = insertion_point->PreviousNode();
129 if (previous_node && (previous_node->opcode() == SpvOpLoopMerge ||
130 previous_node->opcode() == SpvOpSelectionMerge)) {
131 insertion_point = previous_node;
132 }
133
134 inst->InsertBefore(insertion_point);
135 context()->set_instr_block(inst, pre_header_bb);
136 return true;
137}
138
139} // namespace opt
140} // namespace spvtools
141