1 | // Copyright (c) 2016 Google Inc. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/set_spec_constant_default_value_pass.h" |
16 | |
17 | #include <algorithm> |
18 | #include <cctype> |
19 | #include <cstring> |
20 | #include <tuple> |
21 | #include <vector> |
22 | |
23 | #include "source/opt/def_use_manager.h" |
24 | #include "source/opt/ir_context.h" |
25 | #include "source/opt/type_manager.h" |
26 | #include "source/opt/types.h" |
27 | #include "source/util/make_unique.h" |
28 | #include "source/util/parse_number.h" |
29 | #include "spirv-tools/libspirv.h" |
30 | |
31 | namespace spvtools { |
32 | namespace opt { |
33 | |
34 | namespace { |
35 | using utils::EncodeNumberStatus; |
36 | using utils::NumberType; |
37 | using utils::ParseAndEncodeNumber; |
38 | using utils::ParseNumber; |
39 | |
40 | // Given a numeric value in a null-terminated c string and the expected type of |
41 | // the value, parses the string and encodes it in a vector of words. If the |
42 | // value is a scalar integer or floating point value, encodes the value in |
43 | // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector |
44 | // with single word with value 0 or 1 respectively. Returns the vector |
45 | // containing the encoded value on success. Otherwise returns an empty vector. |
46 | std::vector<uint32_t> ParseDefaultValueStr(const char* text, |
47 | const analysis::Type* type) { |
48 | std::vector<uint32_t> result; |
49 | if (!strcmp(text, "true" ) && type->AsBool()) { |
50 | result.push_back(1u); |
51 | } else if (!strcmp(text, "false" ) && type->AsBool()) { |
52 | result.push_back(0u); |
53 | } else { |
54 | NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT}; |
55 | if (const auto* IT = type->AsInteger()) { |
56 | number_type.bitwidth = IT->width(); |
57 | number_type.kind = |
58 | IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; |
59 | } else if (const auto* FT = type->AsFloat()) { |
60 | number_type.bitwidth = FT->width(); |
61 | number_type.kind = SPV_NUMBER_FLOATING; |
62 | } else { |
63 | // Does not handle types other then boolean, integer or float. Returns |
64 | // empty vector. |
65 | result.clear(); |
66 | return result; |
67 | } |
68 | EncodeNumberStatus rc = ParseAndEncodeNumber( |
69 | text, number_type, [&result](uint32_t word) { result.push_back(word); }, |
70 | nullptr); |
71 | // Clear the result vector on failure. |
72 | if (rc != EncodeNumberStatus::kSuccess) { |
73 | result.clear(); |
74 | } |
75 | } |
76 | return result; |
77 | } |
78 | |
79 | // Given a bit pattern and a type, checks if the bit pattern is compatible |
80 | // with the type. If so, returns the bit pattern, otherwise returns an empty |
81 | // bit pattern. If the given bit pattern is empty, returns an empty bit |
82 | // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern |
83 | // to be returned is determined with the following standard: |
84 | // If any words in the input bit pattern are non zero, returns a bit pattern |
85 | // with 0x1, which represents a 'true'. |
86 | // If all words in the bit pattern are zero, returns a bit pattern with 0x0, |
87 | // which represents a 'false'. |
88 | std::vector<uint32_t> ParseDefaultValueBitPattern( |
89 | const std::vector<uint32_t>& input_bit_pattern, |
90 | const analysis::Type* type) { |
91 | std::vector<uint32_t> result; |
92 | if (type->AsBool()) { |
93 | if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(), |
94 | [](uint32_t i) { return i != 0; })) { |
95 | result.push_back(1u); |
96 | } else { |
97 | result.push_back(0u); |
98 | } |
99 | return result; |
100 | } else if (const auto* IT = type->AsInteger()) { |
101 | if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { |
102 | return std::vector<uint32_t>(input_bit_pattern); |
103 | } |
104 | } else if (const auto* FT = type->AsFloat()) { |
105 | if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) { |
106 | return std::vector<uint32_t>(input_bit_pattern); |
107 | } |
108 | } |
109 | result.clear(); |
110 | return result; |
111 | } |
112 | |
113 | // Returns true if the given instruction's result id could have a SpecId |
114 | // decoration. |
115 | bool CanHaveSpecIdDecoration(const Instruction& inst) { |
116 | switch (inst.opcode()) { |
117 | case SpvOp::SpvOpSpecConstant: |
118 | case SpvOp::SpvOpSpecConstantFalse: |
119 | case SpvOp::SpvOpSpecConstantTrue: |
120 | return true; |
121 | default: |
122 | return false; |
123 | } |
124 | } |
125 | |
126 | // Given a decoration group defining instruction that is decorated with SpecId |
127 | // decoration, finds the spec constant defining instruction which is the real |
128 | // target of the SpecId decoration. Returns the spec constant defining |
129 | // instruction if such an instruction is found, otherwise returns a nullptr. |
130 | Instruction* GetSpecIdTargetFromDecorationGroup( |
131 | const Instruction& decoration_group_defining_inst, |
132 | analysis::DefUseManager* def_use_mgr) { |
133 | // Find the OpGroupDecorate instruction which consumes the given decoration |
134 | // group. Note that the given decoration group has SpecId decoration, which |
135 | // is unique for different spec constants. So the decoration group cannot be |
136 | // consumed by different OpGroupDecorate instructions. Therefore we only need |
137 | // the first OpGroupDecoration instruction that uses the given decoration |
138 | // group. |
139 | Instruction* group_decorate_inst = nullptr; |
140 | if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst, |
141 | [&group_decorate_inst](Instruction* user) { |
142 | if (user->opcode() == |
143 | SpvOp::SpvOpGroupDecorate) { |
144 | group_decorate_inst = user; |
145 | return false; |
146 | } |
147 | return true; |
148 | })) |
149 | return nullptr; |
150 | |
151 | // Scan through the target ids of the OpGroupDecorate instruction. There |
152 | // should be only one spec constant target consumes the SpecId decoration. |
153 | // If multiple target ids are presented in the OpGroupDecorate instruction, |
154 | // they must be the same one that defined by an eligible spec constant |
155 | // instruction. If the OpGroupDecorate instruction has different target ids |
156 | // or a target id is not defined by an eligible spec cosntant instruction, |
157 | // returns a nullptr. |
158 | Instruction* target_inst = nullptr; |
159 | for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) { |
160 | // All the operands of a OpGroupDecorate instruction should be of type |
161 | // SPV_OPERAND_TYPE_ID. |
162 | uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i); |
163 | Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id); |
164 | |
165 | if (!candidate_inst) { |
166 | continue; |
167 | } |
168 | |
169 | if (!target_inst) { |
170 | // If the spec constant target has not been found yet, check if the |
171 | // candidate instruction is the target. |
172 | if (CanHaveSpecIdDecoration(*candidate_inst)) { |
173 | target_inst = candidate_inst; |
174 | } else { |
175 | // Spec id decoration should not be applied on other instructions. |
176 | // TODO(qining): Emit an error message in the invalid case once the |
177 | // error handling is done. |
178 | return nullptr; |
179 | } |
180 | } else { |
181 | // If the spec constant target has been found, check if the candidate |
182 | // instruction is the same one as the target. The module is invalid if |
183 | // the candidate instruction is different with the found target. |
184 | // TODO(qining): Emit an error messaage in the invalid case once the |
185 | // error handling is done. |
186 | if (candidate_inst != target_inst) return nullptr; |
187 | } |
188 | } |
189 | return target_inst; |
190 | } |
191 | } // namespace |
192 | |
193 | Pass::Status SetSpecConstantDefaultValuePass::Process() { |
194 | // The operand index of decoration target in an OpDecorate instruction. |
195 | const uint32_t kTargetIdOperandIndex = 0; |
196 | // The operand index of the decoration literal in an OpDecorate instruction. |
197 | const uint32_t kDecorationOperandIndex = 1; |
198 | // The operand index of Spec id literal value in an OpDecorate SpecId |
199 | // instruction. |
200 | const uint32_t kSpecIdLiteralOperandIndex = 2; |
201 | // The number of operands in an OpDecorate SpecId instruction. |
202 | const uint32_t kOpDecorateSpecIdNumOperands = 3; |
203 | // The in-operand index of the default value in a OpSpecConstant instruction. |
204 | const uint32_t kOpSpecConstantLiteralInOperandIndex = 0; |
205 | |
206 | bool modified = false; |
207 | // Scan through all the annotation instructions to find 'OpDecorate SpecId' |
208 | // instructions. Then extract the decoration target of those instructions. |
209 | // The decoration targets should be spec constant defining instructions with |
210 | // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants |
211 | // will be used to look up their new default values in the mapping from |
212 | // spec id to new default value strings. Once a new default value string |
213 | // is found for a spec id, the string will be parsed according to the target |
214 | // spec constant type. The parsed value will be used to replace the original |
215 | // default value of the target spec constant. |
216 | for (Instruction& inst : context()->annotations()) { |
217 | // Only process 'OpDecorate SpecId' instructions |
218 | if (inst.opcode() != SpvOp::SpvOpDecorate) continue; |
219 | if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue; |
220 | if (inst.GetSingleWordInOperand(kDecorationOperandIndex) != |
221 | uint32_t(SpvDecoration::SpvDecorationSpecId)) { |
222 | continue; |
223 | } |
224 | |
225 | // 'inst' is an OpDecorate SpecId instruction. |
226 | uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex); |
227 | uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex); |
228 | |
229 | // Find the spec constant defining instruction. Note that the |
230 | // target_id might be a decoration group id. |
231 | Instruction* spec_inst = nullptr; |
232 | if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) { |
233 | if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) { |
234 | spec_inst = |
235 | GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr()); |
236 | } else { |
237 | spec_inst = target_inst; |
238 | } |
239 | } else { |
240 | continue; |
241 | } |
242 | if (!spec_inst) continue; |
243 | |
244 | // Get the default value bit pattern for this spec id. |
245 | std::vector<uint32_t> bit_pattern; |
246 | |
247 | if (spec_id_to_value_str_.size() != 0) { |
248 | // Search for the new string-form default value for this spec id. |
249 | auto iter = spec_id_to_value_str_.find(spec_id); |
250 | if (iter == spec_id_to_value_str_.end()) { |
251 | continue; |
252 | } |
253 | |
254 | // Gets the string of the default value and parses it to bit pattern |
255 | // with the type of the spec constant. |
256 | const std::string& default_value_str = iter->second; |
257 | bit_pattern = ParseDefaultValueStr( |
258 | default_value_str.c_str(), |
259 | context()->get_type_mgr()->GetType(spec_inst->type_id())); |
260 | |
261 | } else { |
262 | // Search for the new bit-pattern-form default value for this spec id. |
263 | auto iter = spec_id_to_value_bit_pattern_.find(spec_id); |
264 | if (iter == spec_id_to_value_bit_pattern_.end()) { |
265 | continue; |
266 | } |
267 | |
268 | // Gets the bit-pattern of the default value from the map directly. |
269 | bit_pattern = ParseDefaultValueBitPattern( |
270 | iter->second, |
271 | context()->get_type_mgr()->GetType(spec_inst->type_id())); |
272 | } |
273 | |
274 | if (bit_pattern.empty()) continue; |
275 | |
276 | // Update the operand bit patterns of the spec constant defining |
277 | // instruction. |
278 | switch (spec_inst->opcode()) { |
279 | case SpvOp::SpvOpSpecConstant: |
280 | // If the new value is the same with the original value, no |
281 | // need to do anything. Otherwise update the operand words. |
282 | if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex) |
283 | .words != bit_pattern) { |
284 | spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex, |
285 | std::move(bit_pattern)); |
286 | modified = true; |
287 | } |
288 | break; |
289 | case SpvOp::SpvOpSpecConstantTrue: |
290 | // If the new value is also 'true', no need to change anything. |
291 | // Otherwise, set the opcode to OpSpecConstantFalse; |
292 | if (!static_cast<bool>(bit_pattern.front())) { |
293 | spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse); |
294 | modified = true; |
295 | } |
296 | break; |
297 | case SpvOp::SpvOpSpecConstantFalse: |
298 | // If the new value is also 'false', no need to change anything. |
299 | // Otherwise, set the opcode to OpSpecConstantTrue; |
300 | if (static_cast<bool>(bit_pattern.front())) { |
301 | spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue); |
302 | modified = true; |
303 | } |
304 | break; |
305 | default: |
306 | break; |
307 | } |
308 | // No need to update the DefUse manager, as this pass does not change any |
309 | // ids. |
310 | } |
311 | return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
312 | } |
313 | |
314 | // Returns true if the given char is ':', '\0' or considered as blank space |
315 | // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' '). |
316 | bool IsSeparator(char ch) { |
317 | return std::strchr(":\0" , ch) || std::isspace(ch) != 0; |
318 | } |
319 | |
320 | std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap> |
321 | SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) { |
322 | if (!str) return nullptr; |
323 | |
324 | auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>(); |
325 | |
326 | // The parsing loop, break when points to the end. |
327 | while (*str) { |
328 | // Find the spec id. |
329 | while (std::isspace(*str)) str++; // skip leading spaces. |
330 | const char* entry_begin = str; |
331 | while (!IsSeparator(*str)) str++; |
332 | const char* entry_end = str; |
333 | std::string spec_id_str(entry_begin, entry_end - entry_begin); |
334 | uint32_t spec_id = 0; |
335 | if (!ParseNumber(spec_id_str.c_str(), &spec_id)) { |
336 | // The spec id is not a valid uint32 number. |
337 | return nullptr; |
338 | } |
339 | auto iter = spec_id_to_value->find(spec_id); |
340 | if (iter != spec_id_to_value->end()) { |
341 | // Same spec id has been defined before |
342 | return nullptr; |
343 | } |
344 | // Find the ':', spaces between the spec id and the ':' are not allowed. |
345 | if (*str++ != ':') { |
346 | // ':' not found |
347 | return nullptr; |
348 | } |
349 | // Find the value string |
350 | const char* val_begin = str; |
351 | while (!IsSeparator(*str)) str++; |
352 | const char* val_end = str; |
353 | if (val_end == val_begin) { |
354 | // Value string is empty. |
355 | return nullptr; |
356 | } |
357 | // Update the mapping with spec id and value string. |
358 | (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin); |
359 | |
360 | // Skip trailing spaces. |
361 | while (std::isspace(*str)) str++; |
362 | } |
363 | |
364 | return spec_id_to_value; |
365 | } |
366 | |
367 | } // namespace opt |
368 | } // namespace spvtools |
369 | |