1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/struct_cfg_analysis.h"
16
17#include "source/opt/ir_context.h"
18
19namespace {
20const uint32_t kMergeNodeIndex = 0;
21const uint32_t kContinueNodeIndex = 1;
22} // namespace
23
24namespace spvtools {
25namespace opt {
26
27StructuredCFGAnalysis::StructuredCFGAnalysis(IRContext* ctx) : context_(ctx) {
28 // If this is not a shader, there are no merge instructions, and not
29 // structured CFG to analyze.
30 if (!context_->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
31 return;
32 }
33
34 for (auto& func : *context_->module()) {
35 AddBlocksInFunction(&func);
36 }
37}
38
39void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
40 if (func->begin() == func->end()) return;
41
42 std::list<BasicBlock*> order;
43 context_->cfg()->ComputeStructuredOrder(func, &*func->begin(), &order);
44
45 struct TraversalInfo {
46 ConstructInfo cinfo;
47 uint32_t merge_node;
48 uint32_t continue_node;
49 };
50
51 // Set up a stack to keep track of currently active constructs.
52 std::vector<TraversalInfo> state;
53 state.emplace_back();
54 state[0].cinfo.containing_construct = 0;
55 state[0].cinfo.containing_loop = 0;
56 state[0].cinfo.containing_switch = 0;
57 state[0].cinfo.in_continue = false;
58 state[0].merge_node = 0;
59 state[0].continue_node = 0;
60
61 for (BasicBlock* block : order) {
62 if (context_->cfg()->IsPseudoEntryBlock(block) ||
63 context_->cfg()->IsPseudoExitBlock(block)) {
64 continue;
65 }
66
67 if (block->id() == state.back().merge_node) {
68 state.pop_back();
69 }
70
71 // This works because the structured order is designed to keep the blocks in
72 // the continue construct between the continue header and the merge node.
73 if (block->id() == state.back().continue_node) {
74 state.back().cinfo.in_continue = true;
75 }
76
77 bb_to_construct_.emplace(std::make_pair(block->id(), state.back().cinfo));
78
79 if (Instruction* merge_inst = block->GetMergeInst()) {
80 TraversalInfo new_state;
81 new_state.merge_node =
82 merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
83 new_state.cinfo.containing_construct = block->id();
84
85 if (merge_inst->opcode() == SpvOpLoopMerge) {
86 new_state.cinfo.containing_loop = block->id();
87 new_state.cinfo.containing_switch = 0;
88 new_state.cinfo.in_continue = false;
89 new_state.continue_node =
90 merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
91 } else {
92 new_state.cinfo.containing_loop = state.back().cinfo.containing_loop;
93 new_state.cinfo.in_continue = state.back().cinfo.in_continue;
94 new_state.continue_node = state.back().continue_node;
95
96 if (merge_inst->NextNode()->opcode() == SpvOpSwitch) {
97 new_state.cinfo.containing_switch = block->id();
98 } else {
99 new_state.cinfo.containing_switch =
100 state.back().cinfo.containing_switch;
101 }
102 }
103
104 state.emplace_back(new_state);
105 merge_blocks_.Set(new_state.merge_node);
106 }
107 }
108}
109
110uint32_t StructuredCFGAnalysis::ContainingConstruct(Instruction* inst) {
111 uint32_t bb = context_->get_instr_block(inst)->id();
112 return ContainingConstruct(bb);
113}
114
115uint32_t StructuredCFGAnalysis::MergeBlock(uint32_t bb_id) {
116 uint32_t header_id = ContainingConstruct(bb_id);
117 if (header_id == 0) {
118 return 0;
119 }
120
121 BasicBlock* header = context_->cfg()->block(header_id);
122 Instruction* merge_inst = header->GetMergeInst();
123 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
124}
125
126uint32_t StructuredCFGAnalysis::LoopMergeBlock(uint32_t bb_id) {
127 uint32_t header_id = ContainingLoop(bb_id);
128 if (header_id == 0) {
129 return 0;
130 }
131
132 BasicBlock* header = context_->cfg()->block(header_id);
133 Instruction* merge_inst = header->GetMergeInst();
134 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
135}
136
137uint32_t StructuredCFGAnalysis::LoopContinueBlock(uint32_t bb_id) {
138 uint32_t header_id = ContainingLoop(bb_id);
139 if (header_id == 0) {
140 return 0;
141 }
142
143 BasicBlock* header = context_->cfg()->block(header_id);
144 Instruction* merge_inst = header->GetMergeInst();
145 return merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
146}
147
148uint32_t StructuredCFGAnalysis::SwitchMergeBlock(uint32_t bb_id) {
149 uint32_t header_id = ContainingSwitch(bb_id);
150 if (header_id == 0) {
151 return 0;
152 }
153
154 BasicBlock* header = context_->cfg()->block(header_id);
155 Instruction* merge_inst = header->GetMergeInst();
156 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
157}
158
159bool StructuredCFGAnalysis::IsContinueBlock(uint32_t bb_id) {
160 assert(bb_id != 0);
161 return LoopContinueBlock(bb_id) == bb_id;
162}
163
164bool StructuredCFGAnalysis::IsInContainingLoopsContinueConstruct(
165 uint32_t bb_id) {
166 auto it = bb_to_construct_.find(bb_id);
167 if (it == bb_to_construct_.end()) {
168 return false;
169 }
170 return it->second.in_continue;
171}
172
173bool StructuredCFGAnalysis::IsInContinueConstruct(uint32_t bb_id) {
174 while (bb_id != 0) {
175 if (IsInContainingLoopsContinueConstruct(bb_id)) {
176 return true;
177 }
178 bb_id = ContainingLoop(bb_id);
179 }
180 return false;
181}
182
183bool StructuredCFGAnalysis::IsMergeBlock(uint32_t bb_id) {
184 return merge_blocks_.Get(bb_id);
185}
186
187std::unordered_set<uint32_t>
188StructuredCFGAnalysis::FindFuncsCalledFromContinue() {
189 std::unordered_set<uint32_t> called_from_continue;
190 std::queue<uint32_t> funcs_to_process;
191
192 // First collect the functions that are called directly from a continue
193 // construct.
194 for (Function& func : *context_->module()) {
195 for (auto& bb : func) {
196 if (IsInContainingLoopsContinueConstruct(bb.id())) {
197 for (const Instruction& inst : bb) {
198 if (inst.opcode() == SpvOpFunctionCall) {
199 funcs_to_process.push(inst.GetSingleWordInOperand(0));
200 }
201 }
202 }
203 }
204 }
205
206 // Now collect all of the functions that are indirectly called as well.
207 while (!funcs_to_process.empty()) {
208 uint32_t func_id = funcs_to_process.front();
209 funcs_to_process.pop();
210 Function* func = context_->GetFunction(func_id);
211 if (called_from_continue.insert(func_id).second) {
212 context_->AddCalls(func, &funcs_to_process);
213 }
214 }
215 return called_from_continue;
216}
217
218} // namespace opt
219} // namespace spvtools
220