| 1 | // Copyright (c) 2018 Google LLC. | 
|---|
| 2 | // | 
|---|
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|---|
| 4 | // you may not use this file except in compliance with the License. | 
|---|
| 5 | // You may obtain a copy of the License at | 
|---|
| 6 | // | 
|---|
| 7 | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|---|
| 8 | // | 
|---|
| 9 | // Unless required by applicable law or agreed to in writing, software | 
|---|
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, | 
|---|
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|---|
| 12 | // See the License for the specific language governing permissions and | 
|---|
| 13 | // limitations under the License. | 
|---|
| 14 |  | 
|---|
| 15 | #include "source/opt/scalar_analysis.h" | 
|---|
| 16 |  | 
|---|
| 17 | #include <algorithm> | 
|---|
| 18 | #include <functional> | 
|---|
| 19 | #include <string> | 
|---|
| 20 | #include <utility> | 
|---|
| 21 |  | 
|---|
| 22 | #include "source/opt/ir_context.h" | 
|---|
| 23 |  | 
|---|
| 24 | // Transforms a given scalar operation instruction into a DAG representation. | 
|---|
| 25 | // | 
|---|
| 26 | // 1. Take an instruction and traverse its operands until we reach a | 
|---|
| 27 | // constant node or an instruction which we do not know how to compute the | 
|---|
| 28 | // value, such as a load. | 
|---|
| 29 | // | 
|---|
| 30 | // 2. Create a new node for each instruction traversed and build the nodes for | 
|---|
| 31 | // the in operands of that instruction as well. | 
|---|
| 32 | // | 
|---|
| 33 | // 3. Add the operand nodes as children of the first and hash the node. Use the | 
|---|
| 34 | // hash to see if the node is already in the cache. We ensure the children are | 
|---|
| 35 | // always in sorted order so that two nodes with the same children but inserted | 
|---|
| 36 | // in a different order have the same hash and so that the overloaded operator== | 
|---|
| 37 | // will return true. If the node is already in the cache return the cached | 
|---|
| 38 | // version instead. | 
|---|
| 39 | // | 
|---|
| 40 | // 4. The created DAG can then be simplified by | 
|---|
| 41 | // ScalarAnalysis::SimplifyExpression, implemented in | 
|---|
| 42 | // scalar_analysis_simplification.cpp. See that file for further information on | 
|---|
| 43 | // the simplification process. | 
|---|
| 44 | // | 
|---|
| 45 |  | 
|---|
| 46 | namespace spvtools { | 
|---|
| 47 | namespace opt { | 
|---|
| 48 |  | 
|---|
| 49 | uint32_t SENode::NumberOfNodes = 0; | 
|---|
| 50 |  | 
|---|
| 51 | ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context) | 
|---|
| 52 | : context_(context), pretend_equal_{} { | 
|---|
| 53 | // Create and cached the CantComputeNode. | 
|---|
| 54 | cached_cant_compute_ = | 
|---|
| 55 | GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this))); | 
|---|
| 56 | } | 
|---|
| 57 |  | 
|---|
| 58 | SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) { | 
|---|
| 59 | // If operand is can't compute then the whole graph is can't compute. | 
|---|
| 60 | if (operand->IsCantCompute()) return CreateCantComputeNode(); | 
|---|
| 61 |  | 
|---|
| 62 | if (operand->GetType() == SENode::Constant) { | 
|---|
| 63 | return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue()); | 
|---|
| 64 | } | 
|---|
| 65 | std::unique_ptr<SENode> negation_node{new SENegative(this)}; | 
|---|
| 66 | negation_node->AddChild(operand); | 
|---|
| 67 | return GetCachedOrAdd(std::move(negation_node)); | 
|---|
| 68 | } | 
|---|
| 69 |  | 
|---|
| 70 | SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) { | 
|---|
| 71 | return GetCachedOrAdd( | 
|---|
| 72 | std::unique_ptr<SENode>(new SEConstantNode(this, integer))); | 
|---|
| 73 | } | 
|---|
| 74 |  | 
|---|
| 75 | SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( | 
|---|
| 76 | const Loop* loop, SENode* offset, SENode* coefficient) { | 
|---|
| 77 | assert(loop && "Recurrent add expressions must have a valid loop."); | 
|---|
| 78 |  | 
|---|
| 79 | // If operands are can't compute then the whole graph is can't compute. | 
|---|
| 80 | if (offset->IsCantCompute() || coefficient->IsCantCompute()) | 
|---|
| 81 | return CreateCantComputeNode(); | 
|---|
| 82 |  | 
|---|
| 83 | const Loop* loop_to_use = nullptr; | 
|---|
| 84 | if (pretend_equal_[loop]) { | 
|---|
| 85 | loop_to_use = pretend_equal_[loop]; | 
|---|
| 86 | } else { | 
|---|
| 87 | loop_to_use = loop; | 
|---|
| 88 | } | 
|---|
| 89 |  | 
|---|
| 90 | std::unique_ptr<SERecurrentNode> phi_node{ | 
|---|
| 91 | new SERecurrentNode(this, loop_to_use)}; | 
|---|
| 92 | phi_node->AddOffset(offset); | 
|---|
| 93 | phi_node->AddCoefficient(coefficient); | 
|---|
| 94 |  | 
|---|
| 95 | return GetCachedOrAdd(std::move(phi_node)); | 
|---|
| 96 | } | 
|---|
| 97 |  | 
|---|
| 98 | SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp( | 
|---|
| 99 | const Instruction* multiply) { | 
|---|
| 100 | assert(multiply->opcode() == SpvOp::SpvOpIMul && | 
|---|
| 101 | "Multiply node did not come from a multiply instruction"); | 
|---|
| 102 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); | 
|---|
| 103 |  | 
|---|
| 104 | SENode* op1 = | 
|---|
| 105 | AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0))); | 
|---|
| 106 | SENode* op2 = | 
|---|
| 107 | AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1))); | 
|---|
| 108 |  | 
|---|
| 109 | return CreateMultiplyNode(op1, op2); | 
|---|
| 110 | } | 
|---|
| 111 |  | 
|---|
| 112 | SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1, | 
|---|
| 113 | SENode* operand_2) { | 
|---|
| 114 | // If operands are can't compute then the whole graph is can't compute. | 
|---|
| 115 | if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) | 
|---|
| 116 | return CreateCantComputeNode(); | 
|---|
| 117 |  | 
|---|
| 118 | if (operand_1->GetType() == SENode::Constant && | 
|---|
| 119 | operand_2->GetType() == SENode::Constant) { | 
|---|
| 120 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() * | 
|---|
| 121 | operand_2->AsSEConstantNode()->FoldToSingleValue()); | 
|---|
| 122 | } | 
|---|
| 123 |  | 
|---|
| 124 | std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)}; | 
|---|
| 125 |  | 
|---|
| 126 | multiply_node->AddChild(operand_1); | 
|---|
| 127 | multiply_node->AddChild(operand_2); | 
|---|
| 128 |  | 
|---|
| 129 | return GetCachedOrAdd(std::move(multiply_node)); | 
|---|
| 130 | } | 
|---|
| 131 |  | 
|---|
| 132 | SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1, | 
|---|
| 133 | SENode* operand_2) { | 
|---|
| 134 | // Fold if both operands are constant. | 
|---|
| 135 | if (operand_1->GetType() == SENode::Constant && | 
|---|
| 136 | operand_2->GetType() == SENode::Constant) { | 
|---|
| 137 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() - | 
|---|
| 138 | operand_2->AsSEConstantNode()->FoldToSingleValue()); | 
|---|
| 139 | } | 
|---|
| 140 |  | 
|---|
| 141 | return CreateAddNode(operand_1, CreateNegation(operand_2)); | 
|---|
| 142 | } | 
|---|
| 143 |  | 
|---|
| 144 | SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1, | 
|---|
| 145 | SENode* operand_2) { | 
|---|
| 146 | // Fold if both operands are constant and the |simplify| flag is true. | 
|---|
| 147 | if (operand_1->GetType() == SENode::Constant && | 
|---|
| 148 | operand_2->GetType() == SENode::Constant) { | 
|---|
| 149 | return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() + | 
|---|
| 150 | operand_2->AsSEConstantNode()->FoldToSingleValue()); | 
|---|
| 151 | } | 
|---|
| 152 |  | 
|---|
| 153 | // If operands are can't compute then the whole graph is can't compute. | 
|---|
| 154 | if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) | 
|---|
| 155 | return CreateCantComputeNode(); | 
|---|
| 156 |  | 
|---|
| 157 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; | 
|---|
| 158 |  | 
|---|
| 159 | add_node->AddChild(operand_1); | 
|---|
| 160 | add_node->AddChild(operand_2); | 
|---|
| 161 |  | 
|---|
| 162 | return GetCachedOrAdd(std::move(add_node)); | 
|---|
| 163 | } | 
|---|
| 164 |  | 
|---|
| 165 | SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) { | 
|---|
| 166 | auto itr = recurrent_node_map_.find(inst); | 
|---|
| 167 | if (itr != recurrent_node_map_.end()) return itr->second; | 
|---|
| 168 |  | 
|---|
| 169 | SENode* output = nullptr; | 
|---|
| 170 | switch (inst->opcode()) { | 
|---|
| 171 | case SpvOp::SpvOpPhi: { | 
|---|
| 172 | output = AnalyzePhiInstruction(inst); | 
|---|
| 173 | break; | 
|---|
| 174 | } | 
|---|
| 175 | case SpvOp::SpvOpConstant: | 
|---|
| 176 | case SpvOp::SpvOpConstantNull: { | 
|---|
| 177 | output = AnalyzeConstant(inst); | 
|---|
| 178 | break; | 
|---|
| 179 | } | 
|---|
| 180 | case SpvOp::SpvOpISub: | 
|---|
| 181 | case SpvOp::SpvOpIAdd: { | 
|---|
| 182 | output = AnalyzeAddOp(inst); | 
|---|
| 183 | break; | 
|---|
| 184 | } | 
|---|
| 185 | case SpvOp::SpvOpIMul: { | 
|---|
| 186 | output = AnalyzeMultiplyOp(inst); | 
|---|
| 187 | break; | 
|---|
| 188 | } | 
|---|
| 189 | default: { | 
|---|
| 190 | output = CreateValueUnknownNode(inst); | 
|---|
| 191 | break; | 
|---|
| 192 | } | 
|---|
| 193 | } | 
|---|
| 194 |  | 
|---|
| 195 | return output; | 
|---|
| 196 | } | 
|---|
| 197 |  | 
|---|
| 198 | SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) { | 
|---|
| 199 | if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0); | 
|---|
| 200 |  | 
|---|
| 201 | assert(inst->opcode() == SpvOp::SpvOpConstant); | 
|---|
| 202 | assert(inst->NumInOperands() == 1); | 
|---|
| 203 | int64_t value = 0; | 
|---|
| 204 |  | 
|---|
| 205 | // Look up the instruction in the constant manager. | 
|---|
| 206 | const analysis::Constant* constant = | 
|---|
| 207 | context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id()); | 
|---|
| 208 |  | 
|---|
| 209 | if (!constant) return CreateCantComputeNode(); | 
|---|
| 210 |  | 
|---|
| 211 | const analysis::IntConstant* int_constant = constant->AsIntConstant(); | 
|---|
| 212 |  | 
|---|
| 213 | // Exit out if it is a 64 bit integer. | 
|---|
| 214 | if (!int_constant || int_constant->words().size() != 1) | 
|---|
| 215 | return CreateCantComputeNode(); | 
|---|
| 216 |  | 
|---|
| 217 | if (int_constant->type()->AsInteger()->IsSigned()) { | 
|---|
| 218 | value = int_constant->GetS32BitValue(); | 
|---|
| 219 | } else { | 
|---|
| 220 | value = int_constant->GetU32BitValue(); | 
|---|
| 221 | } | 
|---|
| 222 |  | 
|---|
| 223 | return CreateConstant(value); | 
|---|
| 224 | } | 
|---|
| 225 |  | 
|---|
| 226 | // Handles both addition and subtraction. If the |sub| flag is set then the | 
|---|
| 227 | // addition will be op1+(-op2) otherwise op1+op2. | 
|---|
| 228 | SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) { | 
|---|
| 229 | assert((inst->opcode() == SpvOp::SpvOpIAdd || | 
|---|
| 230 | inst->opcode() == SpvOp::SpvOpISub) && | 
|---|
| 231 | "Add node must be created from a OpIAdd or OpISub instruction"); | 
|---|
| 232 |  | 
|---|
| 233 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); | 
|---|
| 234 |  | 
|---|
| 235 | SENode* op1 = | 
|---|
| 236 | AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0))); | 
|---|
| 237 |  | 
|---|
| 238 | SENode* op2 = | 
|---|
| 239 | AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1))); | 
|---|
| 240 |  | 
|---|
| 241 | // To handle subtraction we wrap the second operand in a unary negation node. | 
|---|
| 242 | if (inst->opcode() == SpvOp::SpvOpISub) { | 
|---|
| 243 | op2 = CreateNegation(op2); | 
|---|
| 244 | } | 
|---|
| 245 |  | 
|---|
| 246 | return CreateAddNode(op1, op2); | 
|---|
| 247 | } | 
|---|
| 248 |  | 
|---|
| 249 | SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) { | 
|---|
| 250 | // The phi should only have two incoming value pairs. | 
|---|
| 251 | if (phi->NumInOperands() != 4) { | 
|---|
| 252 | return CreateCantComputeNode(); | 
|---|
| 253 | } | 
|---|
| 254 |  | 
|---|
| 255 | analysis::DefUseManager* def_use = context_->get_def_use_mgr(); | 
|---|
| 256 |  | 
|---|
| 257 | // Get the basic block this instruction belongs to. | 
|---|
| 258 | BasicBlock* basic_block = | 
|---|
| 259 | context_->get_instr_block(const_cast<Instruction*>(phi)); | 
|---|
| 260 |  | 
|---|
| 261 | // And then the function that the basic blocks belongs to. | 
|---|
| 262 | Function* function = basic_block->GetParent(); | 
|---|
| 263 |  | 
|---|
| 264 | // Use the function to get the loop descriptor. | 
|---|
| 265 | LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); | 
|---|
| 266 |  | 
|---|
| 267 | // We only handle phis in loops at the moment. | 
|---|
| 268 | if (!loop_descriptor) return CreateCantComputeNode(); | 
|---|
| 269 |  | 
|---|
| 270 | // Get the innermost loop which this block belongs to. | 
|---|
| 271 | Loop* loop = (*loop_descriptor)[basic_block->id()]; | 
|---|
| 272 |  | 
|---|
| 273 | // If the loop doesn't exist or doesn't have a preheader or latch block, exit | 
|---|
| 274 | // out. | 
|---|
| 275 | if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() || | 
|---|
| 276 | loop->GetHeaderBlock() != basic_block) | 
|---|
| 277 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 278 |  | 
|---|
| 279 | const Loop* loop_to_use = nullptr; | 
|---|
| 280 | if (pretend_equal_[loop]) { | 
|---|
| 281 | loop_to_use = pretend_equal_[loop]; | 
|---|
| 282 | } else { | 
|---|
| 283 | loop_to_use = loop; | 
|---|
| 284 | } | 
|---|
| 285 | std::unique_ptr<SERecurrentNode> phi_node{ | 
|---|
| 286 | new SERecurrentNode(this, loop_to_use)}; | 
|---|
| 287 |  | 
|---|
| 288 | // We add the node to this map to allow it to be returned before the node is | 
|---|
| 289 | // fully built. This is needed as the subsequent call to AnalyzeInstruction | 
|---|
| 290 | // could lead back to this |phi| instruction so we return the pointer | 
|---|
| 291 | // immediately in AnalyzeInstruction to break the recursion. | 
|---|
| 292 | recurrent_node_map_[phi] = phi_node.get(); | 
|---|
| 293 |  | 
|---|
| 294 | // Traverse the operands of the instruction an create new nodes for each one. | 
|---|
| 295 | for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { | 
|---|
| 296 | uint32_t value_id = phi->GetSingleWordInOperand(i); | 
|---|
| 297 | uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1); | 
|---|
| 298 |  | 
|---|
| 299 | Instruction* value_inst = def_use->GetDef(value_id); | 
|---|
| 300 | SENode* value_node = AnalyzeInstruction(value_inst); | 
|---|
| 301 |  | 
|---|
| 302 | // If any operand is CantCompute then the whole graph is CantCompute. | 
|---|
| 303 | if (value_node->IsCantCompute()) | 
|---|
| 304 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 305 |  | 
|---|
| 306 | // If the value is coming from the preheader block then the value is the | 
|---|
| 307 | // initial value of the phi. | 
|---|
| 308 | if (incoming_label_id == loop->GetPreHeaderBlock()->id()) { | 
|---|
| 309 | phi_node->AddOffset(value_node); | 
|---|
| 310 | } else if (incoming_label_id == loop->GetLatchBlock()->id()) { | 
|---|
| 311 | // Assumed to be in the form of step + phi. | 
|---|
| 312 | if (value_node->GetType() != SENode::Add) | 
|---|
| 313 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 314 |  | 
|---|
| 315 | SENode* step_node = nullptr; | 
|---|
| 316 | SENode* phi_operand = nullptr; | 
|---|
| 317 | SENode* operand_1 = value_node->GetChild(0); | 
|---|
| 318 | SENode* operand_2 = value_node->GetChild(1); | 
|---|
| 319 |  | 
|---|
| 320 | // Find which node is the step term. | 
|---|
| 321 | if (!operand_1->AsSERecurrentNode()) | 
|---|
| 322 | step_node = operand_1; | 
|---|
| 323 | else if (!operand_2->AsSERecurrentNode()) | 
|---|
| 324 | step_node = operand_2; | 
|---|
| 325 |  | 
|---|
| 326 | // Find which node is the recurrent expression. | 
|---|
| 327 | if (operand_1->AsSERecurrentNode()) | 
|---|
| 328 | phi_operand = operand_1; | 
|---|
| 329 | else if (operand_2->AsSERecurrentNode()) | 
|---|
| 330 | phi_operand = operand_2; | 
|---|
| 331 |  | 
|---|
| 332 | // If it is not in the form step + phi exit out. | 
|---|
| 333 | if (!(step_node && phi_operand)) | 
|---|
| 334 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 335 |  | 
|---|
| 336 | // If the phi operand is not the same phi node exit out. | 
|---|
| 337 | if (phi_operand != phi_node.get()) | 
|---|
| 338 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 339 |  | 
|---|
| 340 | if (!IsLoopInvariant(loop, step_node)) | 
|---|
| 341 | return recurrent_node_map_[phi] = CreateCantComputeNode(); | 
|---|
| 342 |  | 
|---|
| 343 | phi_node->AddCoefficient(step_node); | 
|---|
| 344 | } | 
|---|
| 345 | } | 
|---|
| 346 |  | 
|---|
| 347 | // Once the node is fully built we update the map with the version from the | 
|---|
| 348 | // cache (if it has already been added to the cache). | 
|---|
| 349 | return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node)); | 
|---|
| 350 | } | 
|---|
| 351 |  | 
|---|
| 352 | SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode( | 
|---|
| 353 | const Instruction* inst) { | 
|---|
| 354 | std::unique_ptr<SEValueUnknown> load_node{ | 
|---|
| 355 | new SEValueUnknown(this, inst->result_id())}; | 
|---|
| 356 | return GetCachedOrAdd(std::move(load_node)); | 
|---|
| 357 | } | 
|---|
| 358 |  | 
|---|
| 359 | SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() { | 
|---|
| 360 | return cached_cant_compute_; | 
|---|
| 361 | } | 
|---|
| 362 |  | 
|---|
| 363 | // Add the created node into the cache of nodes. If it already exists return it. | 
|---|
| 364 | SENode* ScalarEvolutionAnalysis::GetCachedOrAdd( | 
|---|
| 365 | std::unique_ptr<SENode> prospective_node) { | 
|---|
| 366 | auto itr = node_cache_.find(prospective_node); | 
|---|
| 367 | if (itr != node_cache_.end()) { | 
|---|
| 368 | return (*itr).get(); | 
|---|
| 369 | } | 
|---|
| 370 |  | 
|---|
| 371 | SENode* raw_ptr_to_node = prospective_node.get(); | 
|---|
| 372 | node_cache_.insert(std::move(prospective_node)); | 
|---|
| 373 | return raw_ptr_to_node; | 
|---|
| 374 | } | 
|---|
| 375 |  | 
|---|
| 376 | bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop, | 
|---|
| 377 | const SENode* node) const { | 
|---|
| 378 | for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) { | 
|---|
| 379 | if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) { | 
|---|
| 380 | const BasicBlock*  = rec->GetLoop()->GetHeaderBlock(); | 
|---|
| 381 |  | 
|---|
| 382 | // If the loop which the recurrent expression belongs to is either |loop | 
|---|
| 383 | // or a nested loop inside |loop| then we assume it is variant. | 
|---|
| 384 | if (loop->IsInsideLoop(header)) { | 
|---|
| 385 | return false; | 
|---|
| 386 | } | 
|---|
| 387 | } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) { | 
|---|
| 388 | // If the instruction is inside the loop we conservatively assume it is | 
|---|
| 389 | // loop variant. | 
|---|
| 390 | if (loop->IsInsideLoop(unknown->ResultId())) return false; | 
|---|
| 391 | } | 
|---|
| 392 | } | 
|---|
| 393 |  | 
|---|
| 394 | return true; | 
|---|
| 395 | } | 
|---|
| 396 |  | 
|---|
| 397 | SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm( | 
|---|
| 398 | SENode* node, const Loop* loop) { | 
|---|
| 399 | // Traverse the DAG to find the recurrent expression belonging to |loop|. | 
|---|
| 400 | for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { | 
|---|
| 401 | SERecurrentNode* rec = itr->AsSERecurrentNode(); | 
|---|
| 402 | if (rec && rec->GetLoop() == loop) { | 
|---|
| 403 | return rec->GetCoefficient(); | 
|---|
| 404 | } | 
|---|
| 405 | } | 
|---|
| 406 | return CreateConstant(0); | 
|---|
| 407 | } | 
|---|
| 408 |  | 
|---|
| 409 | SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent, | 
|---|
| 410 | SENode* old_child, | 
|---|
| 411 | SENode* new_child) { | 
|---|
| 412 | // Only handles add. | 
|---|
| 413 | if (parent->GetType() != SENode::Add) return parent; | 
|---|
| 414 |  | 
|---|
| 415 | std::vector<SENode*> new_children; | 
|---|
| 416 | for (SENode* child : *parent) { | 
|---|
| 417 | if (child == old_child) { | 
|---|
| 418 | new_children.push_back(new_child); | 
|---|
| 419 | } else { | 
|---|
| 420 | new_children.push_back(child); | 
|---|
| 421 | } | 
|---|
| 422 | } | 
|---|
| 423 |  | 
|---|
| 424 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; | 
|---|
| 425 | for (SENode* child : new_children) { | 
|---|
| 426 | add_node->AddChild(child); | 
|---|
| 427 | } | 
|---|
| 428 |  | 
|---|
| 429 | return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); | 
|---|
| 430 | } | 
|---|
| 431 |  | 
|---|
| 432 | // Rebuild the |node| eliminating, if it exists, the recurrent term which | 
|---|
| 433 | // belongs to the |loop|. | 
|---|
| 434 | SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( | 
|---|
| 435 | SENode* node, const Loop* loop) { | 
|---|
| 436 | // If the node is already a recurrent expression belonging to loop then just | 
|---|
| 437 | // return the offset. | 
|---|
| 438 | SERecurrentNode* recurrent = node->AsSERecurrentNode(); | 
|---|
| 439 | if (recurrent) { | 
|---|
| 440 | if (recurrent->GetLoop() == loop) { | 
|---|
| 441 | return recurrent->GetOffset(); | 
|---|
| 442 | } else { | 
|---|
| 443 | return node; | 
|---|
| 444 | } | 
|---|
| 445 | } | 
|---|
| 446 |  | 
|---|
| 447 | std::vector<SENode*> new_children; | 
|---|
| 448 | // Otherwise find the recurrent node in the children of this node. | 
|---|
| 449 | for (auto itr : *node) { | 
|---|
| 450 | recurrent = itr->AsSERecurrentNode(); | 
|---|
| 451 | if (recurrent && recurrent->GetLoop() == loop) { | 
|---|
| 452 | new_children.push_back(recurrent->GetOffset()); | 
|---|
| 453 | } else { | 
|---|
| 454 | new_children.push_back(itr); | 
|---|
| 455 | } | 
|---|
| 456 | } | 
|---|
| 457 |  | 
|---|
| 458 | std::unique_ptr<SENode> add_node{new SEAddNode(this)}; | 
|---|
| 459 | for (SENode* child : new_children) { | 
|---|
| 460 | add_node->AddChild(child); | 
|---|
| 461 | } | 
|---|
| 462 |  | 
|---|
| 463 | return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); | 
|---|
| 464 | } | 
|---|
| 465 |  | 
|---|
| 466 | // Return the recurrent term belonging to |loop| if it appears in the graph | 
|---|
| 467 | // starting at |node| or null if it doesn't. | 
|---|
| 468 | SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node, | 
|---|
| 469 | const Loop* loop) { | 
|---|
| 470 | for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { | 
|---|
| 471 | SERecurrentNode* rec = itr->AsSERecurrentNode(); | 
|---|
| 472 | if (rec && rec->GetLoop() == loop) { | 
|---|
| 473 | return rec; | 
|---|
| 474 | } | 
|---|
| 475 | } | 
|---|
| 476 | return nullptr; | 
|---|
| 477 | } | 
|---|
| 478 | std::string SENode::AsString() const { | 
|---|
| 479 | switch (GetType()) { | 
|---|
| 480 | case Constant: | 
|---|
| 481 | return "Constant"; | 
|---|
| 482 | case RecurrentAddExpr: | 
|---|
| 483 | return "RecurrentAddExpr"; | 
|---|
| 484 | case Add: | 
|---|
| 485 | return "Add"; | 
|---|
| 486 | case Negative: | 
|---|
| 487 | return "Negative"; | 
|---|
| 488 | case Multiply: | 
|---|
| 489 | return "Multiply"; | 
|---|
| 490 | case ValueUnknown: | 
|---|
| 491 | return "Value Unknown"; | 
|---|
| 492 | case CanNotCompute: | 
|---|
| 493 | return "Can not compute"; | 
|---|
| 494 | } | 
|---|
| 495 | return "NULL"; | 
|---|
| 496 | } | 
|---|
| 497 |  | 
|---|
| 498 | bool SENode::operator==(const SENode& other) const { | 
|---|
| 499 | if (GetType() != other.GetType()) return false; | 
|---|
| 500 |  | 
|---|
| 501 | if (other.GetChildren().size() != children_.size()) return false; | 
|---|
| 502 |  | 
|---|
| 503 | const SERecurrentNode* this_as_recurrent = AsSERecurrentNode(); | 
|---|
| 504 |  | 
|---|
| 505 | // Check the children are the same, for SERecurrentNodes we need to check the | 
|---|
| 506 | // offset and coefficient manually as the child vector is sorted by ids so the | 
|---|
| 507 | // offset/coefficient information is lost. | 
|---|
| 508 | if (!this_as_recurrent) { | 
|---|
| 509 | for (size_t index = 0; index < children_.size(); ++index) { | 
|---|
| 510 | if (other.GetChildren()[index] != children_[index]) return false; | 
|---|
| 511 | } | 
|---|
| 512 | } else { | 
|---|
| 513 | const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode(); | 
|---|
| 514 |  | 
|---|
| 515 | // We've already checked the types are the same, this should not fail if | 
|---|
| 516 | // this->AsSERecurrentNode() succeeded. | 
|---|
| 517 | assert(other_as_recurrent); | 
|---|
| 518 |  | 
|---|
| 519 | if (this_as_recurrent->GetCoefficient() != | 
|---|
| 520 | other_as_recurrent->GetCoefficient()) | 
|---|
| 521 | return false; | 
|---|
| 522 |  | 
|---|
| 523 | if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset()) | 
|---|
| 524 | return false; | 
|---|
| 525 |  | 
|---|
| 526 | if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop()) | 
|---|
| 527 | return false; | 
|---|
| 528 | } | 
|---|
| 529 |  | 
|---|
| 530 | // If we're dealing with a value unknown node check both nodes were created by | 
|---|
| 531 | // the same instruction. | 
|---|
| 532 | if (GetType() == SENode::ValueUnknown) { | 
|---|
| 533 | if (AsSEValueUnknown()->ResultId() != | 
|---|
| 534 | other.AsSEValueUnknown()->ResultId()) { | 
|---|
| 535 | return false; | 
|---|
| 536 | } | 
|---|
| 537 | } | 
|---|
| 538 |  | 
|---|
| 539 | if (AsSEConstantNode()) { | 
|---|
| 540 | if (AsSEConstantNode()->FoldToSingleValue() != | 
|---|
| 541 | other.AsSEConstantNode()->FoldToSingleValue()) | 
|---|
| 542 | return false; | 
|---|
| 543 | } | 
|---|
| 544 |  | 
|---|
| 545 | return true; | 
|---|
| 546 | } | 
|---|
| 547 |  | 
|---|
| 548 | bool SENode::operator!=(const SENode& other) const { return !(*this == other); } | 
|---|
| 549 |  | 
|---|
| 550 | namespace { | 
|---|
| 551 | // Helper functions to insert 32/64 bit values into the 32 bit hash string. This | 
|---|
| 552 | // allows us to add pointers to the string by reinterpreting the pointers as | 
|---|
| 553 | // uintptr_t. PushToString will deduce the type, call sizeof on it and use | 
|---|
| 554 | // that size to call into the correct PushToStringImpl functor depending on | 
|---|
| 555 | // whether it is 32 or 64 bit. | 
|---|
| 556 |  | 
|---|
| 557 | template <typename T, size_t size_of_t> | 
|---|
| 558 | struct PushToStringImpl; | 
|---|
| 559 |  | 
|---|
| 560 | template <typename T> | 
|---|
| 561 | struct PushToStringImpl<T, 8> { | 
|---|
| 562 | void operator()(T id, std::u32string* str) { | 
|---|
| 563 | str->push_back(static_cast<uint32_t>(id >> 32)); | 
|---|
| 564 | str->push_back(static_cast<uint32_t>(id)); | 
|---|
| 565 | } | 
|---|
| 566 | }; | 
|---|
| 567 |  | 
|---|
| 568 | template <typename T> | 
|---|
| 569 | struct PushToStringImpl<T, 4> { | 
|---|
| 570 | void operator()(T id, std::u32string* str) { | 
|---|
| 571 | str->push_back(static_cast<uint32_t>(id)); | 
|---|
| 572 | } | 
|---|
| 573 | }; | 
|---|
| 574 |  | 
|---|
| 575 | template <typename T> | 
|---|
| 576 | static void PushToString(T id, std::u32string* str) { | 
|---|
| 577 | PushToStringImpl<T, sizeof(T)>{}(id, str); | 
|---|
| 578 | } | 
|---|
| 579 |  | 
|---|
| 580 | }  // namespace | 
|---|
| 581 |  | 
|---|
| 582 | // Implements the hashing of SENodes. | 
|---|
| 583 | size_t SENodeHash::operator()(const SENode* node) const { | 
|---|
| 584 | // Concatinate the terms into a string which we can hash. | 
|---|
| 585 | std::u32string hash_string{}; | 
|---|
| 586 |  | 
|---|
| 587 | // Hashing the type as a string is safer than hashing the enum as the enum is | 
|---|
| 588 | // very likely to collide with constants. | 
|---|
| 589 | for (char ch : node->AsString()) { | 
|---|
| 590 | hash_string.push_back(static_cast<char32_t>(ch)); | 
|---|
| 591 | } | 
|---|
| 592 |  | 
|---|
| 593 | // We just ignore the literal value unless it is a constant. | 
|---|
| 594 | if (node->GetType() == SENode::Constant) | 
|---|
| 595 | PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string); | 
|---|
| 596 |  | 
|---|
| 597 | const SERecurrentNode* recurrent = node->AsSERecurrentNode(); | 
|---|
| 598 |  | 
|---|
| 599 | // If we're dealing with a recurrent expression hash the loop as well so that | 
|---|
| 600 | // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes. | 
|---|
| 601 | if (recurrent) { | 
|---|
| 602 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()), | 
|---|
| 603 | &hash_string); | 
|---|
| 604 |  | 
|---|
| 605 | // Recurrent expressions can't be hashed using the normal method as the | 
|---|
| 606 | // order of coefficient and offset matters to the hash. | 
|---|
| 607 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()), | 
|---|
| 608 | &hash_string); | 
|---|
| 609 | PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()), | 
|---|
| 610 | &hash_string); | 
|---|
| 611 |  | 
|---|
| 612 | return std::hash<std::u32string>{}(hash_string); | 
|---|
| 613 | } | 
|---|
| 614 |  | 
|---|
| 615 | // Hash the result id of the original instruction which created this node if | 
|---|
| 616 | // it is a value unknown node. | 
|---|
| 617 | if (node->GetType() == SENode::ValueUnknown) { | 
|---|
| 618 | PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string); | 
|---|
| 619 | } | 
|---|
| 620 |  | 
|---|
| 621 | // Hash the pointers of the child nodes, each SENode has a unique pointer | 
|---|
| 622 | // associated with it. | 
|---|
| 623 | const std::vector<SENode*>& children = node->GetChildren(); | 
|---|
| 624 | for (const SENode* child : children) { | 
|---|
| 625 | PushToString(reinterpret_cast<uintptr_t>(child), &hash_string); | 
|---|
| 626 | } | 
|---|
| 627 |  | 
|---|
| 628 | return std::hash<std::u32string>{}(hash_string); | 
|---|
| 629 | } | 
|---|
| 630 |  | 
|---|
| 631 | // This overload is the actual overload used by the node_cache_ set. | 
|---|
| 632 | size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const { | 
|---|
| 633 | return this->operator()(node.get()); | 
|---|
| 634 | } | 
|---|
| 635 |  | 
|---|
| 636 | void SENode::DumpDot(std::ostream& out, bool recurse) const { | 
|---|
| 637 | size_t unique_id = std::hash<const SENode*>{}(this); | 
|---|
| 638 | out << unique_id << " [label=\""<< AsString() << " "; | 
|---|
| 639 | if (GetType() == SENode::Constant) { | 
|---|
| 640 | out << "\nwith value: "<< this->AsSEConstantNode()->FoldToSingleValue(); | 
|---|
| 641 | } | 
|---|
| 642 | out << "\"]\n"; | 
|---|
| 643 | for (const SENode* child : children_) { | 
|---|
| 644 | size_t child_unique_id = std::hash<const SENode*>{}(child); | 
|---|
| 645 | out << unique_id << " -> "<< child_unique_id << " \n"; | 
|---|
| 646 | if (recurse) child->DumpDot(out, true); | 
|---|
| 647 | } | 
|---|
| 648 | } | 
|---|
| 649 |  | 
|---|
| 650 | namespace { | 
|---|
| 651 | class IsGreaterThanZero { | 
|---|
| 652 | public: | 
|---|
| 653 | explicit IsGreaterThanZero(IRContext* context) : context_(context) {} | 
|---|
| 654 |  | 
|---|
| 655 | // Determine if the value of |node| is always strictly greater than zero if | 
|---|
| 656 | // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is | 
|---|
| 657 | // true. It returns true is the evaluation was able to conclude something, in | 
|---|
| 658 | // which case the result is stored in |result|. | 
|---|
| 659 | // The algorithm work by going through all the nodes and determine the | 
|---|
| 660 | // sign of each of them. | 
|---|
| 661 | bool Eval(const SENode* node, bool or_equal_zero, bool* result) { | 
|---|
| 662 | *result = false; | 
|---|
| 663 | switch (Visit(node)) { | 
|---|
| 664 | case Signedness::kPositiveOrNegative: { | 
|---|
| 665 | return false; | 
|---|
| 666 | } | 
|---|
| 667 | case Signedness::kStrictlyNegative: { | 
|---|
| 668 | *result = false; | 
|---|
| 669 | break; | 
|---|
| 670 | } | 
|---|
| 671 | case Signedness::kNegative: { | 
|---|
| 672 | if (!or_equal_zero) { | 
|---|
| 673 | return false; | 
|---|
| 674 | } | 
|---|
| 675 | *result = false; | 
|---|
| 676 | break; | 
|---|
| 677 | } | 
|---|
| 678 | case Signedness::kStrictlyPositive: { | 
|---|
| 679 | *result = true; | 
|---|
| 680 | break; | 
|---|
| 681 | } | 
|---|
| 682 | case Signedness::kPositive: { | 
|---|
| 683 | if (!or_equal_zero) { | 
|---|
| 684 | return false; | 
|---|
| 685 | } | 
|---|
| 686 | *result = true; | 
|---|
| 687 | break; | 
|---|
| 688 | } | 
|---|
| 689 | } | 
|---|
| 690 | return true; | 
|---|
| 691 | } | 
|---|
| 692 |  | 
|---|
| 693 | private: | 
|---|
| 694 | enum class Signedness { | 
|---|
| 695 | kPositiveOrNegative,  // Yield a value positive or negative. | 
|---|
| 696 | kStrictlyNegative,    // Yield a value strictly less than 0. | 
|---|
| 697 | kNegative,            // Yield a value less or equal to 0. | 
|---|
| 698 | kStrictlyPositive,    // Yield a value strictly greater than 0. | 
|---|
| 699 | kPositive             // Yield a value greater or equal to 0. | 
|---|
| 700 | }; | 
|---|
| 701 |  | 
|---|
| 702 | // Combine the signedness according to arithmetic rules of a given operator. | 
|---|
| 703 | using Combiner = std::function<Signedness(Signedness, Signedness)>; | 
|---|
| 704 |  | 
|---|
| 705 | // Returns a functor to interpret the signedness of 2 expressions as if they | 
|---|
| 706 | // were added. | 
|---|
| 707 | Combiner GetAddCombiner() const { | 
|---|
| 708 | return [](Signedness lhs, Signedness rhs) { | 
|---|
| 709 | switch (lhs) { | 
|---|
| 710 | case Signedness::kPositiveOrNegative: | 
|---|
| 711 | break; | 
|---|
| 712 | case Signedness::kStrictlyNegative: | 
|---|
| 713 | if (rhs == Signedness::kStrictlyNegative || | 
|---|
| 714 | rhs == Signedness::kNegative) | 
|---|
| 715 | return lhs; | 
|---|
| 716 | break; | 
|---|
| 717 | case Signedness::kNegative: { | 
|---|
| 718 | if (rhs == Signedness::kStrictlyNegative) | 
|---|
| 719 | return Signedness::kStrictlyNegative; | 
|---|
| 720 | if (rhs == Signedness::kNegative) return Signedness::kNegative; | 
|---|
| 721 | break; | 
|---|
| 722 | } | 
|---|
| 723 | case Signedness::kStrictlyPositive: { | 
|---|
| 724 | if (rhs == Signedness::kStrictlyPositive || | 
|---|
| 725 | rhs == Signedness::kPositive) { | 
|---|
| 726 | return Signedness::kStrictlyPositive; | 
|---|
| 727 | } | 
|---|
| 728 | break; | 
|---|
| 729 | } | 
|---|
| 730 | case Signedness::kPositive: { | 
|---|
| 731 | if (rhs == Signedness::kStrictlyPositive) | 
|---|
| 732 | return Signedness::kStrictlyPositive; | 
|---|
| 733 | if (rhs == Signedness::kPositive) return Signedness::kPositive; | 
|---|
| 734 | break; | 
|---|
| 735 | } | 
|---|
| 736 | } | 
|---|
| 737 | return Signedness::kPositiveOrNegative; | 
|---|
| 738 | }; | 
|---|
| 739 | } | 
|---|
| 740 |  | 
|---|
| 741 | // Returns a functor to interpret the signedness of 2 expressions as if they | 
|---|
| 742 | // were multiplied. | 
|---|
| 743 | Combiner GetMulCombiner() const { | 
|---|
| 744 | return [](Signedness lhs, Signedness rhs) { | 
|---|
| 745 | switch (lhs) { | 
|---|
| 746 | case Signedness::kPositiveOrNegative: | 
|---|
| 747 | break; | 
|---|
| 748 | case Signedness::kStrictlyNegative: { | 
|---|
| 749 | switch (rhs) { | 
|---|
| 750 | case Signedness::kPositiveOrNegative: { | 
|---|
| 751 | break; | 
|---|
| 752 | } | 
|---|
| 753 | case Signedness::kStrictlyNegative: { | 
|---|
| 754 | return Signedness::kStrictlyPositive; | 
|---|
| 755 | } | 
|---|
| 756 | case Signedness::kNegative: { | 
|---|
| 757 | return Signedness::kPositive; | 
|---|
| 758 | } | 
|---|
| 759 | case Signedness::kStrictlyPositive: { | 
|---|
| 760 | return Signedness::kStrictlyNegative; | 
|---|
| 761 | } | 
|---|
| 762 | case Signedness::kPositive: { | 
|---|
| 763 | return Signedness::kNegative; | 
|---|
| 764 | } | 
|---|
| 765 | } | 
|---|
| 766 | break; | 
|---|
| 767 | } | 
|---|
| 768 | case Signedness::kNegative: { | 
|---|
| 769 | switch (rhs) { | 
|---|
| 770 | case Signedness::kPositiveOrNegative: { | 
|---|
| 771 | break; | 
|---|
| 772 | } | 
|---|
| 773 | case Signedness::kStrictlyNegative: | 
|---|
| 774 | case Signedness::kNegative: { | 
|---|
| 775 | return Signedness::kPositive; | 
|---|
| 776 | } | 
|---|
| 777 | case Signedness::kStrictlyPositive: | 
|---|
| 778 | case Signedness::kPositive: { | 
|---|
| 779 | return Signedness::kNegative; | 
|---|
| 780 | } | 
|---|
| 781 | } | 
|---|
| 782 | break; | 
|---|
| 783 | } | 
|---|
| 784 | case Signedness::kStrictlyPositive: { | 
|---|
| 785 | return rhs; | 
|---|
| 786 | } | 
|---|
| 787 | case Signedness::kPositive: { | 
|---|
| 788 | switch (rhs) { | 
|---|
| 789 | case Signedness::kPositiveOrNegative: { | 
|---|
| 790 | break; | 
|---|
| 791 | } | 
|---|
| 792 | case Signedness::kStrictlyNegative: | 
|---|
| 793 | case Signedness::kNegative: { | 
|---|
| 794 | return Signedness::kNegative; | 
|---|
| 795 | } | 
|---|
| 796 | case Signedness::kStrictlyPositive: | 
|---|
| 797 | case Signedness::kPositive: { | 
|---|
| 798 | return Signedness::kPositive; | 
|---|
| 799 | } | 
|---|
| 800 | } | 
|---|
| 801 | break; | 
|---|
| 802 | } | 
|---|
| 803 | } | 
|---|
| 804 | return Signedness::kPositiveOrNegative; | 
|---|
| 805 | }; | 
|---|
| 806 | } | 
|---|
| 807 |  | 
|---|
| 808 | Signedness Visit(const SENode* node) { | 
|---|
| 809 | switch (node->GetType()) { | 
|---|
| 810 | case SENode::Constant: | 
|---|
| 811 | return Visit(node->AsSEConstantNode()); | 
|---|
| 812 | break; | 
|---|
| 813 | case SENode::RecurrentAddExpr: | 
|---|
| 814 | return Visit(node->AsSERecurrentNode()); | 
|---|
| 815 | break; | 
|---|
| 816 | case SENode::Negative: | 
|---|
| 817 | return Visit(node->AsSENegative()); | 
|---|
| 818 | break; | 
|---|
| 819 | case SENode::CanNotCompute: | 
|---|
| 820 | return Visit(node->AsSECantCompute()); | 
|---|
| 821 | break; | 
|---|
| 822 | case SENode::ValueUnknown: | 
|---|
| 823 | return Visit(node->AsSEValueUnknown()); | 
|---|
| 824 | break; | 
|---|
| 825 | case SENode::Add: | 
|---|
| 826 | return VisitExpr(node, GetAddCombiner()); | 
|---|
| 827 | break; | 
|---|
| 828 | case SENode::Multiply: | 
|---|
| 829 | return VisitExpr(node, GetMulCombiner()); | 
|---|
| 830 | break; | 
|---|
| 831 | } | 
|---|
| 832 | return Signedness::kPositiveOrNegative; | 
|---|
| 833 | } | 
|---|
| 834 |  | 
|---|
| 835 | // Returns the signedness of a constant |node|. | 
|---|
| 836 | Signedness Visit(const SEConstantNode* node) { | 
|---|
| 837 | if (0 == node->FoldToSingleValue()) return Signedness::kPositive; | 
|---|
| 838 | if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive; | 
|---|
| 839 | if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative; | 
|---|
| 840 | return Signedness::kPositiveOrNegative; | 
|---|
| 841 | } | 
|---|
| 842 |  | 
|---|
| 843 | // Returns the signedness of an unknown |node| based on its type. | 
|---|
| 844 | Signedness Visit(const SEValueUnknown* node) { | 
|---|
| 845 | Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId()); | 
|---|
| 846 | analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id()); | 
|---|
| 847 | assert(type && "Can't retrieve a type for the instruction"); | 
|---|
| 848 | analysis::Integer* int_type = type->AsInteger(); | 
|---|
| 849 | assert(type && "Can't retrieve an integer type for the instruction"); | 
|---|
| 850 | return int_type->IsSigned() ? Signedness::kPositiveOrNegative | 
|---|
| 851 | : Signedness::kPositive; | 
|---|
| 852 | } | 
|---|
| 853 |  | 
|---|
| 854 | // Returns the signedness of a recurring expression. | 
|---|
| 855 | Signedness Visit(const SERecurrentNode* node) { | 
|---|
| 856 | Signedness coeff_sign = Visit(node->GetCoefficient()); | 
|---|
| 857 | // SERecurrentNode represent an affine expression in the range [0, | 
|---|
| 858 | // loop_bound], so the result cannot be strictly positive or negative. | 
|---|
| 859 | switch (coeff_sign) { | 
|---|
| 860 | default: | 
|---|
| 861 | break; | 
|---|
| 862 | case Signedness::kStrictlyNegative: | 
|---|
| 863 | coeff_sign = Signedness::kNegative; | 
|---|
| 864 | break; | 
|---|
| 865 | case Signedness::kStrictlyPositive: | 
|---|
| 866 | coeff_sign = Signedness::kPositive; | 
|---|
| 867 | break; | 
|---|
| 868 | } | 
|---|
| 869 | return GetAddCombiner()(coeff_sign, Visit(node->GetOffset())); | 
|---|
| 870 | } | 
|---|
| 871 |  | 
|---|
| 872 | // Returns the signedness of a negation |node|. | 
|---|
| 873 | Signedness Visit(const SENegative* node) { | 
|---|
| 874 | switch (Visit(*node->begin())) { | 
|---|
| 875 | case Signedness::kPositiveOrNegative: { | 
|---|
| 876 | return Signedness::kPositiveOrNegative; | 
|---|
| 877 | } | 
|---|
| 878 | case Signedness::kStrictlyNegative: { | 
|---|
| 879 | return Signedness::kStrictlyPositive; | 
|---|
| 880 | } | 
|---|
| 881 | case Signedness::kNegative: { | 
|---|
| 882 | return Signedness::kPositive; | 
|---|
| 883 | } | 
|---|
| 884 | case Signedness::kStrictlyPositive: { | 
|---|
| 885 | return Signedness::kStrictlyNegative; | 
|---|
| 886 | } | 
|---|
| 887 | case Signedness::kPositive: { | 
|---|
| 888 | return Signedness::kNegative; | 
|---|
| 889 | } | 
|---|
| 890 | } | 
|---|
| 891 | return Signedness::kPositiveOrNegative; | 
|---|
| 892 | } | 
|---|
| 893 |  | 
|---|
| 894 | Signedness Visit(const SECantCompute*) { | 
|---|
| 895 | return Signedness::kPositiveOrNegative; | 
|---|
| 896 | } | 
|---|
| 897 |  | 
|---|
| 898 | // Returns the signedness of a binary expression by using the combiner | 
|---|
| 899 | // |reduce|. | 
|---|
| 900 | Signedness VisitExpr( | 
|---|
| 901 | const SENode* node, | 
|---|
| 902 | std::function<Signedness(Signedness, Signedness)> reduce) { | 
|---|
| 903 | Signedness result = Visit(*node->begin()); | 
|---|
| 904 | for (const SENode* operand : make_range(++node->begin(), node->end())) { | 
|---|
| 905 | if (result == Signedness::kPositiveOrNegative) { | 
|---|
| 906 | return Signedness::kPositiveOrNegative; | 
|---|
| 907 | } | 
|---|
| 908 | result = reduce(result, Visit(operand)); | 
|---|
| 909 | } | 
|---|
| 910 | return result; | 
|---|
| 911 | } | 
|---|
| 912 |  | 
|---|
| 913 | IRContext* context_; | 
|---|
| 914 | }; | 
|---|
| 915 | }  // namespace | 
|---|
| 916 |  | 
|---|
| 917 | bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node, | 
|---|
| 918 | bool* is_gt_zero) const { | 
|---|
| 919 | return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero); | 
|---|
| 920 | } | 
|---|
| 921 |  | 
|---|
| 922 | bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero( | 
|---|
| 923 | SENode* node, bool* is_ge_zero) const { | 
|---|
| 924 | return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero); | 
|---|
| 925 | } | 
|---|
| 926 |  | 
|---|
| 927 | namespace { | 
|---|
| 928 |  | 
|---|
| 929 | // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z), | 
|---|
| 930 | // if |node| is not in the chain, returns the original chain. | 
|---|
| 931 | static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul, | 
|---|
| 932 | const SENode* node) { | 
|---|
| 933 | SENode* lhs = mul->GetChildren()[0]; | 
|---|
| 934 | SENode* rhs = mul->GetChildren()[1]; | 
|---|
| 935 | if (lhs == node) { | 
|---|
| 936 | return rhs; | 
|---|
| 937 | } | 
|---|
| 938 | if (rhs == node) { | 
|---|
| 939 | return lhs; | 
|---|
| 940 | } | 
|---|
| 941 | if (lhs->AsSEMultiplyNode()) { | 
|---|
| 942 | SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node); | 
|---|
| 943 | if (res != lhs) | 
|---|
| 944 | return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); | 
|---|
| 945 | } | 
|---|
| 946 | if (rhs->AsSEMultiplyNode()) { | 
|---|
| 947 | SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node); | 
|---|
| 948 | if (res != rhs) | 
|---|
| 949 | return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); | 
|---|
| 950 | } | 
|---|
| 951 |  | 
|---|
| 952 | return mul; | 
|---|
| 953 | } | 
|---|
| 954 | }  // namespace | 
|---|
| 955 |  | 
|---|
| 956 | std::pair<SExpression, int64_t> SExpression::operator/( | 
|---|
| 957 | SExpression rhs_wrapper) const { | 
|---|
| 958 | SENode* lhs = node_; | 
|---|
| 959 | SENode* rhs = rhs_wrapper.node_; | 
|---|
| 960 | // Check for division by 0. | 
|---|
| 961 | if (rhs->AsSEConstantNode() && | 
|---|
| 962 | !rhs->AsSEConstantNode()->FoldToSingleValue()) { | 
|---|
| 963 | return {scev_->CreateCantComputeNode(), 0}; | 
|---|
| 964 | } | 
|---|
| 965 |  | 
|---|
| 966 | // Trivial case. | 
|---|
| 967 | if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) { | 
|---|
| 968 | int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue(); | 
|---|
| 969 | int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue(); | 
|---|
| 970 | return {scev_->CreateConstant(lhs_value / rhs_value), | 
|---|
| 971 | lhs_value % rhs_value}; | 
|---|
| 972 | } | 
|---|
| 973 |  | 
|---|
| 974 | // look for a "c U / U" pattern. | 
|---|
| 975 | if (lhs->AsSEMultiplyNode()) { | 
|---|
| 976 | assert(lhs->GetChildren().size() == 2 && | 
|---|
| 977 | "More than 2 operand for a multiply node."); | 
|---|
| 978 | SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs); | 
|---|
| 979 | if (res != lhs) { | 
|---|
| 980 | return {res, 0}; | 
|---|
| 981 | } | 
|---|
| 982 | } | 
|---|
| 983 |  | 
|---|
| 984 | return {scev_->CreateCantComputeNode(), 0}; | 
|---|
| 985 | } | 
|---|
| 986 |  | 
|---|
| 987 | }  // namespace opt | 
|---|
| 988 | }  // namespace spvtools | 
|---|
| 989 |  | 
|---|