| 1 | // Copyright (c) 2015-2016 The Khronos Group Inc. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "source/val/function.h" |
| 16 | |
| 17 | #include <algorithm> |
| 18 | #include <cassert> |
| 19 | #include <sstream> |
| 20 | #include <unordered_map> |
| 21 | #include <unordered_set> |
| 22 | #include <utility> |
| 23 | |
| 24 | #include "source/cfa.h" |
| 25 | #include "source/val/basic_block.h" |
| 26 | #include "source/val/construct.h" |
| 27 | #include "source/val/validate.h" |
| 28 | |
| 29 | namespace spvtools { |
| 30 | namespace val { |
| 31 | |
| 32 | // Universal Limit of ResultID + 1 |
| 33 | static const uint32_t kInvalidId = 0x400000; |
| 34 | |
| 35 | Function::Function(uint32_t function_id, uint32_t result_type_id, |
| 36 | SpvFunctionControlMask function_control, |
| 37 | uint32_t function_type_id) |
| 38 | : id_(function_id), |
| 39 | function_type_id_(function_type_id), |
| 40 | result_type_id_(result_type_id), |
| 41 | function_control_(function_control), |
| 42 | declaration_type_(FunctionDecl::kFunctionDeclUnknown), |
| 43 | end_has_been_registered_(false), |
| 44 | blocks_(), |
| 45 | current_block_(nullptr), |
| 46 | pseudo_entry_block_(0), |
| 47 | pseudo_exit_block_(kInvalidId), |
| 48 | cfg_constructs_(), |
| 49 | variable_ids_(), |
| 50 | parameter_ids_() {} |
| 51 | |
| 52 | bool Function::IsFirstBlock(uint32_t block_id) const { |
| 53 | return !ordered_blocks_.empty() && *first_block() == block_id; |
| 54 | } |
| 55 | |
| 56 | spv_result_t Function::RegisterFunctionParameter(uint32_t parameter_id, |
| 57 | uint32_t type_id) { |
| 58 | assert(current_block_ == nullptr && |
| 59 | "RegisterFunctionParameter can only be called when parsing the binary " |
| 60 | "ouside of a block" ); |
| 61 | // TODO(umar): Validate function parameter type order and count |
| 62 | // TODO(umar): Use these variables to validate parameter type |
| 63 | (void)parameter_id; |
| 64 | (void)type_id; |
| 65 | return SPV_SUCCESS; |
| 66 | } |
| 67 | |
| 68 | spv_result_t Function::RegisterLoopMerge(uint32_t merge_id, |
| 69 | uint32_t continue_id) { |
| 70 | RegisterBlock(merge_id, false); |
| 71 | RegisterBlock(continue_id, false); |
| 72 | BasicBlock& merge_block = blocks_.at(merge_id); |
| 73 | BasicBlock& continue_target_block = blocks_.at(continue_id); |
| 74 | assert(current_block_ && |
| 75 | "RegisterLoopMerge must be called when called within a block" ); |
| 76 | |
| 77 | current_block_->set_type(kBlockTypeLoop); |
| 78 | merge_block.set_type(kBlockTypeMerge); |
| 79 | continue_target_block.set_type(kBlockTypeContinue); |
| 80 | Construct& loop_construct = |
| 81 | AddConstruct({ConstructType::kLoop, current_block_, &merge_block}); |
| 82 | Construct& continue_construct = |
| 83 | AddConstruct({ConstructType::kContinue, &continue_target_block}); |
| 84 | |
| 85 | continue_construct.set_corresponding_constructs({&loop_construct}); |
| 86 | loop_construct.set_corresponding_constructs({&continue_construct}); |
| 87 | merge_block_header_[&merge_block] = current_block_; |
| 88 | if (continue_target_headers_.find(&continue_target_block) == |
| 89 | continue_target_headers_.end()) { |
| 90 | continue_target_headers_[&continue_target_block] = {current_block_}; |
| 91 | } else { |
| 92 | continue_target_headers_[&continue_target_block].push_back(current_block_); |
| 93 | } |
| 94 | |
| 95 | return SPV_SUCCESS; |
| 96 | } |
| 97 | |
| 98 | spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { |
| 99 | RegisterBlock(merge_id, false); |
| 100 | BasicBlock& merge_block = blocks_.at(merge_id); |
| 101 | current_block_->set_type(kBlockTypeSelection); |
| 102 | merge_block.set_type(kBlockTypeMerge); |
| 103 | merge_block_header_[&merge_block] = current_block_; |
| 104 | |
| 105 | AddConstruct({ConstructType::kSelection, current_block(), &merge_block}); |
| 106 | |
| 107 | return SPV_SUCCESS; |
| 108 | } |
| 109 | |
| 110 | spv_result_t Function::RegisterSetFunctionDeclType(FunctionDecl type) { |
| 111 | assert(declaration_type_ == FunctionDecl::kFunctionDeclUnknown); |
| 112 | declaration_type_ = type; |
| 113 | return SPV_SUCCESS; |
| 114 | } |
| 115 | |
| 116 | spv_result_t Function::RegisterBlock(uint32_t block_id, bool is_definition) { |
| 117 | assert( |
| 118 | declaration_type_ == FunctionDecl::kFunctionDeclDefinition && |
| 119 | "RegisterBlocks can only be called after declaration_type_ is defined" ); |
| 120 | |
| 121 | std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block; |
| 122 | bool success = false; |
| 123 | tie(inserted_block, success) = |
| 124 | blocks_.insert({block_id, BasicBlock(block_id)}); |
| 125 | if (is_definition) { // new block definition |
| 126 | assert(current_block_ == nullptr && |
| 127 | "Register Block can only be called when parsing a binary outside of " |
| 128 | "a BasicBlock" ); |
| 129 | |
| 130 | undefined_blocks_.erase(block_id); |
| 131 | current_block_ = &inserted_block->second; |
| 132 | ordered_blocks_.push_back(current_block_); |
| 133 | if (IsFirstBlock(block_id)) current_block_->set_reachable(true); |
| 134 | } else if (success) { // Block doesn't exsist but this is not a definition |
| 135 | undefined_blocks_.insert(block_id); |
| 136 | } |
| 137 | |
| 138 | return SPV_SUCCESS; |
| 139 | } |
| 140 | |
| 141 | void Function::RegisterBlockEnd(std::vector<uint32_t> next_list, |
| 142 | SpvOp branch_instruction) { |
| 143 | assert( |
| 144 | current_block_ && |
| 145 | "RegisterBlockEnd can only be called when parsing a binary in a block" ); |
| 146 | std::vector<BasicBlock*> next_blocks; |
| 147 | next_blocks.reserve(next_list.size()); |
| 148 | |
| 149 | std::unordered_map<uint32_t, BasicBlock>::iterator inserted_block; |
| 150 | bool success; |
| 151 | for (uint32_t successor_id : next_list) { |
| 152 | tie(inserted_block, success) = |
| 153 | blocks_.insert({successor_id, BasicBlock(successor_id)}); |
| 154 | if (success) { |
| 155 | undefined_blocks_.insert(successor_id); |
| 156 | } |
| 157 | next_blocks.push_back(&inserted_block->second); |
| 158 | } |
| 159 | |
| 160 | if (current_block_->is_type(kBlockTypeLoop)) { |
| 161 | // For each loop header, record the set of its successors, and include |
| 162 | // its continue target if the continue target is not the loop header |
| 163 | // itself. |
| 164 | std::vector<BasicBlock*>& next_blocks_plus_continue_target = |
| 165 | loop_header_successors_plus_continue_target_map_[current_block_]; |
| 166 | next_blocks_plus_continue_target = next_blocks; |
| 167 | auto continue_target = |
| 168 | FindConstructForEntryBlock(current_block_, ConstructType::kLoop) |
| 169 | .corresponding_constructs() |
| 170 | .back() |
| 171 | ->entry_block(); |
| 172 | if (continue_target != current_block_) { |
| 173 | next_blocks_plus_continue_target.push_back(continue_target); |
| 174 | } |
| 175 | } |
| 176 | |
| 177 | current_block_->RegisterBranchInstruction(branch_instruction); |
| 178 | current_block_->RegisterSuccessors(next_blocks); |
| 179 | current_block_ = nullptr; |
| 180 | return; |
| 181 | } |
| 182 | |
| 183 | void Function::RegisterFunctionEnd() { |
| 184 | if (!end_has_been_registered_) { |
| 185 | end_has_been_registered_ = true; |
| 186 | |
| 187 | ComputeAugmentedCFG(); |
| 188 | } |
| 189 | } |
| 190 | |
| 191 | size_t Function::block_count() const { return blocks_.size(); } |
| 192 | |
| 193 | size_t Function::undefined_block_count() const { |
| 194 | return undefined_blocks_.size(); |
| 195 | } |
| 196 | |
| 197 | const std::vector<BasicBlock*>& Function::ordered_blocks() const { |
| 198 | return ordered_blocks_; |
| 199 | } |
| 200 | std::vector<BasicBlock*>& Function::ordered_blocks() { return ordered_blocks_; } |
| 201 | |
| 202 | const BasicBlock* Function::current_block() const { return current_block_; } |
| 203 | BasicBlock* Function::current_block() { return current_block_; } |
| 204 | |
| 205 | const std::list<Construct>& Function::constructs() const { |
| 206 | return cfg_constructs_; |
| 207 | } |
| 208 | std::list<Construct>& Function::constructs() { return cfg_constructs_; } |
| 209 | |
| 210 | const BasicBlock* Function::first_block() const { |
| 211 | if (ordered_blocks_.empty()) return nullptr; |
| 212 | return ordered_blocks_[0]; |
| 213 | } |
| 214 | BasicBlock* Function::first_block() { |
| 215 | if (ordered_blocks_.empty()) return nullptr; |
| 216 | return ordered_blocks_[0]; |
| 217 | } |
| 218 | |
| 219 | bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const { |
| 220 | bool ret = false; |
| 221 | const BasicBlock* block; |
| 222 | std::tie(block, std::ignore) = GetBlock(merge_block_id); |
| 223 | if (block) { |
| 224 | ret = block->is_type(type); |
| 225 | } |
| 226 | return ret; |
| 227 | } |
| 228 | |
| 229 | std::pair<const BasicBlock*, bool> Function::GetBlock(uint32_t block_id) const { |
| 230 | const auto b = blocks_.find(block_id); |
| 231 | if (b != end(blocks_)) { |
| 232 | const BasicBlock* block = &(b->second); |
| 233 | bool defined = |
| 234 | undefined_blocks_.find(block->id()) == std::end(undefined_blocks_); |
| 235 | return std::make_pair(block, defined); |
| 236 | } else { |
| 237 | return std::make_pair(nullptr, false); |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | std::pair<BasicBlock*, bool> Function::GetBlock(uint32_t block_id) { |
| 242 | const BasicBlock* out; |
| 243 | bool defined; |
| 244 | std::tie(out, defined) = |
| 245 | const_cast<const Function*>(this)->GetBlock(block_id); |
| 246 | return std::make_pair(const_cast<BasicBlock*>(out), defined); |
| 247 | } |
| 248 | |
| 249 | Function::GetBlocksFunction Function::AugmentedCFGSuccessorsFunction() const { |
| 250 | return [this](const BasicBlock* block) { |
| 251 | auto where = augmented_successors_map_.find(block); |
| 252 | return where == augmented_successors_map_.end() ? block->successors() |
| 253 | : &(*where).second; |
| 254 | }; |
| 255 | } |
| 256 | |
| 257 | Function::GetBlocksFunction |
| 258 | Function::() const { |
| 259 | return [this](const BasicBlock* block) { |
| 260 | auto where = loop_header_successors_plus_continue_target_map_.find(block); |
| 261 | return where == loop_header_successors_plus_continue_target_map_.end() |
| 262 | ? AugmentedCFGSuccessorsFunction()(block) |
| 263 | : &(*where).second; |
| 264 | }; |
| 265 | } |
| 266 | |
| 267 | Function::GetBlocksFunction Function::AugmentedCFGPredecessorsFunction() const { |
| 268 | return [this](const BasicBlock* block) { |
| 269 | auto where = augmented_predecessors_map_.find(block); |
| 270 | return where == augmented_predecessors_map_.end() ? block->predecessors() |
| 271 | : &(*where).second; |
| 272 | }; |
| 273 | } |
| 274 | |
| 275 | void Function::ComputeAugmentedCFG() { |
| 276 | // Compute the successors of the pseudo-entry block, and |
| 277 | // the predecessors of the pseudo exit block. |
| 278 | auto succ_func = [](const BasicBlock* b) { return b->successors(); }; |
| 279 | auto pred_func = [](const BasicBlock* b) { return b->predecessors(); }; |
| 280 | CFA<BasicBlock>::ComputeAugmentedCFG( |
| 281 | ordered_blocks_, &pseudo_entry_block_, &pseudo_exit_block_, |
| 282 | &augmented_successors_map_, &augmented_predecessors_map_, succ_func, |
| 283 | pred_func); |
| 284 | } |
| 285 | |
| 286 | Construct& Function::AddConstruct(const Construct& new_construct) { |
| 287 | cfg_constructs_.push_back(new_construct); |
| 288 | auto& result = cfg_constructs_.back(); |
| 289 | entry_block_to_construct_[std::make_pair(new_construct.entry_block(), |
| 290 | new_construct.type())] = &result; |
| 291 | return result; |
| 292 | } |
| 293 | |
| 294 | Construct& Function::FindConstructForEntryBlock(const BasicBlock* entry_block, |
| 295 | ConstructType type) { |
| 296 | auto where = |
| 297 | entry_block_to_construct_.find(std::make_pair(entry_block, type)); |
| 298 | assert(where != entry_block_to_construct_.end()); |
| 299 | auto construct_ptr = (*where).second; |
| 300 | assert(construct_ptr); |
| 301 | return *construct_ptr; |
| 302 | } |
| 303 | |
| 304 | int Function::GetBlockDepth(BasicBlock* bb) { |
| 305 | // Guard against nullptr. |
| 306 | if (!bb) { |
| 307 | return 0; |
| 308 | } |
| 309 | // Only calculate the depth if it's not already calculated. |
| 310 | // This function uses memoization to avoid duplicate CFG depth calculations. |
| 311 | if (block_depth_.find(bb) != block_depth_.end()) { |
| 312 | return block_depth_[bb]; |
| 313 | } |
| 314 | |
| 315 | BasicBlock* bb_dom = bb->immediate_dominator(); |
| 316 | if (!bb_dom || bb == bb_dom) { |
| 317 | // This block has no dominator, so it's at depth 0. |
| 318 | block_depth_[bb] = 0; |
| 319 | } else if (bb->is_type(kBlockTypeContinue)) { |
| 320 | // This rule must precede the rule for merge blocks in order to set up |
| 321 | // depths correctly. If a block is both a merge and continue then the merge |
| 322 | // is nested within the continue's loop (or the graph is incorrect). |
| 323 | // The depth of the continue block entry point is 1 + loop header depth. |
| 324 | Construct* continue_construct = |
| 325 | entry_block_to_construct_[std::make_pair(bb, ConstructType::kContinue)]; |
| 326 | assert(continue_construct); |
| 327 | // Continue construct has only 1 corresponding construct (loop header). |
| 328 | Construct* loop_construct = |
| 329 | continue_construct->corresponding_constructs()[0]; |
| 330 | assert(loop_construct); |
| 331 | BasicBlock* = loop_construct->entry_block(); |
| 332 | // The continue target may be the loop itself (while 1). |
| 333 | // In such cases, the depth of the continue block is: 1 + depth of the |
| 334 | // loop's dominator block. |
| 335 | if (loop_header == bb) { |
| 336 | block_depth_[bb] = 1 + GetBlockDepth(bb_dom); |
| 337 | } else { |
| 338 | block_depth_[bb] = 1 + GetBlockDepth(loop_header); |
| 339 | } |
| 340 | } else if (bb->is_type(kBlockTypeMerge)) { |
| 341 | // If this is a merge block, its depth is equal to the block before |
| 342 | // branching. |
| 343 | BasicBlock* = merge_block_header_[bb]; |
| 344 | assert(header); |
| 345 | block_depth_[bb] = GetBlockDepth(header); |
| 346 | } else if (bb_dom->is_type(kBlockTypeSelection) || |
| 347 | bb_dom->is_type(kBlockTypeLoop)) { |
| 348 | // The dominator of the given block is a header block. So, the nesting |
| 349 | // depth of this block is: 1 + nesting depth of the header. |
| 350 | block_depth_[bb] = 1 + GetBlockDepth(bb_dom); |
| 351 | } else { |
| 352 | block_depth_[bb] = GetBlockDepth(bb_dom); |
| 353 | } |
| 354 | return block_depth_[bb]; |
| 355 | } |
| 356 | |
| 357 | void Function::RegisterExecutionModelLimitation(SpvExecutionModel model, |
| 358 | const std::string& message) { |
| 359 | execution_model_limitations_.push_back( |
| 360 | [model, message](SpvExecutionModel in_model, std::string* out_message) { |
| 361 | if (model != in_model) { |
| 362 | if (out_message) { |
| 363 | *out_message = message; |
| 364 | } |
| 365 | return false; |
| 366 | } |
| 367 | return true; |
| 368 | }); |
| 369 | } |
| 370 | |
| 371 | bool Function::IsCompatibleWithExecutionModel(SpvExecutionModel model, |
| 372 | std::string* reason) const { |
| 373 | bool return_value = true; |
| 374 | std::stringstream ss_reason; |
| 375 | |
| 376 | for (const auto& is_compatible : execution_model_limitations_) { |
| 377 | std::string message; |
| 378 | if (!is_compatible(model, &message)) { |
| 379 | if (!reason) return false; |
| 380 | return_value = false; |
| 381 | if (!message.empty()) { |
| 382 | ss_reason << message << "\n" ; |
| 383 | } |
| 384 | } |
| 385 | } |
| 386 | |
| 387 | if (!return_value && reason) { |
| 388 | *reason = ss_reason.str(); |
| 389 | } |
| 390 | |
| 391 | return return_value; |
| 392 | } |
| 393 | |
| 394 | bool Function::CheckLimitations(const ValidationState_t& _, |
| 395 | const Function* entry_point, |
| 396 | std::string* reason) const { |
| 397 | bool return_value = true; |
| 398 | std::stringstream ss_reason; |
| 399 | |
| 400 | for (const auto& is_compatible : limitations_) { |
| 401 | std::string message; |
| 402 | if (!is_compatible(_, entry_point, &message)) { |
| 403 | if (!reason) return false; |
| 404 | return_value = false; |
| 405 | if (!message.empty()) { |
| 406 | ss_reason << message << "\n" ; |
| 407 | } |
| 408 | } |
| 409 | } |
| 410 | |
| 411 | if (!return_value && reason) { |
| 412 | *reason = ss_reason.str(); |
| 413 | } |
| 414 | |
| 415 | return return_value; |
| 416 | } |
| 417 | |
| 418 | } // namespace val |
| 419 | } // namespace spvtools |
| 420 | |