1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Validates correctness of derivative SPIR-V instructions.
16
17#include "source/val/validate.h"
18
19#include <string>
20
21#include "source/diagnostic.h"
22#include "source/opcode.h"
23#include "source/val/instruction.h"
24#include "source/val/validation_state.h"
25
26namespace spvtools {
27namespace val {
28
29// Validates correctness of derivative instructions.
30spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31 const SpvOp opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case SpvOpDPdx:
36 case SpvOpDPdy:
37 case SpvOpFwidth:
38 case SpvOpDPdxFine:
39 case SpvOpDPdyFine:
40 case SpvOpFwidthFine:
41 case SpvOpDPdxCoarse:
42 case SpvOpDPdyCoarse:
43 case SpvOpFwidthCoarse: {
44 if (!_.IsFloatScalarOrVectorType(result_type)) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Expected Result Type to be float scalar or vector type: "
47 << spvOpcodeString(opcode);
48 }
49 if (!_.ContainsSizedIntOrFloatType(result_type, SpvOpTypeFloat, 32)) {
50 return _.diag(SPV_ERROR_INVALID_DATA, inst)
51 << "Result type component width must be 32 bits";
52 }
53
54 const uint32_t p_type = _.GetOperandTypeId(inst, 2);
55 if (p_type != result_type) {
56 return _.diag(SPV_ERROR_INVALID_DATA, inst)
57 << "Expected P type and Result Type to be the same: "
58 << spvOpcodeString(opcode);
59 }
60 _.function(inst->function()->id())
61 ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
62 std::string* message) {
63 if (model != SpvExecutionModelFragment &&
64 model != SpvExecutionModelGLCompute) {
65 if (message) {
66 *message =
67 std::string(
68 "Derivative instructions require Fragment or GLCompute "
69 "execution model: ") +
70 spvOpcodeString(opcode);
71 }
72 return false;
73 }
74 return true;
75 });
76 _.function(inst->function()->id())
77 ->RegisterLimitation([opcode](const ValidationState_t& state,
78 const Function* entry_point,
79 std::string* message) {
80 const auto* models = state.GetExecutionModels(entry_point->id());
81 const auto* modes = state.GetExecutionModes(entry_point->id());
82 if (models->find(SpvExecutionModelGLCompute) != models->end() &&
83 modes->find(SpvExecutionModeDerivativeGroupLinearNV) ==
84 modes->end() &&
85 modes->find(SpvExecutionModeDerivativeGroupQuadsNV) ==
86 modes->end()) {
87 if (message) {
88 *message = std::string(
89 "Derivative instructions require "
90 "DerivativeGroupQuadsNV "
91 "or DerivativeGroupLinearNV execution mode for "
92 "GLCompute execution model: ") +
93 spvOpcodeString(opcode);
94 }
95 return false;
96 }
97 return true;
98 });
99 break;
100 }
101
102 default:
103 break;
104 }
105
106 return SPV_SUCCESS;
107}
108
109} // namespace val
110} // namespace spvtools
111