1 | // Copyright (c) 2017 The Khronos Group Inc. |
2 | // Copyright (c) 2017 Valve Corporation |
3 | // Copyright (c) 2017 LunarG Inc. |
4 | // |
5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | // you may not use this file except in compliance with the License. |
7 | // You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, software |
12 | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | // See the License for the specific language governing permissions and |
15 | // limitations under the License. |
16 | |
17 | #include "source/opt/local_access_chain_convert_pass.h" |
18 | |
19 | #include "ir_builder.h" |
20 | #include "ir_context.h" |
21 | #include "iterator.h" |
22 | |
23 | namespace spvtools { |
24 | namespace opt { |
25 | |
26 | namespace { |
27 | |
28 | const uint32_t kStoreValIdInIdx = 1; |
29 | const uint32_t kAccessChainPtrIdInIdx = 0; |
30 | const uint32_t kConstantValueInIdx = 0; |
31 | const uint32_t kTypeIntWidthInIdx = 0; |
32 | |
33 | } // anonymous namespace |
34 | |
35 | void LocalAccessChainConvertPass::BuildAndAppendInst( |
36 | SpvOp opcode, uint32_t typeId, uint32_t resultId, |
37 | const std::vector<Operand>& in_opnds, |
38 | std::vector<std::unique_ptr<Instruction>>* newInsts) { |
39 | std::unique_ptr<Instruction> newInst( |
40 | new Instruction(context(), opcode, typeId, resultId, in_opnds)); |
41 | get_def_use_mgr()->AnalyzeInstDefUse(&*newInst); |
42 | newInsts->emplace_back(std::move(newInst)); |
43 | } |
44 | |
45 | uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad( |
46 | const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId, |
47 | std::vector<std::unique_ptr<Instruction>>* newInsts) { |
48 | const uint32_t ldResultId = TakeNextId(); |
49 | if (ldResultId == 0) { |
50 | return 0; |
51 | } |
52 | |
53 | *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx); |
54 | const Instruction* varInst = get_def_use_mgr()->GetDef(*varId); |
55 | assert(varInst->opcode() == SpvOpVariable); |
56 | *varPteTypeId = GetPointeeTypeId(varInst); |
57 | BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId, |
58 | {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}}, |
59 | newInsts); |
60 | return ldResultId; |
61 | } |
62 | |
63 | void LocalAccessChainConvertPass::AppendConstantOperands( |
64 | const Instruction* ptrInst, std::vector<Operand>* in_opnds) { |
65 | uint32_t iidIdx = 0; |
66 | ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) { |
67 | if (iidIdx > 0) { |
68 | const Instruction* cInst = get_def_use_mgr()->GetDef(*iid); |
69 | uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx); |
70 | in_opnds->push_back( |
71 | {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}}); |
72 | } |
73 | ++iidIdx; |
74 | }); |
75 | } |
76 | |
77 | bool LocalAccessChainConvertPass::ReplaceAccessChainLoad( |
78 | const Instruction* address_inst, Instruction* original_load) { |
79 | // Build and append load of variable in ptrInst |
80 | std::vector<std::unique_ptr<Instruction>> new_inst; |
81 | uint32_t varId; |
82 | uint32_t varPteTypeId; |
83 | const uint32_t ldResultId = |
84 | BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst); |
85 | if (ldResultId == 0) { |
86 | return false; |
87 | } |
88 | |
89 | context()->get_decoration_mgr()->CloneDecorations( |
90 | original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision}); |
91 | original_load->InsertBefore(std::move(new_inst)); |
92 | |
93 | // Rewrite |original_load| into an extract. |
94 | Instruction::OperandList new_operands; |
95 | |
96 | // copy the result id and the type id to the new operand list. |
97 | new_operands.emplace_back(original_load->GetOperand(0)); |
98 | new_operands.emplace_back(original_load->GetOperand(1)); |
99 | |
100 | new_operands.emplace_back( |
101 | Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}})); |
102 | AppendConstantOperands(address_inst, &new_operands); |
103 | original_load->SetOpcode(SpvOpCompositeExtract); |
104 | original_load->ReplaceOperands(new_operands); |
105 | context()->UpdateDefUse(original_load); |
106 | return true; |
107 | } |
108 | |
109 | bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement( |
110 | const Instruction* ptrInst, uint32_t valId, |
111 | std::vector<std::unique_ptr<Instruction>>* newInsts) { |
112 | // Build and append load of variable in ptrInst |
113 | uint32_t varId; |
114 | uint32_t varPteTypeId; |
115 | const uint32_t ldResultId = |
116 | BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts); |
117 | if (ldResultId == 0) { |
118 | return false; |
119 | } |
120 | |
121 | context()->get_decoration_mgr()->CloneDecorations( |
122 | varId, ldResultId, {SpvDecorationRelaxedPrecision}); |
123 | |
124 | // Build and append Insert |
125 | const uint32_t insResultId = TakeNextId(); |
126 | if (insResultId == 0) { |
127 | return false; |
128 | } |
129 | std::vector<Operand> ins_in_opnds = { |
130 | {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}, |
131 | {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}}; |
132 | AppendConstantOperands(ptrInst, &ins_in_opnds); |
133 | BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId, |
134 | ins_in_opnds, newInsts); |
135 | |
136 | context()->get_decoration_mgr()->CloneDecorations( |
137 | varId, insResultId, {SpvDecorationRelaxedPrecision}); |
138 | |
139 | // Build and append Store |
140 | BuildAndAppendInst(SpvOpStore, 0, 0, |
141 | {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}}, |
142 | {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}}, |
143 | newInsts); |
144 | return true; |
145 | } |
146 | |
147 | bool LocalAccessChainConvertPass::IsConstantIndexAccessChain( |
148 | const Instruction* acp) const { |
149 | uint32_t inIdx = 0; |
150 | return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) { |
151 | if (inIdx > 0) { |
152 | Instruction* opInst = get_def_use_mgr()->GetDef(*tid); |
153 | if (opInst->opcode() != SpvOpConstant) return false; |
154 | } |
155 | ++inIdx; |
156 | return true; |
157 | }); |
158 | } |
159 | |
160 | bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) { |
161 | if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true; |
162 | if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) { |
163 | SpvOp op = user->opcode(); |
164 | if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) { |
165 | if (!HasOnlySupportedRefs(user->result_id())) { |
166 | return false; |
167 | } |
168 | } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName && |
169 | !IsNonTypeDecorate(op)) { |
170 | return false; |
171 | } |
172 | return true; |
173 | })) { |
174 | supported_ref_ptrs_.insert(ptrId); |
175 | return true; |
176 | } |
177 | return false; |
178 | } |
179 | |
180 | void LocalAccessChainConvertPass::FindTargetVars(Function* func) { |
181 | for (auto bi = func->begin(); bi != func->end(); ++bi) { |
182 | for (auto ii = bi->begin(); ii != bi->end(); ++ii) { |
183 | switch (ii->opcode()) { |
184 | case SpvOpStore: |
185 | case SpvOpLoad: { |
186 | uint32_t varId; |
187 | Instruction* ptrInst = GetPtr(&*ii, &varId); |
188 | if (!IsTargetVar(varId)) break; |
189 | const SpvOp op = ptrInst->opcode(); |
190 | // Rule out variables with non-supported refs eg function calls |
191 | if (!HasOnlySupportedRefs(varId)) { |
192 | seen_non_target_vars_.insert(varId); |
193 | seen_target_vars_.erase(varId); |
194 | break; |
195 | } |
196 | // Rule out variables with nested access chains |
197 | // TODO(): Convert nested access chains |
198 | if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand( |
199 | kAccessChainPtrIdInIdx) != varId) { |
200 | seen_non_target_vars_.insert(varId); |
201 | seen_target_vars_.erase(varId); |
202 | break; |
203 | } |
204 | // Rule out variables accessed with non-constant indices |
205 | if (!IsConstantIndexAccessChain(ptrInst)) { |
206 | seen_non_target_vars_.insert(varId); |
207 | seen_target_vars_.erase(varId); |
208 | break; |
209 | } |
210 | } break; |
211 | default: |
212 | break; |
213 | } |
214 | } |
215 | } |
216 | } |
217 | |
218 | Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains( |
219 | Function* func) { |
220 | FindTargetVars(func); |
221 | // Replace access chains of all targeted variables with equivalent |
222 | // extract and insert sequences |
223 | bool modified = false; |
224 | for (auto bi = func->begin(); bi != func->end(); ++bi) { |
225 | std::vector<Instruction*> dead_instructions; |
226 | for (auto ii = bi->begin(); ii != bi->end(); ++ii) { |
227 | switch (ii->opcode()) { |
228 | case SpvOpLoad: { |
229 | uint32_t varId; |
230 | Instruction* ptrInst = GetPtr(&*ii, &varId); |
231 | if (!IsNonPtrAccessChain(ptrInst->opcode())) break; |
232 | if (!IsTargetVar(varId)) break; |
233 | std::vector<std::unique_ptr<Instruction>> newInsts; |
234 | if (!ReplaceAccessChainLoad(ptrInst, &*ii)) { |
235 | return Status::Failure; |
236 | } |
237 | modified = true; |
238 | } break; |
239 | case SpvOpStore: { |
240 | uint32_t varId; |
241 | Instruction* ptrInst = GetPtr(&*ii, &varId); |
242 | if (!IsNonPtrAccessChain(ptrInst->opcode())) break; |
243 | if (!IsTargetVar(varId)) break; |
244 | std::vector<std::unique_ptr<Instruction>> newInsts; |
245 | uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx); |
246 | if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) { |
247 | return Status::Failure; |
248 | } |
249 | dead_instructions.push_back(&*ii); |
250 | ++ii; |
251 | ii = ii.InsertBefore(std::move(newInsts)); |
252 | ++ii; |
253 | ++ii; |
254 | modified = true; |
255 | } break; |
256 | default: |
257 | break; |
258 | } |
259 | } |
260 | |
261 | while (!dead_instructions.empty()) { |
262 | Instruction* inst = dead_instructions.back(); |
263 | dead_instructions.pop_back(); |
264 | DCEInst(inst, [&dead_instructions](Instruction* other_inst) { |
265 | auto i = std::find(dead_instructions.begin(), dead_instructions.end(), |
266 | other_inst); |
267 | if (i != dead_instructions.end()) { |
268 | dead_instructions.erase(i); |
269 | } |
270 | }); |
271 | } |
272 | } |
273 | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
274 | } |
275 | |
276 | void LocalAccessChainConvertPass::Initialize() { |
277 | // Initialize Target Variable Caches |
278 | seen_target_vars_.clear(); |
279 | seen_non_target_vars_.clear(); |
280 | |
281 | // Initialize collections |
282 | supported_ref_ptrs_.clear(); |
283 | |
284 | // Initialize extension whitelist |
285 | InitExtensions(); |
286 | } |
287 | |
288 | bool LocalAccessChainConvertPass::AllExtensionsSupported() const { |
289 | // This capability can now exist without the extension, so we have to check |
290 | // for the capability. This pass is only looking at function scope symbols, |
291 | // so we do not care if there are variable pointers on storage buffers. |
292 | if (context()->get_feature_mgr()->HasCapability( |
293 | SpvCapabilityVariablePointers)) |
294 | return false; |
295 | // If any extension not in whitelist, return false |
296 | for (auto& ei : get_module()->extensions()) { |
297 | const char* extName = |
298 | reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]); |
299 | if (extensions_whitelist_.find(extName) == extensions_whitelist_.end()) |
300 | return false; |
301 | } |
302 | return true; |
303 | } |
304 | |
305 | Pass::Status LocalAccessChainConvertPass::ProcessImpl() { |
306 | // If non-32-bit integer type in module, terminate processing |
307 | // TODO(): Handle non-32-bit integer constants in access chains |
308 | for (const Instruction& inst : get_module()->types_values()) |
309 | if (inst.opcode() == SpvOpTypeInt && |
310 | inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) |
311 | return Status::SuccessWithoutChange; |
312 | // Do not process if module contains OpGroupDecorate. Additional |
313 | // support required in KillNamesAndDecorates(). |
314 | // TODO(greg-lunarg): Add support for OpGroupDecorate |
315 | for (auto& ai : get_module()->annotations()) |
316 | if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange; |
317 | // Do not process if any disallowed extensions are enabled |
318 | if (!AllExtensionsSupported()) return Status::SuccessWithoutChange; |
319 | |
320 | // Process all functions in the module. |
321 | Status status = Status::SuccessWithoutChange; |
322 | for (Function& func : *get_module()) { |
323 | status = CombineStatus(status, ConvertLocalAccessChains(&func)); |
324 | if (status == Status::Failure) { |
325 | break; |
326 | } |
327 | } |
328 | return status; |
329 | } |
330 | |
331 | LocalAccessChainConvertPass::LocalAccessChainConvertPass() {} |
332 | |
333 | Pass::Status LocalAccessChainConvertPass::Process() { |
334 | Initialize(); |
335 | return ProcessImpl(); |
336 | } |
337 | |
338 | void LocalAccessChainConvertPass::InitExtensions() { |
339 | extensions_whitelist_.clear(); |
340 | extensions_whitelist_.insert({ |
341 | "SPV_AMD_shader_explicit_vertex_parameter" , |
342 | "SPV_AMD_shader_trinary_minmax" , |
343 | "SPV_AMD_gcn_shader" , |
344 | "SPV_KHR_shader_ballot" , |
345 | "SPV_AMD_shader_ballot" , |
346 | "SPV_AMD_gpu_shader_half_float" , |
347 | "SPV_KHR_shader_draw_parameters" , |
348 | "SPV_KHR_subgroup_vote" , |
349 | "SPV_KHR_16bit_storage" , |
350 | "SPV_KHR_device_group" , |
351 | "SPV_KHR_multiview" , |
352 | "SPV_NVX_multiview_per_view_attributes" , |
353 | "SPV_NV_viewport_array2" , |
354 | "SPV_NV_stereo_view_rendering" , |
355 | "SPV_NV_sample_mask_override_coverage" , |
356 | "SPV_NV_geometry_shader_passthrough" , |
357 | "SPV_AMD_texture_gather_bias_lod" , |
358 | "SPV_KHR_storage_buffer_storage_class" , |
359 | // SPV_KHR_variable_pointers |
360 | // Currently do not support extended pointer expressions |
361 | "SPV_AMD_gpu_shader_int16" , |
362 | "SPV_KHR_post_depth_coverage" , |
363 | "SPV_KHR_shader_atomic_counter_ops" , |
364 | "SPV_EXT_shader_stencil_export" , |
365 | "SPV_EXT_shader_viewport_index_layer" , |
366 | "SPV_AMD_shader_image_load_store_lod" , |
367 | "SPV_AMD_shader_fragment_mask" , |
368 | "SPV_EXT_fragment_fully_covered" , |
369 | "SPV_AMD_gpu_shader_half_float_fetch" , |
370 | "SPV_GOOGLE_decorate_string" , |
371 | "SPV_GOOGLE_hlsl_functionality1" , |
372 | "SPV_GOOGLE_user_type" , |
373 | "SPV_NV_shader_subgroup_partitioned" , |
374 | "SPV_EXT_demote_to_helper_invocation" , |
375 | "SPV_EXT_descriptor_indexing" , |
376 | "SPV_NV_fragment_shader_barycentric" , |
377 | "SPV_NV_compute_shader_derivatives" , |
378 | "SPV_NV_shader_image_footprint" , |
379 | "SPV_NV_shading_rate" , |
380 | "SPV_NV_mesh_shader" , |
381 | "SPV_NV_ray_tracing" , |
382 | "SPV_KHR_ray_tracing" , |
383 | "SPV_KHR_ray_query" , |
384 | "SPV_EXT_fragment_invocation_density" , |
385 | }); |
386 | } |
387 | |
388 | } // namespace opt |
389 | } // namespace spvtools |
390 | |