1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opcode.h"
16#include "source/val/instruction.h"
17#include "source/val/validate.h"
18#include "source/val/validation_state.h"
19
20namespace spvtools {
21namespace val {
22namespace {
23
24spv_result_t ValidateConstantBool(ValidationState_t& _,
25 const Instruction* inst) {
26 auto type = _.FindDef(inst->type_id());
27 if (!type || type->opcode() != SpvOpTypeBool) {
28 return _.diag(SPV_ERROR_INVALID_ID, inst)
29 << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
30 << _.getIdName(inst->type_id()) << "' is not a boolean type.";
31 }
32
33 return SPV_SUCCESS;
34}
35
36spv_result_t ValidateConstantComposite(ValidationState_t& _,
37 const Instruction* inst) {
38 std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
39
40 const auto result_type = _.FindDef(inst->type_id());
41 if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
42 return _.diag(SPV_ERROR_INVALID_ID, inst)
43 << opcode_name << " Result Type <id> '"
44 << _.getIdName(inst->type_id()) << "' is not a composite type.";
45 }
46
47 const auto constituent_count = inst->words().size() - 3;
48 switch (result_type->opcode()) {
49 case SpvOpTypeVector: {
50 const auto component_count = result_type->GetOperandAs<uint32_t>(2);
51 if (component_count != constituent_count) {
52 // TODO: Output ID's on diagnostic
53 return _.diag(SPV_ERROR_INVALID_ID, inst)
54 << opcode_name
55 << " Constituent <id> count does not match "
56 "Result Type <id> '"
57 << _.getIdName(result_type->id())
58 << "'s vector component count.";
59 }
60 const auto component_type =
61 _.FindDef(result_type->GetOperandAs<uint32_t>(1));
62 if (!component_type) {
63 return _.diag(SPV_ERROR_INVALID_ID, result_type)
64 << "Component type is not defined.";
65 }
66 for (size_t constituent_index = 2;
67 constituent_index < inst->operands().size(); constituent_index++) {
68 const auto constituent_id =
69 inst->GetOperandAs<uint32_t>(constituent_index);
70 const auto constituent = _.FindDef(constituent_id);
71 if (!constituent ||
72 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
73 return _.diag(SPV_ERROR_INVALID_ID, inst)
74 << opcode_name << " Constituent <id> '"
75 << _.getIdName(constituent_id)
76 << "' is not a constant or undef.";
77 }
78 const auto constituent_result_type = _.FindDef(constituent->type_id());
79 if (!constituent_result_type ||
80 component_type->opcode() != constituent_result_type->opcode()) {
81 return _.diag(SPV_ERROR_INVALID_ID, inst)
82 << opcode_name << " Constituent <id> '"
83 << _.getIdName(constituent_id)
84 << "'s type does not match Result Type <id> '"
85 << _.getIdName(result_type->id()) << "'s vector element type.";
86 }
87 }
88 } break;
89 case SpvOpTypeMatrix: {
90 const auto column_count = result_type->GetOperandAs<uint32_t>(2);
91 if (column_count != constituent_count) {
92 // TODO: Output ID's on diagnostic
93 return _.diag(SPV_ERROR_INVALID_ID, inst)
94 << opcode_name
95 << " Constituent <id> count does not match "
96 "Result Type <id> '"
97 << _.getIdName(result_type->id()) << "'s matrix column count.";
98 }
99
100 const auto column_type = _.FindDef(result_type->words()[2]);
101 if (!column_type) {
102 return _.diag(SPV_ERROR_INVALID_ID, result_type)
103 << "Column type is not defined.";
104 }
105 const auto component_count = column_type->GetOperandAs<uint32_t>(2);
106 const auto component_type =
107 _.FindDef(column_type->GetOperandAs<uint32_t>(1));
108 if (!component_type) {
109 return _.diag(SPV_ERROR_INVALID_ID, column_type)
110 << "Component type is not defined.";
111 }
112
113 for (size_t constituent_index = 2;
114 constituent_index < inst->operands().size(); constituent_index++) {
115 const auto constituent_id =
116 inst->GetOperandAs<uint32_t>(constituent_index);
117 const auto constituent = _.FindDef(constituent_id);
118 if (!constituent ||
119 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
120 // The message says "... or undef" because the spec does not say
121 // undef is a constant.
122 return _.diag(SPV_ERROR_INVALID_ID, inst)
123 << opcode_name << " Constituent <id> '"
124 << _.getIdName(constituent_id)
125 << "' is not a constant or undef.";
126 }
127 const auto vector = _.FindDef(constituent->type_id());
128 if (!vector) {
129 return _.diag(SPV_ERROR_INVALID_ID, constituent)
130 << "Result type is not defined.";
131 }
132 if (column_type->opcode() != vector->opcode()) {
133 return _.diag(SPV_ERROR_INVALID_ID, inst)
134 << opcode_name << " Constituent <id> '"
135 << _.getIdName(constituent_id)
136 << "' type does not match Result Type <id> '"
137 << _.getIdName(result_type->id()) << "'s matrix column type.";
138 }
139 const auto vector_component_type =
140 _.FindDef(vector->GetOperandAs<uint32_t>(1));
141 if (component_type->id() != vector_component_type->id()) {
142 return _.diag(SPV_ERROR_INVALID_ID, inst)
143 << opcode_name << " Constituent <id> '"
144 << _.getIdName(constituent_id)
145 << "' component type does not match Result Type <id> '"
146 << _.getIdName(result_type->id())
147 << "'s matrix column component type.";
148 }
149 if (component_count != vector->words()[3]) {
150 return _.diag(SPV_ERROR_INVALID_ID, inst)
151 << opcode_name << " Constituent <id> '"
152 << _.getIdName(constituent_id)
153 << "' vector component count does not match Result Type <id> '"
154 << _.getIdName(result_type->id())
155 << "'s vector component count.";
156 }
157 }
158 } break;
159 case SpvOpTypeArray: {
160 auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
161 if (!element_type) {
162 return _.diag(SPV_ERROR_INVALID_ID, result_type)
163 << "Element type is not defined.";
164 }
165 const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
166 if (!length) {
167 return _.diag(SPV_ERROR_INVALID_ID, result_type)
168 << "Length is not defined.";
169 }
170 bool is_int32;
171 bool is_const;
172 uint32_t value;
173 std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
174 if (is_int32 && is_const && value != constituent_count) {
175 return _.diag(SPV_ERROR_INVALID_ID, inst)
176 << opcode_name
177 << " Constituent count does not match "
178 "Result Type <id> '"
179 << _.getIdName(result_type->id()) << "'s array length.";
180 }
181 for (size_t constituent_index = 2;
182 constituent_index < inst->operands().size(); constituent_index++) {
183 const auto constituent_id =
184 inst->GetOperandAs<uint32_t>(constituent_index);
185 const auto constituent = _.FindDef(constituent_id);
186 if (!constituent ||
187 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
188 return _.diag(SPV_ERROR_INVALID_ID, inst)
189 << opcode_name << " Constituent <id> '"
190 << _.getIdName(constituent_id)
191 << "' is not a constant or undef.";
192 }
193 const auto constituent_type = _.FindDef(constituent->type_id());
194 if (!constituent_type) {
195 return _.diag(SPV_ERROR_INVALID_ID, constituent)
196 << "Result type is not defined.";
197 }
198 if (element_type->id() != constituent_type->id()) {
199 return _.diag(SPV_ERROR_INVALID_ID, inst)
200 << opcode_name << " Constituent <id> '"
201 << _.getIdName(constituent_id)
202 << "'s type does not match Result Type <id> '"
203 << _.getIdName(result_type->id()) << "'s array element type.";
204 }
205 }
206 } break;
207 case SpvOpTypeStruct: {
208 const auto member_count = result_type->words().size() - 2;
209 if (member_count != constituent_count) {
210 return _.diag(SPV_ERROR_INVALID_ID, inst)
211 << opcode_name << " Constituent <id> '"
212 << _.getIdName(inst->type_id())
213 << "' count does not match Result Type <id> '"
214 << _.getIdName(result_type->id()) << "'s struct member count.";
215 }
216 for (uint32_t constituent_index = 2, member_index = 1;
217 constituent_index < inst->operands().size();
218 constituent_index++, member_index++) {
219 const auto constituent_id =
220 inst->GetOperandAs<uint32_t>(constituent_index);
221 const auto constituent = _.FindDef(constituent_id);
222 if (!constituent ||
223 !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
224 return _.diag(SPV_ERROR_INVALID_ID, inst)
225 << opcode_name << " Constituent <id> '"
226 << _.getIdName(constituent_id)
227 << "' is not a constant or undef.";
228 }
229 const auto constituent_type = _.FindDef(constituent->type_id());
230 if (!constituent_type) {
231 return _.diag(SPV_ERROR_INVALID_ID, constituent)
232 << "Result type is not defined.";
233 }
234
235 const auto member_type_id =
236 result_type->GetOperandAs<uint32_t>(member_index);
237 const auto member_type = _.FindDef(member_type_id);
238 if (!member_type || member_type->id() != constituent_type->id()) {
239 return _.diag(SPV_ERROR_INVALID_ID, inst)
240 << opcode_name << " Constituent <id> '"
241 << _.getIdName(constituent_id)
242 << "' type does not match the Result Type <id> '"
243 << _.getIdName(result_type->id()) << "'s member type.";
244 }
245 }
246 } break;
247 case SpvOpTypeCooperativeMatrixNV: {
248 if (1 != constituent_count) {
249 return _.diag(SPV_ERROR_INVALID_ID, inst)
250 << opcode_name << " Constituent <id> '"
251 << _.getIdName(inst->type_id()) << "' count must be one.";
252 }
253 const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
254 const auto constituent = _.FindDef(constituent_id);
255 if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
256 return _.diag(SPV_ERROR_INVALID_ID, inst)
257 << opcode_name << " Constituent <id> '"
258 << _.getIdName(constituent_id)
259 << "' is not a constant or undef.";
260 }
261 const auto constituent_type = _.FindDef(constituent->type_id());
262 if (!constituent_type) {
263 return _.diag(SPV_ERROR_INVALID_ID, constituent)
264 << "Result type is not defined.";
265 }
266
267 const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
268 const auto component_type = _.FindDef(component_type_id);
269 if (!component_type || component_type->id() != constituent_type->id()) {
270 return _.diag(SPV_ERROR_INVALID_ID, inst)
271 << opcode_name << " Constituent <id> '"
272 << _.getIdName(constituent_id)
273 << "' type does not match the Result Type <id> '"
274 << _.getIdName(result_type->id()) << "'s component type.";
275 }
276 } break;
277 default:
278 break;
279 }
280 return SPV_SUCCESS;
281}
282
283spv_result_t ValidateConstantSampler(ValidationState_t& _,
284 const Instruction* inst) {
285 const auto result_type = _.FindDef(inst->type_id());
286 if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
287 return _.diag(SPV_ERROR_INVALID_ID, result_type)
288 << "OpConstantSampler Result Type <id> '"
289 << _.getIdName(inst->type_id()) << "' is not a sampler type.";
290 }
291
292 return SPV_SUCCESS;
293}
294
295// True if instruction defines a type that can have a null value, as defined by
296// the SPIR-V spec. Tracks composite-type components through module to check
297// nullability transitively.
298bool IsTypeNullable(const std::vector<uint32_t>& instruction,
299 const ValidationState_t& _) {
300 uint16_t opcode;
301 uint16_t word_count;
302 spvOpcodeSplit(instruction[0], &word_count, &opcode);
303 switch (static_cast<SpvOp>(opcode)) {
304 case SpvOpTypeBool:
305 case SpvOpTypeInt:
306 case SpvOpTypeFloat:
307 case SpvOpTypeEvent:
308 case SpvOpTypeDeviceEvent:
309 case SpvOpTypeReserveId:
310 case SpvOpTypeQueue:
311 return true;
312 case SpvOpTypeArray:
313 case SpvOpTypeMatrix:
314 case SpvOpTypeCooperativeMatrixNV:
315 case SpvOpTypeVector: {
316 auto base_type = _.FindDef(instruction[2]);
317 return base_type && IsTypeNullable(base_type->words(), _);
318 }
319 case SpvOpTypeStruct: {
320 for (size_t elementIndex = 2; elementIndex < instruction.size();
321 ++elementIndex) {
322 auto element = _.FindDef(instruction[elementIndex]);
323 if (!element || !IsTypeNullable(element->words(), _)) return false;
324 }
325 return true;
326 }
327 case SpvOpTypePointer:
328 if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) {
329 return false;
330 }
331 return true;
332 default:
333 return false;
334 }
335}
336
337spv_result_t ValidateConstantNull(ValidationState_t& _,
338 const Instruction* inst) {
339 const auto result_type = _.FindDef(inst->type_id());
340 if (!result_type || !IsTypeNullable(result_type->words(), _)) {
341 return _.diag(SPV_ERROR_INVALID_ID, inst)
342 << "OpConstantNull Result Type <id> '"
343 << _.getIdName(inst->type_id()) << "' cannot have a null value.";
344 }
345
346 return SPV_SUCCESS;
347}
348
349// Validates that OpSpecConstant specializes to either int or float type.
350spv_result_t ValidateSpecConstant(ValidationState_t& _,
351 const Instruction* inst) {
352 // Operand 0 is the <id> of the type that we're specializing to.
353 auto type_id = inst->GetOperandAs<const uint32_t>(0);
354 auto type_instruction = _.FindDef(type_id);
355 auto type_opcode = type_instruction->opcode();
356 if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
357 return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
358 "must be an integer or "
359 "floating-point number.";
360 }
361 return SPV_SUCCESS;
362}
363
364spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
365 const Instruction* inst) {
366 const auto op = inst->GetOperandAs<SpvOp>(2);
367
368 // The binary parser already ensures that the op is valid for *some*
369 // environment. Here we check restrictions.
370 switch (op) {
371 case SpvOpQuantizeToF16:
372 if (!_.HasCapability(SpvCapabilityShader)) {
373 return _.diag(SPV_ERROR_INVALID_ID, inst)
374 << "Specialization constant operation " << spvOpcodeString(op)
375 << " requires Shader capability";
376 }
377 break;
378
379 case SpvOpUConvert:
380 if (!_.features().uconvert_spec_constant_op &&
381 !_.HasCapability(SpvCapabilityKernel)) {
382 return _.diag(SPV_ERROR_INVALID_ID, inst)
383 << "Prior to SPIR-V 1.4, specialization constant operation "
384 "UConvert requires Kernel capability or extension "
385 "SPV_AMD_gpu_shader_int16";
386 }
387 break;
388
389 case SpvOpConvertFToS:
390 case SpvOpConvertSToF:
391 case SpvOpConvertFToU:
392 case SpvOpConvertUToF:
393 case SpvOpConvertPtrToU:
394 case SpvOpConvertUToPtr:
395 case SpvOpGenericCastToPtr:
396 case SpvOpPtrCastToGeneric:
397 case SpvOpBitcast:
398 case SpvOpFNegate:
399 case SpvOpFAdd:
400 case SpvOpFSub:
401 case SpvOpFMul:
402 case SpvOpFDiv:
403 case SpvOpFRem:
404 case SpvOpFMod:
405 case SpvOpAccessChain:
406 case SpvOpInBoundsAccessChain:
407 case SpvOpPtrAccessChain:
408 case SpvOpInBoundsPtrAccessChain:
409 if (!_.HasCapability(SpvCapabilityKernel)) {
410 return _.diag(SPV_ERROR_INVALID_ID, inst)
411 << "Specialization constant operation " << spvOpcodeString(op)
412 << " requires Kernel capability";
413 }
414 break;
415
416 default:
417 break;
418 }
419
420 // TODO(dneto): Validate result type and arguments to the various operations.
421 return SPV_SUCCESS;
422}
423
424} // namespace
425
426spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
427 switch (inst->opcode()) {
428 case SpvOpConstantTrue:
429 case SpvOpConstantFalse:
430 case SpvOpSpecConstantTrue:
431 case SpvOpSpecConstantFalse:
432 if (auto error = ValidateConstantBool(_, inst)) return error;
433 break;
434 case SpvOpConstantComposite:
435 case SpvOpSpecConstantComposite:
436 if (auto error = ValidateConstantComposite(_, inst)) return error;
437 break;
438 case SpvOpConstantSampler:
439 if (auto error = ValidateConstantSampler(_, inst)) return error;
440 break;
441 case SpvOpConstantNull:
442 if (auto error = ValidateConstantNull(_, inst)) return error;
443 break;
444 case SpvOpSpecConstant:
445 if (auto error = ValidateSpecConstant(_, inst)) return error;
446 break;
447 case SpvOpSpecConstantOp:
448 if (auto error = ValidateSpecConstantOp(_, inst)) return error;
449 break;
450 default:
451 break;
452 }
453
454 // Generally disallow creating 8- or 16-bit constants unless the full
455 // capabilities are present.
456 if (spvOpcodeIsConstant(inst->opcode()) &&
457 _.HasCapability(SpvCapabilityShader) &&
458 !_.IsPointerType(inst->type_id()) &&
459 _.ContainsLimitedUseIntOrFloatType(inst->type_id())) {
460 return _.diag(SPV_ERROR_INVALID_ID, inst)
461 << "Cannot form constants of 8- or 16-bit types";
462 }
463
464 return SPV_SUCCESS;
465}
466
467} // namespace val
468} // namespace spvtools
469