| 1 | // Copyright (c) 2019 Google LLC. | 
|---|
| 2 | // | 
|---|
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); | 
|---|
| 4 | // you may not use this file except in compliance with the License. | 
|---|
| 5 | // You may obtain a copy of the License at | 
|---|
| 6 | // | 
|---|
| 7 | //     http://www.apache.org/licenses/LICENSE-2.0 | 
|---|
| 8 | // | 
|---|
| 9 | // Unless required by applicable law or agreed to in writing, software | 
|---|
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, | 
|---|
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|---|
| 12 | // See the License for the specific language governing permissions and | 
|---|
| 13 | // limitations under the License. | 
|---|
| 14 |  | 
|---|
| 15 | #include "source/opt/amd_ext_to_khr.h" | 
|---|
| 16 |  | 
|---|
| 17 | #include <set> | 
|---|
| 18 | #include <string> | 
|---|
| 19 |  | 
|---|
| 20 | #include "ir_builder.h" | 
|---|
| 21 | #include "source/opt/ir_context.h" | 
|---|
| 22 | #include "spv-amd-shader-ballot.insts.inc" | 
|---|
| 23 | #include "type_manager.h" | 
|---|
| 24 |  | 
|---|
| 25 | namespace spvtools { | 
|---|
| 26 | namespace opt { | 
|---|
| 27 |  | 
|---|
| 28 | namespace { | 
|---|
| 29 |  | 
|---|
| 30 | enum AmdShaderBallotExtOpcodes { | 
|---|
| 31 | AmdShaderBallotSwizzleInvocationsAMD = 1, | 
|---|
| 32 | AmdShaderBallotSwizzleInvocationsMaskedAMD = 2, | 
|---|
| 33 | AmdShaderBallotWriteInvocationAMD = 3, | 
|---|
| 34 | AmdShaderBallotMbcntAMD = 4 | 
|---|
| 35 | }; | 
|---|
| 36 |  | 
|---|
| 37 | enum AmdShaderTrinaryMinMaxExtOpCodes { | 
|---|
| 38 | FMin3AMD = 1, | 
|---|
| 39 | UMin3AMD = 2, | 
|---|
| 40 | SMin3AMD = 3, | 
|---|
| 41 | FMax3AMD = 4, | 
|---|
| 42 | UMax3AMD = 5, | 
|---|
| 43 | SMax3AMD = 6, | 
|---|
| 44 | FMid3AMD = 7, | 
|---|
| 45 | UMid3AMD = 8, | 
|---|
| 46 | SMid3AMD = 9 | 
|---|
| 47 | }; | 
|---|
| 48 |  | 
|---|
| 49 | enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 }; | 
|---|
| 50 |  | 
|---|
| 51 | analysis::Type* GetUIntType(IRContext* ctx) { | 
|---|
| 52 | analysis::Integer int_type(32, false); | 
|---|
| 53 | return ctx->get_type_mgr()->GetRegisteredType(&int_type); | 
|---|
| 54 | } | 
|---|
| 55 |  | 
|---|
| 56 | // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where | 
|---|
| 57 | // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450 | 
|---|
| 58 | // extended instruction set that corresponds to the trinary instruction being | 
|---|
| 59 | // replaced. | 
|---|
| 60 | template <GLSLstd450 opcode> | 
|---|
| 61 | bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst, | 
|---|
| 62 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 63 | uint32_t glsl405_ext_inst_id = | 
|---|
| 64 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 65 | if (glsl405_ext_inst_id == 0) { | 
|---|
| 66 | ctx->AddExtInstImport( "GLSL.std.450"); | 
|---|
| 67 | glsl405_ext_inst_id = | 
|---|
| 68 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 69 | } | 
|---|
| 70 |  | 
|---|
| 71 | InstructionBuilder ir_builder( | 
|---|
| 72 | ctx, inst, | 
|---|
| 73 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 74 |  | 
|---|
| 75 | uint32_t op1 = inst->GetSingleWordInOperand(2); | 
|---|
| 76 | uint32_t op2 = inst->GetSingleWordInOperand(3); | 
|---|
| 77 | uint32_t op3 = inst->GetSingleWordInOperand(4); | 
|---|
| 78 |  | 
|---|
| 79 | Instruction* temp = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 80 | inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2}); | 
|---|
| 81 |  | 
|---|
| 82 | Instruction::OperandList new_operands; | 
|---|
| 83 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); | 
|---|
| 84 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, | 
|---|
| 85 | {static_cast<uint32_t>(opcode)}}); | 
|---|
| 86 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}}); | 
|---|
| 87 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}}); | 
|---|
| 88 |  | 
|---|
| 89 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 90 | ctx->UpdateDefUse(inst); | 
|---|
| 91 | return true; | 
|---|
| 92 | } | 
|---|
| 93 |  | 
|---|
| 94 | // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c), | 
|---|
| 95 | // max(b,c)|. The three parameters are the opcode that correspond to the min, | 
|---|
| 96 | // max, and clamp operations for the type of the instruction being replaced. | 
|---|
| 97 | template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode> | 
|---|
| 98 | bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst, | 
|---|
| 99 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 100 | uint32_t glsl405_ext_inst_id = | 
|---|
| 101 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 102 | if (glsl405_ext_inst_id == 0) { | 
|---|
| 103 | ctx->AddExtInstImport( "GLSL.std.450"); | 
|---|
| 104 | glsl405_ext_inst_id = | 
|---|
| 105 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 106 | } | 
|---|
| 107 |  | 
|---|
| 108 | InstructionBuilder ir_builder( | 
|---|
| 109 | ctx, inst, | 
|---|
| 110 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 111 |  | 
|---|
| 112 | uint32_t op1 = inst->GetSingleWordInOperand(2); | 
|---|
| 113 | uint32_t op2 = inst->GetSingleWordInOperand(3); | 
|---|
| 114 | uint32_t op3 = inst->GetSingleWordInOperand(4); | 
|---|
| 115 |  | 
|---|
| 116 | Instruction* min = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 117 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode), | 
|---|
| 118 | {op2, op3}); | 
|---|
| 119 | Instruction* max = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 120 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode), | 
|---|
| 121 | {op2, op3}); | 
|---|
| 122 |  | 
|---|
| 123 | Instruction::OperandList new_operands; | 
|---|
| 124 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); | 
|---|
| 125 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, | 
|---|
| 126 | {static_cast<uint32_t>(clamp_opcode)}}); | 
|---|
| 127 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}}); | 
|---|
| 128 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}}); | 
|---|
| 129 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}}); | 
|---|
| 130 |  | 
|---|
| 131 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 132 | ctx->UpdateDefUse(inst); | 
|---|
| 133 | return true; | 
|---|
| 134 | } | 
|---|
| 135 |  | 
|---|
| 136 | // Returns a folding rule that will replace the opcode with |opcode| and add | 
|---|
| 137 | // the capabilities required.  The folding rule assumes it is folding an | 
|---|
| 138 | // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension. | 
|---|
| 139 | template <SpvOp new_opcode> | 
|---|
| 140 | bool ReplaceGroupNonuniformOperationOpCode( | 
|---|
| 141 | IRContext* ctx, Instruction* inst, | 
|---|
| 142 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 143 | switch (new_opcode) { | 
|---|
| 144 | case SpvOpGroupNonUniformIAdd: | 
|---|
| 145 | case SpvOpGroupNonUniformFAdd: | 
|---|
| 146 | case SpvOpGroupNonUniformUMin: | 
|---|
| 147 | case SpvOpGroupNonUniformSMin: | 
|---|
| 148 | case SpvOpGroupNonUniformFMin: | 
|---|
| 149 | case SpvOpGroupNonUniformUMax: | 
|---|
| 150 | case SpvOpGroupNonUniformSMax: | 
|---|
| 151 | case SpvOpGroupNonUniformFMax: | 
|---|
| 152 | break; | 
|---|
| 153 | default: | 
|---|
| 154 | assert( | 
|---|
| 155 | false && | 
|---|
| 156 | "Should be replacing with a group non uniform arithmetic operation."); | 
|---|
| 157 | } | 
|---|
| 158 |  | 
|---|
| 159 | switch (inst->opcode()) { | 
|---|
| 160 | case SpvOpGroupIAddNonUniformAMD: | 
|---|
| 161 | case SpvOpGroupFAddNonUniformAMD: | 
|---|
| 162 | case SpvOpGroupUMinNonUniformAMD: | 
|---|
| 163 | case SpvOpGroupSMinNonUniformAMD: | 
|---|
| 164 | case SpvOpGroupFMinNonUniformAMD: | 
|---|
| 165 | case SpvOpGroupUMaxNonUniformAMD: | 
|---|
| 166 | case SpvOpGroupSMaxNonUniformAMD: | 
|---|
| 167 | case SpvOpGroupFMaxNonUniformAMD: | 
|---|
| 168 | break; | 
|---|
| 169 | default: | 
|---|
| 170 | assert(false && | 
|---|
| 171 | "Should be replacing a group non uniform arithmetic operation."); | 
|---|
| 172 | } | 
|---|
| 173 |  | 
|---|
| 174 | ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic); | 
|---|
| 175 | inst->SetOpcode(new_opcode); | 
|---|
| 176 | return true; | 
|---|
| 177 | } | 
|---|
| 178 |  | 
|---|
| 179 | // Returns a folding rule that will replace the SwizzleInvocationsAMD extended | 
|---|
| 180 | // instruction in the SPV_AMD_shader_ballot extension. | 
|---|
| 181 | // | 
|---|
| 182 | // The instruction | 
|---|
| 183 | // | 
|---|
| 184 | //  %offset = OpConstantComposite %v3uint %x %y %z %w | 
|---|
| 185 | //  %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset | 
|---|
| 186 | // | 
|---|
| 187 | // is replaced with | 
|---|
| 188 | // | 
|---|
| 189 | // potentially new constants and types | 
|---|
| 190 | // | 
|---|
| 191 | // clang-format off | 
|---|
| 192 | //         %uint_max = OpConstant %uint 0xFFFFFFFF | 
|---|
| 193 | //           %v4uint = OpTypeVector %uint 4 | 
|---|
| 194 | //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max | 
|---|
| 195 | //             %null = OpConstantNull %type | 
|---|
| 196 | // clang-format on | 
|---|
| 197 | // | 
|---|
| 198 | // and the following code in the function body | 
|---|
| 199 | // | 
|---|
| 200 | // clang-format off | 
|---|
| 201 | //         %id = OpLoad %uint %SubgroupLocalInvocationId | 
|---|
| 202 | //   %quad_idx = OpBitwiseAnd %uint %id %uint_3 | 
|---|
| 203 | //   %quad_ldr = OpBitwiseXor %uint %id %quad_idx | 
|---|
| 204 | //  %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx | 
|---|
| 205 | // %target_inv = OpIAdd %uint %quad_ldr %my_offset | 
|---|
| 206 | //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv | 
|---|
| 207 | //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv | 
|---|
| 208 | //     %result = OpSelect %type %is_active %shuffle %null | 
|---|
| 209 | // clang-format on | 
|---|
| 210 | // | 
|---|
| 211 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 212 | bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst, | 
|---|
| 213 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 214 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); | 
|---|
| 215 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); | 
|---|
| 216 |  | 
|---|
| 217 | ctx->AddExtension( "SPV_KHR_shader_ballot"); | 
|---|
| 218 | ctx->AddCapability(SpvCapabilityGroupNonUniformBallot); | 
|---|
| 219 | ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle); | 
|---|
| 220 |  | 
|---|
| 221 | InstructionBuilder ir_builder( | 
|---|
| 222 | ctx, inst, | 
|---|
| 223 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 224 |  | 
|---|
| 225 | uint32_t data_id = inst->GetSingleWordInOperand(2); | 
|---|
| 226 | uint32_t offset_id = inst->GetSingleWordInOperand(3); | 
|---|
| 227 |  | 
|---|
| 228 | // Get the subgroup invocation id. | 
|---|
| 229 | uint32_t var_id = | 
|---|
| 230 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); | 
|---|
| 231 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable."); | 
|---|
| 232 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); | 
|---|
| 233 | Instruction* var_ptr_type = | 
|---|
| 234 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); | 
|---|
| 235 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); | 
|---|
| 236 |  | 
|---|
| 237 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); | 
|---|
| 238 |  | 
|---|
| 239 | uint32_t quad_mask = ir_builder.GetUintConstantId(3); | 
|---|
| 240 |  | 
|---|
| 241 | // This gives the offset in the group of 4 of this invocation. | 
|---|
| 242 | Instruction* quad_idx = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseAnd, | 
|---|
| 243 | id->result_id(), quad_mask); | 
|---|
| 244 |  | 
|---|
| 245 | // Get the invocation id of the first invocation in the group of 4. | 
|---|
| 246 | Instruction* quad_ldr = ir_builder.AddBinaryOp( | 
|---|
| 247 | uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id()); | 
|---|
| 248 |  | 
|---|
| 249 | // Get the offset of the target invocation from the offset vector. | 
|---|
| 250 | Instruction* my_offset = | 
|---|
| 251 | ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic, offset_id, | 
|---|
| 252 | quad_idx->result_id()); | 
|---|
| 253 |  | 
|---|
| 254 | // Determine the index of the invocation to read from. | 
|---|
| 255 | Instruction* target_inv = ir_builder.AddBinaryOp( | 
|---|
| 256 | uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id()); | 
|---|
| 257 |  | 
|---|
| 258 | // Do the group operations | 
|---|
| 259 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); | 
|---|
| 260 | uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup); | 
|---|
| 261 | const auto* ballot_value_const = const_mgr->GetConstant( | 
|---|
| 262 | type_mgr->GetUIntVectorType(4), | 
|---|
| 263 | {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); | 
|---|
| 264 | Instruction* ballot_value = | 
|---|
| 265 | const_mgr->GetDefiningInstruction(ballot_value_const); | 
|---|
| 266 | Instruction* is_active = ir_builder.AddNaryOp( | 
|---|
| 267 | type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract, | 
|---|
| 268 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); | 
|---|
| 269 | Instruction* shuffle = | 
|---|
| 270 | ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle, | 
|---|
| 271 | {subgroup_scope, data_id, target_inv->result_id()}); | 
|---|
| 272 |  | 
|---|
| 273 | // Create the null constant to use in the select. | 
|---|
| 274 | const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), | 
|---|
| 275 | std::vector<uint32_t>()); | 
|---|
| 276 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); | 
|---|
| 277 |  | 
|---|
| 278 | // Build the select. | 
|---|
| 279 | inst->SetOpcode(SpvOpSelect); | 
|---|
| 280 | Instruction::OperandList new_operands; | 
|---|
| 281 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); | 
|---|
| 282 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); | 
|---|
| 283 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); | 
|---|
| 284 |  | 
|---|
| 285 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 286 | ctx->UpdateDefUse(inst); | 
|---|
| 287 | return true; | 
|---|
| 288 | } | 
|---|
| 289 |  | 
|---|
| 290 | // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD | 
|---|
| 291 | // extended instruction in the SPV_AMD_shader_ballot extension. | 
|---|
| 292 | // | 
|---|
| 293 | // The instruction | 
|---|
| 294 | // | 
|---|
| 295 | //    %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z | 
|---|
| 296 | //  %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask | 
|---|
| 297 | // | 
|---|
| 298 | // is replaced with | 
|---|
| 299 | // | 
|---|
| 300 | // potentially new constants and types | 
|---|
| 301 | // | 
|---|
| 302 | // clang-format off | 
|---|
| 303 | // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0 | 
|---|
| 304 | //         %uint_max = OpConstant %uint 0xFFFFFFFF | 
|---|
| 305 | //           %v4uint = OpTypeVector %uint 4 | 
|---|
| 306 | //     %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max | 
|---|
| 307 | // clang-format on | 
|---|
| 308 | // | 
|---|
| 309 | // and the following code in the function body | 
|---|
| 310 | // | 
|---|
| 311 | // clang-format off | 
|---|
| 312 | //         %id = OpLoad %uint %SubgroupLocalInvocationId | 
|---|
| 313 | //   %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend | 
|---|
| 314 | //        %and = OpBitwiseAnd %uint %id %and_mask | 
|---|
| 315 | //         %or = OpBitwiseOr %uint %and %uint_y | 
|---|
| 316 | // %target_inv = OpBitwiseXor %uint %or %uint_z | 
|---|
| 317 | //  %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv | 
|---|
| 318 | //    %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv | 
|---|
| 319 | //     %result = OpSelect %type %is_active %shuffle %uint_0 | 
|---|
| 320 | // clang-format on | 
|---|
| 321 | // | 
|---|
| 322 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 323 | bool ReplaceSwizzleInvocationsMasked( | 
|---|
| 324 | IRContext* ctx, Instruction* inst, | 
|---|
| 325 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 326 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); | 
|---|
| 327 | analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr(); | 
|---|
| 328 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); | 
|---|
| 329 |  | 
|---|
| 330 | ctx->AddCapability(SpvCapabilityGroupNonUniformBallot); | 
|---|
| 331 | ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle); | 
|---|
| 332 |  | 
|---|
| 333 | InstructionBuilder ir_builder( | 
|---|
| 334 | ctx, inst, | 
|---|
| 335 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 336 |  | 
|---|
| 337 | // Get the operands to inst, and the components of the mask | 
|---|
| 338 | uint32_t data_id = inst->GetSingleWordInOperand(2); | 
|---|
| 339 |  | 
|---|
| 340 | Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3)); | 
|---|
| 341 | assert(mask_inst->opcode() == SpvOpConstantComposite && | 
|---|
| 342 | "The mask is suppose to be a vector constant."); | 
|---|
| 343 | assert(mask_inst->NumInOperands() == 3 && | 
|---|
| 344 | "The mask is suppose to have 3 components."); | 
|---|
| 345 |  | 
|---|
| 346 | uint32_t uint_x = mask_inst->GetSingleWordInOperand(0); | 
|---|
| 347 | uint32_t uint_y = mask_inst->GetSingleWordInOperand(1); | 
|---|
| 348 | uint32_t uint_z = mask_inst->GetSingleWordInOperand(2); | 
|---|
| 349 |  | 
|---|
| 350 | // Get the subgroup invocation id. | 
|---|
| 351 | uint32_t var_id = | 
|---|
| 352 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); | 
|---|
| 353 | ctx->AddExtension( "SPV_KHR_shader_ballot"); | 
|---|
| 354 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable."); | 
|---|
| 355 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); | 
|---|
| 356 | Instruction* var_ptr_type = | 
|---|
| 357 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); | 
|---|
| 358 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); | 
|---|
| 359 |  | 
|---|
| 360 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); | 
|---|
| 361 |  | 
|---|
| 362 | // Do the bitwise operations. | 
|---|
| 363 | uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0); | 
|---|
| 364 | Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr, | 
|---|
| 365 | uint_x, mask_extended); | 
|---|
| 366 | Instruction* and_result = ir_builder.AddBinaryOp( | 
|---|
| 367 | uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id()); | 
|---|
| 368 | Instruction* or_result = ir_builder.AddBinaryOp( | 
|---|
| 369 | uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y); | 
|---|
| 370 | Instruction* target_inv = ir_builder.AddBinaryOp( | 
|---|
| 371 | uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z); | 
|---|
| 372 |  | 
|---|
| 373 | // Do the group operations | 
|---|
| 374 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); | 
|---|
| 375 | uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup); | 
|---|
| 376 | const auto* ballot_value_const = const_mgr->GetConstant( | 
|---|
| 377 | type_mgr->GetUIntVectorType(4), | 
|---|
| 378 | {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); | 
|---|
| 379 | Instruction* ballot_value = | 
|---|
| 380 | const_mgr->GetDefiningInstruction(ballot_value_const); | 
|---|
| 381 | Instruction* is_active = ir_builder.AddNaryOp( | 
|---|
| 382 | type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract, | 
|---|
| 383 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); | 
|---|
| 384 | Instruction* shuffle = | 
|---|
| 385 | ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle, | 
|---|
| 386 | {subgroup_scope, data_id, target_inv->result_id()}); | 
|---|
| 387 |  | 
|---|
| 388 | // Create the null constant to use in the select. | 
|---|
| 389 | const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), | 
|---|
| 390 | std::vector<uint32_t>()); | 
|---|
| 391 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); | 
|---|
| 392 |  | 
|---|
| 393 | // Build the select. | 
|---|
| 394 | inst->SetOpcode(SpvOpSelect); | 
|---|
| 395 | Instruction::OperandList new_operands; | 
|---|
| 396 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); | 
|---|
| 397 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); | 
|---|
| 398 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); | 
|---|
| 399 |  | 
|---|
| 400 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 401 | ctx->UpdateDefUse(inst); | 
|---|
| 402 | return true; | 
|---|
| 403 | } | 
|---|
| 404 |  | 
|---|
| 405 | // Returns a folding rule that will replace the WriteInvocationAMD extended | 
|---|
| 406 | // instruction in the SPV_AMD_shader_ballot extension. | 
|---|
| 407 | // | 
|---|
| 408 | // The instruction | 
|---|
| 409 | // | 
|---|
| 410 | // clang-format off | 
|---|
| 411 | //    %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index | 
|---|
| 412 | // clang-format on | 
|---|
| 413 | // | 
|---|
| 414 | // with | 
|---|
| 415 | // | 
|---|
| 416 | //     %id = OpLoad %uint %SubgroupLocalInvocationId | 
|---|
| 417 | //    %cmp = OpIEqual %bool %id %invocation_index | 
|---|
| 418 | // %result = OpSelect %type %cmp %write_value %input_value | 
|---|
| 419 | // | 
|---|
| 420 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 421 | bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst, | 
|---|
| 422 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 423 | uint32_t var_id = | 
|---|
| 424 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); | 
|---|
| 425 | ctx->AddCapability(SpvCapabilitySubgroupBallotKHR); | 
|---|
| 426 | ctx->AddExtension( "SPV_KHR_shader_ballot"); | 
|---|
| 427 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable."); | 
|---|
| 428 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); | 
|---|
| 429 | Instruction* var_ptr_type = | 
|---|
| 430 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); | 
|---|
| 431 |  | 
|---|
| 432 | InstructionBuilder ir_builder( | 
|---|
| 433 | ctx, inst, | 
|---|
| 434 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 435 | Instruction* t = | 
|---|
| 436 | ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id); | 
|---|
| 437 | analysis::Bool bool_type; | 
|---|
| 438 | uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type); | 
|---|
| 439 | Instruction* cmp = | 
|---|
| 440 | ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(), | 
|---|
| 441 | inst->GetSingleWordInOperand(4)); | 
|---|
| 442 |  | 
|---|
| 443 | // Build a select. | 
|---|
| 444 | inst->SetOpcode(SpvOpSelect); | 
|---|
| 445 | Instruction::OperandList new_operands; | 
|---|
| 446 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}}); | 
|---|
| 447 | new_operands.push_back(inst->GetInOperand(3)); | 
|---|
| 448 | new_operands.push_back(inst->GetInOperand(2)); | 
|---|
| 449 |  | 
|---|
| 450 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 451 | ctx->UpdateDefUse(inst); | 
|---|
| 452 | return true; | 
|---|
| 453 | } | 
|---|
| 454 |  | 
|---|
| 455 | // Returns a folding rule that will replace the MbcntAMD extended instruction in | 
|---|
| 456 | // the SPV_AMD_shader_ballot extension. | 
|---|
| 457 | // | 
|---|
| 458 | // The instruction | 
|---|
| 459 | // | 
|---|
| 460 | //  %result = OpExtInst %uint %1 MbcntAMD %mask | 
|---|
| 461 | // | 
|---|
| 462 | // with | 
|---|
| 463 | // | 
|---|
| 464 | // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because | 
|---|
| 465 | // AMD's shader compiler expects a 64-bit integer mask. | 
|---|
| 466 | // | 
|---|
| 467 | //     %var = OpLoad %v4uint %SubgroupLtMaskKHR | 
|---|
| 468 | // %shuffle = OpVectorShuffle %v2uint %var %var 0 1 | 
|---|
| 469 | //    %cast = OpBitcast %ulong %shuffle | 
|---|
| 470 | // | 
|---|
| 471 | // Perform the mask and count the bits. | 
|---|
| 472 | // | 
|---|
| 473 | //     %and = OpBitwiseAnd %ulong %cast %mask | 
|---|
| 474 | //  %result = OpBitCount %uint %and | 
|---|
| 475 | // | 
|---|
| 476 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 477 | bool ReplaceMbcnt(IRContext* context, Instruction* inst, | 
|---|
| 478 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 479 | analysis::TypeManager* type_mgr = context->get_type_mgr(); | 
|---|
| 480 | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); | 
|---|
| 481 |  | 
|---|
| 482 | uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask); | 
|---|
| 483 | assert(var_id != 0 && "Could not get SubgroupLtMask variable."); | 
|---|
| 484 | context->AddCapability(SpvCapabilityGroupNonUniformBallot); | 
|---|
| 485 | Instruction* var_inst = def_use_mgr->GetDef(var_id); | 
|---|
| 486 | Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id()); | 
|---|
| 487 | Instruction* var_type = | 
|---|
| 488 | def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1)); | 
|---|
| 489 | assert(var_type->opcode() == SpvOpTypeVector && | 
|---|
| 490 | "Variable is suppose to be a vector of 4 ints"); | 
|---|
| 491 |  | 
|---|
| 492 | // Get the type for the shuffle. | 
|---|
| 493 | analysis::Vector temp_type(GetUIntType(context), 2); | 
|---|
| 494 | const analysis::Type* shuffle_type = | 
|---|
| 495 | context->get_type_mgr()->GetRegisteredType(&temp_type); | 
|---|
| 496 | uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type); | 
|---|
| 497 |  | 
|---|
| 498 | uint32_t mask_id = inst->GetSingleWordInOperand(2); | 
|---|
| 499 | Instruction* mask_inst = def_use_mgr->GetDef(mask_id); | 
|---|
| 500 |  | 
|---|
| 501 | // Testing with amd's shader compiler shows that a 64-bit mask is expected. | 
|---|
| 502 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr); | 
|---|
| 503 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64); | 
|---|
| 504 |  | 
|---|
| 505 | InstructionBuilder ir_builder( | 
|---|
| 506 | context, inst, | 
|---|
| 507 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 508 | Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id); | 
|---|
| 509 | Instruction* shuffle = ir_builder.AddVectorShuffle( | 
|---|
| 510 | shuffle_type_id, load->result_id(), load->result_id(), {0, 1}); | 
|---|
| 511 | Instruction* bitcast = ir_builder.AddUnaryOp( | 
|---|
| 512 | mask_inst->type_id(), SpvOpBitcast, shuffle->result_id()); | 
|---|
| 513 | Instruction* t = ir_builder.AddBinaryOp(mask_inst->type_id(), SpvOpBitwiseAnd, | 
|---|
| 514 | bitcast->result_id(), mask_id); | 
|---|
| 515 |  | 
|---|
| 516 | inst->SetOpcode(SpvOpBitCount); | 
|---|
| 517 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}}); | 
|---|
| 518 | context->UpdateDefUse(inst); | 
|---|
| 519 | return true; | 
|---|
| 520 | } | 
|---|
| 521 |  | 
|---|
| 522 | // A folding rule that will replace the CubeFaceCoordAMD extended | 
|---|
| 523 | // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding is | 
|---|
| 524 | // successful. | 
|---|
| 525 | // | 
|---|
| 526 | // The instruction | 
|---|
| 527 | // | 
|---|
| 528 | //  %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input | 
|---|
| 529 | // | 
|---|
| 530 | // with | 
|---|
| 531 | // | 
|---|
| 532 | //             %x = OpCompositeExtract %float %input 0 | 
|---|
| 533 | //             %y = OpCompositeExtract %float %input 1 | 
|---|
| 534 | //             %z = OpCompositeExtract %float %input 2 | 
|---|
| 535 | //            %nx = OpFNegate %float %x | 
|---|
| 536 | //            %ny = OpFNegate %float %y | 
|---|
| 537 | //            %nz = OpFNegate %float %z | 
|---|
| 538 | //            %ax = OpExtInst %float %n_1 FAbs %x | 
|---|
| 539 | //            %ay = OpExtInst %float %n_1 FAbs %y | 
|---|
| 540 | //            %az = OpExtInst %float %n_1 FAbs %z | 
|---|
| 541 | //      %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax | 
|---|
| 542 | //          %amax = OpExtInst %float %n_1 FMax %az %amax_x_y | 
|---|
| 543 | //        %cubema = OpFMul %float %float_2 %amax | 
|---|
| 544 | //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y | 
|---|
| 545 | //  %not_is_z_max = OpLogicalNot %bool %is_z_max | 
|---|
| 546 | //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax | 
|---|
| 547 | //      %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x | 
|---|
| 548 | //      %is_z_neg = OpFOrdLessThan %bool %z %float_0 | 
|---|
| 549 | // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x | 
|---|
| 550 | //      %is_x_neg = OpFOrdLessThan %bool %x %float_0 | 
|---|
| 551 | // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz | 
|---|
| 552 | //           %sel = OpSelect %float %is_y_max %x %cubesc_case_2 | 
|---|
| 553 | //        %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel | 
|---|
| 554 | //      %is_y_neg = OpFOrdLessThan %bool %y %float_0 | 
|---|
| 555 | // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z | 
|---|
| 556 | //        %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny | 
|---|
| 557 | //          %cube = OpCompositeConstruct %v2float %cubesc %cubetc | 
|---|
| 558 | //         %denom = OpCompositeConstruct %v2float %cubema %cubema | 
|---|
| 559 | //           %div = OpFDiv %v2float %cube %denom | 
|---|
| 560 | //        %result = OpFAdd %v2float %div %const | 
|---|
| 561 | // | 
|---|
| 562 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 563 | bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst, | 
|---|
| 564 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 565 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); | 
|---|
| 566 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); | 
|---|
| 567 |  | 
|---|
| 568 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); | 
|---|
| 569 | const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2); | 
|---|
| 570 | uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type); | 
|---|
| 571 | uint32_t bool_id = type_mgr->GetBoolTypeId(); | 
|---|
| 572 |  | 
|---|
| 573 | InstructionBuilder ir_builder( | 
|---|
| 574 | ctx, inst, | 
|---|
| 575 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 576 |  | 
|---|
| 577 | uint32_t input_id = inst->GetSingleWordInOperand(2); | 
|---|
| 578 | uint32_t glsl405_ext_inst_id = | 
|---|
| 579 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 580 | if (glsl405_ext_inst_id == 0) { | 
|---|
| 581 | ctx->AddExtInstImport( "GLSL.std.450"); | 
|---|
| 582 | glsl405_ext_inst_id = | 
|---|
| 583 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 584 | } | 
|---|
| 585 |  | 
|---|
| 586 | // Get the constants that will be used. | 
|---|
| 587 | uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); | 
|---|
| 588 | uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); | 
|---|
| 589 | uint32_t f0_5_const_id = const_mgr->GetFloatConst(0.5); | 
|---|
| 590 | const analysis::Constant* vec_const = | 
|---|
| 591 | const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id}); | 
|---|
| 592 | uint32_t vec_const_id = | 
|---|
| 593 | const_mgr->GetDefiningInstruction(vec_const)->result_id(); | 
|---|
| 594 |  | 
|---|
| 595 | // Extract the input values. | 
|---|
| 596 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); | 
|---|
| 597 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); | 
|---|
| 598 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); | 
|---|
| 599 |  | 
|---|
| 600 | // Negate the input values. | 
|---|
| 601 | Instruction* nx = | 
|---|
| 602 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, x->result_id()); | 
|---|
| 603 | Instruction* ny = | 
|---|
| 604 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, y->result_id()); | 
|---|
| 605 | Instruction* nz = | 
|---|
| 606 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, z->result_id()); | 
|---|
| 607 |  | 
|---|
| 608 | // Get the abolsute values of the inputs. | 
|---|
| 609 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 610 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); | 
|---|
| 611 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 612 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); | 
|---|
| 613 | Instruction* az = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 614 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); | 
|---|
| 615 |  | 
|---|
| 616 | // Find which values are negative.  Used in later computations. | 
|---|
| 617 | Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 618 | z->result_id(), f0_const_id); | 
|---|
| 619 | Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 620 | y->result_id(), f0_const_id); | 
|---|
| 621 | Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 622 | x->result_id(), f0_const_id); | 
|---|
| 623 |  | 
|---|
| 624 | // Compute cubema | 
|---|
| 625 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 626 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, | 
|---|
| 627 | {ax->result_id(), ay->result_id()}); | 
|---|
| 628 | Instruction* amax = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 629 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, | 
|---|
| 630 | {az->result_id(), amax_x_y->result_id()}); | 
|---|
| 631 | Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, SpvOpFMul, | 
|---|
| 632 | f2_const_id, amax->result_id()); | 
|---|
| 633 |  | 
|---|
| 634 | // Do the comparisons needed for computing cubesc and cubetc. | 
|---|
| 635 | Instruction* is_z_max = | 
|---|
| 636 | ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual, | 
|---|
| 637 | az->result_id(), amax_x_y->result_id()); | 
|---|
| 638 | Instruction* not_is_z_max = | 
|---|
| 639 | ir_builder.AddUnaryOp(bool_id, SpvOpLogicalNot, is_z_max->result_id()); | 
|---|
| 640 | Instruction* y_gr_x = ir_builder.AddBinaryOp( | 
|---|
| 641 | bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id()); | 
|---|
| 642 | Instruction* is_y_max = ir_builder.AddBinaryOp( | 
|---|
| 643 | bool_id, SpvOpLogicalAnd, not_is_z_max->result_id(), y_gr_x->result_id()); | 
|---|
| 644 |  | 
|---|
| 645 | // Select the correct value for cubesc. | 
|---|
| 646 | Instruction* cubesc_case_1 = ir_builder.AddSelect( | 
|---|
| 647 | float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id()); | 
|---|
| 648 | Instruction* cubesc_case_2 = ir_builder.AddSelect( | 
|---|
| 649 | float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id()); | 
|---|
| 650 | Instruction* sel = | 
|---|
| 651 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(), | 
|---|
| 652 | cubesc_case_2->result_id()); | 
|---|
| 653 | Instruction* cubesc = | 
|---|
| 654 | ir_builder.AddSelect(float_type_id, is_z_max->result_id(), | 
|---|
| 655 | cubesc_case_1->result_id(), sel->result_id()); | 
|---|
| 656 |  | 
|---|
| 657 | // Select the correct value for cubetc. | 
|---|
| 658 | Instruction* cubetc_case_1 = ir_builder.AddSelect( | 
|---|
| 659 | float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id()); | 
|---|
| 660 | Instruction* cubetc = | 
|---|
| 661 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), | 
|---|
| 662 | cubetc_case_1->result_id(), ny->result_id()); | 
|---|
| 663 |  | 
|---|
| 664 | // Do the division | 
|---|
| 665 | Instruction* cube = ir_builder.AddCompositeConstruct( | 
|---|
| 666 | v2_float_type_id, {cubesc->result_id(), cubetc->result_id()}); | 
|---|
| 667 | Instruction* denom = ir_builder.AddCompositeConstruct( | 
|---|
| 668 | v2_float_type_id, {cubema->result_id(), cubema->result_id()}); | 
|---|
| 669 | Instruction* div = ir_builder.AddBinaryOp( | 
|---|
| 670 | v2_float_type_id, SpvOpFDiv, cube->result_id(), denom->result_id()); | 
|---|
| 671 |  | 
|---|
| 672 | // Get the final result by adding 0.5 to |div|. | 
|---|
| 673 | inst->SetOpcode(SpvOpFAdd); | 
|---|
| 674 | Instruction::OperandList new_operands; | 
|---|
| 675 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}}); | 
|---|
| 676 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}}); | 
|---|
| 677 |  | 
|---|
| 678 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 679 | ctx->UpdateDefUse(inst); | 
|---|
| 680 | return true; | 
|---|
| 681 | } | 
|---|
| 682 |  | 
|---|
| 683 | // A folding rule that will replace the CubeFaceIndexAMD extended | 
|---|
| 684 | // instruction in the SPV_AMD_gcn_shader_ballot.  Returns true if the folding | 
|---|
| 685 | // is successful. | 
|---|
| 686 | // | 
|---|
| 687 | // The instruction | 
|---|
| 688 | // | 
|---|
| 689 | //  %result = OpExtInst %float %1 CubeFaceIndexAMD %input | 
|---|
| 690 | // | 
|---|
| 691 | // with | 
|---|
| 692 | // | 
|---|
| 693 | //             %x = OpCompositeExtract %float %input 0 | 
|---|
| 694 | //             %y = OpCompositeExtract %float %input 1 | 
|---|
| 695 | //             %z = OpCompositeExtract %float %input 2 | 
|---|
| 696 | //            %ax = OpExtInst %float %n_1 FAbs %x | 
|---|
| 697 | //            %ay = OpExtInst %float %n_1 FAbs %y | 
|---|
| 698 | //            %az = OpExtInst %float %n_1 FAbs %z | 
|---|
| 699 | //      %is_z_neg = OpFOrdLessThan %bool %z %float_0 | 
|---|
| 700 | //      %is_y_neg = OpFOrdLessThan %bool %y %float_0 | 
|---|
| 701 | //      %is_x_neg = OpFOrdLessThan %bool %x %float_0 | 
|---|
| 702 | //      %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay | 
|---|
| 703 | //      %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y | 
|---|
| 704 | //        %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax | 
|---|
| 705 | //        %case_z = OpSelect %float %is_z_neg %float_5 %float4 | 
|---|
| 706 | //        %case_y = OpSelect %float %is_y_neg %float_3 %float2 | 
|---|
| 707 | //        %case_x = OpSelect %float %is_x_neg %float_1 %float0 | 
|---|
| 708 | //           %sel = OpSelect %float %y_gt_x %case_y %case_x | 
|---|
| 709 | //        %result = OpSelect %float %is_z_max %case_z %sel | 
|---|
| 710 | // | 
|---|
| 711 | // Also adding the capabilities and builtins that are needed. | 
|---|
| 712 | bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst, | 
|---|
| 713 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 714 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); | 
|---|
| 715 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); | 
|---|
| 716 |  | 
|---|
| 717 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); | 
|---|
| 718 | uint32_t bool_id = type_mgr->GetBoolTypeId(); | 
|---|
| 719 |  | 
|---|
| 720 | InstructionBuilder ir_builder( | 
|---|
| 721 | ctx, inst, | 
|---|
| 722 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 723 |  | 
|---|
| 724 | uint32_t input_id = inst->GetSingleWordInOperand(2); | 
|---|
| 725 | uint32_t glsl405_ext_inst_id = | 
|---|
| 726 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 727 | if (glsl405_ext_inst_id == 0) { | 
|---|
| 728 | ctx->AddExtInstImport( "GLSL.std.450"); | 
|---|
| 729 | glsl405_ext_inst_id = | 
|---|
| 730 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); | 
|---|
| 731 | } | 
|---|
| 732 |  | 
|---|
| 733 | // Get the constants that will be used. | 
|---|
| 734 | uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); | 
|---|
| 735 | uint32_t f1_const_id = const_mgr->GetFloatConst(1.0); | 
|---|
| 736 | uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); | 
|---|
| 737 | uint32_t f3_const_id = const_mgr->GetFloatConst(3.0); | 
|---|
| 738 | uint32_t f4_const_id = const_mgr->GetFloatConst(4.0); | 
|---|
| 739 | uint32_t f5_const_id = const_mgr->GetFloatConst(5.0); | 
|---|
| 740 |  | 
|---|
| 741 | // Extract the input values. | 
|---|
| 742 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); | 
|---|
| 743 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); | 
|---|
| 744 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); | 
|---|
| 745 |  | 
|---|
| 746 | // Get the absolute values of the inputs. | 
|---|
| 747 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 748 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); | 
|---|
| 749 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 750 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); | 
|---|
| 751 | Instruction* az = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 752 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); | 
|---|
| 753 |  | 
|---|
| 754 | // Find which values are negative.  Used in later computations. | 
|---|
| 755 | Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 756 | z->result_id(), f0_const_id); | 
|---|
| 757 | Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 758 | y->result_id(), f0_const_id); | 
|---|
| 759 | Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, | 
|---|
| 760 | x->result_id(), f0_const_id); | 
|---|
| 761 |  | 
|---|
| 762 | // Find the max value. | 
|---|
| 763 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( | 
|---|
| 764 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, | 
|---|
| 765 | {ax->result_id(), ay->result_id()}); | 
|---|
| 766 | Instruction* is_z_max = | 
|---|
| 767 | ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual, | 
|---|
| 768 | az->result_id(), amax_x_y->result_id()); | 
|---|
| 769 | Instruction* y_gr_x = ir_builder.AddBinaryOp( | 
|---|
| 770 | bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id()); | 
|---|
| 771 |  | 
|---|
| 772 | // Get the value for each case. | 
|---|
| 773 | Instruction* case_z = ir_builder.AddSelect( | 
|---|
| 774 | float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id); | 
|---|
| 775 | Instruction* case_y = ir_builder.AddSelect( | 
|---|
| 776 | float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id); | 
|---|
| 777 | Instruction* case_x = ir_builder.AddSelect( | 
|---|
| 778 | float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id); | 
|---|
| 779 |  | 
|---|
| 780 | // Select the correct case. | 
|---|
| 781 | Instruction* sel = | 
|---|
| 782 | ir_builder.AddSelect(float_type_id, y_gr_x->result_id(), | 
|---|
| 783 | case_y->result_id(), case_x->result_id()); | 
|---|
| 784 |  | 
|---|
| 785 | // Get the final result by adding 0.5 to |div|. | 
|---|
| 786 | inst->SetOpcode(SpvOpSelect); | 
|---|
| 787 | Instruction::OperandList new_operands; | 
|---|
| 788 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}}); | 
|---|
| 789 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}}); | 
|---|
| 790 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}}); | 
|---|
| 791 |  | 
|---|
| 792 | inst->SetInOperands(std::move(new_operands)); | 
|---|
| 793 | ctx->UpdateDefUse(inst); | 
|---|
| 794 | return true; | 
|---|
| 795 | } | 
|---|
| 796 |  | 
|---|
| 797 | // A folding rule that will replace the TimeAMD extended instruction in the | 
|---|
| 798 | // SPV_AMD_gcn_shader_ballot.  It returns true if the folding is successful. | 
|---|
| 799 | // It returns False, otherwise. | 
|---|
| 800 | // | 
|---|
| 801 | // The instruction | 
|---|
| 802 | // | 
|---|
| 803 | //  %result = OpExtInst %uint64 %1 TimeAMD | 
|---|
| 804 | // | 
|---|
| 805 | // with | 
|---|
| 806 | // | 
|---|
| 807 | //  %result = OpReadClockKHR %uint64 %uint_3 | 
|---|
| 808 | // | 
|---|
| 809 | // NOTE: TimeAMD uses subgroup scope (it is not a real time clock). | 
|---|
| 810 | bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst, | 
|---|
| 811 | const std::vector<const analysis::Constant*>&) { | 
|---|
| 812 | InstructionBuilder ir_builder( | 
|---|
| 813 | ctx, inst, | 
|---|
| 814 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); | 
|---|
| 815 | ctx->AddExtension( "SPV_KHR_shader_clock"); | 
|---|
| 816 | ctx->AddCapability(SpvCapabilityShaderClockKHR); | 
|---|
| 817 |  | 
|---|
| 818 | inst->SetOpcode(SpvOpReadClockKHR); | 
|---|
| 819 | Instruction::OperandList args; | 
|---|
| 820 | uint32_t subgroup_scope_id = ir_builder.GetUintConstantId(SpvScopeSubgroup); | 
|---|
| 821 | args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}}); | 
|---|
| 822 | inst->SetInOperands(std::move(args)); | 
|---|
| 823 | ctx->UpdateDefUse(inst); | 
|---|
| 824 |  | 
|---|
| 825 | return true; | 
|---|
| 826 | } | 
|---|
| 827 |  | 
|---|
| 828 | class AmdExtFoldingRules : public FoldingRules { | 
|---|
| 829 | public: | 
|---|
| 830 | explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {} | 
|---|
| 831 |  | 
|---|
| 832 | protected: | 
|---|
| 833 | virtual void AddFoldingRules() override { | 
|---|
| 834 | rules_[SpvOpGroupIAddNonUniformAMD].push_back( | 
|---|
| 835 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformIAdd>); | 
|---|
| 836 | rules_[SpvOpGroupFAddNonUniformAMD].push_back( | 
|---|
| 837 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFAdd>); | 
|---|
| 838 | rules_[SpvOpGroupUMinNonUniformAMD].push_back( | 
|---|
| 839 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMin>); | 
|---|
| 840 | rules_[SpvOpGroupSMinNonUniformAMD].push_back( | 
|---|
| 841 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMin>); | 
|---|
| 842 | rules_[SpvOpGroupFMinNonUniformAMD].push_back( | 
|---|
| 843 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMin>); | 
|---|
| 844 | rules_[SpvOpGroupUMaxNonUniformAMD].push_back( | 
|---|
| 845 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMax>); | 
|---|
| 846 | rules_[SpvOpGroupSMaxNonUniformAMD].push_back( | 
|---|
| 847 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMax>); | 
|---|
| 848 | rules_[SpvOpGroupFMaxNonUniformAMD].push_back( | 
|---|
| 849 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMax>); | 
|---|
| 850 |  | 
|---|
| 851 | uint32_t extension_id = | 
|---|
| 852 | context()->module()->GetExtInstImportId( "SPV_AMD_shader_ballot"); | 
|---|
| 853 |  | 
|---|
| 854 | if (extension_id != 0) { | 
|---|
| 855 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}] | 
|---|
| 856 | .push_back(ReplaceSwizzleInvocations); | 
|---|
| 857 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}] | 
|---|
| 858 | .push_back(ReplaceSwizzleInvocationsMasked); | 
|---|
| 859 | ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back( | 
|---|
| 860 | ReplaceWriteInvocation); | 
|---|
| 861 | ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back( | 
|---|
| 862 | ReplaceMbcnt); | 
|---|
| 863 | } | 
|---|
| 864 |  | 
|---|
| 865 | extension_id = context()->module()->GetExtInstImportId( | 
|---|
| 866 | "SPV_AMD_shader_trinary_minmax"); | 
|---|
| 867 |  | 
|---|
| 868 | if (extension_id != 0) { | 
|---|
| 869 | ext_rules_[{extension_id, FMin3AMD}].push_back( | 
|---|
| 870 | ReplaceTrinaryMinMax<GLSLstd450FMin>); | 
|---|
| 871 | ext_rules_[{extension_id, UMin3AMD}].push_back( | 
|---|
| 872 | ReplaceTrinaryMinMax<GLSLstd450UMin>); | 
|---|
| 873 | ext_rules_[{extension_id, SMin3AMD}].push_back( | 
|---|
| 874 | ReplaceTrinaryMinMax<GLSLstd450SMin>); | 
|---|
| 875 | ext_rules_[{extension_id, FMax3AMD}].push_back( | 
|---|
| 876 | ReplaceTrinaryMinMax<GLSLstd450FMax>); | 
|---|
| 877 | ext_rules_[{extension_id, UMax3AMD}].push_back( | 
|---|
| 878 | ReplaceTrinaryMinMax<GLSLstd450UMax>); | 
|---|
| 879 | ext_rules_[{extension_id, SMax3AMD}].push_back( | 
|---|
| 880 | ReplaceTrinaryMinMax<GLSLstd450SMax>); | 
|---|
| 881 | ext_rules_[{extension_id, FMid3AMD}].push_back( | 
|---|
| 882 | ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>); | 
|---|
| 883 | ext_rules_[{extension_id, UMid3AMD}].push_back( | 
|---|
| 884 | ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>); | 
|---|
| 885 | ext_rules_[{extension_id, SMid3AMD}].push_back( | 
|---|
| 886 | ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>); | 
|---|
| 887 | } | 
|---|
| 888 |  | 
|---|
| 889 | extension_id = | 
|---|
| 890 | context()->module()->GetExtInstImportId( "SPV_AMD_gcn_shader"); | 
|---|
| 891 |  | 
|---|
| 892 | if (extension_id != 0) { | 
|---|
| 893 | ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back( | 
|---|
| 894 | ReplaceCubeFaceCoord); | 
|---|
| 895 | ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back( | 
|---|
| 896 | ReplaceCubeFaceIndex); | 
|---|
| 897 | ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD); | 
|---|
| 898 | } | 
|---|
| 899 | } | 
|---|
| 900 | }; | 
|---|
| 901 |  | 
|---|
| 902 | class AmdExtConstFoldingRules : public ConstantFoldingRules { | 
|---|
| 903 | public: | 
|---|
| 904 | AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {} | 
|---|
| 905 |  | 
|---|
| 906 | protected: | 
|---|
| 907 | virtual void AddFoldingRules() override {} | 
|---|
| 908 | }; | 
|---|
| 909 |  | 
|---|
| 910 | }  // namespace | 
|---|
| 911 |  | 
|---|
| 912 | Pass::Status AmdExtensionToKhrPass::Process() { | 
|---|
| 913 | bool changed = false; | 
|---|
| 914 |  | 
|---|
| 915 | // Traverse the body of the functions to replace instructions that require | 
|---|
| 916 | // the extensions. | 
|---|
| 917 | InstructionFolder folder( | 
|---|
| 918 | context(), | 
|---|
| 919 | std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())), | 
|---|
| 920 | MakeUnique<AmdExtConstFoldingRules>(context())); | 
|---|
| 921 | for (Function& func : *get_module()) { | 
|---|
| 922 | func.ForEachInst([&changed, &folder](Instruction* inst) { | 
|---|
| 923 | if (folder.FoldInstruction(inst)) { | 
|---|
| 924 | changed = true; | 
|---|
| 925 | } | 
|---|
| 926 | }); | 
|---|
| 927 | } | 
|---|
| 928 |  | 
|---|
| 929 | // Now that instruction that require the extensions have been removed, we can | 
|---|
| 930 | // remove the extension instructions. | 
|---|
| 931 | std::set<std::string> ext_to_remove = { "SPV_AMD_shader_ballot", | 
|---|
| 932 | "SPV_AMD_shader_trinary_minmax", | 
|---|
| 933 | "SPV_AMD_gcn_shader"}; | 
|---|
| 934 |  | 
|---|
| 935 | std::vector<Instruction*> to_be_killed; | 
|---|
| 936 | for (Instruction& inst : context()->module()->extensions()) { | 
|---|
| 937 | if (inst.opcode() == SpvOpExtension) { | 
|---|
| 938 | if (ext_to_remove.count(reinterpret_cast<const char*>( | 
|---|
| 939 | &(inst.GetInOperand(0).words[0]))) != 0) { | 
|---|
| 940 | to_be_killed.push_back(&inst); | 
|---|
| 941 | } | 
|---|
| 942 | } | 
|---|
| 943 | } | 
|---|
| 944 |  | 
|---|
| 945 | for (Instruction& inst : context()->ext_inst_imports()) { | 
|---|
| 946 | if (inst.opcode() == SpvOpExtInstImport) { | 
|---|
| 947 | if (ext_to_remove.count(reinterpret_cast<const char*>( | 
|---|
| 948 | &(inst.GetInOperand(0).words[0]))) != 0) { | 
|---|
| 949 | to_be_killed.push_back(&inst); | 
|---|
| 950 | } | 
|---|
| 951 | } | 
|---|
| 952 | } | 
|---|
| 953 |  | 
|---|
| 954 | for (Instruction* inst : to_be_killed) { | 
|---|
| 955 | context()->KillInst(inst); | 
|---|
| 956 | changed = true; | 
|---|
| 957 | } | 
|---|
| 958 |  | 
|---|
| 959 | // The replacements that take place use instructions that are missing before | 
|---|
| 960 | // SPIR-V 1.3. If we changed something, we will have to make sure the version | 
|---|
| 961 | // is at least SPIR-V 1.3 to make sure those instruction can be used. | 
|---|
| 962 | if (changed) { | 
|---|
| 963 | uint32_t version = get_module()->version(); | 
|---|
| 964 | if (version < 0x00010300 /*1.3*/) { | 
|---|
| 965 | get_module()->set_version(0x00010300); | 
|---|
| 966 | } | 
|---|
| 967 | } | 
|---|
| 968 | return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; | 
|---|
| 969 | } | 
|---|
| 970 |  | 
|---|
| 971 | }  // namespace opt | 
|---|
| 972 | }  // namespace spvtools | 
|---|
| 973 |  | 
|---|