1// Copyright (c) 2016 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/set_spec_constant_default_value_pass.h"
16
17#include <algorithm>
18#include <cctype>
19#include <cstring>
20#include <tuple>
21#include <vector>
22
23#include "source/opt/def_use_manager.h"
24#include "source/opt/ir_context.h"
25#include "source/opt/type_manager.h"
26#include "source/opt/types.h"
27#include "source/util/make_unique.h"
28#include "source/util/parse_number.h"
29#include "spirv-tools/libspirv.h"
30
31namespace spvtools {
32namespace opt {
33
34namespace {
35using utils::EncodeNumberStatus;
36using utils::NumberType;
37using utils::ParseAndEncodeNumber;
38using utils::ParseNumber;
39
40// Given a numeric value in a null-terminated c string and the expected type of
41// the value, parses the string and encodes it in a vector of words. If the
42// value is a scalar integer or floating point value, encodes the value in
43// SPIR-V encoding format. If the value is 'false' or 'true', returns a vector
44// with single word with value 0 or 1 respectively. Returns the vector
45// containing the encoded value on success. Otherwise returns an empty vector.
46std::vector<uint32_t> ParseDefaultValueStr(const char* text,
47 const analysis::Type* type) {
48 std::vector<uint32_t> result;
49 if (!strcmp(text, "true") && type->AsBool()) {
50 result.push_back(1u);
51 } else if (!strcmp(text, "false") && type->AsBool()) {
52 result.push_back(0u);
53 } else {
54 NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
55 if (const auto* IT = type->AsInteger()) {
56 number_type.bitwidth = IT->width();
57 number_type.kind =
58 IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
59 } else if (const auto* FT = type->AsFloat()) {
60 number_type.bitwidth = FT->width();
61 number_type.kind = SPV_NUMBER_FLOATING;
62 } else {
63 // Does not handle types other then boolean, integer or float. Returns
64 // empty vector.
65 result.clear();
66 return result;
67 }
68 EncodeNumberStatus rc = ParseAndEncodeNumber(
69 text, number_type, [&result](uint32_t word) { result.push_back(word); },
70 nullptr);
71 // Clear the result vector on failure.
72 if (rc != EncodeNumberStatus::kSuccess) {
73 result.clear();
74 }
75 }
76 return result;
77}
78
79// Given a bit pattern and a type, checks if the bit pattern is compatible
80// with the type. If so, returns the bit pattern, otherwise returns an empty
81// bit pattern. If the given bit pattern is empty, returns an empty bit
82// pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
83// to be returned is determined with the following standard:
84// If any words in the input bit pattern are non zero, returns a bit pattern
85// with 0x1, which represents a 'true'.
86// If all words in the bit pattern are zero, returns a bit pattern with 0x0,
87// which represents a 'false'.
88std::vector<uint32_t> ParseDefaultValueBitPattern(
89 const std::vector<uint32_t>& input_bit_pattern,
90 const analysis::Type* type) {
91 std::vector<uint32_t> result;
92 if (type->AsBool()) {
93 if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
94 [](uint32_t i) { return i != 0; })) {
95 result.push_back(1u);
96 } else {
97 result.push_back(0u);
98 }
99 return result;
100 } else if (const auto* IT = type->AsInteger()) {
101 if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
102 return std::vector<uint32_t>(input_bit_pattern);
103 }
104 } else if (const auto* FT = type->AsFloat()) {
105 if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
106 return std::vector<uint32_t>(input_bit_pattern);
107 }
108 }
109 result.clear();
110 return result;
111}
112
113// Returns true if the given instruction's result id could have a SpecId
114// decoration.
115bool CanHaveSpecIdDecoration(const Instruction& inst) {
116 switch (inst.opcode()) {
117 case SpvOp::SpvOpSpecConstant:
118 case SpvOp::SpvOpSpecConstantFalse:
119 case SpvOp::SpvOpSpecConstantTrue:
120 return true;
121 default:
122 return false;
123 }
124}
125
126// Given a decoration group defining instruction that is decorated with SpecId
127// decoration, finds the spec constant defining instruction which is the real
128// target of the SpecId decoration. Returns the spec constant defining
129// instruction if such an instruction is found, otherwise returns a nullptr.
130Instruction* GetSpecIdTargetFromDecorationGroup(
131 const Instruction& decoration_group_defining_inst,
132 analysis::DefUseManager* def_use_mgr) {
133 // Find the OpGroupDecorate instruction which consumes the given decoration
134 // group. Note that the given decoration group has SpecId decoration, which
135 // is unique for different spec constants. So the decoration group cannot be
136 // consumed by different OpGroupDecorate instructions. Therefore we only need
137 // the first OpGroupDecoration instruction that uses the given decoration
138 // group.
139 Instruction* group_decorate_inst = nullptr;
140 if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
141 [&group_decorate_inst](Instruction* user) {
142 if (user->opcode() ==
143 SpvOp::SpvOpGroupDecorate) {
144 group_decorate_inst = user;
145 return false;
146 }
147 return true;
148 }))
149 return nullptr;
150
151 // Scan through the target ids of the OpGroupDecorate instruction. There
152 // should be only one spec constant target consumes the SpecId decoration.
153 // If multiple target ids are presented in the OpGroupDecorate instruction,
154 // they must be the same one that defined by an eligible spec constant
155 // instruction. If the OpGroupDecorate instruction has different target ids
156 // or a target id is not defined by an eligible spec cosntant instruction,
157 // returns a nullptr.
158 Instruction* target_inst = nullptr;
159 for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) {
160 // All the operands of a OpGroupDecorate instruction should be of type
161 // SPV_OPERAND_TYPE_ID.
162 uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i);
163 Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id);
164
165 if (!candidate_inst) {
166 continue;
167 }
168
169 if (!target_inst) {
170 // If the spec constant target has not been found yet, check if the
171 // candidate instruction is the target.
172 if (CanHaveSpecIdDecoration(*candidate_inst)) {
173 target_inst = candidate_inst;
174 } else {
175 // Spec id decoration should not be applied on other instructions.
176 // TODO(qining): Emit an error message in the invalid case once the
177 // error handling is done.
178 return nullptr;
179 }
180 } else {
181 // If the spec constant target has been found, check if the candidate
182 // instruction is the same one as the target. The module is invalid if
183 // the candidate instruction is different with the found target.
184 // TODO(qining): Emit an error messaage in the invalid case once the
185 // error handling is done.
186 if (candidate_inst != target_inst) return nullptr;
187 }
188 }
189 return target_inst;
190}
191} // namespace
192
193Pass::Status SetSpecConstantDefaultValuePass::Process() {
194 // The operand index of decoration target in an OpDecorate instruction.
195 const uint32_t kTargetIdOperandIndex = 0;
196 // The operand index of the decoration literal in an OpDecorate instruction.
197 const uint32_t kDecorationOperandIndex = 1;
198 // The operand index of Spec id literal value in an OpDecorate SpecId
199 // instruction.
200 const uint32_t kSpecIdLiteralOperandIndex = 2;
201 // The number of operands in an OpDecorate SpecId instruction.
202 const uint32_t kOpDecorateSpecIdNumOperands = 3;
203 // The in-operand index of the default value in a OpSpecConstant instruction.
204 const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
205
206 bool modified = false;
207 // Scan through all the annotation instructions to find 'OpDecorate SpecId'
208 // instructions. Then extract the decoration target of those instructions.
209 // The decoration targets should be spec constant defining instructions with
210 // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants
211 // will be used to look up their new default values in the mapping from
212 // spec id to new default value strings. Once a new default value string
213 // is found for a spec id, the string will be parsed according to the target
214 // spec constant type. The parsed value will be used to replace the original
215 // default value of the target spec constant.
216 for (Instruction& inst : context()->annotations()) {
217 // Only process 'OpDecorate SpecId' instructions
218 if (inst.opcode() != SpvOp::SpvOpDecorate) continue;
219 if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue;
220 if (inst.GetSingleWordInOperand(kDecorationOperandIndex) !=
221 uint32_t(SpvDecoration::SpvDecorationSpecId)) {
222 continue;
223 }
224
225 // 'inst' is an OpDecorate SpecId instruction.
226 uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex);
227 uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex);
228
229 // Find the spec constant defining instruction. Note that the
230 // target_id might be a decoration group id.
231 Instruction* spec_inst = nullptr;
232 if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
233 if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
234 spec_inst =
235 GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
236 } else {
237 spec_inst = target_inst;
238 }
239 } else {
240 continue;
241 }
242 if (!spec_inst) continue;
243
244 // Get the default value bit pattern for this spec id.
245 std::vector<uint32_t> bit_pattern;
246
247 if (spec_id_to_value_str_.size() != 0) {
248 // Search for the new string-form default value for this spec id.
249 auto iter = spec_id_to_value_str_.find(spec_id);
250 if (iter == spec_id_to_value_str_.end()) {
251 continue;
252 }
253
254 // Gets the string of the default value and parses it to bit pattern
255 // with the type of the spec constant.
256 const std::string& default_value_str = iter->second;
257 bit_pattern = ParseDefaultValueStr(
258 default_value_str.c_str(),
259 context()->get_type_mgr()->GetType(spec_inst->type_id()));
260
261 } else {
262 // Search for the new bit-pattern-form default value for this spec id.
263 auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
264 if (iter == spec_id_to_value_bit_pattern_.end()) {
265 continue;
266 }
267
268 // Gets the bit-pattern of the default value from the map directly.
269 bit_pattern = ParseDefaultValueBitPattern(
270 iter->second,
271 context()->get_type_mgr()->GetType(spec_inst->type_id()));
272 }
273
274 if (bit_pattern.empty()) continue;
275
276 // Update the operand bit patterns of the spec constant defining
277 // instruction.
278 switch (spec_inst->opcode()) {
279 case SpvOp::SpvOpSpecConstant:
280 // If the new value is the same with the original value, no
281 // need to do anything. Otherwise update the operand words.
282 if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex)
283 .words != bit_pattern) {
284 spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex,
285 std::move(bit_pattern));
286 modified = true;
287 }
288 break;
289 case SpvOp::SpvOpSpecConstantTrue:
290 // If the new value is also 'true', no need to change anything.
291 // Otherwise, set the opcode to OpSpecConstantFalse;
292 if (!static_cast<bool>(bit_pattern.front())) {
293 spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse);
294 modified = true;
295 }
296 break;
297 case SpvOp::SpvOpSpecConstantFalse:
298 // If the new value is also 'false', no need to change anything.
299 // Otherwise, set the opcode to OpSpecConstantTrue;
300 if (static_cast<bool>(bit_pattern.front())) {
301 spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue);
302 modified = true;
303 }
304 break;
305 default:
306 break;
307 }
308 // No need to update the DefUse manager, as this pass does not change any
309 // ids.
310 }
311 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
312}
313
314// Returns true if the given char is ':', '\0' or considered as blank space
315// (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
316bool IsSeparator(char ch) {
317 return std::strchr(":\0", ch) || std::isspace(ch) != 0;
318}
319
320std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap>
321SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) {
322 if (!str) return nullptr;
323
324 auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>();
325
326 // The parsing loop, break when points to the end.
327 while (*str) {
328 // Find the spec id.
329 while (std::isspace(*str)) str++; // skip leading spaces.
330 const char* entry_begin = str;
331 while (!IsSeparator(*str)) str++;
332 const char* entry_end = str;
333 std::string spec_id_str(entry_begin, entry_end - entry_begin);
334 uint32_t spec_id = 0;
335 if (!ParseNumber(spec_id_str.c_str(), &spec_id)) {
336 // The spec id is not a valid uint32 number.
337 return nullptr;
338 }
339 auto iter = spec_id_to_value->find(spec_id);
340 if (iter != spec_id_to_value->end()) {
341 // Same spec id has been defined before
342 return nullptr;
343 }
344 // Find the ':', spaces between the spec id and the ':' are not allowed.
345 if (*str++ != ':') {
346 // ':' not found
347 return nullptr;
348 }
349 // Find the value string
350 const char* val_begin = str;
351 while (!IsSeparator(*str)) str++;
352 const char* val_end = str;
353 if (val_end == val_begin) {
354 // Value string is empty.
355 return nullptr;
356 }
357 // Update the mapping with spec id and value string.
358 (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin);
359
360 // Skip trailing spaces.
361 while (std::isspace(*str)) str++;
362 }
363
364 return spec_id_to_value;
365}
366
367} // namespace opt
368} // namespace spvtools
369