1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASI,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
16#define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
17
18#include <algorithm>
19#include <memory>
20#include <string>
21#include <vector>
22
23#include "source/opt/tree_iterator.h"
24
25namespace spvtools {
26namespace opt {
27
28class Loop;
29class ScalarEvolutionAnalysis;
30class SEConstantNode;
31class SERecurrentNode;
32class SEAddNode;
33class SEMultiplyNode;
34class SENegative;
35class SEValueUnknown;
36class SECantCompute;
37
38// Abstract class representing a node in the scalar evolution DAG. Each node
39// contains a vector of pointers to its children and each subclass of SENode
40// implements GetType and an As method to allow casting. SENodes can be hashed
41// using the SENodeHash functor. The vector of children is sorted when a node is
42// added. This is important as it allows the hash of X+Y to be the same as Y+X.
43class SENode {
44 public:
45 enum SENodeType {
46 Constant,
47 RecurrentAddExpr,
48 Add,
49 Multiply,
50 Negative,
51 ValueUnknown,
52 CanNotCompute
53 };
54
55 using ChildContainerType = std::vector<SENode*>;
56
57 explicit SENode(ScalarEvolutionAnalysis* parent_analysis)
58 : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {}
59
60 virtual SENodeType GetType() const = 0;
61
62 virtual ~SENode() {}
63
64 virtual inline void AddChild(SENode* child) {
65 // If this is a constant node, assert.
66 if (AsSEConstantNode()) {
67 assert(false && "Trying to add a child node to a constant!");
68 }
69
70 // Find the first point in the vector where |child| is greater than the node
71 // currently in the vector.
72 auto find_first_less_than = [child](const SENode* node) {
73 return child->unique_id_ <= node->unique_id_;
74 };
75
76 auto position = std::find_if_not(children_.begin(), children_.end(),
77 find_first_less_than);
78 // Children are sorted so the hashing and equality operator will be the same
79 // for a node with the same children. X+Y should be the same as Y+X.
80 children_.insert(position, child);
81 }
82
83 // Get the type as an std::string. This is used to represent the node in the
84 // dot output and is used to hash the type as well.
85 std::string AsString() const;
86
87 // Dump the SENode and its immediate children, if |recurse| is true then it
88 // will recurse through all children to print the DAG starting from this node
89 // as a root.
90 void DumpDot(std::ostream& out, bool recurse = false) const;
91
92 // Checks if two nodes are the same by hashing them.
93 bool operator==(const SENode& other) const;
94
95 // Checks if two nodes are not the same by comparing the hashes.
96 bool operator!=(const SENode& other) const;
97
98 // Return the child node at |index|.
99 inline SENode* GetChild(size_t index) { return children_[index]; }
100 inline const SENode* GetChild(size_t index) const { return children_[index]; }
101
102 // Iterator to iterate over the child nodes.
103 using iterator = ChildContainerType::iterator;
104 using const_iterator = ChildContainerType::const_iterator;
105
106 // Iterate over immediate child nodes.
107 iterator begin() { return children_.begin(); }
108 iterator end() { return children_.end(); }
109
110 // Constant overloads for iterating over immediate child nodes.
111 const_iterator begin() const { return children_.cbegin(); }
112 const_iterator end() const { return children_.cend(); }
113 const_iterator cbegin() { return children_.cbegin(); }
114 const_iterator cend() { return children_.cend(); }
115
116 // Collect all the recurrent nodes in this SENode
117 std::vector<SERecurrentNode*> CollectRecurrentNodes() {
118 std::vector<SERecurrentNode*> recurrent_nodes{};
119
120 if (auto recurrent_node = AsSERecurrentNode()) {
121 recurrent_nodes.push_back(recurrent_node);
122 }
123
124 for (auto child : GetChildren()) {
125 auto child_recurrent_nodes = child->CollectRecurrentNodes();
126 recurrent_nodes.insert(recurrent_nodes.end(),
127 child_recurrent_nodes.begin(),
128 child_recurrent_nodes.end());
129 }
130
131 return recurrent_nodes;
132 }
133
134 // Collect all the value unknown nodes in this SENode
135 std::vector<SEValueUnknown*> CollectValueUnknownNodes() {
136 std::vector<SEValueUnknown*> value_unknown_nodes{};
137
138 if (auto value_unknown_node = AsSEValueUnknown()) {
139 value_unknown_nodes.push_back(value_unknown_node);
140 }
141
142 for (auto child : GetChildren()) {
143 auto child_value_unknown_nodes = child->CollectValueUnknownNodes();
144 value_unknown_nodes.insert(value_unknown_nodes.end(),
145 child_value_unknown_nodes.begin(),
146 child_value_unknown_nodes.end());
147 }
148
149 return value_unknown_nodes;
150 }
151
152 // Iterator to iterate over the entire DAG. Even though we are using the tree
153 // iterator it should still be safe to iterate over. However, nodes with
154 // multiple parents will be visited multiple times, unlike in a tree.
155 using dag_iterator = TreeDFIterator<SENode>;
156 using const_dag_iterator = TreeDFIterator<const SENode>;
157
158 // Iterate over all child nodes in the graph.
159 dag_iterator graph_begin() { return dag_iterator(this); }
160 dag_iterator graph_end() { return dag_iterator(); }
161 const_dag_iterator graph_begin() const { return graph_cbegin(); }
162 const_dag_iterator graph_end() const { return graph_cend(); }
163 const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); }
164 const_dag_iterator graph_cend() const { return const_dag_iterator(); }
165
166 // Return the vector of immediate children.
167 const ChildContainerType& GetChildren() const { return children_; }
168 ChildContainerType& GetChildren() { return children_; }
169
170 // Return true if this node is a cant compute node.
171 bool IsCantCompute() const { return GetType() == CanNotCompute; }
172
173// Implements a casting method for each type.
174// clang-format off
175#define DeclareCastMethod(target) \
176 virtual target* As##target() { return nullptr; } \
177 virtual const target* As##target() const { return nullptr; }
178 DeclareCastMethod(SEConstantNode)
179 DeclareCastMethod(SERecurrentNode)
180 DeclareCastMethod(SEAddNode)
181 DeclareCastMethod(SEMultiplyNode)
182 DeclareCastMethod(SENegative)
183 DeclareCastMethod(SEValueUnknown)
184 DeclareCastMethod(SECantCompute)
185#undef DeclareCastMethod
186
187 // Get the analysis which has this node in its cache.
188 inline ScalarEvolutionAnalysis* GetParentAnalysis() const {
189 return parent_analysis_;
190 }
191
192 protected:
193 ChildContainerType children_;
194
195 ScalarEvolutionAnalysis* parent_analysis_;
196
197 // The unique id of this node, assigned on creation by incrementing the static
198 // node count.
199 uint32_t unique_id_;
200
201 // The number of nodes created.
202 static uint32_t NumberOfNodes;
203};
204// clang-format on
205
206// Function object to handle the hashing of SENodes. Hashing algorithm hashes
207// the type (as a string), the literal value of any constants, and the child
208// pointers which are assumed to be unique.
209struct SENodeHash {
210 size_t operator()(const std::unique_ptr<SENode>& node) const;
211 size_t operator()(const SENode* node) const;
212};
213
214// A node representing a constant integer.
215class SEConstantNode : public SENode {
216 public:
217 SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value)
218 : SENode(parent_analysis), literal_value_(value) {}
219
220 SENodeType GetType() const final { return Constant; }
221
222 int64_t FoldToSingleValue() const { return literal_value_; }
223
224 SEConstantNode* AsSEConstantNode() override { return this; }
225 const SEConstantNode* AsSEConstantNode() const override { return this; }
226
227 inline void AddChild(SENode*) final {
228 assert(false && "Attempting to add a child to a constant node!");
229 }
230
231 protected:
232 int64_t literal_value_;
233};
234
235// A node representing a recurrent expression in the code. A recurrent
236// expression is an expression whose value can be expressed as a linear
237// expression of the loop iterations. Such as an induction variable. The actual
238// value of a recurrent expression is coefficent_ * iteration + offset_, hence
239// an induction variable i=0, i++ becomes a recurrent expression with an offset
240// of zero and a coefficient of one.
241class SERecurrentNode : public SENode {
242 public:
243 SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop)
244 : SENode(parent_analysis), loop_(loop) {}
245
246 SENodeType GetType() const final { return RecurrentAddExpr; }
247
248 inline void AddCoefficient(SENode* child) {
249 coefficient_ = child;
250 SENode::AddChild(child);
251 }
252
253 inline void AddOffset(SENode* child) {
254 offset_ = child;
255 SENode::AddChild(child);
256 }
257
258 inline const SENode* GetCoefficient() const { return coefficient_; }
259 inline SENode* GetCoefficient() { return coefficient_; }
260
261 inline const SENode* GetOffset() const { return offset_; }
262 inline SENode* GetOffset() { return offset_; }
263
264 // Return the loop which this recurrent expression is recurring within.
265 const Loop* GetLoop() const { return loop_; }
266
267 SERecurrentNode* AsSERecurrentNode() override { return this; }
268 const SERecurrentNode* AsSERecurrentNode() const override { return this; }
269
270 private:
271 SENode* coefficient_;
272 SENode* offset_;
273 const Loop* loop_;
274};
275
276// A node representing an addition operation between child nodes.
277class SEAddNode : public SENode {
278 public:
279 explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis)
280 : SENode(parent_analysis) {}
281
282 SENodeType GetType() const final { return Add; }
283
284 SEAddNode* AsSEAddNode() override { return this; }
285 const SEAddNode* AsSEAddNode() const override { return this; }
286};
287
288// A node representing a multiply operation between child nodes.
289class SEMultiplyNode : public SENode {
290 public:
291 explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis)
292 : SENode(parent_analysis) {}
293
294 SENodeType GetType() const final { return Multiply; }
295
296 SEMultiplyNode* AsSEMultiplyNode() override { return this; }
297 const SEMultiplyNode* AsSEMultiplyNode() const override { return this; }
298};
299
300// A node representing a unary negative operation.
301class SENegative : public SENode {
302 public:
303 explicit SENegative(ScalarEvolutionAnalysis* parent_analysis)
304 : SENode(parent_analysis) {}
305
306 SENodeType GetType() const final { return Negative; }
307
308 SENegative* AsSENegative() override { return this; }
309 const SENegative* AsSENegative() const override { return this; }
310};
311
312// A node representing a value which we do not know the value of, such as a load
313// instruction.
314class SEValueUnknown : public SENode {
315 public:
316 // SEValueUnknowns must come from an instruction |unique_id| is the unique id
317 // of that instruction. This is so we cancompare value unknowns and have a
318 // unique value unknown for each instruction.
319 SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id)
320 : SENode(parent_analysis), result_id_(result_id) {}
321
322 SENodeType GetType() const final { return ValueUnknown; }
323
324 SEValueUnknown* AsSEValueUnknown() override { return this; }
325 const SEValueUnknown* AsSEValueUnknown() const override { return this; }
326
327 inline uint32_t ResultId() const { return result_id_; }
328
329 private:
330 uint32_t result_id_;
331};
332
333// A node which we cannot reason about at all.
334class SECantCompute : public SENode {
335 public:
336 explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis)
337 : SENode(parent_analysis) {}
338
339 SENodeType GetType() const final { return CanNotCompute; }
340
341 SECantCompute* AsSECantCompute() override { return this; }
342 const SECantCompute* AsSECantCompute() const override { return this; }
343};
344
345} // namespace opt
346} // namespace spvtools
347#endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
348