1 | // Copyright (c) 2017 Google Inc. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/fold.h" |
16 | |
17 | #include <cassert> |
18 | #include <cstdint> |
19 | #include <vector> |
20 | |
21 | #include "source/opt/const_folding_rules.h" |
22 | #include "source/opt/def_use_manager.h" |
23 | #include "source/opt/folding_rules.h" |
24 | #include "source/opt/ir_builder.h" |
25 | #include "source/opt/ir_context.h" |
26 | |
27 | namespace spvtools { |
28 | namespace opt { |
29 | namespace { |
30 | |
31 | #ifndef INT32_MIN |
32 | #define INT32_MIN (-2147483648) |
33 | #endif |
34 | |
35 | #ifndef INT32_MAX |
36 | #define INT32_MAX 2147483647 |
37 | #endif |
38 | |
39 | #ifndef UINT32_MAX |
40 | #define UINT32_MAX 0xffffffff /* 4294967295U */ |
41 | #endif |
42 | |
43 | } // namespace |
44 | |
45 | uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const { |
46 | switch (opcode) { |
47 | // Arthimetics |
48 | case SpvOp::SpvOpSNegate: { |
49 | int32_t s_operand = static_cast<int32_t>(operand); |
50 | if (s_operand == std::numeric_limits<int32_t>::min()) { |
51 | return s_operand; |
52 | } |
53 | return -s_operand; |
54 | } |
55 | case SpvOp::SpvOpNot: |
56 | return ~operand; |
57 | case SpvOp::SpvOpLogicalNot: |
58 | return !static_cast<bool>(operand); |
59 | case SpvOp::SpvOpUConvert: |
60 | return operand; |
61 | case SpvOp::SpvOpSConvert: |
62 | return operand; |
63 | default: |
64 | assert(false && |
65 | "Unsupported unary operation for OpSpecConstantOp instruction" ); |
66 | return 0u; |
67 | } |
68 | } |
69 | |
70 | uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a, |
71 | uint32_t b) const { |
72 | switch (opcode) { |
73 | // Arthimetics |
74 | case SpvOp::SpvOpIAdd: |
75 | return a + b; |
76 | case SpvOp::SpvOpISub: |
77 | return a - b; |
78 | case SpvOp::SpvOpIMul: |
79 | return a * b; |
80 | case SpvOp::SpvOpUDiv: |
81 | if (b != 0) { |
82 | return a / b; |
83 | } else { |
84 | // Dividing by 0 is undefined, so we will just pick 0. |
85 | return 0; |
86 | } |
87 | case SpvOp::SpvOpSDiv: |
88 | if (b != 0u) { |
89 | return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b)); |
90 | } else { |
91 | // Dividing by 0 is undefined, so we will just pick 0. |
92 | return 0; |
93 | } |
94 | case SpvOp::SpvOpSRem: { |
95 | // The sign of non-zero result comes from the first operand: a. This is |
96 | // guaranteed by C++11 rules for integer division operator. The division |
97 | // result is rounded toward zero, so the result of '%' has the sign of |
98 | // the first operand. |
99 | if (b != 0u) { |
100 | return static_cast<int32_t>(a) % static_cast<int32_t>(b); |
101 | } else { |
102 | // Remainder when dividing with 0 is undefined, so we will just pick 0. |
103 | return 0; |
104 | } |
105 | } |
106 | case SpvOp::SpvOpSMod: { |
107 | // The sign of non-zero result comes from the second operand: b |
108 | if (b != 0u) { |
109 | int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b); |
110 | int32_t b_prim = static_cast<int32_t>(b); |
111 | return (rem + b_prim) % b_prim; |
112 | } else { |
113 | // Mod with 0 is undefined, so we will just pick 0. |
114 | return 0; |
115 | } |
116 | } |
117 | case SpvOp::SpvOpUMod: |
118 | if (b != 0u) { |
119 | return (a % b); |
120 | } else { |
121 | // Mod with 0 is undefined, so we will just pick 0. |
122 | return 0; |
123 | } |
124 | |
125 | // Shifting |
126 | case SpvOp::SpvOpShiftRightLogical: |
127 | if (b >= 32) { |
128 | // This is undefined behaviour when |b| > 32. Choose 0 for consistency. |
129 | // When |b| == 32, doing the shift in C++ in undefined, but the result |
130 | // will be 0, so just return that value. |
131 | return 0; |
132 | } |
133 | return a >> b; |
134 | case SpvOp::SpvOpShiftRightArithmetic: |
135 | if (b > 32) { |
136 | // This is undefined behaviour. Choose 0 for consistency. |
137 | return 0; |
138 | } |
139 | if (b == 32) { |
140 | // Doing the shift in C++ is undefined, but the result is defined in the |
141 | // spir-v spec. Find that value another way. |
142 | if (static_cast<int32_t>(a) >= 0) { |
143 | return 0; |
144 | } else { |
145 | return static_cast<uint32_t>(-1); |
146 | } |
147 | } |
148 | return (static_cast<int32_t>(a)) >> b; |
149 | case SpvOp::SpvOpShiftLeftLogical: |
150 | if (b >= 32) { |
151 | // This is undefined behaviour when |b| > 32. Choose 0 for consistency. |
152 | // When |b| == 32, doing the shift in C++ in undefined, but the result |
153 | // will be 0, so just return that value. |
154 | return 0; |
155 | } |
156 | return a << b; |
157 | |
158 | // Bitwise operations |
159 | case SpvOp::SpvOpBitwiseOr: |
160 | return a | b; |
161 | case SpvOp::SpvOpBitwiseAnd: |
162 | return a & b; |
163 | case SpvOp::SpvOpBitwiseXor: |
164 | return a ^ b; |
165 | |
166 | // Logical |
167 | case SpvOp::SpvOpLogicalEqual: |
168 | return (static_cast<bool>(a)) == (static_cast<bool>(b)); |
169 | case SpvOp::SpvOpLogicalNotEqual: |
170 | return (static_cast<bool>(a)) != (static_cast<bool>(b)); |
171 | case SpvOp::SpvOpLogicalOr: |
172 | return (static_cast<bool>(a)) || (static_cast<bool>(b)); |
173 | case SpvOp::SpvOpLogicalAnd: |
174 | return (static_cast<bool>(a)) && (static_cast<bool>(b)); |
175 | |
176 | // Comparison |
177 | case SpvOp::SpvOpIEqual: |
178 | return a == b; |
179 | case SpvOp::SpvOpINotEqual: |
180 | return a != b; |
181 | case SpvOp::SpvOpULessThan: |
182 | return a < b; |
183 | case SpvOp::SpvOpSLessThan: |
184 | return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b)); |
185 | case SpvOp::SpvOpUGreaterThan: |
186 | return a > b; |
187 | case SpvOp::SpvOpSGreaterThan: |
188 | return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b)); |
189 | case SpvOp::SpvOpULessThanEqual: |
190 | return a <= b; |
191 | case SpvOp::SpvOpSLessThanEqual: |
192 | return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b)); |
193 | case SpvOp::SpvOpUGreaterThanEqual: |
194 | return a >= b; |
195 | case SpvOp::SpvOpSGreaterThanEqual: |
196 | return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b)); |
197 | default: |
198 | assert(false && |
199 | "Unsupported binary operation for OpSpecConstantOp instruction" ); |
200 | return 0u; |
201 | } |
202 | } |
203 | |
204 | uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, |
205 | uint32_t c) const { |
206 | switch (opcode) { |
207 | case SpvOp::SpvOpSelect: |
208 | return (static_cast<bool>(a)) ? b : c; |
209 | default: |
210 | assert(false && |
211 | "Unsupported ternary operation for OpSpecConstantOp instruction" ); |
212 | return 0u; |
213 | } |
214 | } |
215 | |
216 | uint32_t InstructionFolder::OperateWords( |
217 | SpvOp opcode, const std::vector<uint32_t>& operand_words) const { |
218 | switch (operand_words.size()) { |
219 | case 1: |
220 | return UnaryOperate(opcode, operand_words.front()); |
221 | case 2: |
222 | return BinaryOperate(opcode, operand_words.front(), operand_words.back()); |
223 | case 3: |
224 | return TernaryOperate(opcode, operand_words[0], operand_words[1], |
225 | operand_words[2]); |
226 | default: |
227 | assert(false && "Invalid number of operands" ); |
228 | return 0; |
229 | } |
230 | } |
231 | |
232 | bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const { |
233 | auto identity_map = [](uint32_t id) { return id; }; |
234 | Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map); |
235 | if (folded_inst != nullptr) { |
236 | inst->SetOpcode(SpvOpCopyObject); |
237 | inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}}); |
238 | return true; |
239 | } |
240 | |
241 | analysis::ConstantManager* const_manager = context_->get_constant_mgr(); |
242 | std::vector<const analysis::Constant*> constants = |
243 | const_manager->GetOperandConstants(inst); |
244 | |
245 | for (const FoldingRule& rule : |
246 | GetFoldingRules().GetRulesForInstruction(inst)) { |
247 | if (rule(context_, inst, constants)) { |
248 | return true; |
249 | } |
250 | } |
251 | return false; |
252 | } |
253 | |
254 | // Returns the result of performing an operation on scalar constant operands. |
255 | // This function extracts the operand values as 32 bit words and returns the |
256 | // result in 32 bit word. Scalar constants with longer than 32-bit width are |
257 | // not accepted in this function. |
258 | uint32_t InstructionFolder::FoldScalars( |
259 | SpvOp opcode, |
260 | const std::vector<const analysis::Constant*>& operands) const { |
261 | assert(IsFoldableOpcode(opcode) && |
262 | "Unhandled instruction opcode in FoldScalars" ); |
263 | std::vector<uint32_t> operand_values_in_raw_words; |
264 | for (const auto& operand : operands) { |
265 | if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) { |
266 | const auto& scalar_words = scalar->words(); |
267 | assert(scalar_words.size() == 1 && |
268 | "Scalar constants with longer than 32-bit width are not allowed " |
269 | "in FoldScalars()" ); |
270 | operand_values_in_raw_words.push_back(scalar_words.front()); |
271 | } else if (operand->AsNullConstant()) { |
272 | operand_values_in_raw_words.push_back(0u); |
273 | } else { |
274 | assert(false && |
275 | "FoldScalars() only accepts ScalarConst or NullConst type of " |
276 | "constant" ); |
277 | } |
278 | } |
279 | return OperateWords(opcode, operand_values_in_raw_words); |
280 | } |
281 | |
282 | bool InstructionFolder::FoldBinaryIntegerOpToConstant( |
283 | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
284 | uint32_t* result) const { |
285 | SpvOp opcode = inst->opcode(); |
286 | analysis::ConstantManager* const_manger = context_->get_constant_mgr(); |
287 | |
288 | uint32_t ids[2]; |
289 | const analysis::IntConstant* constants[2]; |
290 | for (uint32_t i = 0; i < 2; i++) { |
291 | const Operand* operand = &inst->GetInOperand(i); |
292 | if (operand->type != SPV_OPERAND_TYPE_ID) { |
293 | return false; |
294 | } |
295 | ids[i] = id_map(operand->words[0]); |
296 | const analysis::Constant* constant = |
297 | const_manger->FindDeclaredConstant(ids[i]); |
298 | constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr); |
299 | } |
300 | |
301 | switch (opcode) { |
302 | // Arthimetics |
303 | case SpvOp::SpvOpIMul: |
304 | for (uint32_t i = 0; i < 2; i++) { |
305 | if (constants[i] != nullptr && constants[i]->IsZero()) { |
306 | *result = 0; |
307 | return true; |
308 | } |
309 | } |
310 | break; |
311 | case SpvOp::SpvOpUDiv: |
312 | case SpvOp::SpvOpSDiv: |
313 | case SpvOp::SpvOpSRem: |
314 | case SpvOp::SpvOpSMod: |
315 | case SpvOp::SpvOpUMod: |
316 | // This changes undefined behaviour (ie divide by 0) into a 0. |
317 | for (uint32_t i = 0; i < 2; i++) { |
318 | if (constants[i] != nullptr && constants[i]->IsZero()) { |
319 | *result = 0; |
320 | return true; |
321 | } |
322 | } |
323 | break; |
324 | |
325 | // Shifting |
326 | case SpvOp::SpvOpShiftRightLogical: |
327 | case SpvOp::SpvOpShiftLeftLogical: |
328 | if (constants[1] != nullptr) { |
329 | // When shifting by a value larger than the size of the result, the |
330 | // result is undefined. We are setting the undefined behaviour to a |
331 | // result of 0. If the shift amount is the same as the size of the |
332 | // result, then the result is defined, and it 0. |
333 | uint32_t shift_amount = constants[1]->GetU32BitValue(); |
334 | if (shift_amount >= 32) { |
335 | *result = 0; |
336 | return true; |
337 | } |
338 | } |
339 | break; |
340 | |
341 | // Bitwise operations |
342 | case SpvOp::SpvOpBitwiseOr: |
343 | for (uint32_t i = 0; i < 2; i++) { |
344 | if (constants[i] != nullptr) { |
345 | // TODO: Change the mask against a value based on the bit width of the |
346 | // instruction result type. This way we can handle say 16-bit values |
347 | // as well. |
348 | uint32_t mask = constants[i]->GetU32BitValue(); |
349 | if (mask == 0xFFFFFFFF) { |
350 | *result = 0xFFFFFFFF; |
351 | return true; |
352 | } |
353 | } |
354 | } |
355 | break; |
356 | case SpvOp::SpvOpBitwiseAnd: |
357 | for (uint32_t i = 0; i < 2; i++) { |
358 | if (constants[i] != nullptr) { |
359 | if (constants[i]->IsZero()) { |
360 | *result = 0; |
361 | return true; |
362 | } |
363 | } |
364 | } |
365 | break; |
366 | |
367 | // Comparison |
368 | case SpvOp::SpvOpULessThan: |
369 | if (constants[0] != nullptr && |
370 | constants[0]->GetU32BitValue() == UINT32_MAX) { |
371 | *result = false; |
372 | return true; |
373 | } |
374 | if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { |
375 | *result = false; |
376 | return true; |
377 | } |
378 | break; |
379 | case SpvOp::SpvOpSLessThan: |
380 | if (constants[0] != nullptr && |
381 | constants[0]->GetS32BitValue() == INT32_MAX) { |
382 | *result = false; |
383 | return true; |
384 | } |
385 | if (constants[1] != nullptr && |
386 | constants[1]->GetS32BitValue() == INT32_MIN) { |
387 | *result = false; |
388 | return true; |
389 | } |
390 | break; |
391 | case SpvOp::SpvOpUGreaterThan: |
392 | if (constants[0] != nullptr && constants[0]->IsZero()) { |
393 | *result = false; |
394 | return true; |
395 | } |
396 | if (constants[1] != nullptr && |
397 | constants[1]->GetU32BitValue() == UINT32_MAX) { |
398 | *result = false; |
399 | return true; |
400 | } |
401 | break; |
402 | case SpvOp::SpvOpSGreaterThan: |
403 | if (constants[0] != nullptr && |
404 | constants[0]->GetS32BitValue() == INT32_MIN) { |
405 | *result = false; |
406 | return true; |
407 | } |
408 | if (constants[1] != nullptr && |
409 | constants[1]->GetS32BitValue() == INT32_MAX) { |
410 | *result = false; |
411 | return true; |
412 | } |
413 | break; |
414 | case SpvOp::SpvOpULessThanEqual: |
415 | if (constants[0] != nullptr && constants[0]->IsZero()) { |
416 | *result = true; |
417 | return true; |
418 | } |
419 | if (constants[1] != nullptr && |
420 | constants[1]->GetU32BitValue() == UINT32_MAX) { |
421 | *result = true; |
422 | return true; |
423 | } |
424 | break; |
425 | case SpvOp::SpvOpSLessThanEqual: |
426 | if (constants[0] != nullptr && |
427 | constants[0]->GetS32BitValue() == INT32_MIN) { |
428 | *result = true; |
429 | return true; |
430 | } |
431 | if (constants[1] != nullptr && |
432 | constants[1]->GetS32BitValue() == INT32_MAX) { |
433 | *result = true; |
434 | return true; |
435 | } |
436 | break; |
437 | case SpvOp::SpvOpUGreaterThanEqual: |
438 | if (constants[0] != nullptr && |
439 | constants[0]->GetU32BitValue() == UINT32_MAX) { |
440 | *result = true; |
441 | return true; |
442 | } |
443 | if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) { |
444 | *result = true; |
445 | return true; |
446 | } |
447 | break; |
448 | case SpvOp::SpvOpSGreaterThanEqual: |
449 | if (constants[0] != nullptr && |
450 | constants[0]->GetS32BitValue() == INT32_MAX) { |
451 | *result = true; |
452 | return true; |
453 | } |
454 | if (constants[1] != nullptr && |
455 | constants[1]->GetS32BitValue() == INT32_MIN) { |
456 | *result = true; |
457 | return true; |
458 | } |
459 | break; |
460 | default: |
461 | break; |
462 | } |
463 | return false; |
464 | } |
465 | |
466 | bool InstructionFolder::FoldBinaryBooleanOpToConstant( |
467 | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
468 | uint32_t* result) const { |
469 | SpvOp opcode = inst->opcode(); |
470 | analysis::ConstantManager* const_manger = context_->get_constant_mgr(); |
471 | |
472 | uint32_t ids[2]; |
473 | const analysis::BoolConstant* constants[2]; |
474 | for (uint32_t i = 0; i < 2; i++) { |
475 | const Operand* operand = &inst->GetInOperand(i); |
476 | if (operand->type != SPV_OPERAND_TYPE_ID) { |
477 | return false; |
478 | } |
479 | ids[i] = id_map(operand->words[0]); |
480 | const analysis::Constant* constant = |
481 | const_manger->FindDeclaredConstant(ids[i]); |
482 | constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr); |
483 | } |
484 | |
485 | switch (opcode) { |
486 | // Logical |
487 | case SpvOp::SpvOpLogicalOr: |
488 | for (uint32_t i = 0; i < 2; i++) { |
489 | if (constants[i] != nullptr) { |
490 | if (constants[i]->value()) { |
491 | *result = true; |
492 | return true; |
493 | } |
494 | } |
495 | } |
496 | break; |
497 | case SpvOp::SpvOpLogicalAnd: |
498 | for (uint32_t i = 0; i < 2; i++) { |
499 | if (constants[i] != nullptr) { |
500 | if (!constants[i]->value()) { |
501 | *result = false; |
502 | return true; |
503 | } |
504 | } |
505 | } |
506 | break; |
507 | |
508 | default: |
509 | break; |
510 | } |
511 | return false; |
512 | } |
513 | |
514 | bool InstructionFolder::FoldIntegerOpToConstant( |
515 | Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map, |
516 | uint32_t* result) const { |
517 | assert(IsFoldableOpcode(inst->opcode()) && |
518 | "Unhandled instruction opcode in FoldScalars" ); |
519 | switch (inst->NumInOperands()) { |
520 | case 2: |
521 | return FoldBinaryIntegerOpToConstant(inst, id_map, result) || |
522 | FoldBinaryBooleanOpToConstant(inst, id_map, result); |
523 | default: |
524 | return false; |
525 | } |
526 | } |
527 | |
528 | std::vector<uint32_t> InstructionFolder::FoldVectors( |
529 | SpvOp opcode, uint32_t num_dims, |
530 | const std::vector<const analysis::Constant*>& operands) const { |
531 | assert(IsFoldableOpcode(opcode) && |
532 | "Unhandled instruction opcode in FoldVectors" ); |
533 | std::vector<uint32_t> result; |
534 | for (uint32_t d = 0; d < num_dims; d++) { |
535 | std::vector<uint32_t> operand_values_for_one_dimension; |
536 | for (const auto& operand : operands) { |
537 | if (const analysis::VectorConstant* vector_operand = |
538 | operand->AsVectorConstant()) { |
539 | // Extract the raw value of the scalar component constants |
540 | // in 32-bit words here. The reason of not using FoldScalars() here |
541 | // is that we do not create temporary null constants as components |
542 | // when the vector operand is a NullConstant because Constant creation |
543 | // may need extra checks for the validity and that is not manageed in |
544 | // here. |
545 | if (const analysis::ScalarConstant* scalar_component = |
546 | vector_operand->GetComponents().at(d)->AsScalarConstant()) { |
547 | const auto& scalar_words = scalar_component->words(); |
548 | assert( |
549 | scalar_words.size() == 1 && |
550 | "Vector components with longer than 32-bit width are not allowed " |
551 | "in FoldVectors()" ); |
552 | operand_values_for_one_dimension.push_back(scalar_words.front()); |
553 | } else if (operand->AsNullConstant()) { |
554 | operand_values_for_one_dimension.push_back(0u); |
555 | } else { |
556 | assert(false && |
557 | "VectorConst should only has ScalarConst or NullConst as " |
558 | "components" ); |
559 | } |
560 | } else if (operand->AsNullConstant()) { |
561 | operand_values_for_one_dimension.push_back(0u); |
562 | } else { |
563 | assert(false && |
564 | "FoldVectors() only accepts VectorConst or NullConst type of " |
565 | "constant" ); |
566 | } |
567 | } |
568 | result.push_back(OperateWords(opcode, operand_values_for_one_dimension)); |
569 | } |
570 | return result; |
571 | } |
572 | |
573 | bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const { |
574 | // NOTE: Extend to more opcodes as new cases are handled in the folder |
575 | // functions. |
576 | switch (opcode) { |
577 | case SpvOp::SpvOpBitwiseAnd: |
578 | case SpvOp::SpvOpBitwiseOr: |
579 | case SpvOp::SpvOpBitwiseXor: |
580 | case SpvOp::SpvOpIAdd: |
581 | case SpvOp::SpvOpIEqual: |
582 | case SpvOp::SpvOpIMul: |
583 | case SpvOp::SpvOpINotEqual: |
584 | case SpvOp::SpvOpISub: |
585 | case SpvOp::SpvOpLogicalAnd: |
586 | case SpvOp::SpvOpLogicalEqual: |
587 | case SpvOp::SpvOpLogicalNot: |
588 | case SpvOp::SpvOpLogicalNotEqual: |
589 | case SpvOp::SpvOpLogicalOr: |
590 | case SpvOp::SpvOpNot: |
591 | case SpvOp::SpvOpSDiv: |
592 | case SpvOp::SpvOpSelect: |
593 | case SpvOp::SpvOpSGreaterThan: |
594 | case SpvOp::SpvOpSGreaterThanEqual: |
595 | case SpvOp::SpvOpShiftLeftLogical: |
596 | case SpvOp::SpvOpShiftRightArithmetic: |
597 | case SpvOp::SpvOpShiftRightLogical: |
598 | case SpvOp::SpvOpSLessThan: |
599 | case SpvOp::SpvOpSLessThanEqual: |
600 | case SpvOp::SpvOpSMod: |
601 | case SpvOp::SpvOpSNegate: |
602 | case SpvOp::SpvOpSRem: |
603 | case SpvOp::SpvOpSConvert: |
604 | case SpvOp::SpvOpUConvert: |
605 | case SpvOp::SpvOpUDiv: |
606 | case SpvOp::SpvOpUGreaterThan: |
607 | case SpvOp::SpvOpUGreaterThanEqual: |
608 | case SpvOp::SpvOpULessThan: |
609 | case SpvOp::SpvOpULessThanEqual: |
610 | case SpvOp::SpvOpUMod: |
611 | return true; |
612 | default: |
613 | return false; |
614 | } |
615 | } |
616 | |
617 | bool InstructionFolder::IsFoldableConstant( |
618 | const analysis::Constant* cst) const { |
619 | // Currently supported constants are 32-bit values or null constants. |
620 | if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant()) |
621 | return scalar->words().size() == 1; |
622 | else |
623 | return cst->AsNullConstant() != nullptr; |
624 | } |
625 | |
626 | Instruction* InstructionFolder::FoldInstructionToConstant( |
627 | Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const { |
628 | analysis::ConstantManager* const_mgr = context_->get_constant_mgr(); |
629 | |
630 | if (!inst->IsFoldableByFoldScalar() && |
631 | !GetConstantFoldingRules().HasFoldingRule(inst)) { |
632 | return nullptr; |
633 | } |
634 | // Collect the values of the constant parameters. |
635 | std::vector<const analysis::Constant*> constants; |
636 | bool missing_constants = false; |
637 | inst->ForEachInId([&constants, &missing_constants, const_mgr, |
638 | &id_map](uint32_t* op_id) { |
639 | uint32_t id = id_map(*op_id); |
640 | const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id); |
641 | if (!const_op) { |
642 | constants.push_back(nullptr); |
643 | missing_constants = true; |
644 | } else { |
645 | constants.push_back(const_op); |
646 | } |
647 | }); |
648 | |
649 | const analysis::Constant* folded_const = nullptr; |
650 | for (auto rule : GetConstantFoldingRules().GetRulesForInstruction(inst)) { |
651 | folded_const = rule(context_, inst, constants); |
652 | if (folded_const != nullptr) { |
653 | Instruction* const_inst = |
654 | const_mgr->GetDefiningInstruction(folded_const, inst->type_id()); |
655 | if (const_inst == nullptr) { |
656 | return nullptr; |
657 | } |
658 | assert(const_inst->type_id() == inst->type_id()); |
659 | // May be a new instruction that needs to be analysed. |
660 | context_->UpdateDefUse(const_inst); |
661 | return const_inst; |
662 | } |
663 | } |
664 | |
665 | uint32_t result_val = 0; |
666 | bool successful = false; |
667 | // If all parameters are constant, fold the instruction to a constant. |
668 | if (!missing_constants && inst->IsFoldableByFoldScalar()) { |
669 | result_val = FoldScalars(inst->opcode(), constants); |
670 | successful = true; |
671 | } |
672 | |
673 | if (!successful && inst->IsFoldableByFoldScalar()) { |
674 | successful = FoldIntegerOpToConstant(inst, id_map, &result_val); |
675 | } |
676 | |
677 | if (successful) { |
678 | const analysis::Constant* result_const = |
679 | const_mgr->GetConstant(const_mgr->GetType(inst), {result_val}); |
680 | Instruction* folded_inst = |
681 | const_mgr->GetDefiningInstruction(result_const, inst->type_id()); |
682 | return folded_inst; |
683 | } |
684 | return nullptr; |
685 | } |
686 | |
687 | bool InstructionFolder::IsFoldableType(Instruction* type_inst) const { |
688 | // Support 32-bit integers. |
689 | if (type_inst->opcode() == SpvOpTypeInt) { |
690 | return type_inst->GetSingleWordInOperand(0) == 32; |
691 | } |
692 | // Support booleans. |
693 | if (type_inst->opcode() == SpvOpTypeBool) { |
694 | return true; |
695 | } |
696 | // Nothing else yet. |
697 | return false; |
698 | } |
699 | |
700 | bool InstructionFolder::FoldInstruction(Instruction* inst) const { |
701 | bool modified = false; |
702 | Instruction* folded_inst(inst); |
703 | while (folded_inst->opcode() != SpvOpCopyObject && |
704 | FoldInstructionInternal(&*folded_inst)) { |
705 | modified = true; |
706 | } |
707 | return modified; |
708 | } |
709 | |
710 | } // namespace opt |
711 | } // namespace spvtools |
712 | |