1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "code_sink.h"
16
17#include <set>
18#include <vector>
19
20#include "source/opt/instruction.h"
21#include "source/opt/ir_builder.h"
22#include "source/opt/ir_context.h"
23#include "source/util/bit_vector.h"
24
25namespace spvtools {
26namespace opt {
27
28Pass::Status CodeSinkingPass::Process() {
29 bool modified = false;
30 for (Function& function : *get_module()) {
31 cfg()->ForEachBlockInPostOrder(function.entry().get(),
32 [&modified, this](BasicBlock* bb) {
33 if (SinkInstructionsInBB(bb)) {
34 modified = true;
35 }
36 });
37 }
38 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
39}
40
41bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
42 bool modified = false;
43 for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
44 if (SinkInstruction(&*inst)) {
45 inst = bb->rbegin();
46 modified = true;
47 }
48 }
49 return modified;
50}
51
52bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
53 if (inst->opcode() != SpvOpLoad && inst->opcode() != SpvOpAccessChain) {
54 return false;
55 }
56
57 if (ReferencesMutableMemory(inst)) {
58 return false;
59 }
60
61 if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
62 Instruction* pos = &*target_bb->begin();
63 while (pos->opcode() == SpvOpPhi) {
64 pos = pos->NextNode();
65 }
66
67 inst->InsertBefore(pos);
68 context()->set_instr_block(inst, target_bb);
69 return true;
70 }
71 return false;
72}
73
74BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
75 assert(inst->result_id() != 0 && "Instruction should have a result.");
76 BasicBlock* original_bb = context()->get_instr_block(inst);
77 BasicBlock* bb = original_bb;
78
79 std::unordered_set<uint32_t> bbs_with_uses;
80 get_def_use_mgr()->ForEachUse(
81 inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
82 if (use->opcode() != SpvOpPhi) {
83 BasicBlock* use_bb = context()->get_instr_block(use);
84 if (use_bb) {
85 bbs_with_uses.insert(use_bb->id());
86 }
87 } else {
88 bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
89 }
90 });
91
92 while (true) {
93 // If |inst| is used in |bb|, then |inst| cannot be moved any further.
94 if (bbs_with_uses.count(bb->id())) {
95 break;
96 }
97
98 // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
99 // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
100 // then one predecessor, then moving |inst| into succ_bb could cause it to
101 // be executed more often, so the search has to stop.
102 if (bb->terminator()->opcode() == SpvOpBranch) {
103 uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
104 if (cfg()->preds(succ_bb_id).size() == 1) {
105 bb = context()->get_instr_block(succ_bb_id);
106 continue;
107 } else {
108 break;
109 }
110 }
111
112 // The remaining checks need to know the merge node. If there is no merge
113 // instruction or an OpLoopMerge, then it is a break or continue. We could
114 // figure it out, but not worth doing it now.
115 Instruction* merge_inst = bb->GetMergeInst();
116 if (merge_inst == nullptr || merge_inst->opcode() != SpvOpSelectionMerge) {
117 break;
118 }
119
120 // Check all of the successors of |bb| it see which lead to a use of |inst|
121 // before reaching the merge node.
122 bool used_in_multiple_blocks = false;
123 uint32_t bb_used_in = 0;
124 bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
125 &bbs_with_uses](uint32_t* succ_bb_id) {
126 if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
127 if (bb_used_in == 0) {
128 bb_used_in = *succ_bb_id;
129 } else {
130 used_in_multiple_blocks = true;
131 }
132 }
133 });
134
135 // If more than one successor, which is not the merge block, uses |inst|
136 // then we have to leave |inst| in bb because there is none of the
137 // successors dominate all uses of |inst|.
138 if (used_in_multiple_blocks) {
139 break;
140 }
141
142 if (bb_used_in == 0) {
143 // If |inst| is not used before reaching the merge node, then we can move
144 // |inst| to the merge node.
145 bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
146 } else {
147 // If the only successor that leads to a used of |inst| has more than 1
148 // predecessor, then moving |inst| could cause it to be executed more
149 // often, so we cannot move it.
150 if (cfg()->preds(bb_used_in).size() != 1) {
151 break;
152 }
153
154 // If |inst| is used after the merge block, then |bb_used_in| does not
155 // dominate all of the uses. So we cannot move |inst| any further.
156 if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
157 bbs_with_uses)) {
158 break;
159 }
160
161 // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
162 // block.
163 bb = context()->get_instr_block(bb_used_in);
164 }
165 continue;
166 }
167 return (bb != original_bb ? bb : nullptr);
168}
169
170bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
171 if (!inst->IsLoad()) {
172 return false;
173 }
174
175 Instruction* base_ptr = inst->GetBaseAddress();
176 if (base_ptr->opcode() != SpvOpVariable) {
177 return true;
178 }
179
180 if (base_ptr->IsReadOnlyVariable()) {
181 return false;
182 }
183
184 if (HasUniformMemorySync()) {
185 return true;
186 }
187
188 if (base_ptr->GetSingleWordInOperand(0) != SpvStorageClassUniform) {
189 return true;
190 }
191
192 return HasPossibleStore(base_ptr);
193}
194
195bool CodeSinkingPass::HasUniformMemorySync() {
196 if (checked_for_uniform_sync_) {
197 return has_uniform_sync_;
198 }
199
200 bool has_sync = false;
201 get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
202 switch (inst->opcode()) {
203 case SpvOpMemoryBarrier: {
204 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
205 if (IsSyncOnUniform(mem_semantics_id)) {
206 has_sync = true;
207 }
208 break;
209 }
210 case SpvOpControlBarrier:
211 case SpvOpAtomicLoad:
212 case SpvOpAtomicStore:
213 case SpvOpAtomicExchange:
214 case SpvOpAtomicIIncrement:
215 case SpvOpAtomicIDecrement:
216 case SpvOpAtomicIAdd:
217 case SpvOpAtomicISub:
218 case SpvOpAtomicSMin:
219 case SpvOpAtomicUMin:
220 case SpvOpAtomicSMax:
221 case SpvOpAtomicUMax:
222 case SpvOpAtomicAnd:
223 case SpvOpAtomicOr:
224 case SpvOpAtomicXor:
225 case SpvOpAtomicFlagTestAndSet:
226 case SpvOpAtomicFlagClear: {
227 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
228 if (IsSyncOnUniform(mem_semantics_id)) {
229 has_sync = true;
230 }
231 break;
232 }
233 case SpvOpAtomicCompareExchange:
234 case SpvOpAtomicCompareExchangeWeak:
235 if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
236 IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
237 has_sync = true;
238 }
239 break;
240 default:
241 break;
242 }
243 });
244 has_uniform_sync_ = has_sync;
245 return has_sync;
246}
247
248bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
249 const analysis::Constant* mem_semantics_const =
250 context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
251 assert(mem_semantics_const != nullptr &&
252 "Expecting memory semantics id to be a constant.");
253 assert(mem_semantics_const->AsIntConstant() &&
254 "Memory semantics should be an integer.");
255 uint32_t mem_semantics_int = mem_semantics_const->GetU32();
256
257 // If it does not affect uniform memory, then it is does not apply to uniform
258 // memory.
259 if ((mem_semantics_int & SpvMemorySemanticsUniformMemoryMask) == 0) {
260 return false;
261 }
262
263 // Check if there is an acquire or release. If so not, this it does not add
264 // any memory constraints.
265 return (mem_semantics_int & (SpvMemorySemanticsAcquireMask |
266 SpvMemorySemanticsAcquireReleaseMask |
267 SpvMemorySemanticsReleaseMask)) != 0;
268}
269
270bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
271 assert(var_inst->opcode() == SpvOpVariable ||
272 var_inst->opcode() == SpvOpAccessChain ||
273 var_inst->opcode() == SpvOpPtrAccessChain);
274
275 return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
276 switch (use->opcode()) {
277 case SpvOpStore:
278 return true;
279 case SpvOpAccessChain:
280 case SpvOpPtrAccessChain:
281 return HasPossibleStore(use);
282 default:
283 return false;
284 }
285 });
286}
287
288bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
289 const std::unordered_set<uint32_t>& set) {
290 std::vector<uint32_t> worklist;
291 worklist.push_back(start);
292 std::unordered_set<uint32_t> already_done;
293 already_done.insert(start);
294
295 while (!worklist.empty()) {
296 BasicBlock* bb = context()->get_instr_block(worklist.back());
297 worklist.pop_back();
298
299 if (bb->id() == end) {
300 continue;
301 }
302
303 if (set.count(bb->id())) {
304 return true;
305 }
306
307 bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
308 if (already_done.insert(*succ_bb_id).second) {
309 worklist.push_back(*succ_bb_id);
310 }
311 });
312 }
313 return false;
314}
315
316// namespace opt
317
318} // namespace opt
319} // namespace spvtools
320