| 1 | // Copyright (c) 2018 Google LLC. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "source/opt/scalar_analysis.h" |
| 16 | |
| 17 | #include <algorithm> |
| 18 | #include <functional> |
| 19 | #include <string> |
| 20 | #include <utility> |
| 21 | |
| 22 | #include "source/opt/ir_context.h" |
| 23 | |
| 24 | // Transforms a given scalar operation instruction into a DAG representation. |
| 25 | // |
| 26 | // 1. Take an instruction and traverse its operands until we reach a |
| 27 | // constant node or an instruction which we do not know how to compute the |
| 28 | // value, such as a load. |
| 29 | // |
| 30 | // 2. Create a new node for each instruction traversed and build the nodes for |
| 31 | // the in operands of that instruction as well. |
| 32 | // |
| 33 | // 3. Add the operand nodes as children of the first and hash the node. Use the |
| 34 | // hash to see if the node is already in the cache. We ensure the children are |
| 35 | // always in sorted order so that two nodes with the same children but inserted |
| 36 | // in a different order have the same hash and so that the overloaded operator== |
| 37 | // will return true. If the node is already in the cache return the cached |
| 38 | // version instead. |
| 39 | // |
| 40 | // 4. The created DAG can then be simplified by |
| 41 | // ScalarAnalysis::SimplifyExpression, implemented in |
| 42 | // scalar_analysis_simplification.cpp. See that file for further information on |
| 43 | // the simplification process. |
| 44 | // |
| 45 | |
| 46 | namespace spvtools { |
| 47 | namespace opt { |
| 48 | |
| 49 | uint32_t SENode::NumberOfNodes = 0; |
| 50 | |
| 51 | ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context) |
| 52 | : context_(context), pretend_equal_{} { |
| 53 | // Create and cached the CantComputeNode. |
| 54 | cached_cant_compute_ = |
| 55 | GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this))); |
| 56 | } |
| 57 | |
| 58 | SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) { |
| 59 | // If operand is can't compute then the whole graph is can't compute. |
| 60 | if (operand->IsCantCompute()) return CreateCantComputeNode(); |
| 61 | |
| 62 | if (operand->GetType() == SENode::Constant) { |
| 63 | return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue()); |
| 64 | } |
| 65 | std::unique_ptr<SENode> negation_node{new SENegative(this)}; |
| 66 | negation_node->AddChild(operand); |
| 67 | return GetCachedOrAdd(std::move(negation_node)); |
| 68 | } |
| 69 | |
| 70 | SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) { |
| 71 | return GetCachedOrAdd( |
| 72 | std::unique_ptr<SENode>(new SEConstantNode(this, integer))); |
| 73 | } |
| 74 | |
| 75 | SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( |
| 76 | const Loop* loop, SENode* offset, SENode* coefficient) { |
| 77 | assert(loop && "Recurrent add expressions must have a valid loop." ); |
| 78 | |
| 79 | // If operands are can't compute then the whole graph is can't compute. |
| 80 | if (offset->IsCantCompute() || coefficient->IsCantCompute()) |
| 81 | return CreateCantComputeNode(); |
| 82 | |
| 83 | const Loop* loop_to_use = nullptr; |
| 84 | if (pretend_equal_[loop]) { |
| 85 | loop_to_use = pretend_equal_[loop]; |
| 86 | } else { |
| 87 | loop_to_use = loop; |
| 88 | } |
| 89 | |
| 90 | std::unique_ptr<SERecurrentNode> phi_node{ |
| 91 | new SERecurrentNode(this, loop_to_use)}; |
| 92 | phi_node->AddOffset(offset); |
| 93 | phi_node->AddCoefficient(coefficient); |
| 94 | |
| 95 | return GetCachedOrAdd(std::move(phi_node)); |
| 96 | } |
| 97 | |
| 98 | SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp( |
| 99 | const Instruction* multiply) { |
| 100 | assert(multiply->opcode() == SpvOp::SpvOpIMul && |
| 101 | "Multiply node did not come from a multiply instruction" ); |
| 102 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); |
| 103 | |
| 104 | SENode* op1 = |
| 105 | AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0))); |
| 106 | SENode* op2 = |
| 107 | AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1))); |
| 108 | |
| 109 | return CreateMultiplyNode(op1, op2); |
| 110 | } |
| 111 | |
| 112 | SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1, |
| 113 | SENode* operand_2) { |
| 114 | // If operands are can't compute then the whole graph is can't compute. |
| 115 | if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) |
| 116 | return CreateCantComputeNode(); |
| 117 | |
| 118 | if (operand_1->GetType() == SENode::Constant && |
| 119 | operand_2->GetType() == SENode::Constant) { |
| 120 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() * |
| 121 | operand_2->AsSEConstantNode()->FoldToSingleValue()); |
| 122 | } |
| 123 | |
| 124 | std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)}; |
| 125 | |
| 126 | multiply_node->AddChild(operand_1); |
| 127 | multiply_node->AddChild(operand_2); |
| 128 | |
| 129 | return GetCachedOrAdd(std::move(multiply_node)); |
| 130 | } |
| 131 | |
| 132 | SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1, |
| 133 | SENode* operand_2) { |
| 134 | // Fold if both operands are constant. |
| 135 | if (operand_1->GetType() == SENode::Constant && |
| 136 | operand_2->GetType() == SENode::Constant) { |
| 137 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() - |
| 138 | operand_2->AsSEConstantNode()->FoldToSingleValue()); |
| 139 | } |
| 140 | |
| 141 | return CreateAddNode(operand_1, CreateNegation(operand_2)); |
| 142 | } |
| 143 | |
| 144 | SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1, |
| 145 | SENode* operand_2) { |
| 146 | // Fold if both operands are constant and the |simplify| flag is true. |
| 147 | if (operand_1->GetType() == SENode::Constant && |
| 148 | operand_2->GetType() == SENode::Constant) { |
| 149 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() + |
| 150 | operand_2->AsSEConstantNode()->FoldToSingleValue()); |
| 151 | } |
| 152 | |
| 153 | // If operands are can't compute then the whole graph is can't compute. |
| 154 | if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) |
| 155 | return CreateCantComputeNode(); |
| 156 | |
| 157 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; |
| 158 | |
| 159 | add_node->AddChild(operand_1); |
| 160 | add_node->AddChild(operand_2); |
| 161 | |
| 162 | return GetCachedOrAdd(std::move(add_node)); |
| 163 | } |
| 164 | |
| 165 | SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) { |
| 166 | auto itr = recurrent_node_map_.find(inst); |
| 167 | if (itr != recurrent_node_map_.end()) return itr->second; |
| 168 | |
| 169 | SENode* output = nullptr; |
| 170 | switch (inst->opcode()) { |
| 171 | case SpvOp::SpvOpPhi: { |
| 172 | output = AnalyzePhiInstruction(inst); |
| 173 | break; |
| 174 | } |
| 175 | case SpvOp::SpvOpConstant: |
| 176 | case SpvOp::SpvOpConstantNull: { |
| 177 | output = AnalyzeConstant(inst); |
| 178 | break; |
| 179 | } |
| 180 | case SpvOp::SpvOpISub: |
| 181 | case SpvOp::SpvOpIAdd: { |
| 182 | output = AnalyzeAddOp(inst); |
| 183 | break; |
| 184 | } |
| 185 | case SpvOp::SpvOpIMul: { |
| 186 | output = AnalyzeMultiplyOp(inst); |
| 187 | break; |
| 188 | } |
| 189 | default: { |
| 190 | output = CreateValueUnknownNode(inst); |
| 191 | break; |
| 192 | } |
| 193 | } |
| 194 | |
| 195 | return output; |
| 196 | } |
| 197 | |
| 198 | SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) { |
| 199 | if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0); |
| 200 | |
| 201 | assert(inst->opcode() == SpvOp::SpvOpConstant); |
| 202 | assert(inst->NumInOperands() == 1); |
| 203 | int64_t value = 0; |
| 204 | |
| 205 | // Look up the instruction in the constant manager. |
| 206 | const analysis::Constant* constant = |
| 207 | context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id()); |
| 208 | |
| 209 | if (!constant) return CreateCantComputeNode(); |
| 210 | |
| 211 | const analysis::IntConstant* int_constant = constant->AsIntConstant(); |
| 212 | |
| 213 | // Exit out if it is a 64 bit integer. |
| 214 | if (!int_constant || int_constant->words().size() != 1) |
| 215 | return CreateCantComputeNode(); |
| 216 | |
| 217 | if (int_constant->type()->AsInteger()->IsSigned()) { |
| 218 | value = int_constant->GetS32BitValue(); |
| 219 | } else { |
| 220 | value = int_constant->GetU32BitValue(); |
| 221 | } |
| 222 | |
| 223 | return CreateConstant(value); |
| 224 | } |
| 225 | |
| 226 | // Handles both addition and subtraction. If the |sub| flag is set then the |
| 227 | // addition will be op1+(-op2) otherwise op1+op2. |
| 228 | SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) { |
| 229 | assert((inst->opcode() == SpvOp::SpvOpIAdd || |
| 230 | inst->opcode() == SpvOp::SpvOpISub) && |
| 231 | "Add node must be created from a OpIAdd or OpISub instruction" ); |
| 232 | |
| 233 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); |
| 234 | |
| 235 | SENode* op1 = |
| 236 | AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0))); |
| 237 | |
| 238 | SENode* op2 = |
| 239 | AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1))); |
| 240 | |
| 241 | // To handle subtraction we wrap the second operand in a unary negation node. |
| 242 | if (inst->opcode() == SpvOp::SpvOpISub) { |
| 243 | op2 = CreateNegation(op2); |
| 244 | } |
| 245 | |
| 246 | return CreateAddNode(op1, op2); |
| 247 | } |
| 248 | |
| 249 | SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) { |
| 250 | // The phi should only have two incoming value pairs. |
| 251 | if (phi->NumInOperands() != 4) { |
| 252 | return CreateCantComputeNode(); |
| 253 | } |
| 254 | |
| 255 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); |
| 256 | |
| 257 | // Get the basic block this instruction belongs to. |
| 258 | BasicBlock* basic_block = |
| 259 | context_->get_instr_block(const_cast<Instruction*>(phi)); |
| 260 | |
| 261 | // And then the function that the basic blocks belongs to. |
| 262 | Function* function = basic_block->GetParent(); |
| 263 | |
| 264 | // Use the function to get the loop descriptor. |
| 265 | LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); |
| 266 | |
| 267 | // We only handle phis in loops at the moment. |
| 268 | if (!loop_descriptor) return CreateCantComputeNode(); |
| 269 | |
| 270 | // Get the innermost loop which this block belongs to. |
| 271 | Loop* loop = (*loop_descriptor)[basic_block->id()]; |
| 272 | |
| 273 | // If the loop doesn't exist or doesn't have a preheader or latch block, exit |
| 274 | // out. |
| 275 | if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() || |
| 276 | loop->GetHeaderBlock() != basic_block) |
| 277 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 278 | |
| 279 | const Loop* loop_to_use = nullptr; |
| 280 | if (pretend_equal_[loop]) { |
| 281 | loop_to_use = pretend_equal_[loop]; |
| 282 | } else { |
| 283 | loop_to_use = loop; |
| 284 | } |
| 285 | std::unique_ptr<SERecurrentNode> phi_node{ |
| 286 | new SERecurrentNode(this, loop_to_use)}; |
| 287 | |
| 288 | // We add the node to this map to allow it to be returned before the node is |
| 289 | // fully built. This is needed as the subsequent call to AnalyzeInstruction |
| 290 | // could lead back to this |phi| instruction so we return the pointer |
| 291 | // immediately in AnalyzeInstruction to break the recursion. |
| 292 | recurrent_node_map_[phi] = phi_node.get(); |
| 293 | |
| 294 | // Traverse the operands of the instruction an create new nodes for each one. |
| 295 | for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { |
| 296 | uint32_t value_id = phi->GetSingleWordInOperand(i); |
| 297 | uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1); |
| 298 | |
| 299 | Instruction* value_inst = def_use->GetDef(value_id); |
| 300 | SENode* value_node = AnalyzeInstruction(value_inst); |
| 301 | |
| 302 | // If any operand is CantCompute then the whole graph is CantCompute. |
| 303 | if (value_node->IsCantCompute()) |
| 304 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 305 | |
| 306 | // If the value is coming from the preheader block then the value is the |
| 307 | // initial value of the phi. |
| 308 | if (incoming_label_id == loop->GetPreHeaderBlock()->id()) { |
| 309 | phi_node->AddOffset(value_node); |
| 310 | } else if (incoming_label_id == loop->GetLatchBlock()->id()) { |
| 311 | // Assumed to be in the form of step + phi. |
| 312 | if (value_node->GetType() != SENode::Add) |
| 313 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 314 | |
| 315 | SENode* step_node = nullptr; |
| 316 | SENode* phi_operand = nullptr; |
| 317 | SENode* operand_1 = value_node->GetChild(0); |
| 318 | SENode* operand_2 = value_node->GetChild(1); |
| 319 | |
| 320 | // Find which node is the step term. |
| 321 | if (!operand_1->AsSERecurrentNode()) |
| 322 | step_node = operand_1; |
| 323 | else if (!operand_2->AsSERecurrentNode()) |
| 324 | step_node = operand_2; |
| 325 | |
| 326 | // Find which node is the recurrent expression. |
| 327 | if (operand_1->AsSERecurrentNode()) |
| 328 | phi_operand = operand_1; |
| 329 | else if (operand_2->AsSERecurrentNode()) |
| 330 | phi_operand = operand_2; |
| 331 | |
| 332 | // If it is not in the form step + phi exit out. |
| 333 | if (!(step_node && phi_operand)) |
| 334 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 335 | |
| 336 | // If the phi operand is not the same phi node exit out. |
| 337 | if (phi_operand != phi_node.get()) |
| 338 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 339 | |
| 340 | if (!IsLoopInvariant(loop, step_node)) |
| 341 | return recurrent_node_map_[phi] = CreateCantComputeNode(); |
| 342 | |
| 343 | phi_node->AddCoefficient(step_node); |
| 344 | } |
| 345 | } |
| 346 | |
| 347 | // Once the node is fully built we update the map with the version from the |
| 348 | // cache (if it has already been added to the cache). |
| 349 | return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node)); |
| 350 | } |
| 351 | |
| 352 | SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode( |
| 353 | const Instruction* inst) { |
| 354 | std::unique_ptr<SEValueUnknown> load_node{ |
| 355 | new SEValueUnknown(this, inst->result_id())}; |
| 356 | return GetCachedOrAdd(std::move(load_node)); |
| 357 | } |
| 358 | |
| 359 | SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() { |
| 360 | return cached_cant_compute_; |
| 361 | } |
| 362 | |
| 363 | // Add the created node into the cache of nodes. If it already exists return it. |
| 364 | SENode* ScalarEvolutionAnalysis::GetCachedOrAdd( |
| 365 | std::unique_ptr<SENode> prospective_node) { |
| 366 | auto itr = node_cache_.find(prospective_node); |
| 367 | if (itr != node_cache_.end()) { |
| 368 | return (*itr).get(); |
| 369 | } |
| 370 | |
| 371 | SENode* raw_ptr_to_node = prospective_node.get(); |
| 372 | node_cache_.insert(std::move(prospective_node)); |
| 373 | return raw_ptr_to_node; |
| 374 | } |
| 375 | |
| 376 | bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop, |
| 377 | const SENode* node) const { |
| 378 | for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) { |
| 379 | if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) { |
| 380 | const BasicBlock* = rec->GetLoop()->GetHeaderBlock(); |
| 381 | |
| 382 | // If the loop which the recurrent expression belongs to is either |loop |
| 383 | // or a nested loop inside |loop| then we assume it is variant. |
| 384 | if (loop->IsInsideLoop(header)) { |
| 385 | return false; |
| 386 | } |
| 387 | } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) { |
| 388 | // If the instruction is inside the loop we conservatively assume it is |
| 389 | // loop variant. |
| 390 | if (loop->IsInsideLoop(unknown->ResultId())) return false; |
| 391 | } |
| 392 | } |
| 393 | |
| 394 | return true; |
| 395 | } |
| 396 | |
| 397 | SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm( |
| 398 | SENode* node, const Loop* loop) { |
| 399 | // Traverse the DAG to find the recurrent expression belonging to |loop|. |
| 400 | for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { |
| 401 | SERecurrentNode* rec = itr->AsSERecurrentNode(); |
| 402 | if (rec && rec->GetLoop() == loop) { |
| 403 | return rec->GetCoefficient(); |
| 404 | } |
| 405 | } |
| 406 | return CreateConstant(0); |
| 407 | } |
| 408 | |
| 409 | SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent, |
| 410 | SENode* old_child, |
| 411 | SENode* new_child) { |
| 412 | // Only handles add. |
| 413 | if (parent->GetType() != SENode::Add) return parent; |
| 414 | |
| 415 | std::vector<SENode*> new_children; |
| 416 | for (SENode* child : *parent) { |
| 417 | if (child == old_child) { |
| 418 | new_children.push_back(new_child); |
| 419 | } else { |
| 420 | new_children.push_back(child); |
| 421 | } |
| 422 | } |
| 423 | |
| 424 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; |
| 425 | for (SENode* child : new_children) { |
| 426 | add_node->AddChild(child); |
| 427 | } |
| 428 | |
| 429 | return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); |
| 430 | } |
| 431 | |
| 432 | // Rebuild the |node| eliminating, if it exists, the recurrent term which |
| 433 | // belongs to the |loop|. |
| 434 | SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( |
| 435 | SENode* node, const Loop* loop) { |
| 436 | // If the node is already a recurrent expression belonging to loop then just |
| 437 | // return the offset. |
| 438 | SERecurrentNode* recurrent = node->AsSERecurrentNode(); |
| 439 | if (recurrent) { |
| 440 | if (recurrent->GetLoop() == loop) { |
| 441 | return recurrent->GetOffset(); |
| 442 | } else { |
| 443 | return node; |
| 444 | } |
| 445 | } |
| 446 | |
| 447 | std::vector<SENode*> new_children; |
| 448 | // Otherwise find the recurrent node in the children of this node. |
| 449 | for (auto itr : *node) { |
| 450 | recurrent = itr->AsSERecurrentNode(); |
| 451 | if (recurrent && recurrent->GetLoop() == loop) { |
| 452 | new_children.push_back(recurrent->GetOffset()); |
| 453 | } else { |
| 454 | new_children.push_back(itr); |
| 455 | } |
| 456 | } |
| 457 | |
| 458 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; |
| 459 | for (SENode* child : new_children) { |
| 460 | add_node->AddChild(child); |
| 461 | } |
| 462 | |
| 463 | return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); |
| 464 | } |
| 465 | |
| 466 | // Return the recurrent term belonging to |loop| if it appears in the graph |
| 467 | // starting at |node| or null if it doesn't. |
| 468 | SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node, |
| 469 | const Loop* loop) { |
| 470 | for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { |
| 471 | SERecurrentNode* rec = itr->AsSERecurrentNode(); |
| 472 | if (rec && rec->GetLoop() == loop) { |
| 473 | return rec; |
| 474 | } |
| 475 | } |
| 476 | return nullptr; |
| 477 | } |
| 478 | std::string SENode::AsString() const { |
| 479 | switch (GetType()) { |
| 480 | case Constant: |
| 481 | return "Constant" ; |
| 482 | case RecurrentAddExpr: |
| 483 | return "RecurrentAddExpr" ; |
| 484 | case Add: |
| 485 | return "Add" ; |
| 486 | case Negative: |
| 487 | return "Negative" ; |
| 488 | case Multiply: |
| 489 | return "Multiply" ; |
| 490 | case ValueUnknown: |
| 491 | return "Value Unknown" ; |
| 492 | case CanNotCompute: |
| 493 | return "Can not compute" ; |
| 494 | } |
| 495 | return "NULL" ; |
| 496 | } |
| 497 | |
| 498 | bool SENode::operator==(const SENode& other) const { |
| 499 | if (GetType() != other.GetType()) return false; |
| 500 | |
| 501 | if (other.GetChildren().size() != children_.size()) return false; |
| 502 | |
| 503 | const SERecurrentNode* this_as_recurrent = AsSERecurrentNode(); |
| 504 | |
| 505 | // Check the children are the same, for SERecurrentNodes we need to check the |
| 506 | // offset and coefficient manually as the child vector is sorted by ids so the |
| 507 | // offset/coefficient information is lost. |
| 508 | if (!this_as_recurrent) { |
| 509 | for (size_t index = 0; index < children_.size(); ++index) { |
| 510 | if (other.GetChildren()[index] != children_[index]) return false; |
| 511 | } |
| 512 | } else { |
| 513 | const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode(); |
| 514 | |
| 515 | // We've already checked the types are the same, this should not fail if |
| 516 | // this->AsSERecurrentNode() succeeded. |
| 517 | assert(other_as_recurrent); |
| 518 | |
| 519 | if (this_as_recurrent->GetCoefficient() != |
| 520 | other_as_recurrent->GetCoefficient()) |
| 521 | return false; |
| 522 | |
| 523 | if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset()) |
| 524 | return false; |
| 525 | |
| 526 | if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop()) |
| 527 | return false; |
| 528 | } |
| 529 | |
| 530 | // If we're dealing with a value unknown node check both nodes were created by |
| 531 | // the same instruction. |
| 532 | if (GetType() == SENode::ValueUnknown) { |
| 533 | if (AsSEValueUnknown()->ResultId() != |
| 534 | other.AsSEValueUnknown()->ResultId()) { |
| 535 | return false; |
| 536 | } |
| 537 | } |
| 538 | |
| 539 | if (AsSEConstantNode()) { |
| 540 | if (AsSEConstantNode()->FoldToSingleValue() != |
| 541 | other.AsSEConstantNode()->FoldToSingleValue()) |
| 542 | return false; |
| 543 | } |
| 544 | |
| 545 | return true; |
| 546 | } |
| 547 | |
| 548 | bool SENode::operator!=(const SENode& other) const { return !(*this == other); } |
| 549 | |
| 550 | namespace { |
| 551 | // Helper functions to insert 32/64 bit values into the 32 bit hash string. This |
| 552 | // allows us to add pointers to the string by reinterpreting the pointers as |
| 553 | // uintptr_t. PushToString will deduce the type, call sizeof on it and use |
| 554 | // that size to call into the correct PushToStringImpl functor depending on |
| 555 | // whether it is 32 or 64 bit. |
| 556 | |
| 557 | template <typename T, size_t size_of_t> |
| 558 | struct PushToStringImpl; |
| 559 | |
| 560 | template <typename T> |
| 561 | struct PushToStringImpl<T, 8> { |
| 562 | void operator()(T id, std::u32string* str) { |
| 563 | str->push_back(static_cast<uint32_t>(id >> 32)); |
| 564 | str->push_back(static_cast<uint32_t>(id)); |
| 565 | } |
| 566 | }; |
| 567 | |
| 568 | template <typename T> |
| 569 | struct PushToStringImpl<T, 4> { |
| 570 | void operator()(T id, std::u32string* str) { |
| 571 | str->push_back(static_cast<uint32_t>(id)); |
| 572 | } |
| 573 | }; |
| 574 | |
| 575 | template <typename T> |
| 576 | static void PushToString(T id, std::u32string* str) { |
| 577 | PushToStringImpl<T, sizeof(T)>{}(id, str); |
| 578 | } |
| 579 | |
| 580 | } // namespace |
| 581 | |
| 582 | // Implements the hashing of SENodes. |
| 583 | size_t SENodeHash::operator()(const SENode* node) const { |
| 584 | // Concatinate the terms into a string which we can hash. |
| 585 | std::u32string hash_string{}; |
| 586 | |
| 587 | // Hashing the type as a string is safer than hashing the enum as the enum is |
| 588 | // very likely to collide with constants. |
| 589 | for (char ch : node->AsString()) { |
| 590 | hash_string.push_back(static_cast<char32_t>(ch)); |
| 591 | } |
| 592 | |
| 593 | // We just ignore the literal value unless it is a constant. |
| 594 | if (node->GetType() == SENode::Constant) |
| 595 | PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string); |
| 596 | |
| 597 | const SERecurrentNode* recurrent = node->AsSERecurrentNode(); |
| 598 | |
| 599 | // If we're dealing with a recurrent expression hash the loop as well so that |
| 600 | // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes. |
| 601 | if (recurrent) { |
| 602 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()), |
| 603 | &hash_string); |
| 604 | |
| 605 | // Recurrent expressions can't be hashed using the normal method as the |
| 606 | // order of coefficient and offset matters to the hash. |
| 607 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()), |
| 608 | &hash_string); |
| 609 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()), |
| 610 | &hash_string); |
| 611 | |
| 612 | return std::hash<std::u32string>{}(hash_string); |
| 613 | } |
| 614 | |
| 615 | // Hash the result id of the original instruction which created this node if |
| 616 | // it is a value unknown node. |
| 617 | if (node->GetType() == SENode::ValueUnknown) { |
| 618 | PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string); |
| 619 | } |
| 620 | |
| 621 | // Hash the pointers of the child nodes, each SENode has a unique pointer |
| 622 | // associated with it. |
| 623 | const std::vector<SENode*>& children = node->GetChildren(); |
| 624 | for (const SENode* child : children) { |
| 625 | PushToString(reinterpret_cast<uintptr_t>(child), &hash_string); |
| 626 | } |
| 627 | |
| 628 | return std::hash<std::u32string>{}(hash_string); |
| 629 | } |
| 630 | |
| 631 | // This overload is the actual overload used by the node_cache_ set. |
| 632 | size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const { |
| 633 | return this->operator()(node.get()); |
| 634 | } |
| 635 | |
| 636 | void SENode::DumpDot(std::ostream& out, bool recurse) const { |
| 637 | size_t unique_id = std::hash<const SENode*>{}(this); |
| 638 | out << unique_id << " [label=\"" << AsString() << " " ; |
| 639 | if (GetType() == SENode::Constant) { |
| 640 | out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue(); |
| 641 | } |
| 642 | out << "\"]\n" ; |
| 643 | for (const SENode* child : children_) { |
| 644 | size_t child_unique_id = std::hash<const SENode*>{}(child); |
| 645 | out << unique_id << " -> " << child_unique_id << " \n" ; |
| 646 | if (recurse) child->DumpDot(out, true); |
| 647 | } |
| 648 | } |
| 649 | |
| 650 | namespace { |
| 651 | class IsGreaterThanZero { |
| 652 | public: |
| 653 | explicit IsGreaterThanZero(IRContext* context) : context_(context) {} |
| 654 | |
| 655 | // Determine if the value of |node| is always strictly greater than zero if |
| 656 | // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is |
| 657 | // true. It returns true is the evaluation was able to conclude something, in |
| 658 | // which case the result is stored in |result|. |
| 659 | // The algorithm work by going through all the nodes and determine the |
| 660 | // sign of each of them. |
| 661 | bool Eval(const SENode* node, bool or_equal_zero, bool* result) { |
| 662 | *result = false; |
| 663 | switch (Visit(node)) { |
| 664 | case Signedness::kPositiveOrNegative: { |
| 665 | return false; |
| 666 | } |
| 667 | case Signedness::kStrictlyNegative: { |
| 668 | *result = false; |
| 669 | break; |
| 670 | } |
| 671 | case Signedness::kNegative: { |
| 672 | if (!or_equal_zero) { |
| 673 | return false; |
| 674 | } |
| 675 | *result = false; |
| 676 | break; |
| 677 | } |
| 678 | case Signedness::kStrictlyPositive: { |
| 679 | *result = true; |
| 680 | break; |
| 681 | } |
| 682 | case Signedness::kPositive: { |
| 683 | if (!or_equal_zero) { |
| 684 | return false; |
| 685 | } |
| 686 | *result = true; |
| 687 | break; |
| 688 | } |
| 689 | } |
| 690 | return true; |
| 691 | } |
| 692 | |
| 693 | private: |
| 694 | enum class Signedness { |
| 695 | kPositiveOrNegative, // Yield a value positive or negative. |
| 696 | kStrictlyNegative, // Yield a value strictly less than 0. |
| 697 | kNegative, // Yield a value less or equal to 0. |
| 698 | kStrictlyPositive, // Yield a value strictly greater than 0. |
| 699 | kPositive // Yield a value greater or equal to 0. |
| 700 | }; |
| 701 | |
| 702 | // Combine the signedness according to arithmetic rules of a given operator. |
| 703 | using Combiner = std::function<Signedness(Signedness, Signedness)>; |
| 704 | |
| 705 | // Returns a functor to interpret the signedness of 2 expressions as if they |
| 706 | // were added. |
| 707 | Combiner GetAddCombiner() const { |
| 708 | return [](Signedness lhs, Signedness rhs) { |
| 709 | switch (lhs) { |
| 710 | case Signedness::kPositiveOrNegative: |
| 711 | break; |
| 712 | case Signedness::kStrictlyNegative: |
| 713 | if (rhs == Signedness::kStrictlyNegative || |
| 714 | rhs == Signedness::kNegative) |
| 715 | return lhs; |
| 716 | break; |
| 717 | case Signedness::kNegative: { |
| 718 | if (rhs == Signedness::kStrictlyNegative) |
| 719 | return Signedness::kStrictlyNegative; |
| 720 | if (rhs == Signedness::kNegative) return Signedness::kNegative; |
| 721 | break; |
| 722 | } |
| 723 | case Signedness::kStrictlyPositive: { |
| 724 | if (rhs == Signedness::kStrictlyPositive || |
| 725 | rhs == Signedness::kPositive) { |
| 726 | return Signedness::kStrictlyPositive; |
| 727 | } |
| 728 | break; |
| 729 | } |
| 730 | case Signedness::kPositive: { |
| 731 | if (rhs == Signedness::kStrictlyPositive) |
| 732 | return Signedness::kStrictlyPositive; |
| 733 | if (rhs == Signedness::kPositive) return Signedness::kPositive; |
| 734 | break; |
| 735 | } |
| 736 | } |
| 737 | return Signedness::kPositiveOrNegative; |
| 738 | }; |
| 739 | } |
| 740 | |
| 741 | // Returns a functor to interpret the signedness of 2 expressions as if they |
| 742 | // were multiplied. |
| 743 | Combiner GetMulCombiner() const { |
| 744 | return [](Signedness lhs, Signedness rhs) { |
| 745 | switch (lhs) { |
| 746 | case Signedness::kPositiveOrNegative: |
| 747 | break; |
| 748 | case Signedness::kStrictlyNegative: { |
| 749 | switch (rhs) { |
| 750 | case Signedness::kPositiveOrNegative: { |
| 751 | break; |
| 752 | } |
| 753 | case Signedness::kStrictlyNegative: { |
| 754 | return Signedness::kStrictlyPositive; |
| 755 | } |
| 756 | case Signedness::kNegative: { |
| 757 | return Signedness::kPositive; |
| 758 | } |
| 759 | case Signedness::kStrictlyPositive: { |
| 760 | return Signedness::kStrictlyNegative; |
| 761 | } |
| 762 | case Signedness::kPositive: { |
| 763 | return Signedness::kNegative; |
| 764 | } |
| 765 | } |
| 766 | break; |
| 767 | } |
| 768 | case Signedness::kNegative: { |
| 769 | switch (rhs) { |
| 770 | case Signedness::kPositiveOrNegative: { |
| 771 | break; |
| 772 | } |
| 773 | case Signedness::kStrictlyNegative: |
| 774 | case Signedness::kNegative: { |
| 775 | return Signedness::kPositive; |
| 776 | } |
| 777 | case Signedness::kStrictlyPositive: |
| 778 | case Signedness::kPositive: { |
| 779 | return Signedness::kNegative; |
| 780 | } |
| 781 | } |
| 782 | break; |
| 783 | } |
| 784 | case Signedness::kStrictlyPositive: { |
| 785 | return rhs; |
| 786 | } |
| 787 | case Signedness::kPositive: { |
| 788 | switch (rhs) { |
| 789 | case Signedness::kPositiveOrNegative: { |
| 790 | break; |
| 791 | } |
| 792 | case Signedness::kStrictlyNegative: |
| 793 | case Signedness::kNegative: { |
| 794 | return Signedness::kNegative; |
| 795 | } |
| 796 | case Signedness::kStrictlyPositive: |
| 797 | case Signedness::kPositive: { |
| 798 | return Signedness::kPositive; |
| 799 | } |
| 800 | } |
| 801 | break; |
| 802 | } |
| 803 | } |
| 804 | return Signedness::kPositiveOrNegative; |
| 805 | }; |
| 806 | } |
| 807 | |
| 808 | Signedness Visit(const SENode* node) { |
| 809 | switch (node->GetType()) { |
| 810 | case SENode::Constant: |
| 811 | return Visit(node->AsSEConstantNode()); |
| 812 | break; |
| 813 | case SENode::RecurrentAddExpr: |
| 814 | return Visit(node->AsSERecurrentNode()); |
| 815 | break; |
| 816 | case SENode::Negative: |
| 817 | return Visit(node->AsSENegative()); |
| 818 | break; |
| 819 | case SENode::CanNotCompute: |
| 820 | return Visit(node->AsSECantCompute()); |
| 821 | break; |
| 822 | case SENode::ValueUnknown: |
| 823 | return Visit(node->AsSEValueUnknown()); |
| 824 | break; |
| 825 | case SENode::Add: |
| 826 | return VisitExpr(node, GetAddCombiner()); |
| 827 | break; |
| 828 | case SENode::Multiply: |
| 829 | return VisitExpr(node, GetMulCombiner()); |
| 830 | break; |
| 831 | } |
| 832 | return Signedness::kPositiveOrNegative; |
| 833 | } |
| 834 | |
| 835 | // Returns the signedness of a constant |node|. |
| 836 | Signedness Visit(const SEConstantNode* node) { |
| 837 | if (0 == node->FoldToSingleValue()) return Signedness::kPositive; |
| 838 | if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive; |
| 839 | if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative; |
| 840 | return Signedness::kPositiveOrNegative; |
| 841 | } |
| 842 | |
| 843 | // Returns the signedness of an unknown |node| based on its type. |
| 844 | Signedness Visit(const SEValueUnknown* node) { |
| 845 | Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId()); |
| 846 | analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id()); |
| 847 | assert(type && "Can't retrieve a type for the instruction" ); |
| 848 | analysis::Integer* int_type = type->AsInteger(); |
| 849 | assert(type && "Can't retrieve an integer type for the instruction" ); |
| 850 | return int_type->IsSigned() ? Signedness::kPositiveOrNegative |
| 851 | : Signedness::kPositive; |
| 852 | } |
| 853 | |
| 854 | // Returns the signedness of a recurring expression. |
| 855 | Signedness Visit(const SERecurrentNode* node) { |
| 856 | Signedness coeff_sign = Visit(node->GetCoefficient()); |
| 857 | // SERecurrentNode represent an affine expression in the range [0, |
| 858 | // loop_bound], so the result cannot be strictly positive or negative. |
| 859 | switch (coeff_sign) { |
| 860 | default: |
| 861 | break; |
| 862 | case Signedness::kStrictlyNegative: |
| 863 | coeff_sign = Signedness::kNegative; |
| 864 | break; |
| 865 | case Signedness::kStrictlyPositive: |
| 866 | coeff_sign = Signedness::kPositive; |
| 867 | break; |
| 868 | } |
| 869 | return GetAddCombiner()(coeff_sign, Visit(node->GetOffset())); |
| 870 | } |
| 871 | |
| 872 | // Returns the signedness of a negation |node|. |
| 873 | Signedness Visit(const SENegative* node) { |
| 874 | switch (Visit(*node->begin())) { |
| 875 | case Signedness::kPositiveOrNegative: { |
| 876 | return Signedness::kPositiveOrNegative; |
| 877 | } |
| 878 | case Signedness::kStrictlyNegative: { |
| 879 | return Signedness::kStrictlyPositive; |
| 880 | } |
| 881 | case Signedness::kNegative: { |
| 882 | return Signedness::kPositive; |
| 883 | } |
| 884 | case Signedness::kStrictlyPositive: { |
| 885 | return Signedness::kStrictlyNegative; |
| 886 | } |
| 887 | case Signedness::kPositive: { |
| 888 | return Signedness::kNegative; |
| 889 | } |
| 890 | } |
| 891 | return Signedness::kPositiveOrNegative; |
| 892 | } |
| 893 | |
| 894 | Signedness Visit(const SECantCompute*) { |
| 895 | return Signedness::kPositiveOrNegative; |
| 896 | } |
| 897 | |
| 898 | // Returns the signedness of a binary expression by using the combiner |
| 899 | // |reduce|. |
| 900 | Signedness VisitExpr( |
| 901 | const SENode* node, |
| 902 | std::function<Signedness(Signedness, Signedness)> reduce) { |
| 903 | Signedness result = Visit(*node->begin()); |
| 904 | for (const SENode* operand : make_range(++node->begin(), node->end())) { |
| 905 | if (result == Signedness::kPositiveOrNegative) { |
| 906 | return Signedness::kPositiveOrNegative; |
| 907 | } |
| 908 | result = reduce(result, Visit(operand)); |
| 909 | } |
| 910 | return result; |
| 911 | } |
| 912 | |
| 913 | IRContext* context_; |
| 914 | }; |
| 915 | } // namespace |
| 916 | |
| 917 | bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node, |
| 918 | bool* is_gt_zero) const { |
| 919 | return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero); |
| 920 | } |
| 921 | |
| 922 | bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero( |
| 923 | SENode* node, bool* is_ge_zero) const { |
| 924 | return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero); |
| 925 | } |
| 926 | |
| 927 | namespace { |
| 928 | |
| 929 | // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z), |
| 930 | // if |node| is not in the chain, returns the original chain. |
| 931 | static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul, |
| 932 | const SENode* node) { |
| 933 | SENode* lhs = mul->GetChildren()[0]; |
| 934 | SENode* rhs = mul->GetChildren()[1]; |
| 935 | if (lhs == node) { |
| 936 | return rhs; |
| 937 | } |
| 938 | if (rhs == node) { |
| 939 | return lhs; |
| 940 | } |
| 941 | if (lhs->AsSEMultiplyNode()) { |
| 942 | SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node); |
| 943 | if (res != lhs) |
| 944 | return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); |
| 945 | } |
| 946 | if (rhs->AsSEMultiplyNode()) { |
| 947 | SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node); |
| 948 | if (res != rhs) |
| 949 | return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); |
| 950 | } |
| 951 | |
| 952 | return mul; |
| 953 | } |
| 954 | } // namespace |
| 955 | |
| 956 | std::pair<SExpression, int64_t> SExpression::operator/( |
| 957 | SExpression rhs_wrapper) const { |
| 958 | SENode* lhs = node_; |
| 959 | SENode* rhs = rhs_wrapper.node_; |
| 960 | // Check for division by 0. |
| 961 | if (rhs->AsSEConstantNode() && |
| 962 | !rhs->AsSEConstantNode()->FoldToSingleValue()) { |
| 963 | return {scev_->CreateCantComputeNode(), 0}; |
| 964 | } |
| 965 | |
| 966 | // Trivial case. |
| 967 | if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) { |
| 968 | int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue(); |
| 969 | int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue(); |
| 970 | return {scev_->CreateConstant(lhs_value / rhs_value), |
| 971 | lhs_value % rhs_value}; |
| 972 | } |
| 973 | |
| 974 | // look for a "c U / U" pattern. |
| 975 | if (lhs->AsSEMultiplyNode()) { |
| 976 | assert(lhs->GetChildren().size() == 2 && |
| 977 | "More than 2 operand for a multiply node." ); |
| 978 | SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs); |
| 979 | if (res != lhs) { |
| 980 | return {res, 0}; |
| 981 | } |
| 982 | } |
| 983 | |
| 984 | return {scev_->CreateCantComputeNode(), 0}; |
| 985 | } |
| 986 | |
| 987 | } // namespace opt |
| 988 | } // namespace spvtools |
| 989 | |