1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <algorithm>
16
17#include "source/opcode.h"
18#include "source/val/instruction.h"
19#include "source/val/validate.h"
20#include "source/val/validation_state.h"
21
22namespace spvtools {
23namespace val {
24namespace {
25
26// Returns true if |a| and |b| are instructions defining pointers that point to
27// types logically match and the decorations that apply to |b| are a subset
28// of the decorations that apply to |a|.
29bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
30 ValidationState_t& _) {
31 if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
32 return false;
33 }
34
35 const auto& dec_a = _.id_decorations(a->id());
36 const auto& dec_b = _.id_decorations(b->id());
37 for (const auto& dec : dec_b) {
38 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
39 return false;
40 }
41 }
42
43 uint32_t a_type = a->GetOperandAs<uint32_t>(2);
44 uint32_t b_type = b->GetOperandAs<uint32_t>(2);
45
46 if (a_type == b_type) {
47 return true;
48 }
49
50 Instruction* a_type_inst = _.FindDef(a_type);
51 Instruction* b_type_inst = _.FindDef(b_type);
52
53 return _.LogicallyMatch(a_type_inst, b_type_inst, true);
54}
55
56spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
57 const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
58 const auto function_type = _.FindDef(function_type_id);
59 if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
60 return _.diag(SPV_ERROR_INVALID_ID, inst)
61 << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
62 << "' is not a function type.";
63 }
64
65 const auto return_id = function_type->GetOperandAs<uint32_t>(1);
66 if (return_id != inst->type_id()) {
67 return _.diag(SPV_ERROR_INVALID_ID, inst)
68 << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
69 << "' does not match the Function Type's return type <id> '"
70 << _.getIdName(return_id) << "'.";
71 }
72
73 const std::vector<SpvOp> acceptable = {
74 SpvOpDecorate,
75 SpvOpEnqueueKernel,
76 SpvOpEntryPoint,
77 SpvOpExecutionMode,
78 SpvOpExecutionModeId,
79 SpvOpFunctionCall,
80 SpvOpGetKernelNDrangeSubGroupCount,
81 SpvOpGetKernelNDrangeMaxSubGroupSize,
82 SpvOpGetKernelWorkGroupSize,
83 SpvOpGetKernelPreferredWorkGroupSizeMultiple,
84 SpvOpGetKernelLocalSizeForSubgroupCount,
85 SpvOpGetKernelMaxNumSubgroups,
86 SpvOpName};
87 for (auto& pair : inst->uses()) {
88 const auto* use = pair.first;
89 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
90 acceptable.end() &&
91 !use->IsNonSemantic() && !use->IsDebugInfo()) {
92 return _.diag(SPV_ERROR_INVALID_ID, use)
93 << "Invalid use of function result id " << _.getIdName(inst->id())
94 << ".";
95 }
96 }
97
98 return SPV_SUCCESS;
99}
100
101spv_result_t ValidateFunctionParameter(ValidationState_t& _,
102 const Instruction* inst) {
103 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
104 size_t param_index = 0;
105 size_t inst_num = inst->LineNum() - 1;
106 if (inst_num == 0) {
107 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
108 << "Function parameter cannot be the first instruction.";
109 }
110
111 auto func_inst = &_.ordered_instructions()[inst_num];
112 while (--inst_num) {
113 func_inst = &_.ordered_instructions()[inst_num];
114 if (func_inst->opcode() == SpvOpFunction) {
115 break;
116 } else if (func_inst->opcode() == SpvOpFunctionParameter) {
117 ++param_index;
118 }
119 }
120
121 if (func_inst->opcode() != SpvOpFunction) {
122 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
123 << "Function parameter must be preceded by a function.";
124 }
125
126 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
127 const auto function_type = _.FindDef(function_type_id);
128 if (!function_type) {
129 return _.diag(SPV_ERROR_INVALID_ID, func_inst)
130 << "Missing function type definition.";
131 }
132 if (param_index >= function_type->words().size() - 3) {
133 return _.diag(SPV_ERROR_INVALID_ID, inst)
134 << "Too many OpFunctionParameters for " << func_inst->id()
135 << ": expected " << function_type->words().size() - 3
136 << " based on the function's type";
137 }
138
139 const auto param_type =
140 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
141 if (!param_type || inst->type_id() != param_type->id()) {
142 return _.diag(SPV_ERROR_INVALID_ID, inst)
143 << "OpFunctionParameter Result Type <id> '"
144 << _.getIdName(inst->type_id())
145 << "' does not match the OpTypeFunction parameter "
146 "type of the same index.";
147 }
148
149 // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
150 // RestrictPointerEXT, or AliasedPointerEXT.
151 auto param_nonarray_type_id = param_type->id();
152 while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
153 param_nonarray_type_id =
154 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
155 }
156 if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
157 auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
158 if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
159 SpvStorageClassPhysicalStorageBufferEXT) {
160 // check for Aliased or Restrict
161 const auto& decorations = _.id_decorations(inst->id());
162
163 bool foundAliased = std::any_of(
164 decorations.begin(), decorations.end(), [](const Decoration& d) {
165 return SpvDecorationAliased == d.dec_type();
166 });
167
168 bool foundRestrict = std::any_of(
169 decorations.begin(), decorations.end(), [](const Decoration& d) {
170 return SpvDecorationRestrict == d.dec_type();
171 });
172
173 if (!foundAliased && !foundRestrict) {
174 return _.diag(SPV_ERROR_INVALID_ID, inst)
175 << "OpFunctionParameter " << inst->id()
176 << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
177 "pointer.";
178 }
179 if (foundAliased && foundRestrict) {
180 return _.diag(SPV_ERROR_INVALID_ID, inst)
181 << "OpFunctionParameter " << inst->id()
182 << ": can't specify both Aliased and Restrict for "
183 "PhysicalStorageBufferEXT pointer.";
184 }
185 } else {
186 const auto pointee_type_id =
187 param_nonarray_type->GetOperandAs<uint32_t>(2);
188 const auto pointee_type = _.FindDef(pointee_type_id);
189 if (SpvOpTypePointer == pointee_type->opcode() &&
190 pointee_type->GetOperandAs<uint32_t>(1u) ==
191 SpvStorageClassPhysicalStorageBufferEXT) {
192 // check for AliasedPointerEXT/RestrictPointerEXT
193 const auto& decorations = _.id_decorations(inst->id());
194
195 bool foundAliased = std::any_of(
196 decorations.begin(), decorations.end(), [](const Decoration& d) {
197 return SpvDecorationAliasedPointerEXT == d.dec_type();
198 });
199
200 bool foundRestrict = std::any_of(
201 decorations.begin(), decorations.end(), [](const Decoration& d) {
202 return SpvDecorationRestrictPointerEXT == d.dec_type();
203 });
204
205 if (!foundAliased && !foundRestrict) {
206 return _.diag(SPV_ERROR_INVALID_ID, inst)
207 << "OpFunctionParameter " << inst->id()
208 << ": expected AliasedPointerEXT or RestrictPointerEXT for "
209 "PhysicalStorageBufferEXT pointer.";
210 }
211 if (foundAliased && foundRestrict) {
212 return _.diag(SPV_ERROR_INVALID_ID, inst)
213 << "OpFunctionParameter " << inst->id()
214 << ": can't specify both AliasedPointerEXT and "
215 "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
216 }
217 }
218 }
219 }
220
221 return SPV_SUCCESS;
222}
223
224spv_result_t ValidateFunctionCall(ValidationState_t& _,
225 const Instruction* inst) {
226 const auto function_id = inst->GetOperandAs<uint32_t>(2);
227 const auto function = _.FindDef(function_id);
228 if (!function || SpvOpFunction != function->opcode()) {
229 return _.diag(SPV_ERROR_INVALID_ID, inst)
230 << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
231 << "' is not a function.";
232 }
233
234 auto return_type = _.FindDef(function->type_id());
235 if (!return_type || return_type->id() != inst->type_id()) {
236 return _.diag(SPV_ERROR_INVALID_ID, inst)
237 << "OpFunctionCall Result Type <id> '"
238 << _.getIdName(inst->type_id())
239 << "'s type does not match Function <id> '"
240 << _.getIdName(return_type->id()) << "'s return type.";
241 }
242
243 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
244 const auto function_type = _.FindDef(function_type_id);
245 if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
246 return _.diag(SPV_ERROR_INVALID_ID, inst)
247 << "Missing function type definition.";
248 }
249
250 const auto function_call_arg_count = inst->words().size() - 4;
251 const auto function_param_count = function_type->words().size() - 3;
252 if (function_param_count != function_call_arg_count) {
253 return _.diag(SPV_ERROR_INVALID_ID, inst)
254 << "OpFunctionCall Function <id>'s parameter count does not match "
255 "the argument count.";
256 }
257
258 for (size_t argument_index = 3, param_index = 2;
259 argument_index < inst->operands().size();
260 argument_index++, param_index++) {
261 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
262 const auto argument = _.FindDef(argument_id);
263 if (!argument) {
264 return _.diag(SPV_ERROR_INVALID_ID, inst)
265 << "Missing argument " << argument_index - 3 << " definition.";
266 }
267
268 const auto argument_type = _.FindDef(argument->type_id());
269 if (!argument_type) {
270 return _.diag(SPV_ERROR_INVALID_ID, inst)
271 << "Missing argument " << argument_index - 3
272 << " type definition.";
273 }
274
275 const auto parameter_type_id =
276 function_type->GetOperandAs<uint32_t>(param_index);
277 const auto parameter_type = _.FindDef(parameter_type_id);
278 if (!parameter_type || argument_type->id() != parameter_type->id()) {
279 if (!_.options()->before_hlsl_legalization ||
280 !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
281 return _.diag(SPV_ERROR_INVALID_ID, inst)
282 << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
283 << "'s type does not match Function <id> '"
284 << _.getIdName(parameter_type_id) << "'s parameter type.";
285 }
286 }
287
288 if (_.addressing_model() == SpvAddressingModelLogical) {
289 if (parameter_type->opcode() == SpvOpTypePointer &&
290 !_.options()->relax_logical_pointer) {
291 SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
292 // Validate which storage classes can be pointer operands.
293 switch (sc) {
294 case SpvStorageClassUniformConstant:
295 case SpvStorageClassFunction:
296 case SpvStorageClassPrivate:
297 case SpvStorageClassWorkgroup:
298 case SpvStorageClassAtomicCounter:
299 // These are always allowed.
300 break;
301 case SpvStorageClassStorageBuffer:
302 if (!_.features().variable_pointers_storage_buffer) {
303 return _.diag(SPV_ERROR_INVALID_ID, inst)
304 << "StorageBuffer pointer operand "
305 << _.getIdName(argument_id)
306 << " requires a variable pointers capability";
307 }
308 break;
309 default:
310 return _.diag(SPV_ERROR_INVALID_ID, inst)
311 << "Invalid storage class for pointer operand "
312 << _.getIdName(argument_id);
313 }
314
315 // Validate memory object declaration requirements.
316 if (argument->opcode() != SpvOpVariable &&
317 argument->opcode() != SpvOpFunctionParameter) {
318 const bool ssbo_vptr =
319 _.features().variable_pointers_storage_buffer &&
320 sc == SpvStorageClassStorageBuffer;
321 const bool wg_vptr =
322 _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
323 const bool uc_ptr = sc == SpvStorageClassUniformConstant;
324 if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
325 return _.diag(SPV_ERROR_INVALID_ID, inst)
326 << "Pointer operand " << _.getIdName(argument_id)
327 << " must be a memory object declaration";
328 }
329 }
330 }
331 }
332 }
333 return SPV_SUCCESS;
334}
335
336} // namespace
337
338spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
339 switch (inst->opcode()) {
340 case SpvOpFunction:
341 if (auto error = ValidateFunction(_, inst)) return error;
342 break;
343 case SpvOpFunctionParameter:
344 if (auto error = ValidateFunctionParameter(_, inst)) return error;
345 break;
346 case SpvOpFunctionCall:
347 if (auto error = ValidateFunctionCall(_, inst)) return error;
348 break;
349 default:
350 break;
351 }
352
353 return SPV_SUCCESS;
354}
355
356} // namespace val
357} // namespace spvtools
358