1// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/folding_rules.h"
16
17#include <limits>
18#include <memory>
19#include <utility>
20
21#include "ir_builder.h"
22#include "source/latest_version_glsl_std_450_header.h"
23#include "source/opt/ir_context.h"
24
25namespace spvtools {
26namespace opt {
27namespace {
28
29const uint32_t kExtractCompositeIdInIdx = 0;
30const uint32_t kInsertObjectIdInIdx = 0;
31const uint32_t kInsertCompositeIdInIdx = 1;
32const uint32_t kExtInstSetIdInIdx = 0;
33const uint32_t kExtInstInstructionInIdx = 1;
34const uint32_t kFMixXIdInIdx = 2;
35const uint32_t kFMixYIdInIdx = 3;
36const uint32_t kFMixAIdInIdx = 4;
37const uint32_t kStoreObjectInIdx = 1;
38
39// Some image instructions may contain an "image operands" argument.
40// Returns the operand index for the "image operands".
41// Returns -1 if the instruction does not have image operands.
42int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {
43 const auto opcode = inst->opcode();
44 switch (opcode) {
45 case SpvOpImageSampleImplicitLod:
46 case SpvOpImageSampleExplicitLod:
47 case SpvOpImageSampleProjImplicitLod:
48 case SpvOpImageSampleProjExplicitLod:
49 case SpvOpImageFetch:
50 case SpvOpImageRead:
51 case SpvOpImageSparseSampleImplicitLod:
52 case SpvOpImageSparseSampleExplicitLod:
53 case SpvOpImageSparseSampleProjImplicitLod:
54 case SpvOpImageSparseSampleProjExplicitLod:
55 case SpvOpImageSparseFetch:
56 case SpvOpImageSparseRead:
57 return inst->NumOperands() > 4 ? 2 : -1;
58 case SpvOpImageSampleDrefImplicitLod:
59 case SpvOpImageSampleDrefExplicitLod:
60 case SpvOpImageSampleProjDrefImplicitLod:
61 case SpvOpImageSampleProjDrefExplicitLod:
62 case SpvOpImageGather:
63 case SpvOpImageDrefGather:
64 case SpvOpImageSparseSampleDrefImplicitLod:
65 case SpvOpImageSparseSampleDrefExplicitLod:
66 case SpvOpImageSparseSampleProjDrefImplicitLod:
67 case SpvOpImageSparseSampleProjDrefExplicitLod:
68 case SpvOpImageSparseGather:
69 case SpvOpImageSparseDrefGather:
70 return inst->NumOperands() > 5 ? 3 : -1;
71 case SpvOpImageWrite:
72 return inst->NumOperands() > 3 ? 3 : -1;
73 default:
74 return -1;
75 }
76}
77
78// Returns the element width of |type|.
79uint32_t ElementWidth(const analysis::Type* type) {
80 if (const analysis::Vector* vec_type = type->AsVector()) {
81 return ElementWidth(vec_type->element_type());
82 } else if (const analysis::Float* float_type = type->AsFloat()) {
83 return float_type->width();
84 } else {
85 assert(type->AsInteger());
86 return type->AsInteger()->width();
87 }
88}
89
90// Returns true if |type| is Float or a vector of Float.
91bool HasFloatingPoint(const analysis::Type* type) {
92 if (type->AsFloat()) {
93 return true;
94 } else if (const analysis::Vector* vec_type = type->AsVector()) {
95 return vec_type->element_type()->AsFloat() != nullptr;
96 }
97
98 return false;
99}
100
101// Returns false if |val| is NaN, infinite or subnormal.
102template <typename T>
103bool IsValidResult(T val) {
104 int classified = std::fpclassify(val);
105 switch (classified) {
106 case FP_NAN:
107 case FP_INFINITE:
108 case FP_SUBNORMAL:
109 return false;
110 default:
111 return true;
112 }
113}
114
115const analysis::Constant* ConstInput(
116 const std::vector<const analysis::Constant*>& constants) {
117 return constants[0] ? constants[0] : constants[1];
118}
119
120Instruction* NonConstInput(IRContext* context, const analysis::Constant* c,
121 Instruction* inst) {
122 uint32_t in_op = c ? 1u : 0u;
123 return context->get_def_use_mgr()->GetDef(
124 inst->GetSingleWordInOperand(in_op));
125}
126
127// Returns the negation of |c|. |c| must be a 32 or 64 bit floating point
128// constant.
129uint32_t NegateFloatingPointConstant(analysis::ConstantManager* const_mgr,
130 const analysis::Constant* c) {
131 assert(c);
132 assert(c->type()->AsFloat());
133 uint32_t width = c->type()->AsFloat()->width();
134 assert(width == 32 || width == 64);
135 std::vector<uint32_t> words;
136 if (width == 64) {
137 utils::FloatProxy<double> result(c->GetDouble() * -1.0);
138 words = result.GetWords();
139 } else {
140 utils::FloatProxy<float> result(c->GetFloat() * -1.0f);
141 words = result.GetWords();
142 }
143
144 const analysis::Constant* negated_const =
145 const_mgr->GetConstant(c->type(), std::move(words));
146 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
147}
148
149std::vector<uint32_t> ExtractInts(uint64_t val) {
150 std::vector<uint32_t> words;
151 words.push_back(static_cast<uint32_t>(val));
152 words.push_back(static_cast<uint32_t>(val >> 32));
153 return words;
154}
155
156// Negates the integer constant |c|. Returns the id of the defining instruction.
157uint32_t NegateIntegerConstant(analysis::ConstantManager* const_mgr,
158 const analysis::Constant* c) {
159 assert(c);
160 assert(c->type()->AsInteger());
161 uint32_t width = c->type()->AsInteger()->width();
162 assert(width == 32 || width == 64);
163 std::vector<uint32_t> words;
164 if (width == 64) {
165 uint64_t uval = static_cast<uint64_t>(0 - c->GetU64());
166 words = ExtractInts(uval);
167 } else {
168 words.push_back(static_cast<uint32_t>(0 - c->GetU32()));
169 }
170
171 const analysis::Constant* negated_const =
172 const_mgr->GetConstant(c->type(), std::move(words));
173 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
174}
175
176// Negates the vector constant |c|. Returns the id of the defining instruction.
177uint32_t NegateVectorConstant(analysis::ConstantManager* const_mgr,
178 const analysis::Constant* c) {
179 assert(const_mgr && c);
180 assert(c->type()->AsVector());
181 if (c->AsNullConstant()) {
182 // 0.0 vs -0.0 shouldn't matter.
183 return const_mgr->GetDefiningInstruction(c)->result_id();
184 } else {
185 const analysis::Type* component_type =
186 c->AsVectorConstant()->component_type();
187 std::vector<uint32_t> words;
188 for (auto& comp : c->AsVectorConstant()->GetComponents()) {
189 if (component_type->AsFloat()) {
190 words.push_back(NegateFloatingPointConstant(const_mgr, comp));
191 } else {
192 assert(component_type->AsInteger());
193 words.push_back(NegateIntegerConstant(const_mgr, comp));
194 }
195 }
196
197 const analysis::Constant* negated_const =
198 const_mgr->GetConstant(c->type(), std::move(words));
199 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
200 }
201}
202
203// Negates |c|. Returns the id of the defining instruction.
204uint32_t NegateConstant(analysis::ConstantManager* const_mgr,
205 const analysis::Constant* c) {
206 if (c->type()->AsVector()) {
207 return NegateVectorConstant(const_mgr, c);
208 } else if (c->type()->AsFloat()) {
209 return NegateFloatingPointConstant(const_mgr, c);
210 } else {
211 assert(c->type()->AsInteger());
212 return NegateIntegerConstant(const_mgr, c);
213 }
214}
215
216// Takes the reciprocal of |c|. |c|'s type must be Float or a vector of Float.
217// Returns 0 if the reciprocal is NaN, infinite or subnormal.
218uint32_t Reciprocal(analysis::ConstantManager* const_mgr,
219 const analysis::Constant* c) {
220 assert(const_mgr && c);
221 assert(c->type()->AsFloat());
222
223 uint32_t width = c->type()->AsFloat()->width();
224 assert(width == 32 || width == 64);
225 std::vector<uint32_t> words;
226 if (width == 64) {
227 spvtools::utils::FloatProxy<double> result(1.0 / c->GetDouble());
228 if (!IsValidResult(result.getAsFloat())) return 0;
229 words = result.GetWords();
230 } else {
231 spvtools::utils::FloatProxy<float> result(1.0f / c->GetFloat());
232 if (!IsValidResult(result.getAsFloat())) return 0;
233 words = result.GetWords();
234 }
235
236 const analysis::Constant* negated_const =
237 const_mgr->GetConstant(c->type(), std::move(words));
238 return const_mgr->GetDefiningInstruction(negated_const)->result_id();
239}
240
241// Replaces fdiv where second operand is constant with fmul.
242FoldingRule ReciprocalFDiv() {
243 return [](IRContext* context, Instruction* inst,
244 const std::vector<const analysis::Constant*>& constants) {
245 assert(inst->opcode() == SpvOpFDiv);
246 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
247 const analysis::Type* type =
248 context->get_type_mgr()->GetType(inst->type_id());
249 if (!inst->IsFloatingPointFoldingAllowed()) return false;
250
251 uint32_t width = ElementWidth(type);
252 if (width != 32 && width != 64) return false;
253
254 if (constants[1] != nullptr) {
255 uint32_t id = 0;
256 if (const analysis::VectorConstant* vector_const =
257 constants[1]->AsVectorConstant()) {
258 std::vector<uint32_t> neg_ids;
259 for (auto& comp : vector_const->GetComponents()) {
260 id = Reciprocal(const_mgr, comp);
261 if (id == 0) return false;
262 neg_ids.push_back(id);
263 }
264 const analysis::Constant* negated_const =
265 const_mgr->GetConstant(constants[1]->type(), std::move(neg_ids));
266 id = const_mgr->GetDefiningInstruction(negated_const)->result_id();
267 } else if (constants[1]->AsFloatConstant()) {
268 id = Reciprocal(const_mgr, constants[1]);
269 if (id == 0) return false;
270 } else {
271 // Don't fold a null constant.
272 return false;
273 }
274 inst->SetOpcode(SpvOpFMul);
275 inst->SetInOperands(
276 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0u)}},
277 {SPV_OPERAND_TYPE_ID, {id}}});
278 return true;
279 }
280
281 return false;
282 };
283}
284
285// Elides consecutive negate instructions.
286FoldingRule MergeNegateArithmetic() {
287 return [](IRContext* context, Instruction* inst,
288 const std::vector<const analysis::Constant*>& constants) {
289 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
290 (void)constants;
291 const analysis::Type* type =
292 context->get_type_mgr()->GetType(inst->type_id());
293 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
294 return false;
295
296 Instruction* op_inst =
297 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
298 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
299 return false;
300
301 if (op_inst->opcode() == inst->opcode()) {
302 // Elide negates.
303 inst->SetOpcode(SpvOpCopyObject);
304 inst->SetInOperands(
305 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0u)}}});
306 return true;
307 }
308
309 return false;
310 };
311}
312
313// Merges negate into a mul or div operation if that operation contains a
314// constant operand.
315// Cases:
316// -(x * 2) = x * -2
317// -(2 * x) = x * -2
318// -(x / 2) = x / -2
319// -(2 / x) = -2 / x
320FoldingRule MergeNegateMulDivArithmetic() {
321 return [](IRContext* context, Instruction* inst,
322 const std::vector<const analysis::Constant*>& constants) {
323 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
324 (void)constants;
325 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
326 const analysis::Type* type =
327 context->get_type_mgr()->GetType(inst->type_id());
328 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
329 return false;
330
331 Instruction* op_inst =
332 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
333 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
334 return false;
335
336 uint32_t width = ElementWidth(type);
337 if (width != 32 && width != 64) return false;
338
339 SpvOp opcode = op_inst->opcode();
340 if (opcode == SpvOpFMul || opcode == SpvOpFDiv || opcode == SpvOpIMul ||
341 opcode == SpvOpSDiv || opcode == SpvOpUDiv) {
342 std::vector<const analysis::Constant*> op_constants =
343 const_mgr->GetOperandConstants(op_inst);
344 // Merge negate into mul or div if one operand is constant.
345 if (op_constants[0] || op_constants[1]) {
346 bool zero_is_variable = op_constants[0] == nullptr;
347 const analysis::Constant* c = ConstInput(op_constants);
348 uint32_t neg_id = NegateConstant(const_mgr, c);
349 uint32_t non_const_id = zero_is_variable
350 ? op_inst->GetSingleWordInOperand(0u)
351 : op_inst->GetSingleWordInOperand(1u);
352 // Change this instruction to a mul/div.
353 inst->SetOpcode(op_inst->opcode());
354 if (opcode == SpvOpFDiv || opcode == SpvOpUDiv || opcode == SpvOpSDiv) {
355 uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
356 uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
357 inst->SetInOperands(
358 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
359 } else {
360 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
361 {SPV_OPERAND_TYPE_ID, {neg_id}}});
362 }
363 return true;
364 }
365 }
366
367 return false;
368 };
369}
370
371// Merges negate into a add or sub operation if that operation contains a
372// constant operand.
373// Cases:
374// -(x + 2) = -2 - x
375// -(2 + x) = -2 - x
376// -(x - 2) = 2 - x
377// -(2 - x) = x - 2
378FoldingRule MergeNegateAddSubArithmetic() {
379 return [](IRContext* context, Instruction* inst,
380 const std::vector<const analysis::Constant*>& constants) {
381 assert(inst->opcode() == SpvOpFNegate || inst->opcode() == SpvOpSNegate);
382 (void)constants;
383 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
384 const analysis::Type* type =
385 context->get_type_mgr()->GetType(inst->type_id());
386 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
387 return false;
388
389 Instruction* op_inst =
390 context->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
391 if (HasFloatingPoint(type) && !op_inst->IsFloatingPointFoldingAllowed())
392 return false;
393
394 uint32_t width = ElementWidth(type);
395 if (width != 32 && width != 64) return false;
396
397 if (op_inst->opcode() == SpvOpFAdd || op_inst->opcode() == SpvOpFSub ||
398 op_inst->opcode() == SpvOpIAdd || op_inst->opcode() == SpvOpISub) {
399 std::vector<const analysis::Constant*> op_constants =
400 const_mgr->GetOperandConstants(op_inst);
401 if (op_constants[0] || op_constants[1]) {
402 bool zero_is_variable = op_constants[0] == nullptr;
403 bool is_add = (op_inst->opcode() == SpvOpFAdd) ||
404 (op_inst->opcode() == SpvOpIAdd);
405 bool swap_operands = !is_add || zero_is_variable;
406 bool negate_const = is_add;
407 const analysis::Constant* c = ConstInput(op_constants);
408 uint32_t const_id = 0;
409 if (negate_const) {
410 const_id = NegateConstant(const_mgr, c);
411 } else {
412 const_id = zero_is_variable ? op_inst->GetSingleWordInOperand(1u)
413 : op_inst->GetSingleWordInOperand(0u);
414 }
415
416 // Swap operands if necessary and make the instruction a subtraction.
417 uint32_t op0 =
418 zero_is_variable ? op_inst->GetSingleWordInOperand(0u) : const_id;
419 uint32_t op1 =
420 zero_is_variable ? const_id : op_inst->GetSingleWordInOperand(1u);
421 if (swap_operands) std::swap(op0, op1);
422 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
423 inst->SetInOperands(
424 {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
425 return true;
426 }
427 }
428
429 return false;
430 };
431}
432
433// Returns true if |c| has a zero element.
434bool HasZero(const analysis::Constant* c) {
435 if (c->AsNullConstant()) {
436 return true;
437 }
438 if (const analysis::VectorConstant* vec_const = c->AsVectorConstant()) {
439 for (auto& comp : vec_const->GetComponents())
440 if (HasZero(comp)) return true;
441 } else {
442 assert(c->AsScalarConstant());
443 return c->AsScalarConstant()->IsZero();
444 }
445
446 return false;
447}
448
449// Performs |input1| |opcode| |input2| and returns the merged constant result
450// id. Returns 0 if the result is not a valid value. The input types must be
451// Float.
452uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
453 SpvOp opcode,
454 const analysis::Constant* input1,
455 const analysis::Constant* input2) {
456 const analysis::Type* type = input1->type();
457 assert(type->AsFloat());
458 uint32_t width = type->AsFloat()->width();
459 assert(width == 32 || width == 64);
460 std::vector<uint32_t> words;
461#define FOLD_OP(op) \
462 if (width == 64) { \
463 utils::FloatProxy<double> val = \
464 input1->GetDouble() op input2->GetDouble(); \
465 double dval = val.getAsFloat(); \
466 if (!IsValidResult(dval)) return 0; \
467 words = val.GetWords(); \
468 } else { \
469 utils::FloatProxy<float> val = input1->GetFloat() op input2->GetFloat(); \
470 float fval = val.getAsFloat(); \
471 if (!IsValidResult(fval)) return 0; \
472 words = val.GetWords(); \
473 }
474 switch (opcode) {
475 case SpvOpFMul:
476 FOLD_OP(*);
477 break;
478 case SpvOpFDiv:
479 if (HasZero(input2)) return 0;
480 FOLD_OP(/);
481 break;
482 case SpvOpFAdd:
483 FOLD_OP(+);
484 break;
485 case SpvOpFSub:
486 FOLD_OP(-);
487 break;
488 default:
489 assert(false && "Unexpected operation");
490 break;
491 }
492#undef FOLD_OP
493 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
494 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
495}
496
497// Performs |input1| |opcode| |input2| and returns the merged constant result
498// id. Returns 0 if the result is not a valid value. The input types must be
499// Integers.
500uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
501 SpvOp opcode, const analysis::Constant* input1,
502 const analysis::Constant* input2) {
503 assert(input1->type()->AsInteger());
504 const analysis::Integer* type = input1->type()->AsInteger();
505 uint32_t width = type->AsInteger()->width();
506 assert(width == 32 || width == 64);
507 std::vector<uint32_t> words;
508#define FOLD_OP(op) \
509 if (width == 64) { \
510 if (type->IsSigned()) { \
511 int64_t val = input1->GetS64() op input2->GetS64(); \
512 words = ExtractInts(static_cast<uint64_t>(val)); \
513 } else { \
514 uint64_t val = input1->GetU64() op input2->GetU64(); \
515 words = ExtractInts(val); \
516 } \
517 } else { \
518 if (type->IsSigned()) { \
519 int32_t val = input1->GetS32() op input2->GetS32(); \
520 words.push_back(static_cast<uint32_t>(val)); \
521 } else { \
522 uint32_t val = input1->GetU32() op input2->GetU32(); \
523 words.push_back(val); \
524 } \
525 }
526 switch (opcode) {
527 case SpvOpIMul:
528 FOLD_OP(*);
529 break;
530 case SpvOpSDiv:
531 case SpvOpUDiv:
532 assert(false && "Should not merge integer division");
533 break;
534 case SpvOpIAdd:
535 FOLD_OP(+);
536 break;
537 case SpvOpISub:
538 FOLD_OP(-);
539 break;
540 default:
541 assert(false && "Unexpected operation");
542 break;
543 }
544#undef FOLD_OP
545 const analysis::Constant* merged_const = const_mgr->GetConstant(type, words);
546 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
547}
548
549// Performs |input1| |opcode| |input2| and returns the merged constant result
550// id. Returns 0 if the result is not a valid value. The input types must be
551// Integers, Floats or Vectors of such.
552uint32_t PerformOperation(analysis::ConstantManager* const_mgr, SpvOp opcode,
553 const analysis::Constant* input1,
554 const analysis::Constant* input2) {
555 assert(input1 && input2);
556 const analysis::Type* type = input1->type();
557 std::vector<uint32_t> words;
558 if (const analysis::Vector* vector_type = type->AsVector()) {
559 const analysis::Type* ele_type = vector_type->element_type();
560 for (uint32_t i = 0; i != vector_type->element_count(); ++i) {
561 uint32_t id = 0;
562
563 const analysis::Constant* input1_comp = nullptr;
564 if (const analysis::VectorConstant* input1_vector =
565 input1->AsVectorConstant()) {
566 input1_comp = input1_vector->GetComponents()[i];
567 } else {
568 assert(input1->AsNullConstant());
569 input1_comp = const_mgr->GetConstant(ele_type, {});
570 }
571
572 const analysis::Constant* input2_comp = nullptr;
573 if (const analysis::VectorConstant* input2_vector =
574 input2->AsVectorConstant()) {
575 input2_comp = input2_vector->GetComponents()[i];
576 } else {
577 assert(input2->AsNullConstant());
578 input2_comp = const_mgr->GetConstant(ele_type, {});
579 }
580
581 if (ele_type->AsFloat()) {
582 id = PerformFloatingPointOperation(const_mgr, opcode, input1_comp,
583 input2_comp);
584 } else {
585 assert(ele_type->AsInteger());
586 id = PerformIntegerOperation(const_mgr, opcode, input1_comp,
587 input2_comp);
588 }
589 if (id == 0) return 0;
590 words.push_back(id);
591 }
592 const analysis::Constant* merged_const =
593 const_mgr->GetConstant(type, words);
594 return const_mgr->GetDefiningInstruction(merged_const)->result_id();
595 } else if (type->AsFloat()) {
596 return PerformFloatingPointOperation(const_mgr, opcode, input1, input2);
597 } else {
598 assert(type->AsInteger());
599 return PerformIntegerOperation(const_mgr, opcode, input1, input2);
600 }
601}
602
603// Merges consecutive multiplies where each contains one constant operand.
604// Cases:
605// 2 * (x * 2) = x * 4
606// 2 * (2 * x) = x * 4
607// (x * 2) * 2 = x * 4
608// (2 * x) * 2 = x * 4
609FoldingRule MergeMulMulArithmetic() {
610 return [](IRContext* context, Instruction* inst,
611 const std::vector<const analysis::Constant*>& constants) {
612 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
613 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
614 const analysis::Type* type =
615 context->get_type_mgr()->GetType(inst->type_id());
616 if (HasFloatingPoint(type) && !inst->IsFloatingPointFoldingAllowed())
617 return false;
618
619 uint32_t width = ElementWidth(type);
620 if (width != 32 && width != 64) return false;
621
622 // Determine the constant input and the variable input in |inst|.
623 const analysis::Constant* const_input1 = ConstInput(constants);
624 if (!const_input1) return false;
625 Instruction* other_inst = NonConstInput(context, constants[0], inst);
626 if (HasFloatingPoint(type) && !other_inst->IsFloatingPointFoldingAllowed())
627 return false;
628
629 if (other_inst->opcode() == inst->opcode()) {
630 std::vector<const analysis::Constant*> other_constants =
631 const_mgr->GetOperandConstants(other_inst);
632 const analysis::Constant* const_input2 = ConstInput(other_constants);
633 if (!const_input2) return false;
634
635 bool other_first_is_variable = other_constants[0] == nullptr;
636 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
637 const_input1, const_input2);
638 if (merged_id == 0) return false;
639
640 uint32_t non_const_id = other_first_is_variable
641 ? other_inst->GetSingleWordInOperand(0u)
642 : other_inst->GetSingleWordInOperand(1u);
643 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
644 {SPV_OPERAND_TYPE_ID, {merged_id}}});
645 return true;
646 }
647
648 return false;
649 };
650}
651
652// Merges divides into subsequent multiplies if each instruction contains one
653// constant operand. Does not support integer operations.
654// Cases:
655// 2 * (x / 2) = x * 1
656// 2 * (2 / x) = 4 / x
657// (x / 2) * 2 = x * 1
658// (2 / x) * 2 = 4 / x
659// (y / x) * x = y
660// x * (y / x) = y
661FoldingRule MergeMulDivArithmetic() {
662 return [](IRContext* context, Instruction* inst,
663 const std::vector<const analysis::Constant*>& constants) {
664 assert(inst->opcode() == SpvOpFMul);
665 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
666 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
667
668 const analysis::Type* type =
669 context->get_type_mgr()->GetType(inst->type_id());
670 if (!inst->IsFloatingPointFoldingAllowed()) return false;
671
672 uint32_t width = ElementWidth(type);
673 if (width != 32 && width != 64) return false;
674
675 for (uint32_t i = 0; i < 2; i++) {
676 uint32_t op_id = inst->GetSingleWordInOperand(i);
677 Instruction* op_inst = def_use_mgr->GetDef(op_id);
678 if (op_inst->opcode() == SpvOpFDiv) {
679 if (op_inst->GetSingleWordInOperand(1) ==
680 inst->GetSingleWordInOperand(1 - i)) {
681 inst->SetOpcode(SpvOpCopyObject);
682 inst->SetInOperands(
683 {{SPV_OPERAND_TYPE_ID, {op_inst->GetSingleWordInOperand(0)}}});
684 return true;
685 }
686 }
687 }
688
689 const analysis::Constant* const_input1 = ConstInput(constants);
690 if (!const_input1) return false;
691 Instruction* other_inst = NonConstInput(context, constants[0], inst);
692 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
693
694 if (other_inst->opcode() == SpvOpFDiv) {
695 std::vector<const analysis::Constant*> other_constants =
696 const_mgr->GetOperandConstants(other_inst);
697 const analysis::Constant* const_input2 = ConstInput(other_constants);
698 if (!const_input2 || HasZero(const_input2)) return false;
699
700 bool other_first_is_variable = other_constants[0] == nullptr;
701 // If the variable value is the second operand of the divide, multiply
702 // the constants together. Otherwise divide the constants.
703 uint32_t merged_id = PerformOperation(
704 const_mgr,
705 other_first_is_variable ? other_inst->opcode() : inst->opcode(),
706 const_input1, const_input2);
707 if (merged_id == 0) return false;
708
709 uint32_t non_const_id = other_first_is_variable
710 ? other_inst->GetSingleWordInOperand(0u)
711 : other_inst->GetSingleWordInOperand(1u);
712
713 // If the variable value is on the second operand of the div, then this
714 // operation is a div. Otherwise it should be a multiply.
715 inst->SetOpcode(other_first_is_variable ? inst->opcode()
716 : other_inst->opcode());
717 if (other_first_is_variable) {
718 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
719 {SPV_OPERAND_TYPE_ID, {merged_id}}});
720 } else {
721 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {merged_id}},
722 {SPV_OPERAND_TYPE_ID, {non_const_id}}});
723 }
724 return true;
725 }
726
727 return false;
728 };
729}
730
731// Merges multiply of constant and negation.
732// Cases:
733// (-x) * 2 = x * -2
734// 2 * (-x) = x * -2
735FoldingRule MergeMulNegateArithmetic() {
736 return [](IRContext* context, Instruction* inst,
737 const std::vector<const analysis::Constant*>& constants) {
738 assert(inst->opcode() == SpvOpFMul || inst->opcode() == SpvOpIMul);
739 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
740 const analysis::Type* type =
741 context->get_type_mgr()->GetType(inst->type_id());
742 bool uses_float = HasFloatingPoint(type);
743 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
744
745 uint32_t width = ElementWidth(type);
746 if (width != 32 && width != 64) return false;
747
748 const analysis::Constant* const_input1 = ConstInput(constants);
749 if (!const_input1) return false;
750 Instruction* other_inst = NonConstInput(context, constants[0], inst);
751 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
752 return false;
753
754 if (other_inst->opcode() == SpvOpFNegate ||
755 other_inst->opcode() == SpvOpSNegate) {
756 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
757
758 inst->SetInOperands(
759 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
760 {SPV_OPERAND_TYPE_ID, {neg_id}}});
761 return true;
762 }
763
764 return false;
765 };
766}
767
768// Merges consecutive divides if each instruction contains one constant operand.
769// Does not support integer division.
770// Cases:
771// 2 / (x / 2) = 4 / x
772// 4 / (2 / x) = 2 * x
773// (4 / x) / 2 = 2 / x
774// (x / 2) / 2 = x / 4
775FoldingRule MergeDivDivArithmetic() {
776 return [](IRContext* context, Instruction* inst,
777 const std::vector<const analysis::Constant*>& constants) {
778 assert(inst->opcode() == SpvOpFDiv);
779 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
780 const analysis::Type* type =
781 context->get_type_mgr()->GetType(inst->type_id());
782 if (!inst->IsFloatingPointFoldingAllowed()) return false;
783
784 uint32_t width = ElementWidth(type);
785 if (width != 32 && width != 64) return false;
786
787 const analysis::Constant* const_input1 = ConstInput(constants);
788 if (!const_input1 || HasZero(const_input1)) return false;
789 Instruction* other_inst = NonConstInput(context, constants[0], inst);
790 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
791
792 bool first_is_variable = constants[0] == nullptr;
793 if (other_inst->opcode() == inst->opcode()) {
794 std::vector<const analysis::Constant*> other_constants =
795 const_mgr->GetOperandConstants(other_inst);
796 const analysis::Constant* const_input2 = ConstInput(other_constants);
797 if (!const_input2 || HasZero(const_input2)) return false;
798
799 bool other_first_is_variable = other_constants[0] == nullptr;
800
801 SpvOp merge_op = inst->opcode();
802 if (other_first_is_variable) {
803 // Constants magnify.
804 merge_op = SpvOpFMul;
805 }
806
807 // This is an x / (*) case. Swap the inputs. Doesn't harm multiply
808 // because it is commutative.
809 if (first_is_variable) std::swap(const_input1, const_input2);
810 uint32_t merged_id =
811 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
812 if (merged_id == 0) return false;
813
814 uint32_t non_const_id = other_first_is_variable
815 ? other_inst->GetSingleWordInOperand(0u)
816 : other_inst->GetSingleWordInOperand(1u);
817
818 SpvOp op = inst->opcode();
819 if (!first_is_variable && !other_first_is_variable) {
820 // Effectively div of 1/x, so change to multiply.
821 op = SpvOpFMul;
822 }
823
824 uint32_t op1 = merged_id;
825 uint32_t op2 = non_const_id;
826 if (first_is_variable && other_first_is_variable) std::swap(op1, op2);
827 inst->SetOpcode(op);
828 inst->SetInOperands(
829 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
830 return true;
831 }
832
833 return false;
834 };
835}
836
837// Fold multiplies succeeded by divides where each instruction contains a
838// constant operand. Does not support integer divide.
839// Cases:
840// 4 / (x * 2) = 2 / x
841// 4 / (2 * x) = 2 / x
842// (x * 4) / 2 = x * 2
843// (4 * x) / 2 = x * 2
844// (x * y) / x = y
845// (y * x) / x = y
846FoldingRule MergeDivMulArithmetic() {
847 return [](IRContext* context, Instruction* inst,
848 const std::vector<const analysis::Constant*>& constants) {
849 assert(inst->opcode() == SpvOpFDiv);
850 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
851 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
852
853 const analysis::Type* type =
854 context->get_type_mgr()->GetType(inst->type_id());
855 if (!inst->IsFloatingPointFoldingAllowed()) return false;
856
857 uint32_t width = ElementWidth(type);
858 if (width != 32 && width != 64) return false;
859
860 uint32_t op_id = inst->GetSingleWordInOperand(0);
861 Instruction* op_inst = def_use_mgr->GetDef(op_id);
862
863 if (op_inst->opcode() == SpvOpFMul) {
864 for (uint32_t i = 0; i < 2; i++) {
865 if (op_inst->GetSingleWordInOperand(i) ==
866 inst->GetSingleWordInOperand(1)) {
867 inst->SetOpcode(SpvOpCopyObject);
868 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
869 {op_inst->GetSingleWordInOperand(1 - i)}}});
870 return true;
871 }
872 }
873 }
874
875 const analysis::Constant* const_input1 = ConstInput(constants);
876 if (!const_input1 || HasZero(const_input1)) return false;
877 Instruction* other_inst = NonConstInput(context, constants[0], inst);
878 if (!other_inst->IsFloatingPointFoldingAllowed()) return false;
879
880 bool first_is_variable = constants[0] == nullptr;
881 if (other_inst->opcode() == SpvOpFMul) {
882 std::vector<const analysis::Constant*> other_constants =
883 const_mgr->GetOperandConstants(other_inst);
884 const analysis::Constant* const_input2 = ConstInput(other_constants);
885 if (!const_input2) return false;
886
887 bool other_first_is_variable = other_constants[0] == nullptr;
888
889 // This is an x / (*) case. Swap the inputs.
890 if (first_is_variable) std::swap(const_input1, const_input2);
891 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
892 const_input1, const_input2);
893 if (merged_id == 0) return false;
894
895 uint32_t non_const_id = other_first_is_variable
896 ? other_inst->GetSingleWordInOperand(0u)
897 : other_inst->GetSingleWordInOperand(1u);
898
899 uint32_t op1 = merged_id;
900 uint32_t op2 = non_const_id;
901 if (first_is_variable) std::swap(op1, op2);
902
903 // Convert to multiply
904 if (first_is_variable) inst->SetOpcode(other_inst->opcode());
905 inst->SetInOperands(
906 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
907 return true;
908 }
909
910 return false;
911 };
912}
913
914// Fold divides of a constant and a negation.
915// Cases:
916// (-x) / 2 = x / -2
917// 2 / (-x) = 2 / -x
918FoldingRule MergeDivNegateArithmetic() {
919 return [](IRContext* context, Instruction* inst,
920 const std::vector<const analysis::Constant*>& constants) {
921 assert(inst->opcode() == SpvOpFDiv || inst->opcode() == SpvOpSDiv ||
922 inst->opcode() == SpvOpUDiv);
923 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
924 const analysis::Type* type =
925 context->get_type_mgr()->GetType(inst->type_id());
926 bool uses_float = HasFloatingPoint(type);
927 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
928
929 uint32_t width = ElementWidth(type);
930 if (width != 32 && width != 64) return false;
931
932 const analysis::Constant* const_input1 = ConstInput(constants);
933 if (!const_input1) return false;
934 Instruction* other_inst = NonConstInput(context, constants[0], inst);
935 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
936 return false;
937
938 bool first_is_variable = constants[0] == nullptr;
939 if (other_inst->opcode() == SpvOpFNegate ||
940 other_inst->opcode() == SpvOpSNegate) {
941 uint32_t neg_id = NegateConstant(const_mgr, const_input1);
942
943 if (first_is_variable) {
944 inst->SetInOperands(
945 {{SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}},
946 {SPV_OPERAND_TYPE_ID, {neg_id}}});
947 } else {
948 inst->SetInOperands(
949 {{SPV_OPERAND_TYPE_ID, {neg_id}},
950 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
951 }
952 return true;
953 }
954
955 return false;
956 };
957}
958
959// Folds addition of a constant and a negation.
960// Cases:
961// (-x) + 2 = 2 - x
962// 2 + (-x) = 2 - x
963FoldingRule MergeAddNegateArithmetic() {
964 return [](IRContext* context, Instruction* inst,
965 const std::vector<const analysis::Constant*>& constants) {
966 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
967 const analysis::Type* type =
968 context->get_type_mgr()->GetType(inst->type_id());
969 bool uses_float = HasFloatingPoint(type);
970 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
971
972 const analysis::Constant* const_input1 = ConstInput(constants);
973 if (!const_input1) return false;
974 Instruction* other_inst = NonConstInput(context, constants[0], inst);
975 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
976 return false;
977
978 if (other_inst->opcode() == SpvOpSNegate ||
979 other_inst->opcode() == SpvOpFNegate) {
980 inst->SetOpcode(HasFloatingPoint(type) ? SpvOpFSub : SpvOpISub);
981 uint32_t const_id = constants[0] ? inst->GetSingleWordInOperand(0u)
982 : inst->GetSingleWordInOperand(1u);
983 inst->SetInOperands(
984 {{SPV_OPERAND_TYPE_ID, {const_id}},
985 {SPV_OPERAND_TYPE_ID, {other_inst->GetSingleWordInOperand(0u)}}});
986 return true;
987 }
988 return false;
989 };
990}
991
992// Folds subtraction of a constant and a negation.
993// Cases:
994// (-x) - 2 = -2 - x
995// 2 - (-x) = x + 2
996FoldingRule MergeSubNegateArithmetic() {
997 return [](IRContext* context, Instruction* inst,
998 const std::vector<const analysis::Constant*>& constants) {
999 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1000 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1001 const analysis::Type* type =
1002 context->get_type_mgr()->GetType(inst->type_id());
1003 bool uses_float = HasFloatingPoint(type);
1004 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1005
1006 uint32_t width = ElementWidth(type);
1007 if (width != 32 && width != 64) return false;
1008
1009 const analysis::Constant* const_input1 = ConstInput(constants);
1010 if (!const_input1) return false;
1011 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1012 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1013 return false;
1014
1015 if (other_inst->opcode() == SpvOpSNegate ||
1016 other_inst->opcode() == SpvOpFNegate) {
1017 uint32_t op1 = 0;
1018 uint32_t op2 = 0;
1019 SpvOp opcode = inst->opcode();
1020 if (constants[0] != nullptr) {
1021 op1 = other_inst->GetSingleWordInOperand(0u);
1022 op2 = inst->GetSingleWordInOperand(0u);
1023 opcode = HasFloatingPoint(type) ? SpvOpFAdd : SpvOpIAdd;
1024 } else {
1025 op1 = NegateConstant(const_mgr, const_input1);
1026 op2 = other_inst->GetSingleWordInOperand(0u);
1027 }
1028
1029 inst->SetOpcode(opcode);
1030 inst->SetInOperands(
1031 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1032 return true;
1033 }
1034 return false;
1035 };
1036}
1037
1038// Folds addition of an addition where each operation has a constant operand.
1039// Cases:
1040// (x + 2) + 2 = x + 4
1041// (2 + x) + 2 = x + 4
1042// 2 + (x + 2) = x + 4
1043// 2 + (2 + x) = x + 4
1044FoldingRule MergeAddAddArithmetic() {
1045 return [](IRContext* context, Instruction* inst,
1046 const std::vector<const analysis::Constant*>& constants) {
1047 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1048 const analysis::Type* type =
1049 context->get_type_mgr()->GetType(inst->type_id());
1050 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1051 bool uses_float = HasFloatingPoint(type);
1052 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1053
1054 uint32_t width = ElementWidth(type);
1055 if (width != 32 && width != 64) return false;
1056
1057 const analysis::Constant* const_input1 = ConstInput(constants);
1058 if (!const_input1) return false;
1059 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1060 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1061 return false;
1062
1063 if (other_inst->opcode() == SpvOpFAdd ||
1064 other_inst->opcode() == SpvOpIAdd) {
1065 std::vector<const analysis::Constant*> other_constants =
1066 const_mgr->GetOperandConstants(other_inst);
1067 const analysis::Constant* const_input2 = ConstInput(other_constants);
1068 if (!const_input2) return false;
1069
1070 Instruction* non_const_input =
1071 NonConstInput(context, other_constants[0], other_inst);
1072 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1073 const_input1, const_input2);
1074 if (merged_id == 0) return false;
1075
1076 inst->SetInOperands(
1077 {{SPV_OPERAND_TYPE_ID, {non_const_input->result_id()}},
1078 {SPV_OPERAND_TYPE_ID, {merged_id}}});
1079 return true;
1080 }
1081 return false;
1082 };
1083}
1084
1085// Folds addition of a subtraction where each operation has a constant operand.
1086// Cases:
1087// (x - 2) + 2 = x + 0
1088// (2 - x) + 2 = 4 - x
1089// 2 + (x - 2) = x + 0
1090// 2 + (2 - x) = 4 - x
1091FoldingRule MergeAddSubArithmetic() {
1092 return [](IRContext* context, Instruction* inst,
1093 const std::vector<const analysis::Constant*>& constants) {
1094 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1095 const analysis::Type* type =
1096 context->get_type_mgr()->GetType(inst->type_id());
1097 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1098 bool uses_float = HasFloatingPoint(type);
1099 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1100
1101 uint32_t width = ElementWidth(type);
1102 if (width != 32 && width != 64) return false;
1103
1104 const analysis::Constant* const_input1 = ConstInput(constants);
1105 if (!const_input1) return false;
1106 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1107 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1108 return false;
1109
1110 if (other_inst->opcode() == SpvOpFSub ||
1111 other_inst->opcode() == SpvOpISub) {
1112 std::vector<const analysis::Constant*> other_constants =
1113 const_mgr->GetOperandConstants(other_inst);
1114 const analysis::Constant* const_input2 = ConstInput(other_constants);
1115 if (!const_input2) return false;
1116
1117 bool first_is_variable = other_constants[0] == nullptr;
1118 SpvOp op = inst->opcode();
1119 uint32_t op1 = 0;
1120 uint32_t op2 = 0;
1121 if (first_is_variable) {
1122 // Subtract constants. Non-constant operand is first.
1123 op1 = other_inst->GetSingleWordInOperand(0u);
1124 op2 = PerformOperation(const_mgr, other_inst->opcode(), const_input1,
1125 const_input2);
1126 } else {
1127 // Add constants. Constant operand is first. Change the opcode.
1128 op1 = PerformOperation(const_mgr, inst->opcode(), const_input1,
1129 const_input2);
1130 op2 = other_inst->GetSingleWordInOperand(1u);
1131 op = other_inst->opcode();
1132 }
1133 if (op1 == 0 || op2 == 0) return false;
1134
1135 inst->SetOpcode(op);
1136 inst->SetInOperands(
1137 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1138 return true;
1139 }
1140 return false;
1141 };
1142}
1143
1144// Folds subtraction of an addition where each operand has a constant operand.
1145// Cases:
1146// (x + 2) - 2 = x + 0
1147// (2 + x) - 2 = x + 0
1148// 2 - (x + 2) = 0 - x
1149// 2 - (2 + x) = 0 - x
1150FoldingRule MergeSubAddArithmetic() {
1151 return [](IRContext* context, Instruction* inst,
1152 const std::vector<const analysis::Constant*>& constants) {
1153 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1154 const analysis::Type* type =
1155 context->get_type_mgr()->GetType(inst->type_id());
1156 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1157 bool uses_float = HasFloatingPoint(type);
1158 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1159
1160 uint32_t width = ElementWidth(type);
1161 if (width != 32 && width != 64) return false;
1162
1163 const analysis::Constant* const_input1 = ConstInput(constants);
1164 if (!const_input1) return false;
1165 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1166 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1167 return false;
1168
1169 if (other_inst->opcode() == SpvOpFAdd ||
1170 other_inst->opcode() == SpvOpIAdd) {
1171 std::vector<const analysis::Constant*> other_constants =
1172 const_mgr->GetOperandConstants(other_inst);
1173 const analysis::Constant* const_input2 = ConstInput(other_constants);
1174 if (!const_input2) return false;
1175
1176 Instruction* non_const_input =
1177 NonConstInput(context, other_constants[0], other_inst);
1178
1179 // If the first operand of the sub is not a constant, swap the constants
1180 // so the subtraction has the correct operands.
1181 if (constants[0] == nullptr) std::swap(const_input1, const_input2);
1182 // Subtract the constants.
1183 uint32_t merged_id = PerformOperation(const_mgr, inst->opcode(),
1184 const_input1, const_input2);
1185 SpvOp op = inst->opcode();
1186 uint32_t op1 = 0;
1187 uint32_t op2 = 0;
1188 if (constants[0] == nullptr) {
1189 // Non-constant operand is first. Change the opcode.
1190 op1 = non_const_input->result_id();
1191 op2 = merged_id;
1192 op = other_inst->opcode();
1193 } else {
1194 // Constant operand is first.
1195 op1 = merged_id;
1196 op2 = non_const_input->result_id();
1197 }
1198 if (op1 == 0 || op2 == 0) return false;
1199
1200 inst->SetOpcode(op);
1201 inst->SetInOperands(
1202 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1203 return true;
1204 }
1205 return false;
1206 };
1207}
1208
1209// Folds subtraction of a subtraction where each operand has a constant operand.
1210// Cases:
1211// (x - 2) - 2 = x - 4
1212// (2 - x) - 2 = 0 - x
1213// 2 - (x - 2) = 4 - x
1214// 2 - (2 - x) = x + 0
1215FoldingRule MergeSubSubArithmetic() {
1216 return [](IRContext* context, Instruction* inst,
1217 const std::vector<const analysis::Constant*>& constants) {
1218 assert(inst->opcode() == SpvOpFSub || inst->opcode() == SpvOpISub);
1219 const analysis::Type* type =
1220 context->get_type_mgr()->GetType(inst->type_id());
1221 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1222 bool uses_float = HasFloatingPoint(type);
1223 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1224
1225 uint32_t width = ElementWidth(type);
1226 if (width != 32 && width != 64) return false;
1227
1228 const analysis::Constant* const_input1 = ConstInput(constants);
1229 if (!const_input1) return false;
1230 Instruction* other_inst = NonConstInput(context, constants[0], inst);
1231 if (uses_float && !other_inst->IsFloatingPointFoldingAllowed())
1232 return false;
1233
1234 if (other_inst->opcode() == SpvOpFSub ||
1235 other_inst->opcode() == SpvOpISub) {
1236 std::vector<const analysis::Constant*> other_constants =
1237 const_mgr->GetOperandConstants(other_inst);
1238 const analysis::Constant* const_input2 = ConstInput(other_constants);
1239 if (!const_input2) return false;
1240
1241 Instruction* non_const_input =
1242 NonConstInput(context, other_constants[0], other_inst);
1243
1244 // Merge the constants.
1245 uint32_t merged_id = 0;
1246 SpvOp merge_op = inst->opcode();
1247 if (other_constants[0] == nullptr) {
1248 merge_op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1249 } else if (constants[0] == nullptr) {
1250 std::swap(const_input1, const_input2);
1251 }
1252 merged_id =
1253 PerformOperation(const_mgr, merge_op, const_input1, const_input2);
1254 if (merged_id == 0) return false;
1255
1256 SpvOp op = inst->opcode();
1257 if (constants[0] != nullptr && other_constants[0] != nullptr) {
1258 // Change the operation.
1259 op = uses_float ? SpvOpFAdd : SpvOpIAdd;
1260 }
1261
1262 uint32_t op1 = 0;
1263 uint32_t op2 = 0;
1264 if ((constants[0] == nullptr) ^ (other_constants[0] == nullptr)) {
1265 op1 = merged_id;
1266 op2 = non_const_input->result_id();
1267 } else {
1268 op1 = non_const_input->result_id();
1269 op2 = merged_id;
1270 }
1271
1272 inst->SetOpcode(op);
1273 inst->SetInOperands(
1274 {{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}});
1275 return true;
1276 }
1277 return false;
1278 };
1279}
1280
1281// Helper function for MergeGenericAddSubArithmetic. If |addend| and
1282// subtrahend of |sub| is the same, merge to copy of minuend of |sub|.
1283bool MergeGenericAddendSub(uint32_t addend, uint32_t sub, Instruction* inst) {
1284 IRContext* context = inst->context();
1285 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1286 Instruction* sub_inst = def_use_mgr->GetDef(sub);
1287 if (sub_inst->opcode() != SpvOpFSub && sub_inst->opcode() != SpvOpISub)
1288 return false;
1289 if (sub_inst->opcode() == SpvOpFSub &&
1290 !sub_inst->IsFloatingPointFoldingAllowed())
1291 return false;
1292 if (addend != sub_inst->GetSingleWordInOperand(1)) return false;
1293 inst->SetOpcode(SpvOpCopyObject);
1294 inst->SetInOperands(
1295 {{SPV_OPERAND_TYPE_ID, {sub_inst->GetSingleWordInOperand(0)}}});
1296 context->UpdateDefUse(inst);
1297 return true;
1298}
1299
1300// Folds addition of a subtraction where the subtrahend is equal to the
1301// other addend. Return a copy of the minuend. Accepts generic (const and
1302// non-const) operands.
1303// Cases:
1304// (a - b) + b = a
1305// b + (a - b) = a
1306FoldingRule MergeGenericAddSubArithmetic() {
1307 return [](IRContext* context, Instruction* inst,
1308 const std::vector<const analysis::Constant*>&) {
1309 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1310 const analysis::Type* type =
1311 context->get_type_mgr()->GetType(inst->type_id());
1312 bool uses_float = HasFloatingPoint(type);
1313 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1314
1315 uint32_t width = ElementWidth(type);
1316 if (width != 32 && width != 64) return false;
1317
1318 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1319 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1320 if (MergeGenericAddendSub(add_op0, add_op1, inst)) return true;
1321 return MergeGenericAddendSub(add_op1, add_op0, inst);
1322 };
1323}
1324
1325// Helper function for FactorAddMuls. If |factor0_0| is the same as |factor1_0|,
1326// generate |factor0_0| * (|factor0_1| + |factor1_1|).
1327bool FactorAddMulsOpnds(uint32_t factor0_0, uint32_t factor0_1,
1328 uint32_t factor1_0, uint32_t factor1_1,
1329 Instruction* inst) {
1330 IRContext* context = inst->context();
1331 if (factor0_0 != factor1_0) return false;
1332 InstructionBuilder ir_builder(
1333 context, inst,
1334 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
1335 Instruction* new_add_inst = ir_builder.AddBinaryOp(
1336 inst->type_id(), inst->opcode(), factor0_1, factor1_1);
1337 inst->SetOpcode(inst->opcode() == SpvOpFAdd ? SpvOpFMul : SpvOpIMul);
1338 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {factor0_0}},
1339 {SPV_OPERAND_TYPE_ID, {new_add_inst->result_id()}}});
1340 context->UpdateDefUse(inst);
1341 return true;
1342}
1343
1344// Perform the following factoring identity, handling all operand order
1345// combinations: (a * b) + (a * c) = a * (b + c)
1346FoldingRule FactorAddMuls() {
1347 return [](IRContext* context, Instruction* inst,
1348 const std::vector<const analysis::Constant*>&) {
1349 assert(inst->opcode() == SpvOpFAdd || inst->opcode() == SpvOpIAdd);
1350 const analysis::Type* type =
1351 context->get_type_mgr()->GetType(inst->type_id());
1352 bool uses_float = HasFloatingPoint(type);
1353 if (uses_float && !inst->IsFloatingPointFoldingAllowed()) return false;
1354
1355 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1356 uint32_t add_op0 = inst->GetSingleWordInOperand(0);
1357 Instruction* add_op0_inst = def_use_mgr->GetDef(add_op0);
1358 if (add_op0_inst->opcode() != SpvOpFMul &&
1359 add_op0_inst->opcode() != SpvOpIMul)
1360 return false;
1361 uint32_t add_op1 = inst->GetSingleWordInOperand(1);
1362 Instruction* add_op1_inst = def_use_mgr->GetDef(add_op1);
1363 if (add_op1_inst->opcode() != SpvOpFMul &&
1364 add_op1_inst->opcode() != SpvOpIMul)
1365 return false;
1366
1367 // Only perform this optimization if both of the muls only have one use.
1368 // Otherwise this is a deoptimization in size and performance.
1369 if (def_use_mgr->NumUses(add_op0_inst) > 1) return false;
1370 if (def_use_mgr->NumUses(add_op1_inst) > 1) return false;
1371
1372 if (add_op0_inst->opcode() == SpvOpFMul &&
1373 (!add_op0_inst->IsFloatingPointFoldingAllowed() ||
1374 !add_op1_inst->IsFloatingPointFoldingAllowed()))
1375 return false;
1376
1377 for (int i = 0; i < 2; i++) {
1378 for (int j = 0; j < 2; j++) {
1379 // Check if operand i in add_op0_inst matches operand j in add_op1_inst.
1380 if (FactorAddMulsOpnds(add_op0_inst->GetSingleWordInOperand(i),
1381 add_op0_inst->GetSingleWordInOperand(1 - i),
1382 add_op1_inst->GetSingleWordInOperand(j),
1383 add_op1_inst->GetSingleWordInOperand(1 - j),
1384 inst))
1385 return true;
1386 }
1387 }
1388 return false;
1389 };
1390}
1391
1392FoldingRule IntMultipleBy1() {
1393 return [](IRContext*, Instruction* inst,
1394 const std::vector<const analysis::Constant*>& constants) {
1395 assert(inst->opcode() == SpvOpIMul && "Wrong opcode. Should be OpIMul.");
1396 for (uint32_t i = 0; i < 2; i++) {
1397 if (constants[i] == nullptr) {
1398 continue;
1399 }
1400 const analysis::IntConstant* int_constant = constants[i]->AsIntConstant();
1401 if (int_constant) {
1402 uint32_t width = ElementWidth(int_constant->type());
1403 if (width != 32 && width != 64) return false;
1404 bool is_one = (width == 32) ? int_constant->GetU32BitValue() == 1u
1405 : int_constant->GetU64BitValue() == 1ull;
1406 if (is_one) {
1407 inst->SetOpcode(SpvOpCopyObject);
1408 inst->SetInOperands(
1409 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1 - i)}}});
1410 return true;
1411 }
1412 }
1413 }
1414 return false;
1415 };
1416}
1417
1418FoldingRule CompositeConstructFeedingExtract() {
1419 return [](IRContext* context, Instruction* inst,
1420 const std::vector<const analysis::Constant*>&) {
1421 // If the input to an OpCompositeExtract is an OpCompositeConstruct,
1422 // then we can simply use the appropriate element in the construction.
1423 assert(inst->opcode() == SpvOpCompositeExtract &&
1424 "Wrong opcode. Should be OpCompositeExtract.");
1425 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1426 analysis::TypeManager* type_mgr = context->get_type_mgr();
1427
1428 // If there are no index operands, then this rule cannot do anything.
1429 if (inst->NumInOperands() <= 1) {
1430 return false;
1431 }
1432
1433 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1434 Instruction* cinst = def_use_mgr->GetDef(cid);
1435
1436 if (cinst->opcode() != SpvOpCompositeConstruct) {
1437 return false;
1438 }
1439
1440 std::vector<Operand> operands;
1441 analysis::Type* composite_type = type_mgr->GetType(cinst->type_id());
1442 if (composite_type->AsVector() == nullptr) {
1443 // Get the element being extracted from the OpCompositeConstruct
1444 // Since it is not a vector, it is simple to extract the single element.
1445 uint32_t element_index = inst->GetSingleWordInOperand(1);
1446 uint32_t element_id = cinst->GetSingleWordInOperand(element_index);
1447 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1448
1449 // Add the remaining indices for extraction.
1450 for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
1451 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1452 {inst->GetSingleWordInOperand(i)}});
1453 }
1454
1455 } else {
1456 // With vectors we have to handle the case where it is concatenating
1457 // vectors.
1458 assert(inst->NumInOperands() == 2 &&
1459 "Expecting a vector of scalar values.");
1460
1461 uint32_t element_index = inst->GetSingleWordInOperand(1);
1462 for (uint32_t construct_index = 0;
1463 construct_index < cinst->NumInOperands(); ++construct_index) {
1464 uint32_t element_id = cinst->GetSingleWordInOperand(construct_index);
1465 Instruction* element_def = def_use_mgr->GetDef(element_id);
1466 analysis::Vector* element_type =
1467 type_mgr->GetType(element_def->type_id())->AsVector();
1468 if (element_type) {
1469 uint32_t vector_size = element_type->element_count();
1470 if (vector_size < element_index) {
1471 // The element we want comes after this vector.
1472 element_index -= vector_size;
1473 } else {
1474 // We want an element of this vector.
1475 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1476 operands.push_back(
1477 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {element_index}});
1478 break;
1479 }
1480 } else {
1481 if (element_index == 0) {
1482 // This is a scalar, and we this is the element we are extracting.
1483 operands.push_back({SPV_OPERAND_TYPE_ID, {element_id}});
1484 break;
1485 } else {
1486 // Skip over this scalar value.
1487 --element_index;
1488 }
1489 }
1490 }
1491 }
1492
1493 // If there were no extra indices, then we have the final object. No need
1494 // to extract even more.
1495 if (operands.size() == 1) {
1496 inst->SetOpcode(SpvOpCopyObject);
1497 }
1498
1499 inst->SetInOperands(std::move(operands));
1500 return true;
1501 };
1502}
1503
1504// If the OpCompositeConstruct is simply putting back together elements that
1505// where extracted from the same source, we can simply reuse the source.
1506//
1507// This is a common code pattern because of the way that scalar replacement
1508// works.
1509bool CompositeExtractFeedingConstruct(
1510 IRContext* context, Instruction* inst,
1511 const std::vector<const analysis::Constant*>&) {
1512 assert(inst->opcode() == SpvOpCompositeConstruct &&
1513 "Wrong opcode. Should be OpCompositeConstruct.");
1514 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1515 uint32_t original_id = 0;
1516
1517 if (inst->NumInOperands() == 0) {
1518 // The struct being constructed has no members.
1519 return false;
1520 }
1521
1522 // Check each element to make sure they are:
1523 // - extractions
1524 // - extracting the same position they are inserting
1525 // - all extract from the same id.
1526 for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
1527 const uint32_t element_id = inst->GetSingleWordInOperand(i);
1528 Instruction* element_inst = def_use_mgr->GetDef(element_id);
1529
1530 if (element_inst->opcode() != SpvOpCompositeExtract) {
1531 return false;
1532 }
1533
1534 if (element_inst->NumInOperands() != 2) {
1535 return false;
1536 }
1537
1538 if (element_inst->GetSingleWordInOperand(1) != i) {
1539 return false;
1540 }
1541
1542 if (i == 0) {
1543 original_id =
1544 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1545 } else if (original_id !=
1546 element_inst->GetSingleWordInOperand(kExtractCompositeIdInIdx)) {
1547 return false;
1548 }
1549 }
1550
1551 // The last check it to see that the object being extracted from is the
1552 // correct type.
1553 Instruction* original_inst = def_use_mgr->GetDef(original_id);
1554 if (original_inst->type_id() != inst->type_id()) {
1555 return false;
1556 }
1557
1558 // Simplify by using the original object.
1559 inst->SetOpcode(SpvOpCopyObject);
1560 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
1561 return true;
1562}
1563
1564FoldingRule InsertFeedingExtract() {
1565 return [](IRContext* context, Instruction* inst,
1566 const std::vector<const analysis::Constant*>&) {
1567 assert(inst->opcode() == SpvOpCompositeExtract &&
1568 "Wrong opcode. Should be OpCompositeExtract.");
1569 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1570 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1571 Instruction* cinst = def_use_mgr->GetDef(cid);
1572
1573 if (cinst->opcode() != SpvOpCompositeInsert) {
1574 return false;
1575 }
1576
1577 // Find the first position where the list of insert and extract indicies
1578 // differ, if at all.
1579 uint32_t i;
1580 for (i = 1; i < inst->NumInOperands(); ++i) {
1581 if (i + 1 >= cinst->NumInOperands()) {
1582 break;
1583 }
1584
1585 if (inst->GetSingleWordInOperand(i) !=
1586 cinst->GetSingleWordInOperand(i + 1)) {
1587 break;
1588 }
1589 }
1590
1591 // We are extracting the element that was inserted.
1592 if (i == inst->NumInOperands() && i + 1 == cinst->NumInOperands()) {
1593 inst->SetOpcode(SpvOpCopyObject);
1594 inst->SetInOperands(
1595 {{SPV_OPERAND_TYPE_ID,
1596 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}}});
1597 return true;
1598 }
1599
1600 // Extracting the value that was inserted along with values for the base
1601 // composite. Cannot do anything.
1602 if (i == inst->NumInOperands()) {
1603 return false;
1604 }
1605
1606 // Extracting an element of the value that was inserted. Extract from
1607 // that value directly.
1608 if (i + 1 == cinst->NumInOperands()) {
1609 std::vector<Operand> operands;
1610 operands.push_back(
1611 {SPV_OPERAND_TYPE_ID,
1612 {cinst->GetSingleWordInOperand(kInsertObjectIdInIdx)}});
1613 for (; i < inst->NumInOperands(); ++i) {
1614 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1615 {inst->GetSingleWordInOperand(i)}});
1616 }
1617 inst->SetInOperands(std::move(operands));
1618 return true;
1619 }
1620
1621 // Extracting a value that is disjoint from the element being inserted.
1622 // Rewrite the extract to use the composite input to the insert.
1623 std::vector<Operand> operands;
1624 operands.push_back(
1625 {SPV_OPERAND_TYPE_ID,
1626 {cinst->GetSingleWordInOperand(kInsertCompositeIdInIdx)}});
1627 for (i = 1; i < inst->NumInOperands(); ++i) {
1628 operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER,
1629 {inst->GetSingleWordInOperand(i)}});
1630 }
1631 inst->SetInOperands(std::move(operands));
1632 return true;
1633 };
1634}
1635
1636// When a VectorShuffle is feeding an Extract, we can extract from one of the
1637// operands of the VectorShuffle. We just need to adjust the index in the
1638// extract instruction.
1639FoldingRule VectorShuffleFeedingExtract() {
1640 return [](IRContext* context, Instruction* inst,
1641 const std::vector<const analysis::Constant*>&) {
1642 assert(inst->opcode() == SpvOpCompositeExtract &&
1643 "Wrong opcode. Should be OpCompositeExtract.");
1644 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1645 analysis::TypeManager* type_mgr = context->get_type_mgr();
1646 uint32_t cid = inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1647 Instruction* cinst = def_use_mgr->GetDef(cid);
1648
1649 if (cinst->opcode() != SpvOpVectorShuffle) {
1650 return false;
1651 }
1652
1653 // Find the size of the first vector operand of the VectorShuffle
1654 Instruction* first_input =
1655 def_use_mgr->GetDef(cinst->GetSingleWordInOperand(0));
1656 analysis::Type* first_input_type =
1657 type_mgr->GetType(first_input->type_id());
1658 assert(first_input_type->AsVector() &&
1659 "Input to vector shuffle should be vectors.");
1660 uint32_t first_input_size = first_input_type->AsVector()->element_count();
1661
1662 // Get index of the element the vector shuffle is placing in the position
1663 // being extracted.
1664 uint32_t new_index =
1665 cinst->GetSingleWordInOperand(2 + inst->GetSingleWordInOperand(1));
1666
1667 // Extracting an undefined value so fold this extract into an undef.
1668 const uint32_t undef_literal_value = 0xffffffff;
1669 if (new_index == undef_literal_value) {
1670 inst->SetOpcode(SpvOpUndef);
1671 inst->SetInOperands({});
1672 return true;
1673 }
1674
1675 // Get the id of the of the vector the elemtent comes from, and update the
1676 // index if needed.
1677 uint32_t new_vector = 0;
1678 if (new_index < first_input_size) {
1679 new_vector = cinst->GetSingleWordInOperand(0);
1680 } else {
1681 new_vector = cinst->GetSingleWordInOperand(1);
1682 new_index -= first_input_size;
1683 }
1684
1685 // Update the extract instruction.
1686 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1687 inst->SetInOperand(1, {new_index});
1688 return true;
1689 };
1690}
1691
1692// When an FMix with is feeding an Extract that extracts an element whose
1693// corresponding |a| in the FMix is 0 or 1, we can extract from one of the
1694// operands of the FMix.
1695FoldingRule FMixFeedingExtract() {
1696 return [](IRContext* context, Instruction* inst,
1697 const std::vector<const analysis::Constant*>&) {
1698 assert(inst->opcode() == SpvOpCompositeExtract &&
1699 "Wrong opcode. Should be OpCompositeExtract.");
1700 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
1701 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
1702
1703 uint32_t composite_id =
1704 inst->GetSingleWordInOperand(kExtractCompositeIdInIdx);
1705 Instruction* composite_inst = def_use_mgr->GetDef(composite_id);
1706
1707 if (composite_inst->opcode() != SpvOpExtInst) {
1708 return false;
1709 }
1710
1711 uint32_t inst_set_id =
1712 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
1713
1714 if (composite_inst->GetSingleWordInOperand(kExtInstSetIdInIdx) !=
1715 inst_set_id ||
1716 composite_inst->GetSingleWordInOperand(kExtInstInstructionInIdx) !=
1717 GLSLstd450FMix) {
1718 return false;
1719 }
1720
1721 // Get the |a| for the FMix instruction.
1722 uint32_t a_id = composite_inst->GetSingleWordInOperand(kFMixAIdInIdx);
1723 std::unique_ptr<Instruction> a(inst->Clone(context));
1724 a->SetInOperand(kExtractCompositeIdInIdx, {a_id});
1725 context->get_instruction_folder().FoldInstruction(a.get());
1726
1727 if (a->opcode() != SpvOpCopyObject) {
1728 return false;
1729 }
1730
1731 const analysis::Constant* a_const =
1732 const_mgr->FindDeclaredConstant(a->GetSingleWordInOperand(0));
1733
1734 if (!a_const) {
1735 return false;
1736 }
1737
1738 bool use_x = false;
1739
1740 assert(a_const->type()->AsFloat());
1741 double element_value = a_const->GetValueAsDouble();
1742 if (element_value == 0.0) {
1743 use_x = true;
1744 } else if (element_value == 1.0) {
1745 use_x = false;
1746 } else {
1747 return false;
1748 }
1749
1750 // Get the id of the of the vector the element comes from.
1751 uint32_t new_vector = 0;
1752 if (use_x) {
1753 new_vector = composite_inst->GetSingleWordInOperand(kFMixXIdInIdx);
1754 } else {
1755 new_vector = composite_inst->GetSingleWordInOperand(kFMixYIdInIdx);
1756 }
1757
1758 // Update the extract instruction.
1759 inst->SetInOperand(kExtractCompositeIdInIdx, {new_vector});
1760 return true;
1761 };
1762}
1763
1764FoldingRule RedundantPhi() {
1765 // An OpPhi instruction where all values are the same or the result of the phi
1766 // itself, can be replaced by the value itself.
1767 return [](IRContext*, Instruction* inst,
1768 const std::vector<const analysis::Constant*>&) {
1769 assert(inst->opcode() == SpvOpPhi && "Wrong opcode. Should be OpPhi.");
1770
1771 uint32_t incoming_value = 0;
1772
1773 for (uint32_t i = 0; i < inst->NumInOperands(); i += 2) {
1774 uint32_t op_id = inst->GetSingleWordInOperand(i);
1775 if (op_id == inst->result_id()) {
1776 continue;
1777 }
1778
1779 if (incoming_value == 0) {
1780 incoming_value = op_id;
1781 } else if (op_id != incoming_value) {
1782 // Found two possible value. Can't simplify.
1783 return false;
1784 }
1785 }
1786
1787 if (incoming_value == 0) {
1788 // Code looks invalid. Don't do anything.
1789 return false;
1790 }
1791
1792 // We have a single incoming value. Simplify using that value.
1793 inst->SetOpcode(SpvOpCopyObject);
1794 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {incoming_value}}});
1795 return true;
1796 };
1797}
1798
1799FoldingRule RedundantSelect() {
1800 // An OpSelect instruction where both values are the same or the condition is
1801 // constant can be replaced by one of the values
1802 return [](IRContext*, Instruction* inst,
1803 const std::vector<const analysis::Constant*>& constants) {
1804 assert(inst->opcode() == SpvOpSelect &&
1805 "Wrong opcode. Should be OpSelect.");
1806 assert(inst->NumInOperands() == 3);
1807 assert(constants.size() == 3);
1808
1809 uint32_t true_id = inst->GetSingleWordInOperand(1);
1810 uint32_t false_id = inst->GetSingleWordInOperand(2);
1811
1812 if (true_id == false_id) {
1813 // Both results are the same, condition doesn't matter
1814 inst->SetOpcode(SpvOpCopyObject);
1815 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1816 return true;
1817 } else if (constants[0]) {
1818 const analysis::Type* type = constants[0]->type();
1819 if (type->AsBool()) {
1820 // Scalar constant value, select the corresponding value.
1821 inst->SetOpcode(SpvOpCopyObject);
1822 if (constants[0]->AsNullConstant() ||
1823 !constants[0]->AsBoolConstant()->value()) {
1824 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1825 } else {
1826 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {true_id}}});
1827 }
1828 return true;
1829 } else {
1830 assert(type->AsVector());
1831 if (constants[0]->AsNullConstant()) {
1832 // All values come from false id.
1833 inst->SetOpcode(SpvOpCopyObject);
1834 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {false_id}}});
1835 return true;
1836 } else {
1837 // Convert to a vector shuffle.
1838 std::vector<Operand> ops;
1839 ops.push_back({SPV_OPERAND_TYPE_ID, {true_id}});
1840 ops.push_back({SPV_OPERAND_TYPE_ID, {false_id}});
1841 const analysis::VectorConstant* vector_const =
1842 constants[0]->AsVectorConstant();
1843 uint32_t size =
1844 static_cast<uint32_t>(vector_const->GetComponents().size());
1845 for (uint32_t i = 0; i != size; ++i) {
1846 const analysis::Constant* component =
1847 vector_const->GetComponents()[i];
1848 if (component->AsNullConstant() ||
1849 !component->AsBoolConstant()->value()) {
1850 // Selecting from the false vector which is the second input
1851 // vector to the shuffle. Offset the index by |size|.
1852 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i + size}});
1853 } else {
1854 // Selecting from true vector which is the first input vector to
1855 // the shuffle.
1856 ops.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}});
1857 }
1858 }
1859
1860 inst->SetOpcode(SpvOpVectorShuffle);
1861 inst->SetInOperands(std::move(ops));
1862 return true;
1863 }
1864 }
1865 }
1866
1867 return false;
1868 };
1869}
1870
1871enum class FloatConstantKind { Unknown, Zero, One };
1872
1873FloatConstantKind getFloatConstantKind(const analysis::Constant* constant) {
1874 if (constant == nullptr) {
1875 return FloatConstantKind::Unknown;
1876 }
1877
1878 assert(HasFloatingPoint(constant->type()) && "Unexpected constant type");
1879
1880 if (constant->AsNullConstant()) {
1881 return FloatConstantKind::Zero;
1882 } else if (const analysis::VectorConstant* vc =
1883 constant->AsVectorConstant()) {
1884 const std::vector<const analysis::Constant*>& components =
1885 vc->GetComponents();
1886 assert(!components.empty());
1887
1888 FloatConstantKind kind = getFloatConstantKind(components[0]);
1889
1890 for (size_t i = 1; i < components.size(); ++i) {
1891 if (getFloatConstantKind(components[i]) != kind) {
1892 return FloatConstantKind::Unknown;
1893 }
1894 }
1895
1896 return kind;
1897 } else if (const analysis::FloatConstant* fc = constant->AsFloatConstant()) {
1898 if (fc->IsZero()) return FloatConstantKind::Zero;
1899
1900 uint32_t width = fc->type()->AsFloat()->width();
1901 if (width != 32 && width != 64) return FloatConstantKind::Unknown;
1902
1903 double value = (width == 64) ? fc->GetDoubleValue() : fc->GetFloatValue();
1904
1905 if (value == 0.0) {
1906 return FloatConstantKind::Zero;
1907 } else if (value == 1.0) {
1908 return FloatConstantKind::One;
1909 } else {
1910 return FloatConstantKind::Unknown;
1911 }
1912 } else {
1913 return FloatConstantKind::Unknown;
1914 }
1915}
1916
1917FoldingRule RedundantFAdd() {
1918 return [](IRContext*, Instruction* inst,
1919 const std::vector<const analysis::Constant*>& constants) {
1920 assert(inst->opcode() == SpvOpFAdd && "Wrong opcode. Should be OpFAdd.");
1921 assert(constants.size() == 2);
1922
1923 if (!inst->IsFloatingPointFoldingAllowed()) {
1924 return false;
1925 }
1926
1927 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1928 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1929
1930 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1931 inst->SetOpcode(SpvOpCopyObject);
1932 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1933 {inst->GetSingleWordInOperand(
1934 kind0 == FloatConstantKind::Zero ? 1 : 0)}}});
1935 return true;
1936 }
1937
1938 return false;
1939 };
1940}
1941
1942FoldingRule RedundantFSub() {
1943 return [](IRContext*, Instruction* inst,
1944 const std::vector<const analysis::Constant*>& constants) {
1945 assert(inst->opcode() == SpvOpFSub && "Wrong opcode. Should be OpFSub.");
1946 assert(constants.size() == 2);
1947
1948 if (!inst->IsFloatingPointFoldingAllowed()) {
1949 return false;
1950 }
1951
1952 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1953 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1954
1955 if (kind0 == FloatConstantKind::Zero) {
1956 inst->SetOpcode(SpvOpFNegate);
1957 inst->SetInOperands(
1958 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1)}}});
1959 return true;
1960 }
1961
1962 if (kind1 == FloatConstantKind::Zero) {
1963 inst->SetOpcode(SpvOpCopyObject);
1964 inst->SetInOperands(
1965 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
1966 return true;
1967 }
1968
1969 return false;
1970 };
1971}
1972
1973FoldingRule RedundantFMul() {
1974 return [](IRContext*, Instruction* inst,
1975 const std::vector<const analysis::Constant*>& constants) {
1976 assert(inst->opcode() == SpvOpFMul && "Wrong opcode. Should be OpFMul.");
1977 assert(constants.size() == 2);
1978
1979 if (!inst->IsFloatingPointFoldingAllowed()) {
1980 return false;
1981 }
1982
1983 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
1984 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
1985
1986 if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::Zero) {
1987 inst->SetOpcode(SpvOpCopyObject);
1988 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1989 {inst->GetSingleWordInOperand(
1990 kind0 == FloatConstantKind::Zero ? 0 : 1)}}});
1991 return true;
1992 }
1993
1994 if (kind0 == FloatConstantKind::One || kind1 == FloatConstantKind::One) {
1995 inst->SetOpcode(SpvOpCopyObject);
1996 inst->SetInOperands({{SPV_OPERAND_TYPE_ID,
1997 {inst->GetSingleWordInOperand(
1998 kind0 == FloatConstantKind::One ? 1 : 0)}}});
1999 return true;
2000 }
2001
2002 return false;
2003 };
2004}
2005
2006FoldingRule RedundantFDiv() {
2007 return [](IRContext*, Instruction* inst,
2008 const std::vector<const analysis::Constant*>& constants) {
2009 assert(inst->opcode() == SpvOpFDiv && "Wrong opcode. Should be OpFDiv.");
2010 assert(constants.size() == 2);
2011
2012 if (!inst->IsFloatingPointFoldingAllowed()) {
2013 return false;
2014 }
2015
2016 FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2017 FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2018
2019 if (kind0 == FloatConstantKind::Zero) {
2020 inst->SetOpcode(SpvOpCopyObject);
2021 inst->SetInOperands(
2022 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2023 return true;
2024 }
2025
2026 if (kind1 == FloatConstantKind::One) {
2027 inst->SetOpcode(SpvOpCopyObject);
2028 inst->SetInOperands(
2029 {{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
2030 return true;
2031 }
2032
2033 return false;
2034 };
2035}
2036
2037FoldingRule RedundantFMix() {
2038 return [](IRContext* context, Instruction* inst,
2039 const std::vector<const analysis::Constant*>& constants) {
2040 assert(inst->opcode() == SpvOpExtInst &&
2041 "Wrong opcode. Should be OpExtInst.");
2042
2043 if (!inst->IsFloatingPointFoldingAllowed()) {
2044 return false;
2045 }
2046
2047 uint32_t instSetId =
2048 context->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
2049
2050 if (inst->GetSingleWordInOperand(kExtInstSetIdInIdx) == instSetId &&
2051 inst->GetSingleWordInOperand(kExtInstInstructionInIdx) ==
2052 GLSLstd450FMix) {
2053 assert(constants.size() == 5);
2054
2055 FloatConstantKind kind4 = getFloatConstantKind(constants[4]);
2056
2057 if (kind4 == FloatConstantKind::Zero || kind4 == FloatConstantKind::One) {
2058 inst->SetOpcode(SpvOpCopyObject);
2059 inst->SetInOperands(
2060 {{SPV_OPERAND_TYPE_ID,
2061 {inst->GetSingleWordInOperand(kind4 == FloatConstantKind::Zero
2062 ? kFMixXIdInIdx
2063 : kFMixYIdInIdx)}}});
2064 return true;
2065 }
2066 }
2067
2068 return false;
2069 };
2070}
2071
2072// This rule handles addition of zero for integers.
2073FoldingRule RedundantIAdd() {
2074 return [](IRContext* context, Instruction* inst,
2075 const std::vector<const analysis::Constant*>& constants) {
2076 assert(inst->opcode() == SpvOpIAdd && "Wrong opcode. Should be OpIAdd.");
2077
2078 uint32_t operand = std::numeric_limits<uint32_t>::max();
2079 const analysis::Type* operand_type = nullptr;
2080 if (constants[0] && constants[0]->IsZero()) {
2081 operand = inst->GetSingleWordInOperand(1);
2082 operand_type = constants[0]->type();
2083 } else if (constants[1] && constants[1]->IsZero()) {
2084 operand = inst->GetSingleWordInOperand(0);
2085 operand_type = constants[1]->type();
2086 }
2087
2088 if (operand != std::numeric_limits<uint32_t>::max()) {
2089 const analysis::Type* inst_type =
2090 context->get_type_mgr()->GetType(inst->type_id());
2091 if (inst_type->IsSame(operand_type)) {
2092 inst->SetOpcode(SpvOpCopyObject);
2093 } else {
2094 inst->SetOpcode(SpvOpBitcast);
2095 }
2096 inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2097 return true;
2098 }
2099 return false;
2100 };
2101}
2102
2103// This rule look for a dot with a constant vector containing a single 1 and
2104// the rest 0s. This is the same as doing an extract.
2105FoldingRule DotProductDoingExtract() {
2106 return [](IRContext* context, Instruction* inst,
2107 const std::vector<const analysis::Constant*>& constants) {
2108 assert(inst->opcode() == SpvOpDot && "Wrong opcode. Should be OpDot.");
2109
2110 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2111
2112 if (!inst->IsFloatingPointFoldingAllowed()) {
2113 return false;
2114 }
2115
2116 for (int i = 0; i < 2; ++i) {
2117 if (!constants[i]) {
2118 continue;
2119 }
2120
2121 const analysis::Vector* vector_type = constants[i]->type()->AsVector();
2122 assert(vector_type && "Inputs to OpDot must be vectors.");
2123 const analysis::Float* element_type =
2124 vector_type->element_type()->AsFloat();
2125 assert(element_type && "Inputs to OpDot must be vectors of floats.");
2126 uint32_t element_width = element_type->width();
2127 if (element_width != 32 && element_width != 64) {
2128 return false;
2129 }
2130
2131 std::vector<const analysis::Constant*> components;
2132 components = constants[i]->GetVectorComponents(const_mgr);
2133
2134 const uint32_t kNotFound = std::numeric_limits<uint32_t>::max();
2135
2136 uint32_t component_with_one = kNotFound;
2137 bool all_others_zero = true;
2138 for (uint32_t j = 0; j < components.size(); ++j) {
2139 const analysis::Constant* element = components[j];
2140 double value =
2141 (element_width == 32 ? element->GetFloat() : element->GetDouble());
2142 if (value == 0.0) {
2143 continue;
2144 } else if (value == 1.0) {
2145 if (component_with_one == kNotFound) {
2146 component_with_one = j;
2147 } else {
2148 component_with_one = kNotFound;
2149 break;
2150 }
2151 } else {
2152 all_others_zero = false;
2153 break;
2154 }
2155 }
2156
2157 if (!all_others_zero || component_with_one == kNotFound) {
2158 continue;
2159 }
2160
2161 std::vector<Operand> operands;
2162 operands.push_back(
2163 {SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(1u - i)}});
2164 operands.push_back(
2165 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_with_one}});
2166
2167 inst->SetOpcode(SpvOpCompositeExtract);
2168 inst->SetInOperands(std::move(operands));
2169 return true;
2170 }
2171 return false;
2172 };
2173}
2174
2175// If we are storing an undef, then we can remove the store.
2176//
2177// TODO: We can do something similar for OpImageWrite, but checking for volatile
2178// is complicated. Waiting to see if it is needed.
2179FoldingRule StoringUndef() {
2180 return [](IRContext* context, Instruction* inst,
2181 const std::vector<const analysis::Constant*>&) {
2182 assert(inst->opcode() == SpvOpStore && "Wrong opcode. Should be OpStore.");
2183
2184 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2185
2186 // If this is a volatile store, the store cannot be removed.
2187 if (inst->NumInOperands() == 3) {
2188 if (inst->GetSingleWordInOperand(2) & SpvMemoryAccessVolatileMask) {
2189 return false;
2190 }
2191 }
2192
2193 uint32_t object_id = inst->GetSingleWordInOperand(kStoreObjectInIdx);
2194 Instruction* object_inst = def_use_mgr->GetDef(object_id);
2195 if (object_inst->opcode() == SpvOpUndef) {
2196 inst->ToNop();
2197 return true;
2198 }
2199 return false;
2200 };
2201}
2202
2203FoldingRule VectorShuffleFeedingShuffle() {
2204 return [](IRContext* context, Instruction* inst,
2205 const std::vector<const analysis::Constant*>&) {
2206 assert(inst->opcode() == SpvOpVectorShuffle &&
2207 "Wrong opcode. Should be OpVectorShuffle.");
2208
2209 analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
2210 analysis::TypeManager* type_mgr = context->get_type_mgr();
2211
2212 Instruction* feeding_shuffle_inst =
2213 def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
2214 analysis::Vector* op0_type =
2215 type_mgr->GetType(feeding_shuffle_inst->type_id())->AsVector();
2216 uint32_t op0_length = op0_type->element_count();
2217
2218 bool feeder_is_op0 = true;
2219 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2220 feeding_shuffle_inst =
2221 def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
2222 feeder_is_op0 = false;
2223 }
2224
2225 if (feeding_shuffle_inst->opcode() != SpvOpVectorShuffle) {
2226 return false;
2227 }
2228
2229 Instruction* feeder2 =
2230 def_use_mgr->GetDef(feeding_shuffle_inst->GetSingleWordInOperand(0));
2231 analysis::Vector* feeder_op0_type =
2232 type_mgr->GetType(feeder2->type_id())->AsVector();
2233 uint32_t feeder_op0_length = feeder_op0_type->element_count();
2234
2235 uint32_t new_feeder_id = 0;
2236 std::vector<Operand> new_operands;
2237 new_operands.resize(
2238 2, {SPV_OPERAND_TYPE_ID, {0}}); // Place holders for vector operands.
2239 const uint32_t undef_literal = 0xffffffff;
2240 for (uint32_t op = 2; op < inst->NumInOperands(); ++op) {
2241 uint32_t component_index = inst->GetSingleWordInOperand(op);
2242
2243 // Do not interpret the undefined value literal as coming from operand 1.
2244 if (component_index != undef_literal &&
2245 feeder_is_op0 == (component_index < op0_length)) {
2246 // This component comes from the feeding_shuffle_inst. Update
2247 // |component_index| to be the index into the operand of the feeder.
2248
2249 // Adjust component_index to get the index into the operands of the
2250 // feeding_shuffle_inst.
2251 if (component_index >= op0_length) {
2252 component_index -= op0_length;
2253 }
2254 component_index =
2255 feeding_shuffle_inst->GetSingleWordInOperand(component_index + 2);
2256
2257 // Check if we are using a component from the first or second operand of
2258 // the feeding instruction.
2259 if (component_index < feeder_op0_length) {
2260 if (new_feeder_id == 0) {
2261 // First time through, save the id of the operand the element comes
2262 // from.
2263 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(0);
2264 } else if (new_feeder_id !=
2265 feeding_shuffle_inst->GetSingleWordInOperand(0)) {
2266 // We need both elements of the feeding_shuffle_inst, so we cannot
2267 // fold.
2268 return false;
2269 }
2270 } else {
2271 if (new_feeder_id == 0) {
2272 // First time through, save the id of the operand the element comes
2273 // from.
2274 new_feeder_id = feeding_shuffle_inst->GetSingleWordInOperand(1);
2275 } else if (new_feeder_id !=
2276 feeding_shuffle_inst->GetSingleWordInOperand(1)) {
2277 // We need both elements of the feeding_shuffle_inst, so we cannot
2278 // fold.
2279 return false;
2280 }
2281 component_index -= feeder_op0_length;
2282 }
2283
2284 if (!feeder_is_op0) {
2285 component_index += op0_length;
2286 }
2287 }
2288 new_operands.push_back(
2289 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {component_index}});
2290 }
2291
2292 if (new_feeder_id == 0) {
2293 analysis::ConstantManager* const_mgr = context->get_constant_mgr();
2294 const analysis::Type* type =
2295 type_mgr->GetType(feeding_shuffle_inst->type_id());
2296 const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
2297 new_feeder_id =
2298 const_mgr->GetDefiningInstruction(null_const, 0)->result_id();
2299 }
2300
2301 if (feeder_is_op0) {
2302 // If the size of the first vector operand changed then the indices
2303 // referring to the second operand need to be adjusted.
2304 Instruction* new_feeder_inst = def_use_mgr->GetDef(new_feeder_id);
2305 analysis::Type* new_feeder_type =
2306 type_mgr->GetType(new_feeder_inst->type_id());
2307 uint32_t new_op0_size = new_feeder_type->AsVector()->element_count();
2308 int32_t adjustment = op0_length - new_op0_size;
2309
2310 if (adjustment != 0) {
2311 for (uint32_t i = 2; i < new_operands.size(); i++) {
2312 if (inst->GetSingleWordInOperand(i) >= op0_length) {
2313 new_operands[i].words[0] -= adjustment;
2314 }
2315 }
2316 }
2317
2318 new_operands[0].words[0] = new_feeder_id;
2319 new_operands[1] = inst->GetInOperand(1);
2320 } else {
2321 new_operands[1].words[0] = new_feeder_id;
2322 new_operands[0] = inst->GetInOperand(0);
2323 }
2324
2325 inst->SetInOperands(std::move(new_operands));
2326 return true;
2327 };
2328}
2329
2330// Removes duplicate ids from the interface list of an OpEntryPoint
2331// instruction.
2332FoldingRule RemoveRedundantOperands() {
2333 return [](IRContext*, Instruction* inst,
2334 const std::vector<const analysis::Constant*>&) {
2335 assert(inst->opcode() == SpvOpEntryPoint &&
2336 "Wrong opcode. Should be OpEntryPoint.");
2337 bool has_redundant_operand = false;
2338 std::unordered_set<uint32_t> seen_operands;
2339 std::vector<Operand> new_operands;
2340
2341 new_operands.emplace_back(inst->GetOperand(0));
2342 new_operands.emplace_back(inst->GetOperand(1));
2343 new_operands.emplace_back(inst->GetOperand(2));
2344 for (uint32_t i = 3; i < inst->NumOperands(); ++i) {
2345 if (seen_operands.insert(inst->GetSingleWordOperand(i)).second) {
2346 new_operands.emplace_back(inst->GetOperand(i));
2347 } else {
2348 has_redundant_operand = true;
2349 }
2350 }
2351
2352 if (!has_redundant_operand) {
2353 return false;
2354 }
2355
2356 inst->SetInOperands(std::move(new_operands));
2357 return true;
2358 };
2359}
2360
2361// If an image instruction's operand is a constant, updates the image operand
2362// flag from Offset to ConstOffset.
2363FoldingRule UpdateImageOperands() {
2364 return [](IRContext*, Instruction* inst,
2365 const std::vector<const analysis::Constant*>& constants) {
2366 const auto opcode = inst->opcode();
2367 (void)opcode;
2368 assert((opcode == SpvOpImageSampleImplicitLod ||
2369 opcode == SpvOpImageSampleExplicitLod ||
2370 opcode == SpvOpImageSampleDrefImplicitLod ||
2371 opcode == SpvOpImageSampleDrefExplicitLod ||
2372 opcode == SpvOpImageSampleProjImplicitLod ||
2373 opcode == SpvOpImageSampleProjExplicitLod ||
2374 opcode == SpvOpImageSampleProjDrefImplicitLod ||
2375 opcode == SpvOpImageSampleProjDrefExplicitLod ||
2376 opcode == SpvOpImageFetch || opcode == SpvOpImageGather ||
2377 opcode == SpvOpImageDrefGather || opcode == SpvOpImageRead ||
2378 opcode == SpvOpImageWrite ||
2379 opcode == SpvOpImageSparseSampleImplicitLod ||
2380 opcode == SpvOpImageSparseSampleExplicitLod ||
2381 opcode == SpvOpImageSparseSampleDrefImplicitLod ||
2382 opcode == SpvOpImageSparseSampleDrefExplicitLod ||
2383 opcode == SpvOpImageSparseSampleProjImplicitLod ||
2384 opcode == SpvOpImageSparseSampleProjExplicitLod ||
2385 opcode == SpvOpImageSparseSampleProjDrefImplicitLod ||
2386 opcode == SpvOpImageSparseSampleProjDrefExplicitLod ||
2387 opcode == SpvOpImageSparseFetch ||
2388 opcode == SpvOpImageSparseGather ||
2389 opcode == SpvOpImageSparseDrefGather ||
2390 opcode == SpvOpImageSparseRead) &&
2391 "Wrong opcode. Should be an image instruction.");
2392
2393 int32_t operand_index = ImageOperandsMaskInOperandIndex(inst);
2394 if (operand_index >= 0) {
2395 auto image_operands = inst->GetSingleWordInOperand(operand_index);
2396 if (image_operands & SpvImageOperandsOffsetMask) {
2397 uint32_t offset_operand_index = operand_index + 1;
2398 if (image_operands & SpvImageOperandsBiasMask) offset_operand_index++;
2399 if (image_operands & SpvImageOperandsLodMask) offset_operand_index++;
2400 if (image_operands & SpvImageOperandsGradMask)
2401 offset_operand_index += 2;
2402 assert(((image_operands & SpvImageOperandsConstOffsetMask) == 0) &&
2403 "Offset and ConstOffset may not be used together");
2404 if (offset_operand_index < inst->NumOperands()) {
2405 if (constants[offset_operand_index]) {
2406 image_operands = image_operands | SpvImageOperandsConstOffsetMask;
2407 image_operands = image_operands & ~SpvImageOperandsOffsetMask;
2408 inst->SetInOperand(operand_index, {image_operands});
2409 return true;
2410 }
2411 }
2412 }
2413 }
2414
2415 return false;
2416 };
2417}
2418
2419} // namespace
2420
2421void FoldingRules::AddFoldingRules() {
2422 // Add all folding rules to the list for the opcodes to which they apply.
2423 // Note that the order in which rules are added to the list matters. If a rule
2424 // applies to the instruction, the rest of the rules will not be attempted.
2425 // Take that into consideration.
2426 rules_[SpvOpCompositeConstruct].push_back(CompositeExtractFeedingConstruct);
2427
2428 rules_[SpvOpCompositeExtract].push_back(InsertFeedingExtract());
2429 rules_[SpvOpCompositeExtract].push_back(CompositeConstructFeedingExtract());
2430 rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
2431 rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
2432
2433 rules_[SpvOpDot].push_back(DotProductDoingExtract());
2434
2435 rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
2436
2437 rules_[SpvOpFAdd].push_back(RedundantFAdd());
2438 rules_[SpvOpFAdd].push_back(MergeAddNegateArithmetic());
2439 rules_[SpvOpFAdd].push_back(MergeAddAddArithmetic());
2440 rules_[SpvOpFAdd].push_back(MergeAddSubArithmetic());
2441 rules_[SpvOpFAdd].push_back(MergeGenericAddSubArithmetic());
2442 rules_[SpvOpFAdd].push_back(FactorAddMuls());
2443
2444 rules_[SpvOpFDiv].push_back(RedundantFDiv());
2445 rules_[SpvOpFDiv].push_back(ReciprocalFDiv());
2446 rules_[SpvOpFDiv].push_back(MergeDivDivArithmetic());
2447 rules_[SpvOpFDiv].push_back(MergeDivMulArithmetic());
2448 rules_[SpvOpFDiv].push_back(MergeDivNegateArithmetic());
2449
2450 rules_[SpvOpFMul].push_back(RedundantFMul());
2451 rules_[SpvOpFMul].push_back(MergeMulMulArithmetic());
2452 rules_[SpvOpFMul].push_back(MergeMulDivArithmetic());
2453 rules_[SpvOpFMul].push_back(MergeMulNegateArithmetic());
2454
2455 rules_[SpvOpFNegate].push_back(MergeNegateArithmetic());
2456 rules_[SpvOpFNegate].push_back(MergeNegateAddSubArithmetic());
2457 rules_[SpvOpFNegate].push_back(MergeNegateMulDivArithmetic());
2458
2459 rules_[SpvOpFSub].push_back(RedundantFSub());
2460 rules_[SpvOpFSub].push_back(MergeSubNegateArithmetic());
2461 rules_[SpvOpFSub].push_back(MergeSubAddArithmetic());
2462 rules_[SpvOpFSub].push_back(MergeSubSubArithmetic());
2463
2464 rules_[SpvOpIAdd].push_back(RedundantIAdd());
2465 rules_[SpvOpIAdd].push_back(MergeAddNegateArithmetic());
2466 rules_[SpvOpIAdd].push_back(MergeAddAddArithmetic());
2467 rules_[SpvOpIAdd].push_back(MergeAddSubArithmetic());
2468 rules_[SpvOpIAdd].push_back(MergeGenericAddSubArithmetic());
2469 rules_[SpvOpIAdd].push_back(FactorAddMuls());
2470
2471 rules_[SpvOpIMul].push_back(IntMultipleBy1());
2472 rules_[SpvOpIMul].push_back(MergeMulMulArithmetic());
2473 rules_[SpvOpIMul].push_back(MergeMulNegateArithmetic());
2474
2475 rules_[SpvOpISub].push_back(MergeSubNegateArithmetic());
2476 rules_[SpvOpISub].push_back(MergeSubAddArithmetic());
2477 rules_[SpvOpISub].push_back(MergeSubSubArithmetic());
2478
2479 rules_[SpvOpPhi].push_back(RedundantPhi());
2480
2481 rules_[SpvOpSDiv].push_back(MergeDivNegateArithmetic());
2482
2483 rules_[SpvOpSNegate].push_back(MergeNegateArithmetic());
2484 rules_[SpvOpSNegate].push_back(MergeNegateMulDivArithmetic());
2485 rules_[SpvOpSNegate].push_back(MergeNegateAddSubArithmetic());
2486
2487 rules_[SpvOpSelect].push_back(RedundantSelect());
2488
2489 rules_[SpvOpStore].push_back(StoringUndef());
2490
2491 rules_[SpvOpUDiv].push_back(MergeDivNegateArithmetic());
2492
2493 rules_[SpvOpVectorShuffle].push_back(VectorShuffleFeedingShuffle());
2494
2495 rules_[SpvOpImageSampleImplicitLod].push_back(UpdateImageOperands());
2496 rules_[SpvOpImageSampleExplicitLod].push_back(UpdateImageOperands());
2497 rules_[SpvOpImageSampleDrefImplicitLod].push_back(UpdateImageOperands());
2498 rules_[SpvOpImageSampleDrefExplicitLod].push_back(UpdateImageOperands());
2499 rules_[SpvOpImageSampleProjImplicitLod].push_back(UpdateImageOperands());
2500 rules_[SpvOpImageSampleProjExplicitLod].push_back(UpdateImageOperands());
2501 rules_[SpvOpImageSampleProjDrefImplicitLod].push_back(UpdateImageOperands());
2502 rules_[SpvOpImageSampleProjDrefExplicitLod].push_back(UpdateImageOperands());
2503 rules_[SpvOpImageFetch].push_back(UpdateImageOperands());
2504 rules_[SpvOpImageGather].push_back(UpdateImageOperands());
2505 rules_[SpvOpImageDrefGather].push_back(UpdateImageOperands());
2506 rules_[SpvOpImageRead].push_back(UpdateImageOperands());
2507 rules_[SpvOpImageWrite].push_back(UpdateImageOperands());
2508 rules_[SpvOpImageSparseSampleImplicitLod].push_back(UpdateImageOperands());
2509 rules_[SpvOpImageSparseSampleExplicitLod].push_back(UpdateImageOperands());
2510 rules_[SpvOpImageSparseSampleDrefImplicitLod].push_back(
2511 UpdateImageOperands());
2512 rules_[SpvOpImageSparseSampleDrefExplicitLod].push_back(
2513 UpdateImageOperands());
2514 rules_[SpvOpImageSparseSampleProjImplicitLod].push_back(
2515 UpdateImageOperands());
2516 rules_[SpvOpImageSparseSampleProjExplicitLod].push_back(
2517 UpdateImageOperands());
2518 rules_[SpvOpImageSparseSampleProjDrefImplicitLod].push_back(
2519 UpdateImageOperands());
2520 rules_[SpvOpImageSparseSampleProjDrefExplicitLod].push_back(
2521 UpdateImageOperands());
2522 rules_[SpvOpImageSparseFetch].push_back(UpdateImageOperands());
2523 rules_[SpvOpImageSparseGather].push_back(UpdateImageOperands());
2524 rules_[SpvOpImageSparseDrefGather].push_back(UpdateImageOperands());
2525 rules_[SpvOpImageSparseRead].push_back(UpdateImageOperands());
2526
2527 FeatureManager* feature_manager = context_->get_feature_mgr();
2528 // Add rules for GLSLstd450
2529 uint32_t ext_inst_glslstd450_id =
2530 feature_manager->GetExtInstImportId_GLSLstd450();
2531 if (ext_inst_glslstd450_id != 0) {
2532 ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(
2533 RedundantFMix());
2534 }
2535}
2536} // namespace opt
2537} // namespace spvtools
2538