1 | // Copyright (c) 2018 Google LLC. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opcode.h" |
16 | #include "source/val/instruction.h" |
17 | #include "source/val/validate.h" |
18 | #include "source/val/validation_state.h" |
19 | |
20 | namespace spvtools { |
21 | namespace val { |
22 | namespace { |
23 | |
24 | spv_result_t ValidateConstantBool(ValidationState_t& _, |
25 | const Instruction* inst) { |
26 | auto type = _.FindDef(inst->type_id()); |
27 | if (!type || type->opcode() != SpvOpTypeBool) { |
28 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
29 | << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '" |
30 | << _.getIdName(inst->type_id()) << "' is not a boolean type." ; |
31 | } |
32 | |
33 | return SPV_SUCCESS; |
34 | } |
35 | |
36 | spv_result_t ValidateConstantComposite(ValidationState_t& _, |
37 | const Instruction* inst) { |
38 | std::string opcode_name = std::string("Op" ) + spvOpcodeString(inst->opcode()); |
39 | |
40 | const auto result_type = _.FindDef(inst->type_id()); |
41 | if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { |
42 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
43 | << opcode_name << " Result Type <id> '" |
44 | << _.getIdName(inst->type_id()) << "' is not a composite type." ; |
45 | } |
46 | |
47 | const auto constituent_count = inst->words().size() - 3; |
48 | switch (result_type->opcode()) { |
49 | case SpvOpTypeVector: { |
50 | const auto component_count = result_type->GetOperandAs<uint32_t>(2); |
51 | if (component_count != constituent_count) { |
52 | // TODO: Output ID's on diagnostic |
53 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
54 | << opcode_name |
55 | << " Constituent <id> count does not match " |
56 | "Result Type <id> '" |
57 | << _.getIdName(result_type->id()) |
58 | << "'s vector component count." ; |
59 | } |
60 | const auto component_type = |
61 | _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
62 | if (!component_type) { |
63 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
64 | << "Component type is not defined." ; |
65 | } |
66 | for (size_t constituent_index = 2; |
67 | constituent_index < inst->operands().size(); constituent_index++) { |
68 | const auto constituent_id = |
69 | inst->GetOperandAs<uint32_t>(constituent_index); |
70 | const auto constituent = _.FindDef(constituent_id); |
71 | if (!constituent || |
72 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
73 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
74 | << opcode_name << " Constituent <id> '" |
75 | << _.getIdName(constituent_id) |
76 | << "' is not a constant or undef." ; |
77 | } |
78 | const auto constituent_result_type = _.FindDef(constituent->type_id()); |
79 | if (!constituent_result_type || |
80 | component_type->opcode() != constituent_result_type->opcode()) { |
81 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
82 | << opcode_name << " Constituent <id> '" |
83 | << _.getIdName(constituent_id) |
84 | << "'s type does not match Result Type <id> '" |
85 | << _.getIdName(result_type->id()) << "'s vector element type." ; |
86 | } |
87 | } |
88 | } break; |
89 | case SpvOpTypeMatrix: { |
90 | const auto column_count = result_type->GetOperandAs<uint32_t>(2); |
91 | if (column_count != constituent_count) { |
92 | // TODO: Output ID's on diagnostic |
93 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
94 | << opcode_name |
95 | << " Constituent <id> count does not match " |
96 | "Result Type <id> '" |
97 | << _.getIdName(result_type->id()) << "'s matrix column count." ; |
98 | } |
99 | |
100 | const auto column_type = _.FindDef(result_type->words()[2]); |
101 | if (!column_type) { |
102 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
103 | << "Column type is not defined." ; |
104 | } |
105 | const auto component_count = column_type->GetOperandAs<uint32_t>(2); |
106 | const auto component_type = |
107 | _.FindDef(column_type->GetOperandAs<uint32_t>(1)); |
108 | if (!component_type) { |
109 | return _.diag(SPV_ERROR_INVALID_ID, column_type) |
110 | << "Component type is not defined." ; |
111 | } |
112 | |
113 | for (size_t constituent_index = 2; |
114 | constituent_index < inst->operands().size(); constituent_index++) { |
115 | const auto constituent_id = |
116 | inst->GetOperandAs<uint32_t>(constituent_index); |
117 | const auto constituent = _.FindDef(constituent_id); |
118 | if (!constituent || |
119 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
120 | // The message says "... or undef" because the spec does not say |
121 | // undef is a constant. |
122 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
123 | << opcode_name << " Constituent <id> '" |
124 | << _.getIdName(constituent_id) |
125 | << "' is not a constant or undef." ; |
126 | } |
127 | const auto vector = _.FindDef(constituent->type_id()); |
128 | if (!vector) { |
129 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
130 | << "Result type is not defined." ; |
131 | } |
132 | if (column_type->opcode() != vector->opcode()) { |
133 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
134 | << opcode_name << " Constituent <id> '" |
135 | << _.getIdName(constituent_id) |
136 | << "' type does not match Result Type <id> '" |
137 | << _.getIdName(result_type->id()) << "'s matrix column type." ; |
138 | } |
139 | const auto vector_component_type = |
140 | _.FindDef(vector->GetOperandAs<uint32_t>(1)); |
141 | if (component_type->id() != vector_component_type->id()) { |
142 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
143 | << opcode_name << " Constituent <id> '" |
144 | << _.getIdName(constituent_id) |
145 | << "' component type does not match Result Type <id> '" |
146 | << _.getIdName(result_type->id()) |
147 | << "'s matrix column component type." ; |
148 | } |
149 | if (component_count != vector->words()[3]) { |
150 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
151 | << opcode_name << " Constituent <id> '" |
152 | << _.getIdName(constituent_id) |
153 | << "' vector component count does not match Result Type <id> '" |
154 | << _.getIdName(result_type->id()) |
155 | << "'s vector component count." ; |
156 | } |
157 | } |
158 | } break; |
159 | case SpvOpTypeArray: { |
160 | auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); |
161 | if (!element_type) { |
162 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
163 | << "Element type is not defined." ; |
164 | } |
165 | const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); |
166 | if (!length) { |
167 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
168 | << "Length is not defined." ; |
169 | } |
170 | bool is_int32; |
171 | bool is_const; |
172 | uint32_t value; |
173 | std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); |
174 | if (is_int32 && is_const && value != constituent_count) { |
175 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
176 | << opcode_name |
177 | << " Constituent count does not match " |
178 | "Result Type <id> '" |
179 | << _.getIdName(result_type->id()) << "'s array length." ; |
180 | } |
181 | for (size_t constituent_index = 2; |
182 | constituent_index < inst->operands().size(); constituent_index++) { |
183 | const auto constituent_id = |
184 | inst->GetOperandAs<uint32_t>(constituent_index); |
185 | const auto constituent = _.FindDef(constituent_id); |
186 | if (!constituent || |
187 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
188 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
189 | << opcode_name << " Constituent <id> '" |
190 | << _.getIdName(constituent_id) |
191 | << "' is not a constant or undef." ; |
192 | } |
193 | const auto constituent_type = _.FindDef(constituent->type_id()); |
194 | if (!constituent_type) { |
195 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
196 | << "Result type is not defined." ; |
197 | } |
198 | if (element_type->id() != constituent_type->id()) { |
199 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
200 | << opcode_name << " Constituent <id> '" |
201 | << _.getIdName(constituent_id) |
202 | << "'s type does not match Result Type <id> '" |
203 | << _.getIdName(result_type->id()) << "'s array element type." ; |
204 | } |
205 | } |
206 | } break; |
207 | case SpvOpTypeStruct: { |
208 | const auto member_count = result_type->words().size() - 2; |
209 | if (member_count != constituent_count) { |
210 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
211 | << opcode_name << " Constituent <id> '" |
212 | << _.getIdName(inst->type_id()) |
213 | << "' count does not match Result Type <id> '" |
214 | << _.getIdName(result_type->id()) << "'s struct member count." ; |
215 | } |
216 | for (uint32_t constituent_index = 2, member_index = 1; |
217 | constituent_index < inst->operands().size(); |
218 | constituent_index++, member_index++) { |
219 | const auto constituent_id = |
220 | inst->GetOperandAs<uint32_t>(constituent_index); |
221 | const auto constituent = _.FindDef(constituent_id); |
222 | if (!constituent || |
223 | !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
224 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
225 | << opcode_name << " Constituent <id> '" |
226 | << _.getIdName(constituent_id) |
227 | << "' is not a constant or undef." ; |
228 | } |
229 | const auto constituent_type = _.FindDef(constituent->type_id()); |
230 | if (!constituent_type) { |
231 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
232 | << "Result type is not defined." ; |
233 | } |
234 | |
235 | const auto member_type_id = |
236 | result_type->GetOperandAs<uint32_t>(member_index); |
237 | const auto member_type = _.FindDef(member_type_id); |
238 | if (!member_type || member_type->id() != constituent_type->id()) { |
239 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
240 | << opcode_name << " Constituent <id> '" |
241 | << _.getIdName(constituent_id) |
242 | << "' type does not match the Result Type <id> '" |
243 | << _.getIdName(result_type->id()) << "'s member type." ; |
244 | } |
245 | } |
246 | } break; |
247 | case SpvOpTypeCooperativeMatrixNV: { |
248 | if (1 != constituent_count) { |
249 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
250 | << opcode_name << " Constituent <id> '" |
251 | << _.getIdName(inst->type_id()) << "' count must be one." ; |
252 | } |
253 | const auto constituent_id = inst->GetOperandAs<uint32_t>(2); |
254 | const auto constituent = _.FindDef(constituent_id); |
255 | if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { |
256 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
257 | << opcode_name << " Constituent <id> '" |
258 | << _.getIdName(constituent_id) |
259 | << "' is not a constant or undef." ; |
260 | } |
261 | const auto constituent_type = _.FindDef(constituent->type_id()); |
262 | if (!constituent_type) { |
263 | return _.diag(SPV_ERROR_INVALID_ID, constituent) |
264 | << "Result type is not defined." ; |
265 | } |
266 | |
267 | const auto component_type_id = result_type->GetOperandAs<uint32_t>(1); |
268 | const auto component_type = _.FindDef(component_type_id); |
269 | if (!component_type || component_type->id() != constituent_type->id()) { |
270 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
271 | << opcode_name << " Constituent <id> '" |
272 | << _.getIdName(constituent_id) |
273 | << "' type does not match the Result Type <id> '" |
274 | << _.getIdName(result_type->id()) << "'s component type." ; |
275 | } |
276 | } break; |
277 | default: |
278 | break; |
279 | } |
280 | return SPV_SUCCESS; |
281 | } |
282 | |
283 | spv_result_t ValidateConstantSampler(ValidationState_t& _, |
284 | const Instruction* inst) { |
285 | const auto result_type = _.FindDef(inst->type_id()); |
286 | if (!result_type || result_type->opcode() != SpvOpTypeSampler) { |
287 | return _.diag(SPV_ERROR_INVALID_ID, result_type) |
288 | << "OpConstantSampler Result Type <id> '" |
289 | << _.getIdName(inst->type_id()) << "' is not a sampler type." ; |
290 | } |
291 | |
292 | return SPV_SUCCESS; |
293 | } |
294 | |
295 | // True if instruction defines a type that can have a null value, as defined by |
296 | // the SPIR-V spec. Tracks composite-type components through module to check |
297 | // nullability transitively. |
298 | bool IsTypeNullable(const std::vector<uint32_t>& instruction, |
299 | const ValidationState_t& _) { |
300 | uint16_t opcode; |
301 | uint16_t word_count; |
302 | spvOpcodeSplit(instruction[0], &word_count, &opcode); |
303 | switch (static_cast<SpvOp>(opcode)) { |
304 | case SpvOpTypeBool: |
305 | case SpvOpTypeInt: |
306 | case SpvOpTypeFloat: |
307 | case SpvOpTypeEvent: |
308 | case SpvOpTypeDeviceEvent: |
309 | case SpvOpTypeReserveId: |
310 | case SpvOpTypeQueue: |
311 | return true; |
312 | case SpvOpTypeArray: |
313 | case SpvOpTypeMatrix: |
314 | case SpvOpTypeCooperativeMatrixNV: |
315 | case SpvOpTypeVector: { |
316 | auto base_type = _.FindDef(instruction[2]); |
317 | return base_type && IsTypeNullable(base_type->words(), _); |
318 | } |
319 | case SpvOpTypeStruct: { |
320 | for (size_t elementIndex = 2; elementIndex < instruction.size(); |
321 | ++elementIndex) { |
322 | auto element = _.FindDef(instruction[elementIndex]); |
323 | if (!element || !IsTypeNullable(element->words(), _)) return false; |
324 | } |
325 | return true; |
326 | } |
327 | case SpvOpTypePointer: |
328 | if (instruction[2] == SpvStorageClassPhysicalStorageBuffer) { |
329 | return false; |
330 | } |
331 | return true; |
332 | default: |
333 | return false; |
334 | } |
335 | } |
336 | |
337 | spv_result_t ValidateConstantNull(ValidationState_t& _, |
338 | const Instruction* inst) { |
339 | const auto result_type = _.FindDef(inst->type_id()); |
340 | if (!result_type || !IsTypeNullable(result_type->words(), _)) { |
341 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
342 | << "OpConstantNull Result Type <id> '" |
343 | << _.getIdName(inst->type_id()) << "' cannot have a null value." ; |
344 | } |
345 | |
346 | return SPV_SUCCESS; |
347 | } |
348 | |
349 | // Validates that OpSpecConstant specializes to either int or float type. |
350 | spv_result_t ValidateSpecConstant(ValidationState_t& _, |
351 | const Instruction* inst) { |
352 | // Operand 0 is the <id> of the type that we're specializing to. |
353 | auto type_id = inst->GetOperandAs<const uint32_t>(0); |
354 | auto type_instruction = _.FindDef(type_id); |
355 | auto type_opcode = type_instruction->opcode(); |
356 | if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) { |
357 | return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant " |
358 | "must be an integer or " |
359 | "floating-point number." ; |
360 | } |
361 | return SPV_SUCCESS; |
362 | } |
363 | |
364 | spv_result_t ValidateSpecConstantOp(ValidationState_t& _, |
365 | const Instruction* inst) { |
366 | const auto op = inst->GetOperandAs<SpvOp>(2); |
367 | |
368 | // The binary parser already ensures that the op is valid for *some* |
369 | // environment. Here we check restrictions. |
370 | switch (op) { |
371 | case SpvOpQuantizeToF16: |
372 | if (!_.HasCapability(SpvCapabilityShader)) { |
373 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
374 | << "Specialization constant operation " << spvOpcodeString(op) |
375 | << " requires Shader capability" ; |
376 | } |
377 | break; |
378 | |
379 | case SpvOpUConvert: |
380 | if (!_.features().uconvert_spec_constant_op && |
381 | !_.HasCapability(SpvCapabilityKernel)) { |
382 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
383 | << "Prior to SPIR-V 1.4, specialization constant operation " |
384 | "UConvert requires Kernel capability or extension " |
385 | "SPV_AMD_gpu_shader_int16" ; |
386 | } |
387 | break; |
388 | |
389 | case SpvOpConvertFToS: |
390 | case SpvOpConvertSToF: |
391 | case SpvOpConvertFToU: |
392 | case SpvOpConvertUToF: |
393 | case SpvOpConvertPtrToU: |
394 | case SpvOpConvertUToPtr: |
395 | case SpvOpGenericCastToPtr: |
396 | case SpvOpPtrCastToGeneric: |
397 | case SpvOpBitcast: |
398 | case SpvOpFNegate: |
399 | case SpvOpFAdd: |
400 | case SpvOpFSub: |
401 | case SpvOpFMul: |
402 | case SpvOpFDiv: |
403 | case SpvOpFRem: |
404 | case SpvOpFMod: |
405 | case SpvOpAccessChain: |
406 | case SpvOpInBoundsAccessChain: |
407 | case SpvOpPtrAccessChain: |
408 | case SpvOpInBoundsPtrAccessChain: |
409 | if (!_.HasCapability(SpvCapabilityKernel)) { |
410 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
411 | << "Specialization constant operation " << spvOpcodeString(op) |
412 | << " requires Kernel capability" ; |
413 | } |
414 | break; |
415 | |
416 | default: |
417 | break; |
418 | } |
419 | |
420 | // TODO(dneto): Validate result type and arguments to the various operations. |
421 | return SPV_SUCCESS; |
422 | } |
423 | |
424 | } // namespace |
425 | |
426 | spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { |
427 | switch (inst->opcode()) { |
428 | case SpvOpConstantTrue: |
429 | case SpvOpConstantFalse: |
430 | case SpvOpSpecConstantTrue: |
431 | case SpvOpSpecConstantFalse: |
432 | if (auto error = ValidateConstantBool(_, inst)) return error; |
433 | break; |
434 | case SpvOpConstantComposite: |
435 | case SpvOpSpecConstantComposite: |
436 | if (auto error = ValidateConstantComposite(_, inst)) return error; |
437 | break; |
438 | case SpvOpConstantSampler: |
439 | if (auto error = ValidateConstantSampler(_, inst)) return error; |
440 | break; |
441 | case SpvOpConstantNull: |
442 | if (auto error = ValidateConstantNull(_, inst)) return error; |
443 | break; |
444 | case SpvOpSpecConstant: |
445 | if (auto error = ValidateSpecConstant(_, inst)) return error; |
446 | break; |
447 | case SpvOpSpecConstantOp: |
448 | if (auto error = ValidateSpecConstantOp(_, inst)) return error; |
449 | break; |
450 | default: |
451 | break; |
452 | } |
453 | |
454 | // Generally disallow creating 8- or 16-bit constants unless the full |
455 | // capabilities are present. |
456 | if (spvOpcodeIsConstant(inst->opcode()) && |
457 | _.HasCapability(SpvCapabilityShader) && |
458 | !_.IsPointerType(inst->type_id()) && |
459 | _.ContainsLimitedUseIntOrFloatType(inst->type_id())) { |
460 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
461 | << "Cannot form constants of 8- or 16-bit types" ; |
462 | } |
463 | |
464 | return SPV_SUCCESS; |
465 | } |
466 | |
467 | } // namespace val |
468 | } // namespace spvtools |
469 | |