1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/desc_sroa.h"
16
17#include "source/util/string_utils.h"
18
19namespace spvtools {
20namespace opt {
21
22Pass::Status DescriptorScalarReplacement::Process() {
23 bool modified = false;
24
25 std::vector<Instruction*> vars_to_kill;
26
27 for (Instruction& var : context()->types_values()) {
28 if (IsCandidate(&var)) {
29 modified = true;
30 if (!ReplaceCandidate(&var)) {
31 return Status::Failure;
32 }
33 vars_to_kill.push_back(&var);
34 }
35 }
36
37 for (Instruction* var : vars_to_kill) {
38 context()->KillInst(var);
39 }
40
41 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
42}
43
44bool DescriptorScalarReplacement::IsCandidate(Instruction* var) {
45 if (var->opcode() != SpvOpVariable) {
46 return false;
47 }
48
49 uint32_t ptr_type_id = var->type_id();
50 Instruction* ptr_type_inst =
51 context()->get_def_use_mgr()->GetDef(ptr_type_id);
52 if (ptr_type_inst->opcode() != SpvOpTypePointer) {
53 return false;
54 }
55
56 uint32_t var_type_id = ptr_type_inst->GetSingleWordInOperand(1);
57 Instruction* var_type_inst =
58 context()->get_def_use_mgr()->GetDef(var_type_id);
59 if (var_type_inst->opcode() != SpvOpTypeArray) {
60 return false;
61 }
62
63 bool has_desc_set_decoration = false;
64 context()->get_decoration_mgr()->ForEachDecoration(
65 var->result_id(), SpvDecorationDescriptorSet,
66 [&has_desc_set_decoration](const Instruction&) {
67 has_desc_set_decoration = true;
68 });
69 if (!has_desc_set_decoration) {
70 return false;
71 }
72
73 bool has_binding_decoration = false;
74 context()->get_decoration_mgr()->ForEachDecoration(
75 var->result_id(), SpvDecorationBinding,
76 [&has_binding_decoration](const Instruction&) {
77 has_binding_decoration = true;
78 });
79 if (!has_binding_decoration) {
80 return false;
81 }
82
83 return true;
84}
85
86bool DescriptorScalarReplacement::ReplaceCandidate(Instruction* var) {
87 std::vector<Instruction*> work_list;
88 bool failed = !get_def_use_mgr()->WhileEachUser(
89 var->result_id(), [this, &work_list](Instruction* use) {
90 if (use->opcode() == SpvOpName) {
91 return true;
92 }
93
94 if (use->IsDecoration()) {
95 return true;
96 }
97
98 switch (use->opcode()) {
99 case SpvOpAccessChain:
100 case SpvOpInBoundsAccessChain:
101 work_list.push_back(use);
102 return true;
103 default:
104 context()->EmitErrorMessage(
105 "Variable cannot be replaced: invalid instruction", use);
106 return false;
107 }
108 return true;
109 });
110
111 if (failed) {
112 return false;
113 }
114
115 for (Instruction* use : work_list) {
116 if (!ReplaceAccessChain(var, use)) {
117 return false;
118 }
119 }
120 return true;
121}
122
123bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var,
124 Instruction* use) {
125 if (use->NumInOperands() <= 1) {
126 context()->EmitErrorMessage(
127 "Variable cannot be replaced: invalid instruction", use);
128 return false;
129 }
130
131 uint32_t idx_id = use->GetSingleWordInOperand(1);
132 const analysis::Constant* idx_const =
133 context()->get_constant_mgr()->FindDeclaredConstant(idx_id);
134 if (idx_const == nullptr) {
135 context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
136 use);
137 return false;
138 }
139
140 uint32_t idx = idx_const->GetU32();
141 uint32_t replacement_var = GetReplacementVariable(var, idx);
142
143 if (use->NumInOperands() == 2) {
144 // We are not indexing into the replacement variable. We can replaces the
145 // access chain with the replacement varibale itself.
146 context()->ReplaceAllUsesWith(use->result_id(), replacement_var);
147 context()->KillInst(use);
148 return true;
149 }
150
151 // We need to build a new access chain with the replacement variable as the
152 // base address.
153 Instruction::OperandList new_operands;
154
155 // Same result id and result type.
156 new_operands.emplace_back(use->GetOperand(0));
157 new_operands.emplace_back(use->GetOperand(1));
158
159 // Use the replacement variable as the base address.
160 new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}});
161
162 // Drop the first index because it is consumed by the replacment, and copy the
163 // rest.
164 for (uint32_t i = 4; i < use->NumOperands(); i++) {
165 new_operands.emplace_back(use->GetOperand(i));
166 }
167
168 use->ReplaceOperands(new_operands);
169 context()->UpdateDefUse(use);
170 return true;
171}
172
173uint32_t DescriptorScalarReplacement::GetReplacementVariable(Instruction* var,
174 uint32_t idx) {
175 auto replacement_vars = replacement_variables_.find(var);
176 if (replacement_vars == replacement_variables_.end()) {
177 uint32_t ptr_type_id = var->type_id();
178 Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
179 assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
180 "Variable should be a pointer to an array.");
181 uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
182 Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
183 assert(arr_type_inst->opcode() == SpvOpTypeArray &&
184 "Variable should be a pointer to an array.");
185
186 uint32_t array_len_id = arr_type_inst->GetSingleWordInOperand(1);
187 const analysis::Constant* array_len_const =
188 context()->get_constant_mgr()->FindDeclaredConstant(array_len_id);
189 assert(array_len_const != nullptr && "Array length must be a constant.");
190 uint32_t array_len = array_len_const->GetU32();
191
192 replacement_vars = replacement_variables_
193 .insert({var, std::vector<uint32_t>(array_len, 0)})
194 .first;
195 }
196
197 if (replacement_vars->second[idx] == 0) {
198 replacement_vars->second[idx] = CreateReplacementVariable(var, idx);
199 }
200
201 return replacement_vars->second[idx];
202}
203
204uint32_t DescriptorScalarReplacement::CreateReplacementVariable(
205 Instruction* var, uint32_t idx) {
206 // The storage class for the new variable is the same as the original.
207 SpvStorageClass storage_class =
208 static_cast<SpvStorageClass>(var->GetSingleWordInOperand(0));
209
210 // The type for the new variable will be a pointer to type of the elements of
211 // the array.
212 uint32_t ptr_type_id = var->type_id();
213 Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
214 assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
215 "Variable should be a pointer to an array.");
216 uint32_t arr_type_id = ptr_type_inst->GetSingleWordInOperand(1);
217 Instruction* arr_type_inst = get_def_use_mgr()->GetDef(arr_type_id);
218 assert(arr_type_inst->opcode() == SpvOpTypeArray &&
219 "Variable should be a pointer to an array.");
220 uint32_t element_type_id = arr_type_inst->GetSingleWordInOperand(0);
221
222 uint32_t ptr_element_type_id = context()->get_type_mgr()->FindPointerToType(
223 element_type_id, storage_class);
224
225 // Create the variable.
226 uint32_t id = TakeNextId();
227 std::unique_ptr<Instruction> variable(
228 new Instruction(context(), SpvOpVariable, ptr_element_type_id, id,
229 std::initializer_list<Operand>{
230 {SPV_OPERAND_TYPE_STORAGE_CLASS,
231 {static_cast<uint32_t>(storage_class)}}}));
232 context()->AddGlobalValue(std::move(variable));
233
234 // Copy all of the decorations to the new variable. The only difference is
235 // the Binding decoration needs to be adjusted.
236 for (auto old_decoration :
237 get_decoration_mgr()->GetDecorationsFor(var->result_id(), true)) {
238 assert(old_decoration->opcode() == SpvOpDecorate);
239 std::unique_ptr<Instruction> new_decoration(
240 old_decoration->Clone(context()));
241 new_decoration->SetInOperand(0, {id});
242
243 uint32_t decoration = new_decoration->GetSingleWordInOperand(1u);
244 if (decoration == SpvDecorationBinding) {
245 uint32_t new_binding = new_decoration->GetSingleWordInOperand(2) + idx;
246 new_decoration->SetInOperand(2, {new_binding});
247 }
248 context()->AddAnnotationInst(std::move(new_decoration));
249 }
250
251 // Create a new OpName for the replacement variable.
252 for (auto p : context()->GetNames(var->result_id())) {
253 Instruction* name_inst = p.second;
254 std::string name_str = utils::MakeString(name_inst->GetOperand(1).words);
255 name_str += "[";
256 name_str += utils::ToString(idx);
257 name_str += "]";
258
259 std::unique_ptr<Instruction> new_name(new Instruction(
260 context(), SpvOpName, 0, 0,
261 std::initializer_list<Operand>{
262 {SPV_OPERAND_TYPE_ID, {id}},
263 {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}}));
264 Instruction* new_name_inst = new_name.get();
265 context()->AddDebug2Inst(std::move(new_name));
266 get_def_use_mgr()->AnalyzeInstDefUse(new_name_inst);
267 }
268
269 return id;
270}
271
272} // namespace opt
273} // namespace spvtools
274