1 | // Copyright (c) 2018 Google LLC |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/combine_access_chains.h" |
16 | |
17 | #include <utility> |
18 | |
19 | #include "source/opt/constants.h" |
20 | #include "source/opt/ir_builder.h" |
21 | #include "source/opt/ir_context.h" |
22 | |
23 | namespace spvtools { |
24 | namespace opt { |
25 | |
26 | Pass::Status CombineAccessChains::Process() { |
27 | bool modified = false; |
28 | |
29 | for (auto& function : *get_module()) { |
30 | modified |= ProcessFunction(function); |
31 | } |
32 | |
33 | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
34 | } |
35 | |
36 | bool CombineAccessChains::ProcessFunction(Function& function) { |
37 | bool modified = false; |
38 | |
39 | cfg()->ForEachBlockInReversePostOrder( |
40 | function.entry().get(), [&modified, this](BasicBlock* block) { |
41 | block->ForEachInst([&modified, this](Instruction* inst) { |
42 | switch (inst->opcode()) { |
43 | case SpvOpAccessChain: |
44 | case SpvOpInBoundsAccessChain: |
45 | case SpvOpPtrAccessChain: |
46 | case SpvOpInBoundsPtrAccessChain: |
47 | modified |= CombineAccessChain(inst); |
48 | break; |
49 | default: |
50 | break; |
51 | } |
52 | }); |
53 | }); |
54 | |
55 | return modified; |
56 | } |
57 | |
58 | uint32_t CombineAccessChains::GetConstantValue( |
59 | const analysis::Constant* constant_inst) { |
60 | if (constant_inst->type()->AsInteger()->width() <= 32) { |
61 | if (constant_inst->type()->AsInteger()->IsSigned()) { |
62 | return static_cast<uint32_t>(constant_inst->GetS32()); |
63 | } else { |
64 | return constant_inst->GetU32(); |
65 | } |
66 | } else { |
67 | assert(false); |
68 | return 0u; |
69 | } |
70 | } |
71 | |
72 | uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { |
73 | uint32_t array_stride = 0; |
74 | context()->get_decoration_mgr()->WhileEachDecoration( |
75 | inst->type_id(), SpvDecorationArrayStride, |
76 | [&array_stride](const Instruction& decoration) { |
77 | assert(decoration.opcode() != SpvOpDecorateId); |
78 | if (decoration.opcode() == SpvOpDecorate) { |
79 | array_stride = decoration.GetSingleWordInOperand(1); |
80 | } else { |
81 | array_stride = decoration.GetSingleWordInOperand(2); |
82 | } |
83 | return false; |
84 | }); |
85 | return array_stride; |
86 | } |
87 | |
88 | const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { |
89 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
90 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
91 | |
92 | Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
93 | const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); |
94 | assert(type->AsPointer()); |
95 | type = type->AsPointer()->pointee_type(); |
96 | std::vector<uint32_t> element_indices; |
97 | uint32_t starting_index = 1; |
98 | if (IsPtrAccessChain(inst->opcode())) { |
99 | // Skip the first index of OpPtrAccessChain as it does not affect type |
100 | // resolution. |
101 | starting_index = 2; |
102 | } |
103 | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
104 | Instruction* index_inst = |
105 | def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); |
106 | const analysis::Constant* index_constant = |
107 | context()->get_constant_mgr()->GetConstantFromInst(index_inst); |
108 | if (index_constant) { |
109 | uint32_t index_value = GetConstantValue(index_constant); |
110 | element_indices.push_back(index_value); |
111 | } else { |
112 | // This index must not matter to resolve the type in valid SPIR-V. |
113 | element_indices.push_back(0); |
114 | } |
115 | } |
116 | type = type_mgr->GetMemberType(type, element_indices); |
117 | return type; |
118 | } |
119 | |
120 | bool CombineAccessChains::CombineIndices(Instruction* ptr_input, |
121 | Instruction* inst, |
122 | std::vector<Operand>* new_operands) { |
123 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
124 | analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); |
125 | |
126 | Instruction* last_index_inst = def_use_mgr->GetDef( |
127 | ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); |
128 | const analysis::Constant* last_index_constant = |
129 | constant_mgr->GetConstantFromInst(last_index_inst); |
130 | |
131 | Instruction* element_inst = |
132 | def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
133 | const analysis::Constant* element_constant = |
134 | constant_mgr->GetConstantFromInst(element_inst); |
135 | |
136 | // Combine the last index of the AccessChain (|ptr_inst|) with the element |
137 | // operand of the PtrAccessChain (|inst|). |
138 | const bool combining_element_operands = |
139 | IsPtrAccessChain(inst->opcode()) && |
140 | IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; |
141 | uint32_t new_value_id = 0; |
142 | const analysis::Type* type = GetIndexedType(ptr_input); |
143 | if (last_index_constant && element_constant) { |
144 | // Combine the constants. |
145 | uint32_t new_value = GetConstantValue(last_index_constant) + |
146 | GetConstantValue(element_constant); |
147 | const analysis::Constant* new_value_constant = |
148 | constant_mgr->GetConstant(last_index_constant->type(), {new_value}); |
149 | Instruction* new_value_inst = |
150 | constant_mgr->GetDefiningInstruction(new_value_constant); |
151 | new_value_id = new_value_inst->result_id(); |
152 | } else if (!type->AsStruct() || combining_element_operands) { |
153 | // Generate an addition of the two indices. |
154 | InstructionBuilder builder( |
155 | context(), inst, |
156 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
157 | Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), |
158 | last_index_inst->result_id(), |
159 | element_inst->result_id()); |
160 | new_value_id = addition->result_id(); |
161 | } else { |
162 | // Indexing into structs must be constant, so bail out here. |
163 | return false; |
164 | } |
165 | new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); |
166 | return true; |
167 | } |
168 | |
169 | bool CombineAccessChains::CreateNewInputOperands( |
170 | Instruction* ptr_input, Instruction* inst, |
171 | std::vector<Operand>* new_operands) { |
172 | // Start by copying all the input operands of the feeder access chain. |
173 | for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { |
174 | new_operands->push_back(ptr_input->GetInOperand(i)); |
175 | } |
176 | |
177 | // Deal with the last index of the feeder access chain. |
178 | if (IsPtrAccessChain(inst->opcode())) { |
179 | // The last index of the feeder should be combined with the element operand |
180 | // of |inst|. |
181 | if (!CombineIndices(ptr_input, inst, new_operands)) return false; |
182 | } else { |
183 | // The indices aren't being combined so now add the last index operand of |
184 | // |ptr_input|. |
185 | new_operands->push_back( |
186 | ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); |
187 | } |
188 | |
189 | // Copy the remaining index operands. |
190 | uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; |
191 | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
192 | new_operands->push_back(inst->GetInOperand(i)); |
193 | } |
194 | |
195 | return true; |
196 | } |
197 | |
198 | bool CombineAccessChains::CombineAccessChain(Instruction* inst) { |
199 | assert((inst->opcode() == SpvOpPtrAccessChain || |
200 | inst->opcode() == SpvOpAccessChain || |
201 | inst->opcode() == SpvOpInBoundsAccessChain || |
202 | inst->opcode() == SpvOpInBoundsPtrAccessChain) && |
203 | "Wrong opcode. Expected an access chain." ); |
204 | |
205 | Instruction* ptr_input = |
206 | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); |
207 | if (ptr_input->opcode() != SpvOpAccessChain && |
208 | ptr_input->opcode() != SpvOpInBoundsAccessChain && |
209 | ptr_input->opcode() != SpvOpPtrAccessChain && |
210 | ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { |
211 | return false; |
212 | } |
213 | |
214 | if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; |
215 | |
216 | // Handles the following cases: |
217 | // 1. |ptr_input| is an index-less access chain. Replace the pointer |
218 | // in |inst| with |ptr_input|'s pointer. |
219 | // 2. |inst| is a index-less access chain. Change |inst| to an |
220 | // OpCopyObject. |
221 | // 3. |inst| is not a pointer access chain. |
222 | // |inst|'s indices are appended to |ptr_input|'s indices. |
223 | // 4. |ptr_input| is not pointer access chain. |
224 | // |inst| is a pointer access chain. |
225 | // |inst|'s element operand is combined with the last index in |
226 | // |ptr_input| to form a new operand. |
227 | // 5. |ptr_input| is a pointer access chain. |
228 | // Like the above scenario, |inst|'s element operand is combined |
229 | // with |ptr_input|'s last index. This results is either a |
230 | // combined element operand or combined regular index. |
231 | |
232 | // TODO(alan-baker): Support this properly. Requires analyzing the |
233 | // size/alignment of the type and converting the stride into an element |
234 | // index. |
235 | uint32_t array_stride = GetArrayStride(ptr_input); |
236 | if (array_stride != 0) return false; |
237 | |
238 | if (ptr_input->NumInOperands() == 1) { |
239 | // The input is effectively a no-op. |
240 | inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); |
241 | context()->AnalyzeUses(inst); |
242 | } else if (inst->NumInOperands() == 1) { |
243 | // |inst| is a no-op, change it to a copy. Instruction simplification will |
244 | // clean it up. |
245 | inst->SetOpcode(SpvOpCopyObject); |
246 | } else { |
247 | std::vector<Operand> new_operands; |
248 | if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; |
249 | |
250 | // Update the instruction. |
251 | inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); |
252 | inst->SetInOperands(std::move(new_operands)); |
253 | context()->AnalyzeUses(inst); |
254 | } |
255 | return true; |
256 | } |
257 | |
258 | SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { |
259 | auto IsInBounds = [](SpvOp opcode) { |
260 | return opcode == SpvOpInBoundsPtrAccessChain || |
261 | opcode == SpvOpInBoundsAccessChain; |
262 | }; |
263 | |
264 | if (input_opcode == SpvOpInBoundsPtrAccessChain) { |
265 | if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; |
266 | } else if (input_opcode == SpvOpInBoundsAccessChain) { |
267 | if (!IsInBounds(base_opcode)) return SpvOpAccessChain; |
268 | } |
269 | |
270 | return input_opcode; |
271 | } |
272 | |
273 | bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { |
274 | return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; |
275 | } |
276 | |
277 | bool CombineAccessChains::Has64BitIndices(Instruction* inst) { |
278 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
279 | Instruction* index_inst = |
280 | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); |
281 | const analysis::Type* index_type = |
282 | context()->get_type_mgr()->GetType(index_inst->type_id()); |
283 | if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) |
284 | return true; |
285 | } |
286 | return false; |
287 | } |
288 | |
289 | } // namespace opt |
290 | } // namespace spvtools |
291 | |