1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_
16#define SOURCE_OPT_SCALAR_ANALYSIS_H_
17
18#include <algorithm>
19#include <cstdint>
20#include <map>
21#include <memory>
22#include <unordered_set>
23#include <utility>
24#include <vector>
25
26#include "source/opt/basic_block.h"
27#include "source/opt/instruction.h"
28#include "source/opt/scalar_analysis_nodes.h"
29
30namespace spvtools {
31namespace opt {
32
33class IRContext;
34class Loop;
35
36// Manager for the Scalar Evolution analysis. Creates and maintains a DAG of
37// scalar operations generated from analysing the use def graph from incoming
38// instructions. Each node is hashed as it is added so like node (for instance,
39// two induction variables i=0,i++ and j=0,j++) become the same node. After
40// creating a DAG with AnalyzeInstruction it can the be simplified into a more
41// usable form with SimplifyExpression.
42class ScalarEvolutionAnalysis {
43 public:
44 explicit ScalarEvolutionAnalysis(IRContext* context);
45
46 // Create a unary negative node on |operand|.
47 SENode* CreateNegation(SENode* operand);
48
49 // Creates a subtraction between the two operands by adding |operand_1| to the
50 // negation of |operand_2|.
51 SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2);
52
53 // Create an addition node between two operands. The |simplify| when set will
54 // allow the function to return an SEConstant instead of an addition if the
55 // two input operands are also constant.
56 SENode* CreateAddNode(SENode* operand_1, SENode* operand_2);
57
58 // Create a multiply node between two operands.
59 SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2);
60
61 // Create a node representing a constant integer.
62 SENode* CreateConstant(int64_t integer);
63
64 // Create a value unknown node, such as a load.
65 SENode* CreateValueUnknownNode(const Instruction* inst);
66
67 // Create a CantComputeNode. Used to exit out of analysis.
68 SENode* CreateCantComputeNode();
69
70 // Create a new recurrent node with |offset| and |coefficient|, with respect
71 // to |loop|.
72 SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset,
73 SENode* coefficient);
74
75 // Construct the DAG by traversing use def chain of |inst|.
76 SENode* AnalyzeInstruction(const Instruction* inst);
77
78 // Simplify the |node| by grouping like terms or if contains a recurrent
79 // expression, rewrite the graph so the whole DAG (from |node| down) is in
80 // terms of that recurrent expression.
81 //
82 // For example.
83 // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be
84 // transformed into Rec(1,1).
85 //
86 // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are
87 // ValueUnknown nodes (such as a load instruction).
88 SENode* SimplifyExpression(SENode* node);
89
90 // Add |prospective_node| into the cache and return a raw pointer to it. If
91 // |prospective_node| is already in the cache just return the raw pointer.
92 SENode* GetCachedOrAdd(std::unique_ptr<SENode> prospective_node);
93
94 // Checks that the graph starting from |node| is invariant to the |loop|.
95 bool IsLoopInvariant(const Loop* loop, const SENode* node) const;
96
97 // Sets |is_gt_zero| to true if |node| represent a value always strictly
98 // greater than 0. The result of |is_gt_zero| is valid only if the function
99 // returns true.
100 bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const;
101
102 // Sets |is_ge_zero| to true if |node| represent a value greater or equals to
103 // 0. The result of |is_ge_zero| is valid only if the function returns true.
104 bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const;
105
106 // Find the recurrent term belonging to |loop| in the graph starting from
107 // |node| and return the coefficient of that recurrent term. Constant zero
108 // will be returned if no recurrent could be found. |node| should be in
109 // simplest form.
110 SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop);
111
112 // Return a rebuilt graph starting from |node| with the recurrent expression
113 // belonging to |loop| being zeroed out. Returned node will be simplified.
114 SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop);
115
116 // Return the recurrent term belonging to |loop| if it appears in the graph
117 // starting at |node| or null if it doesn't.
118 SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop);
119
120 SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child);
121
122 // The loops in |loop_pair| will be considered the same when constructing
123 // SERecurrentNode objects. This enables analysing dependencies that will be
124 // created during loop fusion.
125 void AddLoopsToPretendAreTheSame(
126 const std::pair<const Loop*, const Loop*>& loop_pair) {
127 pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair);
128 }
129
130 private:
131 SENode* AnalyzeConstant(const Instruction* inst);
132
133 // Handles both addition and subtraction. If the |instruction| is OpISub
134 // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then
135 // the result will be op1+op2. |instruction| must be OpIAdd or OpISub.
136 SENode* AnalyzeAddOp(const Instruction* instruction);
137
138 SENode* AnalyzeMultiplyOp(const Instruction* multiply);
139
140 SENode* AnalyzePhiInstruction(const Instruction* phi);
141
142 IRContext* context_;
143
144 // A map of instructions to SENodes. This is used to track recurrent
145 // expressions as they are added when analyzing instructions. Recurrent
146 // expressions come from phi nodes which by nature can include recursion so we
147 // check if nodes have already been built when analyzing instructions.
148 std::map<const Instruction*, SENode*> recurrent_node_map_;
149
150 // On creation we create and cache the CantCompute node so we not need to
151 // perform a needless create step.
152 SENode* cached_cant_compute_;
153
154 // Helper functor to allow two unique_ptr to nodes to be compare. Only
155 // needed
156 // for the unordered_set implementation.
157 struct NodePointersEquality {
158 bool operator()(const std::unique_ptr<SENode>& lhs,
159 const std::unique_ptr<SENode>& rhs) const {
160 return *lhs == *rhs;
161 }
162 };
163
164 // Cache of nodes. All pointers to the nodes are references to the memory
165 // managed by they set.
166 std::unordered_set<std::unique_ptr<SENode>, SENodeHash, NodePointersEquality>
167 node_cache_;
168
169 // Loops that should be considered the same for performing analysis for loop
170 // fusion.
171 std::map<const Loop*, const Loop*> pretend_equal_;
172};
173
174// Wrapping class to manipulate SENode pointer using + - * / operators.
175class SExpression {
176 public:
177 // Implicit on purpose !
178 SExpression(SENode* node)
179 : node_(node->GetParentAnalysis()->SimplifyExpression(node)),
180 scev_(node->GetParentAnalysis()) {}
181
182 inline operator SENode*() const { return node_; }
183 inline SENode* operator->() const { return node_; }
184 const SENode& operator*() const { return *node_; }
185
186 inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const {
187 return scev_;
188 }
189
190 inline SExpression operator+(SENode* rhs) const;
191 template <typename T,
192 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
193 inline SExpression operator+(T integer) const;
194 inline SExpression operator+(SExpression rhs) const;
195
196 inline SExpression operator-() const;
197 inline SExpression operator-(SENode* rhs) const;
198 template <typename T,
199 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
200 inline SExpression operator-(T integer) const;
201 inline SExpression operator-(SExpression rhs) const;
202
203 inline SExpression operator*(SENode* rhs) const;
204 template <typename T,
205 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
206 inline SExpression operator*(T integer) const;
207 inline SExpression operator*(SExpression rhs) const;
208
209 template <typename T,
210 typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
211 inline std::pair<SExpression, int64_t> operator/(T integer) const;
212 // Try to perform a division. Returns the pair <this.node_ / rhs, division
213 // remainder>. If it fails to simplify it, the function returns a
214 // CanNotCompute node.
215 std::pair<SExpression, int64_t> operator/(SExpression rhs) const;
216
217 private:
218 SENode* node_;
219 ScalarEvolutionAnalysis* scev_;
220};
221
222inline SExpression SExpression::operator+(SENode* rhs) const {
223 return scev_->CreateAddNode(node_, rhs);
224}
225
226template <typename T,
227 typename std::enable_if<std::is_integral<T>::value, int>::type>
228inline SExpression SExpression::operator+(T integer) const {
229 return *this + scev_->CreateConstant(integer);
230}
231
232inline SExpression SExpression::operator+(SExpression rhs) const {
233 return *this + rhs.node_;
234}
235
236inline SExpression SExpression::operator-() const {
237 return scev_->CreateNegation(node_);
238}
239
240inline SExpression SExpression::operator-(SENode* rhs) const {
241 return *this + scev_->CreateNegation(rhs);
242}
243
244template <typename T,
245 typename std::enable_if<std::is_integral<T>::value, int>::type>
246inline SExpression SExpression::operator-(T integer) const {
247 return *this - scev_->CreateConstant(integer);
248}
249
250inline SExpression SExpression::operator-(SExpression rhs) const {
251 return *this - rhs.node_;
252}
253
254inline SExpression SExpression::operator*(SENode* rhs) const {
255 return scev_->CreateMultiplyNode(node_, rhs);
256}
257
258template <typename T,
259 typename std::enable_if<std::is_integral<T>::value, int>::type>
260inline SExpression SExpression::operator*(T integer) const {
261 return *this * scev_->CreateConstant(integer);
262}
263
264inline SExpression SExpression::operator*(SExpression rhs) const {
265 return *this * rhs.node_;
266}
267
268template <typename T,
269 typename std::enable_if<std::is_integral<T>::value, int>::type>
270inline std::pair<SExpression, int64_t> SExpression::operator/(T integer) const {
271 return *this / scev_->CreateConstant(integer);
272}
273
274template <typename T,
275 typename std::enable_if<std::is_integral<T>::value, int>::type>
276inline SExpression operator+(T lhs, SExpression rhs) {
277 return rhs + lhs;
278}
279inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; }
280
281template <typename T,
282 typename std::enable_if<std::is_integral<T>::value, int>::type>
283inline SExpression operator-(T lhs, SExpression rhs) {
284 // NOLINTNEXTLINE(whitespace/braces)
285 return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} -
286 rhs;
287}
288inline SExpression operator-(SENode* lhs, SExpression rhs) {
289 // NOLINTNEXTLINE(whitespace/braces)
290 return SExpression{lhs} - rhs;
291}
292
293template <typename T,
294 typename std::enable_if<std::is_integral<T>::value, int>::type>
295inline SExpression operator*(T lhs, SExpression rhs) {
296 return rhs * lhs;
297}
298inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; }
299
300template <typename T,
301 typename std::enable_if<std::is_integral<T>::value, int>::type>
302inline std::pair<SExpression, int64_t> operator/(T lhs, SExpression rhs) {
303 // NOLINTNEXTLINE(whitespace/braces)
304 return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} /
305 rhs;
306}
307inline std::pair<SExpression, int64_t> operator/(SENode* lhs, SExpression rhs) {
308 // NOLINTNEXTLINE(whitespace/braces)
309 return SExpression{lhs} / rhs;
310}
311
312} // namespace opt
313} // namespace spvtools
314#endif // SOURCE_OPT_SCALAR_ANALYSIS_H_
315