1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/wrap_opkill.h"
16
17#include "ir_builder.h"
18
19namespace spvtools {
20namespace opt {
21
22Pass::Status WrapOpKill::Process() {
23 bool modified = false;
24
25 auto func_to_process =
26 context()->GetStructuredCFGAnalysis()->FindFuncsCalledFromContinue();
27 for (uint32_t func_id : func_to_process) {
28 Function* func = context()->GetFunction(func_id);
29 bool successful = func->WhileEachInst([this, &modified](Instruction* inst) {
30 if (inst->opcode() == SpvOpKill) {
31 modified = true;
32 if (!ReplaceWithFunctionCall(inst)) {
33 return false;
34 }
35 }
36 return true;
37 });
38
39 if (!successful) {
40 return Status::Failure;
41 }
42 }
43
44 if (opkill_function_ != nullptr) {
45 assert(modified &&
46 "The function should only be generated if something was modified.");
47 context()->AddFunction(std::move(opkill_function_));
48 }
49 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
50}
51
52bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
53 assert(inst->opcode() == SpvOpKill &&
54 "|inst| must be an OpKill instruction.");
55 InstructionBuilder ir_builder(
56 context(), inst,
57 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
58 uint32_t func_id = GetOpKillFuncId();
59 if (func_id == 0) {
60 return false;
61 }
62 if (ir_builder.AddFunctionCall(GetVoidTypeId(), func_id, {}) == nullptr) {
63 return false;
64 }
65
66 Instruction* return_inst = nullptr;
67 uint32_t return_type_id = GetOwningFunctionsReturnType(inst);
68 if (return_type_id != GetVoidTypeId()) {
69 Instruction* undef = ir_builder.AddNullaryOp(return_type_id, SpvOpUndef);
70 if (undef == nullptr) {
71 return false;
72 }
73 return_inst =
74 ir_builder.AddUnaryOp(0, SpvOpReturnValue, undef->result_id());
75 } else {
76 return_inst = ir_builder.AddNullaryOp(0, SpvOpReturn);
77 }
78
79 if (return_inst == nullptr) {
80 return false;
81 }
82
83 context()->KillInst(inst);
84 return true;
85}
86
87uint32_t WrapOpKill::GetVoidTypeId() {
88 if (void_type_id_ != 0) {
89 return void_type_id_;
90 }
91
92 analysis::TypeManager* type_mgr = context()->get_type_mgr();
93 analysis::Void void_type;
94 void_type_id_ = type_mgr->GetTypeInstruction(&void_type);
95 return void_type_id_;
96}
97
98uint32_t WrapOpKill::GetVoidFunctionTypeId() {
99 analysis::TypeManager* type_mgr = context()->get_type_mgr();
100 analysis::Void void_type;
101 const analysis::Type* registered_void_type =
102 type_mgr->GetRegisteredType(&void_type);
103
104 analysis::Function func_type(registered_void_type, {});
105 return type_mgr->GetTypeInstruction(&func_type);
106}
107
108uint32_t WrapOpKill::GetOpKillFuncId() {
109 if (opkill_function_ != nullptr) {
110 return opkill_function_->result_id();
111 }
112
113 uint32_t opkill_func_id = TakeNextId();
114 if (opkill_func_id == 0) {
115 return 0;
116 }
117
118 uint32_t void_type_id = GetVoidTypeId();
119 if (void_type_id == 0) {
120 return 0;
121 }
122
123 // Generate the function start instruction
124 std::unique_ptr<Instruction> func_start(new Instruction(
125 context(), SpvOpFunction, void_type_id, opkill_func_id, {}));
126 func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}});
127 func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}});
128 opkill_function_.reset(new Function(std::move(func_start)));
129
130 // Generate the function end instruction
131 std::unique_ptr<Instruction> func_end(
132 new Instruction(context(), SpvOpFunctionEnd, 0, 0, {}));
133 opkill_function_->SetFunctionEnd(std::move(func_end));
134
135 // Create the one basic block for the function.
136 uint32_t lab_id = TakeNextId();
137 if (lab_id == 0) {
138 return 0;
139 }
140 std::unique_ptr<Instruction> label_inst(
141 new Instruction(context(), SpvOpLabel, 0, lab_id, {}));
142 std::unique_ptr<BasicBlock> bb(new BasicBlock(std::move(label_inst)));
143
144 // Add the OpKill to the basic block
145 std::unique_ptr<Instruction> kill_inst(
146 new Instruction(context(), SpvOpKill, 0, 0, {}));
147 bb->AddInstruction(std::move(kill_inst));
148
149 // Add the bb to the function
150 bb->SetParent(opkill_function_.get());
151 opkill_function_->AddBasicBlock(std::move(bb));
152
153 // Add the function to the module.
154 if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) {
155 opkill_function_->ForEachInst(
156 [this](Instruction* inst) { context()->AnalyzeDefUse(inst); });
157 }
158
159 if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) {
160 for (BasicBlock& basic_block : *opkill_function_) {
161 context()->set_instr_block(basic_block.GetLabelInst(), &basic_block);
162 for (Instruction& inst : basic_block) {
163 context()->set_instr_block(&inst, &basic_block);
164 }
165 }
166 }
167
168 return opkill_function_->result_id();
169}
170
171uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) {
172 BasicBlock* bb = context()->get_instr_block(inst);
173 if (bb == nullptr) {
174 return 0;
175 }
176
177 Function* func = bb->GetParent();
178 return func->type_id();
179}
180
181} // namespace opt
182} // namespace spvtools
183