1// Copyright (c) 2018 The Khronos Group Inc.
2// Copyright (c) 2018 Valve Corporation
3// Copyright (c) 2018 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "source/opt/dead_insert_elim_pass.h"
18
19#include "source/opt/composite.h"
20#include "source/opt/ir_context.h"
21#include "source/opt/iterator.h"
22#include "spirv/1.2/GLSL.std.450.h"
23
24namespace spvtools {
25namespace opt {
26
27namespace {
28
29const uint32_t kTypeVectorCountInIdx = 1;
30const uint32_t kTypeMatrixCountInIdx = 1;
31const uint32_t kTypeArrayLengthIdInIdx = 1;
32const uint32_t kTypeIntWidthInIdx = 0;
33const uint32_t kConstantValueInIdx = 0;
34const uint32_t kInsertObjectIdInIdx = 0;
35const uint32_t kInsertCompositeIdInIdx = 1;
36
37} // anonymous namespace
38
39uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
40 switch (typeInst->opcode()) {
41 case SpvOpTypeVector: {
42 return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
43 } break;
44 case SpvOpTypeMatrix: {
45 return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
46 } break;
47 case SpvOpTypeArray: {
48 uint32_t lenId =
49 typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
50 Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
51 if (lenInst->opcode() != SpvOpConstant) return 0;
52 uint32_t lenTypeId = lenInst->type_id();
53 Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
54 // TODO(greg-lunarg): Support non-32-bit array length
55 if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
56 return 0;
57 return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
58 } break;
59 case SpvOpTypeStruct: {
60 return typeInst->NumInOperands();
61 } break;
62 default: { return 0; } break;
63 }
64}
65
66void DeadInsertElimPass::MarkInsertChain(
67 Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
68 uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
69 // Not currently optimizing array inserts.
70 Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
71 if (typeInst->opcode() == SpvOpTypeArray) return;
72 // Insert chains are only composed of inserts and phis
73 if (insertChain->opcode() != SpvOpCompositeInsert &&
74 insertChain->opcode() != SpvOpPhi)
75 return;
76 // If extract indices are empty, mark all subcomponents if type
77 // is constant length.
78 if (pExtIndices == nullptr) {
79 uint32_t cnum = NumComponents(typeInst);
80 if (cnum > 0) {
81 std::vector<uint32_t> extIndices;
82 for (uint32_t i = 0; i < cnum; i++) {
83 extIndices.clear();
84 extIndices.push_back(i);
85 std::unordered_set<uint32_t> sub_visited_phis;
86 MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
87 }
88 return;
89 }
90 }
91 Instruction* insInst = insertChain;
92 while (insInst->opcode() == SpvOpCompositeInsert) {
93 // If no extract indices, mark insert and inserted object (which might
94 // also be an insert chain) and continue up the chain though the input
95 // composite.
96 //
97 // Note: We mark inserted objects in this function (rather than in
98 // EliminateDeadInsertsOnePass) because in some cases, we can do it
99 // more accurately here.
100 if (pExtIndices == nullptr) {
101 liveInserts_.insert(insInst->result_id());
102 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
103 std::unordered_set<uint32_t> obj_visited_phis;
104 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
105 &obj_visited_phis);
106 // If extract indices match insert, we are done. Mark insert and
107 // inserted object.
108 } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
109 liveInserts_.insert(insInst->result_id());
110 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
111 std::unordered_set<uint32_t> obj_visited_phis;
112 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
113 &obj_visited_phis);
114 break;
115 // If non-matching intersection, mark insert
116 } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
117 liveInserts_.insert(insInst->result_id());
118 // If more extract indices than insert, we are done. Use remaining
119 // extract indices to mark inserted object.
120 uint32_t numInsertIndices = insInst->NumInOperands() - 2;
121 if (pExtIndices->size() - extOffset > numInsertIndices) {
122 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
123 std::unordered_set<uint32_t> obj_visited_phis;
124 MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
125 extOffset + numInsertIndices, &obj_visited_phis);
126 break;
127 // If fewer extract indices than insert, also mark inserted object and
128 // continue up chain.
129 } else {
130 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
131 std::unordered_set<uint32_t> obj_visited_phis;
132 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
133 &obj_visited_phis);
134 }
135 }
136 // Get next insert in chain
137 const uint32_t compId =
138 insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
139 insInst = get_def_use_mgr()->GetDef(compId);
140 }
141 // If insert chain ended with phi, do recursive call on each operand
142 if (insInst->opcode() != SpvOpPhi) return;
143 // Mark phi visited to prevent potential infinite loop. If phi is already
144 // visited, return to avoid infinite loop.
145 if (visited_phis->count(insInst->result_id()) != 0) return;
146 visited_phis->insert(insInst->result_id());
147
148 // Phis may have duplicate inputs values for different edges, prune incoming
149 // ids lists before recursing.
150 std::vector<uint32_t> ids;
151 for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
152 ids.push_back(insInst->GetSingleWordInOperand(i));
153 }
154 std::sort(ids.begin(), ids.end());
155 auto new_end = std::unique(ids.begin(), ids.end());
156 for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
157 Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
158 MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
159 }
160}
161
162bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
163 bool modified = false;
164 bool lastmodified = true;
165 // Each pass can delete dead instructions, thus potentially revealing
166 // new dead insertions ie insertions with no uses.
167 while (lastmodified) {
168 lastmodified = EliminateDeadInsertsOnePass(func);
169 modified |= lastmodified;
170 }
171 return modified;
172}
173
174bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
175 bool modified = false;
176 liveInserts_.clear();
177 visitedPhis_.clear();
178 // Mark all live inserts
179 for (auto bi = func->begin(); bi != func->end(); ++bi) {
180 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
181 // Only process Inserts and composite Phis
182 SpvOp op = ii->opcode();
183 Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
184 if (op != SpvOpCompositeInsert &&
185 (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
186 continue;
187 // The marking algorithm can be expensive for large arrays and the
188 // efficacy of eliminating dead inserts into arrays is questionable.
189 // Skip optimizing array inserts for now. Just mark them live.
190 // TODO(greg-lunarg): Eliminate dead array inserts
191 if (op == SpvOpCompositeInsert) {
192 if (typeInst->opcode() == SpvOpTypeArray) {
193 liveInserts_.insert(ii->result_id());
194 continue;
195 }
196 }
197 const uint32_t id = ii->result_id();
198 get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
199 switch (user->opcode()) {
200 case SpvOpCompositeInsert:
201 case SpvOpPhi:
202 // Use by insert or phi does not initiate marking
203 break;
204 case SpvOpCompositeExtract: {
205 // Capture extract indices
206 std::vector<uint32_t> extIndices;
207 uint32_t icnt = 0;
208 user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
209 if (icnt > 0) extIndices.push_back(*idp);
210 ++icnt;
211 });
212 // Mark all inserts in chain that intersect with extract
213 std::unordered_set<uint32_t> visited_phis;
214 MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
215 } break;
216 default: {
217 // Mark inserts in chain for all components
218 MarkInsertChain(&*ii, nullptr, 0, nullptr);
219 } break;
220 }
221 });
222 }
223 }
224 // Find and disconnect dead inserts
225 std::vector<Instruction*> dead_instructions;
226 for (auto bi = func->begin(); bi != func->end(); ++bi) {
227 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
228 if (ii->opcode() != SpvOpCompositeInsert) continue;
229 const uint32_t id = ii->result_id();
230 if (liveInserts_.find(id) != liveInserts_.end()) continue;
231 const uint32_t replId =
232 ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
233 (void)context()->ReplaceAllUsesWith(id, replId);
234 dead_instructions.push_back(&*ii);
235 modified = true;
236 }
237 }
238 // DCE dead inserts
239 while (!dead_instructions.empty()) {
240 Instruction* inst = dead_instructions.back();
241 dead_instructions.pop_back();
242 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
243 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
244 other_inst);
245 if (i != dead_instructions.end()) {
246 dead_instructions.erase(i);
247 }
248 });
249 }
250 return modified;
251}
252
253Pass::Status DeadInsertElimPass::Process() {
254 // Process all entry point functions.
255 ProcessFunction pfn = [this](Function* fp) {
256 return EliminateDeadInserts(fp);
257 };
258 bool modified = context()->ProcessEntryPointCallTree(pfn);
259 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
260}
261
262} // namespace opt
263} // namespace spvtools
264