1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/constants.h"
16
17#include <unordered_map>
18#include <vector>
19
20#include "source/opt/ir_context.h"
21
22namespace spvtools {
23namespace opt {
24namespace analysis {
25
26float Constant::GetFloat() const {
27 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 32);
28
29 if (const FloatConstant* fc = AsFloatConstant()) {
30 return fc->GetFloatValue();
31 } else {
32 assert(AsNullConstant() && "Must be a floating point constant.");
33 return 0.0f;
34 }
35}
36
37double Constant::GetDouble() const {
38 assert(type()->AsFloat() != nullptr && type()->AsFloat()->width() == 64);
39
40 if (const FloatConstant* fc = AsFloatConstant()) {
41 return fc->GetDoubleValue();
42 } else {
43 assert(AsNullConstant() && "Must be a floating point constant.");
44 return 0.0;
45 }
46}
47
48double Constant::GetValueAsDouble() const {
49 assert(type()->AsFloat() != nullptr);
50 if (type()->AsFloat()->width() == 32) {
51 return GetFloat();
52 } else {
53 assert(type()->AsFloat()->width() == 64);
54 return GetDouble();
55 }
56}
57
58uint32_t Constant::GetU32() const {
59 assert(type()->AsInteger() != nullptr);
60 assert(type()->AsInteger()->width() == 32);
61
62 if (const IntConstant* ic = AsIntConstant()) {
63 return ic->GetU32BitValue();
64 } else {
65 assert(AsNullConstant() && "Must be an integer constant.");
66 return 0u;
67 }
68}
69
70uint64_t Constant::GetU64() const {
71 assert(type()->AsInteger() != nullptr);
72 assert(type()->AsInteger()->width() == 64);
73
74 if (const IntConstant* ic = AsIntConstant()) {
75 return ic->GetU64BitValue();
76 } else {
77 assert(AsNullConstant() && "Must be an integer constant.");
78 return 0u;
79 }
80}
81
82int32_t Constant::GetS32() const {
83 assert(type()->AsInteger() != nullptr);
84 assert(type()->AsInteger()->width() == 32);
85
86 if (const IntConstant* ic = AsIntConstant()) {
87 return ic->GetS32BitValue();
88 } else {
89 assert(AsNullConstant() && "Must be an integer constant.");
90 return 0;
91 }
92}
93
94int64_t Constant::GetS64() const {
95 assert(type()->AsInteger() != nullptr);
96 assert(type()->AsInteger()->width() == 64);
97
98 if (const IntConstant* ic = AsIntConstant()) {
99 return ic->GetS64BitValue();
100 } else {
101 assert(AsNullConstant() && "Must be an integer constant.");
102 return 0;
103 }
104}
105
106uint64_t Constant::GetZeroExtendedValue() const {
107 const auto* int_type = type()->AsInteger();
108 assert(int_type != nullptr);
109 const auto width = int_type->width();
110 assert(width <= 64);
111
112 uint64_t value = 0;
113 if (const IntConstant* ic = AsIntConstant()) {
114 if (width <= 32) {
115 value = ic->GetU32BitValue();
116 } else {
117 value = ic->GetU64BitValue();
118 }
119 } else {
120 assert(AsNullConstant() && "Must be an integer constant.");
121 }
122 return value;
123}
124
125int64_t Constant::GetSignExtendedValue() const {
126 const auto* int_type = type()->AsInteger();
127 assert(int_type != nullptr);
128 const auto width = int_type->width();
129 assert(width <= 64);
130
131 int64_t value = 0;
132 if (const IntConstant* ic = AsIntConstant()) {
133 if (width <= 32) {
134 // Let the C++ compiler do the sign extension.
135 value = int64_t(ic->GetS32BitValue());
136 } else {
137 value = ic->GetS64BitValue();
138 }
139 } else {
140 assert(AsNullConstant() && "Must be an integer constant.");
141 }
142 return value;
143}
144
145ConstantManager::ConstantManager(IRContext* ctx) : ctx_(ctx) {
146 // Populate the constant table with values from constant declarations in the
147 // module. The values of each OpConstant declaration is the identity
148 // assignment (i.e., each constant is its own value).
149 for (const auto& inst : ctx_->module()->GetConstants()) {
150 MapInst(inst);
151 }
152}
153
154Type* ConstantManager::GetType(const Instruction* inst) const {
155 return context()->get_type_mgr()->GetType(inst->type_id());
156}
157
158std::vector<const Constant*> ConstantManager::GetOperandConstants(
159 const Instruction* inst) const {
160 std::vector<const Constant*> constants;
161 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
162 const Operand* operand = &inst->GetInOperand(i);
163 if (operand->type != SPV_OPERAND_TYPE_ID) {
164 constants.push_back(nullptr);
165 } else {
166 uint32_t id = operand->words[0];
167 const analysis::Constant* constant = FindDeclaredConstant(id);
168 constants.push_back(constant);
169 }
170 }
171 return constants;
172}
173
174uint32_t ConstantManager::FindDeclaredConstant(const Constant* c,
175 uint32_t type_id) const {
176 c = FindConstant(c);
177 if (c == nullptr) {
178 return 0;
179 }
180
181 for (auto range = const_val_to_id_.equal_range(c);
182 range.first != range.second; ++range.first) {
183 Instruction* const_def =
184 context()->get_def_use_mgr()->GetDef(range.first->second);
185 if (type_id == 0 || const_def->type_id() == type_id) {
186 return range.first->second;
187 }
188 }
189 return 0;
190}
191
192std::vector<const Constant*> ConstantManager::GetConstantsFromIds(
193 const std::vector<uint32_t>& ids) const {
194 std::vector<const Constant*> constants;
195 for (uint32_t id : ids) {
196 if (const Constant* c = FindDeclaredConstant(id)) {
197 constants.push_back(c);
198 } else {
199 return {};
200 }
201 }
202 return constants;
203}
204
205Instruction* ConstantManager::BuildInstructionAndAddToModule(
206 const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
207 // TODO(1841): Handle id overflow.
208 uint32_t new_id = context()->TakeNextId();
209 if (new_id == 0) {
210 return nullptr;
211 }
212
213 auto new_inst = CreateInstruction(new_id, new_const, type_id);
214 if (!new_inst) {
215 return nullptr;
216 }
217 auto* new_inst_ptr = new_inst.get();
218 *pos = pos->InsertBefore(std::move(new_inst));
219 ++(*pos);
220 context()->get_def_use_mgr()->AnalyzeInstDefUse(new_inst_ptr);
221 MapConstantToInst(new_const, new_inst_ptr);
222 return new_inst_ptr;
223}
224
225Instruction* ConstantManager::GetDefiningInstruction(
226 const Constant* c, uint32_t type_id, Module::inst_iterator* pos) {
227 uint32_t decl_id = FindDeclaredConstant(c, type_id);
228 if (decl_id == 0) {
229 auto iter = context()->types_values_end();
230 if (pos == nullptr) pos = &iter;
231 return BuildInstructionAndAddToModule(c, pos, type_id);
232 } else {
233 auto def = context()->get_def_use_mgr()->GetDef(decl_id);
234 assert(def != nullptr);
235 assert((type_id == 0 || def->type_id() == type_id) &&
236 "This constant already has an instruction with a different type.");
237 return def;
238 }
239}
240
241std::unique_ptr<Constant> ConstantManager::CreateConstant(
242 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) const {
243 if (literal_words_or_ids.size() == 0) {
244 // Constant declared with OpConstantNull
245 return MakeUnique<NullConstant>(type);
246 } else if (auto* bt = type->AsBool()) {
247 assert(literal_words_or_ids.size() == 1 &&
248 "Bool constant should be declared with one operand");
249 return MakeUnique<BoolConstant>(bt, literal_words_or_ids.front());
250 } else if (auto* it = type->AsInteger()) {
251 return MakeUnique<IntConstant>(it, literal_words_or_ids);
252 } else if (auto* ft = type->AsFloat()) {
253 return MakeUnique<FloatConstant>(ft, literal_words_or_ids);
254 } else if (auto* vt = type->AsVector()) {
255 auto components = GetConstantsFromIds(literal_words_or_ids);
256 if (components.empty()) return nullptr;
257 // All components of VectorConstant must be of type Bool, Integer or Float.
258 if (!std::all_of(components.begin(), components.end(),
259 [](const Constant* c) {
260 if (c->type()->AsBool() || c->type()->AsInteger() ||
261 c->type()->AsFloat()) {
262 return true;
263 } else {
264 return false;
265 }
266 }))
267 return nullptr;
268 // All components of VectorConstant must be in the same type.
269 const auto* component_type = components.front()->type();
270 if (!std::all_of(components.begin(), components.end(),
271 [&component_type](const Constant* c) {
272 if (c->type() == component_type) return true;
273 return false;
274 }))
275 return nullptr;
276 return MakeUnique<VectorConstant>(vt, components);
277 } else if (auto* mt = type->AsMatrix()) {
278 auto components = GetConstantsFromIds(literal_words_or_ids);
279 if (components.empty()) return nullptr;
280 return MakeUnique<MatrixConstant>(mt, components);
281 } else if (auto* st = type->AsStruct()) {
282 auto components = GetConstantsFromIds(literal_words_or_ids);
283 if (components.empty()) return nullptr;
284 return MakeUnique<StructConstant>(st, components);
285 } else if (auto* at = type->AsArray()) {
286 auto components = GetConstantsFromIds(literal_words_or_ids);
287 if (components.empty()) return nullptr;
288 return MakeUnique<ArrayConstant>(at, components);
289 } else {
290 return nullptr;
291 }
292}
293
294const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
295 std::vector<uint32_t> literal_words_or_ids;
296
297 // Collect the constant defining literals or component ids.
298 for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
299 literal_words_or_ids.insert(literal_words_or_ids.end(),
300 inst->GetInOperand(i).words.begin(),
301 inst->GetInOperand(i).words.end());
302 }
303
304 switch (inst->opcode()) {
305 // OpConstant{True|False} have the value embedded in the opcode. So they
306 // are not handled by the for-loop above. Here we add the value explicitly.
307 case SpvOp::SpvOpConstantTrue:
308 literal_words_or_ids.push_back(true);
309 break;
310 case SpvOp::SpvOpConstantFalse:
311 literal_words_or_ids.push_back(false);
312 break;
313 case SpvOp::SpvOpConstantNull:
314 case SpvOp::SpvOpConstant:
315 case SpvOp::SpvOpConstantComposite:
316 case SpvOp::SpvOpSpecConstantComposite:
317 break;
318 default:
319 return nullptr;
320 }
321
322 return GetConstant(GetType(inst), literal_words_or_ids);
323}
324
325std::unique_ptr<Instruction> ConstantManager::CreateInstruction(
326 uint32_t id, const Constant* c, uint32_t type_id) const {
327 uint32_t type =
328 (type_id == 0) ? context()->get_type_mgr()->GetId(c->type()) : type_id;
329 if (c->AsNullConstant()) {
330 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantNull, type,
331 id, std::initializer_list<Operand>{});
332 } else if (const BoolConstant* bc = c->AsBoolConstant()) {
333 return MakeUnique<Instruction>(
334 context(),
335 bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
336 type, id, std::initializer_list<Operand>{});
337 } else if (const IntConstant* ic = c->AsIntConstant()) {
338 return MakeUnique<Instruction>(
339 context(), SpvOp::SpvOpConstant, type, id,
340 std::initializer_list<Operand>{
341 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
342 ic->words())});
343 } else if (const FloatConstant* fc = c->AsFloatConstant()) {
344 return MakeUnique<Instruction>(
345 context(), SpvOp::SpvOpConstant, type, id,
346 std::initializer_list<Operand>{
347 Operand(spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
348 fc->words())});
349 } else if (const CompositeConstant* cc = c->AsCompositeConstant()) {
350 return CreateCompositeInstruction(id, cc, type_id);
351 } else {
352 return nullptr;
353 }
354}
355
356std::unique_ptr<Instruction> ConstantManager::CreateCompositeInstruction(
357 uint32_t result_id, const CompositeConstant* cc, uint32_t type_id) const {
358 std::vector<Operand> operands;
359 Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
360 uint32_t component_index = 0;
361 for (const Constant* component_const : cc->GetComponents()) {
362 uint32_t component_type_id = 0;
363 if (type_inst && type_inst->opcode() == SpvOpTypeStruct) {
364 component_type_id = type_inst->GetSingleWordInOperand(component_index);
365 } else if (type_inst && type_inst->opcode() == SpvOpTypeArray) {
366 component_type_id = type_inst->GetSingleWordInOperand(0);
367 }
368 uint32_t id = FindDeclaredConstant(component_const, component_type_id);
369
370 if (id == 0) {
371 // Cannot get the id of the component constant, while all components
372 // should have been added to the module prior to the composite constant.
373 // Cannot create OpConstantComposite instruction in this case.
374 return nullptr;
375 }
376 operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
377 std::initializer_list<uint32_t>{id});
378 component_index++;
379 }
380 uint32_t type =
381 (type_id == 0) ? context()->get_type_mgr()->GetId(cc->type()) : type_id;
382 return MakeUnique<Instruction>(context(), SpvOp::SpvOpConstantComposite, type,
383 result_id, std::move(operands));
384}
385
386const Constant* ConstantManager::GetConstant(
387 const Type* type, const std::vector<uint32_t>& literal_words_or_ids) {
388 auto cst = CreateConstant(type, literal_words_or_ids);
389 return cst ? RegisterConstant(std::move(cst)) : nullptr;
390}
391
392uint32_t ConstantManager::GetFloatConst(float val) {
393 Type* float_type = context()->get_type_mgr()->GetFloatType();
394 utils::FloatProxy<float> v(val);
395 const Constant* c = GetConstant(float_type, v.GetWords());
396 return GetDefiningInstruction(c)->result_id();
397}
398
399std::vector<const analysis::Constant*> Constant::GetVectorComponents(
400 analysis::ConstantManager* const_mgr) const {
401 std::vector<const analysis::Constant*> components;
402 const analysis::VectorConstant* a = this->AsVectorConstant();
403 const analysis::Vector* vector_type = this->type()->AsVector();
404 assert(vector_type != nullptr);
405 if (a != nullptr) {
406 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
407 components.push_back(a->GetComponents()[i]);
408 }
409 } else {
410 const analysis::Type* element_type = vector_type->element_type();
411 const analysis::Constant* element_null_const =
412 const_mgr->GetConstant(element_type, {});
413 for (uint32_t i = 0; i < vector_type->element_count(); ++i) {
414 components.push_back(element_null_const);
415 }
416 }
417 return components;
418}
419
420} // namespace analysis
421} // namespace opt
422} // namespace spvtools
423