1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/private_to_local_pass.h"
16
17#include <memory>
18#include <utility>
19#include <vector>
20
21#include "source/opt/ir_context.h"
22#include "source/spirv_constant.h"
23
24namespace spvtools {
25namespace opt {
26namespace {
27
28const uint32_t kVariableStorageClassInIdx = 0;
29const uint32_t kSpvTypePointerTypeIdInIdx = 1;
30
31} // namespace
32
33Pass::Status PrivateToLocalPass::Process() {
34 bool modified = false;
35
36 // Private variables require the shader capability. If this is not a shader,
37 // there is no work to do.
38 if (context()->get_feature_mgr()->HasCapability(SpvCapabilityAddresses))
39 return Status::SuccessWithoutChange;
40
41 std::vector<std::pair<Instruction*, Function*>> variables_to_move;
42 std::unordered_set<uint32_t> localized_variables;
43 for (auto& inst : context()->types_values()) {
44 if (inst.opcode() != SpvOpVariable) {
45 continue;
46 }
47
48 if (inst.GetSingleWordInOperand(kVariableStorageClassInIdx) !=
49 SpvStorageClassPrivate) {
50 continue;
51 }
52
53 Function* target_function = FindLocalFunction(inst);
54 if (target_function != nullptr) {
55 variables_to_move.push_back({&inst, target_function});
56 }
57 }
58
59 modified = !variables_to_move.empty();
60 for (auto p : variables_to_move) {
61 if (!MoveVariable(p.first, p.second)) {
62 return Status::Failure;
63 }
64 localized_variables.insert(p.first->result_id());
65 }
66
67 if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4)) {
68 // In SPIR-V 1.4 and later entry points must list private storage class
69 // variables that are statically used by the entry point. Go through the
70 // entry points and remove any references to variables that were localized.
71 for (auto& entry : get_module()->entry_points()) {
72 std::vector<Operand> new_operands;
73 for (uint32_t i = 0; i < entry.NumInOperands(); ++i) {
74 // Execution model, function id and name are always kept.
75 if (i < 3 ||
76 !localized_variables.count(entry.GetSingleWordInOperand(i))) {
77 new_operands.push_back(entry.GetInOperand(i));
78 }
79 }
80 if (new_operands.size() != entry.NumInOperands()) {
81 entry.SetInOperands(std::move(new_operands));
82 context()->AnalyzeUses(&entry);
83 }
84 }
85 }
86
87 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
88}
89
90Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
91 bool found_first_use = false;
92 Function* target_function = nullptr;
93 context()->get_def_use_mgr()->ForEachUser(
94 inst.result_id(),
95 [&target_function, &found_first_use, this](Instruction* use) {
96 BasicBlock* current_block = context()->get_instr_block(use);
97 if (current_block == nullptr) {
98 return;
99 }
100
101 if (!IsValidUse(use)) {
102 found_first_use = true;
103 target_function = nullptr;
104 return;
105 }
106 Function* current_function = current_block->GetParent();
107 if (!found_first_use) {
108 found_first_use = true;
109 target_function = current_function;
110 } else if (target_function != current_function) {
111 target_function = nullptr;
112 }
113 });
114 return target_function;
115} // namespace opt
116
117bool PrivateToLocalPass::MoveVariable(Instruction* variable,
118 Function* function) {
119 // The variable needs to be removed from the global section, and placed in the
120 // header of the function. First step remove from the global list.
121 variable->RemoveFromList();
122 std::unique_ptr<Instruction> var(variable); // Take ownership.
123 context()->ForgetUses(variable);
124
125 // Update the storage class of the variable.
126 variable->SetInOperand(kVariableStorageClassInIdx, {SpvStorageClassFunction});
127
128 // Update the type as well.
129 uint32_t new_type_id = GetNewType(variable->type_id());
130 if (new_type_id == 0) {
131 return false;
132 }
133 variable->SetResultType(new_type_id);
134
135 // Place the variable at the start of the first basic block.
136 context()->AnalyzeUses(variable);
137 context()->set_instr_block(variable, &*function->begin());
138 function->begin()->begin()->InsertBefore(move(var));
139
140 // Update uses where the type may have changed.
141 return UpdateUses(variable->result_id());
142}
143
144uint32_t PrivateToLocalPass::GetNewType(uint32_t old_type_id) {
145 auto type_mgr = context()->get_type_mgr();
146 Instruction* old_type_inst = get_def_use_mgr()->GetDef(old_type_id);
147 uint32_t pointee_type_id =
148 old_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
149 uint32_t new_type_id =
150 type_mgr->FindPointerToType(pointee_type_id, SpvStorageClassFunction);
151 if (new_type_id != 0) {
152 context()->UpdateDefUse(context()->get_def_use_mgr()->GetDef(new_type_id));
153 }
154 return new_type_id;
155}
156
157bool PrivateToLocalPass::IsValidUse(const Instruction* inst) const {
158 // The cases in this switch have to match the cases in |UpdateUse|.
159 // If we don't know how to update it, it is not valid.
160 switch (inst->opcode()) {
161 case SpvOpLoad:
162 case SpvOpStore:
163 case SpvOpImageTexelPointer: // Treat like a load
164 return true;
165 case SpvOpAccessChain:
166 return context()->get_def_use_mgr()->WhileEachUser(
167 inst, [this](const Instruction* user) {
168 if (!IsValidUse(user)) return false;
169 return true;
170 });
171 case SpvOpName:
172 return true;
173 default:
174 return spvOpcodeIsDecoration(inst->opcode());
175 }
176}
177
178bool PrivateToLocalPass::UpdateUse(Instruction* inst) {
179 // The cases in this switch have to match the cases in |IsValidUse|. If we
180 // don't think it is valid, the optimization will not view the variable as a
181 // candidate, and therefore the use will not be updated.
182 switch (inst->opcode()) {
183 case SpvOpLoad:
184 case SpvOpStore:
185 case SpvOpImageTexelPointer: // Treat like a load
186 // The type is fine because it is the type pointed to, and that does not
187 // change.
188 break;
189 case SpvOpAccessChain: {
190 context()->ForgetUses(inst);
191 uint32_t new_type_id = GetNewType(inst->type_id());
192 if (new_type_id == 0) {
193 return false;
194 }
195 inst->SetResultType(new_type_id);
196 context()->AnalyzeUses(inst);
197
198 // Update uses where the type may have changed.
199 if (!UpdateUses(inst->result_id())) {
200 return false;
201 }
202 } break;
203 case SpvOpName:
204 case SpvOpEntryPoint: // entry points will be updated separately.
205 break;
206 default:
207 assert(spvOpcodeIsDecoration(inst->opcode()) &&
208 "Do not know how to update the type for this instruction.");
209 break;
210 }
211 return true;
212}
213
214bool PrivateToLocalPass::UpdateUses(uint32_t id) {
215 std::vector<Instruction*> uses;
216 context()->get_def_use_mgr()->ForEachUser(
217 id, [&uses](Instruction* use) { uses.push_back(use); });
218
219 for (Instruction* use : uses) {
220 if (!UpdateUse(use)) {
221 return false;
222 }
223 }
224 return true;
225}
226
227} // namespace opt
228} // namespace spvtools
229