1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/loop_descriptor.h"
16
17#include <algorithm>
18#include <iostream>
19#include <limits>
20#include <stack>
21#include <type_traits>
22#include <utility>
23#include <vector>
24
25#include "source/opt/cfg.h"
26#include "source/opt/constants.h"
27#include "source/opt/dominator_tree.h"
28#include "source/opt/ir_builder.h"
29#include "source/opt/ir_context.h"
30#include "source/opt/iterator.h"
31#include "source/opt/tree_iterator.h"
32#include "source/util/make_unique.h"
33
34namespace spvtools {
35namespace opt {
36
37// Takes in a phi instruction |induction| and the loop |header| and returns the
38// step operation of the loop.
39Instruction* Loop::GetInductionStepOperation(
40 const Instruction* induction) const {
41 // Induction must be a phi instruction.
42 assert(induction->opcode() == SpvOpPhi);
43
44 Instruction* step = nullptr;
45
46 analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
47
48 // Traverse the incoming operands of the phi instruction.
49 for (uint32_t operand_id = 1; operand_id < induction->NumInOperands();
50 operand_id += 2) {
51 // Incoming edge.
52 BasicBlock* incoming_block =
53 context_->cfg()->block(induction->GetSingleWordInOperand(operand_id));
54
55 // Check if the block is dominated by header, and thus coming from within
56 // the loop.
57 if (IsInsideLoop(incoming_block)) {
58 step = def_use_manager->GetDef(
59 induction->GetSingleWordInOperand(operand_id - 1));
60 break;
61 }
62 }
63
64 if (!step || !IsSupportedStepOp(step->opcode())) {
65 return nullptr;
66 }
67
68 // The induction variable which binds the loop must only be modified once.
69 uint32_t lhs = step->GetSingleWordInOperand(0);
70 uint32_t rhs = step->GetSingleWordInOperand(1);
71
72 // One of the left hand side or right hand side of the step instruction must
73 // be the induction phi and the other must be an OpConstant.
74 if (lhs != induction->result_id() && rhs != induction->result_id()) {
75 return nullptr;
76 }
77
78 if (def_use_manager->GetDef(lhs)->opcode() != SpvOp::SpvOpConstant &&
79 def_use_manager->GetDef(rhs)->opcode() != SpvOp::SpvOpConstant) {
80 return nullptr;
81 }
82
83 return step;
84}
85
86// Returns true if the |step| operation is an induction variable step operation
87// which is currently handled.
88bool Loop::IsSupportedStepOp(SpvOp step) const {
89 switch (step) {
90 case SpvOp::SpvOpISub:
91 case SpvOp::SpvOpIAdd:
92 return true;
93 default:
94 return false;
95 }
96}
97
98bool Loop::IsSupportedCondition(SpvOp condition) const {
99 switch (condition) {
100 // <
101 case SpvOp::SpvOpULessThan:
102 case SpvOp::SpvOpSLessThan:
103 // >
104 case SpvOp::SpvOpUGreaterThan:
105 case SpvOp::SpvOpSGreaterThan:
106
107 // >=
108 case SpvOp::SpvOpSGreaterThanEqual:
109 case SpvOp::SpvOpUGreaterThanEqual:
110 // <=
111 case SpvOp::SpvOpSLessThanEqual:
112 case SpvOp::SpvOpULessThanEqual:
113
114 return true;
115 default:
116 return false;
117 }
118}
119
120int64_t Loop::GetResidualConditionValue(SpvOp condition, int64_t initial_value,
121 int64_t step_value,
122 size_t number_of_iterations,
123 size_t factor) {
124 int64_t remainder =
125 initial_value + (number_of_iterations % factor) * step_value;
126
127 // We subtract or add one as the above formula calculates the remainder if the
128 // loop where just less than or greater than. Adding or subtracting one should
129 // give a functionally equivalent value.
130 switch (condition) {
131 case SpvOp::SpvOpSGreaterThanEqual:
132 case SpvOp::SpvOpUGreaterThanEqual: {
133 remainder -= 1;
134 break;
135 }
136 case SpvOp::SpvOpSLessThanEqual:
137 case SpvOp::SpvOpULessThanEqual: {
138 remainder += 1;
139 break;
140 }
141
142 default:
143 break;
144 }
145 return remainder;
146}
147
148Instruction* Loop::GetConditionInst() const {
149 BasicBlock* condition_block = FindConditionBlock();
150 if (!condition_block) {
151 return nullptr;
152 }
153 Instruction* branch_conditional = &*condition_block->tail();
154 if (!branch_conditional ||
155 branch_conditional->opcode() != SpvOpBranchConditional) {
156 return nullptr;
157 }
158 Instruction* condition_inst = context_->get_def_use_mgr()->GetDef(
159 branch_conditional->GetSingleWordInOperand(0));
160 if (IsSupportedCondition(condition_inst->opcode())) {
161 return condition_inst;
162 }
163
164 return nullptr;
165}
166
167// Extract the initial value from the |induction| OpPhi instruction and store it
168// in |value|. If the function couldn't find the initial value of |induction|
169// return false.
170bool Loop::GetInductionInitValue(const Instruction* induction,
171 int64_t* value) const {
172 Instruction* constant_instruction = nullptr;
173 analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
174
175 for (uint32_t operand_id = 0; operand_id < induction->NumInOperands();
176 operand_id += 2) {
177 BasicBlock* bb = context_->cfg()->block(
178 induction->GetSingleWordInOperand(operand_id + 1));
179
180 if (!IsInsideLoop(bb)) {
181 constant_instruction = def_use_manager->GetDef(
182 induction->GetSingleWordInOperand(operand_id));
183 }
184 }
185
186 if (!constant_instruction) return false;
187
188 const analysis::Constant* constant =
189 context_->get_constant_mgr()->FindDeclaredConstant(
190 constant_instruction->result_id());
191 if (!constant) return false;
192
193 if (value) {
194 const analysis::Integer* type =
195 constant->AsIntConstant()->type()->AsInteger();
196
197 if (type->IsSigned()) {
198 *value = constant->AsIntConstant()->GetS32BitValue();
199 } else {
200 *value = constant->AsIntConstant()->GetU32BitValue();
201 }
202 }
203
204 return true;
205}
206
207Loop::Loop(IRContext* context, DominatorAnalysis* dom_analysis,
208 BasicBlock* header, BasicBlock* continue_target,
209 BasicBlock* merge_target)
210 : context_(context),
211 loop_header_(header),
212 loop_continue_(continue_target),
213 loop_merge_(merge_target),
214 loop_preheader_(nullptr),
215 parent_(nullptr),
216 loop_is_marked_for_removal_(false) {
217 assert(context);
218 assert(dom_analysis);
219 loop_preheader_ = FindLoopPreheader(dom_analysis);
220 loop_latch_ = FindLatchBlock();
221}
222
223BasicBlock* Loop::FindLoopPreheader(DominatorAnalysis* dom_analysis) {
224 CFG* cfg = context_->cfg();
225 DominatorTree& dom_tree = dom_analysis->GetDomTree();
226 DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_);
227
228 // The loop predecessor.
229 BasicBlock* loop_pred = nullptr;
230
231 auto header_pred = cfg->preds(loop_header_->id());
232 for (uint32_t p_id : header_pred) {
233 DominatorTreeNode* node = dom_tree.GetTreeNode(p_id);
234 if (node && !dom_tree.Dominates(header_node, node)) {
235 // The predecessor is not part of the loop, so potential loop preheader.
236 if (loop_pred && node->bb_ != loop_pred) {
237 // If we saw 2 distinct predecessors that are outside the loop, we don't
238 // have a loop preheader.
239 return nullptr;
240 }
241 loop_pred = node->bb_;
242 }
243 }
244 // Safe guard against invalid code, SPIR-V spec forbids loop with the entry
245 // node as header.
246 assert(loop_pred && "The header node is the entry block ?");
247
248 // So we have a unique basic block that can enter this loop.
249 // If this loop is the unique successor of this block, then it is a loop
250 // preheader.
251 bool is_preheader = true;
252 uint32_t loop_header_id = loop_header_->id();
253 const auto* const_loop_pred = loop_pred;
254 const_loop_pred->ForEachSuccessorLabel(
255 [&is_preheader, loop_header_id](const uint32_t id) {
256 if (id != loop_header_id) is_preheader = false;
257 });
258 if (is_preheader) return loop_pred;
259 return nullptr;
260}
261
262bool Loop::IsInsideLoop(Instruction* inst) const {
263 const BasicBlock* parent_block = context_->get_instr_block(inst);
264 if (!parent_block) return false;
265 return IsInsideLoop(parent_block);
266}
267
268bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
269 assert(bb->GetParent() && "The basic block does not belong to a function");
270 DominatorAnalysis* dom_analysis =
271 context_->GetDominatorAnalysis(bb->GetParent());
272 if (dom_analysis->IsReachable(bb) &&
273 !dom_analysis->Dominates(GetHeaderBlock(), bb))
274 return false;
275
276 return true;
277}
278
279BasicBlock* Loop::GetOrCreatePreHeaderBlock() {
280 if (loop_preheader_) return loop_preheader_;
281
282 CFG* cfg = context_->cfg();
283 loop_header_ = cfg->SplitLoopHeader(loop_header_);
284 return loop_preheader_;
285}
286
287void Loop::SetContinueBlock(BasicBlock* continue_block) {
288 assert(IsInsideLoop(continue_block));
289 loop_continue_ = continue_block;
290}
291
292void Loop::SetLatchBlock(BasicBlock* latch) {
293#ifndef NDEBUG
294 assert(latch->GetParent() && "The basic block does not belong to a function");
295
296 const auto* const_latch = latch;
297 const_latch->ForEachSuccessorLabel([this](uint32_t id) {
298 assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) &&
299 "A predecessor of the continue block does not belong to the loop");
300 });
301#endif // NDEBUG
302 assert(IsInsideLoop(latch) && "The continue block is not in the loop");
303
304 SetLatchBlockImpl(latch);
305}
306
307void Loop::SetMergeBlock(BasicBlock* merge) {
308#ifndef NDEBUG
309 assert(merge->GetParent() && "The basic block does not belong to a function");
310#endif // NDEBUG
311 assert(!IsInsideLoop(merge) && "The merge block is in the loop");
312
313 SetMergeBlockImpl(merge);
314 if (GetHeaderBlock()->GetLoopMergeInst()) {
315 UpdateLoopMergeInst();
316 }
317}
318
319void Loop::SetPreHeaderBlock(BasicBlock* preheader) {
320 if (preheader) {
321 assert(!IsInsideLoop(preheader) && "The preheader block is in the loop");
322 assert(preheader->tail()->opcode() == SpvOpBranch &&
323 "The preheader block does not unconditionally branch to the header "
324 "block");
325 assert(preheader->tail()->GetSingleWordOperand(0) ==
326 GetHeaderBlock()->id() &&
327 "The preheader block does not unconditionally branch to the header "
328 "block");
329 }
330 loop_preheader_ = preheader;
331}
332
333BasicBlock* Loop::FindLatchBlock() {
334 CFG* cfg = context_->cfg();
335
336 DominatorAnalysis* dominator_analysis =
337 context_->GetDominatorAnalysis(loop_header_->GetParent());
338
339 // Look at the predecessors of the loop header to find a predecessor block
340 // which is dominated by the loop continue target. There should only be one
341 // block which meets this criteria and this is the latch block, as per the
342 // SPIR-V spec.
343 for (uint32_t block_id : cfg->preds(loop_header_->id())) {
344 if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) {
345 return cfg->block(block_id);
346 }
347 }
348
349 assert(
350 false &&
351 "Every loop should have a latch block dominated by the continue target");
352 return nullptr;
353}
354
355void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
356 CFG* cfg = context_->cfg();
357 exit_blocks->clear();
358
359 for (uint32_t bb_id : GetBlocks()) {
360 const BasicBlock* bb = cfg->block(bb_id);
361 bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) {
362 if (!IsInsideLoop(succ)) {
363 exit_blocks->insert(succ);
364 }
365 });
366 }
367}
368
369void Loop::GetMergingBlocks(
370 std::unordered_set<uint32_t>* merging_blocks) const {
371 assert(GetMergeBlock() && "This loop is not structured");
372 CFG* cfg = context_->cfg();
373 merging_blocks->clear();
374
375 std::stack<const BasicBlock*> to_visit;
376 to_visit.push(GetMergeBlock());
377 while (!to_visit.empty()) {
378 const BasicBlock* bb = to_visit.top();
379 to_visit.pop();
380 merging_blocks->insert(bb->id());
381 for (uint32_t pred_id : cfg->preds(bb->id())) {
382 if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) {
383 to_visit.push(cfg->block(pred_id));
384 }
385 }
386 }
387}
388
389namespace {
390
391static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) {
392 for (Instruction& inst : *bb) {
393 if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst))
394 return false;
395 }
396
397 return true;
398}
399
400} // namespace
401
402bool Loop::IsSafeToClone() const {
403 CFG& cfg = *context_->cfg();
404
405 for (uint32_t bb_id : GetBlocks()) {
406 BasicBlock* bb = cfg.block(bb_id);
407 assert(bb);
408 if (!IsBasicBlockSafeToClone(context_, bb)) return false;
409 }
410
411 // Look at the merge construct.
412 if (GetHeaderBlock()->GetLoopMergeInst()) {
413 std::unordered_set<uint32_t> blocks;
414 GetMergingBlocks(&blocks);
415 blocks.erase(GetMergeBlock()->id());
416 for (uint32_t bb_id : blocks) {
417 BasicBlock* bb = cfg.block(bb_id);
418 assert(bb);
419 if (!IsBasicBlockSafeToClone(context_, bb)) return false;
420 }
421 }
422
423 return true;
424}
425
426bool Loop::IsLCSSA() const {
427 CFG* cfg = context_->cfg();
428 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
429
430 std::unordered_set<uint32_t> exit_blocks;
431 GetExitBlocks(&exit_blocks);
432
433 // Declare ir_context so we can capture context_ in the below lambda
434 IRContext* ir_context = context_;
435
436 for (uint32_t bb_id : GetBlocks()) {
437 for (Instruction& insn : *cfg->block(bb_id)) {
438 // All uses must be either:
439 // - In the loop;
440 // - In an exit block and in a phi instruction.
441 if (!def_use_mgr->WhileEachUser(
442 &insn,
443 [&exit_blocks, ir_context, this](Instruction* use) -> bool {
444 BasicBlock* parent = ir_context->get_instr_block(use);
445 assert(parent && "Invalid analysis");
446 if (IsInsideLoop(parent)) return true;
447 if (use->opcode() != SpvOpPhi) return false;
448 return exit_blocks.count(parent->id());
449 }))
450 return false;
451 }
452 }
453 return true;
454}
455
456bool Loop::ShouldHoistInstruction(IRContext* context, Instruction* inst) {
457 return AreAllOperandsOutsideLoop(context, inst) &&
458 inst->IsOpcodeCodeMotionSafe();
459}
460
461bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) {
462 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
463 bool all_outside_loop = true;
464
465 const std::function<void(uint32_t*)> operand_outside_loop =
466 [this, &def_use_mgr, &all_outside_loop](uint32_t* id) {
467 if (this->IsInsideLoop(def_use_mgr->GetDef(*id))) {
468 all_outside_loop = false;
469 return;
470 }
471 };
472
473 inst->ForEachInId(operand_outside_loop);
474 return all_outside_loop;
475}
476
477void Loop::ComputeLoopStructuredOrder(
478 std::vector<BasicBlock*>* ordered_loop_blocks, bool include_pre_header,
479 bool include_merge) const {
480 CFG& cfg = *context_->cfg();
481
482 // Reserve the memory: all blocks in the loop + extra if needed.
483 ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header +
484 include_merge);
485
486 if (include_pre_header && GetPreHeaderBlock())
487 ordered_loop_blocks->push_back(loop_preheader_);
488 cfg.ForEachBlockInReversePostOrder(
489 loop_header_, [ordered_loop_blocks, this](BasicBlock* bb) {
490 if (IsInsideLoop(bb)) ordered_loop_blocks->push_back(bb);
491 });
492 if (include_merge && GetMergeBlock())
493 ordered_loop_blocks->push_back(loop_merge_);
494}
495
496LoopDescriptor::LoopDescriptor(IRContext* context, const Function* f)
497 : loops_(), dummy_top_loop_(nullptr) {
498 PopulateList(context, f);
499}
500
501LoopDescriptor::~LoopDescriptor() { ClearLoops(); }
502
503void LoopDescriptor::PopulateList(IRContext* context, const Function* f) {
504 DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f);
505
506 ClearLoops();
507
508 // Post-order traversal of the dominator tree to find all the OpLoopMerge
509 // instructions.
510 DominatorTree& dom_tree = dom_analysis->GetDomTree();
511 for (DominatorTreeNode& node :
512 make_range(dom_tree.post_begin(), dom_tree.post_end())) {
513 Instruction* merge_inst = node.bb_->GetLoopMergeInst();
514 if (merge_inst) {
515 bool all_backedge_unreachable = true;
516 for (uint32_t pid : context->cfg()->preds(node.bb_->id())) {
517 if (dom_analysis->IsReachable(pid) &&
518 dom_analysis->Dominates(node.bb_->id(), pid)) {
519 all_backedge_unreachable = false;
520 break;
521 }
522 }
523 if (all_backedge_unreachable)
524 continue; // ignore this one, we actually never branch back.
525
526 // The id of the merge basic block of this loop.
527 uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
528
529 // The id of the continue basic block of this loop.
530 uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1);
531
532 // The merge target of this loop.
533 BasicBlock* merge_bb = context->cfg()->block(merge_bb_id);
534
535 // The continue target of this loop.
536 BasicBlock* continue_bb = context->cfg()->block(continue_bb_id);
537
538 // The basic block containing the merge instruction.
539 BasicBlock* header_bb = context->get_instr_block(merge_inst);
540
541 // Add the loop to the list of all the loops in the function.
542 Loop* current_loop =
543 new Loop(context, dom_analysis, header_bb, continue_bb, merge_bb);
544 loops_.push_back(current_loop);
545
546 // We have a bottom-up construction, so if this loop has nested-loops,
547 // they are by construction at the tail of the loop list.
548 for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) {
549 Loop* previous_loop = *itr;
550
551 // If the loop already has a parent, then it has been processed.
552 if (previous_loop->HasParent()) continue;
553
554 // If the current loop does not dominates the previous loop then it is
555 // not nested loop.
556 if (!dom_analysis->Dominates(header_bb,
557 previous_loop->GetHeaderBlock()))
558 continue;
559 // If the current loop merge dominates the previous loop then it is
560 // not nested loop.
561 if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock()))
562 continue;
563
564 current_loop->AddNestedLoop(previous_loop);
565 }
566 DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb);
567 for (DominatorTreeNode& loop_node :
568 make_range(node.df_begin(), node.df_end())) {
569 // Check if we are in the loop.
570 if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
571 current_loop->AddBasicBlock(loop_node.bb_);
572 basic_block_to_loop_.insert(
573 std::make_pair(loop_node.bb_->id(), current_loop));
574 }
575 }
576 }
577 for (Loop* loop : loops_) {
578 if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
579 }
580}
581
582std::vector<Loop*> LoopDescriptor::GetLoopsInBinaryLayoutOrder() {
583 std::vector<uint32_t> ids{};
584
585 for (size_t i = 0; i < NumLoops(); ++i) {
586 ids.push_back(GetLoopByIndex(i).GetHeaderBlock()->id());
587 }
588
589 std::vector<Loop*> loops{};
590 if (!ids.empty()) {
591 auto function = GetLoopByIndex(0).GetHeaderBlock()->GetParent();
592 for (const auto& block : *function) {
593 auto block_id = block.id();
594
595 auto element = std::find(std::begin(ids), std::end(ids), block_id);
596 if (element != std::end(ids)) {
597 loops.push_back(&GetLoopByIndex(element - std::begin(ids)));
598 }
599 }
600 }
601
602 return loops;
603}
604
605BasicBlock* Loop::FindConditionBlock() const {
606 if (!loop_merge_) {
607 return nullptr;
608 }
609 BasicBlock* condition_block = nullptr;
610
611 uint32_t in_loop_pred = 0;
612 for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) {
613 if (IsInsideLoop(p)) {
614 if (in_loop_pred) {
615 // 2 in-loop predecessors.
616 return nullptr;
617 }
618 in_loop_pred = p;
619 }
620 }
621 if (!in_loop_pred) {
622 // Merge block is unreachable.
623 return nullptr;
624 }
625
626 BasicBlock* bb = context_->cfg()->block(in_loop_pred);
627
628 if (!bb) return nullptr;
629
630 const Instruction& branch = *bb->ctail();
631
632 // Make sure the branch is a conditional branch.
633 if (branch.opcode() != SpvOpBranchConditional) return nullptr;
634
635 // Make sure one of the two possible branches is to the merge block.
636 if (branch.GetSingleWordInOperand(1) == loop_merge_->id() ||
637 branch.GetSingleWordInOperand(2) == loop_merge_->id()) {
638 condition_block = bb;
639 }
640
641 return condition_block;
642}
643
644bool Loop::FindNumberOfIterations(const Instruction* induction,
645 const Instruction* branch_inst,
646 size_t* iterations_out,
647 int64_t* step_value_out,
648 int64_t* init_value_out) const {
649 // From the branch instruction find the branch condition.
650 analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
651
652 // Condition instruction from the OpConditionalBranch.
653 Instruction* condition =
654 def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0));
655
656 assert(IsSupportedCondition(condition->opcode()));
657
658 // Get the constant manager from the ir context.
659 analysis::ConstantManager* const_manager = context_->get_constant_mgr();
660
661 // Find the constant value used by the condition variable. Exit out if it
662 // isn't a constant int.
663 const analysis::Constant* upper_bound =
664 const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3));
665 if (!upper_bound) return false;
666
667 // Must be integer because of the opcode on the condition.
668 int64_t condition_value = 0;
669
670 const analysis::Integer* type =
671 upper_bound->AsIntConstant()->type()->AsInteger();
672
673 if (type->width() > 32) {
674 return false;
675 }
676
677 if (type->IsSigned()) {
678 condition_value = upper_bound->AsIntConstant()->GetS32BitValue();
679 } else {
680 condition_value = upper_bound->AsIntConstant()->GetU32BitValue();
681 }
682
683 // Find the instruction which is stepping through the loop.
684 Instruction* step_inst = GetInductionStepOperation(induction);
685 if (!step_inst) return false;
686
687 // Find the constant value used by the condition variable.
688 const analysis::Constant* step_constant =
689 const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3));
690 if (!step_constant) return false;
691
692 // Must be integer because of the opcode on the condition.
693 int64_t step_value = 0;
694
695 const analysis::Integer* step_type =
696 step_constant->AsIntConstant()->type()->AsInteger();
697
698 if (step_type->IsSigned()) {
699 step_value = step_constant->AsIntConstant()->GetS32BitValue();
700 } else {
701 step_value = step_constant->AsIntConstant()->GetU32BitValue();
702 }
703
704 // If this is a subtraction step we should negate the step value.
705 if (step_inst->opcode() == SpvOp::SpvOpISub) {
706 step_value = -step_value;
707 }
708
709 // Find the inital value of the loop and make sure it is a constant integer.
710 int64_t init_value = 0;
711 if (!GetInductionInitValue(induction, &init_value)) return false;
712
713 // If iterations is non null then store the value in that.
714 int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
715 init_value, step_value);
716
717 // If the loop body will not be reached return false.
718 if (num_itrs <= 0) {
719 return false;
720 }
721
722 if (iterations_out) {
723 assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
724 *iterations_out = static_cast<size_t>(num_itrs);
725 }
726
727 if (step_value_out) {
728 *step_value_out = step_value;
729 }
730
731 if (init_value_out) {
732 *init_value_out = init_value;
733 }
734
735 return true;
736}
737
738// We retrieve the number of iterations using the following formula, diff /
739// |step_value| where diff is calculated differently according to the
740// |condition| and uses the |condition_value| and |init_value|. If diff /
741// |step_value| is NOT cleanly divisable then we add one to the sum.
742int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
743 int64_t init_value, int64_t step_value) const {
744 int64_t diff = 0;
745
746 switch (condition) {
747 case SpvOp::SpvOpSLessThan:
748 case SpvOp::SpvOpULessThan: {
749 // If the condition is not met to begin with the loop will never iterate.
750 if (!(init_value < condition_value)) return 0;
751
752 diff = condition_value - init_value;
753
754 // If the operation is a less then operation then the diff and step must
755 // have the same sign otherwise the induction will never cross the
756 // condition (either never true or always true).
757 if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
758 return 0;
759 }
760
761 break;
762 }
763 case SpvOp::SpvOpSGreaterThan:
764 case SpvOp::SpvOpUGreaterThan: {
765 // If the condition is not met to begin with the loop will never iterate.
766 if (!(init_value > condition_value)) return 0;
767
768 diff = init_value - condition_value;
769
770 // If the operation is a greater than operation then the diff and step
771 // must have opposite signs. Otherwise the condition will always be true
772 // or will never be true.
773 if ((diff < 0 && step_value < 0) || (diff > 0 && step_value > 0)) {
774 return 0;
775 }
776
777 break;
778 }
779
780 case SpvOp::SpvOpSGreaterThanEqual:
781 case SpvOp::SpvOpUGreaterThanEqual: {
782 // If the condition is not met to begin with the loop will never iterate.
783 if (!(init_value >= condition_value)) return 0;
784
785 // We subract one to make it the same as SpvOpGreaterThan as it is
786 // functionally equivalent.
787 diff = init_value - (condition_value - 1);
788
789 // If the operation is a greater than operation then the diff and step
790 // must have opposite signs. Otherwise the condition will always be true
791 // or will never be true.
792 if ((diff > 0 && step_value > 0) || (diff < 0 && step_value < 0)) {
793 return 0;
794 }
795
796 break;
797 }
798
799 case SpvOp::SpvOpSLessThanEqual:
800 case SpvOp::SpvOpULessThanEqual: {
801 // If the condition is not met to begin with the loop will never iterate.
802 if (!(init_value <= condition_value)) return 0;
803
804 // We add one to make it the same as SpvOpLessThan as it is functionally
805 // equivalent.
806 diff = (condition_value + 1) - init_value;
807
808 // If the operation is a less than operation then the diff and step must
809 // have the same sign otherwise the induction will never cross the
810 // condition (either never true or always true).
811 if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
812 return 0;
813 }
814
815 break;
816 }
817
818 default:
819 assert(false &&
820 "Could not retrieve number of iterations from the loop condition. "
821 "Condition is not supported.");
822 }
823
824 // Take the abs of - step values.
825 step_value = llabs(step_value);
826 diff = llabs(diff);
827 int64_t result = diff / step_value;
828
829 if (diff % step_value != 0) {
830 result += 1;
831 }
832 return result;
833}
834
835// Returns the list of induction variables within the loop.
836void Loop::GetInductionVariables(
837 std::vector<Instruction*>& induction_variables) const {
838 for (Instruction& inst : *loop_header_) {
839 if (inst.opcode() == SpvOp::SpvOpPhi) {
840 induction_variables.push_back(&inst);
841 }
842 }
843}
844
845Instruction* Loop::FindConditionVariable(
846 const BasicBlock* condition_block) const {
847 // Find the branch instruction.
848 const Instruction& branch_inst = *condition_block->ctail();
849
850 Instruction* induction = nullptr;
851 // Verify that the branch instruction is a conditional branch.
852 if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) {
853 // From the branch instruction find the branch condition.
854 analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
855
856 // Find the instruction representing the condition used in the conditional
857 // branch.
858 Instruction* condition =
859 def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0));
860
861 // Ensure that the condition is a less than operation.
862 if (condition && IsSupportedCondition(condition->opcode())) {
863 // The left hand side operand of the operation.
864 Instruction* variable_inst =
865 def_use_manager->GetDef(condition->GetSingleWordOperand(2));
866
867 // Make sure the variable instruction used is a phi.
868 if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr;
869
870 // Make sure the phi instruction only has two incoming blocks. Each
871 // incoming block will be represented by two in operands in the phi
872 // instruction, the value and the block which that value came from. We
873 // assume the cannocalised phi will have two incoming values, one from the
874 // preheader and one from the continue block.
875 size_t max_supported_operands = 4;
876 if (variable_inst->NumInOperands() == max_supported_operands) {
877 // The operand index of the first incoming block label.
878 uint32_t operand_label_1 = 1;
879
880 // The operand index of the second incoming block label.
881 uint32_t operand_label_2 = 3;
882
883 // Make sure one of them is the preheader.
884 if (!IsInsideLoop(
885 variable_inst->GetSingleWordInOperand(operand_label_1)) &&
886 !IsInsideLoop(
887 variable_inst->GetSingleWordInOperand(operand_label_2))) {
888 return nullptr;
889 }
890
891 // And make sure that the other is the latch block.
892 if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
893 loop_latch_->id() &&
894 variable_inst->GetSingleWordInOperand(operand_label_2) !=
895 loop_latch_->id()) {
896 return nullptr;
897 }
898 } else {
899 return nullptr;
900 }
901
902 if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr))
903 return nullptr;
904 induction = variable_inst;
905 }
906 }
907
908 return induction;
909}
910
911bool LoopDescriptor::CreatePreHeaderBlocksIfMissing() {
912 auto modified = false;
913
914 for (auto& loop : *this) {
915 if (!loop.GetPreHeaderBlock()) {
916 modified = true;
917 // TODO(1841): Handle failure to create pre-header.
918 loop.GetOrCreatePreHeaderBlock();
919 }
920 }
921
922 return modified;
923}
924
925// Add and remove loops which have been marked for addition and removal to
926// maintain the state of the loop descriptor class.
927void LoopDescriptor::PostModificationCleanup() {
928 LoopContainerType loops_to_remove_;
929 for (Loop* loop : loops_) {
930 if (loop->IsMarkedForRemoval()) {
931 loops_to_remove_.push_back(loop);
932 if (loop->HasParent()) {
933 loop->GetParent()->RemoveChildLoop(loop);
934 }
935 }
936 }
937
938 for (Loop* loop : loops_to_remove_) {
939 loops_.erase(std::find(loops_.begin(), loops_.end(), loop));
940 delete loop;
941 }
942
943 for (auto& pair : loops_to_add_) {
944 Loop* parent = pair.first;
945 std::unique_ptr<Loop> loop = std::move(pair.second);
946
947 if (parent) {
948 loop->SetParent(nullptr);
949 parent->AddNestedLoop(loop.get());
950
951 for (uint32_t block_id : loop->GetBlocks()) {
952 parent->AddBasicBlock(block_id);
953 }
954 }
955
956 loops_.emplace_back(loop.release());
957 }
958
959 loops_to_add_.clear();
960}
961
962void LoopDescriptor::ClearLoops() {
963 for (Loop* loop : loops_) {
964 delete loop;
965 }
966 loops_.clear();
967}
968
969// Adds a new loop nest to the descriptor set.
970Loop* LoopDescriptor::AddLoopNest(std::unique_ptr<Loop> new_loop) {
971 Loop* loop = new_loop.release();
972 if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
973 // Iterate from inner to outer most loop, adding basic block to loop mapping
974 // as we go.
975 for (Loop& current_loop :
976 make_range(iterator::begin(loop), iterator::end(nullptr))) {
977 loops_.push_back(&current_loop);
978 for (uint32_t bb_id : current_loop.GetBlocks())
979 basic_block_to_loop_.insert(std::make_pair(bb_id, &current_loop));
980 }
981
982 return loop;
983}
984
985void LoopDescriptor::RemoveLoop(Loop* loop) {
986 Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_;
987 parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(),
988 parent->nested_loops_.end(), loop));
989 std::for_each(
990 loop->nested_loops_.begin(), loop->nested_loops_.end(),
991 [loop](Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); });
992 parent->nested_loops_.insert(parent->nested_loops_.end(),
993 loop->nested_loops_.begin(),
994 loop->nested_loops_.end());
995 for (uint32_t bb_id : loop->GetBlocks()) {
996 Loop* l = FindLoopForBasicBlock(bb_id);
997 if (l == loop) {
998 SetBasicBlockToLoop(bb_id, l->GetParent());
999 } else {
1000 ForgetBasicBlock(bb_id);
1001 }
1002 }
1003
1004 LoopContainerType::iterator it =
1005 std::find(loops_.begin(), loops_.end(), loop);
1006 assert(it != loops_.end());
1007 delete loop;
1008 loops_.erase(it);
1009}
1010
1011} // namespace opt
1012} // namespace spvtools
1013