1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/scalar_analysis.h"
16
17#include <functional>
18#include <map>
19#include <memory>
20#include <set>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25// Simplifies scalar analysis DAGs.
26//
27// 1. Given a node passed to SimplifyExpression we first simplify the graph by
28// calling SimplifyPolynomial. This groups like nodes following basic arithmetic
29// rules, so multiple adds of the same load instruction could be grouped into a
30// single multiply of that instruction. SimplifyPolynomial will traverse the DAG
31// and build up an accumulator buffer for each class of instruction it finds.
32// For example take the loop:
33// for (i=0, i<N; i++) { i+B+23+4+B+C; }
34// In this example the expression "i+B+23+4+B+C" has four classes of
35// instruction, induction variable i, the two value unknowns B and C, and the
36// constants. The accumulator buffer is then used to rebuild the graph using
37// the accumulation of each type. This example would then be folded into
38// i+2*B+C+27.
39//
40// This new graph contains a single add node (or if only one type found then
41// just that node) with each of the like terms (or multiplication node) as a
42// child.
43//
44// 2. FoldRecurrentAddExpressions is then called on this new DAG. This will take
45// RecurrentAddExpressions which are with respect to the same loop and fold them
46// into a single new RecurrentAddExpression with respect to that same loop. An
47// expression can have multiple RecurrentAddExpression's with respect to
48// different loops in the case of nested loops. These expressions cannot be
49// folded further. For example:
50//
51// for (i=0; i<N;i++) for(j=0,k=1; j<N;++j,++k)
52//
53// The 'j' and 'k' are RecurrentAddExpression with respect to the second loop
54// and 'i' to the first. If 'j' and 'k' are used in an expression together then
55// they will be folded into a new RecurrentAddExpression with respect to the
56// second loop in that expression.
57//
58//
59// 3. If the DAG now only contains a single RecurrentAddExpression we can now
60// perform a final optimization SimplifyRecurrentAddExpression. This will
61// transform the entire DAG into a RecurrentAddExpression. Additions to the
62// RecurrentAddExpression are added to the offset field and multiplications to
63// the coefficient.
64//
65
66namespace spvtools {
67namespace opt {
68
69// Implementation of the functions which are used to simplify the graph. Graphs
70// of unknowns, multiplies, additions, and constants can be turned into a linear
71// add node with each term as a child. For instance a large graph built from, X
72// + X*2 + Y - Y*3 + 4 - 1, would become a single add expression with the
73// children X*3, -Y*2, and the constant 3. Graphs containing a recurrent
74// expression will be simplified to represent the entire graph around a single
75// recurrent expression. So for an induction variable (i=0, i++) if you add 1 to
76// i in an expression we can rewrite the graph of that expression to be a single
77// recurrent expression of (i=1,i++).
78class SENodeSimplifyImpl {
79 public:
80 SENodeSimplifyImpl(ScalarEvolutionAnalysis* analysis,
81 SENode* node_to_simplify)
82 : analysis_(*analysis),
83 node_(node_to_simplify),
84 constant_accumulator_(0) {}
85
86 // Return the result of the simplification.
87 SENode* Simplify();
88
89 private:
90 // Recursively descend through the graph to build up the accumulator objects
91 // which are used to flatten the graph. |child| is the node currenty being
92 // traversed and the |negation| flag is used to signify that this operation
93 // was preceded by a unary negative operation and as such the result should be
94 // negated.
95 void GatherAccumulatorsFromChildNodes(SENode* new_node, SENode* child,
96 bool negation);
97
98 // Given a |multiply| node add to the accumulators for the term type within
99 // the |multiply| expression. Will return true if the accumulators could be
100 // calculated successfully. If the |multiply| is in any form other than
101 // unknown*constant then we return false. |negation| signifies that the
102 // operation was preceded by a unary negative.
103 bool AccumulatorsFromMultiply(SENode* multiply, bool negation);
104
105 SERecurrentNode* UpdateCoefficient(SERecurrentNode* recurrent,
106 int64_t coefficient_update) const;
107
108 // If the graph contains a recurrent expression, ie, an expression with the
109 // loop iterations as a term in the expression, then the whole expression
110 // can be rewritten to be a recurrent expression.
111 SENode* SimplifyRecurrentAddExpression(SERecurrentNode* node);
112
113 // Simplify the whole graph by linking like terms together in a single flat
114 // add node. So X*2 + Y -Y + 3 +6 would become X*2 + 9. Where X and Y are a
115 // ValueUnknown node (i.e, a load) or a recurrent expression.
116 SENode* SimplifyPolynomial();
117
118 // Each recurrent expression is an expression with respect to a specific loop.
119 // If we have two different recurrent terms with respect to the same loop in a
120 // single expression then we can fold those terms into a single new term.
121 // For instance:
122 //
123 // induction i = 0, i++
124 // temp = i*10
125 // array[i+temp]
126 //
127 // We can fold the i + temp into a single expression. Rec(0,1) + Rec(0,10) can
128 // become Rec(0,11).
129 SENode* FoldRecurrentAddExpressions(SENode*);
130
131 // We can eliminate recurrent expressions which have a coefficient of zero by
132 // replacing them with their offset value. We are able to do this because a
133 // recurrent expression represents the equation coefficient*iterations +
134 // offset.
135 SENode* EliminateZeroCoefficientRecurrents(SENode* node);
136
137 // A reference the the analysis which requested the simplification.
138 ScalarEvolutionAnalysis& analysis_;
139
140 // The node being simplified.
141 SENode* node_;
142
143 // An accumulator of the net result of all the constant operations performed
144 // in a graph.
145 int64_t constant_accumulator_;
146
147 // An accumulator for each of the non constant terms in the graph.
148 std::map<SENode*, int64_t> accumulators_;
149};
150
151// From a |multiply| build up the accumulator objects.
152bool SENodeSimplifyImpl::AccumulatorsFromMultiply(SENode* multiply,
153 bool negation) {
154 if (multiply->GetChildren().size() != 2 ||
155 multiply->GetType() != SENode::Multiply)
156 return false;
157
158 SENode* operand_1 = multiply->GetChild(0);
159 SENode* operand_2 = multiply->GetChild(1);
160
161 SENode* value_unknown = nullptr;
162 SENode* constant = nullptr;
163
164 // Work out which operand is the unknown value.
165 if (operand_1->GetType() == SENode::ValueUnknown ||
166 operand_1->GetType() == SENode::RecurrentAddExpr)
167 value_unknown = operand_1;
168 else if (operand_2->GetType() == SENode::ValueUnknown ||
169 operand_2->GetType() == SENode::RecurrentAddExpr)
170 value_unknown = operand_2;
171
172 // Work out which operand is the constant coefficient.
173 if (operand_1->GetType() == SENode::Constant)
174 constant = operand_1;
175 else if (operand_2->GetType() == SENode::Constant)
176 constant = operand_2;
177
178 // If the expression is not a variable multiplied by a constant coefficient,
179 // exit out.
180 if (!(value_unknown && constant)) {
181 return false;
182 }
183
184 int64_t sign = negation ? -1 : 1;
185
186 auto iterator = accumulators_.find(value_unknown);
187 int64_t new_value = constant->AsSEConstantNode()->FoldToSingleValue() * sign;
188 // Add the result of the multiplication to the accumulators.
189 if (iterator != accumulators_.end()) {
190 (*iterator).second += new_value;
191 } else {
192 accumulators_.insert({value_unknown, new_value});
193 }
194
195 return true;
196}
197
198SENode* SENodeSimplifyImpl::Simplify() {
199 // We only handle graphs with an addition, multiplication, or negation, at the
200 // root.
201 if (node_->GetType() != SENode::Add && node_->GetType() != SENode::Multiply &&
202 node_->GetType() != SENode::Negative)
203 return node_;
204
205 SENode* simplified_polynomial = SimplifyPolynomial();
206
207 SERecurrentNode* recurrent_expr = nullptr;
208 node_ = simplified_polynomial;
209
210 // Fold recurrent expressions which are with respect to the same loop into a
211 // single recurrent expression.
212 simplified_polynomial = FoldRecurrentAddExpressions(simplified_polynomial);
213
214 simplified_polynomial =
215 EliminateZeroCoefficientRecurrents(simplified_polynomial);
216
217 // Traverse the immediate children of the new node to find the recurrent
218 // expression. If there is more than one there is nothing further we can do.
219 for (SENode* child : simplified_polynomial->GetChildren()) {
220 if (child->GetType() == SENode::RecurrentAddExpr) {
221 recurrent_expr = child->AsSERecurrentNode();
222 }
223 }
224
225 // We need to count the number of unique recurrent expressions in the DAG to
226 // ensure there is only one.
227 for (auto child_iterator = simplified_polynomial->graph_begin();
228 child_iterator != simplified_polynomial->graph_end(); ++child_iterator) {
229 if (child_iterator->GetType() == SENode::RecurrentAddExpr &&
230 recurrent_expr != child_iterator->AsSERecurrentNode()) {
231 return simplified_polynomial;
232 }
233 }
234
235 if (recurrent_expr) {
236 return SimplifyRecurrentAddExpression(recurrent_expr);
237 }
238
239 return simplified_polynomial;
240}
241
242// Traverse the graph to build up the accumulator objects.
243void SENodeSimplifyImpl::GatherAccumulatorsFromChildNodes(SENode* new_node,
244 SENode* child,
245 bool negation) {
246 int32_t sign = negation ? -1 : 1;
247
248 if (child->GetType() == SENode::Constant) {
249 // Collect all the constants and add them together.
250 constant_accumulator_ +=
251 child->AsSEConstantNode()->FoldToSingleValue() * sign;
252
253 } else if (child->GetType() == SENode::ValueUnknown ||
254 child->GetType() == SENode::RecurrentAddExpr) {
255 // To rebuild the graph of X+X+X*2 into 4*X we count the occurrences of X
256 // and create a new node of count*X after. X can either be a ValueUnknown or
257 // a RecurrentAddExpr. The count for each X is stored in the accumulators_
258 // map.
259
260 auto iterator = accumulators_.find(child);
261 // If we've encountered this term before add to the accumulator for it.
262 if (iterator == accumulators_.end())
263 accumulators_.insert({child, sign});
264 else
265 iterator->second += sign;
266
267 } else if (child->GetType() == SENode::Multiply) {
268 if (!AccumulatorsFromMultiply(child, negation)) {
269 new_node->AddChild(child);
270 }
271
272 } else if (child->GetType() == SENode::Add) {
273 for (SENode* next_child : *child) {
274 GatherAccumulatorsFromChildNodes(new_node, next_child, negation);
275 }
276
277 } else if (child->GetType() == SENode::Negative) {
278 SENode* negated_node = child->GetChild(0);
279 GatherAccumulatorsFromChildNodes(new_node, negated_node, !negation);
280 } else {
281 // If we can't work out how to fold the expression just add it back into
282 // the graph.
283 new_node->AddChild(child);
284 }
285}
286
287SERecurrentNode* SENodeSimplifyImpl::UpdateCoefficient(
288 SERecurrentNode* recurrent, int64_t coefficient_update) const {
289 std::unique_ptr<SERecurrentNode> new_recurrent_node{new SERecurrentNode(
290 recurrent->GetParentAnalysis(), recurrent->GetLoop())};
291
292 SENode* new_coefficient = analysis_.CreateMultiplyNode(
293 recurrent->GetCoefficient(),
294 analysis_.CreateConstant(coefficient_update));
295
296 // See if the node can be simplified.
297 SENode* simplified = analysis_.SimplifyExpression(new_coefficient);
298 if (simplified->GetType() != SENode::CanNotCompute)
299 new_coefficient = simplified;
300
301 if (coefficient_update < 0) {
302 new_recurrent_node->AddOffset(
303 analysis_.CreateNegation(recurrent->GetOffset()));
304 } else {
305 new_recurrent_node->AddOffset(recurrent->GetOffset());
306 }
307
308 new_recurrent_node->AddCoefficient(new_coefficient);
309
310 return analysis_.GetCachedOrAdd(std::move(new_recurrent_node))
311 ->AsSERecurrentNode();
312}
313
314// Simplify all the terms in the polynomial function.
315SENode* SENodeSimplifyImpl::SimplifyPolynomial() {
316 std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
317
318 // Traverse the graph and gather the accumulators from it.
319 GatherAccumulatorsFromChildNodes(new_add.get(), node_, false);
320
321 // Fold all the constants into a single constant node.
322 if (constant_accumulator_ != 0) {
323 new_add->AddChild(analysis_.CreateConstant(constant_accumulator_));
324 }
325
326 for (auto& pair : accumulators_) {
327 SENode* term = pair.first;
328 int64_t count = pair.second;
329
330 // We can eliminate the term completely.
331 if (count == 0) continue;
332
333 if (count == 1) {
334 new_add->AddChild(term);
335 } else if (count == -1 && term->GetType() != SENode::RecurrentAddExpr) {
336 // If the count is -1 we can just add a negative version of that node,
337 // unless it is a recurrent expression as we would rather the negative
338 // goes on the recurrent expressions children. This makes it easier to
339 // work with in other places.
340 new_add->AddChild(analysis_.CreateNegation(term));
341 } else {
342 // Output value unknown terms as count*term and output recurrent
343 // expression terms as rec(offset, coefficient + count) offset and
344 // coefficient are the same as in the original expression.
345 if (term->GetType() == SENode::ValueUnknown) {
346 SENode* count_as_constant = analysis_.CreateConstant(count);
347 new_add->AddChild(
348 analysis_.CreateMultiplyNode(count_as_constant, term));
349 } else {
350 assert(term->GetType() == SENode::RecurrentAddExpr &&
351 "We only handle value unknowns or recurrent expressions");
352
353 // Create a new recurrent expression by adding the count to the
354 // coefficient of the old one.
355 new_add->AddChild(UpdateCoefficient(term->AsSERecurrentNode(), count));
356 }
357 }
358 }
359
360 // If there is only one term in the addition left just return that term.
361 if (new_add->GetChildren().size() == 1) {
362 return new_add->GetChild(0);
363 }
364
365 // If there are no terms left in the addition just return 0.
366 if (new_add->GetChildren().size() == 0) {
367 return analysis_.CreateConstant(0);
368 }
369
370 return analysis_.GetCachedOrAdd(std::move(new_add));
371}
372
373SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) {
374 std::unique_ptr<SEAddNode> new_node{new SEAddNode(&analysis_)};
375
376 // A mapping of loops to the list of recurrent expressions which are with
377 // respect to those loops.
378 std::map<const Loop*, std::vector<std::pair<SERecurrentNode*, bool>>>
379 loops_to_recurrent{};
380
381 bool has_multiple_same_loop_recurrent_terms = false;
382
383 for (SENode* child : *root) {
384 bool negation = false;
385
386 if (child->GetType() == SENode::Negative) {
387 child = child->GetChild(0);
388 negation = true;
389 }
390
391 if (child->GetType() == SENode::RecurrentAddExpr) {
392 const Loop* loop = child->AsSERecurrentNode()->GetLoop();
393
394 SERecurrentNode* rec = child->AsSERecurrentNode();
395 if (loops_to_recurrent.find(loop) == loops_to_recurrent.end()) {
396 loops_to_recurrent[loop] = {std::make_pair(rec, negation)};
397 } else {
398 loops_to_recurrent[loop].push_back(std::make_pair(rec, negation));
399 has_multiple_same_loop_recurrent_terms = true;
400 }
401 } else {
402 new_node->AddChild(child);
403 }
404 }
405
406 if (!has_multiple_same_loop_recurrent_terms) return root;
407
408 for (auto pair : loops_to_recurrent) {
409 std::vector<std::pair<SERecurrentNode*, bool>>& recurrent_expressions =
410 pair.second;
411 const Loop* loop = pair.first;
412
413 std::unique_ptr<SENode> new_coefficient{new SEAddNode(&analysis_)};
414 std::unique_ptr<SENode> new_offset{new SEAddNode(&analysis_)};
415
416 for (auto node_pair : recurrent_expressions) {
417 SERecurrentNode* node = node_pair.first;
418 bool negative = node_pair.second;
419
420 if (!negative) {
421 new_coefficient->AddChild(node->GetCoefficient());
422 new_offset->AddChild(node->GetOffset());
423 } else {
424 new_coefficient->AddChild(
425 analysis_.CreateNegation(node->GetCoefficient()));
426 new_offset->AddChild(analysis_.CreateNegation(node->GetOffset()));
427 }
428 }
429
430 std::unique_ptr<SERecurrentNode> new_recurrent{
431 new SERecurrentNode(&analysis_, loop)};
432
433 SENode* new_coefficient_simplified =
434 analysis_.SimplifyExpression(new_coefficient.get());
435
436 SENode* new_offset_simplified =
437 analysis_.SimplifyExpression(new_offset.get());
438
439 if (new_coefficient_simplified->GetType() == SENode::Constant &&
440 new_coefficient_simplified->AsSEConstantNode()->FoldToSingleValue() ==
441 0) {
442 return new_offset_simplified;
443 }
444
445 new_recurrent->AddCoefficient(new_coefficient_simplified);
446 new_recurrent->AddOffset(new_offset_simplified);
447
448 new_node->AddChild(analysis_.GetCachedOrAdd(std::move(new_recurrent)));
449 }
450
451 // If we only have one child in the add just return that.
452 if (new_node->GetChildren().size() == 1) {
453 return new_node->GetChild(0);
454 }
455
456 return analysis_.GetCachedOrAdd(std::move(new_node));
457}
458
459SENode* SENodeSimplifyImpl::EliminateZeroCoefficientRecurrents(SENode* node) {
460 if (node->GetType() != SENode::Add) return node;
461
462 bool has_change = false;
463
464 std::vector<SENode*> new_children{};
465 for (SENode* child : *node) {
466 if (child->GetType() == SENode::RecurrentAddExpr) {
467 SENode* coefficient = child->AsSERecurrentNode()->GetCoefficient();
468 // If coefficient is zero then we can eliminate the recurrent expression
469 // entirely and just return the offset as the recurrent expression is
470 // representing the equation coefficient*iterations + offset.
471 if (coefficient->GetType() == SENode::Constant &&
472 coefficient->AsSEConstantNode()->FoldToSingleValue() == 0) {
473 new_children.push_back(child->AsSERecurrentNode()->GetOffset());
474 has_change = true;
475 } else {
476 new_children.push_back(child);
477 }
478 } else {
479 new_children.push_back(child);
480 }
481 }
482
483 if (!has_change) return node;
484
485 std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
486
487 for (SENode* child : new_children) {
488 new_add->AddChild(child);
489 }
490
491 return analysis_.GetCachedOrAdd(std::move(new_add));
492}
493
494SENode* SENodeSimplifyImpl::SimplifyRecurrentAddExpression(
495 SERecurrentNode* recurrent_expr) {
496 const std::vector<SENode*>& children = node_->GetChildren();
497
498 std::unique_ptr<SERecurrentNode> recurrent_node{new SERecurrentNode(
499 recurrent_expr->GetParentAnalysis(), recurrent_expr->GetLoop())};
500
501 // Create and simplify the new offset node.
502 std::unique_ptr<SENode> new_offset{
503 new SEAddNode(recurrent_expr->GetParentAnalysis())};
504 new_offset->AddChild(recurrent_expr->GetOffset());
505
506 for (SENode* child : children) {
507 if (child->GetType() != SENode::RecurrentAddExpr) {
508 new_offset->AddChild(child);
509 }
510 }
511
512 // Simplify the new offset.
513 SENode* simplified_child = analysis_.SimplifyExpression(new_offset.get());
514
515 // If the child can be simplified, add the simplified form otherwise, add it
516 // via the usual caching mechanism.
517 if (simplified_child->GetType() != SENode::CanNotCompute) {
518 recurrent_node->AddOffset(simplified_child);
519 } else {
520 recurrent_expr->AddOffset(analysis_.GetCachedOrAdd(std::move(new_offset)));
521 }
522
523 recurrent_node->AddCoefficient(recurrent_expr->GetCoefficient());
524
525 return analysis_.GetCachedOrAdd(std::move(recurrent_node));
526}
527
528/*
529 * Scalar Analysis simplification public methods.
530 */
531
532SENode* ScalarEvolutionAnalysis::SimplifyExpression(SENode* node) {
533 SENodeSimplifyImpl impl{this, node};
534
535 return impl.Simplify();
536}
537
538} // namespace opt
539} // namespace spvtools
540