| 1 | // Copyright (c) 2016 Google Inc. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "source/opt/basic_block.h" |
| 16 | |
| 17 | #include <ostream> |
| 18 | |
| 19 | #include "source/opt/function.h" |
| 20 | #include "source/opt/ir_context.h" |
| 21 | #include "source/opt/module.h" |
| 22 | #include "source/opt/reflect.h" |
| 23 | #include "source/util/make_unique.h" |
| 24 | |
| 25 | namespace spvtools { |
| 26 | namespace opt { |
| 27 | namespace { |
| 28 | |
| 29 | const uint32_t kLoopMergeContinueBlockIdInIdx = 1; |
| 30 | const uint32_t kLoopMergeMergeBlockIdInIdx = 0; |
| 31 | const uint32_t kSelectionMergeMergeBlockIdInIdx = 0; |
| 32 | |
| 33 | } // namespace |
| 34 | |
| 35 | BasicBlock* BasicBlock::Clone(IRContext* context) const { |
| 36 | BasicBlock* clone = new BasicBlock( |
| 37 | std::unique_ptr<Instruction>(GetLabelInst()->Clone(context))); |
| 38 | for (const auto& inst : insts_) { |
| 39 | // Use the incoming context |
| 40 | clone->AddInstruction(std::unique_ptr<Instruction>(inst.Clone(context))); |
| 41 | } |
| 42 | |
| 43 | if (context->AreAnalysesValid( |
| 44 | IRContext::Analysis::kAnalysisInstrToBlockMapping)) { |
| 45 | for (auto& inst : *clone) { |
| 46 | context->set_instr_block(&inst, clone); |
| 47 | } |
| 48 | } |
| 49 | |
| 50 | return clone; |
| 51 | } |
| 52 | |
| 53 | const Instruction* BasicBlock::GetMergeInst() const { |
| 54 | const Instruction* result = nullptr; |
| 55 | // If it exists, the merge instruction immediately precedes the |
| 56 | // terminator. |
| 57 | auto iter = ctail(); |
| 58 | if (iter != cbegin()) { |
| 59 | --iter; |
| 60 | const auto opcode = iter->opcode(); |
| 61 | if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) { |
| 62 | result = &*iter; |
| 63 | } |
| 64 | } |
| 65 | return result; |
| 66 | } |
| 67 | |
| 68 | Instruction* BasicBlock::GetMergeInst() { |
| 69 | Instruction* result = nullptr; |
| 70 | // If it exists, the merge instruction immediately precedes the |
| 71 | // terminator. |
| 72 | auto iter = tail(); |
| 73 | if (iter != begin()) { |
| 74 | --iter; |
| 75 | const auto opcode = iter->opcode(); |
| 76 | if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) { |
| 77 | result = &*iter; |
| 78 | } |
| 79 | } |
| 80 | return result; |
| 81 | } |
| 82 | |
| 83 | const Instruction* BasicBlock::GetLoopMergeInst() const { |
| 84 | if (auto* merge = GetMergeInst()) { |
| 85 | if (merge->opcode() == SpvOpLoopMerge) { |
| 86 | return merge; |
| 87 | } |
| 88 | } |
| 89 | return nullptr; |
| 90 | } |
| 91 | |
| 92 | Instruction* BasicBlock::GetLoopMergeInst() { |
| 93 | if (auto* merge = GetMergeInst()) { |
| 94 | if (merge->opcode() == SpvOpLoopMerge) { |
| 95 | return merge; |
| 96 | } |
| 97 | } |
| 98 | return nullptr; |
| 99 | } |
| 100 | |
| 101 | void BasicBlock::KillAllInsts(bool killLabel) { |
| 102 | ForEachInst([killLabel](Instruction* ip) { |
| 103 | if (killLabel || ip->opcode() != SpvOpLabel) { |
| 104 | ip->context()->KillInst(ip); |
| 105 | } |
| 106 | }); |
| 107 | } |
| 108 | |
| 109 | void BasicBlock::ForEachSuccessorLabel( |
| 110 | const std::function<void(const uint32_t)>& f) const { |
| 111 | WhileEachSuccessorLabel([f](const uint32_t l) { |
| 112 | f(l); |
| 113 | return true; |
| 114 | }); |
| 115 | } |
| 116 | |
| 117 | bool BasicBlock::WhileEachSuccessorLabel( |
| 118 | const std::function<bool(const uint32_t)>& f) const { |
| 119 | const auto br = &insts_.back(); |
| 120 | switch (br->opcode()) { |
| 121 | case SpvOpBranch: |
| 122 | return f(br->GetOperand(0).words[0]); |
| 123 | case SpvOpBranchConditional: |
| 124 | case SpvOpSwitch: { |
| 125 | bool is_first = true; |
| 126 | return br->WhileEachInId([&is_first, &f](const uint32_t* idp) { |
| 127 | if (!is_first) return f(*idp); |
| 128 | is_first = false; |
| 129 | return true; |
| 130 | }); |
| 131 | } |
| 132 | default: |
| 133 | return true; |
| 134 | } |
| 135 | } |
| 136 | |
| 137 | void BasicBlock::ForEachSuccessorLabel( |
| 138 | const std::function<void(uint32_t*)>& f) { |
| 139 | auto br = &insts_.back(); |
| 140 | switch (br->opcode()) { |
| 141 | case SpvOpBranch: { |
| 142 | uint32_t tmp_id = br->GetOperand(0).words[0]; |
| 143 | f(&tmp_id); |
| 144 | if (tmp_id != br->GetOperand(0).words[0]) br->SetOperand(0, {tmp_id}); |
| 145 | } break; |
| 146 | case SpvOpBranchConditional: |
| 147 | case SpvOpSwitch: { |
| 148 | bool is_first = true; |
| 149 | br->ForEachInId([&is_first, &f](uint32_t* idp) { |
| 150 | if (!is_first) f(idp); |
| 151 | is_first = false; |
| 152 | }); |
| 153 | } break; |
| 154 | default: |
| 155 | break; |
| 156 | } |
| 157 | } |
| 158 | |
| 159 | bool BasicBlock::IsSuccessor(const BasicBlock* block) const { |
| 160 | uint32_t succId = block->id(); |
| 161 | bool isSuccessor = false; |
| 162 | ForEachSuccessorLabel([&isSuccessor, succId](const uint32_t label) { |
| 163 | if (label == succId) isSuccessor = true; |
| 164 | }); |
| 165 | return isSuccessor; |
| 166 | } |
| 167 | |
| 168 | void BasicBlock::ForMergeAndContinueLabel( |
| 169 | const std::function<void(const uint32_t)>& f) { |
| 170 | auto ii = insts_.end(); |
| 171 | --ii; |
| 172 | if (ii == insts_.begin()) return; |
| 173 | --ii; |
| 174 | if (ii->opcode() == SpvOpSelectionMerge || ii->opcode() == SpvOpLoopMerge) { |
| 175 | ii->ForEachInId([&f](const uint32_t* idp) { f(*idp); }); |
| 176 | } |
| 177 | } |
| 178 | |
| 179 | uint32_t BasicBlock::MergeBlockIdIfAny() const { |
| 180 | auto merge_ii = cend(); |
| 181 | --merge_ii; |
| 182 | uint32_t mbid = 0; |
| 183 | if (merge_ii != cbegin()) { |
| 184 | --merge_ii; |
| 185 | if (merge_ii->opcode() == SpvOpLoopMerge) { |
| 186 | mbid = merge_ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx); |
| 187 | } else if (merge_ii->opcode() == SpvOpSelectionMerge) { |
| 188 | mbid = merge_ii->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx); |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | return mbid; |
| 193 | } |
| 194 | |
| 195 | uint32_t BasicBlock::MergeBlockId() const { |
| 196 | uint32_t mbid = MergeBlockIdIfAny(); |
| 197 | assert(mbid && "Expected block to have a corresponding merge block" ); |
| 198 | return mbid; |
| 199 | } |
| 200 | |
| 201 | uint32_t BasicBlock::ContinueBlockIdIfAny() const { |
| 202 | auto merge_ii = cend(); |
| 203 | --merge_ii; |
| 204 | uint32_t cbid = 0; |
| 205 | if (merge_ii != cbegin()) { |
| 206 | --merge_ii; |
| 207 | if (merge_ii->opcode() == SpvOpLoopMerge) { |
| 208 | cbid = merge_ii->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx); |
| 209 | } |
| 210 | } |
| 211 | return cbid; |
| 212 | } |
| 213 | |
| 214 | uint32_t BasicBlock::ContinueBlockId() const { |
| 215 | uint32_t cbid = ContinueBlockIdIfAny(); |
| 216 | assert(cbid && "Expected block to have a corresponding continue target" ); |
| 217 | return cbid; |
| 218 | } |
| 219 | |
| 220 | std::ostream& operator<<(std::ostream& str, const BasicBlock& block) { |
| 221 | str << block.PrettyPrint(); |
| 222 | return str; |
| 223 | } |
| 224 | |
| 225 | void BasicBlock::Dump() const { |
| 226 | std::cerr << "Basic block #" << id() << "\n" << *this << "\n " ; |
| 227 | } |
| 228 | |
| 229 | std::string BasicBlock::PrettyPrint(uint32_t options) const { |
| 230 | std::ostringstream str; |
| 231 | ForEachInst([&str, options](const Instruction* inst) { |
| 232 | str << inst->PrettyPrint(options); |
| 233 | if (!IsTerminatorInst(inst->opcode())) { |
| 234 | str << std::endl; |
| 235 | } |
| 236 | }); |
| 237 | return str.str(); |
| 238 | } |
| 239 | |
| 240 | BasicBlock* BasicBlock::SplitBasicBlock(IRContext* context, uint32_t label_id, |
| 241 | iterator iter) { |
| 242 | assert(!insts_.empty()); |
| 243 | |
| 244 | std::unique_ptr<BasicBlock> new_block_temp = |
| 245 | MakeUnique<BasicBlock>(MakeUnique<Instruction>( |
| 246 | context, SpvOpLabel, 0, label_id, std::initializer_list<Operand>{})); |
| 247 | BasicBlock* new_block = new_block_temp.get(); |
| 248 | function_->InsertBasicBlockAfter(std::move(new_block_temp), this); |
| 249 | |
| 250 | new_block->insts_.Splice(new_block->end(), &insts_, iter, end()); |
| 251 | new_block->SetParent(GetParent()); |
| 252 | |
| 253 | context->AnalyzeDefUse(new_block->GetLabelInst()); |
| 254 | |
| 255 | // Update the phi nodes in the successor blocks to reference the new block id. |
| 256 | const_cast<const BasicBlock*>(new_block)->ForEachSuccessorLabel( |
| 257 | [new_block, this, context](const uint32_t label) { |
| 258 | BasicBlock* target_bb = context->get_instr_block(label); |
| 259 | target_bb->ForEachPhiInst( |
| 260 | [this, new_block, context](Instruction* phi_inst) { |
| 261 | bool changed = false; |
| 262 | for (uint32_t i = 1; i < phi_inst->NumInOperands(); i += 2) { |
| 263 | if (phi_inst->GetSingleWordInOperand(i) == this->id()) { |
| 264 | changed = true; |
| 265 | phi_inst->SetInOperand(i, {new_block->id()}); |
| 266 | } |
| 267 | } |
| 268 | |
| 269 | if (changed) { |
| 270 | context->UpdateDefUse(phi_inst); |
| 271 | } |
| 272 | }); |
| 273 | }); |
| 274 | |
| 275 | if (context->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { |
| 276 | context->set_instr_block(new_block->GetLabelInst(), new_block); |
| 277 | new_block->ForEachInst([new_block, context](Instruction* inst) { |
| 278 | context->set_instr_block(inst, new_block); |
| 279 | }); |
| 280 | } |
| 281 | |
| 282 | return new_block; |
| 283 | } |
| 284 | |
| 285 | } // namespace opt |
| 286 | } // namespace spvtools |
| 287 | |