1 | // Copyright (c) 2017 Google Inc. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/scalar_replacement_pass.h" |
16 | |
17 | #include <algorithm> |
18 | #include <queue> |
19 | #include <tuple> |
20 | #include <utility> |
21 | |
22 | #include "source/enum_string_mapping.h" |
23 | #include "source/extensions.h" |
24 | #include "source/opt/reflect.h" |
25 | #include "source/opt/types.h" |
26 | #include "source/util/make_unique.h" |
27 | |
28 | namespace spvtools { |
29 | namespace opt { |
30 | |
31 | Pass::Status ScalarReplacementPass::Process() { |
32 | Status status = Status::SuccessWithoutChange; |
33 | for (auto& f : *get_module()) { |
34 | Status functionStatus = ProcessFunction(&f); |
35 | if (functionStatus == Status::Failure) |
36 | return functionStatus; |
37 | else if (functionStatus == Status::SuccessWithChange) |
38 | status = functionStatus; |
39 | } |
40 | |
41 | return status; |
42 | } |
43 | |
44 | Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) { |
45 | std::queue<Instruction*> worklist; |
46 | BasicBlock& entry = *function->begin(); |
47 | for (auto iter = entry.begin(); iter != entry.end(); ++iter) { |
48 | // Function storage class OpVariables must appear as the first instructions |
49 | // of the entry block. |
50 | if (iter->opcode() != SpvOpVariable) break; |
51 | |
52 | Instruction* varInst = &*iter; |
53 | if (CanReplaceVariable(varInst)) { |
54 | worklist.push(varInst); |
55 | } |
56 | } |
57 | |
58 | Status status = Status::SuccessWithoutChange; |
59 | while (!worklist.empty()) { |
60 | Instruction* varInst = worklist.front(); |
61 | worklist.pop(); |
62 | |
63 | Status var_status = ReplaceVariable(varInst, &worklist); |
64 | if (var_status == Status::Failure) |
65 | return var_status; |
66 | else if (var_status == Status::SuccessWithChange) |
67 | status = var_status; |
68 | } |
69 | |
70 | return status; |
71 | } |
72 | |
73 | Pass::Status ScalarReplacementPass::ReplaceVariable( |
74 | Instruction* inst, std::queue<Instruction*>* worklist) { |
75 | std::vector<Instruction*> replacements; |
76 | if (!CreateReplacementVariables(inst, &replacements)) { |
77 | return Status::Failure; |
78 | } |
79 | |
80 | std::vector<Instruction*> dead; |
81 | bool replaced_all_uses = get_def_use_mgr()->WhileEachUser( |
82 | inst, [this, &replacements, &dead](Instruction* user) { |
83 | if (!IsAnnotationInst(user->opcode())) { |
84 | switch (user->opcode()) { |
85 | case SpvOpLoad: |
86 | if (ReplaceWholeLoad(user, replacements)) { |
87 | dead.push_back(user); |
88 | } else { |
89 | return false; |
90 | } |
91 | break; |
92 | case SpvOpStore: |
93 | if (ReplaceWholeStore(user, replacements)) { |
94 | dead.push_back(user); |
95 | } else { |
96 | return false; |
97 | } |
98 | break; |
99 | case SpvOpAccessChain: |
100 | case SpvOpInBoundsAccessChain: |
101 | if (ReplaceAccessChain(user, replacements)) |
102 | dead.push_back(user); |
103 | else |
104 | return false; |
105 | break; |
106 | case SpvOpName: |
107 | case SpvOpMemberName: |
108 | break; |
109 | default: |
110 | assert(false && "Unexpected opcode" ); |
111 | break; |
112 | } |
113 | } |
114 | return true; |
115 | }); |
116 | |
117 | if (replaced_all_uses) { |
118 | dead.push_back(inst); |
119 | } else { |
120 | return Status::Failure; |
121 | } |
122 | |
123 | // If there are no dead instructions to clean up, return with no changes. |
124 | if (dead.empty()) return Status::SuccessWithoutChange; |
125 | |
126 | // Clean up some dead code. |
127 | while (!dead.empty()) { |
128 | Instruction* toKill = dead.back(); |
129 | dead.pop_back(); |
130 | context()->KillInst(toKill); |
131 | } |
132 | |
133 | // Attempt to further scalarize. |
134 | for (auto var : replacements) { |
135 | if (var->opcode() == SpvOpVariable) { |
136 | if (get_def_use_mgr()->NumUsers(var) == 0) { |
137 | context()->KillInst(var); |
138 | } else if (CanReplaceVariable(var)) { |
139 | worklist->push(var); |
140 | } |
141 | } |
142 | } |
143 | |
144 | return Status::SuccessWithChange; |
145 | } |
146 | |
147 | bool ScalarReplacementPass::ReplaceWholeLoad( |
148 | Instruction* load, const std::vector<Instruction*>& replacements) { |
149 | // Replaces the load of the entire composite with a load from each replacement |
150 | // variable followed by a composite construction. |
151 | BasicBlock* block = context()->get_instr_block(load); |
152 | std::vector<Instruction*> loads; |
153 | loads.reserve(replacements.size()); |
154 | BasicBlock::iterator where(load); |
155 | for (auto var : replacements) { |
156 | // Create a load of each replacement variable. |
157 | if (var->opcode() != SpvOpVariable) { |
158 | loads.push_back(var); |
159 | continue; |
160 | } |
161 | |
162 | Instruction* type = GetStorageType(var); |
163 | uint32_t loadId = TakeNextId(); |
164 | if (loadId == 0) { |
165 | return false; |
166 | } |
167 | std::unique_ptr<Instruction> newLoad( |
168 | new Instruction(context(), SpvOpLoad, type->result_id(), loadId, |
169 | std::initializer_list<Operand>{ |
170 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
171 | // Copy memory access attributes which start at index 1. Index 0 is the |
172 | // pointer to load. |
173 | for (uint32_t i = 1; i < load->NumInOperands(); ++i) { |
174 | Operand copy(load->GetInOperand(i)); |
175 | newLoad->AddOperand(std::move(copy)); |
176 | } |
177 | where = where.InsertBefore(std::move(newLoad)); |
178 | get_def_use_mgr()->AnalyzeInstDefUse(&*where); |
179 | context()->set_instr_block(&*where, block); |
180 | loads.push_back(&*where); |
181 | } |
182 | |
183 | // Construct a new composite. |
184 | uint32_t compositeId = TakeNextId(); |
185 | if (compositeId == 0) { |
186 | return false; |
187 | } |
188 | where = load; |
189 | std::unique_ptr<Instruction> compositeConstruct(new Instruction( |
190 | context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {})); |
191 | for (auto l : loads) { |
192 | Operand op(SPV_OPERAND_TYPE_ID, |
193 | std::initializer_list<uint32_t>{l->result_id()}); |
194 | compositeConstruct->AddOperand(std::move(op)); |
195 | } |
196 | where = where.InsertBefore(std::move(compositeConstruct)); |
197 | get_def_use_mgr()->AnalyzeInstDefUse(&*where); |
198 | context()->set_instr_block(&*where, block); |
199 | context()->ReplaceAllUsesWith(load->result_id(), compositeId); |
200 | return true; |
201 | } |
202 | |
203 | bool ScalarReplacementPass::ReplaceWholeStore( |
204 | Instruction* store, const std::vector<Instruction*>& replacements) { |
205 | // Replaces a store to the whole composite with a series of extract and stores |
206 | // to each element. |
207 | uint32_t storeInput = store->GetSingleWordInOperand(1u); |
208 | BasicBlock* block = context()->get_instr_block(store); |
209 | BasicBlock::iterator where(store); |
210 | uint32_t elementIndex = 0; |
211 | for (auto var : replacements) { |
212 | // Create the extract. |
213 | if (var->opcode() != SpvOpVariable) { |
214 | elementIndex++; |
215 | continue; |
216 | } |
217 | |
218 | Instruction* type = GetStorageType(var); |
219 | uint32_t = TakeNextId(); |
220 | if (extractId == 0) { |
221 | return false; |
222 | } |
223 | std::unique_ptr<Instruction> (new Instruction( |
224 | context(), SpvOpCompositeExtract, type->result_id(), extractId, |
225 | std::initializer_list<Operand>{ |
226 | {SPV_OPERAND_TYPE_ID, {storeInput}}, |
227 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}})); |
228 | auto iter = where.InsertBefore(std::move(extract)); |
229 | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
230 | context()->set_instr_block(&*iter, block); |
231 | |
232 | // Create the store. |
233 | std::unique_ptr<Instruction> newStore( |
234 | new Instruction(context(), SpvOpStore, 0, 0, |
235 | std::initializer_list<Operand>{ |
236 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
237 | {SPV_OPERAND_TYPE_ID, {extractId}}})); |
238 | // Copy memory access attributes which start at index 2. Index 0 is the |
239 | // pointer and index 1 is the data. |
240 | for (uint32_t i = 2; i < store->NumInOperands(); ++i) { |
241 | Operand copy(store->GetInOperand(i)); |
242 | newStore->AddOperand(std::move(copy)); |
243 | } |
244 | iter = where.InsertBefore(std::move(newStore)); |
245 | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
246 | context()->set_instr_block(&*iter, block); |
247 | } |
248 | return true; |
249 | } |
250 | |
251 | bool ScalarReplacementPass::ReplaceAccessChain( |
252 | Instruction* chain, const std::vector<Instruction*>& replacements) { |
253 | // Replaces the access chain with either another access chain (with one fewer |
254 | // indexes) or a direct use of the replacement variable. |
255 | uint32_t indexId = chain->GetSingleWordInOperand(1u); |
256 | const Instruction* index = get_def_use_mgr()->GetDef(indexId); |
257 | int64_t indexValue = context() |
258 | ->get_constant_mgr() |
259 | ->GetConstantFromInst(index) |
260 | ->GetSignExtendedValue(); |
261 | if (indexValue < 0 || |
262 | indexValue >= static_cast<int64_t>(replacements.size())) { |
263 | // Out of bounds access, this is illegal IR. Notice that OpAccessChain |
264 | // indexing is 0-based, so we should also reject index == size-of-array. |
265 | return false; |
266 | } else { |
267 | const Instruction* var = replacements[static_cast<size_t>(indexValue)]; |
268 | if (chain->NumInOperands() > 2) { |
269 | // Replace input access chain with another access chain. |
270 | BasicBlock::iterator chainIter(chain); |
271 | uint32_t replacementId = TakeNextId(); |
272 | if (replacementId == 0) { |
273 | return false; |
274 | } |
275 | std::unique_ptr<Instruction> replacementChain(new Instruction( |
276 | context(), chain->opcode(), chain->type_id(), replacementId, |
277 | std::initializer_list<Operand>{ |
278 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}})); |
279 | // Add the remaining indexes. |
280 | for (uint32_t i = 2; i < chain->NumInOperands(); ++i) { |
281 | Operand copy(chain->GetInOperand(i)); |
282 | replacementChain->AddOperand(std::move(copy)); |
283 | } |
284 | auto iter = chainIter.InsertBefore(std::move(replacementChain)); |
285 | get_def_use_mgr()->AnalyzeInstDefUse(&*iter); |
286 | context()->set_instr_block(&*iter, context()->get_instr_block(chain)); |
287 | context()->ReplaceAllUsesWith(chain->result_id(), replacementId); |
288 | } else { |
289 | // Replace with a use of the variable. |
290 | context()->ReplaceAllUsesWith(chain->result_id(), var->result_id()); |
291 | } |
292 | } |
293 | |
294 | return true; |
295 | } |
296 | |
297 | bool ScalarReplacementPass::CreateReplacementVariables( |
298 | Instruction* inst, std::vector<Instruction*>* replacements) { |
299 | Instruction* type = GetStorageType(inst); |
300 | |
301 | std::unique_ptr<std::unordered_set<int64_t>> components_used = |
302 | GetUsedComponents(inst); |
303 | |
304 | uint32_t elem = 0; |
305 | switch (type->opcode()) { |
306 | case SpvOpTypeStruct: |
307 | type->ForEachInOperand( |
308 | [this, inst, &elem, replacements, &components_used](uint32_t* id) { |
309 | if (!components_used || components_used->count(elem)) { |
310 | CreateVariable(*id, inst, elem, replacements); |
311 | } else { |
312 | replacements->push_back(CreateNullConstant(*id)); |
313 | } |
314 | elem++; |
315 | }); |
316 | break; |
317 | case SpvOpTypeArray: |
318 | for (uint32_t i = 0; i != GetArrayLength(type); ++i) { |
319 | if (!components_used || components_used->count(i)) { |
320 | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, |
321 | replacements); |
322 | } else { |
323 | replacements->push_back( |
324 | CreateNullConstant(type->GetSingleWordInOperand(0u))); |
325 | } |
326 | } |
327 | break; |
328 | |
329 | case SpvOpTypeMatrix: |
330 | case SpvOpTypeVector: |
331 | for (uint32_t i = 0; i != GetNumElements(type); ++i) { |
332 | CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); |
333 | } |
334 | break; |
335 | |
336 | default: |
337 | assert(false && "Unexpected type." ); |
338 | break; |
339 | } |
340 | |
341 | TransferAnnotations(inst, replacements); |
342 | return std::find(replacements->begin(), replacements->end(), nullptr) == |
343 | replacements->end(); |
344 | } |
345 | |
346 | void ScalarReplacementPass::TransferAnnotations( |
347 | const Instruction* source, std::vector<Instruction*>* replacements) { |
348 | // Only transfer invariant and restrict decorations on the variable. There are |
349 | // no type or member decorations that are necessary to transfer. |
350 | for (auto inst : |
351 | get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) { |
352 | assert(inst->opcode() == SpvOpDecorate); |
353 | uint32_t decoration = inst->GetSingleWordInOperand(1u); |
354 | if (decoration == SpvDecorationInvariant || |
355 | decoration == SpvDecorationRestrict) { |
356 | for (auto var : *replacements) { |
357 | if (var == nullptr) { |
358 | continue; |
359 | } |
360 | |
361 | std::unique_ptr<Instruction> annotation( |
362 | new Instruction(context(), SpvOpDecorate, 0, 0, |
363 | std::initializer_list<Operand>{ |
364 | {SPV_OPERAND_TYPE_ID, {var->result_id()}}, |
365 | {SPV_OPERAND_TYPE_DECORATION, {decoration}}})); |
366 | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
367 | Operand copy(inst->GetInOperand(i)); |
368 | annotation->AddOperand(std::move(copy)); |
369 | } |
370 | context()->AddAnnotationInst(std::move(annotation)); |
371 | get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end()); |
372 | } |
373 | } |
374 | } |
375 | } |
376 | |
377 | void ScalarReplacementPass::CreateVariable( |
378 | uint32_t typeId, Instruction* varInst, uint32_t index, |
379 | std::vector<Instruction*>* replacements) { |
380 | uint32_t ptrId = GetOrCreatePointerType(typeId); |
381 | uint32_t id = TakeNextId(); |
382 | |
383 | if (id == 0) { |
384 | replacements->push_back(nullptr); |
385 | } |
386 | |
387 | std::unique_ptr<Instruction> variable(new Instruction( |
388 | context(), SpvOpVariable, ptrId, id, |
389 | std::initializer_list<Operand>{ |
390 | {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}})); |
391 | |
392 | BasicBlock* block = context()->get_instr_block(varInst); |
393 | block->begin().InsertBefore(std::move(variable)); |
394 | Instruction* inst = &*block->begin(); |
395 | |
396 | // If varInst was initialized, make sure to initialize its replacement. |
397 | GetOrCreateInitialValue(varInst, index, inst); |
398 | get_def_use_mgr()->AnalyzeInstDefUse(inst); |
399 | context()->set_instr_block(inst, block); |
400 | |
401 | // Copy decorations from the member to the new variable. |
402 | Instruction* typeInst = GetStorageType(varInst); |
403 | for (auto dec_inst : |
404 | get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { |
405 | uint32_t decoration; |
406 | if (dec_inst->opcode() != SpvOpMemberDecorate) { |
407 | continue; |
408 | } |
409 | |
410 | if (dec_inst->GetSingleWordInOperand(1) != index) { |
411 | continue; |
412 | } |
413 | |
414 | decoration = dec_inst->GetSingleWordInOperand(2u); |
415 | switch (decoration) { |
416 | case SpvDecorationRelaxedPrecision: { |
417 | std::unique_ptr<Instruction> new_dec_inst( |
418 | new Instruction(context(), SpvOpDecorate, 0, 0, {})); |
419 | new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id})); |
420 | for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) { |
421 | new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i))); |
422 | } |
423 | context()->AddAnnotationInst(std::move(new_dec_inst)); |
424 | } break; |
425 | default: |
426 | break; |
427 | } |
428 | } |
429 | |
430 | replacements->push_back(inst); |
431 | } |
432 | |
433 | uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) { |
434 | auto iter = pointee_to_pointer_.find(id); |
435 | if (iter != pointee_to_pointer_.end()) return iter->second; |
436 | |
437 | analysis::Type* pointeeTy; |
438 | std::unique_ptr<analysis::Pointer> pointerTy; |
439 | std::tie(pointeeTy, pointerTy) = |
440 | context()->get_type_mgr()->GetTypeAndPointerType(id, |
441 | SpvStorageClassFunction); |
442 | uint32_t ptrId = 0; |
443 | if (pointeeTy->IsUniqueType()) { |
444 | // Non-ambiguous type, just ask the type manager for an id. |
445 | ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get()); |
446 | pointee_to_pointer_[id] = ptrId; |
447 | return ptrId; |
448 | } |
449 | |
450 | // Ambiguous type. We must perform a linear search to try and find the right |
451 | // type. |
452 | for (auto global : context()->types_values()) { |
453 | if (global.opcode() == SpvOpTypePointer && |
454 | global.GetSingleWordInOperand(0u) == SpvStorageClassFunction && |
455 | global.GetSingleWordInOperand(1u) == id) { |
456 | if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) { |
457 | // Only reuse a decoration-less pointer of the correct type. |
458 | ptrId = global.result_id(); |
459 | break; |
460 | } |
461 | } |
462 | } |
463 | |
464 | if (ptrId != 0) { |
465 | pointee_to_pointer_[id] = ptrId; |
466 | return ptrId; |
467 | } |
468 | |
469 | ptrId = TakeNextId(); |
470 | context()->AddType(MakeUnique<Instruction>( |
471 | context(), SpvOpTypePointer, 0, ptrId, |
472 | std::initializer_list<Operand>{ |
473 | {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}, |
474 | {SPV_OPERAND_TYPE_ID, {id}}})); |
475 | Instruction* ptr = &*--context()->types_values_end(); |
476 | get_def_use_mgr()->AnalyzeInstDefUse(ptr); |
477 | pointee_to_pointer_[id] = ptrId; |
478 | // Register with the type manager if necessary. |
479 | context()->get_type_mgr()->RegisterType(ptrId, *pointerTy); |
480 | |
481 | return ptrId; |
482 | } |
483 | |
484 | void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source, |
485 | uint32_t index, |
486 | Instruction* newVar) { |
487 | assert(source->opcode() == SpvOpVariable); |
488 | if (source->NumInOperands() < 2) return; |
489 | |
490 | uint32_t initId = source->GetSingleWordInOperand(1u); |
491 | uint32_t storageId = GetStorageType(newVar)->result_id(); |
492 | Instruction* init = get_def_use_mgr()->GetDef(initId); |
493 | uint32_t newInitId = 0; |
494 | // TODO(dnovillo): Refactor this with constant propagation. |
495 | if (init->opcode() == SpvOpConstantNull) { |
496 | // Initialize to appropriate NULL. |
497 | auto iter = type_to_null_.find(storageId); |
498 | if (iter == type_to_null_.end()) { |
499 | newInitId = TakeNextId(); |
500 | type_to_null_[storageId] = newInitId; |
501 | context()->AddGlobalValue( |
502 | MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId, |
503 | newInitId, std::initializer_list<Operand>{})); |
504 | Instruction* newNull = &*--context()->types_values_end(); |
505 | get_def_use_mgr()->AnalyzeInstDefUse(newNull); |
506 | } else { |
507 | newInitId = iter->second; |
508 | } |
509 | } else if (IsSpecConstantInst(init->opcode())) { |
510 | // Create a new constant extract. |
511 | newInitId = TakeNextId(); |
512 | context()->AddGlobalValue(MakeUnique<Instruction>( |
513 | context(), SpvOpSpecConstantOp, storageId, newInitId, |
514 | std::initializer_list<Operand>{ |
515 | {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}}, |
516 | {SPV_OPERAND_TYPE_ID, {init->result_id()}}, |
517 | {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}})); |
518 | Instruction* newSpecConst = &*--context()->types_values_end(); |
519 | get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst); |
520 | } else if (init->opcode() == SpvOpConstantComposite) { |
521 | // Get the appropriate index constant. |
522 | newInitId = init->GetSingleWordInOperand(index); |
523 | Instruction* element = get_def_use_mgr()->GetDef(newInitId); |
524 | if (element->opcode() == SpvOpUndef) { |
525 | // Undef is not a valid initializer for a variable. |
526 | newInitId = 0; |
527 | } |
528 | } else { |
529 | assert(false); |
530 | } |
531 | |
532 | if (newInitId != 0) { |
533 | newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}}); |
534 | } |
535 | } |
536 | |
537 | uint64_t ScalarReplacementPass::GetArrayLength( |
538 | const Instruction* arrayType) const { |
539 | assert(arrayType->opcode() == SpvOpTypeArray); |
540 | const Instruction* length = |
541 | get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u)); |
542 | return context() |
543 | ->get_constant_mgr() |
544 | ->GetConstantFromInst(length) |
545 | ->GetZeroExtendedValue(); |
546 | } |
547 | |
548 | uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const { |
549 | assert(type->opcode() == SpvOpTypeVector || |
550 | type->opcode() == SpvOpTypeMatrix); |
551 | const Operand& op = type->GetInOperand(1u); |
552 | assert(op.words.size() <= 2); |
553 | uint64_t len = 0; |
554 | for (size_t i = 0; i != op.words.size(); ++i) { |
555 | len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i)); |
556 | } |
557 | return len; |
558 | } |
559 | |
560 | bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const { |
561 | const Instruction* inst = get_def_use_mgr()->GetDef(id); |
562 | assert(inst); |
563 | return spvOpcodeIsSpecConstant(inst->opcode()); |
564 | } |
565 | |
566 | Instruction* ScalarReplacementPass::GetStorageType( |
567 | const Instruction* inst) const { |
568 | assert(inst->opcode() == SpvOpVariable); |
569 | |
570 | uint32_t ptrTypeId = inst->type_id(); |
571 | uint32_t typeId = |
572 | get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u); |
573 | return get_def_use_mgr()->GetDef(typeId); |
574 | } |
575 | |
576 | bool ScalarReplacementPass::CanReplaceVariable( |
577 | const Instruction* varInst) const { |
578 | assert(varInst->opcode() == SpvOpVariable); |
579 | |
580 | // Can only replace function scope variables. |
581 | if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) { |
582 | return false; |
583 | } |
584 | |
585 | if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) { |
586 | return false; |
587 | } |
588 | |
589 | const Instruction* typeInst = GetStorageType(varInst); |
590 | if (!CheckType(typeInst)) { |
591 | return false; |
592 | } |
593 | |
594 | if (!CheckAnnotations(varInst)) { |
595 | return false; |
596 | } |
597 | |
598 | if (!CheckUses(varInst)) { |
599 | return false; |
600 | } |
601 | |
602 | return true; |
603 | } |
604 | |
605 | bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const { |
606 | if (!CheckTypeAnnotations(typeInst)) { |
607 | return false; |
608 | } |
609 | |
610 | switch (typeInst->opcode()) { |
611 | case SpvOpTypeStruct: |
612 | // Don't bother with empty structs or very large structs. |
613 | if (typeInst->NumInOperands() == 0 || |
614 | IsLargerThanSizeLimit(typeInst->NumInOperands())) { |
615 | return false; |
616 | } |
617 | return true; |
618 | case SpvOpTypeArray: |
619 | if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) { |
620 | return false; |
621 | } |
622 | if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) { |
623 | return false; |
624 | } |
625 | return true; |
626 | // TODO(alanbaker): Develop some heuristics for when this should be |
627 | // re-enabled. |
628 | //// Specifically including matrix and vector in an attempt to reduce the |
629 | //// number of vector registers required. |
630 | // case SpvOpTypeMatrix: |
631 | // case SpvOpTypeVector: |
632 | // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false; |
633 | // return true; |
634 | |
635 | case SpvOpTypeRuntimeArray: |
636 | default: |
637 | return false; |
638 | } |
639 | } |
640 | |
641 | bool ScalarReplacementPass::CheckTypeAnnotations( |
642 | const Instruction* typeInst) const { |
643 | for (auto inst : |
644 | get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) { |
645 | uint32_t decoration; |
646 | if (inst->opcode() == SpvOpDecorate) { |
647 | decoration = inst->GetSingleWordInOperand(1u); |
648 | } else { |
649 | assert(inst->opcode() == SpvOpMemberDecorate); |
650 | decoration = inst->GetSingleWordInOperand(2u); |
651 | } |
652 | |
653 | switch (decoration) { |
654 | case SpvDecorationRowMajor: |
655 | case SpvDecorationColMajor: |
656 | case SpvDecorationArrayStride: |
657 | case SpvDecorationMatrixStride: |
658 | case SpvDecorationCPacked: |
659 | case SpvDecorationInvariant: |
660 | case SpvDecorationRestrict: |
661 | case SpvDecorationOffset: |
662 | case SpvDecorationAlignment: |
663 | case SpvDecorationAlignmentId: |
664 | case SpvDecorationMaxByteOffset: |
665 | case SpvDecorationRelaxedPrecision: |
666 | break; |
667 | default: |
668 | return false; |
669 | } |
670 | } |
671 | |
672 | return true; |
673 | } |
674 | |
675 | bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const { |
676 | for (auto inst : |
677 | get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) { |
678 | assert(inst->opcode() == SpvOpDecorate); |
679 | uint32_t decoration = inst->GetSingleWordInOperand(1u); |
680 | switch (decoration) { |
681 | case SpvDecorationInvariant: |
682 | case SpvDecorationRestrict: |
683 | case SpvDecorationAlignment: |
684 | case SpvDecorationAlignmentId: |
685 | case SpvDecorationMaxByteOffset: |
686 | break; |
687 | default: |
688 | return false; |
689 | } |
690 | } |
691 | |
692 | return true; |
693 | } |
694 | |
695 | bool ScalarReplacementPass::CheckUses(const Instruction* inst) const { |
696 | VariableStats stats = {0, 0}; |
697 | bool ok = CheckUses(inst, &stats); |
698 | |
699 | // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when |
700 | // SRoA is costly, such as when the structure has many (unaccessed?) |
701 | // members. |
702 | |
703 | return ok; |
704 | } |
705 | |
706 | bool ScalarReplacementPass::CheckUses(const Instruction* inst, |
707 | VariableStats* stats) const { |
708 | uint64_t max_legal_index = GetMaxLegalIndex(inst); |
709 | |
710 | bool ok = true; |
711 | get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok]( |
712 | const Instruction* user, |
713 | uint32_t index) { |
714 | // Annotations are check as a group separately. |
715 | if (!IsAnnotationInst(user->opcode())) { |
716 | switch (user->opcode()) { |
717 | case SpvOpAccessChain: |
718 | case SpvOpInBoundsAccessChain: |
719 | if (index == 2u && user->NumInOperands() > 1) { |
720 | uint32_t id = user->GetSingleWordInOperand(1u); |
721 | const Instruction* opInst = get_def_use_mgr()->GetDef(id); |
722 | const auto* constant = |
723 | context()->get_constant_mgr()->GetConstantFromInst(opInst); |
724 | if (!constant) { |
725 | ok = false; |
726 | } else if (constant->GetZeroExtendedValue() >= max_legal_index) { |
727 | ok = false; |
728 | } else { |
729 | if (!CheckUsesRelaxed(user)) ok = false; |
730 | } |
731 | stats->num_partial_accesses++; |
732 | } else { |
733 | ok = false; |
734 | } |
735 | break; |
736 | case SpvOpLoad: |
737 | if (!CheckLoad(user, index)) ok = false; |
738 | stats->num_full_accesses++; |
739 | break; |
740 | case SpvOpStore: |
741 | if (!CheckStore(user, index)) ok = false; |
742 | stats->num_full_accesses++; |
743 | break; |
744 | case SpvOpName: |
745 | case SpvOpMemberName: |
746 | break; |
747 | default: |
748 | ok = false; |
749 | break; |
750 | } |
751 | } |
752 | }); |
753 | |
754 | return ok; |
755 | } |
756 | |
757 | bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const { |
758 | bool ok = true; |
759 | get_def_use_mgr()->ForEachUse( |
760 | inst, [this, &ok](const Instruction* user, uint32_t index) { |
761 | switch (user->opcode()) { |
762 | case SpvOpAccessChain: |
763 | case SpvOpInBoundsAccessChain: |
764 | if (index != 2u) { |
765 | ok = false; |
766 | } else { |
767 | if (!CheckUsesRelaxed(user)) ok = false; |
768 | } |
769 | break; |
770 | case SpvOpLoad: |
771 | if (!CheckLoad(user, index)) ok = false; |
772 | break; |
773 | case SpvOpStore: |
774 | if (!CheckStore(user, index)) ok = false; |
775 | break; |
776 | default: |
777 | ok = false; |
778 | break; |
779 | } |
780 | }); |
781 | |
782 | return ok; |
783 | } |
784 | |
785 | bool ScalarReplacementPass::CheckLoad(const Instruction* inst, |
786 | uint32_t index) const { |
787 | if (index != 2u) return false; |
788 | if (inst->NumInOperands() >= 2 && |
789 | inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask) |
790 | return false; |
791 | return true; |
792 | } |
793 | |
794 | bool ScalarReplacementPass::CheckStore(const Instruction* inst, |
795 | uint32_t index) const { |
796 | if (index != 0u) return false; |
797 | if (inst->NumInOperands() >= 3 && |
798 | inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask) |
799 | return false; |
800 | return true; |
801 | } |
802 | bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const { |
803 | if (max_num_elements_ == 0) { |
804 | return false; |
805 | } |
806 | return length > max_num_elements_; |
807 | } |
808 | |
809 | std::unique_ptr<std::unordered_set<int64_t>> |
810 | ScalarReplacementPass::GetUsedComponents(Instruction* inst) { |
811 | std::unique_ptr<std::unordered_set<int64_t>> result( |
812 | new std::unordered_set<int64_t>()); |
813 | |
814 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
815 | |
816 | def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr, |
817 | this](Instruction* use) { |
818 | switch (use->opcode()) { |
819 | case SpvOpLoad: { |
820 | // Look for extract from the load. |
821 | std::vector<uint32_t> t; |
822 | if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) { |
823 | if (use2->opcode() != SpvOpCompositeExtract || |
824 | use2->NumInOperands() <= 1) { |
825 | return false; |
826 | } |
827 | t.push_back(use2->GetSingleWordInOperand(1)); |
828 | return true; |
829 | })) { |
830 | result->insert(t.begin(), t.end()); |
831 | return true; |
832 | } else { |
833 | result.reset(nullptr); |
834 | return false; |
835 | } |
836 | } |
837 | case SpvOpName: |
838 | case SpvOpMemberName: |
839 | case SpvOpStore: |
840 | // No components are used. |
841 | return true; |
842 | case SpvOpAccessChain: |
843 | case SpvOpInBoundsAccessChain: { |
844 | // Add the first index it if is a constant. |
845 | // TODO: Could be improved by checking if the address is used in a load. |
846 | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); |
847 | uint32_t index_id = use->GetSingleWordInOperand(1); |
848 | const analysis::Constant* index_const = |
849 | const_mgr->FindDeclaredConstant(index_id); |
850 | if (index_const) { |
851 | result->insert(index_const->GetSignExtendedValue()); |
852 | return true; |
853 | } else { |
854 | // Could be any element. Assuming all are used. |
855 | result.reset(nullptr); |
856 | return false; |
857 | } |
858 | } |
859 | default: |
860 | // We do not know what is happening. Have to assume the worst. |
861 | result.reset(nullptr); |
862 | return false; |
863 | } |
864 | }); |
865 | |
866 | return result; |
867 | } |
868 | |
869 | Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { |
870 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
871 | analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); |
872 | |
873 | const analysis::Type* type = type_mgr->GetType(type_id); |
874 | const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); |
875 | Instruction* null_inst = |
876 | const_mgr->GetDefiningInstruction(null_const, type_id); |
877 | if (null_inst != nullptr) { |
878 | context()->UpdateDefUse(null_inst); |
879 | } |
880 | return null_inst; |
881 | } |
882 | |
883 | uint64_t ScalarReplacementPass::GetMaxLegalIndex( |
884 | const Instruction* var_inst) const { |
885 | assert(var_inst->opcode() == SpvOpVariable && |
886 | "|var_inst| must be a variable instruction." ); |
887 | Instruction* type = GetStorageType(var_inst); |
888 | switch (type->opcode()) { |
889 | case SpvOpTypeStruct: |
890 | return type->NumInOperands(); |
891 | case SpvOpTypeArray: |
892 | return GetArrayLength(type); |
893 | case SpvOpTypeMatrix: |
894 | case SpvOpTypeVector: |
895 | return GetNumElements(type); |
896 | default: |
897 | return 0; |
898 | } |
899 | return 0; |
900 | } |
901 | |
902 | } // namespace opt |
903 | } // namespace spvtools |
904 | |