1 | // Copyright (c) 2019 Google LLC. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/amd_ext_to_khr.h" |
16 | |
17 | #include <set> |
18 | #include <string> |
19 | |
20 | #include "ir_builder.h" |
21 | #include "source/opt/ir_context.h" |
22 | #include "spv-amd-shader-ballot.insts.inc" |
23 | #include "type_manager.h" |
24 | |
25 | namespace spvtools { |
26 | namespace opt { |
27 | |
28 | namespace { |
29 | |
30 | enum AmdShaderBallotExtOpcodes { |
31 | AmdShaderBallotSwizzleInvocationsAMD = 1, |
32 | AmdShaderBallotSwizzleInvocationsMaskedAMD = 2, |
33 | AmdShaderBallotWriteInvocationAMD = 3, |
34 | AmdShaderBallotMbcntAMD = 4 |
35 | }; |
36 | |
37 | enum AmdShaderTrinaryMinMaxExtOpCodes { |
38 | FMin3AMD = 1, |
39 | UMin3AMD = 2, |
40 | SMin3AMD = 3, |
41 | FMax3AMD = 4, |
42 | UMax3AMD = 5, |
43 | SMax3AMD = 6, |
44 | FMid3AMD = 7, |
45 | UMid3AMD = 8, |
46 | SMid3AMD = 9 |
47 | }; |
48 | |
49 | enum AmdGcnShader { CubeFaceCoordAMD = 2, CubeFaceIndexAMD = 1, TimeAMD = 3 }; |
50 | |
51 | analysis::Type* GetUIntType(IRContext* ctx) { |
52 | analysis::Integer int_type(32, false); |
53 | return ctx->get_type_mgr()->GetRegisteredType(&int_type); |
54 | } |
55 | |
56 | // Returns a folding rule that replaces |op(a,b,c)| by |op(op(a,b),c)|, where |
57 | // |op| is either min or max. |opcode| is the binary opcode in the GLSLstd450 |
58 | // extended instruction set that corresponds to the trinary instruction being |
59 | // replaced. |
60 | template <GLSLstd450 opcode> |
61 | bool ReplaceTrinaryMinMax(IRContext* ctx, Instruction* inst, |
62 | const std::vector<const analysis::Constant*>&) { |
63 | uint32_t glsl405_ext_inst_id = |
64 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
65 | if (glsl405_ext_inst_id == 0) { |
66 | ctx->AddExtInstImport("GLSL.std.450" ); |
67 | glsl405_ext_inst_id = |
68 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
69 | } |
70 | |
71 | InstructionBuilder ir_builder( |
72 | ctx, inst, |
73 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
74 | |
75 | uint32_t op1 = inst->GetSingleWordInOperand(2); |
76 | uint32_t op2 = inst->GetSingleWordInOperand(3); |
77 | uint32_t op3 = inst->GetSingleWordInOperand(4); |
78 | |
79 | Instruction* temp = ir_builder.AddNaryExtendedInstruction( |
80 | inst->type_id(), glsl405_ext_inst_id, opcode, {op1, op2}); |
81 | |
82 | Instruction::OperandList new_operands; |
83 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); |
84 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, |
85 | {static_cast<uint32_t>(opcode)}}); |
86 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {temp->result_id()}}); |
87 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op3}}); |
88 | |
89 | inst->SetInOperands(std::move(new_operands)); |
90 | ctx->UpdateDefUse(inst); |
91 | return true; |
92 | } |
93 | |
94 | // Returns a folding rule that replaces |mid(a,b,c)| by |clamp(a, min(b,c), |
95 | // max(b,c)|. The three parameters are the opcode that correspond to the min, |
96 | // max, and clamp operations for the type of the instruction being replaced. |
97 | template <GLSLstd450 min_opcode, GLSLstd450 max_opcode, GLSLstd450 clamp_opcode> |
98 | bool ReplaceTrinaryMid(IRContext* ctx, Instruction* inst, |
99 | const std::vector<const analysis::Constant*>&) { |
100 | uint32_t glsl405_ext_inst_id = |
101 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
102 | if (glsl405_ext_inst_id == 0) { |
103 | ctx->AddExtInstImport("GLSL.std.450" ); |
104 | glsl405_ext_inst_id = |
105 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
106 | } |
107 | |
108 | InstructionBuilder ir_builder( |
109 | ctx, inst, |
110 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
111 | |
112 | uint32_t op1 = inst->GetSingleWordInOperand(2); |
113 | uint32_t op2 = inst->GetSingleWordInOperand(3); |
114 | uint32_t op3 = inst->GetSingleWordInOperand(4); |
115 | |
116 | Instruction* min = ir_builder.AddNaryExtendedInstruction( |
117 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(min_opcode), |
118 | {op2, op3}); |
119 | Instruction* max = ir_builder.AddNaryExtendedInstruction( |
120 | inst->type_id(), glsl405_ext_inst_id, static_cast<uint32_t>(max_opcode), |
121 | {op2, op3}); |
122 | |
123 | Instruction::OperandList new_operands; |
124 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl405_ext_inst_id}}); |
125 | new_operands.push_back({SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, |
126 | {static_cast<uint32_t>(clamp_opcode)}}); |
127 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {op1}}); |
128 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {min->result_id()}}); |
129 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {max->result_id()}}); |
130 | |
131 | inst->SetInOperands(std::move(new_operands)); |
132 | ctx->UpdateDefUse(inst); |
133 | return true; |
134 | } |
135 | |
136 | // Returns a folding rule that will replace the opcode with |opcode| and add |
137 | // the capabilities required. The folding rule assumes it is folding an |
138 | // OpGroup*NonUniformAMD instruction from the SPV_AMD_shader_ballot extension. |
139 | template <SpvOp new_opcode> |
140 | bool ReplaceGroupNonuniformOperationOpCode( |
141 | IRContext* ctx, Instruction* inst, |
142 | const std::vector<const analysis::Constant*>&) { |
143 | switch (new_opcode) { |
144 | case SpvOpGroupNonUniformIAdd: |
145 | case SpvOpGroupNonUniformFAdd: |
146 | case SpvOpGroupNonUniformUMin: |
147 | case SpvOpGroupNonUniformSMin: |
148 | case SpvOpGroupNonUniformFMin: |
149 | case SpvOpGroupNonUniformUMax: |
150 | case SpvOpGroupNonUniformSMax: |
151 | case SpvOpGroupNonUniformFMax: |
152 | break; |
153 | default: |
154 | assert( |
155 | false && |
156 | "Should be replacing with a group non uniform arithmetic operation." ); |
157 | } |
158 | |
159 | switch (inst->opcode()) { |
160 | case SpvOpGroupIAddNonUniformAMD: |
161 | case SpvOpGroupFAddNonUniformAMD: |
162 | case SpvOpGroupUMinNonUniformAMD: |
163 | case SpvOpGroupSMinNonUniformAMD: |
164 | case SpvOpGroupFMinNonUniformAMD: |
165 | case SpvOpGroupUMaxNonUniformAMD: |
166 | case SpvOpGroupSMaxNonUniformAMD: |
167 | case SpvOpGroupFMaxNonUniformAMD: |
168 | break; |
169 | default: |
170 | assert(false && |
171 | "Should be replacing a group non uniform arithmetic operation." ); |
172 | } |
173 | |
174 | ctx->AddCapability(SpvCapabilityGroupNonUniformArithmetic); |
175 | inst->SetOpcode(new_opcode); |
176 | return true; |
177 | } |
178 | |
179 | // Returns a folding rule that will replace the SwizzleInvocationsAMD extended |
180 | // instruction in the SPV_AMD_shader_ballot extension. |
181 | // |
182 | // The instruction |
183 | // |
184 | // %offset = OpConstantComposite %v3uint %x %y %z %w |
185 | // %result = OpExtInst %type %1 SwizzleInvocationsAMD %data %offset |
186 | // |
187 | // is replaced with |
188 | // |
189 | // potentially new constants and types |
190 | // |
191 | // clang-format off |
192 | // %uint_max = OpConstant %uint 0xFFFFFFFF |
193 | // %v4uint = OpTypeVector %uint 4 |
194 | // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max |
195 | // %null = OpConstantNull %type |
196 | // clang-format on |
197 | // |
198 | // and the following code in the function body |
199 | // |
200 | // clang-format off |
201 | // %id = OpLoad %uint %SubgroupLocalInvocationId |
202 | // %quad_idx = OpBitwiseAnd %uint %id %uint_3 |
203 | // %quad_ldr = OpBitwiseXor %uint %id %quad_idx |
204 | // %my_offset = OpVectorExtractDynamic %uint %offset %quad_idx |
205 | // %target_inv = OpIAdd %uint %quad_ldr %my_offset |
206 | // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv |
207 | // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv |
208 | // %result = OpSelect %type %is_active %shuffle %null |
209 | // clang-format on |
210 | // |
211 | // Also adding the capabilities and builtins that are needed. |
212 | bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst, |
213 | const std::vector<const analysis::Constant*>&) { |
214 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
215 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
216 | |
217 | ctx->AddExtension("SPV_KHR_shader_ballot" ); |
218 | ctx->AddCapability(SpvCapabilityGroupNonUniformBallot); |
219 | ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle); |
220 | |
221 | InstructionBuilder ir_builder( |
222 | ctx, inst, |
223 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
224 | |
225 | uint32_t data_id = inst->GetSingleWordInOperand(2); |
226 | uint32_t offset_id = inst->GetSingleWordInOperand(3); |
227 | |
228 | // Get the subgroup invocation id. |
229 | uint32_t var_id = |
230 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); |
231 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable." ); |
232 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
233 | Instruction* var_ptr_type = |
234 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
235 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); |
236 | |
237 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); |
238 | |
239 | uint32_t quad_mask = ir_builder.GetUintConstantId(3); |
240 | |
241 | // This gives the offset in the group of 4 of this invocation. |
242 | Instruction* quad_idx = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseAnd, |
243 | id->result_id(), quad_mask); |
244 | |
245 | // Get the invocation id of the first invocation in the group of 4. |
246 | Instruction* quad_ldr = ir_builder.AddBinaryOp( |
247 | uint_type_id, SpvOpBitwiseXor, id->result_id(), quad_idx->result_id()); |
248 | |
249 | // Get the offset of the target invocation from the offset vector. |
250 | Instruction* my_offset = |
251 | ir_builder.AddBinaryOp(uint_type_id, SpvOpVectorExtractDynamic, offset_id, |
252 | quad_idx->result_id()); |
253 | |
254 | // Determine the index of the invocation to read from. |
255 | Instruction* target_inv = ir_builder.AddBinaryOp( |
256 | uint_type_id, SpvOpIAdd, quad_ldr->result_id(), my_offset->result_id()); |
257 | |
258 | // Do the group operations |
259 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); |
260 | uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup); |
261 | const auto* ballot_value_const = const_mgr->GetConstant( |
262 | type_mgr->GetUIntVectorType(4), |
263 | {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); |
264 | Instruction* ballot_value = |
265 | const_mgr->GetDefiningInstruction(ballot_value_const); |
266 | Instruction* is_active = ir_builder.AddNaryOp( |
267 | type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract, |
268 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); |
269 | Instruction* shuffle = |
270 | ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle, |
271 | {subgroup_scope, data_id, target_inv->result_id()}); |
272 | |
273 | // Create the null constant to use in the select. |
274 | const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), |
275 | std::vector<uint32_t>()); |
276 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); |
277 | |
278 | // Build the select. |
279 | inst->SetOpcode(SpvOpSelect); |
280 | Instruction::OperandList new_operands; |
281 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); |
282 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); |
283 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); |
284 | |
285 | inst->SetInOperands(std::move(new_operands)); |
286 | ctx->UpdateDefUse(inst); |
287 | return true; |
288 | } |
289 | |
290 | // Returns a folding rule that will replace the SwizzleInvocationsMaskedAMD |
291 | // extended instruction in the SPV_AMD_shader_ballot extension. |
292 | // |
293 | // The instruction |
294 | // |
295 | // %mask = OpConstantComposite %v3uint %uint_x %uint_y %uint_z |
296 | // %result = OpExtInst %uint %1 SwizzleInvocationsMaskedAMD %data %mask |
297 | // |
298 | // is replaced with |
299 | // |
300 | // potentially new constants and types |
301 | // |
302 | // clang-format off |
303 | // %uint_mask_extend = OpConstant %uint 0xFFFFFFE0 |
304 | // %uint_max = OpConstant %uint 0xFFFFFFFF |
305 | // %v4uint = OpTypeVector %uint 4 |
306 | // %ballot_value = OpConstantComposite %v4uint %uint_max %uint_max %uint_max %uint_max |
307 | // clang-format on |
308 | // |
309 | // and the following code in the function body |
310 | // |
311 | // clang-format off |
312 | // %id = OpLoad %uint %SubgroupLocalInvocationId |
313 | // %and_mask = OpBitwiseOr %uint %uint_x %uint_mask_extend |
314 | // %and = OpBitwiseAnd %uint %id %and_mask |
315 | // %or = OpBitwiseOr %uint %and %uint_y |
316 | // %target_inv = OpBitwiseXor %uint %or %uint_z |
317 | // %is_active = OpGroupNonUniformBallotBitExtract %bool %uint_3 %ballot_value %target_inv |
318 | // %shuffle = OpGroupNonUniformShuffle %type %uint_3 %data %target_inv |
319 | // %result = OpSelect %type %is_active %shuffle %uint_0 |
320 | // clang-format on |
321 | // |
322 | // Also adding the capabilities and builtins that are needed. |
323 | bool ReplaceSwizzleInvocationsMasked( |
324 | IRContext* ctx, Instruction* inst, |
325 | const std::vector<const analysis::Constant*>&) { |
326 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
327 | analysis::DefUseManager* def_use_mgr = ctx->get_def_use_mgr(); |
328 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
329 | |
330 | ctx->AddCapability(SpvCapabilityGroupNonUniformBallot); |
331 | ctx->AddCapability(SpvCapabilityGroupNonUniformShuffle); |
332 | |
333 | InstructionBuilder ir_builder( |
334 | ctx, inst, |
335 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
336 | |
337 | // Get the operands to inst, and the components of the mask |
338 | uint32_t data_id = inst->GetSingleWordInOperand(2); |
339 | |
340 | Instruction* mask_inst = def_use_mgr->GetDef(inst->GetSingleWordInOperand(3)); |
341 | assert(mask_inst->opcode() == SpvOpConstantComposite && |
342 | "The mask is suppose to be a vector constant." ); |
343 | assert(mask_inst->NumInOperands() == 3 && |
344 | "The mask is suppose to have 3 components." ); |
345 | |
346 | uint32_t uint_x = mask_inst->GetSingleWordInOperand(0); |
347 | uint32_t uint_y = mask_inst->GetSingleWordInOperand(1); |
348 | uint32_t uint_z = mask_inst->GetSingleWordInOperand(2); |
349 | |
350 | // Get the subgroup invocation id. |
351 | uint32_t var_id = |
352 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); |
353 | ctx->AddExtension("SPV_KHR_shader_ballot" ); |
354 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable." ); |
355 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
356 | Instruction* var_ptr_type = |
357 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
358 | uint32_t uint_type_id = var_ptr_type->GetSingleWordInOperand(1); |
359 | |
360 | Instruction* id = ir_builder.AddLoad(uint_type_id, var_id); |
361 | |
362 | // Do the bitwise operations. |
363 | uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0); |
364 | Instruction* and_mask = ir_builder.AddBinaryOp(uint_type_id, SpvOpBitwiseOr, |
365 | uint_x, mask_extended); |
366 | Instruction* and_result = ir_builder.AddBinaryOp( |
367 | uint_type_id, SpvOpBitwiseAnd, id->result_id(), and_mask->result_id()); |
368 | Instruction* or_result = ir_builder.AddBinaryOp( |
369 | uint_type_id, SpvOpBitwiseOr, and_result->result_id(), uint_y); |
370 | Instruction* target_inv = ir_builder.AddBinaryOp( |
371 | uint_type_id, SpvOpBitwiseXor, or_result->result_id(), uint_z); |
372 | |
373 | // Do the group operations |
374 | uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF); |
375 | uint32_t subgroup_scope = ir_builder.GetUintConstantId(SpvScopeSubgroup); |
376 | const auto* ballot_value_const = const_mgr->GetConstant( |
377 | type_mgr->GetUIntVectorType(4), |
378 | {uint_max_id, uint_max_id, uint_max_id, uint_max_id}); |
379 | Instruction* ballot_value = |
380 | const_mgr->GetDefiningInstruction(ballot_value_const); |
381 | Instruction* is_active = ir_builder.AddNaryOp( |
382 | type_mgr->GetBoolTypeId(), SpvOpGroupNonUniformBallotBitExtract, |
383 | {subgroup_scope, ballot_value->result_id(), target_inv->result_id()}); |
384 | Instruction* shuffle = |
385 | ir_builder.AddNaryOp(inst->type_id(), SpvOpGroupNonUniformShuffle, |
386 | {subgroup_scope, data_id, target_inv->result_id()}); |
387 | |
388 | // Create the null constant to use in the select. |
389 | const auto* null = const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), |
390 | std::vector<uint32_t>()); |
391 | Instruction* null_inst = const_mgr->GetDefiningInstruction(null); |
392 | |
393 | // Build the select. |
394 | inst->SetOpcode(SpvOpSelect); |
395 | Instruction::OperandList new_operands; |
396 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_active->result_id()}}); |
397 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {shuffle->result_id()}}); |
398 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {null_inst->result_id()}}); |
399 | |
400 | inst->SetInOperands(std::move(new_operands)); |
401 | ctx->UpdateDefUse(inst); |
402 | return true; |
403 | } |
404 | |
405 | // Returns a folding rule that will replace the WriteInvocationAMD extended |
406 | // instruction in the SPV_AMD_shader_ballot extension. |
407 | // |
408 | // The instruction |
409 | // |
410 | // clang-format off |
411 | // %result = OpExtInst %type %1 WriteInvocationAMD %input_value %write_value %invocation_index |
412 | // clang-format on |
413 | // |
414 | // with |
415 | // |
416 | // %id = OpLoad %uint %SubgroupLocalInvocationId |
417 | // %cmp = OpIEqual %bool %id %invocation_index |
418 | // %result = OpSelect %type %cmp %write_value %input_value |
419 | // |
420 | // Also adding the capabilities and builtins that are needed. |
421 | bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst, |
422 | const std::vector<const analysis::Constant*>&) { |
423 | uint32_t var_id = |
424 | ctx->GetBuiltinInputVarId(SpvBuiltInSubgroupLocalInvocationId); |
425 | ctx->AddCapability(SpvCapabilitySubgroupBallotKHR); |
426 | ctx->AddExtension("SPV_KHR_shader_ballot" ); |
427 | assert(var_id != 0 && "Could not get SubgroupLocalInvocationId variable." ); |
428 | Instruction* var_inst = ctx->get_def_use_mgr()->GetDef(var_id); |
429 | Instruction* var_ptr_type = |
430 | ctx->get_def_use_mgr()->GetDef(var_inst->type_id()); |
431 | |
432 | InstructionBuilder ir_builder( |
433 | ctx, inst, |
434 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
435 | Instruction* t = |
436 | ir_builder.AddLoad(var_ptr_type->GetSingleWordInOperand(1), var_id); |
437 | analysis::Bool bool_type; |
438 | uint32_t bool_type_id = ctx->get_type_mgr()->GetTypeInstruction(&bool_type); |
439 | Instruction* cmp = |
440 | ir_builder.AddBinaryOp(bool_type_id, SpvOpIEqual, t->result_id(), |
441 | inst->GetSingleWordInOperand(4)); |
442 | |
443 | // Build a select. |
444 | inst->SetOpcode(SpvOpSelect); |
445 | Instruction::OperandList new_operands; |
446 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {cmp->result_id()}}); |
447 | new_operands.push_back(inst->GetInOperand(3)); |
448 | new_operands.push_back(inst->GetInOperand(2)); |
449 | |
450 | inst->SetInOperands(std::move(new_operands)); |
451 | ctx->UpdateDefUse(inst); |
452 | return true; |
453 | } |
454 | |
455 | // Returns a folding rule that will replace the MbcntAMD extended instruction in |
456 | // the SPV_AMD_shader_ballot extension. |
457 | // |
458 | // The instruction |
459 | // |
460 | // %result = OpExtInst %uint %1 MbcntAMD %mask |
461 | // |
462 | // with |
463 | // |
464 | // Get SubgroupLtMask and convert the first 64-bits into a uint64_t because |
465 | // AMD's shader compiler expects a 64-bit integer mask. |
466 | // |
467 | // %var = OpLoad %v4uint %SubgroupLtMaskKHR |
468 | // %shuffle = OpVectorShuffle %v2uint %var %var 0 1 |
469 | // %cast = OpBitcast %ulong %shuffle |
470 | // |
471 | // Perform the mask and count the bits. |
472 | // |
473 | // %and = OpBitwiseAnd %ulong %cast %mask |
474 | // %result = OpBitCount %uint %and |
475 | // |
476 | // Also adding the capabilities and builtins that are needed. |
477 | bool ReplaceMbcnt(IRContext* context, Instruction* inst, |
478 | const std::vector<const analysis::Constant*>&) { |
479 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
480 | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
481 | |
482 | uint32_t var_id = context->GetBuiltinInputVarId(SpvBuiltInSubgroupLtMask); |
483 | assert(var_id != 0 && "Could not get SubgroupLtMask variable." ); |
484 | context->AddCapability(SpvCapabilityGroupNonUniformBallot); |
485 | Instruction* var_inst = def_use_mgr->GetDef(var_id); |
486 | Instruction* var_ptr_type = def_use_mgr->GetDef(var_inst->type_id()); |
487 | Instruction* var_type = |
488 | def_use_mgr->GetDef(var_ptr_type->GetSingleWordInOperand(1)); |
489 | assert(var_type->opcode() == SpvOpTypeVector && |
490 | "Variable is suppose to be a vector of 4 ints" ); |
491 | |
492 | // Get the type for the shuffle. |
493 | analysis::Vector temp_type(GetUIntType(context), 2); |
494 | const analysis::Type* shuffle_type = |
495 | context->get_type_mgr()->GetRegisteredType(&temp_type); |
496 | uint32_t shuffle_type_id = type_mgr->GetTypeInstruction(shuffle_type); |
497 | |
498 | uint32_t mask_id = inst->GetSingleWordInOperand(2); |
499 | Instruction* mask_inst = def_use_mgr->GetDef(mask_id); |
500 | |
501 | // Testing with amd's shader compiler shows that a 64-bit mask is expected. |
502 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger() != nullptr); |
503 | assert(type_mgr->GetType(mask_inst->type_id())->AsInteger()->width() == 64); |
504 | |
505 | InstructionBuilder ir_builder( |
506 | context, inst, |
507 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
508 | Instruction* load = ir_builder.AddLoad(var_type->result_id(), var_id); |
509 | Instruction* shuffle = ir_builder.AddVectorShuffle( |
510 | shuffle_type_id, load->result_id(), load->result_id(), {0, 1}); |
511 | Instruction* bitcast = ir_builder.AddUnaryOp( |
512 | mask_inst->type_id(), SpvOpBitcast, shuffle->result_id()); |
513 | Instruction* t = ir_builder.AddBinaryOp(mask_inst->type_id(), SpvOpBitwiseAnd, |
514 | bitcast->result_id(), mask_id); |
515 | |
516 | inst->SetOpcode(SpvOpBitCount); |
517 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}}); |
518 | context->UpdateDefUse(inst); |
519 | return true; |
520 | } |
521 | |
522 | // A folding rule that will replace the CubeFaceCoordAMD extended |
523 | // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding is |
524 | // successful. |
525 | // |
526 | // The instruction |
527 | // |
528 | // %result = OpExtInst %v2float %1 CubeFaceCoordAMD %input |
529 | // |
530 | // with |
531 | // |
532 | // %x = OpCompositeExtract %float %input 0 |
533 | // %y = OpCompositeExtract %float %input 1 |
534 | // %z = OpCompositeExtract %float %input 2 |
535 | // %nx = OpFNegate %float %x |
536 | // %ny = OpFNegate %float %y |
537 | // %nz = OpFNegate %float %z |
538 | // %ax = OpExtInst %float %n_1 FAbs %x |
539 | // %ay = OpExtInst %float %n_1 FAbs %y |
540 | // %az = OpExtInst %float %n_1 FAbs %z |
541 | // %amax_x_y = OpExtInst %float %n_1 FMax %ay %ax |
542 | // %amax = OpExtInst %float %n_1 FMax %az %amax_x_y |
543 | // %cubema = OpFMul %float %float_2 %amax |
544 | // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y |
545 | // %not_is_z_max = OpLogicalNot %bool %is_z_max |
546 | // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax |
547 | // %is_y_max = OpLogicalAnd %bool %not_is_z_max %y_gt_x |
548 | // %is_z_neg = OpFOrdLessThan %bool %z %float_0 |
549 | // %cubesc_case_1 = OpSelect %float %is_z_neg %nx %x |
550 | // %is_x_neg = OpFOrdLessThan %bool %x %float_0 |
551 | // %cubesc_case_2 = OpSelect %float %is_x_neg %z %nz |
552 | // %sel = OpSelect %float %is_y_max %x %cubesc_case_2 |
553 | // %cubesc = OpSelect %float %is_z_max %cubesc_case_1 %sel |
554 | // %is_y_neg = OpFOrdLessThan %bool %y %float_0 |
555 | // %cubetc_case_1 = OpSelect %float %is_y_neg %nz %z |
556 | // %cubetc = OpSelect %float %is_y_max %cubetc_case_1 %ny |
557 | // %cube = OpCompositeConstruct %v2float %cubesc %cubetc |
558 | // %denom = OpCompositeConstruct %v2float %cubema %cubema |
559 | // %div = OpFDiv %v2float %cube %denom |
560 | // %result = OpFAdd %v2float %div %const |
561 | // |
562 | // Also adding the capabilities and builtins that are needed. |
563 | bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst, |
564 | const std::vector<const analysis::Constant*>&) { |
565 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
566 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
567 | |
568 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); |
569 | const analysis::Type* v2_float_type = type_mgr->GetFloatVectorType(2); |
570 | uint32_t v2_float_type_id = type_mgr->GetId(v2_float_type); |
571 | uint32_t bool_id = type_mgr->GetBoolTypeId(); |
572 | |
573 | InstructionBuilder ir_builder( |
574 | ctx, inst, |
575 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
576 | |
577 | uint32_t input_id = inst->GetSingleWordInOperand(2); |
578 | uint32_t glsl405_ext_inst_id = |
579 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
580 | if (glsl405_ext_inst_id == 0) { |
581 | ctx->AddExtInstImport("GLSL.std.450" ); |
582 | glsl405_ext_inst_id = |
583 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
584 | } |
585 | |
586 | // Get the constants that will be used. |
587 | uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); |
588 | uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); |
589 | uint32_t f0_5_const_id = const_mgr->GetFloatConst(0.5); |
590 | const analysis::Constant* vec_const = |
591 | const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id}); |
592 | uint32_t vec_const_id = |
593 | const_mgr->GetDefiningInstruction(vec_const)->result_id(); |
594 | |
595 | // Extract the input values. |
596 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); |
597 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); |
598 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); |
599 | |
600 | // Negate the input values. |
601 | Instruction* nx = |
602 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, x->result_id()); |
603 | Instruction* ny = |
604 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, y->result_id()); |
605 | Instruction* nz = |
606 | ir_builder.AddUnaryOp(float_type_id, SpvOpFNegate, z->result_id()); |
607 | |
608 | // Get the abolsute values of the inputs. |
609 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( |
610 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); |
611 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( |
612 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); |
613 | Instruction* az = ir_builder.AddNaryExtendedInstruction( |
614 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); |
615 | |
616 | // Find which values are negative. Used in later computations. |
617 | Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
618 | z->result_id(), f0_const_id); |
619 | Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
620 | y->result_id(), f0_const_id); |
621 | Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
622 | x->result_id(), f0_const_id); |
623 | |
624 | // Compute cubema |
625 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( |
626 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
627 | {ax->result_id(), ay->result_id()}); |
628 | Instruction* amax = ir_builder.AddNaryExtendedInstruction( |
629 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
630 | {az->result_id(), amax_x_y->result_id()}); |
631 | Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, SpvOpFMul, |
632 | f2_const_id, amax->result_id()); |
633 | |
634 | // Do the comparisons needed for computing cubesc and cubetc. |
635 | Instruction* is_z_max = |
636 | ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual, |
637 | az->result_id(), amax_x_y->result_id()); |
638 | Instruction* not_is_z_max = |
639 | ir_builder.AddUnaryOp(bool_id, SpvOpLogicalNot, is_z_max->result_id()); |
640 | Instruction* y_gr_x = ir_builder.AddBinaryOp( |
641 | bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id()); |
642 | Instruction* is_y_max = ir_builder.AddBinaryOp( |
643 | bool_id, SpvOpLogicalAnd, not_is_z_max->result_id(), y_gr_x->result_id()); |
644 | |
645 | // Select the correct value for cubesc. |
646 | Instruction* cubesc_case_1 = ir_builder.AddSelect( |
647 | float_type_id, is_z_neg->result_id(), nx->result_id(), x->result_id()); |
648 | Instruction* cubesc_case_2 = ir_builder.AddSelect( |
649 | float_type_id, is_x_neg->result_id(), z->result_id(), nz->result_id()); |
650 | Instruction* sel = |
651 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), x->result_id(), |
652 | cubesc_case_2->result_id()); |
653 | Instruction* cubesc = |
654 | ir_builder.AddSelect(float_type_id, is_z_max->result_id(), |
655 | cubesc_case_1->result_id(), sel->result_id()); |
656 | |
657 | // Select the correct value for cubetc. |
658 | Instruction* cubetc_case_1 = ir_builder.AddSelect( |
659 | float_type_id, is_y_neg->result_id(), nz->result_id(), z->result_id()); |
660 | Instruction* cubetc = |
661 | ir_builder.AddSelect(float_type_id, is_y_max->result_id(), |
662 | cubetc_case_1->result_id(), ny->result_id()); |
663 | |
664 | // Do the division |
665 | Instruction* cube = ir_builder.AddCompositeConstruct( |
666 | v2_float_type_id, {cubesc->result_id(), cubetc->result_id()}); |
667 | Instruction* denom = ir_builder.AddCompositeConstruct( |
668 | v2_float_type_id, {cubema->result_id(), cubema->result_id()}); |
669 | Instruction* div = ir_builder.AddBinaryOp( |
670 | v2_float_type_id, SpvOpFDiv, cube->result_id(), denom->result_id()); |
671 | |
672 | // Get the final result by adding 0.5 to |div|. |
673 | inst->SetOpcode(SpvOpFAdd); |
674 | Instruction::OperandList new_operands; |
675 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {div->result_id()}}); |
676 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {vec_const_id}}); |
677 | |
678 | inst->SetInOperands(std::move(new_operands)); |
679 | ctx->UpdateDefUse(inst); |
680 | return true; |
681 | } |
682 | |
683 | // A folding rule that will replace the CubeFaceIndexAMD extended |
684 | // instruction in the SPV_AMD_gcn_shader_ballot. Returns true if the folding |
685 | // is successful. |
686 | // |
687 | // The instruction |
688 | // |
689 | // %result = OpExtInst %float %1 CubeFaceIndexAMD %input |
690 | // |
691 | // with |
692 | // |
693 | // %x = OpCompositeExtract %float %input 0 |
694 | // %y = OpCompositeExtract %float %input 1 |
695 | // %z = OpCompositeExtract %float %input 2 |
696 | // %ax = OpExtInst %float %n_1 FAbs %x |
697 | // %ay = OpExtInst %float %n_1 FAbs %y |
698 | // %az = OpExtInst %float %n_1 FAbs %z |
699 | // %is_z_neg = OpFOrdLessThan %bool %z %float_0 |
700 | // %is_y_neg = OpFOrdLessThan %bool %y %float_0 |
701 | // %is_x_neg = OpFOrdLessThan %bool %x %float_0 |
702 | // %amax_x_y = OpExtInst %float %n_1 FMax %ax %ay |
703 | // %is_z_max = OpFOrdGreaterThanEqual %bool %az %amax_x_y |
704 | // %y_gt_x = OpFOrdGreaterThanEqual %bool %ay %ax |
705 | // %case_z = OpSelect %float %is_z_neg %float_5 %float4 |
706 | // %case_y = OpSelect %float %is_y_neg %float_3 %float2 |
707 | // %case_x = OpSelect %float %is_x_neg %float_1 %float0 |
708 | // %sel = OpSelect %float %y_gt_x %case_y %case_x |
709 | // %result = OpSelect %float %is_z_max %case_z %sel |
710 | // |
711 | // Also adding the capabilities and builtins that are needed. |
712 | bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst, |
713 | const std::vector<const analysis::Constant*>&) { |
714 | analysis::TypeManager* type_mgr = ctx->get_type_mgr(); |
715 | analysis::ConstantManager* const_mgr = ctx->get_constant_mgr(); |
716 | |
717 | uint32_t float_type_id = type_mgr->GetFloatTypeId(); |
718 | uint32_t bool_id = type_mgr->GetBoolTypeId(); |
719 | |
720 | InstructionBuilder ir_builder( |
721 | ctx, inst, |
722 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
723 | |
724 | uint32_t input_id = inst->GetSingleWordInOperand(2); |
725 | uint32_t glsl405_ext_inst_id = |
726 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
727 | if (glsl405_ext_inst_id == 0) { |
728 | ctx->AddExtInstImport("GLSL.std.450" ); |
729 | glsl405_ext_inst_id = |
730 | ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); |
731 | } |
732 | |
733 | // Get the constants that will be used. |
734 | uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); |
735 | uint32_t f1_const_id = const_mgr->GetFloatConst(1.0); |
736 | uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); |
737 | uint32_t f3_const_id = const_mgr->GetFloatConst(3.0); |
738 | uint32_t f4_const_id = const_mgr->GetFloatConst(4.0); |
739 | uint32_t f5_const_id = const_mgr->GetFloatConst(5.0); |
740 | |
741 | // Extract the input values. |
742 | Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); |
743 | Instruction* y = ir_builder.AddCompositeExtract(float_type_id, input_id, {1}); |
744 | Instruction* z = ir_builder.AddCompositeExtract(float_type_id, input_id, {2}); |
745 | |
746 | // Get the absolute values of the inputs. |
747 | Instruction* ax = ir_builder.AddNaryExtendedInstruction( |
748 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {x->result_id()}); |
749 | Instruction* ay = ir_builder.AddNaryExtendedInstruction( |
750 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {y->result_id()}); |
751 | Instruction* az = ir_builder.AddNaryExtendedInstruction( |
752 | float_type_id, glsl405_ext_inst_id, GLSLstd450FAbs, {z->result_id()}); |
753 | |
754 | // Find which values are negative. Used in later computations. |
755 | Instruction* is_z_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
756 | z->result_id(), f0_const_id); |
757 | Instruction* is_y_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
758 | y->result_id(), f0_const_id); |
759 | Instruction* is_x_neg = ir_builder.AddBinaryOp(bool_id, SpvOpFOrdLessThan, |
760 | x->result_id(), f0_const_id); |
761 | |
762 | // Find the max value. |
763 | Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction( |
764 | float_type_id, glsl405_ext_inst_id, GLSLstd450FMax, |
765 | {ax->result_id(), ay->result_id()}); |
766 | Instruction* is_z_max = |
767 | ir_builder.AddBinaryOp(bool_id, SpvOpFOrdGreaterThanEqual, |
768 | az->result_id(), amax_x_y->result_id()); |
769 | Instruction* y_gr_x = ir_builder.AddBinaryOp( |
770 | bool_id, SpvOpFOrdGreaterThanEqual, ay->result_id(), ax->result_id()); |
771 | |
772 | // Get the value for each case. |
773 | Instruction* case_z = ir_builder.AddSelect( |
774 | float_type_id, is_z_neg->result_id(), f5_const_id, f4_const_id); |
775 | Instruction* case_y = ir_builder.AddSelect( |
776 | float_type_id, is_y_neg->result_id(), f3_const_id, f2_const_id); |
777 | Instruction* case_x = ir_builder.AddSelect( |
778 | float_type_id, is_x_neg->result_id(), f1_const_id, f0_const_id); |
779 | |
780 | // Select the correct case. |
781 | Instruction* sel = |
782 | ir_builder.AddSelect(float_type_id, y_gr_x->result_id(), |
783 | case_y->result_id(), case_x->result_id()); |
784 | |
785 | // Get the final result by adding 0.5 to |div|. |
786 | inst->SetOpcode(SpvOpSelect); |
787 | Instruction::OperandList new_operands; |
788 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {is_z_max->result_id()}}); |
789 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {case_z->result_id()}}); |
790 | new_operands.push_back({SPV_OPERAND_TYPE_ID, {sel->result_id()}}); |
791 | |
792 | inst->SetInOperands(std::move(new_operands)); |
793 | ctx->UpdateDefUse(inst); |
794 | return true; |
795 | } |
796 | |
797 | // A folding rule that will replace the TimeAMD extended instruction in the |
798 | // SPV_AMD_gcn_shader_ballot. It returns true if the folding is successful. |
799 | // It returns False, otherwise. |
800 | // |
801 | // The instruction |
802 | // |
803 | // %result = OpExtInst %uint64 %1 TimeAMD |
804 | // |
805 | // with |
806 | // |
807 | // %result = OpReadClockKHR %uint64 %uint_3 |
808 | // |
809 | // NOTE: TimeAMD uses subgroup scope (it is not a real time clock). |
810 | bool ReplaceTimeAMD(IRContext* ctx, Instruction* inst, |
811 | const std::vector<const analysis::Constant*>&) { |
812 | InstructionBuilder ir_builder( |
813 | ctx, inst, |
814 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
815 | ctx->AddExtension("SPV_KHR_shader_clock" ); |
816 | ctx->AddCapability(SpvCapabilityShaderClockKHR); |
817 | |
818 | inst->SetOpcode(SpvOpReadClockKHR); |
819 | Instruction::OperandList args; |
820 | uint32_t subgroup_scope_id = ir_builder.GetUintConstantId(SpvScopeSubgroup); |
821 | args.push_back({SPV_OPERAND_TYPE_ID, {subgroup_scope_id}}); |
822 | inst->SetInOperands(std::move(args)); |
823 | ctx->UpdateDefUse(inst); |
824 | |
825 | return true; |
826 | } |
827 | |
828 | class AmdExtFoldingRules : public FoldingRules { |
829 | public: |
830 | explicit AmdExtFoldingRules(IRContext* ctx) : FoldingRules(ctx) {} |
831 | |
832 | protected: |
833 | virtual void AddFoldingRules() override { |
834 | rules_[SpvOpGroupIAddNonUniformAMD].push_back( |
835 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformIAdd>); |
836 | rules_[SpvOpGroupFAddNonUniformAMD].push_back( |
837 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFAdd>); |
838 | rules_[SpvOpGroupUMinNonUniformAMD].push_back( |
839 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMin>); |
840 | rules_[SpvOpGroupSMinNonUniformAMD].push_back( |
841 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMin>); |
842 | rules_[SpvOpGroupFMinNonUniformAMD].push_back( |
843 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMin>); |
844 | rules_[SpvOpGroupUMaxNonUniformAMD].push_back( |
845 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformUMax>); |
846 | rules_[SpvOpGroupSMaxNonUniformAMD].push_back( |
847 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformSMax>); |
848 | rules_[SpvOpGroupFMaxNonUniformAMD].push_back( |
849 | ReplaceGroupNonuniformOperationOpCode<SpvOpGroupNonUniformFMax>); |
850 | |
851 | uint32_t extension_id = |
852 | context()->module()->GetExtInstImportId("SPV_AMD_shader_ballot" ); |
853 | |
854 | if (extension_id != 0) { |
855 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsAMD}] |
856 | .push_back(ReplaceSwizzleInvocations); |
857 | ext_rules_[{extension_id, AmdShaderBallotSwizzleInvocationsMaskedAMD}] |
858 | .push_back(ReplaceSwizzleInvocationsMasked); |
859 | ext_rules_[{extension_id, AmdShaderBallotWriteInvocationAMD}].push_back( |
860 | ReplaceWriteInvocation); |
861 | ext_rules_[{extension_id, AmdShaderBallotMbcntAMD}].push_back( |
862 | ReplaceMbcnt); |
863 | } |
864 | |
865 | extension_id = context()->module()->GetExtInstImportId( |
866 | "SPV_AMD_shader_trinary_minmax" ); |
867 | |
868 | if (extension_id != 0) { |
869 | ext_rules_[{extension_id, FMin3AMD}].push_back( |
870 | ReplaceTrinaryMinMax<GLSLstd450FMin>); |
871 | ext_rules_[{extension_id, UMin3AMD}].push_back( |
872 | ReplaceTrinaryMinMax<GLSLstd450UMin>); |
873 | ext_rules_[{extension_id, SMin3AMD}].push_back( |
874 | ReplaceTrinaryMinMax<GLSLstd450SMin>); |
875 | ext_rules_[{extension_id, FMax3AMD}].push_back( |
876 | ReplaceTrinaryMinMax<GLSLstd450FMax>); |
877 | ext_rules_[{extension_id, UMax3AMD}].push_back( |
878 | ReplaceTrinaryMinMax<GLSLstd450UMax>); |
879 | ext_rules_[{extension_id, SMax3AMD}].push_back( |
880 | ReplaceTrinaryMinMax<GLSLstd450SMax>); |
881 | ext_rules_[{extension_id, FMid3AMD}].push_back( |
882 | ReplaceTrinaryMid<GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp>); |
883 | ext_rules_[{extension_id, UMid3AMD}].push_back( |
884 | ReplaceTrinaryMid<GLSLstd450UMin, GLSLstd450UMax, GLSLstd450UClamp>); |
885 | ext_rules_[{extension_id, SMid3AMD}].push_back( |
886 | ReplaceTrinaryMid<GLSLstd450SMin, GLSLstd450SMax, GLSLstd450SClamp>); |
887 | } |
888 | |
889 | extension_id = |
890 | context()->module()->GetExtInstImportId("SPV_AMD_gcn_shader" ); |
891 | |
892 | if (extension_id != 0) { |
893 | ext_rules_[{extension_id, CubeFaceCoordAMD}].push_back( |
894 | ReplaceCubeFaceCoord); |
895 | ext_rules_[{extension_id, CubeFaceIndexAMD}].push_back( |
896 | ReplaceCubeFaceIndex); |
897 | ext_rules_[{extension_id, TimeAMD}].push_back(ReplaceTimeAMD); |
898 | } |
899 | } |
900 | }; |
901 | |
902 | class AmdExtConstFoldingRules : public ConstantFoldingRules { |
903 | public: |
904 | AmdExtConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {} |
905 | |
906 | protected: |
907 | virtual void AddFoldingRules() override {} |
908 | }; |
909 | |
910 | } // namespace |
911 | |
912 | Pass::Status AmdExtensionToKhrPass::Process() { |
913 | bool changed = false; |
914 | |
915 | // Traverse the body of the functions to replace instructions that require |
916 | // the extensions. |
917 | InstructionFolder folder( |
918 | context(), |
919 | std::unique_ptr<AmdExtFoldingRules>(new AmdExtFoldingRules(context())), |
920 | MakeUnique<AmdExtConstFoldingRules>(context())); |
921 | for (Function& func : *get_module()) { |
922 | func.ForEachInst([&changed, &folder](Instruction* inst) { |
923 | if (folder.FoldInstruction(inst)) { |
924 | changed = true; |
925 | } |
926 | }); |
927 | } |
928 | |
929 | // Now that instruction that require the extensions have been removed, we can |
930 | // remove the extension instructions. |
931 | std::set<std::string> ext_to_remove = {"SPV_AMD_shader_ballot" , |
932 | "SPV_AMD_shader_trinary_minmax" , |
933 | "SPV_AMD_gcn_shader" }; |
934 | |
935 | std::vector<Instruction*> to_be_killed; |
936 | for (Instruction& inst : context()->module()->extensions()) { |
937 | if (inst.opcode() == SpvOpExtension) { |
938 | if (ext_to_remove.count(reinterpret_cast<const char*>( |
939 | &(inst.GetInOperand(0).words[0]))) != 0) { |
940 | to_be_killed.push_back(&inst); |
941 | } |
942 | } |
943 | } |
944 | |
945 | for (Instruction& inst : context()->ext_inst_imports()) { |
946 | if (inst.opcode() == SpvOpExtInstImport) { |
947 | if (ext_to_remove.count(reinterpret_cast<const char*>( |
948 | &(inst.GetInOperand(0).words[0]))) != 0) { |
949 | to_be_killed.push_back(&inst); |
950 | } |
951 | } |
952 | } |
953 | |
954 | for (Instruction* inst : to_be_killed) { |
955 | context()->KillInst(inst); |
956 | changed = true; |
957 | } |
958 | |
959 | // The replacements that take place use instructions that are missing before |
960 | // SPIR-V 1.3. If we changed something, we will have to make sure the version |
961 | // is at least SPIR-V 1.3 to make sure those instruction can be used. |
962 | if (changed) { |
963 | uint32_t version = get_module()->version(); |
964 | if (version < 0x00010300 /*1.3*/) { |
965 | get_module()->set_version(0x00010300); |
966 | } |
967 | } |
968 | return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
969 | } |
970 | |
971 | } // namespace opt |
972 | } // namespace spvtools |
973 | |