| 1 | // Copyright (c) 2018 Google LLC. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include <algorithm> |
| 16 | |
| 17 | #include "source/opcode.h" |
| 18 | #include "source/val/instruction.h" |
| 19 | #include "source/val/validate.h" |
| 20 | #include "source/val/validation_state.h" |
| 21 | |
| 22 | namespace spvtools { |
| 23 | namespace val { |
| 24 | namespace { |
| 25 | |
| 26 | // Returns true if |a| and |b| are instructions defining pointers that point to |
| 27 | // types logically match and the decorations that apply to |b| are a subset |
| 28 | // of the decorations that apply to |a|. |
| 29 | bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, |
| 30 | ValidationState_t& _) { |
| 31 | if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) { |
| 32 | return false; |
| 33 | } |
| 34 | |
| 35 | const auto& dec_a = _.id_decorations(a->id()); |
| 36 | const auto& dec_b = _.id_decorations(b->id()); |
| 37 | for (const auto& dec : dec_b) { |
| 38 | if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { |
| 39 | return false; |
| 40 | } |
| 41 | } |
| 42 | |
| 43 | uint32_t a_type = a->GetOperandAs<uint32_t>(2); |
| 44 | uint32_t b_type = b->GetOperandAs<uint32_t>(2); |
| 45 | |
| 46 | if (a_type == b_type) { |
| 47 | return true; |
| 48 | } |
| 49 | |
| 50 | Instruction* a_type_inst = _.FindDef(a_type); |
| 51 | Instruction* b_type_inst = _.FindDef(b_type); |
| 52 | |
| 53 | return _.LogicallyMatch(a_type_inst, b_type_inst, true); |
| 54 | } |
| 55 | |
| 56 | spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { |
| 57 | const auto function_type_id = inst->GetOperandAs<uint32_t>(3); |
| 58 | const auto function_type = _.FindDef(function_type_id); |
| 59 | if (!function_type || SpvOpTypeFunction != function_type->opcode()) { |
| 60 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 61 | << "OpFunction Function Type <id> '" << _.getIdName(function_type_id) |
| 62 | << "' is not a function type." ; |
| 63 | } |
| 64 | |
| 65 | const auto return_id = function_type->GetOperandAs<uint32_t>(1); |
| 66 | if (return_id != inst->type_id()) { |
| 67 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 68 | << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id()) |
| 69 | << "' does not match the Function Type's return type <id> '" |
| 70 | << _.getIdName(return_id) << "'." ; |
| 71 | } |
| 72 | |
| 73 | const std::vector<SpvOp> acceptable = { |
| 74 | SpvOpDecorate, |
| 75 | SpvOpEnqueueKernel, |
| 76 | SpvOpEntryPoint, |
| 77 | SpvOpExecutionMode, |
| 78 | SpvOpExecutionModeId, |
| 79 | SpvOpFunctionCall, |
| 80 | SpvOpGetKernelNDrangeSubGroupCount, |
| 81 | SpvOpGetKernelNDrangeMaxSubGroupSize, |
| 82 | SpvOpGetKernelWorkGroupSize, |
| 83 | SpvOpGetKernelPreferredWorkGroupSizeMultiple, |
| 84 | SpvOpGetKernelLocalSizeForSubgroupCount, |
| 85 | SpvOpGetKernelMaxNumSubgroups, |
| 86 | SpvOpName}; |
| 87 | for (auto& pair : inst->uses()) { |
| 88 | const auto* use = pair.first; |
| 89 | if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == |
| 90 | acceptable.end() && |
| 91 | !use->IsNonSemantic() && !use->IsDebugInfo()) { |
| 92 | return _.diag(SPV_ERROR_INVALID_ID, use) |
| 93 | << "Invalid use of function result id " << _.getIdName(inst->id()) |
| 94 | << "." ; |
| 95 | } |
| 96 | } |
| 97 | |
| 98 | return SPV_SUCCESS; |
| 99 | } |
| 100 | |
| 101 | spv_result_t ValidateFunctionParameter(ValidationState_t& _, |
| 102 | const Instruction* inst) { |
| 103 | // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. |
| 104 | size_t param_index = 0; |
| 105 | size_t inst_num = inst->LineNum() - 1; |
| 106 | if (inst_num == 0) { |
| 107 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
| 108 | << "Function parameter cannot be the first instruction." ; |
| 109 | } |
| 110 | |
| 111 | auto func_inst = &_.ordered_instructions()[inst_num]; |
| 112 | while (--inst_num) { |
| 113 | func_inst = &_.ordered_instructions()[inst_num]; |
| 114 | if (func_inst->opcode() == SpvOpFunction) { |
| 115 | break; |
| 116 | } else if (func_inst->opcode() == SpvOpFunctionParameter) { |
| 117 | ++param_index; |
| 118 | } |
| 119 | } |
| 120 | |
| 121 | if (func_inst->opcode() != SpvOpFunction) { |
| 122 | return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) |
| 123 | << "Function parameter must be preceded by a function." ; |
| 124 | } |
| 125 | |
| 126 | const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); |
| 127 | const auto function_type = _.FindDef(function_type_id); |
| 128 | if (!function_type) { |
| 129 | return _.diag(SPV_ERROR_INVALID_ID, func_inst) |
| 130 | << "Missing function type definition." ; |
| 131 | } |
| 132 | if (param_index >= function_type->words().size() - 3) { |
| 133 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 134 | << "Too many OpFunctionParameters for " << func_inst->id() |
| 135 | << ": expected " << function_type->words().size() - 3 |
| 136 | << " based on the function's type" ; |
| 137 | } |
| 138 | |
| 139 | const auto param_type = |
| 140 | _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); |
| 141 | if (!param_type || inst->type_id() != param_type->id()) { |
| 142 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 143 | << "OpFunctionParameter Result Type <id> '" |
| 144 | << _.getIdName(inst->type_id()) |
| 145 | << "' does not match the OpTypeFunction parameter " |
| 146 | "type of the same index." ; |
| 147 | } |
| 148 | |
| 149 | // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased, |
| 150 | // RestrictPointerEXT, or AliasedPointerEXT. |
| 151 | auto param_nonarray_type_id = param_type->id(); |
| 152 | while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { |
| 153 | param_nonarray_type_id = |
| 154 | _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u); |
| 155 | } |
| 156 | if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { |
| 157 | auto param_nonarray_type = _.FindDef(param_nonarray_type_id); |
| 158 | if (param_nonarray_type->GetOperandAs<uint32_t>(1u) == |
| 159 | SpvStorageClassPhysicalStorageBufferEXT) { |
| 160 | // check for Aliased or Restrict |
| 161 | const auto& decorations = _.id_decorations(inst->id()); |
| 162 | |
| 163 | bool foundAliased = std::any_of( |
| 164 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
| 165 | return SpvDecorationAliased == d.dec_type(); |
| 166 | }); |
| 167 | |
| 168 | bool foundRestrict = std::any_of( |
| 169 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
| 170 | return SpvDecorationRestrict == d.dec_type(); |
| 171 | }); |
| 172 | |
| 173 | if (!foundAliased && !foundRestrict) { |
| 174 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 175 | << "OpFunctionParameter " << inst->id() |
| 176 | << ": expected Aliased or Restrict for PhysicalStorageBufferEXT " |
| 177 | "pointer." ; |
| 178 | } |
| 179 | if (foundAliased && foundRestrict) { |
| 180 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 181 | << "OpFunctionParameter " << inst->id() |
| 182 | << ": can't specify both Aliased and Restrict for " |
| 183 | "PhysicalStorageBufferEXT pointer." ; |
| 184 | } |
| 185 | } else { |
| 186 | const auto pointee_type_id = |
| 187 | param_nonarray_type->GetOperandAs<uint32_t>(2); |
| 188 | const auto pointee_type = _.FindDef(pointee_type_id); |
| 189 | if (SpvOpTypePointer == pointee_type->opcode() && |
| 190 | pointee_type->GetOperandAs<uint32_t>(1u) == |
| 191 | SpvStorageClassPhysicalStorageBufferEXT) { |
| 192 | // check for AliasedPointerEXT/RestrictPointerEXT |
| 193 | const auto& decorations = _.id_decorations(inst->id()); |
| 194 | |
| 195 | bool foundAliased = std::any_of( |
| 196 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
| 197 | return SpvDecorationAliasedPointerEXT == d.dec_type(); |
| 198 | }); |
| 199 | |
| 200 | bool foundRestrict = std::any_of( |
| 201 | decorations.begin(), decorations.end(), [](const Decoration& d) { |
| 202 | return SpvDecorationRestrictPointerEXT == d.dec_type(); |
| 203 | }); |
| 204 | |
| 205 | if (!foundAliased && !foundRestrict) { |
| 206 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 207 | << "OpFunctionParameter " << inst->id() |
| 208 | << ": expected AliasedPointerEXT or RestrictPointerEXT for " |
| 209 | "PhysicalStorageBufferEXT pointer." ; |
| 210 | } |
| 211 | if (foundAliased && foundRestrict) { |
| 212 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 213 | << "OpFunctionParameter " << inst->id() |
| 214 | << ": can't specify both AliasedPointerEXT and " |
| 215 | "RestrictPointerEXT for PhysicalStorageBufferEXT pointer." ; |
| 216 | } |
| 217 | } |
| 218 | } |
| 219 | } |
| 220 | |
| 221 | return SPV_SUCCESS; |
| 222 | } |
| 223 | |
| 224 | spv_result_t ValidateFunctionCall(ValidationState_t& _, |
| 225 | const Instruction* inst) { |
| 226 | const auto function_id = inst->GetOperandAs<uint32_t>(2); |
| 227 | const auto function = _.FindDef(function_id); |
| 228 | if (!function || SpvOpFunction != function->opcode()) { |
| 229 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 230 | << "OpFunctionCall Function <id> '" << _.getIdName(function_id) |
| 231 | << "' is not a function." ; |
| 232 | } |
| 233 | |
| 234 | auto return_type = _.FindDef(function->type_id()); |
| 235 | if (!return_type || return_type->id() != inst->type_id()) { |
| 236 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 237 | << "OpFunctionCall Result Type <id> '" |
| 238 | << _.getIdName(inst->type_id()) |
| 239 | << "'s type does not match Function <id> '" |
| 240 | << _.getIdName(return_type->id()) << "'s return type." ; |
| 241 | } |
| 242 | |
| 243 | const auto function_type_id = function->GetOperandAs<uint32_t>(3); |
| 244 | const auto function_type = _.FindDef(function_type_id); |
| 245 | if (!function_type || function_type->opcode() != SpvOpTypeFunction) { |
| 246 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 247 | << "Missing function type definition." ; |
| 248 | } |
| 249 | |
| 250 | const auto function_call_arg_count = inst->words().size() - 4; |
| 251 | const auto function_param_count = function_type->words().size() - 3; |
| 252 | if (function_param_count != function_call_arg_count) { |
| 253 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 254 | << "OpFunctionCall Function <id>'s parameter count does not match " |
| 255 | "the argument count." ; |
| 256 | } |
| 257 | |
| 258 | for (size_t argument_index = 3, param_index = 2; |
| 259 | argument_index < inst->operands().size(); |
| 260 | argument_index++, param_index++) { |
| 261 | const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); |
| 262 | const auto argument = _.FindDef(argument_id); |
| 263 | if (!argument) { |
| 264 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 265 | << "Missing argument " << argument_index - 3 << " definition." ; |
| 266 | } |
| 267 | |
| 268 | const auto argument_type = _.FindDef(argument->type_id()); |
| 269 | if (!argument_type) { |
| 270 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 271 | << "Missing argument " << argument_index - 3 |
| 272 | << " type definition." ; |
| 273 | } |
| 274 | |
| 275 | const auto parameter_type_id = |
| 276 | function_type->GetOperandAs<uint32_t>(param_index); |
| 277 | const auto parameter_type = _.FindDef(parameter_type_id); |
| 278 | if (!parameter_type || argument_type->id() != parameter_type->id()) { |
| 279 | if (!_.options()->before_hlsl_legalization || |
| 280 | !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { |
| 281 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 282 | << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id) |
| 283 | << "'s type does not match Function <id> '" |
| 284 | << _.getIdName(parameter_type_id) << "'s parameter type." ; |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | if (_.addressing_model() == SpvAddressingModelLogical) { |
| 289 | if (parameter_type->opcode() == SpvOpTypePointer && |
| 290 | !_.options()->relax_logical_pointer) { |
| 291 | SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u); |
| 292 | // Validate which storage classes can be pointer operands. |
| 293 | switch (sc) { |
| 294 | case SpvStorageClassUniformConstant: |
| 295 | case SpvStorageClassFunction: |
| 296 | case SpvStorageClassPrivate: |
| 297 | case SpvStorageClassWorkgroup: |
| 298 | case SpvStorageClassAtomicCounter: |
| 299 | // These are always allowed. |
| 300 | break; |
| 301 | case SpvStorageClassStorageBuffer: |
| 302 | if (!_.features().variable_pointers_storage_buffer) { |
| 303 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 304 | << "StorageBuffer pointer operand " |
| 305 | << _.getIdName(argument_id) |
| 306 | << " requires a variable pointers capability" ; |
| 307 | } |
| 308 | break; |
| 309 | default: |
| 310 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 311 | << "Invalid storage class for pointer operand " |
| 312 | << _.getIdName(argument_id); |
| 313 | } |
| 314 | |
| 315 | // Validate memory object declaration requirements. |
| 316 | if (argument->opcode() != SpvOpVariable && |
| 317 | argument->opcode() != SpvOpFunctionParameter) { |
| 318 | const bool ssbo_vptr = |
| 319 | _.features().variable_pointers_storage_buffer && |
| 320 | sc == SpvStorageClassStorageBuffer; |
| 321 | const bool wg_vptr = |
| 322 | _.features().variable_pointers && sc == SpvStorageClassWorkgroup; |
| 323 | const bool uc_ptr = sc == SpvStorageClassUniformConstant; |
| 324 | if (!ssbo_vptr && !wg_vptr && !uc_ptr) { |
| 325 | return _.diag(SPV_ERROR_INVALID_ID, inst) |
| 326 | << "Pointer operand " << _.getIdName(argument_id) |
| 327 | << " must be a memory object declaration" ; |
| 328 | } |
| 329 | } |
| 330 | } |
| 331 | } |
| 332 | } |
| 333 | return SPV_SUCCESS; |
| 334 | } |
| 335 | |
| 336 | } // namespace |
| 337 | |
| 338 | spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { |
| 339 | switch (inst->opcode()) { |
| 340 | case SpvOpFunction: |
| 341 | if (auto error = ValidateFunction(_, inst)) return error; |
| 342 | break; |
| 343 | case SpvOpFunctionParameter: |
| 344 | if (auto error = ValidateFunctionParameter(_, inst)) return error; |
| 345 | break; |
| 346 | case SpvOpFunctionCall: |
| 347 | if (auto error = ValidateFunctionCall(_, inst)) return error; |
| 348 | break; |
| 349 | default: |
| 350 | break; |
| 351 | } |
| 352 | |
| 353 | return SPV_SUCCESS; |
| 354 | } |
| 355 | |
| 356 | } // namespace val |
| 357 | } // namespace spvtools |
| 358 | |