1 | // Copyright (c) 2019 Google LLC |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/wrap_opkill.h" |
16 | |
17 | #include "ir_builder.h" |
18 | |
19 | namespace spvtools { |
20 | namespace opt { |
21 | |
22 | Pass::Status WrapOpKill::Process() { |
23 | bool modified = false; |
24 | |
25 | auto func_to_process = |
26 | context()->GetStructuredCFGAnalysis()->FindFuncsCalledFromContinue(); |
27 | for (uint32_t func_id : func_to_process) { |
28 | Function* func = context()->GetFunction(func_id); |
29 | bool successful = func->WhileEachInst([this, &modified](Instruction* inst) { |
30 | if (inst->opcode() == SpvOpKill) { |
31 | modified = true; |
32 | if (!ReplaceWithFunctionCall(inst)) { |
33 | return false; |
34 | } |
35 | } |
36 | return true; |
37 | }); |
38 | |
39 | if (!successful) { |
40 | return Status::Failure; |
41 | } |
42 | } |
43 | |
44 | if (opkill_function_ != nullptr) { |
45 | assert(modified && |
46 | "The function should only be generated if something was modified." ); |
47 | context()->AddFunction(std::move(opkill_function_)); |
48 | } |
49 | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
50 | } |
51 | |
52 | bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) { |
53 | assert(inst->opcode() == SpvOpKill && |
54 | "|inst| must be an OpKill instruction." ); |
55 | InstructionBuilder ir_builder( |
56 | context(), inst, |
57 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
58 | uint32_t func_id = GetOpKillFuncId(); |
59 | if (func_id == 0) { |
60 | return false; |
61 | } |
62 | if (ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}) == nullptr) { |
63 | return false; |
64 | } |
65 | |
66 | Instruction* return_inst = nullptr; |
67 | uint32_t return_type_id = GetOwningFunctionsReturnType(inst); |
68 | if (return_type_id != GetVoidTypeId()) { |
69 | Instruction* undef = ir_builder.AddNullaryOp(return_type_id, SpvOpUndef); |
70 | if (undef == nullptr) { |
71 | return false; |
72 | } |
73 | return_inst = |
74 | ir_builder.AddUnaryOp(0, SpvOpReturnValue, undef->result_id()); |
75 | } else { |
76 | return_inst = ir_builder.AddNullaryOp(0, SpvOpReturn); |
77 | } |
78 | |
79 | if (return_inst == nullptr) { |
80 | return false; |
81 | } |
82 | |
83 | context()->KillInst(inst); |
84 | return true; |
85 | } |
86 | |
87 | uint32_t WrapOpKill::GetVoidTypeId() { |
88 | if (void_type_id_ != 0) { |
89 | return void_type_id_; |
90 | } |
91 | |
92 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
93 | analysis::Void void_type; |
94 | void_type_id_ = type_mgr->GetTypeInstruction(&void_type); |
95 | return void_type_id_; |
96 | } |
97 | |
98 | uint32_t WrapOpKill::GetVoidFunctionTypeId() { |
99 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
100 | analysis::Void void_type; |
101 | const analysis::Type* registered_void_type = |
102 | type_mgr->GetRegisteredType(&void_type); |
103 | |
104 | analysis::Function func_type(registered_void_type, {}); |
105 | return type_mgr->GetTypeInstruction(&func_type); |
106 | } |
107 | |
108 | uint32_t WrapOpKill::GetOpKillFuncId() { |
109 | if (opkill_function_ != nullptr) { |
110 | return opkill_function_->result_id(); |
111 | } |
112 | |
113 | uint32_t opkill_func_id = TakeNextId(); |
114 | if (opkill_func_id == 0) { |
115 | return 0; |
116 | } |
117 | |
118 | uint32_t void_type_id = GetVoidTypeId(); |
119 | if (void_type_id == 0) { |
120 | return 0; |
121 | } |
122 | |
123 | // Generate the function start instruction |
124 | std::unique_ptr<Instruction> func_start(new Instruction( |
125 | context(), SpvOpFunction, void_type_id, opkill_func_id, {})); |
126 | func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}}); |
127 | func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}}); |
128 | opkill_function_.reset(new Function(std::move(func_start))); |
129 | |
130 | // Generate the function end instruction |
131 | std::unique_ptr<Instruction> func_end( |
132 | new Instruction(context(), SpvOpFunctionEnd, 0, 0, {})); |
133 | opkill_function_->SetFunctionEnd(std::move(func_end)); |
134 | |
135 | // Create the one basic block for the function. |
136 | uint32_t lab_id = TakeNextId(); |
137 | if (lab_id == 0) { |
138 | return 0; |
139 | } |
140 | std::unique_ptr<Instruction> label_inst( |
141 | new Instruction(context(), SpvOpLabel, 0, lab_id, {})); |
142 | std::unique_ptr<BasicBlock> bb(new BasicBlock(std::move(label_inst))); |
143 | |
144 | // Add the OpKill to the basic block |
145 | std::unique_ptr<Instruction> kill_inst( |
146 | new Instruction(context(), SpvOpKill, 0, 0, {})); |
147 | bb->AddInstruction(std::move(kill_inst)); |
148 | |
149 | // Add the bb to the function |
150 | bb->SetParent(opkill_function_.get()); |
151 | opkill_function_->AddBasicBlock(std::move(bb)); |
152 | |
153 | // Add the function to the module. |
154 | if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) { |
155 | opkill_function_->ForEachInst( |
156 | [this](Instruction* inst) { context()->AnalyzeDefUse(inst); }); |
157 | } |
158 | |
159 | if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) { |
160 | for (BasicBlock& basic_block : *opkill_function_) { |
161 | context()->set_instr_block(basic_block.GetLabelInst(), &basic_block); |
162 | for (Instruction& inst : basic_block) { |
163 | context()->set_instr_block(&inst, &basic_block); |
164 | } |
165 | } |
166 | } |
167 | |
168 | return opkill_function_->result_id(); |
169 | } |
170 | |
171 | uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) { |
172 | BasicBlock* bb = context()->get_instr_block(inst); |
173 | if (bb == nullptr) { |
174 | return 0; |
175 | } |
176 | |
177 | Function* func = bb->GetParent(); |
178 | return func->type_id(); |
179 | } |
180 | |
181 | } // namespace opt |
182 | } // namespace spvtools |
183 | |