1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/scalar_analysis.h"
16
17#include <algorithm>
18#include <functional>
19#include <string>
20#include <utility>
21
22#include "source/opt/ir_context.h"
23
24// Transforms a given scalar operation instruction into a DAG representation.
25//
26// 1. Take an instruction and traverse its operands until we reach a
27// constant node or an instruction which we do not know how to compute the
28// value, such as a load.
29//
30// 2. Create a new node for each instruction traversed and build the nodes for
31// the in operands of that instruction as well.
32//
33// 3. Add the operand nodes as children of the first and hash the node. Use the
34// hash to see if the node is already in the cache. We ensure the children are
35// always in sorted order so that two nodes with the same children but inserted
36// in a different order have the same hash and so that the overloaded operator==
37// will return true. If the node is already in the cache return the cached
38// version instead.
39//
40// 4. The created DAG can then be simplified by
41// ScalarAnalysis::SimplifyExpression, implemented in
42// scalar_analysis_simplification.cpp. See that file for further information on
43// the simplification process.
44//
45
46namespace spvtools {
47namespace opt {
48
49uint32_t SENode::NumberOfNodes = 0;
50
51ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
52 : context_(context), pretend_equal_{} {
53 // Create and cached the CantComputeNode.
54 cached_cant_compute_ =
55 GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
56}
57
58SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
59 // If operand is can't compute then the whole graph is can't compute.
60 if (operand->IsCantCompute()) return CreateCantComputeNode();
61
62 if (operand->GetType() == SENode::Constant) {
63 return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
64 }
65 std::unique_ptr<SENode> negation_node{new SENegative(this)};
66 negation_node->AddChild(operand);
67 return GetCachedOrAdd(std::move(negation_node));
68}
69
70SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
71 return GetCachedOrAdd(
72 std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
73}
74
75SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
76 const Loop* loop, SENode* offset, SENode* coefficient) {
77 assert(loop && "Recurrent add expressions must have a valid loop.");
78
79 // If operands are can't compute then the whole graph is can't compute.
80 if (offset->IsCantCompute() || coefficient->IsCantCompute())
81 return CreateCantComputeNode();
82
83 const Loop* loop_to_use = nullptr;
84 if (pretend_equal_[loop]) {
85 loop_to_use = pretend_equal_[loop];
86 } else {
87 loop_to_use = loop;
88 }
89
90 std::unique_ptr<SERecurrentNode> phi_node{
91 new SERecurrentNode(this, loop_to_use)};
92 phi_node->AddOffset(offset);
93 phi_node->AddCoefficient(coefficient);
94
95 return GetCachedOrAdd(std::move(phi_node));
96}
97
98SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
99 const Instruction* multiply) {
100 assert(multiply->opcode() == SpvOp::SpvOpIMul &&
101 "Multiply node did not come from a multiply instruction");
102 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
103
104 SENode* op1 =
105 AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
106 SENode* op2 =
107 AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
108
109 return CreateMultiplyNode(op1, op2);
110}
111
112SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
113 SENode* operand_2) {
114 // If operands are can't compute then the whole graph is can't compute.
115 if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
116 return CreateCantComputeNode();
117
118 if (operand_1->GetType() == SENode::Constant &&
119 operand_2->GetType() == SENode::Constant) {
120 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
121 operand_2->AsSEConstantNode()->FoldToSingleValue());
122 }
123
124 std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
125
126 multiply_node->AddChild(operand_1);
127 multiply_node->AddChild(operand_2);
128
129 return GetCachedOrAdd(std::move(multiply_node));
130}
131
132SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
133 SENode* operand_2) {
134 // Fold if both operands are constant.
135 if (operand_1->GetType() == SENode::Constant &&
136 operand_2->GetType() == SENode::Constant) {
137 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
138 operand_2->AsSEConstantNode()->FoldToSingleValue());
139 }
140
141 return CreateAddNode(operand_1, CreateNegation(operand_2));
142}
143
144SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
145 SENode* operand_2) {
146 // Fold if both operands are constant and the |simplify| flag is true.
147 if (operand_1->GetType() == SENode::Constant &&
148 operand_2->GetType() == SENode::Constant) {
149 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
150 operand_2->AsSEConstantNode()->FoldToSingleValue());
151 }
152
153 // If operands are can't compute then the whole graph is can't compute.
154 if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
155 return CreateCantComputeNode();
156
157 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
158
159 add_node->AddChild(operand_1);
160 add_node->AddChild(operand_2);
161
162 return GetCachedOrAdd(std::move(add_node));
163}
164
165SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
166 auto itr = recurrent_node_map_.find(inst);
167 if (itr != recurrent_node_map_.end()) return itr->second;
168
169 SENode* output = nullptr;
170 switch (inst->opcode()) {
171 case SpvOp::SpvOpPhi: {
172 output = AnalyzePhiInstruction(inst);
173 break;
174 }
175 case SpvOp::SpvOpConstant:
176 case SpvOp::SpvOpConstantNull: {
177 output = AnalyzeConstant(inst);
178 break;
179 }
180 case SpvOp::SpvOpISub:
181 case SpvOp::SpvOpIAdd: {
182 output = AnalyzeAddOp(inst);
183 break;
184 }
185 case SpvOp::SpvOpIMul: {
186 output = AnalyzeMultiplyOp(inst);
187 break;
188 }
189 default: {
190 output = CreateValueUnknownNode(inst);
191 break;
192 }
193 }
194
195 return output;
196}
197
198SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
199 if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0);
200
201 assert(inst->opcode() == SpvOp::SpvOpConstant);
202 assert(inst->NumInOperands() == 1);
203 int64_t value = 0;
204
205 // Look up the instruction in the constant manager.
206 const analysis::Constant* constant =
207 context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
208
209 if (!constant) return CreateCantComputeNode();
210
211 const analysis::IntConstant* int_constant = constant->AsIntConstant();
212
213 // Exit out if it is a 64 bit integer.
214 if (!int_constant || int_constant->words().size() != 1)
215 return CreateCantComputeNode();
216
217 if (int_constant->type()->AsInteger()->IsSigned()) {
218 value = int_constant->GetS32BitValue();
219 } else {
220 value = int_constant->GetU32BitValue();
221 }
222
223 return CreateConstant(value);
224}
225
226// Handles both addition and subtraction. If the |sub| flag is set then the
227// addition will be op1+(-op2) otherwise op1+op2.
228SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
229 assert((inst->opcode() == SpvOp::SpvOpIAdd ||
230 inst->opcode() == SpvOp::SpvOpISub) &&
231 "Add node must be created from a OpIAdd or OpISub instruction");
232
233 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
234
235 SENode* op1 =
236 AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
237
238 SENode* op2 =
239 AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
240
241 // To handle subtraction we wrap the second operand in a unary negation node.
242 if (inst->opcode() == SpvOp::SpvOpISub) {
243 op2 = CreateNegation(op2);
244 }
245
246 return CreateAddNode(op1, op2);
247}
248
249SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
250 // The phi should only have two incoming value pairs.
251 if (phi->NumInOperands() != 4) {
252 return CreateCantComputeNode();
253 }
254
255 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
256
257 // Get the basic block this instruction belongs to.
258 BasicBlock* basic_block =
259 context_->get_instr_block(const_cast<Instruction*>(phi));
260
261 // And then the function that the basic blocks belongs to.
262 Function* function = basic_block->GetParent();
263
264 // Use the function to get the loop descriptor.
265 LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
266
267 // We only handle phis in loops at the moment.
268 if (!loop_descriptor) return CreateCantComputeNode();
269
270 // Get the innermost loop which this block belongs to.
271 Loop* loop = (*loop_descriptor)[basic_block->id()];
272
273 // If the loop doesn't exist or doesn't have a preheader or latch block, exit
274 // out.
275 if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
276 loop->GetHeaderBlock() != basic_block)
277 return recurrent_node_map_[phi] = CreateCantComputeNode();
278
279 const Loop* loop_to_use = nullptr;
280 if (pretend_equal_[loop]) {
281 loop_to_use = pretend_equal_[loop];
282 } else {
283 loop_to_use = loop;
284 }
285 std::unique_ptr<SERecurrentNode> phi_node{
286 new SERecurrentNode(this, loop_to_use)};
287
288 // We add the node to this map to allow it to be returned before the node is
289 // fully built. This is needed as the subsequent call to AnalyzeInstruction
290 // could lead back to this |phi| instruction so we return the pointer
291 // immediately in AnalyzeInstruction to break the recursion.
292 recurrent_node_map_[phi] = phi_node.get();
293
294 // Traverse the operands of the instruction an create new nodes for each one.
295 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
296 uint32_t value_id = phi->GetSingleWordInOperand(i);
297 uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
298
299 Instruction* value_inst = def_use->GetDef(value_id);
300 SENode* value_node = AnalyzeInstruction(value_inst);
301
302 // If any operand is CantCompute then the whole graph is CantCompute.
303 if (value_node->IsCantCompute())
304 return recurrent_node_map_[phi] = CreateCantComputeNode();
305
306 // If the value is coming from the preheader block then the value is the
307 // initial value of the phi.
308 if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
309 phi_node->AddOffset(value_node);
310 } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
311 // Assumed to be in the form of step + phi.
312 if (value_node->GetType() != SENode::Add)
313 return recurrent_node_map_[phi] = CreateCantComputeNode();
314
315 SENode* step_node = nullptr;
316 SENode* phi_operand = nullptr;
317 SENode* operand_1 = value_node->GetChild(0);
318 SENode* operand_2 = value_node->GetChild(1);
319
320 // Find which node is the step term.
321 if (!operand_1->AsSERecurrentNode())
322 step_node = operand_1;
323 else if (!operand_2->AsSERecurrentNode())
324 step_node = operand_2;
325
326 // Find which node is the recurrent expression.
327 if (operand_1->AsSERecurrentNode())
328 phi_operand = operand_1;
329 else if (operand_2->AsSERecurrentNode())
330 phi_operand = operand_2;
331
332 // If it is not in the form step + phi exit out.
333 if (!(step_node && phi_operand))
334 return recurrent_node_map_[phi] = CreateCantComputeNode();
335
336 // If the phi operand is not the same phi node exit out.
337 if (phi_operand != phi_node.get())
338 return recurrent_node_map_[phi] = CreateCantComputeNode();
339
340 if (!IsLoopInvariant(loop, step_node))
341 return recurrent_node_map_[phi] = CreateCantComputeNode();
342
343 phi_node->AddCoefficient(step_node);
344 }
345 }
346
347 // Once the node is fully built we update the map with the version from the
348 // cache (if it has already been added to the cache).
349 return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
350}
351
352SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
353 const Instruction* inst) {
354 std::unique_ptr<SEValueUnknown> load_node{
355 new SEValueUnknown(this, inst->result_id())};
356 return GetCachedOrAdd(std::move(load_node));
357}
358
359SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
360 return cached_cant_compute_;
361}
362
363// Add the created node into the cache of nodes. If it already exists return it.
364SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
365 std::unique_ptr<SENode> prospective_node) {
366 auto itr = node_cache_.find(prospective_node);
367 if (itr != node_cache_.end()) {
368 return (*itr).get();
369 }
370
371 SENode* raw_ptr_to_node = prospective_node.get();
372 node_cache_.insert(std::move(prospective_node));
373 return raw_ptr_to_node;
374}
375
376bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
377 const SENode* node) const {
378 for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
379 if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
380 const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
381
382 // If the loop which the recurrent expression belongs to is either |loop
383 // or a nested loop inside |loop| then we assume it is variant.
384 if (loop->IsInsideLoop(header)) {
385 return false;
386 }
387 } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
388 // If the instruction is inside the loop we conservatively assume it is
389 // loop variant.
390 if (loop->IsInsideLoop(unknown->ResultId())) return false;
391 }
392 }
393
394 return true;
395}
396
397SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
398 SENode* node, const Loop* loop) {
399 // Traverse the DAG to find the recurrent expression belonging to |loop|.
400 for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
401 SERecurrentNode* rec = itr->AsSERecurrentNode();
402 if (rec && rec->GetLoop() == loop) {
403 return rec->GetCoefficient();
404 }
405 }
406 return CreateConstant(0);
407}
408
409SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
410 SENode* old_child,
411 SENode* new_child) {
412 // Only handles add.
413 if (parent->GetType() != SENode::Add) return parent;
414
415 std::vector<SENode*> new_children;
416 for (SENode* child : *parent) {
417 if (child == old_child) {
418 new_children.push_back(new_child);
419 } else {
420 new_children.push_back(child);
421 }
422 }
423
424 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
425 for (SENode* child : new_children) {
426 add_node->AddChild(child);
427 }
428
429 return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
430}
431
432// Rebuild the |node| eliminating, if it exists, the recurrent term which
433// belongs to the |loop|.
434SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
435 SENode* node, const Loop* loop) {
436 // If the node is already a recurrent expression belonging to loop then just
437 // return the offset.
438 SERecurrentNode* recurrent = node->AsSERecurrentNode();
439 if (recurrent) {
440 if (recurrent->GetLoop() == loop) {
441 return recurrent->GetOffset();
442 } else {
443 return node;
444 }
445 }
446
447 std::vector<SENode*> new_children;
448 // Otherwise find the recurrent node in the children of this node.
449 for (auto itr : *node) {
450 recurrent = itr->AsSERecurrentNode();
451 if (recurrent && recurrent->GetLoop() == loop) {
452 new_children.push_back(recurrent->GetOffset());
453 } else {
454 new_children.push_back(itr);
455 }
456 }
457
458 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
459 for (SENode* child : new_children) {
460 add_node->AddChild(child);
461 }
462
463 return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
464}
465
466// Return the recurrent term belonging to |loop| if it appears in the graph
467// starting at |node| or null if it doesn't.
468SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
469 const Loop* loop) {
470 for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
471 SERecurrentNode* rec = itr->AsSERecurrentNode();
472 if (rec && rec->GetLoop() == loop) {
473 return rec;
474 }
475 }
476 return nullptr;
477}
478std::string SENode::AsString() const {
479 switch (GetType()) {
480 case Constant:
481 return "Constant";
482 case RecurrentAddExpr:
483 return "RecurrentAddExpr";
484 case Add:
485 return "Add";
486 case Negative:
487 return "Negative";
488 case Multiply:
489 return "Multiply";
490 case ValueUnknown:
491 return "Value Unknown";
492 case CanNotCompute:
493 return "Can not compute";
494 }
495 return "NULL";
496}
497
498bool SENode::operator==(const SENode& other) const {
499 if (GetType() != other.GetType()) return false;
500
501 if (other.GetChildren().size() != children_.size()) return false;
502
503 const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
504
505 // Check the children are the same, for SERecurrentNodes we need to check the
506 // offset and coefficient manually as the child vector is sorted by ids so the
507 // offset/coefficient information is lost.
508 if (!this_as_recurrent) {
509 for (size_t index = 0; index < children_.size(); ++index) {
510 if (other.GetChildren()[index] != children_[index]) return false;
511 }
512 } else {
513 const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
514
515 // We've already checked the types are the same, this should not fail if
516 // this->AsSERecurrentNode() succeeded.
517 assert(other_as_recurrent);
518
519 if (this_as_recurrent->GetCoefficient() !=
520 other_as_recurrent->GetCoefficient())
521 return false;
522
523 if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
524 return false;
525
526 if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
527 return false;
528 }
529
530 // If we're dealing with a value unknown node check both nodes were created by
531 // the same instruction.
532 if (GetType() == SENode::ValueUnknown) {
533 if (AsSEValueUnknown()->ResultId() !=
534 other.AsSEValueUnknown()->ResultId()) {
535 return false;
536 }
537 }
538
539 if (AsSEConstantNode()) {
540 if (AsSEConstantNode()->FoldToSingleValue() !=
541 other.AsSEConstantNode()->FoldToSingleValue())
542 return false;
543 }
544
545 return true;
546}
547
548bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
549
550namespace {
551// Helper functions to insert 32/64 bit values into the 32 bit hash string. This
552// allows us to add pointers to the string by reinterpreting the pointers as
553// uintptr_t. PushToString will deduce the type, call sizeof on it and use
554// that size to call into the correct PushToStringImpl functor depending on
555// whether it is 32 or 64 bit.
556
557template <typename T, size_t size_of_t>
558struct PushToStringImpl;
559
560template <typename T>
561struct PushToStringImpl<T, 8> {
562 void operator()(T id, std::u32string* str) {
563 str->push_back(static_cast<uint32_t>(id >> 32));
564 str->push_back(static_cast<uint32_t>(id));
565 }
566};
567
568template <typename T>
569struct PushToStringImpl<T, 4> {
570 void operator()(T id, std::u32string* str) {
571 str->push_back(static_cast<uint32_t>(id));
572 }
573};
574
575template <typename T>
576static void PushToString(T id, std::u32string* str) {
577 PushToStringImpl<T, sizeof(T)>{}(id, str);
578}
579
580} // namespace
581
582// Implements the hashing of SENodes.
583size_t SENodeHash::operator()(const SENode* node) const {
584 // Concatinate the terms into a string which we can hash.
585 std::u32string hash_string{};
586
587 // Hashing the type as a string is safer than hashing the enum as the enum is
588 // very likely to collide with constants.
589 for (char ch : node->AsString()) {
590 hash_string.push_back(static_cast<char32_t>(ch));
591 }
592
593 // We just ignore the literal value unless it is a constant.
594 if (node->GetType() == SENode::Constant)
595 PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
596
597 const SERecurrentNode* recurrent = node->AsSERecurrentNode();
598
599 // If we're dealing with a recurrent expression hash the loop as well so that
600 // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
601 if (recurrent) {
602 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
603 &hash_string);
604
605 // Recurrent expressions can't be hashed using the normal method as the
606 // order of coefficient and offset matters to the hash.
607 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
608 &hash_string);
609 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
610 &hash_string);
611
612 return std::hash<std::u32string>{}(hash_string);
613 }
614
615 // Hash the result id of the original instruction which created this node if
616 // it is a value unknown node.
617 if (node->GetType() == SENode::ValueUnknown) {
618 PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
619 }
620
621 // Hash the pointers of the child nodes, each SENode has a unique pointer
622 // associated with it.
623 const std::vector<SENode*>& children = node->GetChildren();
624 for (const SENode* child : children) {
625 PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
626 }
627
628 return std::hash<std::u32string>{}(hash_string);
629}
630
631// This overload is the actual overload used by the node_cache_ set.
632size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
633 return this->operator()(node.get());
634}
635
636void SENode::DumpDot(std::ostream& out, bool recurse) const {
637 size_t unique_id = std::hash<const SENode*>{}(this);
638 out << unique_id << " [label=\"" << AsString() << " ";
639 if (GetType() == SENode::Constant) {
640 out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
641 }
642 out << "\"]\n";
643 for (const SENode* child : children_) {
644 size_t child_unique_id = std::hash<const SENode*>{}(child);
645 out << unique_id << " -> " << child_unique_id << " \n";
646 if (recurse) child->DumpDot(out, true);
647 }
648}
649
650namespace {
651class IsGreaterThanZero {
652 public:
653 explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
654
655 // Determine if the value of |node| is always strictly greater than zero if
656 // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
657 // true. It returns true is the evaluation was able to conclude something, in
658 // which case the result is stored in |result|.
659 // The algorithm work by going through all the nodes and determine the
660 // sign of each of them.
661 bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
662 *result = false;
663 switch (Visit(node)) {
664 case Signedness::kPositiveOrNegative: {
665 return false;
666 }
667 case Signedness::kStrictlyNegative: {
668 *result = false;
669 break;
670 }
671 case Signedness::kNegative: {
672 if (!or_equal_zero) {
673 return false;
674 }
675 *result = false;
676 break;
677 }
678 case Signedness::kStrictlyPositive: {
679 *result = true;
680 break;
681 }
682 case Signedness::kPositive: {
683 if (!or_equal_zero) {
684 return false;
685 }
686 *result = true;
687 break;
688 }
689 }
690 return true;
691 }
692
693 private:
694 enum class Signedness {
695 kPositiveOrNegative, // Yield a value positive or negative.
696 kStrictlyNegative, // Yield a value strictly less than 0.
697 kNegative, // Yield a value less or equal to 0.
698 kStrictlyPositive, // Yield a value strictly greater than 0.
699 kPositive // Yield a value greater or equal to 0.
700 };
701
702 // Combine the signedness according to arithmetic rules of a given operator.
703 using Combiner = std::function<Signedness(Signedness, Signedness)>;
704
705 // Returns a functor to interpret the signedness of 2 expressions as if they
706 // were added.
707 Combiner GetAddCombiner() const {
708 return [](Signedness lhs, Signedness rhs) {
709 switch (lhs) {
710 case Signedness::kPositiveOrNegative:
711 break;
712 case Signedness::kStrictlyNegative:
713 if (rhs == Signedness::kStrictlyNegative ||
714 rhs == Signedness::kNegative)
715 return lhs;
716 break;
717 case Signedness::kNegative: {
718 if (rhs == Signedness::kStrictlyNegative)
719 return Signedness::kStrictlyNegative;
720 if (rhs == Signedness::kNegative) return Signedness::kNegative;
721 break;
722 }
723 case Signedness::kStrictlyPositive: {
724 if (rhs == Signedness::kStrictlyPositive ||
725 rhs == Signedness::kPositive) {
726 return Signedness::kStrictlyPositive;
727 }
728 break;
729 }
730 case Signedness::kPositive: {
731 if (rhs == Signedness::kStrictlyPositive)
732 return Signedness::kStrictlyPositive;
733 if (rhs == Signedness::kPositive) return Signedness::kPositive;
734 break;
735 }
736 }
737 return Signedness::kPositiveOrNegative;
738 };
739 }
740
741 // Returns a functor to interpret the signedness of 2 expressions as if they
742 // were multiplied.
743 Combiner GetMulCombiner() const {
744 return [](Signedness lhs, Signedness rhs) {
745 switch (lhs) {
746 case Signedness::kPositiveOrNegative:
747 break;
748 case Signedness::kStrictlyNegative: {
749 switch (rhs) {
750 case Signedness::kPositiveOrNegative: {
751 break;
752 }
753 case Signedness::kStrictlyNegative: {
754 return Signedness::kStrictlyPositive;
755 }
756 case Signedness::kNegative: {
757 return Signedness::kPositive;
758 }
759 case Signedness::kStrictlyPositive: {
760 return Signedness::kStrictlyNegative;
761 }
762 case Signedness::kPositive: {
763 return Signedness::kNegative;
764 }
765 }
766 break;
767 }
768 case Signedness::kNegative: {
769 switch (rhs) {
770 case Signedness::kPositiveOrNegative: {
771 break;
772 }
773 case Signedness::kStrictlyNegative:
774 case Signedness::kNegative: {
775 return Signedness::kPositive;
776 }
777 case Signedness::kStrictlyPositive:
778 case Signedness::kPositive: {
779 return Signedness::kNegative;
780 }
781 }
782 break;
783 }
784 case Signedness::kStrictlyPositive: {
785 return rhs;
786 }
787 case Signedness::kPositive: {
788 switch (rhs) {
789 case Signedness::kPositiveOrNegative: {
790 break;
791 }
792 case Signedness::kStrictlyNegative:
793 case Signedness::kNegative: {
794 return Signedness::kNegative;
795 }
796 case Signedness::kStrictlyPositive:
797 case Signedness::kPositive: {
798 return Signedness::kPositive;
799 }
800 }
801 break;
802 }
803 }
804 return Signedness::kPositiveOrNegative;
805 };
806 }
807
808 Signedness Visit(const SENode* node) {
809 switch (node->GetType()) {
810 case SENode::Constant:
811 return Visit(node->AsSEConstantNode());
812 break;
813 case SENode::RecurrentAddExpr:
814 return Visit(node->AsSERecurrentNode());
815 break;
816 case SENode::Negative:
817 return Visit(node->AsSENegative());
818 break;
819 case SENode::CanNotCompute:
820 return Visit(node->AsSECantCompute());
821 break;
822 case SENode::ValueUnknown:
823 return Visit(node->AsSEValueUnknown());
824 break;
825 case SENode::Add:
826 return VisitExpr(node, GetAddCombiner());
827 break;
828 case SENode::Multiply:
829 return VisitExpr(node, GetMulCombiner());
830 break;
831 }
832 return Signedness::kPositiveOrNegative;
833 }
834
835 // Returns the signedness of a constant |node|.
836 Signedness Visit(const SEConstantNode* node) {
837 if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
838 if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
839 if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
840 return Signedness::kPositiveOrNegative;
841 }
842
843 // Returns the signedness of an unknown |node| based on its type.
844 Signedness Visit(const SEValueUnknown* node) {
845 Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
846 analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
847 assert(type && "Can't retrieve a type for the instruction");
848 analysis::Integer* int_type = type->AsInteger();
849 assert(type && "Can't retrieve an integer type for the instruction");
850 return int_type->IsSigned() ? Signedness::kPositiveOrNegative
851 : Signedness::kPositive;
852 }
853
854 // Returns the signedness of a recurring expression.
855 Signedness Visit(const SERecurrentNode* node) {
856 Signedness coeff_sign = Visit(node->GetCoefficient());
857 // SERecurrentNode represent an affine expression in the range [0,
858 // loop_bound], so the result cannot be strictly positive or negative.
859 switch (coeff_sign) {
860 default:
861 break;
862 case Signedness::kStrictlyNegative:
863 coeff_sign = Signedness::kNegative;
864 break;
865 case Signedness::kStrictlyPositive:
866 coeff_sign = Signedness::kPositive;
867 break;
868 }
869 return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
870 }
871
872 // Returns the signedness of a negation |node|.
873 Signedness Visit(const SENegative* node) {
874 switch (Visit(*node->begin())) {
875 case Signedness::kPositiveOrNegative: {
876 return Signedness::kPositiveOrNegative;
877 }
878 case Signedness::kStrictlyNegative: {
879 return Signedness::kStrictlyPositive;
880 }
881 case Signedness::kNegative: {
882 return Signedness::kPositive;
883 }
884 case Signedness::kStrictlyPositive: {
885 return Signedness::kStrictlyNegative;
886 }
887 case Signedness::kPositive: {
888 return Signedness::kNegative;
889 }
890 }
891 return Signedness::kPositiveOrNegative;
892 }
893
894 Signedness Visit(const SECantCompute*) {
895 return Signedness::kPositiveOrNegative;
896 }
897
898 // Returns the signedness of a binary expression by using the combiner
899 // |reduce|.
900 Signedness VisitExpr(
901 const SENode* node,
902 std::function<Signedness(Signedness, Signedness)> reduce) {
903 Signedness result = Visit(*node->begin());
904 for (const SENode* operand : make_range(++node->begin(), node->end())) {
905 if (result == Signedness::kPositiveOrNegative) {
906 return Signedness::kPositiveOrNegative;
907 }
908 result = reduce(result, Visit(operand));
909 }
910 return result;
911 }
912
913 IRContext* context_;
914};
915} // namespace
916
917bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
918 bool* is_gt_zero) const {
919 return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
920}
921
922bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
923 SENode* node, bool* is_ge_zero) const {
924 return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
925}
926
927namespace {
928
929// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
930// if |node| is not in the chain, returns the original chain.
931static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
932 const SENode* node) {
933 SENode* lhs = mul->GetChildren()[0];
934 SENode* rhs = mul->GetChildren()[1];
935 if (lhs == node) {
936 return rhs;
937 }
938 if (rhs == node) {
939 return lhs;
940 }
941 if (lhs->AsSEMultiplyNode()) {
942 SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
943 if (res != lhs)
944 return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
945 }
946 if (rhs->AsSEMultiplyNode()) {
947 SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
948 if (res != rhs)
949 return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
950 }
951
952 return mul;
953}
954} // namespace
955
956std::pair<SExpression, int64_t> SExpression::operator/(
957 SExpression rhs_wrapper) const {
958 SENode* lhs = node_;
959 SENode* rhs = rhs_wrapper.node_;
960 // Check for division by 0.
961 if (rhs->AsSEConstantNode() &&
962 !rhs->AsSEConstantNode()->FoldToSingleValue()) {
963 return {scev_->CreateCantComputeNode(), 0};
964 }
965
966 // Trivial case.
967 if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
968 int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
969 int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
970 return {scev_->CreateConstant(lhs_value / rhs_value),
971 lhs_value % rhs_value};
972 }
973
974 // look for a "c U / U" pattern.
975 if (lhs->AsSEMultiplyNode()) {
976 assert(lhs->GetChildren().size() == 2 &&
977 "More than 2 operand for a multiply node.");
978 SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
979 if (res != lhs) {
980 return {res, 0};
981 }
982 }
983
984 return {scev_->CreateCantComputeNode(), 0};
985}
986
987} // namespace opt
988} // namespace spvtools
989