1 | // Copyright (c) 2018 Google LLC. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include <algorithm> |
16 | |
17 | #include "source/opcode.h" |
18 | #include "source/val/instruction.h" |
19 | #include "source/val/validate.h" |
20 | #include "source/val/validation_state.h" |
21 | |
22 | namespace spvtools { |
23 | namespace val { |
24 | namespace { |
25 | |
26 | // Returns true if |a| and |b| are instructions defining pointers that point to |
27 | // types logically match and the decorations that apply to |b| are a subset |
28 | // of the decorations that apply to |a|. |
29 | bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, |
30 | ValidationState_t& _) { |
31 | if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) { |
32 | return false; |
33 | } |
34 | |
35 | const auto& dec_a = _.id_decorations(a->id()); |
36 | const auto& dec_b = _.id_decorations(b->id()); |
37 | for (const auto& dec : dec_b) { |
38 | if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { |
39 | return false; |
40 | } |
41 | } |
42 | |
43 | uint32_t a_type = a->GetOperandAs<uint32_t>(2); |
44 | uint32_t b_type = b->GetOperandAs<uint32_t>(2); |
45 | |
46 | if (a_type == b_type) { |
47 | return true; |
48 | } |
49 | |
50 | Instruction* a_type_inst = _.FindDef(a_type); |
51 | Instruction* b_type_inst = _.FindDef(b_type); |
52 | |
53 | return _.LogicallyMatch(a_type_inst, b_type_inst, true); |
54 | } |
55 | |
56 | spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { |
57 | const auto function_type_id = inst->GetOperandAs<uint32_t>(3); |
58 | const auto function_type = _.FindDef(function_type_id); |
59 | if (!function_type || SpvOpTypeFunction != function_type->opcode()) { |
60 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
61 | << "OpFunction Function Type <id> '" << _.getIdName(function_type_id) |
62 | << "' is not a function type." ; |
63 | } |
64 | |
65 | const auto return_id = function_type->GetOperandAs<uint32_t>(1); |
66 | if (return_id != inst->type_id()) { |
67 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
68 | << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id()) |
69 | << "' does not match the Function Type's return type <id> '" |
70 | << _.getIdName(return_id) << "'." ; |
71 | } |
72 | |
73 | const std::vector<SpvOp> acceptable = { |
74 | SpvOpDecorate, |
75 | SpvOpEnqueueKernel, |
76 | SpvOpEntryPoint, |
77 | SpvOpExecutionMode, |
78 | SpvOpExecutionModeId, |
79 | SpvOpFunctionCall, |
80 | SpvOpGetKernelNDrangeSubGroupCount, |
81 | SpvOpGetKernelNDrangeMaxSubGroupSize, |
82 | SpvOpGetKernelWorkGroupSize, |
83 | SpvOpGetKernelPreferredWorkGroupSizeMultiple, |
84 | SpvOpGetKernelLocalSizeForSubgroupCount, |
85 | SpvOpGetKernelMaxNumSubgroups, |
86 | SpvOpName}; |
87 | for (auto& pair : inst->uses()) { |
88 | const auto* use = pair.first; |
89 | if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == |
90 | acceptable.end() && |
91 | !use->IsNonSemantic() && !use->IsDebugInfo()) { |
92 | return _.diag(SPV_ERROR_INVALID_ID, use) |
93 | << "Invalid use of function result id " << _.getIdName(inst->id()) |
94 | << "." ; |
95 | } |
96 | } |
97 | |
98 | return SPV_SUCCESS; |
99 | } |
100 | |
101 | spv_result_t ValidateFunctionParameter(ValidationState_t& _, |
102 | const Instruction* inst) { |
103 | // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. |
104 | size_t param_index = 0; |
105 | size_t inst_num = inst->LineNum() - 1; |
106 | if (inst_num == 0) { |
107 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
108 | << "Function parameter cannot be the first instruction." ; |
109 | } |
110 | |
111 | auto func_inst = &_.ordered_instructions()[inst_num]; |
112 | while (--inst_num) { |
113 | func_inst = &_.ordered_instructions()[inst_num]; |
114 | if (func_inst->opcode() == SpvOpFunction) { |
115 | break; |
116 | } else if (func_inst->opcode() == SpvOpFunctionParameter) { |
117 | ++param_index; |
118 | } |
119 | } |
120 | |
121 | if (func_inst->opcode() != SpvOpFunction) { |
122 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
123 | << "Function parameter must be preceded by a function." ; |
124 | } |
125 | |
126 | const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); |
127 | const auto function_type = _.FindDef(function_type_id); |
128 | if (!function_type) { |
129 | return _.diag(SPV_ERROR_INVALID_ID, func_inst) |
130 | << "Missing function type definition." ; |
131 | } |
132 | if (param_index >= function_type->words().size() - 3) { |
133 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
134 | << "Too many OpFunctionParameters for " << func_inst->id() |
135 | << ": expected " << function_type->words().size() - 3 |
136 | << " based on the function's type" ; |
137 | } |
138 | |
139 | const auto param_type = |
140 | _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); |
141 | if (!param_type || inst->type_id() != param_type->id()) { |
142 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
143 | << "OpFunctionParameter Result Type <id> '" |
144 | << _.getIdName(inst->type_id()) |
145 | << "' does not match the OpTypeFunction parameter " |
146 | "type of the same index." ; |
147 | } |
148 | |
149 | // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased, |
150 | // RestrictPointerEXT, or AliasedPointerEXT. |
151 | auto param_nonarray_type_id = param_type->id(); |
152 | while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { |
153 | param_nonarray_type_id = |
154 | _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u); |
155 | } |
156 | if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { |
157 | auto param_nonarray_type = _.FindDef(param_nonarray_type_id); |
158 | if (param_nonarray_type->GetOperandAs<uint32_t>(1u) == |
159 | SpvStorageClassPhysicalStorageBufferEXT) { |
160 | // check for Aliased or Restrict |
161 | const auto& decorations = _.id_decorations(inst->id()); |
162 | |
163 | bool foundAliased = std::any_of( |
164 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
165 | return SpvDecorationAliased == d.dec_type(); |
166 | }); |
167 | |
168 | bool foundRestrict = std::any_of( |
169 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
170 | return SpvDecorationRestrict == d.dec_type(); |
171 | }); |
172 | |
173 | if (!foundAliased && !foundRestrict) { |
174 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
175 | << "OpFunctionParameter " << inst->id() |
176 | << ": expected Aliased or Restrict for PhysicalStorageBufferEXT " |
177 | "pointer." ; |
178 | } |
179 | if (foundAliased && foundRestrict) { |
180 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
181 | << "OpFunctionParameter " << inst->id() |
182 | << ": can't specify both Aliased and Restrict for " |
183 | "PhysicalStorageBufferEXT pointer." ; |
184 | } |
185 | } else { |
186 | const auto pointee_type_id = |
187 | param_nonarray_type->GetOperandAs<uint32_t>(2); |
188 | const auto pointee_type = _.FindDef(pointee_type_id); |
189 | if (SpvOpTypePointer == pointee_type->opcode() && |
190 | pointee_type->GetOperandAs<uint32_t>(1u) == |
191 | SpvStorageClassPhysicalStorageBufferEXT) { |
192 | // check for AliasedPointerEXT/RestrictPointerEXT |
193 | const auto& decorations = _.id_decorations(inst->id()); |
194 | |
195 | bool foundAliased = std::any_of( |
196 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
197 | return SpvDecorationAliasedPointerEXT == d.dec_type(); |
198 | }); |
199 | |
200 | bool foundRestrict = std::any_of( |
201 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
202 | return SpvDecorationRestrictPointerEXT == d.dec_type(); |
203 | }); |
204 | |
205 | if (!foundAliased && !foundRestrict) { |
206 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
207 | << "OpFunctionParameter " << inst->id() |
208 | << ": expected AliasedPointerEXT or RestrictPointerEXT for " |
209 | "PhysicalStorageBufferEXT pointer." ; |
210 | } |
211 | if (foundAliased && foundRestrict) { |
212 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
213 | << "OpFunctionParameter " << inst->id() |
214 | << ": can't specify both AliasedPointerEXT and " |
215 | "RestrictPointerEXT for PhysicalStorageBufferEXT pointer." ; |
216 | } |
217 | } |
218 | } |
219 | } |
220 | |
221 | return SPV_SUCCESS; |
222 | } |
223 | |
224 | spv_result_t ValidateFunctionCall(ValidationState_t& _, |
225 | const Instruction* inst) { |
226 | const auto function_id = inst->GetOperandAs<uint32_t>(2); |
227 | const auto function = _.FindDef(function_id); |
228 | if (!function || SpvOpFunction != function->opcode()) { |
229 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
230 | << "OpFunctionCall Function <id> '" << _.getIdName(function_id) |
231 | << "' is not a function." ; |
232 | } |
233 | |
234 | auto return_type = _.FindDef(function->type_id()); |
235 | if (!return_type || return_type->id() != inst->type_id()) { |
236 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
237 | << "OpFunctionCall Result Type <id> '" |
238 | << _.getIdName(inst->type_id()) |
239 | << "'s type does not match Function <id> '" |
240 | << _.getIdName(return_type->id()) << "'s return type." ; |
241 | } |
242 | |
243 | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
244 | const auto function_type = _.FindDef(function_type_id); |
245 | if (!function_type || function_type->opcode() != SpvOpTypeFunction) { |
246 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
247 | << "Missing function type definition." ; |
248 | } |
249 | |
250 | const auto function_call_arg_count = inst->words().size() - 4; |
251 | const auto function_param_count = function_type->words().size() - 3; |
252 | if (function_param_count != function_call_arg_count) { |
253 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
254 | << "OpFunctionCall Function <id>'s parameter count does not match " |
255 | "the argument count." ; |
256 | } |
257 | |
258 | for (size_t argument_index = 3, param_index = 2; |
259 | argument_index < inst->operands().size(); |
260 | argument_index++, param_index++) { |
261 | const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); |
262 | const auto argument = _.FindDef(argument_id); |
263 | if (!argument) { |
264 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
265 | << "Missing argument " << argument_index - 3 << " definition." ; |
266 | } |
267 | |
268 | const auto argument_type = _.FindDef(argument->type_id()); |
269 | if (!argument_type) { |
270 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
271 | << "Missing argument " << argument_index - 3 |
272 | << " type definition." ; |
273 | } |
274 | |
275 | const auto parameter_type_id = |
276 | function_type->GetOperandAs<uint32_t>(param_index); |
277 | const auto parameter_type = _.FindDef(parameter_type_id); |
278 | if (!parameter_type || argument_type->id() != parameter_type->id()) { |
279 | if (!_.options()->before_hlsl_legalization || |
280 | !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { |
281 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
282 | << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id) |
283 | << "'s type does not match Function <id> '" |
284 | << _.getIdName(parameter_type_id) << "'s parameter type." ; |
285 | } |
286 | } |
287 | |
288 | if (_.addressing_model() == SpvAddressingModelLogical) { |
289 | if (parameter_type->opcode() == SpvOpTypePointer && |
290 | !_.options()->relax_logical_pointer) { |
291 | SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u); |
292 | // Validate which storage classes can be pointer operands. |
293 | switch (sc) { |
294 | case SpvStorageClassUniformConstant: |
295 | case SpvStorageClassFunction: |
296 | case SpvStorageClassPrivate: |
297 | case SpvStorageClassWorkgroup: |
298 | case SpvStorageClassAtomicCounter: |
299 | // These are always allowed. |
300 | break; |
301 | case SpvStorageClassStorageBuffer: |
302 | if (!_.features().variable_pointers_storage_buffer) { |
303 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
304 | << "StorageBuffer pointer operand " |
305 | << _.getIdName(argument_id) |
306 | << " requires a variable pointers capability" ; |
307 | } |
308 | break; |
309 | default: |
310 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
311 | << "Invalid storage class for pointer operand " |
312 | << _.getIdName(argument_id); |
313 | } |
314 | |
315 | // Validate memory object declaration requirements. |
316 | if (argument->opcode() != SpvOpVariable && |
317 | argument->opcode() != SpvOpFunctionParameter) { |
318 | const bool ssbo_vptr = |
319 | _.features().variable_pointers_storage_buffer && |
320 | sc == SpvStorageClassStorageBuffer; |
321 | const bool wg_vptr = |
322 | _.features().variable_pointers && sc == SpvStorageClassWorkgroup; |
323 | const bool uc_ptr = sc == SpvStorageClassUniformConstant; |
324 | if (!ssbo_vptr && !wg_vptr && !uc_ptr) { |
325 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
326 | << "Pointer operand " << _.getIdName(argument_id) |
327 | << " must be a memory object declaration" ; |
328 | } |
329 | } |
330 | } |
331 | } |
332 | } |
333 | return SPV_SUCCESS; |
334 | } |
335 | |
336 | } // namespace |
337 | |
338 | spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { |
339 | switch (inst->opcode()) { |
340 | case SpvOpFunction: |
341 | if (auto error = ValidateFunction(_, inst)) return error; |
342 | break; |
343 | case SpvOpFunctionParameter: |
344 | if (auto error = ValidateFunctionParameter(_, inst)) return error; |
345 | break; |
346 | case SpvOpFunctionCall: |
347 | if (auto error = ValidateFunctionCall(_, inst)) return error; |
348 | break; |
349 | default: |
350 | break; |
351 | } |
352 | |
353 | return SPV_SUCCESS; |
354 | } |
355 | |
356 | } // namespace val |
357 | } // namespace spvtools |
358 | |