1 | // Copyright (c) 2017 Google Inc. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "source/opt/constants.h" |
16 | |
17 | #include <unordered_map> |
18 | #include <vector> |
19 | |
20 | #include "source/opt/ir_context.h" |
21 | |
22 | namespace spvtools { |
23 | namespace opt { |
24 | namespace analysis { |
25 | |
26 | float Constant::GetFloat() const { |
27 | assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32); |
28 | |
29 | if (const FloatConstant* fc = AsFloatConstant()) { |
30 | return fc->GetFloatValue(); |
31 | } else { |
32 | assert(AsNullConstant() && "Must be a floating point constant." ); |
33 | return 0.0f; |
34 | } |
35 | } |
36 | |
37 | double Constant::GetDouble() const { |
38 | assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64); |
39 | |
40 | if (const FloatConstant* fc = AsFloatConstant()) { |
41 | return fc->GetDoubleValue(); |
42 | } else { |
43 | assert(AsNullConstant() && "Must be a floating point constant." ); |
44 | return 0.0; |
45 | } |
46 | } |
47 | |
48 | double Constant::GetValueAsDouble() const { |
49 | assert(type()->AsFloat() != nullptr); |
50 | if (type()->AsFloat()->width() == 32) { |
51 | return GetFloat(); |
52 | } else { |
53 | assert(type()->AsFloat()->width() == 64); |
54 | return GetDouble(); |
55 | } |
56 | } |
57 | |
58 | uint32_t Constant::GetU32() const { |
59 | assert(type()->AsInteger() != nullptr); |
60 | assert(type()->AsInteger()->width() == 32); |
61 | |
62 | if (const IntConstant* ic = AsIntConstant()) { |
63 | return ic->GetU32BitValue(); |
64 | } else { |
65 | assert(AsNullConstant() && "Must be an integer constant." ); |
66 | return 0u; |
67 | } |
68 | } |
69 | |
70 | uint64_t Constant::GetU64() const { |
71 | assert(type()->AsInteger() != nullptr); |
72 | assert(type()->AsInteger()->width() == 64); |
73 | |
74 | if (const IntConstant* ic = AsIntConstant()) { |
75 | return ic->GetU64BitValue(); |
76 | } else { |
77 | assert(AsNullConstant() && "Must be an integer constant." ); |
78 | return 0u; |
79 | } |
80 | } |
81 | |
82 | int32_t Constant::GetS32() const { |
83 | assert(type()->AsInteger() != nullptr); |
84 | assert(type()->AsInteger()->width() == 32); |
85 | |
86 | if (const IntConstant* ic = AsIntConstant()) { |
87 | return ic->GetS32BitValue(); |
88 | } else { |
89 | assert(AsNullConstant() && "Must be an integer constant." ); |
90 | return 0; |
91 | } |
92 | } |
93 | |
94 | int64_t Constant::GetS64() const { |
95 | assert(type()->AsInteger() != nullptr); |
96 | assert(type()->AsInteger()->width() == 64); |
97 | |
98 | if (const IntConstant* ic = AsIntConstant()) { |
99 | return ic->GetS64BitValue(); |
100 | } else { |
101 | assert(AsNullConstant() && "Must be an integer constant." ); |
102 | return 0; |
103 | } |
104 | } |
105 | |
106 | uint64_t Constant::GetZeroExtendedValue() const { |
107 | const auto* int_type = type()->AsInteger(); |
108 | assert(int_type != nullptr); |
109 | const auto width = int_type->width(); |
110 | assert(width <= 64); |
111 | |
112 | uint64_t value = 0; |
113 | if (const IntConstant* ic = AsIntConstant()) { |
114 | if (width <= 32) { |
115 | value = ic->GetU32BitValue(); |
116 | } else { |
117 | value = ic->GetU64BitValue(); |
118 | } |
119 | } else { |
120 | assert(AsNullConstant() && "Must be an integer constant." ); |
121 | } |
122 | return value; |
123 | } |
124 | |
125 | int64_t Constant::GetSignExtendedValue() const { |
126 | const auto* int_type = type()->AsInteger(); |
127 | assert(int_type != nullptr); |
128 | const auto width = int_type->width(); |
129 | assert(width <= 64); |
130 | |
131 | int64_t value = 0; |
132 | if (const IntConstant* ic = AsIntConstant()) { |
133 | if (width <= 32) { |
134 | // Let the C++ compiler do the sign extension. |
135 | value = int64_t(ic->GetS32BitValue()); |
136 | } else { |
137 | value = ic->GetS64BitValue(); |
138 | } |
139 | } else { |
140 | assert(AsNullConstant() && "Must be an integer constant." ); |
141 | } |
142 | return value; |
143 | } |
144 | |
145 | ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) { |
146 | // Populate the constant table with values from constant declarations in the |
147 | // module. The values of each OpConstant declaration is the identity |
148 | // assignment (i.e., each constant is its own value). |
149 | for (const auto& inst : ctx_->module()->GetConstants()) { |
150 | MapInst(inst); |
151 | } |
152 | } |
153 | |
154 | Type* ConstantManager::GetType(const Instruction* inst) const { |
155 | return context()->get_type_mgr()->GetType(inst->type_id()); |
156 | } |
157 | |
158 | std::vector<const Constant*> ConstantManager::GetOperandConstants( |
159 | const Instruction* inst) const { |
160 | std::vector<const Constant*> constants; |
161 | for (uint32_t i = 0; i < inst->NumInOperands(); i++) { |
162 | const Operand* operand = &inst->GetInOperand(i); |
163 | if (operand->type != SPV_OPERAND_TYPE_ID) { |
164 | constants.push_back(nullptr); |
165 | } else { |
166 | uint32_t id = operand->words[0]; |
167 | const analysis::Constant* constant = FindDeclaredConstant(id); |
168 | constants.push_back(constant); |
169 | } |
170 | } |
171 | return constants; |
172 | } |
173 | |
174 | uint32_t ConstantManager::FindDeclaredConstant(const Constant* c, |
175 | uint32_t type_id) const { |
176 | c = FindConstant(c); |
177 | if (c == nullptr) { |
178 | return 0; |
179 | } |
180 | |
181 | for (auto range = const_val_to_id_.equal_range(c); |
182 | range.first != range.second; ++range.first) { |
183 | Instruction* const_def = |
184 | context()->get_def_use_mgr()->GetDef(range.first->second); |
185 | if (type_id == 0 || const_def->type_id() == type_id) { |
186 | return range.first->second; |
187 | } |
188 | } |
189 | return 0; |
190 | } |
191 | |
192 | std::vector<const Constant*> ConstantManager::GetConstantsFromIds( |
193 | const std::vector<uint32_t>& ids) const { |
194 | std::vector<const Constant*> constants; |
195 | for (uint32_t id : ids) { |
196 | if (const Constant* c = FindDeclaredConstant(id)) { |
197 | constants.push_back(c); |
198 | } else { |
199 | return {}; |
200 | } |
201 | } |
202 | return constants; |
203 | } |
204 | |
205 | Instruction* ConstantManager::BuildInstructionAndAddToModule( |
206 | const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) { |
207 | // TODO(1841): Handle id overflow. |
208 | uint32_t new_id = context()->TakeNextId(); |
209 | if (new_id == 0) { |
210 | return nullptr; |
211 | } |
212 | |
213 | auto new_inst = CreateInstruction(new_id, new_const, type_id); |
214 | if (!new_inst) { |
215 | return nullptr; |
216 | } |
217 | auto* new_inst_ptr = new_inst.get(); |
218 | *pos = pos->InsertBefore(std::move(new_inst)); |
219 | ++(*pos); |
220 | context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr); |
221 | MapConstantToInst(new_const, new_inst_ptr); |
222 | return new_inst_ptr; |
223 | } |
224 | |
225 | Instruction* ConstantManager::GetDefiningInstruction( |
226 | const Constant* c, uint32_t type_id, Module::inst_iterator* pos) { |
227 | uint32_t decl_id = FindDeclaredConstant(c, type_id); |
228 | if (decl_id == 0) { |
229 | auto iter = context()->types_values_end(); |
230 | if (pos == nullptr) pos = &iter; |
231 | return BuildInstructionAndAddToModule(c, pos, type_id); |
232 | } else { |
233 | auto def = context()->get_def_use_mgr()->GetDef(decl_id); |
234 | assert(def != nullptr); |
235 | assert((type_id == 0 || def->type_id() == type_id) && |
236 | "This constant already has an instruction with a different type." ); |
237 | return def; |
238 | } |
239 | } |
240 | |
241 | std::unique_ptr<Constant> ConstantManager::CreateConstant( |
242 | const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const { |
243 | if (literal_words_or_ids.size() == 0) { |
244 | // Constant declared with OpConstantNull |
245 | return MakeUnique<NullConstant>(type); |
246 | } else if (auto* bt = type->AsBool()) { |
247 | assert(literal_words_or_ids.size() == 1 && |
248 | "Bool constant should be declared with one operand" ); |
249 | return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front()); |
250 | } else if (auto* it = type->AsInteger()) { |
251 | return MakeUnique<IntConstant>(it, literal_words_or_ids); |
252 | } else if (auto* ft = type->AsFloat()) { |
253 | return MakeUnique<FloatConstant>(ft, literal_words_or_ids); |
254 | } else if (auto* vt = type->AsVector()) { |
255 | auto components = GetConstantsFromIds(literal_words_or_ids); |
256 | if (components.empty()) return nullptr; |
257 | // All components of VectorConstant must be of type Bool, Integer or Float. |
258 | if (!std::all_of(components.begin(), components.end(), |
259 | [](const Constant* c) { |
260 | if (c->type()->AsBool() || c->type()->AsInteger() || |
261 | c->type()->AsFloat()) { |
262 | return true; |
263 | } else { |
264 | return false; |
265 | } |
266 | })) |
267 | return nullptr; |
268 | // All components of VectorConstant must be in the same type. |
269 | const auto* component_type = components.front()->type(); |
270 | if (!std::all_of(components.begin(), components.end(), |
271 | [&component_type](const Constant* c) { |
272 | if (c->type() == component_type) return true; |
273 | return false; |
274 | })) |
275 | return nullptr; |
276 | return MakeUnique<VectorConstant>(vt, components); |
277 | } else if (auto* mt = type->AsMatrix()) { |
278 | auto components = GetConstantsFromIds(literal_words_or_ids); |
279 | if (components.empty()) return nullptr; |
280 | return MakeUnique<MatrixConstant>(mt, components); |
281 | } else if (auto* st = type->AsStruct()) { |
282 | auto components = GetConstantsFromIds(literal_words_or_ids); |
283 | if (components.empty()) return nullptr; |
284 | return MakeUnique<StructConstant>(st, components); |
285 | } else if (auto* at = type->AsArray()) { |
286 | auto components = GetConstantsFromIds(literal_words_or_ids); |
287 | if (components.empty()) return nullptr; |
288 | return MakeUnique<ArrayConstant>(at, components); |
289 | } else { |
290 | return nullptr; |
291 | } |
292 | } |
293 | |
294 | const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) { |
295 | std::vector<uint32_t> literal_words_or_ids; |
296 | |
297 | // Collect the constant defining literals or component ids. |
298 | for (uint32_t i = 0; i < inst->NumInOperands(); i++) { |
299 | literal_words_or_ids.insert(literal_words_or_ids.end(), |
300 | inst->GetInOperand(i).words.begin(), |
301 | inst->GetInOperand(i).words.end()); |
302 | } |
303 | |
304 | switch (inst->opcode()) { |
305 | // OpConstant{True|False} have the value embedded in the opcode. So they |
306 | // are not handled by the for-loop above. Here we add the value explicitly. |
307 | case SpvOp::SpvOpConstantTrue: |
308 | literal_words_or_ids.push_back(true); |
309 | break; |
310 | case SpvOp::SpvOpConstantFalse: |
311 | literal_words_or_ids.push_back(false); |
312 | break; |
313 | case SpvOp::SpvOpConstantNull: |
314 | case SpvOp::SpvOpConstant: |
315 | case SpvOp::SpvOpConstantComposite: |
316 | case SpvOp::SpvOpSpecConstantComposite: |
317 | break; |
318 | default: |
319 | return nullptr; |
320 | } |
321 | |
322 | return GetConstant(GetType(inst), literal_words_or_ids); |
323 | } |
324 | |
325 | std::unique_ptr<Instruction> ConstantManager::CreateInstruction( |
326 | uint32_t id, const Constant* c, uint32_t type_id) const { |
327 | uint32_t type = |
328 | (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id; |
329 | if (c->AsNullConstant()) { |
330 | return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type, |
331 | id, std::initializer_list<Operand>{}); |
332 | } else if (const BoolConstant* bc = c->AsBoolConstant()) { |
333 | return MakeUnique<Instruction>( |
334 | context(), |
335 | bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse, |
336 | type, id, std::initializer_list<Operand>{}); |
337 | } else if (const IntConstant* ic = c->AsIntConstant()) { |
338 | return MakeUnique<Instruction>( |
339 | context(), SpvOp::SpvOpConstant, type, id, |
340 | std::initializer_list<Operand>{ |
341 | Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, |
342 | ic->words())}); |
343 | } else if (const FloatConstant* fc = c->AsFloatConstant()) { |
344 | return MakeUnique<Instruction>( |
345 | context(), SpvOp::SpvOpConstant, type, id, |
346 | std::initializer_list<Operand>{ |
347 | Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER, |
348 | fc->words())}); |
349 | } else if (const CompositeConstant* cc = c->AsCompositeConstant()) { |
350 | return CreateCompositeInstruction(id, cc, type_id); |
351 | } else { |
352 | return nullptr; |
353 | } |
354 | } |
355 | |
356 | std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction( |
357 | uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const { |
358 | std::vector<Operand> operands; |
359 | Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id); |
360 | uint32_t component_index = 0; |
361 | for (const Constant* component_const : cc->GetComponents()) { |
362 | uint32_t component_type_id = 0; |
363 | if (type_inst && type_inst->opcode() == SpvOpTypeStruct) { |
364 | component_type_id = type_inst->GetSingleWordInOperand(component_index); |
365 | } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) { |
366 | component_type_id = type_inst->GetSingleWordInOperand(0); |
367 | } |
368 | uint32_t id = FindDeclaredConstant(component_const, component_type_id); |
369 | |
370 | if (id == 0) { |
371 | // Cannot get the id of the component constant, while all components |
372 | // should have been added to the module prior to the composite constant. |
373 | // Cannot create OpConstantComposite instruction in this case. |
374 | return nullptr; |
375 | } |
376 | operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID, |
377 | std::initializer_list<uint32_t>{id}); |
378 | component_index++; |
379 | } |
380 | uint32_t type = |
381 | (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id; |
382 | return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type, |
383 | result_id, std::move(operands)); |
384 | } |
385 | |
386 | const Constant* ConstantManager::GetConstant( |
387 | const Type* type, const std::vector<uint32_t>& literal_words_or_ids) { |
388 | auto cst = CreateConstant(type, literal_words_or_ids); |
389 | return cst ? RegisterConstant(std::move(cst)) : nullptr; |
390 | } |
391 | |
392 | uint32_t ConstantManager::GetFloatConst(float val) { |
393 | Type* float_type = context()->get_type_mgr()->GetFloatType(); |
394 | utils::FloatProxy<float> v(val); |
395 | const Constant* c = GetConstant(float_type, v.GetWords()); |
396 | return GetDefiningInstruction(c)->result_id(); |
397 | } |
398 | |
399 | std::vector<const analysis::Constant*> Constant::GetVectorComponents( |
400 | analysis::ConstantManager* const_mgr) const { |
401 | std::vector<const analysis::Constant*> components; |
402 | const analysis::VectorConstant* a = this->AsVectorConstant(); |
403 | const analysis::Vector* vector_type = this->type()->AsVector(); |
404 | assert(vector_type != nullptr); |
405 | if (a != nullptr) { |
406 | for (uint32_t i = 0; i < vector_type->element_count(); ++i) { |
407 | components.push_back(a->GetComponents()[i]); |
408 | } |
409 | } else { |
410 | const analysis::Type* element_type = vector_type->element_type(); |
411 | const analysis::Constant* element_null_const = |
412 | const_mgr->GetConstant(element_type, {}); |
413 | for (uint32_t i = 0; i < vector_type->element_count(); ++i) { |
414 | components.push_back(element_null_const); |
415 | } |
416 | } |
417 | return components; |
418 | } |
419 | |
420 | } // namespace analysis |
421 | } // namespace opt |
422 | } // namespace spvtools |
423 | |