1 | // Copyright (c) 2018 Google LLC. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_ |
16 | #define SOURCE_OPT_SCALAR_ANALYSIS_H_ |
17 | |
18 | #include <algorithm> |
19 | #include <cstdint> |
20 | #include <map> |
21 | #include <memory> |
22 | #include <unordered_set> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "source/opt/basic_block.h" |
27 | #include "source/opt/instruction.h" |
28 | #include "source/opt/scalar_analysis_nodes.h" |
29 | |
30 | namespace spvtools { |
31 | namespace opt { |
32 | |
33 | class IRContext; |
34 | class Loop; |
35 | |
36 | // Manager for the Scalar Evolution analysis. Creates and maintains a DAG of |
37 | // scalar operations generated from analysing the use def graph from incoming |
38 | // instructions. Each node is hashed as it is added so like node (for instance, |
39 | // two induction variables i=0,i++ and j=0,j++) become the same node. After |
40 | // creating a DAG with AnalyzeInstruction it can the be simplified into a more |
41 | // usable form with SimplifyExpression. |
42 | class ScalarEvolutionAnalysis { |
43 | public: |
44 | explicit ScalarEvolutionAnalysis(IRContext* context); |
45 | |
46 | // Create a unary negative node on |operand|. |
47 | SENode* CreateNegation(SENode* operand); |
48 | |
49 | // Creates a subtraction between the two operands by adding |operand_1| to the |
50 | // negation of |operand_2|. |
51 | SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2); |
52 | |
53 | // Create an addition node between two operands. The |simplify| when set will |
54 | // allow the function to return an SEConstant instead of an addition if the |
55 | // two input operands are also constant. |
56 | SENode* CreateAddNode(SENode* operand_1, SENode* operand_2); |
57 | |
58 | // Create a multiply node between two operands. |
59 | SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2); |
60 | |
61 | // Create a node representing a constant integer. |
62 | SENode* CreateConstant(int64_t integer); |
63 | |
64 | // Create a value unknown node, such as a load. |
65 | SENode* CreateValueUnknownNode(const Instruction* inst); |
66 | |
67 | // Create a CantComputeNode. Used to exit out of analysis. |
68 | SENode* CreateCantComputeNode(); |
69 | |
70 | // Create a new recurrent node with |offset| and |coefficient|, with respect |
71 | // to |loop|. |
72 | SENode* CreateRecurrentExpression(const Loop* loop, SENode* offset, |
73 | SENode* coefficient); |
74 | |
75 | // Construct the DAG by traversing use def chain of |inst|. |
76 | SENode* AnalyzeInstruction(const Instruction* inst); |
77 | |
78 | // Simplify the |node| by grouping like terms or if contains a recurrent |
79 | // expression, rewrite the graph so the whole DAG (from |node| down) is in |
80 | // terms of that recurrent expression. |
81 | // |
82 | // For example. |
83 | // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be |
84 | // transformed into Rec(1,1). |
85 | // |
86 | // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are |
87 | // ValueUnknown nodes (such as a load instruction). |
88 | SENode* SimplifyExpression(SENode* node); |
89 | |
90 | // Add |prospective_node| into the cache and return a raw pointer to it. If |
91 | // |prospective_node| is already in the cache just return the raw pointer. |
92 | SENode* GetCachedOrAdd(std::unique_ptr<SENode> prospective_node); |
93 | |
94 | // Checks that the graph starting from |node| is invariant to the |loop|. |
95 | bool IsLoopInvariant(const Loop* loop, const SENode* node) const; |
96 | |
97 | // Sets |is_gt_zero| to true if |node| represent a value always strictly |
98 | // greater than 0. The result of |is_gt_zero| is valid only if the function |
99 | // returns true. |
100 | bool IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const; |
101 | |
102 | // Sets |is_ge_zero| to true if |node| represent a value greater or equals to |
103 | // 0. The result of |is_ge_zero| is valid only if the function returns true. |
104 | bool IsAlwaysGreaterOrEqualToZero(SENode* node, bool* is_ge_zero) const; |
105 | |
106 | // Find the recurrent term belonging to |loop| in the graph starting from |
107 | // |node| and return the coefficient of that recurrent term. Constant zero |
108 | // will be returned if no recurrent could be found. |node| should be in |
109 | // simplest form. |
110 | SENode* GetCoefficientFromRecurrentTerm(SENode* node, const Loop* loop); |
111 | |
112 | // Return a rebuilt graph starting from |node| with the recurrent expression |
113 | // belonging to |loop| being zeroed out. Returned node will be simplified. |
114 | SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const Loop* loop); |
115 | |
116 | // Return the recurrent term belonging to |loop| if it appears in the graph |
117 | // starting at |node| or null if it doesn't. |
118 | SERecurrentNode* GetRecurrentTerm(SENode* node, const Loop* loop); |
119 | |
120 | SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child); |
121 | |
122 | // The loops in |loop_pair| will be considered the same when constructing |
123 | // SERecurrentNode objects. This enables analysing dependencies that will be |
124 | // created during loop fusion. |
125 | void AddLoopsToPretendAreTheSame( |
126 | const std::pair<const Loop*, const Loop*>& loop_pair) { |
127 | pretend_equal_[std::get<1>(loop_pair)] = std::get<0>(loop_pair); |
128 | } |
129 | |
130 | private: |
131 | SENode* AnalyzeConstant(const Instruction* inst); |
132 | |
133 | // Handles both addition and subtraction. If the |instruction| is OpISub |
134 | // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then |
135 | // the result will be op1+op2. |instruction| must be OpIAdd or OpISub. |
136 | SENode* AnalyzeAddOp(const Instruction* instruction); |
137 | |
138 | SENode* AnalyzeMultiplyOp(const Instruction* multiply); |
139 | |
140 | SENode* AnalyzePhiInstruction(const Instruction* phi); |
141 | |
142 | IRContext* context_; |
143 | |
144 | // A map of instructions to SENodes. This is used to track recurrent |
145 | // expressions as they are added when analyzing instructions. Recurrent |
146 | // expressions come from phi nodes which by nature can include recursion so we |
147 | // check if nodes have already been built when analyzing instructions. |
148 | std::map<const Instruction*, SENode*> recurrent_node_map_; |
149 | |
150 | // On creation we create and cache the CantCompute node so we not need to |
151 | // perform a needless create step. |
152 | SENode* cached_cant_compute_; |
153 | |
154 | // Helper functor to allow two unique_ptr to nodes to be compare. Only |
155 | // needed |
156 | // for the unordered_set implementation. |
157 | struct NodePointersEquality { |
158 | bool operator()(const std::unique_ptr<SENode>& lhs, |
159 | const std::unique_ptr<SENode>& rhs) const { |
160 | return *lhs == *rhs; |
161 | } |
162 | }; |
163 | |
164 | // Cache of nodes. All pointers to the nodes are references to the memory |
165 | // managed by they set. |
166 | std::unordered_set<std::unique_ptr<SENode>, SENodeHash, NodePointersEquality> |
167 | node_cache_; |
168 | |
169 | // Loops that should be considered the same for performing analysis for loop |
170 | // fusion. |
171 | std::map<const Loop*, const Loop*> pretend_equal_; |
172 | }; |
173 | |
174 | // Wrapping class to manipulate SENode pointer using + - * / operators. |
175 | class SExpression { |
176 | public: |
177 | // Implicit on purpose ! |
178 | SExpression(SENode* node) |
179 | : node_(node->GetParentAnalysis()->SimplifyExpression(node)), |
180 | scev_(node->GetParentAnalysis()) {} |
181 | |
182 | inline operator SENode*() const { return node_; } |
183 | inline SENode* operator->() const { return node_; } |
184 | const SENode& operator*() const { return *node_; } |
185 | |
186 | inline ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() const { |
187 | return scev_; |
188 | } |
189 | |
190 | inline SExpression operator+(SENode* rhs) const; |
191 | template <typename T, |
192 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
193 | inline SExpression operator+(T integer) const; |
194 | inline SExpression operator+(SExpression rhs) const; |
195 | |
196 | inline SExpression operator-() const; |
197 | inline SExpression operator-(SENode* rhs) const; |
198 | template <typename T, |
199 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
200 | inline SExpression operator-(T integer) const; |
201 | inline SExpression operator-(SExpression rhs) const; |
202 | |
203 | inline SExpression operator*(SENode* rhs) const; |
204 | template <typename T, |
205 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
206 | inline SExpression operator*(T integer) const; |
207 | inline SExpression operator*(SExpression rhs) const; |
208 | |
209 | template <typename T, |
210 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
211 | inline std::pair<SExpression, int64_t> operator/(T integer) const; |
212 | // Try to perform a division. Returns the pair <this.node_ / rhs, division |
213 | // remainder>. If it fails to simplify it, the function returns a |
214 | // CanNotCompute node. |
215 | std::pair<SExpression, int64_t> operator/(SExpression rhs) const; |
216 | |
217 | private: |
218 | SENode* node_; |
219 | ScalarEvolutionAnalysis* scev_; |
220 | }; |
221 | |
222 | inline SExpression SExpression::operator+(SENode* rhs) const { |
223 | return scev_->CreateAddNode(node_, rhs); |
224 | } |
225 | |
226 | template <typename T, |
227 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
228 | inline SExpression SExpression::operator+(T integer) const { |
229 | return *this + scev_->CreateConstant(integer); |
230 | } |
231 | |
232 | inline SExpression SExpression::operator+(SExpression rhs) const { |
233 | return *this + rhs.node_; |
234 | } |
235 | |
236 | inline SExpression SExpression::operator-() const { |
237 | return scev_->CreateNegation(node_); |
238 | } |
239 | |
240 | inline SExpression SExpression::operator-(SENode* rhs) const { |
241 | return *this + scev_->CreateNegation(rhs); |
242 | } |
243 | |
244 | template <typename T, |
245 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
246 | inline SExpression SExpression::operator-(T integer) const { |
247 | return *this - scev_->CreateConstant(integer); |
248 | } |
249 | |
250 | inline SExpression SExpression::operator-(SExpression rhs) const { |
251 | return *this - rhs.node_; |
252 | } |
253 | |
254 | inline SExpression SExpression::operator*(SENode* rhs) const { |
255 | return scev_->CreateMultiplyNode(node_, rhs); |
256 | } |
257 | |
258 | template <typename T, |
259 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
260 | inline SExpression SExpression::operator*(T integer) const { |
261 | return *this * scev_->CreateConstant(integer); |
262 | } |
263 | |
264 | inline SExpression SExpression::operator*(SExpression rhs) const { |
265 | return *this * rhs.node_; |
266 | } |
267 | |
268 | template <typename T, |
269 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
270 | inline std::pair<SExpression, int64_t> SExpression::operator/(T integer) const { |
271 | return *this / scev_->CreateConstant(integer); |
272 | } |
273 | |
274 | template <typename T, |
275 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
276 | inline SExpression operator+(T lhs, SExpression rhs) { |
277 | return rhs + lhs; |
278 | } |
279 | inline SExpression operator+(SENode* lhs, SExpression rhs) { return rhs + lhs; } |
280 | |
281 | template <typename T, |
282 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
283 | inline SExpression operator-(T lhs, SExpression rhs) { |
284 | // NOLINTNEXTLINE(whitespace/braces) |
285 | return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} - |
286 | rhs; |
287 | } |
288 | inline SExpression operator-(SENode* lhs, SExpression rhs) { |
289 | // NOLINTNEXTLINE(whitespace/braces) |
290 | return SExpression{lhs} - rhs; |
291 | } |
292 | |
293 | template <typename T, |
294 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
295 | inline SExpression operator*(T lhs, SExpression rhs) { |
296 | return rhs * lhs; |
297 | } |
298 | inline SExpression operator*(SENode* lhs, SExpression rhs) { return rhs * lhs; } |
299 | |
300 | template <typename T, |
301 | typename std::enable_if<std::is_integral<T>::value, int>::type> |
302 | inline std::pair<SExpression, int64_t> operator/(T lhs, SExpression rhs) { |
303 | // NOLINTNEXTLINE(whitespace/braces) |
304 | return SExpression{rhs.GetScalarEvolutionAnalysis()->CreateConstant(lhs)} / |
305 | rhs; |
306 | } |
307 | inline std::pair<SExpression, int64_t> operator/(SENode* lhs, SExpression rhs) { |
308 | // NOLINTNEXTLINE(whitespace/braces) |
309 | return SExpression{lhs} / rhs; |
310 | } |
311 | |
312 | } // namespace opt |
313 | } // namespace spvtools |
314 | #endif // SOURCE_OPT_SCALAR_ANALYSIS_H_ |
315 | |