1 | // Copyright (c) 2018 Google LLC |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/const_folding_rules.h" |
16 | |
17 | #include "source/opt/ir_context.h" |
18 | |
19 | namespace spvtools { |
20 | namespace opt { |
21 | namespace { |
22 | |
23 | const uint32_t = 0; |
24 | |
25 | // Returns true if |type| is Float or a vector of Float. |
26 | bool HasFloatingPoint(const analysis::Type* type) { |
27 | if (type->AsFloat()) { |
28 | return true; |
29 | } else if (const analysis::Vector* vec_type = type->AsVector()) { |
30 | return vec_type->element_type()->AsFloat() != nullptr; |
31 | } |
32 | |
33 | return false; |
34 | } |
35 | |
36 | // Folds an OpcompositeExtract where input is a composite constant. |
37 | ConstantFoldingRule () { |
38 | return [](IRContext* context, Instruction* inst, |
39 | const std::vector<const analysis::Constant*>& constants) |
40 | -> const analysis::Constant* { |
41 | const analysis::Constant* c = constants[kExtractCompositeIdInIdx]; |
42 | if (c == nullptr) { |
43 | return nullptr; |
44 | } |
45 | |
46 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
47 | uint32_t element_index = inst->GetSingleWordInOperand(i); |
48 | if (c->AsNullConstant()) { |
49 | // Return Null for the return type. |
50 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
51 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
52 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {}); |
53 | } |
54 | |
55 | auto cc = c->AsCompositeConstant(); |
56 | assert(cc != nullptr); |
57 | auto components = cc->GetComponents(); |
58 | // Protect against invalid IR. Refuse to fold if the index is out |
59 | // of bounds. |
60 | if (element_index >= components.size()) return nullptr; |
61 | c = components[element_index]; |
62 | } |
63 | return c; |
64 | }; |
65 | } |
66 | |
67 | ConstantFoldingRule FoldVectorShuffleWithConstants() { |
68 | return [](IRContext* context, Instruction* inst, |
69 | const std::vector<const analysis::Constant*>& constants) |
70 | -> const analysis::Constant* { |
71 | assert(inst->opcode() == SpvOpVectorShuffle); |
72 | const analysis::Constant* c1 = constants[0]; |
73 | const analysis::Constant* c2 = constants[1]; |
74 | if (c1 == nullptr || c2 == nullptr) { |
75 | return nullptr; |
76 | } |
77 | |
78 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
79 | const analysis::Type* element_type = c1->type()->AsVector()->element_type(); |
80 | |
81 | std::vector<const analysis::Constant*> c1_components; |
82 | if (const analysis::VectorConstant* vec_const = c1->AsVectorConstant()) { |
83 | c1_components = vec_const->GetComponents(); |
84 | } else { |
85 | assert(c1->AsNullConstant()); |
86 | const analysis::Constant* element = |
87 | const_mgr->GetConstant(element_type, {}); |
88 | c1_components.resize(c1->type()->AsVector()->element_count(), element); |
89 | } |
90 | std::vector<const analysis::Constant*> c2_components; |
91 | if (const analysis::VectorConstant* vec_const = c2->AsVectorConstant()) { |
92 | c2_components = vec_const->GetComponents(); |
93 | } else { |
94 | assert(c2->AsNullConstant()); |
95 | const analysis::Constant* element = |
96 | const_mgr->GetConstant(element_type, {}); |
97 | c2_components.resize(c2->type()->AsVector()->element_count(), element); |
98 | } |
99 | |
100 | std::vector<uint32_t> ids; |
101 | const uint32_t undef_literal_value = 0xffffffff; |
102 | for (uint32_t i = 2; i < inst->NumInOperands(); ++i) { |
103 | uint32_t index = inst->GetSingleWordInOperand(i); |
104 | if (index == undef_literal_value) { |
105 | // Don't fold shuffle with undef literal value. |
106 | return nullptr; |
107 | } else if (index < c1_components.size()) { |
108 | Instruction* member_inst = |
109 | const_mgr->GetDefiningInstruction(c1_components[index]); |
110 | ids.push_back(member_inst->result_id()); |
111 | } else { |
112 | Instruction* member_inst = const_mgr->GetDefiningInstruction( |
113 | c2_components[index - c1_components.size()]); |
114 | ids.push_back(member_inst->result_id()); |
115 | } |
116 | } |
117 | |
118 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
119 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
120 | }; |
121 | } |
122 | |
123 | ConstantFoldingRule FoldVectorTimesScalar() { |
124 | return [](IRContext* context, Instruction* inst, |
125 | const std::vector<const analysis::Constant*>& constants) |
126 | -> const analysis::Constant* { |
127 | assert(inst->opcode() == SpvOpVectorTimesScalar); |
128 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
129 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
130 | |
131 | if (!inst->IsFloatingPointFoldingAllowed()) { |
132 | if (HasFloatingPoint(type_mgr->GetType(inst->type_id()))) { |
133 | return nullptr; |
134 | } |
135 | } |
136 | |
137 | const analysis::Constant* c1 = constants[0]; |
138 | const analysis::Constant* c2 = constants[1]; |
139 | |
140 | if (c1 && c1->IsZero()) { |
141 | return c1; |
142 | } |
143 | |
144 | if (c2 && c2->IsZero()) { |
145 | // Get or create the NullConstant for this type. |
146 | std::vector<uint32_t> ids; |
147 | return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), ids); |
148 | } |
149 | |
150 | if (c1 == nullptr || c2 == nullptr) { |
151 | return nullptr; |
152 | } |
153 | |
154 | // Check result type. |
155 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
156 | const analysis::Vector* vector_type = result_type->AsVector(); |
157 | assert(vector_type != nullptr); |
158 | const analysis::Type* element_type = vector_type->element_type(); |
159 | assert(element_type != nullptr); |
160 | const analysis::Float* float_type = element_type->AsFloat(); |
161 | assert(float_type != nullptr); |
162 | |
163 | // Check types of c1 and c2. |
164 | assert(c1->type()->AsVector() == vector_type); |
165 | assert(c1->type()->AsVector()->element_type() == element_type && |
166 | c2->type() == element_type); |
167 | |
168 | // Get a float vector that is the result of vector-times-scalar. |
169 | std::vector<const analysis::Constant*> c1_components = |
170 | c1->GetVectorComponents(const_mgr); |
171 | std::vector<uint32_t> ids; |
172 | if (float_type->width() == 32) { |
173 | float scalar = c2->GetFloat(); |
174 | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
175 | utils::FloatProxy<float> result(c1_components[i]->GetFloat() * scalar); |
176 | std::vector<uint32_t> words = result.GetWords(); |
177 | const analysis::Constant* new_elem = |
178 | const_mgr->GetConstant(float_type, words); |
179 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
180 | } |
181 | return const_mgr->GetConstant(vector_type, ids); |
182 | } else if (float_type->width() == 64) { |
183 | double scalar = c2->GetDouble(); |
184 | for (uint32_t i = 0; i < c1_components.size(); ++i) { |
185 | utils::FloatProxy<double> result(c1_components[i]->GetDouble() * |
186 | scalar); |
187 | std::vector<uint32_t> words = result.GetWords(); |
188 | const analysis::Constant* new_elem = |
189 | const_mgr->GetConstant(float_type, words); |
190 | ids.push_back(const_mgr->GetDefiningInstruction(new_elem)->result_id()); |
191 | } |
192 | return const_mgr->GetConstant(vector_type, ids); |
193 | } |
194 | return nullptr; |
195 | }; |
196 | } |
197 | |
198 | ConstantFoldingRule FoldCompositeWithConstants() { |
199 | // Folds an OpCompositeConstruct where all of the inputs are constants to a |
200 | // constant. A new constant is created if necessary. |
201 | return [](IRContext* context, Instruction* inst, |
202 | const std::vector<const analysis::Constant*>& constants) |
203 | -> const analysis::Constant* { |
204 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
205 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
206 | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
207 | Instruction* type_inst = |
208 | context->get_def_use_mgr()->GetDef(inst->type_id()); |
209 | |
210 | std::vector<uint32_t> ids; |
211 | for (uint32_t i = 0; i < constants.size(); ++i) { |
212 | const analysis::Constant* element_const = constants[i]; |
213 | if (element_const == nullptr) { |
214 | return nullptr; |
215 | } |
216 | |
217 | uint32_t component_type_id = 0; |
218 | if (type_inst->opcode() == SpvOpTypeStruct) { |
219 | component_type_id = type_inst->GetSingleWordInOperand(i); |
220 | } else if (type_inst->opcode() == SpvOpTypeArray) { |
221 | component_type_id = type_inst->GetSingleWordInOperand(0); |
222 | } |
223 | |
224 | uint32_t element_id = |
225 | const_mgr->FindDeclaredConstant(element_const, component_type_id); |
226 | if (element_id == 0) { |
227 | return nullptr; |
228 | } |
229 | ids.push_back(element_id); |
230 | } |
231 | return const_mgr->GetConstant(new_type, ids); |
232 | }; |
233 | } |
234 | |
235 | // The interface for a function that returns the result of applying a scalar |
236 | // floating-point binary operation on |a| and |b|. The type of the return value |
237 | // will be |type|. The input constants must also be of type |type|. |
238 | using UnaryScalarFoldingRule = std::function<const analysis::Constant*( |
239 | const analysis::Type* result_type, const analysis::Constant* a, |
240 | analysis::ConstantManager*)>; |
241 | |
242 | // The interface for a function that returns the result of applying a scalar |
243 | // floating-point binary operation on |a| and |b|. The type of the return value |
244 | // will be |type|. The input constants must also be of type |type|. |
245 | using BinaryScalarFoldingRule = std::function<const analysis::Constant*( |
246 | const analysis::Type* result_type, const analysis::Constant* a, |
247 | const analysis::Constant* b, analysis::ConstantManager*)>; |
248 | |
249 | // Returns a |ConstantFoldingRule| that folds unary floating point scalar ops |
250 | // using |scalar_rule| and unary float point vectors ops by applying |
251 | // |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| |
252 | // that is returned assumes that |constants| contains 1 entry. If they are |
253 | // not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| |
254 | // whose element type is |Float| or |Integer|. |
255 | ConstantFoldingRule FoldFPUnaryOp(UnaryScalarFoldingRule scalar_rule) { |
256 | return [scalar_rule](IRContext* context, Instruction* inst, |
257 | const std::vector<const analysis::Constant*>& constants) |
258 | -> const analysis::Constant* { |
259 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
260 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
261 | const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); |
262 | const analysis::Vector* vector_type = result_type->AsVector(); |
263 | |
264 | if (!inst->IsFloatingPointFoldingAllowed()) { |
265 | return nullptr; |
266 | } |
267 | |
268 | const analysis::Constant* arg = |
269 | (inst->opcode() == SpvOpExtInst) ? constants[1] : constants[0]; |
270 | |
271 | if (arg == nullptr) { |
272 | return nullptr; |
273 | } |
274 | |
275 | if (vector_type != nullptr) { |
276 | std::vector<const analysis::Constant*> a_components; |
277 | std::vector<const analysis::Constant*> results_components; |
278 | |
279 | a_components = arg->GetVectorComponents(const_mgr); |
280 | |
281 | // Fold each component of the vector. |
282 | for (uint32_t i = 0; i < a_components.size(); ++i) { |
283 | results_components.push_back(scalar_rule(vector_type->element_type(), |
284 | a_components[i], const_mgr)); |
285 | if (results_components[i] == nullptr) { |
286 | return nullptr; |
287 | } |
288 | } |
289 | |
290 | // Build the constant object and return it. |
291 | std::vector<uint32_t> ids; |
292 | for (const analysis::Constant* member : results_components) { |
293 | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
294 | } |
295 | return const_mgr->GetConstant(vector_type, ids); |
296 | } else { |
297 | return scalar_rule(result_type, arg, const_mgr); |
298 | } |
299 | }; |
300 | } |
301 | |
302 | // Returns the result of folding the constants in |constants| according the |
303 | // |scalar_rule|. If |result_type| is a vector, then |scalar_rule| is applied |
304 | // per component. |
305 | const analysis::Constant* FoldFPBinaryOp( |
306 | BinaryScalarFoldingRule scalar_rule, uint32_t result_type_id, |
307 | const std::vector<const analysis::Constant*>& constants, |
308 | IRContext* context) { |
309 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
310 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
311 | const analysis::Type* result_type = type_mgr->GetType(result_type_id); |
312 | const analysis::Vector* vector_type = result_type->AsVector(); |
313 | |
314 | if (constants[0] == nullptr || constants[1] == nullptr) { |
315 | return nullptr; |
316 | } |
317 | |
318 | if (vector_type != nullptr) { |
319 | std::vector<const analysis::Constant*> a_components; |
320 | std::vector<const analysis::Constant*> b_components; |
321 | std::vector<const analysis::Constant*> results_components; |
322 | |
323 | a_components = constants[0]->GetVectorComponents(const_mgr); |
324 | b_components = constants[1]->GetVectorComponents(const_mgr); |
325 | |
326 | // Fold each component of the vector. |
327 | for (uint32_t i = 0; i < a_components.size(); ++i) { |
328 | results_components.push_back(scalar_rule(vector_type->element_type(), |
329 | a_components[i], b_components[i], |
330 | const_mgr)); |
331 | if (results_components[i] == nullptr) { |
332 | return nullptr; |
333 | } |
334 | } |
335 | |
336 | // Build the constant object and return it. |
337 | std::vector<uint32_t> ids; |
338 | for (const analysis::Constant* member : results_components) { |
339 | ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); |
340 | } |
341 | return const_mgr->GetConstant(vector_type, ids); |
342 | } else { |
343 | return scalar_rule(result_type, constants[0], constants[1], const_mgr); |
344 | } |
345 | } |
346 | |
347 | // Returns a |ConstantFoldingRule| that folds floating point scalars using |
348 | // |scalar_rule| and vectors of floating point by applying |scalar_rule| to the |
349 | // elements of the vector. The |ConstantFoldingRule| that is returned assumes |
350 | // that |constants| contains 2 entries. If they are not |nullptr|, then their |
351 | // type is either |Float| or a |Vector| whose element type is |Float|. |
352 | ConstantFoldingRule FoldFPBinaryOp(BinaryScalarFoldingRule scalar_rule) { |
353 | return [scalar_rule](IRContext* context, Instruction* inst, |
354 | const std::vector<const analysis::Constant*>& constants) |
355 | -> const analysis::Constant* { |
356 | if (!inst->IsFloatingPointFoldingAllowed()) { |
357 | return nullptr; |
358 | } |
359 | if (inst->opcode() == SpvOpExtInst) { |
360 | return FoldFPBinaryOp(scalar_rule, inst->type_id(), |
361 | {constants[1], constants[2]}, context); |
362 | } |
363 | return FoldFPBinaryOp(scalar_rule, inst->type_id(), constants, context); |
364 | }; |
365 | } |
366 | |
367 | // This macro defines a |UnaryScalarFoldingRule| that performs float to |
368 | // integer conversion. |
369 | // TODO(greg-lunarg): Support for 64-bit integer types. |
370 | UnaryScalarFoldingRule FoldFToIOp() { |
371 | return [](const analysis::Type* result_type, const analysis::Constant* a, |
372 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
373 | assert(result_type != nullptr && a != nullptr); |
374 | const analysis::Integer* integer_type = result_type->AsInteger(); |
375 | const analysis::Float* float_type = a->type()->AsFloat(); |
376 | assert(float_type != nullptr); |
377 | assert(integer_type != nullptr); |
378 | if (integer_type->width() != 32) return nullptr; |
379 | if (float_type->width() == 32) { |
380 | float fa = a->GetFloat(); |
381 | uint32_t result = integer_type->IsSigned() |
382 | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
383 | : static_cast<uint32_t>(fa); |
384 | std::vector<uint32_t> words = {result}; |
385 | return const_mgr->GetConstant(result_type, words); |
386 | } else if (float_type->width() == 64) { |
387 | double fa = a->GetDouble(); |
388 | uint32_t result = integer_type->IsSigned() |
389 | ? static_cast<uint32_t>(static_cast<int32_t>(fa)) |
390 | : static_cast<uint32_t>(fa); |
391 | std::vector<uint32_t> words = {result}; |
392 | return const_mgr->GetConstant(result_type, words); |
393 | } |
394 | return nullptr; |
395 | }; |
396 | } |
397 | |
398 | // This function defines a |UnaryScalarFoldingRule| that performs integer to |
399 | // float conversion. |
400 | // TODO(greg-lunarg): Support for 64-bit integer types. |
401 | UnaryScalarFoldingRule FoldIToFOp() { |
402 | return [](const analysis::Type* result_type, const analysis::Constant* a, |
403 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
404 | assert(result_type != nullptr && a != nullptr); |
405 | const analysis::Integer* integer_type = a->type()->AsInteger(); |
406 | const analysis::Float* float_type = result_type->AsFloat(); |
407 | assert(float_type != nullptr); |
408 | assert(integer_type != nullptr); |
409 | if (integer_type->width() != 32) return nullptr; |
410 | uint32_t ua = a->GetU32(); |
411 | if (float_type->width() == 32) { |
412 | float result_val = integer_type->IsSigned() |
413 | ? static_cast<float>(static_cast<int32_t>(ua)) |
414 | : static_cast<float>(ua); |
415 | utils::FloatProxy<float> result(result_val); |
416 | std::vector<uint32_t> words = {result.data()}; |
417 | return const_mgr->GetConstant(result_type, words); |
418 | } else if (float_type->width() == 64) { |
419 | double result_val = integer_type->IsSigned() |
420 | ? static_cast<double>(static_cast<int32_t>(ua)) |
421 | : static_cast<double>(ua); |
422 | utils::FloatProxy<double> result(result_val); |
423 | std::vector<uint32_t> words = result.GetWords(); |
424 | return const_mgr->GetConstant(result_type, words); |
425 | } |
426 | return nullptr; |
427 | }; |
428 | } |
429 | |
430 | // This defines a |UnaryScalarFoldingRule| that performs |OpQuantizeToF16|. |
431 | UnaryScalarFoldingRule FoldQuantizeToF16Scalar() { |
432 | return [](const analysis::Type* result_type, const analysis::Constant* a, |
433 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
434 | assert(result_type != nullptr && a != nullptr); |
435 | const analysis::Float* float_type = a->type()->AsFloat(); |
436 | assert(float_type != nullptr); |
437 | if (float_type->width() != 32) { |
438 | return nullptr; |
439 | } |
440 | |
441 | float fa = a->GetFloat(); |
442 | utils::HexFloat<utils::FloatProxy<float>> orignal(fa); |
443 | utils::HexFloat<utils::FloatProxy<utils::Float16>> quantized(0); |
444 | utils::HexFloat<utils::FloatProxy<float>> result(0.0f); |
445 | orignal.castTo(quantized, utils::round_direction::kToZero); |
446 | quantized.castTo(result, utils::round_direction::kToZero); |
447 | std::vector<uint32_t> words = {result.getBits()}; |
448 | return const_mgr->GetConstant(result_type, words); |
449 | }; |
450 | } |
451 | |
452 | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
453 | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
454 | #define FOLD_FPARITH_OP(op) \ |
455 | [](const analysis::Type* result_type_in_macro, const analysis::Constant* a, \ |
456 | const analysis::Constant* b, \ |
457 | analysis::ConstantManager* const_mgr_in_macro) \ |
458 | -> const analysis::Constant* { \ |
459 | assert(result_type_in_macro != nullptr && a != nullptr && b != nullptr); \ |
460 | assert(result_type_in_macro == a->type() && \ |
461 | result_type_in_macro == b->type()); \ |
462 | const analysis::Float* float_type_in_macro = \ |
463 | result_type_in_macro->AsFloat(); \ |
464 | assert(float_type_in_macro != nullptr); \ |
465 | if (float_type_in_macro->width() == 32) { \ |
466 | float fa = a->GetFloat(); \ |
467 | float fb = b->GetFloat(); \ |
468 | utils::FloatProxy<float> result_in_macro(fa op fb); \ |
469 | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
470 | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
471 | words_in_macro); \ |
472 | } else if (float_type_in_macro->width() == 64) { \ |
473 | double fa = a->GetDouble(); \ |
474 | double fb = b->GetDouble(); \ |
475 | utils::FloatProxy<double> result_in_macro(fa op fb); \ |
476 | std::vector<uint32_t> words_in_macro = result_in_macro.GetWords(); \ |
477 | return const_mgr_in_macro->GetConstant(result_type_in_macro, \ |
478 | words_in_macro); \ |
479 | } \ |
480 | return nullptr; \ |
481 | } |
482 | |
483 | // Define the folding rule for conversion between floating point and integer |
484 | ConstantFoldingRule FoldFToI() { return FoldFPUnaryOp(FoldFToIOp()); } |
485 | ConstantFoldingRule FoldIToF() { return FoldFPUnaryOp(FoldIToFOp()); } |
486 | ConstantFoldingRule FoldQuantizeToF16() { |
487 | return FoldFPUnaryOp(FoldQuantizeToF16Scalar()); |
488 | } |
489 | |
490 | // Define the folding rules for subtraction, addition, multiplication, and |
491 | // division for floating point values. |
492 | ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } |
493 | ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } |
494 | ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } |
495 | ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); } |
496 | |
497 | bool CompareFloatingPoint(bool op_result, bool op_unordered, |
498 | bool need_ordered) { |
499 | if (need_ordered) { |
500 | // operands are ordered and Operand 1 is |op| Operand 2 |
501 | return !op_unordered && op_result; |
502 | } else { |
503 | // operands are unordered or Operand 1 is |op| Operand 2 |
504 | return op_unordered || op_result; |
505 | } |
506 | } |
507 | |
508 | // This macro defines a |BinaryScalarFoldingRule| that applies |op|. The |
509 | // operator |op| must work for both float and double, and use syntax "f1 op f2". |
510 | #define FOLD_FPCMP_OP(op, ord) \ |
511 | [](const analysis::Type* result_type, const analysis::Constant* a, \ |
512 | const analysis::Constant* b, \ |
513 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { \ |
514 | assert(result_type != nullptr && a != nullptr && b != nullptr); \ |
515 | assert(result_type->AsBool()); \ |
516 | assert(a->type() == b->type()); \ |
517 | const analysis::Float* float_type = a->type()->AsFloat(); \ |
518 | assert(float_type != nullptr); \ |
519 | if (float_type->width() == 32) { \ |
520 | float fa = a->GetFloat(); \ |
521 | float fb = b->GetFloat(); \ |
522 | bool result = CompareFloatingPoint( \ |
523 | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
524 | std::vector<uint32_t> words = {uint32_t(result)}; \ |
525 | return const_mgr->GetConstant(result_type, words); \ |
526 | } else if (float_type->width() == 64) { \ |
527 | double fa = a->GetDouble(); \ |
528 | double fb = b->GetDouble(); \ |
529 | bool result = CompareFloatingPoint( \ |
530 | fa op fb, std::isnan(fa) || std::isnan(fb), ord); \ |
531 | std::vector<uint32_t> words = {uint32_t(result)}; \ |
532 | return const_mgr->GetConstant(result_type, words); \ |
533 | } \ |
534 | return nullptr; \ |
535 | } |
536 | |
537 | // Define the folding rules for ordered and unordered comparison for floating |
538 | // point values. |
539 | ConstantFoldingRule FoldFOrdEqual() { |
540 | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, true)); |
541 | } |
542 | ConstantFoldingRule FoldFUnordEqual() { |
543 | return FoldFPBinaryOp(FOLD_FPCMP_OP(==, false)); |
544 | } |
545 | ConstantFoldingRule FoldFOrdNotEqual() { |
546 | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, true)); |
547 | } |
548 | ConstantFoldingRule FoldFUnordNotEqual() { |
549 | return FoldFPBinaryOp(FOLD_FPCMP_OP(!=, false)); |
550 | } |
551 | ConstantFoldingRule FoldFOrdLessThan() { |
552 | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, true)); |
553 | } |
554 | ConstantFoldingRule FoldFUnordLessThan() { |
555 | return FoldFPBinaryOp(FOLD_FPCMP_OP(<, false)); |
556 | } |
557 | ConstantFoldingRule FoldFOrdGreaterThan() { |
558 | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, true)); |
559 | } |
560 | ConstantFoldingRule FoldFUnordGreaterThan() { |
561 | return FoldFPBinaryOp(FOLD_FPCMP_OP(>, false)); |
562 | } |
563 | ConstantFoldingRule FoldFOrdLessThanEqual() { |
564 | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, true)); |
565 | } |
566 | ConstantFoldingRule FoldFUnordLessThanEqual() { |
567 | return FoldFPBinaryOp(FOLD_FPCMP_OP(<=, false)); |
568 | } |
569 | ConstantFoldingRule FoldFOrdGreaterThanEqual() { |
570 | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, true)); |
571 | } |
572 | ConstantFoldingRule FoldFUnordGreaterThanEqual() { |
573 | return FoldFPBinaryOp(FOLD_FPCMP_OP(>=, false)); |
574 | } |
575 | |
576 | // Folds an OpDot where all of the inputs are constants to a |
577 | // constant. A new constant is created if necessary. |
578 | ConstantFoldingRule FoldOpDotWithConstants() { |
579 | return [](IRContext* context, Instruction* inst, |
580 | const std::vector<const analysis::Constant*>& constants) |
581 | -> const analysis::Constant* { |
582 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
583 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
584 | const analysis::Type* new_type = type_mgr->GetType(inst->type_id()); |
585 | assert(new_type->AsFloat() && "OpDot should have a float return type." ); |
586 | const analysis::Float* float_type = new_type->AsFloat(); |
587 | |
588 | if (!inst->IsFloatingPointFoldingAllowed()) { |
589 | return nullptr; |
590 | } |
591 | |
592 | // If one of the operands is 0, then the result is 0. |
593 | bool has_zero_operand = false; |
594 | |
595 | for (int i = 0; i < 2; ++i) { |
596 | if (constants[i]) { |
597 | if (constants[i]->AsNullConstant() || |
598 | constants[i]->AsVectorConstant()->IsZero()) { |
599 | has_zero_operand = true; |
600 | break; |
601 | } |
602 | } |
603 | } |
604 | |
605 | if (has_zero_operand) { |
606 | if (float_type->width() == 32) { |
607 | utils::FloatProxy<float> result(0.0f); |
608 | std::vector<uint32_t> words = result.GetWords(); |
609 | return const_mgr->GetConstant(float_type, words); |
610 | } |
611 | if (float_type->width() == 64) { |
612 | utils::FloatProxy<double> result(0.0); |
613 | std::vector<uint32_t> words = result.GetWords(); |
614 | return const_mgr->GetConstant(float_type, words); |
615 | } |
616 | return nullptr; |
617 | } |
618 | |
619 | if (constants[0] == nullptr || constants[1] == nullptr) { |
620 | return nullptr; |
621 | } |
622 | |
623 | std::vector<const analysis::Constant*> a_components; |
624 | std::vector<const analysis::Constant*> b_components; |
625 | |
626 | a_components = constants[0]->GetVectorComponents(const_mgr); |
627 | b_components = constants[1]->GetVectorComponents(const_mgr); |
628 | |
629 | utils::FloatProxy<double> result(0.0); |
630 | std::vector<uint32_t> words = result.GetWords(); |
631 | const analysis::Constant* result_const = |
632 | const_mgr->GetConstant(float_type, words); |
633 | for (uint32_t i = 0; i < a_components.size() && result_const != nullptr; |
634 | ++i) { |
635 | if (a_components[i] == nullptr || b_components[i] == nullptr) { |
636 | return nullptr; |
637 | } |
638 | |
639 | const analysis::Constant* component = FOLD_FPARITH_OP(*)( |
640 | new_type, a_components[i], b_components[i], const_mgr); |
641 | if (component == nullptr) { |
642 | return nullptr; |
643 | } |
644 | result_const = |
645 | FOLD_FPARITH_OP(+)(new_type, result_const, component, const_mgr); |
646 | } |
647 | return result_const; |
648 | }; |
649 | } |
650 | |
651 | // This function defines a |UnaryScalarFoldingRule| that subtracts the constant |
652 | // from zero. |
653 | UnaryScalarFoldingRule FoldFNegateOp() { |
654 | return [](const analysis::Type* result_type, const analysis::Constant* a, |
655 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
656 | assert(result_type != nullptr && a != nullptr); |
657 | assert(result_type == a->type()); |
658 | const analysis::Float* float_type = result_type->AsFloat(); |
659 | assert(float_type != nullptr); |
660 | if (float_type->width() == 32) { |
661 | float fa = a->GetFloat(); |
662 | utils::FloatProxy<float> result(-fa); |
663 | std::vector<uint32_t> words = result.GetWords(); |
664 | return const_mgr->GetConstant(result_type, words); |
665 | } else if (float_type->width() == 64) { |
666 | double da = a->GetDouble(); |
667 | utils::FloatProxy<double> result(-da); |
668 | std::vector<uint32_t> words = result.GetWords(); |
669 | return const_mgr->GetConstant(result_type, words); |
670 | } |
671 | return nullptr; |
672 | }; |
673 | } |
674 | |
675 | ConstantFoldingRule FoldFNegate() { return FoldFPUnaryOp(FoldFNegateOp()); } |
676 | |
677 | ConstantFoldingRule FoldFClampFeedingCompare(uint32_t cmp_opcode) { |
678 | return [cmp_opcode](IRContext* context, Instruction* inst, |
679 | const std::vector<const analysis::Constant*>& constants) |
680 | -> const analysis::Constant* { |
681 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
682 | analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); |
683 | |
684 | if (!inst->IsFloatingPointFoldingAllowed()) { |
685 | return nullptr; |
686 | } |
687 | |
688 | uint32_t non_const_idx = (constants[0] ? 1 : 0); |
689 | uint32_t operand_id = inst->GetSingleWordInOperand(non_const_idx); |
690 | Instruction* operand_inst = def_use_mgr->GetDef(operand_id); |
691 | |
692 | analysis::TypeManager* type_mgr = context->get_type_mgr(); |
693 | const analysis::Type* operand_type = |
694 | type_mgr->GetType(operand_inst->type_id()); |
695 | |
696 | if (!operand_type->AsFloat()) { |
697 | return nullptr; |
698 | } |
699 | |
700 | if (operand_type->AsFloat()->width() != 32 && |
701 | operand_type->AsFloat()->width() != 64) { |
702 | return nullptr; |
703 | } |
704 | |
705 | if (operand_inst->opcode() != SpvOpExtInst) { |
706 | return nullptr; |
707 | } |
708 | |
709 | if (operand_inst->GetSingleWordInOperand(1) != GLSLstd450FClamp) { |
710 | return nullptr; |
711 | } |
712 | |
713 | if (constants[1] == nullptr && constants[0] == nullptr) { |
714 | return nullptr; |
715 | } |
716 | |
717 | uint32_t max_id = operand_inst->GetSingleWordInOperand(4); |
718 | const analysis::Constant* max_const = |
719 | const_mgr->FindDeclaredConstant(max_id); |
720 | |
721 | uint32_t min_id = operand_inst->GetSingleWordInOperand(3); |
722 | const analysis::Constant* min_const = |
723 | const_mgr->FindDeclaredConstant(min_id); |
724 | |
725 | bool found_result = false; |
726 | bool result = false; |
727 | |
728 | switch (cmp_opcode) { |
729 | case SpvOpFOrdLessThan: |
730 | case SpvOpFUnordLessThan: |
731 | case SpvOpFOrdGreaterThanEqual: |
732 | case SpvOpFUnordGreaterThanEqual: |
733 | if (constants[0]) { |
734 | if (min_const) { |
735 | if (constants[0]->GetValueAsDouble() < |
736 | min_const->GetValueAsDouble()) { |
737 | found_result = true; |
738 | result = (cmp_opcode == SpvOpFOrdLessThan || |
739 | cmp_opcode == SpvOpFUnordLessThan); |
740 | } |
741 | } |
742 | if (max_const) { |
743 | if (constants[0]->GetValueAsDouble() >= |
744 | max_const->GetValueAsDouble()) { |
745 | found_result = true; |
746 | result = !(cmp_opcode == SpvOpFOrdLessThan || |
747 | cmp_opcode == SpvOpFUnordLessThan); |
748 | } |
749 | } |
750 | } |
751 | |
752 | if (constants[1]) { |
753 | if (max_const) { |
754 | if (max_const->GetValueAsDouble() < |
755 | constants[1]->GetValueAsDouble()) { |
756 | found_result = true; |
757 | result = (cmp_opcode == SpvOpFOrdLessThan || |
758 | cmp_opcode == SpvOpFUnordLessThan); |
759 | } |
760 | } |
761 | |
762 | if (min_const) { |
763 | if (min_const->GetValueAsDouble() >= |
764 | constants[1]->GetValueAsDouble()) { |
765 | found_result = true; |
766 | result = !(cmp_opcode == SpvOpFOrdLessThan || |
767 | cmp_opcode == SpvOpFUnordLessThan); |
768 | } |
769 | } |
770 | } |
771 | break; |
772 | case SpvOpFOrdGreaterThan: |
773 | case SpvOpFUnordGreaterThan: |
774 | case SpvOpFOrdLessThanEqual: |
775 | case SpvOpFUnordLessThanEqual: |
776 | if (constants[0]) { |
777 | if (min_const) { |
778 | if (constants[0]->GetValueAsDouble() <= |
779 | min_const->GetValueAsDouble()) { |
780 | found_result = true; |
781 | result = (cmp_opcode == SpvOpFOrdLessThanEqual || |
782 | cmp_opcode == SpvOpFUnordLessThanEqual); |
783 | } |
784 | } |
785 | if (max_const) { |
786 | if (constants[0]->GetValueAsDouble() > |
787 | max_const->GetValueAsDouble()) { |
788 | found_result = true; |
789 | result = !(cmp_opcode == SpvOpFOrdLessThanEqual || |
790 | cmp_opcode == SpvOpFUnordLessThanEqual); |
791 | } |
792 | } |
793 | } |
794 | |
795 | if (constants[1]) { |
796 | if (max_const) { |
797 | if (max_const->GetValueAsDouble() <= |
798 | constants[1]->GetValueAsDouble()) { |
799 | found_result = true; |
800 | result = (cmp_opcode == SpvOpFOrdLessThanEqual || |
801 | cmp_opcode == SpvOpFUnordLessThanEqual); |
802 | } |
803 | } |
804 | |
805 | if (min_const) { |
806 | if (min_const->GetValueAsDouble() > |
807 | constants[1]->GetValueAsDouble()) { |
808 | found_result = true; |
809 | result = !(cmp_opcode == SpvOpFOrdLessThanEqual || |
810 | cmp_opcode == SpvOpFUnordLessThanEqual); |
811 | } |
812 | } |
813 | } |
814 | break; |
815 | default: |
816 | return nullptr; |
817 | } |
818 | |
819 | if (!found_result) { |
820 | return nullptr; |
821 | } |
822 | |
823 | const analysis::Type* bool_type = |
824 | context->get_type_mgr()->GetType(inst->type_id()); |
825 | const analysis::Constant* result_const = |
826 | const_mgr->GetConstant(bool_type, {static_cast<uint32_t>(result)}); |
827 | assert(result_const); |
828 | return result_const; |
829 | }; |
830 | } |
831 | |
832 | ConstantFoldingRule FoldFMix() { |
833 | return [](IRContext* context, Instruction* inst, |
834 | const std::vector<const analysis::Constant*>& constants) |
835 | -> const analysis::Constant* { |
836 | analysis::ConstantManager* const_mgr = context->get_constant_mgr(); |
837 | assert(inst->opcode() == SpvOpExtInst && |
838 | "Expecting an extended instruction." ); |
839 | assert(inst->GetSingleWordInOperand(0) == |
840 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
841 | "Expecting a GLSLstd450 extended instruction." ); |
842 | assert(inst->GetSingleWordInOperand(1) == GLSLstd450FMix && |
843 | "Expecting and FMix instruction." ); |
844 | |
845 | if (!inst->IsFloatingPointFoldingAllowed()) { |
846 | return nullptr; |
847 | } |
848 | |
849 | // Make sure all FMix operands are constants. |
850 | for (uint32_t i = 1; i < 4; i++) { |
851 | if (constants[i] == nullptr) { |
852 | return nullptr; |
853 | } |
854 | } |
855 | |
856 | const analysis::Constant* one; |
857 | bool is_vector = false; |
858 | const analysis::Type* result_type = constants[1]->type(); |
859 | const analysis::Type* base_type = result_type; |
860 | if (base_type->AsVector()) { |
861 | is_vector = true; |
862 | base_type = base_type->AsVector()->element_type(); |
863 | } |
864 | assert(base_type->AsFloat() != nullptr && |
865 | "FMix is suppose to act on floats or vectors of floats." ); |
866 | |
867 | if (base_type->AsFloat()->width() == 32) { |
868 | one = const_mgr->GetConstant(base_type, |
869 | utils::FloatProxy<float>(1.0f).GetWords()); |
870 | } else { |
871 | one = const_mgr->GetConstant(base_type, |
872 | utils::FloatProxy<double>(1.0).GetWords()); |
873 | } |
874 | |
875 | if (is_vector) { |
876 | uint32_t one_id = const_mgr->GetDefiningInstruction(one)->result_id(); |
877 | one = |
878 | const_mgr->GetConstant(result_type, std::vector<uint32_t>(4, one_id)); |
879 | } |
880 | |
881 | const analysis::Constant* temp1 = FoldFPBinaryOp( |
882 | FOLD_FPARITH_OP(-), inst->type_id(), {one, constants[3]}, context); |
883 | if (temp1 == nullptr) { |
884 | return nullptr; |
885 | } |
886 | |
887 | const analysis::Constant* temp2 = FoldFPBinaryOp( |
888 | FOLD_FPARITH_OP(*), inst->type_id(), {constants[1], temp1}, context); |
889 | if (temp2 == nullptr) { |
890 | return nullptr; |
891 | } |
892 | const analysis::Constant* temp3 = |
893 | FoldFPBinaryOp(FOLD_FPARITH_OP(*), inst->type_id(), |
894 | {constants[2], constants[3]}, context); |
895 | if (temp3 == nullptr) { |
896 | return nullptr; |
897 | } |
898 | return FoldFPBinaryOp(FOLD_FPARITH_OP(+), inst->type_id(), {temp2, temp3}, |
899 | context); |
900 | }; |
901 | } |
902 | |
903 | template <class IntType> |
904 | IntType FoldIClamp(IntType x, IntType min_val, IntType max_val) { |
905 | if (x < min_val) { |
906 | x = min_val; |
907 | } |
908 | if (x > max_val) { |
909 | x = max_val; |
910 | } |
911 | return x; |
912 | } |
913 | |
914 | const analysis::Constant* FoldMin(const analysis::Type* result_type, |
915 | const analysis::Constant* a, |
916 | const analysis::Constant* b, |
917 | analysis::ConstantManager*) { |
918 | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
919 | if (int_type->width() == 32) { |
920 | if (int_type->IsSigned()) { |
921 | int32_t va = a->GetS32(); |
922 | int32_t vb = b->GetS32(); |
923 | return (va < vb ? a : b); |
924 | } else { |
925 | uint32_t va = a->GetU32(); |
926 | uint32_t vb = b->GetU32(); |
927 | return (va < vb ? a : b); |
928 | } |
929 | } else if (int_type->width() == 64) { |
930 | if (int_type->IsSigned()) { |
931 | int64_t va = a->GetS64(); |
932 | int64_t vb = b->GetS64(); |
933 | return (va < vb ? a : b); |
934 | } else { |
935 | uint64_t va = a->GetU64(); |
936 | uint64_t vb = b->GetU64(); |
937 | return (va < vb ? a : b); |
938 | } |
939 | } |
940 | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
941 | if (float_type->width() == 32) { |
942 | float va = a->GetFloat(); |
943 | float vb = b->GetFloat(); |
944 | return (va < vb ? a : b); |
945 | } else if (float_type->width() == 64) { |
946 | double va = a->GetDouble(); |
947 | double vb = b->GetDouble(); |
948 | return (va < vb ? a : b); |
949 | } |
950 | } |
951 | return nullptr; |
952 | } |
953 | |
954 | const analysis::Constant* FoldMax(const analysis::Type* result_type, |
955 | const analysis::Constant* a, |
956 | const analysis::Constant* b, |
957 | analysis::ConstantManager*) { |
958 | if (const analysis::Integer* int_type = result_type->AsInteger()) { |
959 | if (int_type->width() == 32) { |
960 | if (int_type->IsSigned()) { |
961 | int32_t va = a->GetS32(); |
962 | int32_t vb = b->GetS32(); |
963 | return (va > vb ? a : b); |
964 | } else { |
965 | uint32_t va = a->GetU32(); |
966 | uint32_t vb = b->GetU32(); |
967 | return (va > vb ? a : b); |
968 | } |
969 | } else if (int_type->width() == 64) { |
970 | if (int_type->IsSigned()) { |
971 | int64_t va = a->GetS64(); |
972 | int64_t vb = b->GetS64(); |
973 | return (va > vb ? a : b); |
974 | } else { |
975 | uint64_t va = a->GetU64(); |
976 | uint64_t vb = b->GetU64(); |
977 | return (va > vb ? a : b); |
978 | } |
979 | } |
980 | } else if (const analysis::Float* float_type = result_type->AsFloat()) { |
981 | if (float_type->width() == 32) { |
982 | float va = a->GetFloat(); |
983 | float vb = b->GetFloat(); |
984 | return (va > vb ? a : b); |
985 | } else if (float_type->width() == 64) { |
986 | double va = a->GetDouble(); |
987 | double vb = b->GetDouble(); |
988 | return (va > vb ? a : b); |
989 | } |
990 | } |
991 | return nullptr; |
992 | } |
993 | |
994 | // Fold an clamp instruction when all three operands are constant. |
995 | const analysis::Constant* FoldClamp1( |
996 | IRContext* context, Instruction* inst, |
997 | const std::vector<const analysis::Constant*>& constants) { |
998 | assert(inst->opcode() == SpvOpExtInst && |
999 | "Expecting an extended instruction." ); |
1000 | assert(inst->GetSingleWordInOperand(0) == |
1001 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1002 | "Expecting a GLSLstd450 extended instruction." ); |
1003 | |
1004 | // Make sure all Clamp operands are constants. |
1005 | for (uint32_t i = 1; i < 3; i++) { |
1006 | if (constants[i] == nullptr) { |
1007 | return nullptr; |
1008 | } |
1009 | } |
1010 | |
1011 | const analysis::Constant* temp = FoldFPBinaryOp( |
1012 | FoldMax, inst->type_id(), {constants[1], constants[2]}, context); |
1013 | if (temp == nullptr) { |
1014 | return nullptr; |
1015 | } |
1016 | return FoldFPBinaryOp(FoldMin, inst->type_id(), {temp, constants[3]}, |
1017 | context); |
1018 | } |
1019 | |
1020 | // Fold a clamp instruction when |x >= min_val|. |
1021 | const analysis::Constant* FoldClamp2( |
1022 | IRContext* context, Instruction* inst, |
1023 | const std::vector<const analysis::Constant*>& constants) { |
1024 | assert(inst->opcode() == SpvOpExtInst && |
1025 | "Expecting an extended instruction." ); |
1026 | assert(inst->GetSingleWordInOperand(0) == |
1027 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1028 | "Expecting a GLSLstd450 extended instruction." ); |
1029 | |
1030 | const analysis::Constant* x = constants[1]; |
1031 | const analysis::Constant* min_val = constants[2]; |
1032 | |
1033 | if (x == nullptr || min_val == nullptr) { |
1034 | return nullptr; |
1035 | } |
1036 | |
1037 | const analysis::Constant* temp = |
1038 | FoldFPBinaryOp(FoldMax, inst->type_id(), {x, min_val}, context); |
1039 | if (temp == min_val) { |
1040 | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1041 | // result of the max operation is |min_val|, we know the result of the min |
1042 | // operation, even if |max_val| is not a constant. |
1043 | return min_val; |
1044 | } |
1045 | return nullptr; |
1046 | } |
1047 | |
1048 | // Fold a clamp instruction when |x >= max_val|. |
1049 | const analysis::Constant* FoldClamp3( |
1050 | IRContext* context, Instruction* inst, |
1051 | const std::vector<const analysis::Constant*>& constants) { |
1052 | assert(inst->opcode() == SpvOpExtInst && |
1053 | "Expecting an extended instruction." ); |
1054 | assert(inst->GetSingleWordInOperand(0) == |
1055 | context->get_feature_mgr()->GetExtInstImportId_GLSLstd450() && |
1056 | "Expecting a GLSLstd450 extended instruction." ); |
1057 | |
1058 | const analysis::Constant* x = constants[1]; |
1059 | const analysis::Constant* max_val = constants[3]; |
1060 | |
1061 | if (x == nullptr || max_val == nullptr) { |
1062 | return nullptr; |
1063 | } |
1064 | |
1065 | const analysis::Constant* temp = |
1066 | FoldFPBinaryOp(FoldMin, inst->type_id(), {x, max_val}, context); |
1067 | if (temp == max_val) { |
1068 | // We can assume that |min_val| is less than |max_val|. Therefore, if the |
1069 | // result of the max operation is |min_val|, we know the result of the min |
1070 | // operation, even if |max_val| is not a constant. |
1071 | return max_val; |
1072 | } |
1073 | return nullptr; |
1074 | } |
1075 | |
1076 | UnaryScalarFoldingRule FoldFTranscendentalUnary(double (*fp)(double)) { |
1077 | return |
1078 | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1079 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1080 | assert(result_type != nullptr && a != nullptr); |
1081 | const analysis::Float* float_type = a->type()->AsFloat(); |
1082 | assert(float_type != nullptr); |
1083 | assert(float_type == result_type->AsFloat()); |
1084 | if (float_type->width() == 32) { |
1085 | float fa = a->GetFloat(); |
1086 | float res = static_cast<float>(fp(fa)); |
1087 | utils::FloatProxy<float> result(res); |
1088 | std::vector<uint32_t> words = result.GetWords(); |
1089 | return const_mgr->GetConstant(result_type, words); |
1090 | } else if (float_type->width() == 64) { |
1091 | double fa = a->GetDouble(); |
1092 | double res = fp(fa); |
1093 | utils::FloatProxy<double> result(res); |
1094 | std::vector<uint32_t> words = result.GetWords(); |
1095 | return const_mgr->GetConstant(result_type, words); |
1096 | } |
1097 | return nullptr; |
1098 | }; |
1099 | } |
1100 | |
1101 | BinaryScalarFoldingRule FoldFTranscendentalBinary(double (*fp)(double, |
1102 | double)) { |
1103 | return |
1104 | [fp](const analysis::Type* result_type, const analysis::Constant* a, |
1105 | const analysis::Constant* b, |
1106 | analysis::ConstantManager* const_mgr) -> const analysis::Constant* { |
1107 | assert(result_type != nullptr && a != nullptr); |
1108 | const analysis::Float* float_type = a->type()->AsFloat(); |
1109 | assert(float_type != nullptr); |
1110 | assert(float_type == result_type->AsFloat()); |
1111 | assert(float_type == b->type()->AsFloat()); |
1112 | if (float_type->width() == 32) { |
1113 | float fa = a->GetFloat(); |
1114 | float fb = b->GetFloat(); |
1115 | float res = static_cast<float>(fp(fa, fb)); |
1116 | utils::FloatProxy<float> result(res); |
1117 | std::vector<uint32_t> words = result.GetWords(); |
1118 | return const_mgr->GetConstant(result_type, words); |
1119 | } else if (float_type->width() == 64) { |
1120 | double fa = a->GetDouble(); |
1121 | double fb = b->GetDouble(); |
1122 | double res = fp(fa, fb); |
1123 | utils::FloatProxy<double> result(res); |
1124 | std::vector<uint32_t> words = result.GetWords(); |
1125 | return const_mgr->GetConstant(result_type, words); |
1126 | } |
1127 | return nullptr; |
1128 | }; |
1129 | } |
1130 | } // namespace |
1131 | |
1132 | void ConstantFoldingRules::AddFoldingRules() { |
1133 | // Add all folding rules to the list for the opcodes to which they apply. |
1134 | // Note that the order in which rules are added to the list matters. If a rule |
1135 | // applies to the instruction, the rest of the rules will not be attempted. |
1136 | // Take that into consideration. |
1137 | |
1138 | rules_[SpvOpCompositeConstruct].push_back(FoldCompositeWithConstants()); |
1139 | |
1140 | rules_[SpvOpCompositeExtract].push_back(FoldExtractWithConstants()); |
1141 | |
1142 | rules_[SpvOpConvertFToS].push_back(FoldFToI()); |
1143 | rules_[SpvOpConvertFToU].push_back(FoldFToI()); |
1144 | rules_[SpvOpConvertSToF].push_back(FoldIToF()); |
1145 | rules_[SpvOpConvertUToF].push_back(FoldIToF()); |
1146 | |
1147 | rules_[SpvOpDot].push_back(FoldOpDotWithConstants()); |
1148 | rules_[SpvOpFAdd].push_back(FoldFAdd()); |
1149 | rules_[SpvOpFDiv].push_back(FoldFDiv()); |
1150 | rules_[SpvOpFMul].push_back(FoldFMul()); |
1151 | rules_[SpvOpFSub].push_back(FoldFSub()); |
1152 | |
1153 | rules_[SpvOpFOrdEqual].push_back(FoldFOrdEqual()); |
1154 | |
1155 | rules_[SpvOpFUnordEqual].push_back(FoldFUnordEqual()); |
1156 | |
1157 | rules_[SpvOpFOrdNotEqual].push_back(FoldFOrdNotEqual()); |
1158 | |
1159 | rules_[SpvOpFUnordNotEqual].push_back(FoldFUnordNotEqual()); |
1160 | |
1161 | rules_[SpvOpFOrdLessThan].push_back(FoldFOrdLessThan()); |
1162 | rules_[SpvOpFOrdLessThan].push_back( |
1163 | FoldFClampFeedingCompare(SpvOpFOrdLessThan)); |
1164 | |
1165 | rules_[SpvOpFUnordLessThan].push_back(FoldFUnordLessThan()); |
1166 | rules_[SpvOpFUnordLessThan].push_back( |
1167 | FoldFClampFeedingCompare(SpvOpFUnordLessThan)); |
1168 | |
1169 | rules_[SpvOpFOrdGreaterThan].push_back(FoldFOrdGreaterThan()); |
1170 | rules_[SpvOpFOrdGreaterThan].push_back( |
1171 | FoldFClampFeedingCompare(SpvOpFOrdGreaterThan)); |
1172 | |
1173 | rules_[SpvOpFUnordGreaterThan].push_back(FoldFUnordGreaterThan()); |
1174 | rules_[SpvOpFUnordGreaterThan].push_back( |
1175 | FoldFClampFeedingCompare(SpvOpFUnordGreaterThan)); |
1176 | |
1177 | rules_[SpvOpFOrdLessThanEqual].push_back(FoldFOrdLessThanEqual()); |
1178 | rules_[SpvOpFOrdLessThanEqual].push_back( |
1179 | FoldFClampFeedingCompare(SpvOpFOrdLessThanEqual)); |
1180 | |
1181 | rules_[SpvOpFUnordLessThanEqual].push_back(FoldFUnordLessThanEqual()); |
1182 | rules_[SpvOpFUnordLessThanEqual].push_back( |
1183 | FoldFClampFeedingCompare(SpvOpFUnordLessThanEqual)); |
1184 | |
1185 | rules_[SpvOpFOrdGreaterThanEqual].push_back(FoldFOrdGreaterThanEqual()); |
1186 | rules_[SpvOpFOrdGreaterThanEqual].push_back( |
1187 | FoldFClampFeedingCompare(SpvOpFOrdGreaterThanEqual)); |
1188 | |
1189 | rules_[SpvOpFUnordGreaterThanEqual].push_back(FoldFUnordGreaterThanEqual()); |
1190 | rules_[SpvOpFUnordGreaterThanEqual].push_back( |
1191 | FoldFClampFeedingCompare(SpvOpFUnordGreaterThanEqual)); |
1192 | |
1193 | rules_[SpvOpVectorShuffle].push_back(FoldVectorShuffleWithConstants()); |
1194 | rules_[SpvOpVectorTimesScalar].push_back(FoldVectorTimesScalar()); |
1195 | |
1196 | rules_[SpvOpFNegate].push_back(FoldFNegate()); |
1197 | rules_[SpvOpQuantizeToF16].push_back(FoldQuantizeToF16()); |
1198 | |
1199 | // Add rules for GLSLstd450 |
1200 | FeatureManager* feature_manager = context_->get_feature_mgr(); |
1201 | uint32_t ext_inst_glslstd450_id = |
1202 | feature_manager->GetExtInstImportId_GLSLstd450(); |
1203 | if (ext_inst_glslstd450_id != 0) { |
1204 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMix}].push_back(FoldFMix()); |
1205 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMin}].push_back( |
1206 | FoldFPBinaryOp(FoldMin)); |
1207 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMin}].push_back( |
1208 | FoldFPBinaryOp(FoldMin)); |
1209 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMin}].push_back( |
1210 | FoldFPBinaryOp(FoldMin)); |
1211 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SMax}].push_back( |
1212 | FoldFPBinaryOp(FoldMax)); |
1213 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UMax}].push_back( |
1214 | FoldFPBinaryOp(FoldMax)); |
1215 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FMax}].push_back( |
1216 | FoldFPBinaryOp(FoldMax)); |
1217 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1218 | FoldClamp1); |
1219 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1220 | FoldClamp2); |
1221 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450UClamp}].push_back( |
1222 | FoldClamp3); |
1223 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1224 | FoldClamp1); |
1225 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1226 | FoldClamp2); |
1227 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450SClamp}].push_back( |
1228 | FoldClamp3); |
1229 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1230 | FoldClamp1); |
1231 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1232 | FoldClamp2); |
1233 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450FClamp}].push_back( |
1234 | FoldClamp3); |
1235 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sin}].push_back( |
1236 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sin))); |
1237 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Cos}].push_back( |
1238 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::cos))); |
1239 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Tan}].push_back( |
1240 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::tan))); |
1241 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Asin}].push_back( |
1242 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::asin))); |
1243 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Acos}].push_back( |
1244 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::acos))); |
1245 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan}].push_back( |
1246 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::atan))); |
1247 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp}].push_back( |
1248 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp))); |
1249 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log}].push_back( |
1250 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log))); |
1251 | |
1252 | #ifdef __ANDROID__ |
1253 | // Android NDK r15c tageting ABI 15 doesn't have full support for C++11 |
1254 | // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't |
1255 | // available up until ABI 18 so we use a shim |
1256 | auto log2_shim = [](double v) -> double { return log(v) / log(2.0); }; |
1257 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
1258 | FoldFPUnaryOp(FoldFTranscendentalUnary(::exp2))); |
1259 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
1260 | FoldFPUnaryOp(FoldFTranscendentalUnary(log2_shim))); |
1261 | #else |
1262 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Exp2}].push_back( |
1263 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::exp2))); |
1264 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Log2}].push_back( |
1265 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::log2))); |
1266 | #endif |
1267 | |
1268 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Sqrt}].push_back( |
1269 | FoldFPUnaryOp(FoldFTranscendentalUnary(std::sqrt))); |
1270 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Atan2}].push_back( |
1271 | FoldFPBinaryOp(FoldFTranscendentalBinary(std::atan2))); |
1272 | ext_rules_[{ext_inst_glslstd450_id, GLSLstd450Pow}].push_back( |
1273 | FoldFPBinaryOp(FoldFTranscendentalBinary(std::pow))); |
1274 | } |
1275 | } |
1276 | } // namespace opt |
1277 | } // namespace spvtools |
1278 | |