1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/copy_prop_arrays.h"
16
17#include <utility>
18
19#include "source/opt/ir_builder.h"
20
21namespace spvtools {
22namespace opt {
23namespace {
24
25const uint32_t kLoadPointerInOperand = 0;
26const uint32_t kStorePointerInOperand = 0;
27const uint32_t kStoreObjectInOperand = 1;
28const uint32_t kCompositeExtractObjectInOperand = 0;
29const uint32_t kTypePointerStorageClassInIdx = 0;
30const uint32_t kTypePointerPointeeInIdx = 1;
31
32} // namespace
33
34Pass::Status CopyPropagateArrays::Process() {
35 bool modified = false;
36 for (Function& function : *get_module()) {
37 BasicBlock* entry_bb = &*function.begin();
38
39 for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
40 ++var_inst) {
41 if (!IsPointerToArrayType(var_inst->type_id())) {
42 continue;
43 }
44
45 // Find the only store to the entire memory location, if it exists.
46 Instruction* store_inst = FindStoreInstruction(&*var_inst);
47
48 if (!store_inst) {
49 continue;
50 }
51
52 std::unique_ptr<MemoryObject> source_object =
53 FindSourceObjectIfPossible(&*var_inst, store_inst);
54
55 if (source_object != nullptr) {
56 if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
57 modified = true;
58 PropagateObject(&*var_inst, source_object.get(), store_inst);
59 }
60 }
61 }
62 }
63 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64}
65
66std::unique_ptr<CopyPropagateArrays::MemoryObject>
67CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
68 Instruction* store_inst) {
69 assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
70
71 // Check that the variable is a composite object where |store_inst|
72 // dominates all of its loads.
73 if (!store_inst) {
74 return nullptr;
75 }
76
77 // Look at the loads to ensure they are dominated by the store.
78 if (!HasValidReferencesOnly(var_inst, store_inst)) {
79 return nullptr;
80 }
81
82 // If so, look at the store to see if it is the copy of an object.
83 std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
84 store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
85
86 if (!source) {
87 return nullptr;
88 }
89
90 // Ensure that |source| does not change between the point at which it is
91 // loaded, and the position in which |var_inst| is loaded.
92 //
93 // For now we will go with the easy to implement approach, and check that the
94 // entire variable (not just the specific component) is never written to.
95
96 if (!HasNoStores(source->GetVariable())) {
97 return nullptr;
98 }
99 return source;
100}
101
102Instruction* CopyPropagateArrays::FindStoreInstruction(
103 const Instruction* var_inst) const {
104 Instruction* store_inst = nullptr;
105 get_def_use_mgr()->WhileEachUser(
106 var_inst, [&store_inst, var_inst](Instruction* use) {
107 if (use->opcode() == SpvOpStore &&
108 use->GetSingleWordInOperand(kStorePointerInOperand) ==
109 var_inst->result_id()) {
110 if (store_inst == nullptr) {
111 store_inst = use;
112 } else {
113 store_inst = nullptr;
114 return false;
115 }
116 }
117 return true;
118 });
119 return store_inst;
120}
121
122void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
123 MemoryObject* source,
124 Instruction* insertion_point) {
125 assert(var_inst->opcode() == SpvOpVariable &&
126 "This function propagates variables.");
127
128 Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
129 context()->KillNamesAndDecorates(var_inst);
130 UpdateUses(var_inst, new_access_chain);
131}
132
133Instruction* CopyPropagateArrays::BuildNewAccessChain(
134 Instruction* insertion_point,
135 CopyPropagateArrays::MemoryObject* source) const {
136 InstructionBuilder builder(
137 context(), insertion_point,
138 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
139
140 if (source->AccessChain().size() == 0) {
141 return source->GetVariable();
142 }
143
144 return builder.AddAccessChain(source->GetPointerTypeId(this),
145 source->GetVariable()->result_id(),
146 source->AccessChain());
147}
148
149bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
150 return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
151 if (use->opcode() == SpvOpLoad) {
152 return true;
153 } else if (use->opcode() == SpvOpAccessChain) {
154 return HasNoStores(use);
155 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
156 return true;
157 } else if (use->opcode() == SpvOpStore) {
158 return false;
159 } else if (use->opcode() == SpvOpImageTexelPointer) {
160 return true;
161 }
162 // Some other instruction. Be conservative.
163 return false;
164 });
165}
166
167bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
168 Instruction* store_inst) {
169 BasicBlock* store_block = context()->get_instr_block(store_inst);
170 DominatorAnalysis* dominator_analysis =
171 context()->GetDominatorAnalysis(store_block->GetParent());
172
173 return get_def_use_mgr()->WhileEachUser(
174 ptr_inst,
175 [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
176 if (use->opcode() == SpvOpLoad ||
177 use->opcode() == SpvOpImageTexelPointer) {
178 // TODO: If there are many load in the same BB as |store_inst| the
179 // time to do the multiple traverses can add up. Consider collecting
180 // those loads and doing a single traversal.
181 return dominator_analysis->Dominates(store_inst, use);
182 } else if (use->opcode() == SpvOpAccessChain) {
183 return HasValidReferencesOnly(use, store_inst);
184 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
185 return true;
186 } else if (use->opcode() == SpvOpStore) {
187 // If we are storing to part of the object it is not an candidate.
188 return ptr_inst->opcode() == SpvOpVariable &&
189 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
190 ptr_inst->result_id();
191 }
192 // Some other instruction. Be conservative.
193 return false;
194 });
195}
196
197std::unique_ptr<CopyPropagateArrays::MemoryObject>
198CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
199 Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
200
201 switch (result_inst->opcode()) {
202 case SpvOpLoad:
203 return BuildMemoryObjectFromLoad(result_inst);
204 case SpvOpCompositeExtract:
205 return BuildMemoryObjectFromExtract(result_inst);
206 case SpvOpCompositeConstruct:
207 return BuildMemoryObjectFromCompositeConstruct(result_inst);
208 case SpvOpCopyObject:
209 return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
210 case SpvOpCompositeInsert:
211 return BuildMemoryObjectFromInsert(result_inst);
212 default:
213 return nullptr;
214 }
215}
216
217std::unique_ptr<CopyPropagateArrays::MemoryObject>
218CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
219 std::vector<uint32_t> components_in_reverse;
220 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
221
222 Instruction* current_inst = def_use_mgr->GetDef(
223 load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
224
225 // Build the access chain for the memory object by collecting the indices used
226 // in the OpAccessChain instructions. If we find a variable index, then
227 // return |nullptr| because we cannot know for sure which memory location is
228 // used.
229 //
230 // It is built in reverse order because the different |OpAccessChain|
231 // instructions are visited in reverse order from which they are applied.
232 while (current_inst->opcode() == SpvOpAccessChain) {
233 for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
234 uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
235 components_in_reverse.push_back(element_index_id);
236 }
237 current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
238 }
239
240 // If the address in the load is not constructed from an |OpVariable|
241 // instruction followed by a series of |OpAccessChain| instructions, then
242 // return |nullptr| because we cannot identify the owner or access chain
243 // exactly.
244 if (current_inst->opcode() != SpvOpVariable) {
245 return nullptr;
246 }
247
248 // Build the memory object. Use |rbegin| and |rend| to put the access chain
249 // back in the correct order.
250 return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
251 new MemoryObject(current_inst, components_in_reverse.rbegin(),
252 components_in_reverse.rend()));
253}
254
255std::unique_ptr<CopyPropagateArrays::MemoryObject>
256CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
257 assert(extract_inst->opcode() == SpvOpCompositeExtract &&
258 "Expecting an OpCompositeExtract instruction.");
259 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
260
261 std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
262 extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
263
264 if (result) {
265 analysis::Integer int_type(32, false);
266 const analysis::Type* uint32_type =
267 context()->get_type_mgr()->GetRegisteredType(&int_type);
268
269 std::vector<uint32_t> components;
270 // Convert the indices in the extract instruction to a series of ids that
271 // can be used by the |OpAccessChain| instruction.
272 for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
273 uint32_t index = extract_inst->GetSingleWordInOperand(i);
274 const analysis::Constant* index_const =
275 const_mgr->GetConstant(uint32_type, {index});
276 components.push_back(
277 const_mgr->GetDefiningInstruction(index_const)->result_id());
278 }
279 result->GetMember(components);
280 return result;
281 }
282 return nullptr;
283}
284
285std::unique_ptr<CopyPropagateArrays::MemoryObject>
286CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
287 Instruction* conststruct_inst) {
288 assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
289 "Expecting an OpCompositeConstruct instruction.");
290
291 // If every operand in the instruction are part of the same memory object, and
292 // are being combined in the same order, then the result is the same as the
293 // parent.
294
295 std::unique_ptr<MemoryObject> memory_object =
296 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
297
298 if (!memory_object) {
299 return nullptr;
300 }
301
302 if (!memory_object->IsMember()) {
303 return nullptr;
304 }
305
306 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
307 const analysis::Constant* last_access =
308 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
309 if (!last_access || !last_access->type()->AsInteger()) {
310 return nullptr;
311 }
312
313 if (last_access->GetU32() != 0) {
314 return nullptr;
315 }
316
317 memory_object->GetParent();
318
319 if (memory_object->GetNumberOfMembers() !=
320 conststruct_inst->NumInOperands()) {
321 return nullptr;
322 }
323
324 for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
325 std::unique_ptr<MemoryObject> member_object =
326 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
327
328 if (!member_object) {
329 return nullptr;
330 }
331
332 if (!member_object->IsMember()) {
333 return nullptr;
334 }
335
336 if (!memory_object->Contains(member_object.get())) {
337 return nullptr;
338 }
339
340 last_access =
341 const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
342 if (!last_access || !last_access->type()->AsInteger()) {
343 return nullptr;
344 }
345
346 if (last_access->GetU32() != i) {
347 return nullptr;
348 }
349 }
350 return memory_object;
351}
352
353std::unique_ptr<CopyPropagateArrays::MemoryObject>
354CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
355 assert(insert_inst->opcode() == SpvOpCompositeInsert &&
356 "Expecting an OpCompositeInsert instruction.");
357
358 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
359 analysis::TypeManager* type_mgr = context()->get_type_mgr();
360 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
361 const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
362
363 uint32_t number_of_elements = 0;
364 if (const analysis::Struct* struct_type = result_type->AsStruct()) {
365 number_of_elements =
366 static_cast<uint32_t>(struct_type->element_types().size());
367 } else if (const analysis::Array* array_type = result_type->AsArray()) {
368 const analysis::Constant* length_const =
369 const_mgr->FindDeclaredConstant(array_type->LengthId());
370 number_of_elements = length_const->GetU32();
371 } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
372 number_of_elements = vector_type->element_count();
373 } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
374 number_of_elements = matrix_type->element_count();
375 }
376
377 if (number_of_elements == 0) {
378 return nullptr;
379 }
380
381 if (insert_inst->NumInOperands() != 3) {
382 return nullptr;
383 }
384
385 if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
386 return nullptr;
387 }
388
389 std::unique_ptr<MemoryObject> memory_object =
390 GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
391
392 if (!memory_object) {
393 return nullptr;
394 }
395
396 if (!memory_object->IsMember()) {
397 return nullptr;
398 }
399
400 const analysis::Constant* last_access =
401 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
402 if (!last_access || !last_access->type()->AsInteger()) {
403 return nullptr;
404 }
405
406 if (last_access->GetU32() != number_of_elements - 1) {
407 return nullptr;
408 }
409
410 memory_object->GetParent();
411
412 Instruction* current_insert =
413 def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
414 for (uint32_t i = number_of_elements - 1; i > 0; --i) {
415 if (current_insert->opcode() != SpvOpCompositeInsert) {
416 return nullptr;
417 }
418
419 if (current_insert->NumInOperands() != 3) {
420 return nullptr;
421 }
422
423 if (current_insert->GetSingleWordInOperand(2) != i - 1) {
424 return nullptr;
425 }
426
427 std::unique_ptr<MemoryObject> current_memory_object =
428 GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
429
430 if (!current_memory_object) {
431 return nullptr;
432 }
433
434 if (!current_memory_object->IsMember()) {
435 return nullptr;
436 }
437
438 if (memory_object->AccessChain().size() + 1 !=
439 current_memory_object->AccessChain().size()) {
440 return nullptr;
441 }
442
443 if (!memory_object->Contains(current_memory_object.get())) {
444 return nullptr;
445 }
446
447 const analysis::Constant* current_last_access =
448 const_mgr->FindDeclaredConstant(
449 current_memory_object->AccessChain().back());
450 if (!current_last_access || !current_last_access->type()->AsInteger()) {
451 return nullptr;
452 }
453
454 if (current_last_access->GetU32() != i - 1) {
455 return nullptr;
456 }
457 current_insert =
458 def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
459 }
460
461 return memory_object;
462}
463
464bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
465 analysis::TypeManager* type_mgr = context()->get_type_mgr();
466 analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
467 if (pointer_type) {
468 return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
469 pointer_type->pointee_type()->kind() == analysis::Type::kImage;
470 }
471 return false;
472}
473
474bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
475 uint32_t type_id) {
476 analysis::TypeManager* type_mgr = context()->get_type_mgr();
477 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
478 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
479
480 analysis::Type* type = type_mgr->GetType(type_id);
481 if (type->AsRuntimeArray()) {
482 return false;
483 }
484
485 if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
486 // If the type is not an aggregate, then the desired type must be the
487 // same as the current type. No work to do, and we can do that.
488 return true;
489 }
490
491 return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
492 const_mgr,
493 type](Instruction* use,
494 uint32_t) {
495 switch (use->opcode()) {
496 case SpvOpLoad: {
497 analysis::Pointer* pointer_type = type->AsPointer();
498 uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
499
500 if (new_type_id != use->type_id()) {
501 return CanUpdateUses(use, new_type_id);
502 }
503 return true;
504 }
505 case SpvOpAccessChain: {
506 analysis::Pointer* pointer_type = type->AsPointer();
507 const analysis::Type* pointee_type = pointer_type->pointee_type();
508
509 std::vector<uint32_t> access_chain;
510 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
511 const analysis::Constant* index_const =
512 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
513 if (index_const) {
514 access_chain.push_back(index_const->GetU32());
515 } else {
516 // Variable index means the type is a type where every element
517 // is the same type. Use element 0 to get the type.
518 access_chain.push_back(0);
519 }
520 }
521
522 const analysis::Type* new_pointee_type =
523 type_mgr->GetMemberType(pointee_type, access_chain);
524 analysis::Pointer pointerTy(new_pointee_type,
525 pointer_type->storage_class());
526 uint32_t new_pointer_type_id =
527 context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
528 if (new_pointer_type_id == 0) {
529 return false;
530 }
531
532 if (new_pointer_type_id != use->type_id()) {
533 return CanUpdateUses(use, new_pointer_type_id);
534 }
535 return true;
536 }
537 case SpvOpCompositeExtract: {
538 std::vector<uint32_t> access_chain;
539 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
540 access_chain.push_back(use->GetSingleWordInOperand(i));
541 }
542
543 const analysis::Type* new_type =
544 type_mgr->GetMemberType(type, access_chain);
545 uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
546 if (new_type_id == 0) {
547 return false;
548 }
549
550 if (new_type_id != use->type_id()) {
551 return CanUpdateUses(use, new_type_id);
552 }
553 return true;
554 }
555 case SpvOpStore:
556 // If needed, we can create an element-by-element copy to change the
557 // type of the value being stored. This way we can always handled
558 // stores.
559 return true;
560 case SpvOpImageTexelPointer:
561 case SpvOpName:
562 return true;
563 default:
564 return use->IsDecoration();
565 }
566 });
567}
568void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
569 Instruction* new_ptr_inst) {
570 analysis::TypeManager* type_mgr = context()->get_type_mgr();
571 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
572 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
573
574 std::vector<std::pair<Instruction*, uint32_t> > uses;
575 def_use_mgr->ForEachUse(original_ptr_inst,
576 [&uses](Instruction* use, uint32_t index) {
577 uses.push_back({use, index});
578 });
579
580 for (auto pair : uses) {
581 Instruction* use = pair.first;
582 uint32_t index = pair.second;
583 switch (use->opcode()) {
584 case SpvOpLoad: {
585 // Replace the actual use.
586 context()->ForgetUses(use);
587 use->SetOperand(index, {new_ptr_inst->result_id()});
588
589 // Update the type.
590 Instruction* pointer_type_inst =
591 def_use_mgr->GetDef(new_ptr_inst->type_id());
592 uint32_t new_type_id =
593 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
594 if (new_type_id != use->type_id()) {
595 use->SetResultType(new_type_id);
596 context()->AnalyzeUses(use);
597 UpdateUses(use, use);
598 } else {
599 context()->AnalyzeUses(use);
600 }
601 } break;
602 case SpvOpAccessChain: {
603 // Update the actual use.
604 context()->ForgetUses(use);
605 use->SetOperand(index, {new_ptr_inst->result_id()});
606
607 // Convert the ids on the OpAccessChain to indices that can be used to
608 // get the specific member.
609 std::vector<uint32_t> access_chain;
610 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
611 const analysis::Constant* index_const =
612 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
613 if (index_const) {
614 access_chain.push_back(index_const->GetU32());
615 } else {
616 // Variable index means the type is an type where every element
617 // is the same type. Use element 0 to get the type.
618 access_chain.push_back(0);
619 }
620 }
621
622 Instruction* pointer_type_inst =
623 get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
624
625 uint32_t new_pointee_type_id = GetMemberTypeId(
626 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
627 access_chain);
628
629 SpvStorageClass storage_class = static_cast<SpvStorageClass>(
630 pointer_type_inst->GetSingleWordInOperand(
631 kTypePointerStorageClassInIdx));
632
633 uint32_t new_pointer_type_id =
634 type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
635
636 if (new_pointer_type_id != use->type_id()) {
637 use->SetResultType(new_pointer_type_id);
638 context()->AnalyzeUses(use);
639 UpdateUses(use, use);
640 } else {
641 context()->AnalyzeUses(use);
642 }
643 } break;
644 case SpvOpCompositeExtract: {
645 // Update the actual use.
646 context()->ForgetUses(use);
647 use->SetOperand(index, {new_ptr_inst->result_id()});
648
649 uint32_t new_type_id = new_ptr_inst->type_id();
650 std::vector<uint32_t> access_chain;
651 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
652 access_chain.push_back(use->GetSingleWordInOperand(i));
653 }
654
655 new_type_id = GetMemberTypeId(new_type_id, access_chain);
656
657 if (new_type_id != use->type_id()) {
658 use->SetResultType(new_type_id);
659 context()->AnalyzeUses(use);
660 UpdateUses(use, use);
661 } else {
662 context()->AnalyzeUses(use);
663 }
664 } break;
665 case SpvOpStore:
666 // If the use is the pointer, then it is the single store to that
667 // variable. We do not want to replace it. Instead, it will become
668 // dead after all of the loads are removed, and ADCE will get rid of it.
669 //
670 // If the use is the object being stored, we will create a copy of the
671 // object turning it into the correct type. The copy is done by
672 // decomposing the object into the base type, which must be the same,
673 // and then rebuilding them.
674 if (index == 1) {
675 Instruction* target_pointer = def_use_mgr->GetDef(
676 use->GetSingleWordInOperand(kStorePointerInOperand));
677 Instruction* pointer_type =
678 def_use_mgr->GetDef(target_pointer->type_id());
679 uint32_t pointee_type_id =
680 pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
681 uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
682
683 context()->ForgetUses(use);
684 use->SetInOperand(index, {copy});
685 context()->AnalyzeUses(use);
686 }
687 break;
688 case SpvOpImageTexelPointer:
689 // We treat an OpImageTexelPointer as a load. The result type should
690 // always have the Image storage class, and should not need to be
691 // updated.
692
693 // Replace the actual use.
694 context()->ForgetUses(use);
695 use->SetOperand(index, {new_ptr_inst->result_id()});
696 context()->AnalyzeUses(use);
697 break;
698 default:
699 assert(false && "Don't know how to rewrite instruction");
700 break;
701 }
702 }
703}
704
705uint32_t CopyPropagateArrays::GetMemberTypeId(
706 uint32_t id, const std::vector<uint32_t>& access_chain) const {
707 for (uint32_t element_index : access_chain) {
708 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
709 switch (type_inst->opcode()) {
710 case SpvOpTypeArray:
711 case SpvOpTypeRuntimeArray:
712 case SpvOpTypeMatrix:
713 case SpvOpTypeVector:
714 id = type_inst->GetSingleWordInOperand(0);
715 break;
716 case SpvOpTypeStruct:
717 id = type_inst->GetSingleWordInOperand(element_index);
718 break;
719 default:
720 break;
721 }
722 assert(id != 0 &&
723 "Tried to extract from an object where it cannot be done.");
724 }
725 return id;
726}
727
728void CopyPropagateArrays::MemoryObject::GetMember(
729 const std::vector<uint32_t>& access_chain) {
730 access_chain_.insert(access_chain_.end(), access_chain.begin(),
731 access_chain.end());
732}
733
734uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
735 IRContext* context = variable_inst_->context();
736 analysis::TypeManager* type_mgr = context->get_type_mgr();
737
738 const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
739 type = type->AsPointer()->pointee_type();
740
741 std::vector<uint32_t> access_indices = GetAccessIds();
742 type = type_mgr->GetMemberType(type, access_indices);
743
744 if (const analysis::Struct* struct_type = type->AsStruct()) {
745 return static_cast<uint32_t>(struct_type->element_types().size());
746 } else if (const analysis::Array* array_type = type->AsArray()) {
747 const analysis::Constant* length_const =
748 context->get_constant_mgr()->FindDeclaredConstant(
749 array_type->LengthId());
750 assert(length_const->type()->AsInteger());
751 return length_const->GetU32();
752 } else if (const analysis::Vector* vector_type = type->AsVector()) {
753 return vector_type->element_count();
754 } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
755 return matrix_type->element_count();
756 } else {
757 return 0;
758 }
759}
760
761template <class iterator>
762CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
763 iterator begin, iterator end)
764 : variable_inst_(var_inst), access_chain_(begin, end) {}
765
766std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
767 analysis::ConstantManager* const_mgr =
768 variable_inst_->context()->get_constant_mgr();
769
770 std::vector<uint32_t> access_indices;
771 for (uint32_t id : AccessChain()) {
772 const analysis::Constant* element_index_const =
773 const_mgr->FindDeclaredConstant(id);
774 if (!element_index_const) {
775 access_indices.push_back(0);
776 } else {
777 access_indices.push_back(element_index_const->GetU32());
778 }
779 }
780 return access_indices;
781}
782
783bool CopyPropagateArrays::MemoryObject::Contains(
784 CopyPropagateArrays::MemoryObject* other) {
785 if (this->GetVariable() != other->GetVariable()) {
786 return false;
787 }
788
789 if (AccessChain().size() > other->AccessChain().size()) {
790 return false;
791 }
792
793 for (uint32_t i = 0; i < AccessChain().size(); i++) {
794 if (AccessChain()[i] != other->AccessChain()[i]) {
795 return false;
796 }
797 }
798 return true;
799}
800
801} // namespace opt
802} // namespace spvtools
803