1 | // Copyright (c) 2018 The Khronos Group Inc. |
2 | // Copyright (c) 2018 Valve Corporation |
3 | // Copyright (c) 2018 LunarG Inc. |
4 | // |
5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | // you may not use this file except in compliance with the License. |
7 | // You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, software |
12 | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | // See the License for the specific language governing permissions and |
15 | // limitations under the License. |
16 | |
17 | #include "source/opt/dead_insert_elim_pass.h" |
18 | |
19 | #include "source/opt/composite.h" |
20 | #include "source/opt/ir_context.h" |
21 | #include "source/opt/iterator.h" |
22 | #include "spirv/1.2/GLSL.std.450.h" |
23 | |
24 | namespace spvtools { |
25 | namespace opt { |
26 | |
27 | namespace { |
28 | |
29 | const uint32_t kTypeVectorCountInIdx = 1; |
30 | const uint32_t kTypeMatrixCountInIdx = 1; |
31 | const uint32_t kTypeArrayLengthIdInIdx = 1; |
32 | const uint32_t kTypeIntWidthInIdx = 0; |
33 | const uint32_t kConstantValueInIdx = 0; |
34 | const uint32_t kInsertObjectIdInIdx = 0; |
35 | const uint32_t kInsertCompositeIdInIdx = 1; |
36 | |
37 | } // anonymous namespace |
38 | |
39 | uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) { |
40 | switch (typeInst->opcode()) { |
41 | case SpvOpTypeVector: { |
42 | return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx); |
43 | } break; |
44 | case SpvOpTypeMatrix: { |
45 | return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx); |
46 | } break; |
47 | case SpvOpTypeArray: { |
48 | uint32_t lenId = |
49 | typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx); |
50 | Instruction* lenInst = get_def_use_mgr()->GetDef(lenId); |
51 | if (lenInst->opcode() != SpvOpConstant) return 0; |
52 | uint32_t lenTypeId = lenInst->type_id(); |
53 | Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId); |
54 | // TODO(greg-lunarg): Support non-32-bit array length |
55 | if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32) |
56 | return 0; |
57 | return lenInst->GetSingleWordInOperand(kConstantValueInIdx); |
58 | } break; |
59 | case SpvOpTypeStruct: { |
60 | return typeInst->NumInOperands(); |
61 | } break; |
62 | default: { return 0; } break; |
63 | } |
64 | } |
65 | |
66 | void DeadInsertElimPass::MarkInsertChain( |
67 | Instruction* insertChain, std::vector<uint32_t>* pExtIndices, |
68 | uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) { |
69 | // Not currently optimizing array inserts. |
70 | Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id()); |
71 | if (typeInst->opcode() == SpvOpTypeArray) return; |
72 | // Insert chains are only composed of inserts and phis |
73 | if (insertChain->opcode() != SpvOpCompositeInsert && |
74 | insertChain->opcode() != SpvOpPhi) |
75 | return; |
76 | // If extract indices are empty, mark all subcomponents if type |
77 | // is constant length. |
78 | if (pExtIndices == nullptr) { |
79 | uint32_t cnum = NumComponents(typeInst); |
80 | if (cnum > 0) { |
81 | std::vector<uint32_t> extIndices; |
82 | for (uint32_t i = 0; i < cnum; i++) { |
83 | extIndices.clear(); |
84 | extIndices.push_back(i); |
85 | std::unordered_set<uint32_t> sub_visited_phis; |
86 | MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis); |
87 | } |
88 | return; |
89 | } |
90 | } |
91 | Instruction* insInst = insertChain; |
92 | while (insInst->opcode() == SpvOpCompositeInsert) { |
93 | // If no extract indices, mark insert and inserted object (which might |
94 | // also be an insert chain) and continue up the chain though the input |
95 | // composite. |
96 | // |
97 | // Note: We mark inserted objects in this function (rather than in |
98 | // EliminateDeadInsertsOnePass) because in some cases, we can do it |
99 | // more accurately here. |
100 | if (pExtIndices == nullptr) { |
101 | liveInserts_.insert(insInst->result_id()); |
102 | uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); |
103 | std::unordered_set<uint32_t> obj_visited_phis; |
104 | MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, |
105 | &obj_visited_phis); |
106 | // If extract indices match insert, we are done. Mark insert and |
107 | // inserted object. |
108 | } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) { |
109 | liveInserts_.insert(insInst->result_id()); |
110 | uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); |
111 | std::unordered_set<uint32_t> obj_visited_phis; |
112 | MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, |
113 | &obj_visited_phis); |
114 | break; |
115 | // If non-matching intersection, mark insert |
116 | } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) { |
117 | liveInserts_.insert(insInst->result_id()); |
118 | // If more extract indices than insert, we are done. Use remaining |
119 | // extract indices to mark inserted object. |
120 | uint32_t numInsertIndices = insInst->NumInOperands() - 2; |
121 | if (pExtIndices->size() - extOffset > numInsertIndices) { |
122 | uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); |
123 | std::unordered_set<uint32_t> obj_visited_phis; |
124 | MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices, |
125 | extOffset + numInsertIndices, &obj_visited_phis); |
126 | break; |
127 | // If fewer extract indices than insert, also mark inserted object and |
128 | // continue up chain. |
129 | } else { |
130 | uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx); |
131 | std::unordered_set<uint32_t> obj_visited_phis; |
132 | MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0, |
133 | &obj_visited_phis); |
134 | } |
135 | } |
136 | // Get next insert in chain |
137 | const uint32_t compId = |
138 | insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx); |
139 | insInst = get_def_use_mgr()->GetDef(compId); |
140 | } |
141 | // If insert chain ended with phi, do recursive call on each operand |
142 | if (insInst->opcode() != SpvOpPhi) return; |
143 | // Mark phi visited to prevent potential infinite loop. If phi is already |
144 | // visited, return to avoid infinite loop. |
145 | if (visited_phis->count(insInst->result_id()) != 0) return; |
146 | visited_phis->insert(insInst->result_id()); |
147 | |
148 | // Phis may have duplicate inputs values for different edges, prune incoming |
149 | // ids lists before recursing. |
150 | std::vector<uint32_t> ids; |
151 | for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) { |
152 | ids.push_back(insInst->GetSingleWordInOperand(i)); |
153 | } |
154 | std::sort(ids.begin(), ids.end()); |
155 | auto new_end = std::unique(ids.begin(), ids.end()); |
156 | for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) { |
157 | Instruction* pi = get_def_use_mgr()->GetDef(*id_iter); |
158 | MarkInsertChain(pi, pExtIndices, extOffset, visited_phis); |
159 | } |
160 | } |
161 | |
162 | bool DeadInsertElimPass::EliminateDeadInserts(Function* func) { |
163 | bool modified = false; |
164 | bool lastmodified = true; |
165 | // Each pass can delete dead instructions, thus potentially revealing |
166 | // new dead insertions ie insertions with no uses. |
167 | while (lastmodified) { |
168 | lastmodified = EliminateDeadInsertsOnePass(func); |
169 | modified |= lastmodified; |
170 | } |
171 | return modified; |
172 | } |
173 | |
174 | bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) { |
175 | bool modified = false; |
176 | liveInserts_.clear(); |
177 | visitedPhis_.clear(); |
178 | // Mark all live inserts |
179 | for (auto bi = func->begin(); bi != func->end(); ++bi) { |
180 | for (auto ii = bi->begin(); ii != bi->end(); ++ii) { |
181 | // Only process Inserts and composite Phis |
182 | SpvOp op = ii->opcode(); |
183 | Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id()); |
184 | if (op != SpvOpCompositeInsert && |
185 | (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode()))) |
186 | continue; |
187 | // The marking algorithm can be expensive for large arrays and the |
188 | // efficacy of eliminating dead inserts into arrays is questionable. |
189 | // Skip optimizing array inserts for now. Just mark them live. |
190 | // TODO(greg-lunarg): Eliminate dead array inserts |
191 | if (op == SpvOpCompositeInsert) { |
192 | if (typeInst->opcode() == SpvOpTypeArray) { |
193 | liveInserts_.insert(ii->result_id()); |
194 | continue; |
195 | } |
196 | } |
197 | const uint32_t id = ii->result_id(); |
198 | get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) { |
199 | switch (user->opcode()) { |
200 | case SpvOpCompositeInsert: |
201 | case SpvOpPhi: |
202 | // Use by insert or phi does not initiate marking |
203 | break; |
204 | case SpvOpCompositeExtract: { |
205 | // Capture extract indices |
206 | std::vector<uint32_t> extIndices; |
207 | uint32_t icnt = 0; |
208 | user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) { |
209 | if (icnt > 0) extIndices.push_back(*idp); |
210 | ++icnt; |
211 | }); |
212 | // Mark all inserts in chain that intersect with extract |
213 | std::unordered_set<uint32_t> visited_phis; |
214 | MarkInsertChain(&*ii, &extIndices, 0, &visited_phis); |
215 | } break; |
216 | default: { |
217 | // Mark inserts in chain for all components |
218 | MarkInsertChain(&*ii, nullptr, 0, nullptr); |
219 | } break; |
220 | } |
221 | }); |
222 | } |
223 | } |
224 | // Find and disconnect dead inserts |
225 | std::vector<Instruction*> dead_instructions; |
226 | for (auto bi = func->begin(); bi != func->end(); ++bi) { |
227 | for (auto ii = bi->begin(); ii != bi->end(); ++ii) { |
228 | if (ii->opcode() != SpvOpCompositeInsert) continue; |
229 | const uint32_t id = ii->result_id(); |
230 | if (liveInserts_.find(id) != liveInserts_.end()) continue; |
231 | const uint32_t replId = |
232 | ii->GetSingleWordInOperand(kInsertCompositeIdInIdx); |
233 | (void)context()->ReplaceAllUsesWith(id, replId); |
234 | dead_instructions.push_back(&*ii); |
235 | modified = true; |
236 | } |
237 | } |
238 | // DCE dead inserts |
239 | while (!dead_instructions.empty()) { |
240 | Instruction* inst = dead_instructions.back(); |
241 | dead_instructions.pop_back(); |
242 | DCEInst(inst, [&dead_instructions](Instruction* other_inst) { |
243 | auto i = std::find(dead_instructions.begin(), dead_instructions.end(), |
244 | other_inst); |
245 | if (i != dead_instructions.end()) { |
246 | dead_instructions.erase(i); |
247 | } |
248 | }); |
249 | } |
250 | return modified; |
251 | } |
252 | |
253 | Pass::Status DeadInsertElimPass::Process() { |
254 | // Process all entry point functions. |
255 | ProcessFunction pfn = [this](Function* fp) { |
256 | return EliminateDeadInserts(fp); |
257 | }; |
258 | bool modified = context()->ProcessEntryPointCallTree(pfn); |
259 | return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
260 | } |
261 | |
262 | } // namespace opt |
263 | } // namespace spvtools |
264 | |