1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/cfg.h"
16
17#include <memory>
18#include <utility>
19
20#include "source/cfa.h"
21#include "source/opt/ir_builder.h"
22#include "source/opt/ir_context.h"
23#include "source/opt/module.h"
24
25namespace spvtools {
26namespace opt {
27namespace {
28
29using cbb_ptr = const opt::BasicBlock*;
30
31// Universal Limit of ResultID + 1
32const int kMaxResultId = 0x400000;
33
34} // namespace
35
36CFG::CFG(Module* module)
37 : module_(module),
38 pseudo_entry_block_(std::unique_ptr<Instruction>(
39 new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
40 pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
41 module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
42 for (auto& fn : *module) {
43 for (auto& blk : fn) {
44 RegisterBlock(&blk);
45 }
46 }
47}
48
49void CFG::AddEdges(BasicBlock* blk) {
50 uint32_t blk_id = blk->id();
51 // Force the creation of an entry, not all basic block have predecessors
52 // (such as the entry blocks and some unreachables).
53 label2preds_[blk_id];
54 const auto* const_blk = blk;
55 const_blk->ForEachSuccessorLabel(
56 [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
57}
58
59void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
60 std::vector<uint32_t> updated_pred_list;
61 for (uint32_t id : preds(blk_id)) {
62 const BasicBlock* pred_blk = block(id);
63 bool has_branch = false;
64 pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
65 if (succ == blk_id) {
66 has_branch = true;
67 }
68 });
69 if (has_branch) updated_pred_list.push_back(id);
70 }
71
72 label2preds_.at(blk_id) = std::move(updated_pred_list);
73}
74
75void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
76 std::list<BasicBlock*>* order) {
77 assert(module_->context()->get_feature_mgr()->HasCapability(
78 SpvCapabilityShader) &&
79 "This only works on structured control flow");
80
81 // Compute structured successors and do DFS.
82 ComputeStructuredSuccessors(func);
83 auto ignore_block = [](cbb_ptr) {};
84 auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
85 auto get_structured_successors = [this](const BasicBlock* b) {
86 return &(block2structured_succs_[b]);
87 };
88
89 // TODO(greg-lunarg): Get rid of const_cast by making moving const
90 // out of the cfa.h prototypes and into the invoking code.
91 auto post_order = [&](cbb_ptr b) {
92 order->push_front(const_cast<BasicBlock*>(b));
93 };
94 CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
95 ignore_block, post_order, ignore_edge);
96}
97
98void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
99 const std::function<void(BasicBlock*)>& f) {
100 std::vector<BasicBlock*> po;
101 std::unordered_set<BasicBlock*> seen;
102 ComputePostOrderTraversal(bb, &po, &seen);
103
104 for (BasicBlock* current_bb : po) {
105 if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
106 f(current_bb);
107 }
108 }
109}
110
111void CFG::ForEachBlockInReversePostOrder(
112 BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
113 WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
114 f(b);
115 return true;
116 });
117}
118
119bool CFG::WhileEachBlockInReversePostOrder(
120 BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
121 std::vector<BasicBlock*> po;
122 std::unordered_set<BasicBlock*> seen;
123 ComputePostOrderTraversal(bb, &po, &seen);
124
125 for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
126 if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
127 if (!f(*current_bb)) {
128 return false;
129 }
130 }
131 }
132 return true;
133}
134
135void CFG::ComputeStructuredSuccessors(Function* func) {
136 block2structured_succs_.clear();
137 for (auto& blk : *func) {
138 // If no predecessors in function, make successor to pseudo entry.
139 if (label2preds_[blk.id()].size() == 0)
140 block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
141
142 // If header, make merge block first successor and continue block second
143 // successor if there is one.
144 uint32_t mbid = blk.MergeBlockIdIfAny();
145 if (mbid != 0) {
146 block2structured_succs_[&blk].push_back(block(mbid));
147 uint32_t cbid = blk.ContinueBlockIdIfAny();
148 if (cbid != 0) {
149 block2structured_succs_[&blk].push_back(block(cbid));
150 }
151 }
152
153 // Add true successors.
154 const auto& const_blk = blk;
155 const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
156 block2structured_succs_[&blk].push_back(block(sbid));
157 });
158 }
159}
160
161void CFG::ComputePostOrderTraversal(BasicBlock* bb,
162 std::vector<BasicBlock*>* order,
163 std::unordered_set<BasicBlock*>* seen) {
164 std::vector<BasicBlock*> stack;
165 stack.push_back(bb);
166 while (!stack.empty()) {
167 bb = stack.back();
168 seen->insert(bb);
169 static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
170 [&seen, &stack, this](const uint32_t sbid) {
171 BasicBlock* succ_bb = id2block_[sbid];
172 if (!seen->count(succ_bb)) {
173 stack.push_back(succ_bb);
174 return false;
175 }
176 return true;
177 });
178 if (stack.back() == bb) {
179 order->push_back(bb);
180 stack.pop_back();
181 }
182 }
183}
184
185BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
186 assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
187
188 Function* fn = bb->GetParent();
189 IRContext* context = module_->context();
190
191 // Get the new header id up front. If we are out of ids, then we cannot split
192 // the loop.
193 uint32_t new_header_id = context->TakeNextId();
194 if (new_header_id == 0) {
195 return nullptr;
196 }
197
198 // Find the insertion point for the new bb.
199 Function::iterator header_it = std::find_if(
200 fn->begin(), fn->end(),
201 [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
202 assert(header_it != fn->end());
203
204 const std::vector<uint32_t>& pred = preds(bb->id());
205 // Find the back edge
206 BasicBlock* latch_block = nullptr;
207 Function::iterator latch_block_iter = header_it;
208 while (++latch_block_iter != fn->end()) {
209 // If blocks are in the proper order, then the only branch that appears
210 // after the header is the latch.
211 if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
212 pred.end()) {
213 break;
214 }
215 }
216 assert(latch_block_iter != fn->end() && "Could not find the latch.");
217 latch_block = &*latch_block_iter;
218
219 RemoveSuccessorEdges(bb);
220
221 // Create the new header bb basic bb.
222 // Leave the phi instructions behind.
223 auto iter = bb->begin();
224 while (iter->opcode() == SpvOpPhi) {
225 ++iter;
226 }
227
228 BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
229 context->AnalyzeDefUse(new_header->GetLabelInst());
230
231 // Update cfg
232 RegisterBlock(new_header);
233
234 // Update bb mappings.
235 context->set_instr_block(new_header->GetLabelInst(), new_header);
236 new_header->ForEachInst([new_header, context](Instruction* inst) {
237 context->set_instr_block(inst, new_header);
238 });
239
240 // Adjust the OpPhi instructions as needed.
241 bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
242 std::vector<uint32_t> preheader_phi_ops;
243 std::vector<Operand> header_phi_ops;
244
245 // Identify where the original inputs to original OpPhi belong: header or
246 // preheader.
247 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
248 uint32_t def_id = phi->GetSingleWordInOperand(i);
249 uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
250 if (branch_id == latch_block->id()) {
251 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
252 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
253 } else {
254 preheader_phi_ops.push_back(def_id);
255 preheader_phi_ops.push_back(branch_id);
256 }
257 }
258
259 // Create a phi instruction if and only if the preheader_phi_ops has more
260 // than one pair.
261 if (preheader_phi_ops.size() > 2) {
262 InstructionBuilder builder(
263 context, &*bb->begin(),
264 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
265
266 Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
267
268 // Add the OpPhi to the header bb.
269 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
270 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
271 } else {
272 // An OpPhi with a single entry is just a copy. In this case use the same
273 // instruction in the new header.
274 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
275 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
276 }
277
278 phi->RemoveFromList();
279 std::unique_ptr<Instruction> phi_owner(phi);
280 phi->SetInOperands(std::move(header_phi_ops));
281 new_header->begin()->InsertBefore(std::move(phi_owner));
282 context->set_instr_block(phi, new_header);
283 context->AnalyzeUses(phi);
284 });
285
286 // Add a branch to the new header.
287 InstructionBuilder branch_builder(
288 context, bb,
289 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
290 bb->AddInstruction(
291 MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
292 std::initializer_list<Operand>{
293 {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
294 context->AnalyzeUses(bb->terminator());
295 context->set_instr_block(bb->terminator(), bb);
296 label2preds_[new_header->id()].push_back(bb->id());
297
298 // Update the latch to branch to the new header.
299 latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
300 if (*id == bb->id()) {
301 *id = new_header_id;
302 }
303 });
304 Instruction* latch_branch = latch_block->terminator();
305 context->AnalyzeUses(latch_branch);
306 label2preds_[new_header->id()].push_back(latch_block->id());
307
308 auto& block_preds = label2preds_[bb->id()];
309 auto latch_pos =
310 std::find(block_preds.begin(), block_preds.end(), latch_block->id());
311 assert(latch_pos != block_preds.end() && "The cfg was invalid.");
312 block_preds.erase(latch_pos);
313
314 // Update the loop descriptors
315 if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
316 LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
317 Loop* loop = (*loop_desc)[bb->id()];
318
319 loop->AddBasicBlock(new_header_id);
320 loop->SetHeaderBlock(new_header);
321 loop_desc->SetBasicBlockToLoop(new_header_id, loop);
322
323 loop->RemoveBasicBlock(bb->id());
324 loop->SetPreHeaderBlock(bb);
325
326 Loop* parent_loop = loop->GetParent();
327 if (parent_loop != nullptr) {
328 parent_loop->AddBasicBlock(bb->id());
329 loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
330 } else {
331 loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
332 }
333 }
334 return new_header;
335}
336
337} // namespace opt
338} // namespace spvtools
339