1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Performs validation of arithmetic instructions.
16
17#include "source/val/validate.h"
18
19#include <vector>
20
21#include "source/diagnostic.h"
22#include "source/opcode.h"
23#include "source/val/instruction.h"
24#include "source/val/validation_state.h"
25
26namespace spvtools {
27namespace val {
28
29// Validates correctness of arithmetic instructions.
30spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
31 const SpvOp opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case SpvOpFAdd:
36 case SpvOpFSub:
37 case SpvOpFMul:
38 case SpvOpFDiv:
39 case SpvOpFRem:
40 case SpvOpFMod:
41 case SpvOpFNegate: {
42 bool supportsCoopMat =
43 (opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
44 if (!_.IsFloatScalarType(result_type) &&
45 !_.IsFloatVectorType(result_type) &&
46 !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
47 return _.diag(SPV_ERROR_INVALID_DATA, inst)
48 << "Expected floating scalar or vector type as Result Type: "
49 << spvOpcodeString(opcode);
50
51 for (size_t operand_index = 2; operand_index < inst->operands().size();
52 ++operand_index) {
53 if (_.GetOperandTypeId(inst, operand_index) != result_type)
54 return _.diag(SPV_ERROR_INVALID_DATA, inst)
55 << "Expected arithmetic operands to be of Result Type: "
56 << spvOpcodeString(opcode) << " operand index "
57 << operand_index;
58 }
59 break;
60 }
61
62 case SpvOpUDiv:
63 case SpvOpUMod: {
64 bool supportsCoopMat = (opcode == SpvOpUDiv);
65 if (!_.IsUnsignedIntScalarType(result_type) &&
66 !_.IsUnsignedIntVectorType(result_type) &&
67 !(supportsCoopMat &&
68 _.IsUnsignedIntCooperativeMatrixType(result_type)))
69 return _.diag(SPV_ERROR_INVALID_DATA, inst)
70 << "Expected unsigned int scalar or vector type as Result Type: "
71 << spvOpcodeString(opcode);
72
73 for (size_t operand_index = 2; operand_index < inst->operands().size();
74 ++operand_index) {
75 if (_.GetOperandTypeId(inst, operand_index) != result_type)
76 return _.diag(SPV_ERROR_INVALID_DATA, inst)
77 << "Expected arithmetic operands to be of Result Type: "
78 << spvOpcodeString(opcode) << " operand index "
79 << operand_index;
80 }
81 break;
82 }
83
84 case SpvOpISub:
85 case SpvOpIAdd:
86 case SpvOpIMul:
87 case SpvOpSDiv:
88 case SpvOpSMod:
89 case SpvOpSRem:
90 case SpvOpSNegate: {
91 bool supportsCoopMat =
92 (opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
93 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
94 !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
95 return _.diag(SPV_ERROR_INVALID_DATA, inst)
96 << "Expected int scalar or vector type as Result Type: "
97 << spvOpcodeString(opcode);
98
99 const uint32_t dimension = _.GetDimension(result_type);
100 const uint32_t bit_width = _.GetBitWidth(result_type);
101
102 for (size_t operand_index = 2; operand_index < inst->operands().size();
103 ++operand_index) {
104 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
105 if (!type_id ||
106 (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
107 !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
108 return _.diag(SPV_ERROR_INVALID_DATA, inst)
109 << "Expected int scalar or vector type as operand: "
110 << spvOpcodeString(opcode) << " operand index "
111 << operand_index;
112
113 if (_.GetDimension(type_id) != dimension)
114 return _.diag(SPV_ERROR_INVALID_DATA, inst)
115 << "Expected arithmetic operands to have the same dimension "
116 << "as Result Type: " << spvOpcodeString(opcode)
117 << " operand index " << operand_index;
118
119 if (_.GetBitWidth(type_id) != bit_width)
120 return _.diag(SPV_ERROR_INVALID_DATA, inst)
121 << "Expected arithmetic operands to have the same bit width "
122 << "as Result Type: " << spvOpcodeString(opcode)
123 << " operand index " << operand_index;
124 }
125 break;
126 }
127
128 case SpvOpDot: {
129 if (!_.IsFloatScalarType(result_type))
130 return _.diag(SPV_ERROR_INVALID_DATA, inst)
131 << "Expected float scalar type as Result Type: "
132 << spvOpcodeString(opcode);
133
134 uint32_t first_vector_num_components = 0;
135
136 for (size_t operand_index = 2; operand_index < inst->operands().size();
137 ++operand_index) {
138 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
139
140 if (!type_id || !_.IsFloatVectorType(type_id))
141 return _.diag(SPV_ERROR_INVALID_DATA, inst)
142 << "Expected float vector as operand: "
143 << spvOpcodeString(opcode) << " operand index "
144 << operand_index;
145
146 const uint32_t component_type = _.GetComponentType(type_id);
147 if (component_type != result_type)
148 return _.diag(SPV_ERROR_INVALID_DATA, inst)
149 << "Expected component type to be equal to Result Type: "
150 << spvOpcodeString(opcode) << " operand index "
151 << operand_index;
152
153 const uint32_t num_components = _.GetDimension(type_id);
154 if (operand_index == 2) {
155 first_vector_num_components = num_components;
156 } else if (num_components != first_vector_num_components) {
157 return _.diag(SPV_ERROR_INVALID_DATA, inst)
158 << "Expected operands to have the same number of componenets: "
159 << spvOpcodeString(opcode);
160 }
161 }
162 break;
163 }
164
165 case SpvOpVectorTimesScalar: {
166 if (!_.IsFloatVectorType(result_type))
167 return _.diag(SPV_ERROR_INVALID_DATA, inst)
168 << "Expected float vector type as Result Type: "
169 << spvOpcodeString(opcode);
170
171 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
172 if (result_type != vector_type_id)
173 return _.diag(SPV_ERROR_INVALID_DATA, inst)
174 << "Expected vector operand type to be equal to Result Type: "
175 << spvOpcodeString(opcode);
176
177 const uint32_t component_type = _.GetComponentType(vector_type_id);
178
179 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
180 if (component_type != scalar_type_id)
181 return _.diag(SPV_ERROR_INVALID_DATA, inst)
182 << "Expected scalar operand type to be equal to the component "
183 << "type of the vector operand: " << spvOpcodeString(opcode);
184
185 break;
186 }
187
188 case SpvOpMatrixTimesScalar: {
189 if (!_.IsFloatMatrixType(result_type) &&
190 !_.IsCooperativeMatrixType(result_type))
191 return _.diag(SPV_ERROR_INVALID_DATA, inst)
192 << "Expected float matrix type as Result Type: "
193 << spvOpcodeString(opcode);
194
195 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
196 if (result_type != matrix_type_id)
197 return _.diag(SPV_ERROR_INVALID_DATA, inst)
198 << "Expected matrix operand type to be equal to Result Type: "
199 << spvOpcodeString(opcode);
200
201 const uint32_t component_type = _.GetComponentType(matrix_type_id);
202
203 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
204 if (component_type != scalar_type_id)
205 return _.diag(SPV_ERROR_INVALID_DATA, inst)
206 << "Expected scalar operand type to be equal to the component "
207 << "type of the matrix operand: " << spvOpcodeString(opcode);
208
209 break;
210 }
211
212 case SpvOpVectorTimesMatrix: {
213 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
214 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
215
216 if (!_.IsFloatVectorType(result_type))
217 return _.diag(SPV_ERROR_INVALID_DATA, inst)
218 << "Expected float vector type as Result Type: "
219 << spvOpcodeString(opcode);
220
221 const uint32_t res_component_type = _.GetComponentType(result_type);
222
223 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
224 return _.diag(SPV_ERROR_INVALID_DATA, inst)
225 << "Expected float vector type as left operand: "
226 << spvOpcodeString(opcode);
227
228 if (res_component_type != _.GetComponentType(vector_type_id))
229 return _.diag(SPV_ERROR_INVALID_DATA, inst)
230 << "Expected component types of Result Type and vector to be "
231 << "equal: " << spvOpcodeString(opcode);
232
233 uint32_t matrix_num_rows = 0;
234 uint32_t matrix_num_cols = 0;
235 uint32_t matrix_col_type = 0;
236 uint32_t matrix_component_type = 0;
237 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
238 &matrix_num_cols, &matrix_col_type,
239 &matrix_component_type))
240 return _.diag(SPV_ERROR_INVALID_DATA, inst)
241 << "Expected float matrix type as right operand: "
242 << spvOpcodeString(opcode);
243
244 if (res_component_type != matrix_component_type)
245 return _.diag(SPV_ERROR_INVALID_DATA, inst)
246 << "Expected component types of Result Type and matrix to be "
247 << "equal: " << spvOpcodeString(opcode);
248
249 if (matrix_num_cols != _.GetDimension(result_type))
250 return _.diag(SPV_ERROR_INVALID_DATA, inst)
251 << "Expected number of columns of the matrix to be equal to "
252 << "Result Type vector size: " << spvOpcodeString(opcode);
253
254 if (matrix_num_rows != _.GetDimension(vector_type_id))
255 return _.diag(SPV_ERROR_INVALID_DATA, inst)
256 << "Expected number of rows of the matrix to be equal to the "
257 << "vector operand size: " << spvOpcodeString(opcode);
258
259 break;
260 }
261
262 case SpvOpMatrixTimesVector: {
263 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
264 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
265
266 if (!_.IsFloatVectorType(result_type))
267 return _.diag(SPV_ERROR_INVALID_DATA, inst)
268 << "Expected float vector type as Result Type: "
269 << spvOpcodeString(opcode);
270
271 uint32_t matrix_num_rows = 0;
272 uint32_t matrix_num_cols = 0;
273 uint32_t matrix_col_type = 0;
274 uint32_t matrix_component_type = 0;
275 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
276 &matrix_num_cols, &matrix_col_type,
277 &matrix_component_type))
278 return _.diag(SPV_ERROR_INVALID_DATA, inst)
279 << "Expected float matrix type as left operand: "
280 << spvOpcodeString(opcode);
281
282 if (result_type != matrix_col_type)
283 return _.diag(SPV_ERROR_INVALID_DATA, inst)
284 << "Expected column type of the matrix to be equal to Result "
285 "Type: "
286 << spvOpcodeString(opcode);
287
288 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
289 return _.diag(SPV_ERROR_INVALID_DATA, inst)
290 << "Expected float vector type as right operand: "
291 << spvOpcodeString(opcode);
292
293 if (matrix_component_type != _.GetComponentType(vector_type_id))
294 return _.diag(SPV_ERROR_INVALID_DATA, inst)
295 << "Expected component types of the operands to be equal: "
296 << spvOpcodeString(opcode);
297
298 if (matrix_num_cols != _.GetDimension(vector_type_id))
299 return _.diag(SPV_ERROR_INVALID_DATA, inst)
300 << "Expected number of columns of the matrix to be equal to the "
301 << "vector size: " << spvOpcodeString(opcode);
302
303 break;
304 }
305
306 case SpvOpMatrixTimesMatrix: {
307 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
308 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
309
310 uint32_t res_num_rows = 0;
311 uint32_t res_num_cols = 0;
312 uint32_t res_col_type = 0;
313 uint32_t res_component_type = 0;
314 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
315 &res_col_type, &res_component_type))
316 return _.diag(SPV_ERROR_INVALID_DATA, inst)
317 << "Expected float matrix type as Result Type: "
318 << spvOpcodeString(opcode);
319
320 uint32_t left_num_rows = 0;
321 uint32_t left_num_cols = 0;
322 uint32_t left_col_type = 0;
323 uint32_t left_component_type = 0;
324 if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
325 &left_col_type, &left_component_type))
326 return _.diag(SPV_ERROR_INVALID_DATA, inst)
327 << "Expected float matrix type as left operand: "
328 << spvOpcodeString(opcode);
329
330 uint32_t right_num_rows = 0;
331 uint32_t right_num_cols = 0;
332 uint32_t right_col_type = 0;
333 uint32_t right_component_type = 0;
334 if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
335 &right_col_type, &right_component_type))
336 return _.diag(SPV_ERROR_INVALID_DATA, inst)
337 << "Expected float matrix type as right operand: "
338 << spvOpcodeString(opcode);
339
340 if (!_.IsFloatScalarType(res_component_type))
341 return _.diag(SPV_ERROR_INVALID_DATA, inst)
342 << "Expected float matrix type as Result Type: "
343 << spvOpcodeString(opcode);
344
345 if (res_col_type != left_col_type)
346 return _.diag(SPV_ERROR_INVALID_DATA, inst)
347 << "Expected column types of Result Type and left matrix to be "
348 << "equal: " << spvOpcodeString(opcode);
349
350 if (res_component_type != right_component_type)
351 return _.diag(SPV_ERROR_INVALID_DATA, inst)
352 << "Expected component types of Result Type and right matrix to "
353 "be "
354 << "equal: " << spvOpcodeString(opcode);
355
356 if (res_num_cols != right_num_cols)
357 return _.diag(SPV_ERROR_INVALID_DATA, inst)
358 << "Expected number of columns of Result Type and right matrix "
359 "to "
360 << "be equal: " << spvOpcodeString(opcode);
361
362 if (left_num_cols != right_num_rows)
363 return _.diag(SPV_ERROR_INVALID_DATA, inst)
364 << "Expected number of columns of left matrix and number of "
365 "rows "
366 << "of right matrix to be equal: " << spvOpcodeString(opcode);
367
368 assert(left_num_rows == res_num_rows);
369 break;
370 }
371
372 case SpvOpOuterProduct: {
373 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
374 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
375
376 uint32_t res_num_rows = 0;
377 uint32_t res_num_cols = 0;
378 uint32_t res_col_type = 0;
379 uint32_t res_component_type = 0;
380 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
381 &res_col_type, &res_component_type))
382 return _.diag(SPV_ERROR_INVALID_DATA, inst)
383 << "Expected float matrix type as Result Type: "
384 << spvOpcodeString(opcode);
385
386 if (left_type_id != res_col_type)
387 return _.diag(SPV_ERROR_INVALID_DATA, inst)
388 << "Expected column type of Result Type to be equal to the type "
389 << "of the left operand: " << spvOpcodeString(opcode);
390
391 if (!right_type_id || !_.IsFloatVectorType(right_type_id))
392 return _.diag(SPV_ERROR_INVALID_DATA, inst)
393 << "Expected float vector type as right operand: "
394 << spvOpcodeString(opcode);
395
396 if (res_component_type != _.GetComponentType(right_type_id))
397 return _.diag(SPV_ERROR_INVALID_DATA, inst)
398 << "Expected component types of the operands to be equal: "
399 << spvOpcodeString(opcode);
400
401 if (res_num_cols != _.GetDimension(right_type_id))
402 return _.diag(SPV_ERROR_INVALID_DATA, inst)
403 << "Expected number of columns of the matrix to be equal to the "
404 << "vector size of the right operand: "
405 << spvOpcodeString(opcode);
406
407 break;
408 }
409
410 case SpvOpIAddCarry:
411 case SpvOpISubBorrow:
412 case SpvOpUMulExtended:
413 case SpvOpSMulExtended: {
414 std::vector<uint32_t> result_types;
415 if (!_.GetStructMemberTypes(result_type, &result_types))
416 return _.diag(SPV_ERROR_INVALID_DATA, inst)
417 << "Expected a struct as Result Type: "
418 << spvOpcodeString(opcode);
419
420 if (result_types.size() != 2)
421 return _.diag(SPV_ERROR_INVALID_DATA, inst)
422 << "Expected Result Type struct to have two members: "
423 << spvOpcodeString(opcode);
424
425 if (opcode == SpvOpSMulExtended) {
426 if (!_.IsIntScalarType(result_types[0]) &&
427 !_.IsIntVectorType(result_types[0]))
428 return _.diag(SPV_ERROR_INVALID_DATA, inst)
429 << "Expected Result Type struct member types to be integer "
430 "scalar "
431 << "or vector: " << spvOpcodeString(opcode);
432 } else {
433 if (!_.IsUnsignedIntScalarType(result_types[0]) &&
434 !_.IsUnsignedIntVectorType(result_types[0]))
435 return _.diag(SPV_ERROR_INVALID_DATA, inst)
436 << "Expected Result Type struct member types to be unsigned "
437 << "integer scalar or vector: " << spvOpcodeString(opcode);
438 }
439
440 if (result_types[0] != result_types[1])
441 return _.diag(SPV_ERROR_INVALID_DATA, inst)
442 << "Expected Result Type struct member types to be identical: "
443 << spvOpcodeString(opcode);
444
445 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
446 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
447
448 if (left_type_id != result_types[0] || right_type_id != result_types[0])
449 return _.diag(SPV_ERROR_INVALID_DATA, inst)
450 << "Expected both operands to be of Result Type member type: "
451 << spvOpcodeString(opcode);
452
453 break;
454 }
455
456 case SpvOpCooperativeMatrixMulAddNV: {
457 const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
458 const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
459 const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
460 const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
461
462 if (!_.IsCooperativeMatrixType(A_type_id)) {
463 return _.diag(SPV_ERROR_INVALID_DATA, inst)
464 << "Expected cooperative matrix type as A Type: "
465 << spvOpcodeString(opcode);
466 }
467 if (!_.IsCooperativeMatrixType(B_type_id)) {
468 return _.diag(SPV_ERROR_INVALID_DATA, inst)
469 << "Expected cooperative matrix type as B Type: "
470 << spvOpcodeString(opcode);
471 }
472 if (!_.IsCooperativeMatrixType(C_type_id)) {
473 return _.diag(SPV_ERROR_INVALID_DATA, inst)
474 << "Expected cooperative matrix type as C Type: "
475 << spvOpcodeString(opcode);
476 }
477 if (!_.IsCooperativeMatrixType(D_type_id)) {
478 return _.diag(SPV_ERROR_INVALID_DATA, inst)
479 << "Expected cooperative matrix type as Result Type: "
480 << spvOpcodeString(opcode);
481 }
482
483 const auto A = _.FindDef(A_type_id);
484 const auto B = _.FindDef(B_type_id);
485 const auto C = _.FindDef(C_type_id);
486 const auto D = _.FindDef(D_type_id);
487
488 std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
489 A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
490
491 A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
492 B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
493 C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
494 D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
495
496 A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
497 B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
498 C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
499 D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
500
501 A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
502 B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
503 C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
504 D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
505
506 const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
507 std::tuple<bool, bool, uint32_t> Y) {
508 return (std::get<1>(X) && std::get<1>(Y) &&
509 std::get<2>(X) != std::get<2>(Y));
510 };
511
512 if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
513 notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
514 notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
515 return _.diag(SPV_ERROR_INVALID_DATA, inst)
516 << "Cooperative matrix scopes must match: "
517 << spvOpcodeString(opcode);
518 }
519
520 if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
521 notEqual(C_rows, D_rows)) {
522 return _.diag(SPV_ERROR_INVALID_DATA, inst)
523 << "Cooperative matrix 'M' mismatch: "
524 << spvOpcodeString(opcode);
525 }
526
527 if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
528 notEqual(C_cols, D_cols)) {
529 return _.diag(SPV_ERROR_INVALID_DATA, inst)
530 << "Cooperative matrix 'N' mismatch: "
531 << spvOpcodeString(opcode);
532 }
533
534 if (notEqual(A_cols, B_rows)) {
535 return _.diag(SPV_ERROR_INVALID_DATA, inst)
536 << "Cooperative matrix 'K' mismatch: "
537 << spvOpcodeString(opcode);
538 }
539 break;
540 }
541
542 default:
543 break;
544 }
545
546 return SPV_SUCCESS;
547}
548
549} // namespace val
550} // namespace spvtools
551