1 | // Copyright (c) 2020 The Khronos Group Inc. |
2 | // Copyright (c) 2020 Valve Corporation |
3 | // Copyright (c) 2020 LunarG Inc. |
4 | // |
5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | // you may not use this file except in compliance with the License. |
7 | // You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, software |
12 | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | // See the License for the specific language governing permissions and |
15 | // limitations under the License. |
16 | |
17 | #include "inst_debug_printf_pass.h" |
18 | |
19 | #include "spirv/unified1/NonSemanticDebugPrintf.h" |
20 | |
21 | namespace spvtools { |
22 | namespace opt { |
23 | |
24 | void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst, |
25 | std::vector<uint32_t>* val_ids, |
26 | InstructionBuilder* builder) { |
27 | uint32_t val_ty_id = val_inst->type_id(); |
28 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
29 | analysis::Type* val_ty = type_mgr->GetType(val_ty_id); |
30 | switch (val_ty->kind()) { |
31 | case analysis::Type::kVector: { |
32 | analysis::Vector* v_ty = val_ty->AsVector(); |
33 | const analysis::Type* c_ty = v_ty->element_type(); |
34 | uint32_t c_ty_id = type_mgr->GetId(c_ty); |
35 | for (uint32_t c = 0; c < v_ty->element_count(); ++c) { |
36 | Instruction* c_inst = builder->AddIdLiteralOp( |
37 | c_ty_id, SpvOpCompositeExtract, val_inst->result_id(), c); |
38 | GenOutputValues(c_inst, val_ids, builder); |
39 | } |
40 | return; |
41 | } |
42 | case analysis::Type::kBool: { |
43 | // Select between uint32 zero or one |
44 | uint32_t zero_id = builder->GetUintConstantId(0); |
45 | uint32_t one_id = builder->GetUintConstantId(1); |
46 | Instruction* sel_inst = builder->AddTernaryOp( |
47 | GetUintId(), SpvOpSelect, val_inst->result_id(), one_id, zero_id); |
48 | val_ids->push_back(sel_inst->result_id()); |
49 | return; |
50 | } |
51 | case analysis::Type::kFloat: { |
52 | analysis::Float* f_ty = val_ty->AsFloat(); |
53 | switch (f_ty->width()) { |
54 | case 16: { |
55 | // Convert float16 to float32 and recurse |
56 | Instruction* f32_inst = builder->AddUnaryOp( |
57 | GetFloatId(), SpvOpFConvert, val_inst->result_id()); |
58 | GenOutputValues(f32_inst, val_ids, builder); |
59 | return; |
60 | } |
61 | case 64: { |
62 | // Bitcast float64 to uint64 and recurse |
63 | Instruction* ui64_inst = builder->AddUnaryOp( |
64 | GetUint64Id(), SpvOpBitcast, val_inst->result_id()); |
65 | GenOutputValues(ui64_inst, val_ids, builder); |
66 | return; |
67 | } |
68 | case 32: { |
69 | // Bitcase float32 to uint32 |
70 | Instruction* bc_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast, |
71 | val_inst->result_id()); |
72 | val_ids->push_back(bc_inst->result_id()); |
73 | return; |
74 | } |
75 | default: |
76 | assert(false && "unsupported float width" ); |
77 | return; |
78 | } |
79 | } |
80 | case analysis::Type::kInteger: { |
81 | analysis::Integer* i_ty = val_ty->AsInteger(); |
82 | switch (i_ty->width()) { |
83 | case 64: { |
84 | Instruction* ui64_inst = val_inst; |
85 | if (i_ty->IsSigned()) { |
86 | // Bitcast sint64 to uint64 |
87 | ui64_inst = builder->AddUnaryOp(GetUint64Id(), SpvOpBitcast, |
88 | val_inst->result_id()); |
89 | } |
90 | // Break uint64 into 2x uint32 |
91 | Instruction* lo_ui64_inst = builder->AddUnaryOp( |
92 | GetUintId(), SpvOpUConvert, ui64_inst->result_id()); |
93 | Instruction* rshift_ui64_inst = builder->AddBinaryOp( |
94 | GetUint64Id(), SpvOpShiftRightLogical, ui64_inst->result_id(), |
95 | builder->GetUintConstantId(32)); |
96 | Instruction* hi_ui64_inst = builder->AddUnaryOp( |
97 | GetUintId(), SpvOpUConvert, rshift_ui64_inst->result_id()); |
98 | val_ids->push_back(lo_ui64_inst->result_id()); |
99 | val_ids->push_back(hi_ui64_inst->result_id()); |
100 | return; |
101 | } |
102 | case 8: { |
103 | Instruction* ui8_inst = val_inst; |
104 | if (i_ty->IsSigned()) { |
105 | // Bitcast sint8 to uint8 |
106 | ui8_inst = builder->AddUnaryOp(GetUint8Id(), SpvOpBitcast, |
107 | val_inst->result_id()); |
108 | } |
109 | // Convert uint8 to uint32 |
110 | Instruction* ui32_inst = builder->AddUnaryOp( |
111 | GetUintId(), SpvOpUConvert, ui8_inst->result_id()); |
112 | val_ids->push_back(ui32_inst->result_id()); |
113 | return; |
114 | } |
115 | case 32: { |
116 | Instruction* ui32_inst = val_inst; |
117 | if (i_ty->IsSigned()) { |
118 | // Bitcast sint32 to uint32 |
119 | ui32_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast, |
120 | val_inst->result_id()); |
121 | } |
122 | // uint32 needs no further processing |
123 | val_ids->push_back(ui32_inst->result_id()); |
124 | return; |
125 | } |
126 | default: |
127 | // TODO(greg-lunarg): Support non-32-bit int |
128 | assert(false && "unsupported int width" ); |
129 | return; |
130 | } |
131 | } |
132 | default: |
133 | assert(false && "unsupported type" ); |
134 | return; |
135 | } |
136 | } |
137 | |
138 | void InstDebugPrintfPass::GenOutputCode( |
139 | Instruction* printf_inst, uint32_t stage_idx, |
140 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
141 | BasicBlock* back_blk_ptr = &*new_blocks->back(); |
142 | InstructionBuilder builder( |
143 | context(), back_blk_ptr, |
144 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
145 | // Gen debug printf record validation-specific values. The format string |
146 | // will have its id written. Vectors will need to be broken down into |
147 | // component values. float16 will need to be converted to float32. Pointer |
148 | // and uint64 will need to be converted to two uint32 values. float32 will |
149 | // need to be bitcast to uint32. int32 will need to be bitcast to uint32. |
150 | std::vector<uint32_t> val_ids; |
151 | bool is_first_operand = false; |
152 | printf_inst->ForEachInId( |
153 | [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) { |
154 | // skip set operand |
155 | if (!is_first_operand) { |
156 | is_first_operand = true; |
157 | return; |
158 | } |
159 | Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid); |
160 | if (opnd_inst->opcode() == SpvOpString) { |
161 | uint32_t string_id_id = builder.GetUintConstantId(*iid); |
162 | val_ids.push_back(string_id_id); |
163 | } else { |
164 | GenOutputValues(opnd_inst, &val_ids, &builder); |
165 | } |
166 | }); |
167 | GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids, |
168 | &builder); |
169 | context()->KillInst(printf_inst); |
170 | } |
171 | |
172 | void InstDebugPrintfPass::GenDebugPrintfCode( |
173 | BasicBlock::iterator ref_inst_itr, |
174 | UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx, |
175 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
176 | // If not DebugPrintf OpExtInst, return. |
177 | Instruction* printf_inst = &*ref_inst_itr; |
178 | if (printf_inst->opcode() != SpvOpExtInst) return; |
179 | if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return; |
180 | if (printf_inst->GetSingleWordInOperand(1) != |
181 | NonSemanticDebugPrintfDebugPrintf) |
182 | return; |
183 | // Initialize DefUse manager before dismantling module |
184 | (void)get_def_use_mgr(); |
185 | // Move original block's preceding instructions into first new block |
186 | std::unique_ptr<BasicBlock> new_blk_ptr; |
187 | MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr); |
188 | new_blocks->push_back(std::move(new_blk_ptr)); |
189 | // Generate instructions to output printf args to printf buffer |
190 | GenOutputCode(printf_inst, stage_idx, new_blocks); |
191 | // Caller expects at least two blocks with last block containing remaining |
192 | // code, so end block after instrumentation, create remainder block, and |
193 | // branch to it |
194 | uint32_t rem_blk_id = TakeNextId(); |
195 | std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id)); |
196 | BasicBlock* back_blk_ptr = &*new_blocks->back(); |
197 | InstructionBuilder builder( |
198 | context(), back_blk_ptr, |
199 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
200 | (void)builder.AddBranch(rem_blk_id); |
201 | // Gen remainder block |
202 | new_blk_ptr.reset(new BasicBlock(std::move(rem_label))); |
203 | builder.SetInsertPoint(&*new_blk_ptr); |
204 | // Move original block's remaining code into remainder block and add |
205 | // to new blocks |
206 | MovePostludeCode(ref_block_itr, &*new_blk_ptr); |
207 | new_blocks->push_back(std::move(new_blk_ptr)); |
208 | } |
209 | |
210 | void InstDebugPrintfPass::InitializeInstDebugPrintf() { |
211 | // Initialize base class |
212 | InitializeInstrument(); |
213 | } |
214 | |
215 | Pass::Status InstDebugPrintfPass::ProcessImpl() { |
216 | // Perform printf instrumentation on each entry point function in module |
217 | InstProcessFunction pfn = |
218 | [this](BasicBlock::iterator ref_inst_itr, |
219 | UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx, |
220 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
221 | return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx, |
222 | new_blocks); |
223 | }; |
224 | (void)InstProcessEntryPointCallTree(pfn); |
225 | // Remove DebugPrintf OpExtInstImport instruction |
226 | Instruction* ext_inst_import_inst = |
227 | get_def_use_mgr()->GetDef(ext_inst_printf_id_); |
228 | context()->KillInst(ext_inst_import_inst); |
229 | // If no remaining non-semantic instruction sets, remove non-semantic debug |
230 | // info extension from module and feature manager |
231 | bool non_sem_set_seen = false; |
232 | for (auto c_itr = context()->module()->ext_inst_import_begin(); |
233 | c_itr != context()->module()->ext_inst_import_end(); ++c_itr) { |
234 | const char* set_name = |
235 | reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]); |
236 | const char* non_sem_str = "NonSemantic." ; |
237 | if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) { |
238 | non_sem_set_seen = true; |
239 | break; |
240 | } |
241 | } |
242 | if (!non_sem_set_seen) { |
243 | for (auto c_itr = context()->module()->extension_begin(); |
244 | c_itr != context()->module()->extension_end(); ++c_itr) { |
245 | const char* ext_name = |
246 | reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]); |
247 | if (!strcmp(ext_name, "SPV_KHR_non_semantic_info" )) { |
248 | context()->KillInst(&*c_itr); |
249 | break; |
250 | } |
251 | } |
252 | context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info); |
253 | } |
254 | return Status::SuccessWithChange; |
255 | } |
256 | |
257 | Pass::Status InstDebugPrintfPass::Process() { |
258 | ext_inst_printf_id_ = |
259 | get_module()->GetExtInstImportId("NonSemantic.DebugPrintf" ); |
260 | if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange; |
261 | InitializeInstDebugPrintf(); |
262 | return ProcessImpl(); |
263 | } |
264 | |
265 | } // namespace opt |
266 | } // namespace spvtools |
267 | |