1// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/combine_access_chains.h"
16
17#include <utility>
18
19#include "source/opt/constants.h"
20#include "source/opt/ir_builder.h"
21#include "source/opt/ir_context.h"
22
23namespace spvtools {
24namespace opt {
25
26Pass::Status CombineAccessChains::Process() {
27 bool modified = false;
28
29 for (auto& function : *get_module()) {
30 modified |= ProcessFunction(function);
31 }
32
33 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
34}
35
36bool CombineAccessChains::ProcessFunction(Function& function) {
37 bool modified = false;
38
39 cfg()->ForEachBlockInReversePostOrder(
40 function.entry().get(), [&modified, this](BasicBlock* block) {
41 block->ForEachInst([&modified, this](Instruction* inst) {
42 switch (inst->opcode()) {
43 case SpvOpAccessChain:
44 case SpvOpInBoundsAccessChain:
45 case SpvOpPtrAccessChain:
46 case SpvOpInBoundsPtrAccessChain:
47 modified |= CombineAccessChain(inst);
48 break;
49 default:
50 break;
51 }
52 });
53 });
54
55 return modified;
56}
57
58uint32_t CombineAccessChains::GetConstantValue(
59 const analysis::Constant* constant_inst) {
60 if (constant_inst->type()->AsInteger()->width() <= 32) {
61 if (constant_inst->type()->AsInteger()->IsSigned()) {
62 return static_cast<uint32_t>(constant_inst->GetS32());
63 } else {
64 return constant_inst->GetU32();
65 }
66 } else {
67 assert(false);
68 return 0u;
69 }
70}
71
72uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
73 uint32_t array_stride = 0;
74 context()->get_decoration_mgr()->WhileEachDecoration(
75 inst->type_id(), SpvDecorationArrayStride,
76 [&array_stride](const Instruction& decoration) {
77 assert(decoration.opcode() != SpvOpDecorateId);
78 if (decoration.opcode() == SpvOpDecorate) {
79 array_stride = decoration.GetSingleWordInOperand(1);
80 } else {
81 array_stride = decoration.GetSingleWordInOperand(2);
82 }
83 return false;
84 });
85 return array_stride;
86}
87
88const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
89 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
90 analysis::TypeManager* type_mgr = context()->get_type_mgr();
91
92 Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
93 const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
94 assert(type->AsPointer());
95 type = type->AsPointer()->pointee_type();
96 std::vector<uint32_t> element_indices;
97 uint32_t starting_index = 1;
98 if (IsPtrAccessChain(inst->opcode())) {
99 // Skip the first index of OpPtrAccessChain as it does not affect type
100 // resolution.
101 starting_index = 2;
102 }
103 for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
104 Instruction* index_inst =
105 def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
106 const analysis::Constant* index_constant =
107 context()->get_constant_mgr()->GetConstantFromInst(index_inst);
108 if (index_constant) {
109 uint32_t index_value = GetConstantValue(index_constant);
110 element_indices.push_back(index_value);
111 } else {
112 // This index must not matter to resolve the type in valid SPIR-V.
113 element_indices.push_back(0);
114 }
115 }
116 type = type_mgr->GetMemberType(type, element_indices);
117 return type;
118}
119
120bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
121 Instruction* inst,
122 std::vector<Operand>* new_operands) {
123 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
124 analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
125
126 Instruction* last_index_inst = def_use_mgr->GetDef(
127 ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
128 const analysis::Constant* last_index_constant =
129 constant_mgr->GetConstantFromInst(last_index_inst);
130
131 Instruction* element_inst =
132 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
133 const analysis::Constant* element_constant =
134 constant_mgr->GetConstantFromInst(element_inst);
135
136 // Combine the last index of the AccessChain (|ptr_inst|) with the element
137 // operand of the PtrAccessChain (|inst|).
138 const bool combining_element_operands =
139 IsPtrAccessChain(inst->opcode()) &&
140 IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
141 uint32_t new_value_id = 0;
142 const analysis::Type* type = GetIndexedType(ptr_input);
143 if (last_index_constant && element_constant) {
144 // Combine the constants.
145 uint32_t new_value = GetConstantValue(last_index_constant) +
146 GetConstantValue(element_constant);
147 const analysis::Constant* new_value_constant =
148 constant_mgr->GetConstant(last_index_constant->type(), {new_value});
149 Instruction* new_value_inst =
150 constant_mgr->GetDefiningInstruction(new_value_constant);
151 new_value_id = new_value_inst->result_id();
152 } else if (!type->AsStruct() || combining_element_operands) {
153 // Generate an addition of the two indices.
154 InstructionBuilder builder(
155 context(), inst,
156 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
157 Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
158 last_index_inst->result_id(),
159 element_inst->result_id());
160 new_value_id = addition->result_id();
161 } else {
162 // Indexing into structs must be constant, so bail out here.
163 return false;
164 }
165 new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
166 return true;
167}
168
169bool CombineAccessChains::CreateNewInputOperands(
170 Instruction* ptr_input, Instruction* inst,
171 std::vector<Operand>* new_operands) {
172 // Start by copying all the input operands of the feeder access chain.
173 for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
174 new_operands->push_back(ptr_input->GetInOperand(i));
175 }
176
177 // Deal with the last index of the feeder access chain.
178 if (IsPtrAccessChain(inst->opcode())) {
179 // The last index of the feeder should be combined with the element operand
180 // of |inst|.
181 if (!CombineIndices(ptr_input, inst, new_operands)) return false;
182 } else {
183 // The indices aren't being combined so now add the last index operand of
184 // |ptr_input|.
185 new_operands->push_back(
186 ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
187 }
188
189 // Copy the remaining index operands.
190 uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
191 for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
192 new_operands->push_back(inst->GetInOperand(i));
193 }
194
195 return true;
196}
197
198bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
199 assert((inst->opcode() == SpvOpPtrAccessChain ||
200 inst->opcode() == SpvOpAccessChain ||
201 inst->opcode() == SpvOpInBoundsAccessChain ||
202 inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
203 "Wrong opcode. Expected an access chain.");
204
205 Instruction* ptr_input =
206 context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
207 if (ptr_input->opcode() != SpvOpAccessChain &&
208 ptr_input->opcode() != SpvOpInBoundsAccessChain &&
209 ptr_input->opcode() != SpvOpPtrAccessChain &&
210 ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
211 return false;
212 }
213
214 if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
215
216 // Handles the following cases:
217 // 1. |ptr_input| is an index-less access chain. Replace the pointer
218 // in |inst| with |ptr_input|'s pointer.
219 // 2. |inst| is a index-less access chain. Change |inst| to an
220 // OpCopyObject.
221 // 3. |inst| is not a pointer access chain.
222 // |inst|'s indices are appended to |ptr_input|'s indices.
223 // 4. |ptr_input| is not pointer access chain.
224 // |inst| is a pointer access chain.
225 // |inst|'s element operand is combined with the last index in
226 // |ptr_input| to form a new operand.
227 // 5. |ptr_input| is a pointer access chain.
228 // Like the above scenario, |inst|'s element operand is combined
229 // with |ptr_input|'s last index. This results is either a
230 // combined element operand or combined regular index.
231
232 // TODO(alan-baker): Support this properly. Requires analyzing the
233 // size/alignment of the type and converting the stride into an element
234 // index.
235 uint32_t array_stride = GetArrayStride(ptr_input);
236 if (array_stride != 0) return false;
237
238 if (ptr_input->NumInOperands() == 1) {
239 // The input is effectively a no-op.
240 inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
241 context()->AnalyzeUses(inst);
242 } else if (inst->NumInOperands() == 1) {
243 // |inst| is a no-op, change it to a copy. Instruction simplification will
244 // clean it up.
245 inst->SetOpcode(SpvOpCopyObject);
246 } else {
247 std::vector<Operand> new_operands;
248 if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
249
250 // Update the instruction.
251 inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
252 inst->SetInOperands(std::move(new_operands));
253 context()->AnalyzeUses(inst);
254 }
255 return true;
256}
257
258SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
259 auto IsInBounds = [](SpvOp opcode) {
260 return opcode == SpvOpInBoundsPtrAccessChain ||
261 opcode == SpvOpInBoundsAccessChain;
262 };
263
264 if (input_opcode == SpvOpInBoundsPtrAccessChain) {
265 if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
266 } else if (input_opcode == SpvOpInBoundsAccessChain) {
267 if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
268 }
269
270 return input_opcode;
271}
272
273bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
274 return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
275}
276
277bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
278 for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
279 Instruction* index_inst =
280 context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
281 const analysis::Type* index_type =
282 context()->get_type_mgr()->GetType(index_inst->type_id());
283 if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
284 return true;
285 }
286 return false;
287}
288
289} // namespace opt
290} // namespace spvtools
291