1// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/if_conversion.h"
16
17#include <memory>
18#include <vector>
19
20#include "source/opt/value_number_table.h"
21
22namespace spvtools {
23namespace opt {
24
25Pass::Status IfConversion::Process() {
26 if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
27 return Status::SuccessWithoutChange;
28 }
29
30 const ValueNumberTable& vn_table = *context()->GetValueNumberTable();
31 bool modified = false;
32 std::vector<Instruction*> to_kill;
33 for (auto& func : *get_module()) {
34 DominatorAnalysis* dominators = context()->GetDominatorAnalysis(&func);
35 for (auto& block : func) {
36 // Check if it is possible for |block| to have phis that can be
37 // transformed.
38 BasicBlock* common = nullptr;
39 if (!CheckBlock(&block, dominators, &common)) continue;
40
41 // Get an insertion point.
42 auto iter = block.begin();
43 while (iter != block.end() && iter->opcode() == SpvOpPhi) {
44 ++iter;
45 }
46
47 InstructionBuilder builder(
48 context(), &*iter,
49 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
50 block.ForEachPhiInst([this, &builder, &modified, &common, &to_kill,
51 dominators, &block, &vn_table](Instruction* phi) {
52 // This phi is not compatible, but subsequent phis might be.
53 if (!CheckType(phi->type_id())) return;
54
55 // We cannot transform cases where the phi is used by another phi in the
56 // same block due to instruction ordering restrictions.
57 // TODO(alan-baker): If all inappropriate uses could also be
58 // transformed, we could still remove this phi.
59 if (!CheckPhiUsers(phi, &block)) return;
60
61 // Identify the incoming values associated with the true and false
62 // branches. If |then_block| dominates |inc0| or if the true edge
63 // branches straight to this block and |common| is |inc0|, then |inc0|
64 // is on the true branch. Otherwise the |inc1| is on the true branch.
65 BasicBlock* inc0 = GetIncomingBlock(phi, 0u);
66 Instruction* branch = common->terminator();
67 uint32_t condition = branch->GetSingleWordInOperand(0u);
68 BasicBlock* then_block = GetBlock(branch->GetSingleWordInOperand(1u));
69 Instruction* true_value = nullptr;
70 Instruction* false_value = nullptr;
71 if ((then_block == &block && inc0 == common) ||
72 dominators->Dominates(then_block, inc0)) {
73 true_value = GetIncomingValue(phi, 0u);
74 false_value = GetIncomingValue(phi, 1u);
75 } else {
76 true_value = GetIncomingValue(phi, 1u);
77 false_value = GetIncomingValue(phi, 0u);
78 }
79
80 BasicBlock* true_def_block = context()->get_instr_block(true_value);
81 BasicBlock* false_def_block = context()->get_instr_block(false_value);
82
83 uint32_t true_vn = vn_table.GetValueNumber(true_value);
84 uint32_t false_vn = vn_table.GetValueNumber(false_value);
85 if (true_vn != 0 && true_vn == false_vn) {
86 Instruction* inst_to_use = nullptr;
87
88 // Try to pick an instruction that is not in a side node. If we can't
89 // pick either the true for false branch as long as they can be
90 // legally moved.
91 if (!true_def_block ||
92 dominators->Dominates(true_def_block, &block)) {
93 inst_to_use = true_value;
94 } else if (!false_def_block ||
95 dominators->Dominates(false_def_block, &block)) {
96 inst_to_use = false_value;
97 } else if (CanHoistInstruction(true_value, common, dominators)) {
98 inst_to_use = true_value;
99 } else if (CanHoistInstruction(false_value, common, dominators)) {
100 inst_to_use = false_value;
101 }
102
103 if (inst_to_use != nullptr) {
104 modified = true;
105 HoistInstruction(inst_to_use, common, dominators);
106 context()->KillNamesAndDecorates(phi);
107 context()->ReplaceAllUsesWith(phi->result_id(),
108 inst_to_use->result_id());
109 }
110 return;
111 }
112
113 // If either incoming value is defined in a block that does not dominate
114 // this phi, then we cannot eliminate the phi with a select.
115 // TODO(alan-baker): Perform code motion where it makes sense to enable
116 // the transform in this case.
117 if (true_def_block && !dominators->Dominates(true_def_block, &block))
118 return;
119
120 if (false_def_block && !dominators->Dominates(false_def_block, &block))
121 return;
122
123 analysis::Type* data_ty =
124 context()->get_type_mgr()->GetType(true_value->type_id());
125 if (analysis::Vector* vec_data_ty = data_ty->AsVector()) {
126 condition = SplatCondition(vec_data_ty, condition, &builder);
127 }
128
129 Instruction* select = builder.AddSelect(phi->type_id(), condition,
130 true_value->result_id(),
131 false_value->result_id());
132 context()->ReplaceAllUsesWith(phi->result_id(), select->result_id());
133 to_kill.push_back(phi);
134 modified = true;
135
136 return;
137 });
138 }
139 }
140
141 for (auto inst : to_kill) {
142 context()->KillInst(inst);
143 }
144
145 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
146}
147
148bool IfConversion::CheckBlock(BasicBlock* block, DominatorAnalysis* dominators,
149 BasicBlock** common) {
150 const std::vector<uint32_t>& preds = cfg()->preds(block->id());
151
152 // TODO(alan-baker): Extend to more than two predecessors
153 if (preds.size() != 2) return false;
154
155 BasicBlock* inc0 = context()->get_instr_block(preds[0]);
156 if (dominators->Dominates(block, inc0)) return false;
157
158 BasicBlock* inc1 = context()->get_instr_block(preds[1]);
159 if (dominators->Dominates(block, inc1)) return false;
160
161 // All phis will have the same common dominator, so cache the result
162 // for this block. If there is no common dominator, then we cannot transform
163 // any phi in this basic block.
164 *common = dominators->CommonDominator(inc0, inc1);
165 if (!*common || cfg()->IsPseudoEntryBlock(*common)) return false;
166 Instruction* branch = (*common)->terminator();
167 if (branch->opcode() != SpvOpBranchConditional) return false;
168 auto merge = (*common)->GetMergeInst();
169 if (!merge || merge->opcode() != SpvOpSelectionMerge) return false;
170 if ((*common)->MergeBlockIdIfAny() != block->id()) return false;
171
172 return true;
173}
174
175bool IfConversion::CheckPhiUsers(Instruction* phi, BasicBlock* block) {
176 return get_def_use_mgr()->WhileEachUser(phi, [block,
177 this](Instruction* user) {
178 if (user->opcode() == SpvOpPhi && context()->get_instr_block(user) == block)
179 return false;
180 return true;
181 });
182}
183
184uint32_t IfConversion::SplatCondition(analysis::Vector* vec_data_ty,
185 uint32_t cond,
186 InstructionBuilder* builder) {
187 // If the data inputs to OpSelect are vectors, the condition for
188 // OpSelect must be a boolean vector with the same number of
189 // components. So splat the condition for the branch into a vector
190 // type.
191 analysis::Bool bool_ty;
192 analysis::Vector bool_vec_ty(&bool_ty, vec_data_ty->element_count());
193 uint32_t bool_vec_id =
194 context()->get_type_mgr()->GetTypeInstruction(&bool_vec_ty);
195 std::vector<uint32_t> ids(vec_data_ty->element_count(), cond);
196 return builder->AddCompositeConstruct(bool_vec_id, ids)->result_id();
197}
198
199bool IfConversion::CheckType(uint32_t id) {
200 Instruction* type = get_def_use_mgr()->GetDef(id);
201 SpvOp op = type->opcode();
202 if (spvOpcodeIsScalarType(op) || op == SpvOpTypePointer ||
203 op == SpvOpTypeVector)
204 return true;
205 return false;
206}
207
208BasicBlock* IfConversion::GetBlock(uint32_t id) {
209 return context()->get_instr_block(get_def_use_mgr()->GetDef(id));
210}
211
212BasicBlock* IfConversion::GetIncomingBlock(Instruction* phi,
213 uint32_t predecessor) {
214 uint32_t in_index = 2 * predecessor + 1;
215 return GetBlock(phi->GetSingleWordInOperand(in_index));
216}
217
218Instruction* IfConversion::GetIncomingValue(Instruction* phi,
219 uint32_t predecessor) {
220 uint32_t in_index = 2 * predecessor;
221 return get_def_use_mgr()->GetDef(phi->GetSingleWordInOperand(in_index));
222}
223
224void IfConversion::HoistInstruction(Instruction* inst, BasicBlock* target_block,
225 DominatorAnalysis* dominators) {
226 BasicBlock* inst_block = context()->get_instr_block(inst);
227 if (!inst_block) {
228 // This is in the header, and dominates everything.
229 return;
230 }
231
232 if (dominators->Dominates(inst_block, target_block)) {
233 // Already in position. No work to do.
234 return;
235 }
236
237 assert(inst->IsOpcodeCodeMotionSafe() &&
238 "Trying to move an instruction that is not safe to move.");
239
240 // First hoist all instructions it depends on.
241 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
242 inst->ForEachInId(
243 [this, target_block, def_use_mgr, dominators](uint32_t* id) {
244 Instruction* operand_inst = def_use_mgr->GetDef(*id);
245 HoistInstruction(operand_inst, target_block, dominators);
246 });
247
248 Instruction* insertion_pos = target_block->terminator();
249 if ((insertion_pos)->PreviousNode()->opcode() == SpvOpSelectionMerge) {
250 insertion_pos = insertion_pos->PreviousNode();
251 }
252 inst->RemoveFromList();
253 insertion_pos->InsertBefore(std::unique_ptr<Instruction>(inst));
254 context()->set_instr_block(inst, target_block);
255}
256
257bool IfConversion::CanHoistInstruction(Instruction* inst,
258 BasicBlock* target_block,
259 DominatorAnalysis* dominators) {
260 BasicBlock* inst_block = context()->get_instr_block(inst);
261 if (!inst_block) {
262 // This is in the header, and dominates everything.
263 return true;
264 }
265
266 if (dominators->Dominates(inst_block, target_block)) {
267 // Already in position. No work to do.
268 return true;
269 }
270
271 if (!inst->IsOpcodeCodeMotionSafe()) {
272 return false;
273 }
274
275 // Check all instruction |inst| depends on.
276 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
277 return inst->WhileEachInId(
278 [this, target_block, def_use_mgr, dominators](uint32_t* id) {
279 Instruction* operand_inst = def_use_mgr->GetDef(*id);
280 return CanHoistInstruction(operand_inst, target_block, dominators);
281 });
282}
283
284} // namespace opt
285} // namespace spvtools
286