1 | // Copyright (c) 2015-2016 The Khronos Group Inc. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/val/construct.h" |
16 | |
17 | #include <cassert> |
18 | #include <cstddef> |
19 | #include <unordered_set> |
20 | |
21 | #include "source/val/function.h" |
22 | #include "source/val/validation_state.h" |
23 | |
24 | namespace spvtools { |
25 | namespace val { |
26 | |
27 | Construct::Construct(ConstructType construct_type, BasicBlock* entry, |
28 | BasicBlock* exit, std::vector<Construct*> constructs) |
29 | : type_(construct_type), |
30 | corresponding_constructs_(constructs), |
31 | entry_block_(entry), |
32 | exit_block_(exit) {} |
33 | |
34 | ConstructType Construct::type() const { return type_; } |
35 | |
36 | const std::vector<Construct*>& Construct::corresponding_constructs() const { |
37 | return corresponding_constructs_; |
38 | } |
39 | std::vector<Construct*>& Construct::corresponding_constructs() { |
40 | return corresponding_constructs_; |
41 | } |
42 | |
43 | bool ValidateConstructSize(ConstructType type, size_t size) { |
44 | switch (type) { |
45 | case ConstructType::kSelection: |
46 | return size == 0; |
47 | case ConstructType::kContinue: |
48 | return size == 1; |
49 | case ConstructType::kLoop: |
50 | return size == 1; |
51 | case ConstructType::kCase: |
52 | return size >= 1; |
53 | default: |
54 | assert(1 == 0 && "Type not defined" ); |
55 | } |
56 | return false; |
57 | } |
58 | |
59 | void Construct::set_corresponding_constructs( |
60 | std::vector<Construct*> constructs) { |
61 | assert(ValidateConstructSize(type_, constructs.size())); |
62 | corresponding_constructs_ = constructs; |
63 | } |
64 | |
65 | const BasicBlock* Construct::entry_block() const { return entry_block_; } |
66 | BasicBlock* Construct::entry_block() { return entry_block_; } |
67 | |
68 | const BasicBlock* Construct::exit_block() const { return exit_block_; } |
69 | BasicBlock* Construct::exit_block() { return exit_block_; } |
70 | |
71 | void Construct::set_exit(BasicBlock* block) { exit_block_ = block; } |
72 | |
73 | Construct::ConstructBlockSet Construct::blocks(Function* function) const { |
74 | auto = entry_block(); |
75 | auto merge = exit_block(); |
76 | assert(header); |
77 | int = function->GetBlockDepth(const_cast<BasicBlock*>(header)); |
78 | ConstructBlockSet construct_blocks; |
79 | std::unordered_set<BasicBlock*> ; |
80 | for (auto& other : corresponding_constructs()) { |
81 | corresponding_headers.insert(other->entry_block()); |
82 | } |
83 | std::vector<BasicBlock*> stack; |
84 | stack.push_back(const_cast<BasicBlock*>(header)); |
85 | while (!stack.empty()) { |
86 | BasicBlock* block = stack.back(); |
87 | stack.pop_back(); |
88 | |
89 | if (merge == block && ExitBlockIsMergeBlock()) { |
90 | // Merge block is not part of the construct. |
91 | continue; |
92 | } |
93 | |
94 | if (corresponding_headers.count(block)) { |
95 | // Entered a corresponding construct. |
96 | continue; |
97 | } |
98 | |
99 | int block_depth = function->GetBlockDepth(block); |
100 | if (block_depth < header_depth) { |
101 | // Broke to outer construct. |
102 | continue; |
103 | } |
104 | |
105 | // In a loop, the continue target is at a depth of the loop construct + 1. |
106 | // A selection construct nested directly within the loop construct is also |
107 | // at the same depth. It is valid, however, to branch directly to the |
108 | // continue target from within the selection construct. |
109 | if (block != header && block_depth == header_depth && |
110 | type() == ConstructType::kSelection && |
111 | block->is_type(kBlockTypeContinue)) { |
112 | // Continued to outer construct. |
113 | continue; |
114 | } |
115 | |
116 | if (!construct_blocks.insert(block).second) continue; |
117 | |
118 | if (merge != block) { |
119 | for (auto succ : *block->successors()) { |
120 | // All blocks in the construct must be dominated by the header. |
121 | if (header->dominates(*succ)) { |
122 | stack.push_back(succ); |
123 | } |
124 | } |
125 | } |
126 | } |
127 | |
128 | return construct_blocks; |
129 | } |
130 | |
131 | bool Construct::IsStructuredExit(ValidationState_t& _, BasicBlock* dest) const { |
132 | // Structured Exits: |
133 | // - Selection: |
134 | // - branch to its merge |
135 | // - branch to nearest enclosing loop merge or continue |
136 | // - branch to nearest enclosing switch selection merge |
137 | // - Loop: |
138 | // - branch to its merge |
139 | // - branch to its continue |
140 | // - Continue: |
141 | // - branch to loop header |
142 | // - branch to loop merge |
143 | // |
144 | // Note: we will never see a case construct here. |
145 | assert(type() != ConstructType::kCase); |
146 | if (type() == ConstructType::kLoop) { |
147 | auto = entry_block(); |
148 | auto terminator = header->terminator(); |
149 | auto index = terminator - &_.ordered_instructions()[0]; |
150 | auto merge_inst = &_.ordered_instructions()[index - 1]; |
151 | auto merge_block_id = merge_inst->GetOperandAs<uint32_t>(0u); |
152 | auto continue_block_id = merge_inst->GetOperandAs<uint32_t>(1u); |
153 | if (dest->id() == merge_block_id || dest->id() == continue_block_id) { |
154 | return true; |
155 | } |
156 | } else if (type() == ConstructType::kContinue) { |
157 | auto loop_construct = corresponding_constructs()[0]; |
158 | auto = loop_construct->entry_block(); |
159 | auto terminator = header->terminator(); |
160 | auto index = terminator - &_.ordered_instructions()[0]; |
161 | auto merge_inst = &_.ordered_instructions()[index - 1]; |
162 | auto merge_block_id = merge_inst->GetOperandAs<uint32_t>(0u); |
163 | if (dest == header || dest->id() == merge_block_id) { |
164 | return true; |
165 | } |
166 | } else { |
167 | assert(type() == ConstructType::kSelection); |
168 | if (dest == exit_block()) { |
169 | return true; |
170 | } |
171 | |
172 | // The next block in the traversal is either: |
173 | // i. The header block that declares |block| as its merge block. |
174 | // ii. The immediate dominator of |block|. |
175 | auto NextBlock = [](const BasicBlock* block) -> const BasicBlock* { |
176 | for (auto& use : block->label()->uses()) { |
177 | if ((use.first->opcode() == SpvOpLoopMerge || |
178 | use.first->opcode() == SpvOpSelectionMerge) && |
179 | use.second == 1) |
180 | return use.first->block(); |
181 | } |
182 | return block->immediate_dominator(); |
183 | }; |
184 | |
185 | bool seen_switch = false; |
186 | auto = entry_block(); |
187 | auto block = NextBlock(header); |
188 | while (block) { |
189 | auto terminator = block->terminator(); |
190 | auto index = terminator - &_.ordered_instructions()[0]; |
191 | auto merge_inst = &_.ordered_instructions()[index - 1]; |
192 | if (merge_inst->opcode() == SpvOpLoopMerge || |
193 | (header->terminator()->opcode() != SpvOpSwitch && |
194 | merge_inst->opcode() == SpvOpSelectionMerge && |
195 | terminator->opcode() == SpvOpSwitch)) { |
196 | auto merge_target = merge_inst->GetOperandAs<uint32_t>(0u); |
197 | auto merge_block = merge_inst->function()->GetBlock(merge_target).first; |
198 | if (merge_block->dominates(*header)) { |
199 | block = NextBlock(block); |
200 | continue; |
201 | } |
202 | |
203 | if ((!seen_switch || merge_inst->opcode() == SpvOpLoopMerge) && |
204 | dest->id() == merge_target) { |
205 | return true; |
206 | } else if (merge_inst->opcode() == SpvOpLoopMerge) { |
207 | auto continue_target = merge_inst->GetOperandAs<uint32_t>(1u); |
208 | if (dest->id() == continue_target) { |
209 | return true; |
210 | } |
211 | } |
212 | |
213 | if (terminator->opcode() == SpvOpSwitch) { |
214 | seen_switch = true; |
215 | } |
216 | |
217 | // Hit an enclosing loop and didn't break or continue. |
218 | if (merge_inst->opcode() == SpvOpLoopMerge) return false; |
219 | } |
220 | |
221 | block = NextBlock(block); |
222 | } |
223 | } |
224 | |
225 | return false; |
226 | } |
227 | |
228 | } // namespace val |
229 | } // namespace spvtools |
230 | |