1// Copyright (c) 2020 The Khronos Group Inc.
2// Copyright (c) 2020 Valve Corporation
3// Copyright (c) 2020 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "inst_debug_printf_pass.h"
18
19#include "spirv/unified1/NonSemanticDebugPrintf.h"
20
21namespace spvtools {
22namespace opt {
23
24void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst,
25 std::vector<uint32_t>* val_ids,
26 InstructionBuilder* builder) {
27 uint32_t val_ty_id = val_inst->type_id();
28 analysis::TypeManager* type_mgr = context()->get_type_mgr();
29 analysis::Type* val_ty = type_mgr->GetType(val_ty_id);
30 switch (val_ty->kind()) {
31 case analysis::Type::kVector: {
32 analysis::Vector* v_ty = val_ty->AsVector();
33 const analysis::Type* c_ty = v_ty->element_type();
34 uint32_t c_ty_id = type_mgr->GetId(c_ty);
35 for (uint32_t c = 0; c < v_ty->element_count(); ++c) {
36 Instruction* c_inst = builder->AddIdLiteralOp(
37 c_ty_id, SpvOpCompositeExtract, val_inst->result_id(), c);
38 GenOutputValues(c_inst, val_ids, builder);
39 }
40 return;
41 }
42 case analysis::Type::kBool: {
43 // Select between uint32 zero or one
44 uint32_t zero_id = builder->GetUintConstantId(0);
45 uint32_t one_id = builder->GetUintConstantId(1);
46 Instruction* sel_inst = builder->AddTernaryOp(
47 GetUintId(), SpvOpSelect, val_inst->result_id(), one_id, zero_id);
48 val_ids->push_back(sel_inst->result_id());
49 return;
50 }
51 case analysis::Type::kFloat: {
52 analysis::Float* f_ty = val_ty->AsFloat();
53 switch (f_ty->width()) {
54 case 16: {
55 // Convert float16 to float32 and recurse
56 Instruction* f32_inst = builder->AddUnaryOp(
57 GetFloatId(), SpvOpFConvert, val_inst->result_id());
58 GenOutputValues(f32_inst, val_ids, builder);
59 return;
60 }
61 case 64: {
62 // Bitcast float64 to uint64 and recurse
63 Instruction* ui64_inst = builder->AddUnaryOp(
64 GetUint64Id(), SpvOpBitcast, val_inst->result_id());
65 GenOutputValues(ui64_inst, val_ids, builder);
66 return;
67 }
68 case 32: {
69 // Bitcase float32 to uint32
70 Instruction* bc_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
71 val_inst->result_id());
72 val_ids->push_back(bc_inst->result_id());
73 return;
74 }
75 default:
76 assert(false && "unsupported float width");
77 return;
78 }
79 }
80 case analysis::Type::kInteger: {
81 analysis::Integer* i_ty = val_ty->AsInteger();
82 switch (i_ty->width()) {
83 case 64: {
84 Instruction* ui64_inst = val_inst;
85 if (i_ty->IsSigned()) {
86 // Bitcast sint64 to uint64
87 ui64_inst = builder->AddUnaryOp(GetUint64Id(), SpvOpBitcast,
88 val_inst->result_id());
89 }
90 // Break uint64 into 2x uint32
91 Instruction* lo_ui64_inst = builder->AddUnaryOp(
92 GetUintId(), SpvOpUConvert, ui64_inst->result_id());
93 Instruction* rshift_ui64_inst = builder->AddBinaryOp(
94 GetUint64Id(), SpvOpShiftRightLogical, ui64_inst->result_id(),
95 builder->GetUintConstantId(32));
96 Instruction* hi_ui64_inst = builder->AddUnaryOp(
97 GetUintId(), SpvOpUConvert, rshift_ui64_inst->result_id());
98 val_ids->push_back(lo_ui64_inst->result_id());
99 val_ids->push_back(hi_ui64_inst->result_id());
100 return;
101 }
102 case 8: {
103 Instruction* ui8_inst = val_inst;
104 if (i_ty->IsSigned()) {
105 // Bitcast sint8 to uint8
106 ui8_inst = builder->AddUnaryOp(GetUint8Id(), SpvOpBitcast,
107 val_inst->result_id());
108 }
109 // Convert uint8 to uint32
110 Instruction* ui32_inst = builder->AddUnaryOp(
111 GetUintId(), SpvOpUConvert, ui8_inst->result_id());
112 val_ids->push_back(ui32_inst->result_id());
113 return;
114 }
115 case 32: {
116 Instruction* ui32_inst = val_inst;
117 if (i_ty->IsSigned()) {
118 // Bitcast sint32 to uint32
119 ui32_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
120 val_inst->result_id());
121 }
122 // uint32 needs no further processing
123 val_ids->push_back(ui32_inst->result_id());
124 return;
125 }
126 default:
127 // TODO(greg-lunarg): Support non-32-bit int
128 assert(false && "unsupported int width");
129 return;
130 }
131 }
132 default:
133 assert(false && "unsupported type");
134 return;
135 }
136}
137
138void InstDebugPrintfPass::GenOutputCode(
139 Instruction* printf_inst, uint32_t stage_idx,
140 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
141 BasicBlock* back_blk_ptr = &*new_blocks->back();
142 InstructionBuilder builder(
143 context(), back_blk_ptr,
144 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
145 // Gen debug printf record validation-specific values. The format string
146 // will have its id written. Vectors will need to be broken down into
147 // component values. float16 will need to be converted to float32. Pointer
148 // and uint64 will need to be converted to two uint32 values. float32 will
149 // need to be bitcast to uint32. int32 will need to be bitcast to uint32.
150 std::vector<uint32_t> val_ids;
151 bool is_first_operand = false;
152 printf_inst->ForEachInId(
153 [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) {
154 // skip set operand
155 if (!is_first_operand) {
156 is_first_operand = true;
157 return;
158 }
159 Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid);
160 if (opnd_inst->opcode() == SpvOpString) {
161 uint32_t string_id_id = builder.GetUintConstantId(*iid);
162 val_ids.push_back(string_id_id);
163 } else {
164 GenOutputValues(opnd_inst, &val_ids, &builder);
165 }
166 });
167 GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
168 &builder);
169 context()->KillInst(printf_inst);
170}
171
172void InstDebugPrintfPass::GenDebugPrintfCode(
173 BasicBlock::iterator ref_inst_itr,
174 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
175 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
176 // If not DebugPrintf OpExtInst, return.
177 Instruction* printf_inst = &*ref_inst_itr;
178 if (printf_inst->opcode() != SpvOpExtInst) return;
179 if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return;
180 if (printf_inst->GetSingleWordInOperand(1) !=
181 NonSemanticDebugPrintfDebugPrintf)
182 return;
183 // Initialize DefUse manager before dismantling module
184 (void)get_def_use_mgr();
185 // Move original block's preceding instructions into first new block
186 std::unique_ptr<BasicBlock> new_blk_ptr;
187 MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
188 new_blocks->push_back(std::move(new_blk_ptr));
189 // Generate instructions to output printf args to printf buffer
190 GenOutputCode(printf_inst, stage_idx, new_blocks);
191 // Caller expects at least two blocks with last block containing remaining
192 // code, so end block after instrumentation, create remainder block, and
193 // branch to it
194 uint32_t rem_blk_id = TakeNextId();
195 std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id));
196 BasicBlock* back_blk_ptr = &*new_blocks->back();
197 InstructionBuilder builder(
198 context(), back_blk_ptr,
199 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
200 (void)builder.AddBranch(rem_blk_id);
201 // Gen remainder block
202 new_blk_ptr.reset(new BasicBlock(std::move(rem_label)));
203 builder.SetInsertPoint(&*new_blk_ptr);
204 // Move original block's remaining code into remainder block and add
205 // to new blocks
206 MovePostludeCode(ref_block_itr, &*new_blk_ptr);
207 new_blocks->push_back(std::move(new_blk_ptr));
208}
209
210void InstDebugPrintfPass::InitializeInstDebugPrintf() {
211 // Initialize base class
212 InitializeInstrument();
213}
214
215Pass::Status InstDebugPrintfPass::ProcessImpl() {
216 // Perform printf instrumentation on each entry point function in module
217 InstProcessFunction pfn =
218 [this](BasicBlock::iterator ref_inst_itr,
219 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
220 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
221 return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx,
222 new_blocks);
223 };
224 (void)InstProcessEntryPointCallTree(pfn);
225 // Remove DebugPrintf OpExtInstImport instruction
226 Instruction* ext_inst_import_inst =
227 get_def_use_mgr()->GetDef(ext_inst_printf_id_);
228 context()->KillInst(ext_inst_import_inst);
229 // If no remaining non-semantic instruction sets, remove non-semantic debug
230 // info extension from module and feature manager
231 bool non_sem_set_seen = false;
232 for (auto c_itr = context()->module()->ext_inst_import_begin();
233 c_itr != context()->module()->ext_inst_import_end(); ++c_itr) {
234 const char* set_name =
235 reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
236 const char* non_sem_str = "NonSemantic.";
237 if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) {
238 non_sem_set_seen = true;
239 break;
240 }
241 }
242 if (!non_sem_set_seen) {
243 for (auto c_itr = context()->module()->extension_begin();
244 c_itr != context()->module()->extension_end(); ++c_itr) {
245 const char* ext_name =
246 reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]);
247 if (!strcmp(ext_name, "SPV_KHR_non_semantic_info")) {
248 context()->KillInst(&*c_itr);
249 break;
250 }
251 }
252 context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info);
253 }
254 return Status::SuccessWithChange;
255}
256
257Pass::Status InstDebugPrintfPass::Process() {
258 ext_inst_printf_id_ =
259 get_module()->GetExtInstImportId("NonSemantic.DebugPrintf");
260 if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange;
261 InitializeInstDebugPrintf();
262 return ProcessImpl();
263}
264
265} // namespace opt
266} // namespace spvtools
267