1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/scalar_replacement_pass.h"
16
17#include <algorithm>
18#include <queue>
19#include <tuple>
20#include <utility>
21
22#include "source/enum_string_mapping.h"
23#include "source/extensions.h"
24#include "source/opt/reflect.h"
25#include "source/opt/types.h"
26#include "source/util/make_unique.h"
27
28namespace spvtools {
29namespace opt {
30
31Pass::Status ScalarReplacementPass::Process() {
32 Status status = Status::SuccessWithoutChange;
33 for (auto& f : *get_module()) {
34 Status functionStatus = ProcessFunction(&f);
35 if (functionStatus == Status::Failure)
36 return functionStatus;
37 else if (functionStatus == Status::SuccessWithChange)
38 status = functionStatus;
39 }
40
41 return status;
42}
43
44Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
45 std::queue<Instruction*> worklist;
46 BasicBlock& entry = *function->begin();
47 for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
48 // Function storage class OpVariables must appear as the first instructions
49 // of the entry block.
50 if (iter->opcode() != SpvOpVariable) break;
51
52 Instruction* varInst = &*iter;
53 if (CanReplaceVariable(varInst)) {
54 worklist.push(varInst);
55 }
56 }
57
58 Status status = Status::SuccessWithoutChange;
59 while (!worklist.empty()) {
60 Instruction* varInst = worklist.front();
61 worklist.pop();
62
63 Status var_status = ReplaceVariable(varInst, &worklist);
64 if (var_status == Status::Failure)
65 return var_status;
66 else if (var_status == Status::SuccessWithChange)
67 status = var_status;
68 }
69
70 return status;
71}
72
73Pass::Status ScalarReplacementPass::ReplaceVariable(
74 Instruction* inst, std::queue<Instruction*>* worklist) {
75 std::vector<Instruction*> replacements;
76 if (!CreateReplacementVariables(inst, &replacements)) {
77 return Status::Failure;
78 }
79
80 std::vector<Instruction*> dead;
81 bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
82 inst, [this, &replacements, &dead](Instruction* user) {
83 if (!IsAnnotationInst(user->opcode())) {
84 switch (user->opcode()) {
85 case SpvOpLoad:
86 if (ReplaceWholeLoad(user, replacements)) {
87 dead.push_back(user);
88 } else {
89 return false;
90 }
91 break;
92 case SpvOpStore:
93 if (ReplaceWholeStore(user, replacements)) {
94 dead.push_back(user);
95 } else {
96 return false;
97 }
98 break;
99 case SpvOpAccessChain:
100 case SpvOpInBoundsAccessChain:
101 if (ReplaceAccessChain(user, replacements))
102 dead.push_back(user);
103 else
104 return false;
105 break;
106 case SpvOpName:
107 case SpvOpMemberName:
108 break;
109 default:
110 assert(false && "Unexpected opcode");
111 break;
112 }
113 }
114 return true;
115 });
116
117 if (replaced_all_uses) {
118 dead.push_back(inst);
119 } else {
120 return Status::Failure;
121 }
122
123 // If there are no dead instructions to clean up, return with no changes.
124 if (dead.empty()) return Status::SuccessWithoutChange;
125
126 // Clean up some dead code.
127 while (!dead.empty()) {
128 Instruction* toKill = dead.back();
129 dead.pop_back();
130 context()->KillInst(toKill);
131 }
132
133 // Attempt to further scalarize.
134 for (auto var : replacements) {
135 if (var->opcode() == SpvOpVariable) {
136 if (get_def_use_mgr()->NumUsers(var) == 0) {
137 context()->KillInst(var);
138 } else if (CanReplaceVariable(var)) {
139 worklist->push(var);
140 }
141 }
142 }
143
144 return Status::SuccessWithChange;
145}
146
147bool ScalarReplacementPass::ReplaceWholeLoad(
148 Instruction* load, const std::vector<Instruction*>& replacements) {
149 // Replaces the load of the entire composite with a load from each replacement
150 // variable followed by a composite construction.
151 BasicBlock* block = context()->get_instr_block(load);
152 std::vector<Instruction*> loads;
153 loads.reserve(replacements.size());
154 BasicBlock::iterator where(load);
155 for (auto var : replacements) {
156 // Create a load of each replacement variable.
157 if (var->opcode() != SpvOpVariable) {
158 loads.push_back(var);
159 continue;
160 }
161
162 Instruction* type = GetStorageType(var);
163 uint32_t loadId = TakeNextId();
164 if (loadId == 0) {
165 return false;
166 }
167 std::unique_ptr<Instruction> newLoad(
168 new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
169 std::initializer_list<Operand>{
170 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
171 // Copy memory access attributes which start at index 1. Index 0 is the
172 // pointer to load.
173 for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
174 Operand copy(load->GetInOperand(i));
175 newLoad->AddOperand(std::move(copy));
176 }
177 where = where.InsertBefore(std::move(newLoad));
178 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
179 context()->set_instr_block(&*where, block);
180 loads.push_back(&*where);
181 }
182
183 // Construct a new composite.
184 uint32_t compositeId = TakeNextId();
185 if (compositeId == 0) {
186 return false;
187 }
188 where = load;
189 std::unique_ptr<Instruction> compositeConstruct(new Instruction(
190 context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
191 for (auto l : loads) {
192 Operand op(SPV_OPERAND_TYPE_ID,
193 std::initializer_list<uint32_t>{l->result_id()});
194 compositeConstruct->AddOperand(std::move(op));
195 }
196 where = where.InsertBefore(std::move(compositeConstruct));
197 get_def_use_mgr()->AnalyzeInstDefUse(&*where);
198 context()->set_instr_block(&*where, block);
199 context()->ReplaceAllUsesWith(load->result_id(), compositeId);
200 return true;
201}
202
203bool ScalarReplacementPass::ReplaceWholeStore(
204 Instruction* store, const std::vector<Instruction*>& replacements) {
205 // Replaces a store to the whole composite with a series of extract and stores
206 // to each element.
207 uint32_t storeInput = store->GetSingleWordInOperand(1u);
208 BasicBlock* block = context()->get_instr_block(store);
209 BasicBlock::iterator where(store);
210 uint32_t elementIndex = 0;
211 for (auto var : replacements) {
212 // Create the extract.
213 if (var->opcode() != SpvOpVariable) {
214 elementIndex++;
215 continue;
216 }
217
218 Instruction* type = GetStorageType(var);
219 uint32_t extractId = TakeNextId();
220 if (extractId == 0) {
221 return false;
222 }
223 std::unique_ptr<Instruction> extract(new Instruction(
224 context(), SpvOpCompositeExtract, type->result_id(), extractId,
225 std::initializer_list<Operand>{
226 {SPV_OPERAND_TYPE_ID, {storeInput}},
227 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
228 auto iter = where.InsertBefore(std::move(extract));
229 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
230 context()->set_instr_block(&*iter, block);
231
232 // Create the store.
233 std::unique_ptr<Instruction> newStore(
234 new Instruction(context(), SpvOpStore, 0, 0,
235 std::initializer_list<Operand>{
236 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
237 {SPV_OPERAND_TYPE_ID, {extractId}}}));
238 // Copy memory access attributes which start at index 2. Index 0 is the
239 // pointer and index 1 is the data.
240 for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
241 Operand copy(store->GetInOperand(i));
242 newStore->AddOperand(std::move(copy));
243 }
244 iter = where.InsertBefore(std::move(newStore));
245 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
246 context()->set_instr_block(&*iter, block);
247 }
248 return true;
249}
250
251bool ScalarReplacementPass::ReplaceAccessChain(
252 Instruction* chain, const std::vector<Instruction*>& replacements) {
253 // Replaces the access chain with either another access chain (with one fewer
254 // indexes) or a direct use of the replacement variable.
255 uint32_t indexId = chain->GetSingleWordInOperand(1u);
256 const Instruction* index = get_def_use_mgr()->GetDef(indexId);
257 int64_t indexValue = context()
258 ->get_constant_mgr()
259 ->GetConstantFromInst(index)
260 ->GetSignExtendedValue();
261 if (indexValue < 0 ||
262 indexValue >= static_cast<int64_t>(replacements.size())) {
263 // Out of bounds access, this is illegal IR. Notice that OpAccessChain
264 // indexing is 0-based, so we should also reject index == size-of-array.
265 return false;
266 } else {
267 const Instruction* var = replacements[static_cast<size_t>(indexValue)];
268 if (chain->NumInOperands() > 2) {
269 // Replace input access chain with another access chain.
270 BasicBlock::iterator chainIter(chain);
271 uint32_t replacementId = TakeNextId();
272 if (replacementId == 0) {
273 return false;
274 }
275 std::unique_ptr<Instruction> replacementChain(new Instruction(
276 context(), chain->opcode(), chain->type_id(), replacementId,
277 std::initializer_list<Operand>{
278 {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
279 // Add the remaining indexes.
280 for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
281 Operand copy(chain->GetInOperand(i));
282 replacementChain->AddOperand(std::move(copy));
283 }
284 auto iter = chainIter.InsertBefore(std::move(replacementChain));
285 get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
286 context()->set_instr_block(&*iter, context()->get_instr_block(chain));
287 context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
288 } else {
289 // Replace with a use of the variable.
290 context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
291 }
292 }
293
294 return true;
295}
296
297bool ScalarReplacementPass::CreateReplacementVariables(
298 Instruction* inst, std::vector<Instruction*>* replacements) {
299 Instruction* type = GetStorageType(inst);
300
301 std::unique_ptr<std::unordered_set<int64_t>> components_used =
302 GetUsedComponents(inst);
303
304 uint32_t elem = 0;
305 switch (type->opcode()) {
306 case SpvOpTypeStruct:
307 type->ForEachInOperand(
308 [this, inst, &elem, replacements, &components_used](uint32_t* id) {
309 if (!components_used || components_used->count(elem)) {
310 CreateVariable(*id, inst, elem, replacements);
311 } else {
312 replacements->push_back(CreateNullConstant(*id));
313 }
314 elem++;
315 });
316 break;
317 case SpvOpTypeArray:
318 for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
319 if (!components_used || components_used->count(i)) {
320 CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
321 replacements);
322 } else {
323 replacements->push_back(
324 CreateNullConstant(type->GetSingleWordInOperand(0u)));
325 }
326 }
327 break;
328
329 case SpvOpTypeMatrix:
330 case SpvOpTypeVector:
331 for (uint32_t i = 0; i != GetNumElements(type); ++i) {
332 CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
333 }
334 break;
335
336 default:
337 assert(false && "Unexpected type.");
338 break;
339 }
340
341 TransferAnnotations(inst, replacements);
342 return std::find(replacements->begin(), replacements->end(), nullptr) ==
343 replacements->end();
344}
345
346void ScalarReplacementPass::TransferAnnotations(
347 const Instruction* source, std::vector<Instruction*>* replacements) {
348 // Only transfer invariant and restrict decorations on the variable. There are
349 // no type or member decorations that are necessary to transfer.
350 for (auto inst :
351 get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
352 assert(inst->opcode() == SpvOpDecorate);
353 uint32_t decoration = inst->GetSingleWordInOperand(1u);
354 if (decoration == SpvDecorationInvariant ||
355 decoration == SpvDecorationRestrict) {
356 for (auto var : *replacements) {
357 if (var == nullptr) {
358 continue;
359 }
360
361 std::unique_ptr<Instruction> annotation(
362 new Instruction(context(), SpvOpDecorate, 0, 0,
363 std::initializer_list<Operand>{
364 {SPV_OPERAND_TYPE_ID, {var->result_id()}},
365 {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
366 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
367 Operand copy(inst->GetInOperand(i));
368 annotation->AddOperand(std::move(copy));
369 }
370 context()->AddAnnotationInst(std::move(annotation));
371 get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
372 }
373 }
374 }
375}
376
377void ScalarReplacementPass::CreateVariable(
378 uint32_t typeId, Instruction* varInst, uint32_t index,
379 std::vector<Instruction*>* replacements) {
380 uint32_t ptrId = GetOrCreatePointerType(typeId);
381 uint32_t id = TakeNextId();
382
383 if (id == 0) {
384 replacements->push_back(nullptr);
385 }
386
387 std::unique_ptr<Instruction> variable(new Instruction(
388 context(), SpvOpVariable, ptrId, id,
389 std::initializer_list<Operand>{
390 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));
391
392 BasicBlock* block = context()->get_instr_block(varInst);
393 block->begin().InsertBefore(std::move(variable));
394 Instruction* inst = &*block->begin();
395
396 // If varInst was initialized, make sure to initialize its replacement.
397 GetOrCreateInitialValue(varInst, index, inst);
398 get_def_use_mgr()->AnalyzeInstDefUse(inst);
399 context()->set_instr_block(inst, block);
400
401 // Copy decorations from the member to the new variable.
402 Instruction* typeInst = GetStorageType(varInst);
403 for (auto dec_inst :
404 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
405 uint32_t decoration;
406 if (dec_inst->opcode() != SpvOpMemberDecorate) {
407 continue;
408 }
409
410 if (dec_inst->GetSingleWordInOperand(1) != index) {
411 continue;
412 }
413
414 decoration = dec_inst->GetSingleWordInOperand(2u);
415 switch (decoration) {
416 case SpvDecorationRelaxedPrecision: {
417 std::unique_ptr<Instruction> new_dec_inst(
418 new Instruction(context(), SpvOpDecorate, 0, 0, {}));
419 new_dec_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {id}));
420 for (uint32_t i = 2; i < dec_inst->NumInOperandWords(); ++i) {
421 new_dec_inst->AddOperand(Operand(dec_inst->GetInOperand(i)));
422 }
423 context()->AddAnnotationInst(std::move(new_dec_inst));
424 } break;
425 default:
426 break;
427 }
428 }
429
430 replacements->push_back(inst);
431}
432
433uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
434 auto iter = pointee_to_pointer_.find(id);
435 if (iter != pointee_to_pointer_.end()) return iter->second;
436
437 analysis::Type* pointeeTy;
438 std::unique_ptr<analysis::Pointer> pointerTy;
439 std::tie(pointeeTy, pointerTy) =
440 context()->get_type_mgr()->GetTypeAndPointerType(id,
441 SpvStorageClassFunction);
442 uint32_t ptrId = 0;
443 if (pointeeTy->IsUniqueType()) {
444 // Non-ambiguous type, just ask the type manager for an id.
445 ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
446 pointee_to_pointer_[id] = ptrId;
447 return ptrId;
448 }
449
450 // Ambiguous type. We must perform a linear search to try and find the right
451 // type.
452 for (auto global : context()->types_values()) {
453 if (global.opcode() == SpvOpTypePointer &&
454 global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
455 global.GetSingleWordInOperand(1u) == id) {
456 if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
457 // Only reuse a decoration-less pointer of the correct type.
458 ptrId = global.result_id();
459 break;
460 }
461 }
462 }
463
464 if (ptrId != 0) {
465 pointee_to_pointer_[id] = ptrId;
466 return ptrId;
467 }
468
469 ptrId = TakeNextId();
470 context()->AddType(MakeUnique<Instruction>(
471 context(), SpvOpTypePointer, 0, ptrId,
472 std::initializer_list<Operand>{
473 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
474 {SPV_OPERAND_TYPE_ID, {id}}}));
475 Instruction* ptr = &*--context()->types_values_end();
476 get_def_use_mgr()->AnalyzeInstDefUse(ptr);
477 pointee_to_pointer_[id] = ptrId;
478 // Register with the type manager if necessary.
479 context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);
480
481 return ptrId;
482}
483
484void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
485 uint32_t index,
486 Instruction* newVar) {
487 assert(source->opcode() == SpvOpVariable);
488 if (source->NumInOperands() < 2) return;
489
490 uint32_t initId = source->GetSingleWordInOperand(1u);
491 uint32_t storageId = GetStorageType(newVar)->result_id();
492 Instruction* init = get_def_use_mgr()->GetDef(initId);
493 uint32_t newInitId = 0;
494 // TODO(dnovillo): Refactor this with constant propagation.
495 if (init->opcode() == SpvOpConstantNull) {
496 // Initialize to appropriate NULL.
497 auto iter = type_to_null_.find(storageId);
498 if (iter == type_to_null_.end()) {
499 newInitId = TakeNextId();
500 type_to_null_[storageId] = newInitId;
501 context()->AddGlobalValue(
502 MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
503 newInitId, std::initializer_list<Operand>{}));
504 Instruction* newNull = &*--context()->types_values_end();
505 get_def_use_mgr()->AnalyzeInstDefUse(newNull);
506 } else {
507 newInitId = iter->second;
508 }
509 } else if (IsSpecConstantInst(init->opcode())) {
510 // Create a new constant extract.
511 newInitId = TakeNextId();
512 context()->AddGlobalValue(MakeUnique<Instruction>(
513 context(), SpvOpSpecConstantOp, storageId, newInitId,
514 std::initializer_list<Operand>{
515 {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
516 {SPV_OPERAND_TYPE_ID, {init->result_id()}},
517 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
518 Instruction* newSpecConst = &*--context()->types_values_end();
519 get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
520 } else if (init->opcode() == SpvOpConstantComposite) {
521 // Get the appropriate index constant.
522 newInitId = init->GetSingleWordInOperand(index);
523 Instruction* element = get_def_use_mgr()->GetDef(newInitId);
524 if (element->opcode() == SpvOpUndef) {
525 // Undef is not a valid initializer for a variable.
526 newInitId = 0;
527 }
528 } else {
529 assert(false);
530 }
531
532 if (newInitId != 0) {
533 newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
534 }
535}
536
537uint64_t ScalarReplacementPass::GetArrayLength(
538 const Instruction* arrayType) const {
539 assert(arrayType->opcode() == SpvOpTypeArray);
540 const Instruction* length =
541 get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
542 return context()
543 ->get_constant_mgr()
544 ->GetConstantFromInst(length)
545 ->GetZeroExtendedValue();
546}
547
548uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
549 assert(type->opcode() == SpvOpTypeVector ||
550 type->opcode() == SpvOpTypeMatrix);
551 const Operand& op = type->GetInOperand(1u);
552 assert(op.words.size() <= 2);
553 uint64_t len = 0;
554 for (size_t i = 0; i != op.words.size(); ++i) {
555 len |= (static_cast<uint64_t>(op.words[i]) << (32ull * i));
556 }
557 return len;
558}
559
560bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
561 const Instruction* inst = get_def_use_mgr()->GetDef(id);
562 assert(inst);
563 return spvOpcodeIsSpecConstant(inst->opcode());
564}
565
566Instruction* ScalarReplacementPass::GetStorageType(
567 const Instruction* inst) const {
568 assert(inst->opcode() == SpvOpVariable);
569
570 uint32_t ptrTypeId = inst->type_id();
571 uint32_t typeId =
572 get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
573 return get_def_use_mgr()->GetDef(typeId);
574}
575
576bool ScalarReplacementPass::CanReplaceVariable(
577 const Instruction* varInst) const {
578 assert(varInst->opcode() == SpvOpVariable);
579
580 // Can only replace function scope variables.
581 if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction) {
582 return false;
583 }
584
585 if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id()))) {
586 return false;
587 }
588
589 const Instruction* typeInst = GetStorageType(varInst);
590 if (!CheckType(typeInst)) {
591 return false;
592 }
593
594 if (!CheckAnnotations(varInst)) {
595 return false;
596 }
597
598 if (!CheckUses(varInst)) {
599 return false;
600 }
601
602 return true;
603}
604
605bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
606 if (!CheckTypeAnnotations(typeInst)) {
607 return false;
608 }
609
610 switch (typeInst->opcode()) {
611 case SpvOpTypeStruct:
612 // Don't bother with empty structs or very large structs.
613 if (typeInst->NumInOperands() == 0 ||
614 IsLargerThanSizeLimit(typeInst->NumInOperands())) {
615 return false;
616 }
617 return true;
618 case SpvOpTypeArray:
619 if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
620 return false;
621 }
622 if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
623 return false;
624 }
625 return true;
626 // TODO(alanbaker): Develop some heuristics for when this should be
627 // re-enabled.
628 //// Specifically including matrix and vector in an attempt to reduce the
629 //// number of vector registers required.
630 // case SpvOpTypeMatrix:
631 // case SpvOpTypeVector:
632 // if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
633 // return true;
634
635 case SpvOpTypeRuntimeArray:
636 default:
637 return false;
638 }
639}
640
641bool ScalarReplacementPass::CheckTypeAnnotations(
642 const Instruction* typeInst) const {
643 for (auto inst :
644 get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
645 uint32_t decoration;
646 if (inst->opcode() == SpvOpDecorate) {
647 decoration = inst->GetSingleWordInOperand(1u);
648 } else {
649 assert(inst->opcode() == SpvOpMemberDecorate);
650 decoration = inst->GetSingleWordInOperand(2u);
651 }
652
653 switch (decoration) {
654 case SpvDecorationRowMajor:
655 case SpvDecorationColMajor:
656 case SpvDecorationArrayStride:
657 case SpvDecorationMatrixStride:
658 case SpvDecorationCPacked:
659 case SpvDecorationInvariant:
660 case SpvDecorationRestrict:
661 case SpvDecorationOffset:
662 case SpvDecorationAlignment:
663 case SpvDecorationAlignmentId:
664 case SpvDecorationMaxByteOffset:
665 case SpvDecorationRelaxedPrecision:
666 break;
667 default:
668 return false;
669 }
670 }
671
672 return true;
673}
674
675bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
676 for (auto inst :
677 get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
678 assert(inst->opcode() == SpvOpDecorate);
679 uint32_t decoration = inst->GetSingleWordInOperand(1u);
680 switch (decoration) {
681 case SpvDecorationInvariant:
682 case SpvDecorationRestrict:
683 case SpvDecorationAlignment:
684 case SpvDecorationAlignmentId:
685 case SpvDecorationMaxByteOffset:
686 break;
687 default:
688 return false;
689 }
690 }
691
692 return true;
693}
694
695bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
696 VariableStats stats = {0, 0};
697 bool ok = CheckUses(inst, &stats);
698
699 // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
700 // SRoA is costly, such as when the structure has many (unaccessed?)
701 // members.
702
703 return ok;
704}
705
706bool ScalarReplacementPass::CheckUses(const Instruction* inst,
707 VariableStats* stats) const {
708 uint64_t max_legal_index = GetMaxLegalIndex(inst);
709
710 bool ok = true;
711 get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
712 const Instruction* user,
713 uint32_t index) {
714 // Annotations are check as a group separately.
715 if (!IsAnnotationInst(user->opcode())) {
716 switch (user->opcode()) {
717 case SpvOpAccessChain:
718 case SpvOpInBoundsAccessChain:
719 if (index == 2u && user->NumInOperands() > 1) {
720 uint32_t id = user->GetSingleWordInOperand(1u);
721 const Instruction* opInst = get_def_use_mgr()->GetDef(id);
722 const auto* constant =
723 context()->get_constant_mgr()->GetConstantFromInst(opInst);
724 if (!constant) {
725 ok = false;
726 } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
727 ok = false;
728 } else {
729 if (!CheckUsesRelaxed(user)) ok = false;
730 }
731 stats->num_partial_accesses++;
732 } else {
733 ok = false;
734 }
735 break;
736 case SpvOpLoad:
737 if (!CheckLoad(user, index)) ok = false;
738 stats->num_full_accesses++;
739 break;
740 case SpvOpStore:
741 if (!CheckStore(user, index)) ok = false;
742 stats->num_full_accesses++;
743 break;
744 case SpvOpName:
745 case SpvOpMemberName:
746 break;
747 default:
748 ok = false;
749 break;
750 }
751 }
752 });
753
754 return ok;
755}
756
757bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
758 bool ok = true;
759 get_def_use_mgr()->ForEachUse(
760 inst, [this, &ok](const Instruction* user, uint32_t index) {
761 switch (user->opcode()) {
762 case SpvOpAccessChain:
763 case SpvOpInBoundsAccessChain:
764 if (index != 2u) {
765 ok = false;
766 } else {
767 if (!CheckUsesRelaxed(user)) ok = false;
768 }
769 break;
770 case SpvOpLoad:
771 if (!CheckLoad(user, index)) ok = false;
772 break;
773 case SpvOpStore:
774 if (!CheckStore(user, index)) ok = false;
775 break;
776 default:
777 ok = false;
778 break;
779 }
780 });
781
782 return ok;
783}
784
785bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
786 uint32_t index) const {
787 if (index != 2u) return false;
788 if (inst->NumInOperands() >= 2 &&
789 inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
790 return false;
791 return true;
792}
793
794bool ScalarReplacementPass::CheckStore(const Instruction* inst,
795 uint32_t index) const {
796 if (index != 0u) return false;
797 if (inst->NumInOperands() >= 3 &&
798 inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
799 return false;
800 return true;
801}
802bool ScalarReplacementPass::IsLargerThanSizeLimit(uint64_t length) const {
803 if (max_num_elements_ == 0) {
804 return false;
805 }
806 return length > max_num_elements_;
807}
808
809std::unique_ptr<std::unordered_set<int64_t>>
810ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
811 std::unique_ptr<std::unordered_set<int64_t>> result(
812 new std::unordered_set<int64_t>());
813
814 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
815
816 def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
817 this](Instruction* use) {
818 switch (use->opcode()) {
819 case SpvOpLoad: {
820 // Look for extract from the load.
821 std::vector<uint32_t> t;
822 if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
823 if (use2->opcode() != SpvOpCompositeExtract ||
824 use2->NumInOperands() <= 1) {
825 return false;
826 }
827 t.push_back(use2->GetSingleWordInOperand(1));
828 return true;
829 })) {
830 result->insert(t.begin(), t.end());
831 return true;
832 } else {
833 result.reset(nullptr);
834 return false;
835 }
836 }
837 case SpvOpName:
838 case SpvOpMemberName:
839 case SpvOpStore:
840 // No components are used.
841 return true;
842 case SpvOpAccessChain:
843 case SpvOpInBoundsAccessChain: {
844 // Add the first index it if is a constant.
845 // TODO: Could be improved by checking if the address is used in a load.
846 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
847 uint32_t index_id = use->GetSingleWordInOperand(1);
848 const analysis::Constant* index_const =
849 const_mgr->FindDeclaredConstant(index_id);
850 if (index_const) {
851 result->insert(index_const->GetSignExtendedValue());
852 return true;
853 } else {
854 // Could be any element. Assuming all are used.
855 result.reset(nullptr);
856 return false;
857 }
858 }
859 default:
860 // We do not know what is happening. Have to assume the worst.
861 result.reset(nullptr);
862 return false;
863 }
864 });
865
866 return result;
867}
868
869Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
870 analysis::TypeManager* type_mgr = context()->get_type_mgr();
871 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
872
873 const analysis::Type* type = type_mgr->GetType(type_id);
874 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
875 Instruction* null_inst =
876 const_mgr->GetDefiningInstruction(null_const, type_id);
877 if (null_inst != nullptr) {
878 context()->UpdateDefUse(null_inst);
879 }
880 return null_inst;
881}
882
883uint64_t ScalarReplacementPass::GetMaxLegalIndex(
884 const Instruction* var_inst) const {
885 assert(var_inst->opcode() == SpvOpVariable &&
886 "|var_inst| must be a variable instruction.");
887 Instruction* type = GetStorageType(var_inst);
888 switch (type->opcode()) {
889 case SpvOpTypeStruct:
890 return type->NumInOperands();
891 case SpvOpTypeArray:
892 return GetArrayLength(type);
893 case SpvOpTypeMatrix:
894 case SpvOpTypeVector:
895 return GetNumElements(type);
896 default:
897 return 0;
898 }
899 return 0;
900}
901
902} // namespace opt
903} // namespace spvtools
904