1// Copyright (c) 2019 The Khronos Group Inc.
2// Copyright (c) 2019 Valve Corporation
3// Copyright (c) 2019 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "convert_to_half_pass.h"
18
19#include "source/opt/ir_builder.h"
20
21namespace {
22
23// Indices of operands in SPIR-V instructions
24static const int kImageSampleDrefIdInIdx = 2;
25
26} // anonymous namespace
27
28namespace spvtools {
29namespace opt {
30
31bool ConvertToHalfPass::IsArithmetic(Instruction* inst) {
32 return target_ops_core_.count(inst->opcode()) != 0 ||
33 (inst->opcode() == SpvOpExtInst &&
34 inst->GetSingleWordInOperand(0) ==
35 context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450() &&
36 target_ops_450_.count(inst->GetSingleWordInOperand(1)) != 0);
37}
38
39bool ConvertToHalfPass::IsFloat(Instruction* inst, uint32_t width) {
40 uint32_t ty_id = inst->type_id();
41 if (ty_id == 0) return false;
42 return Pass::IsFloat(ty_id, width);
43}
44
45bool ConvertToHalfPass::IsDecoratedRelaxed(Instruction* inst) {
46 uint32_t r_id = inst->result_id();
47 for (auto r_inst : get_decoration_mgr()->GetDecorationsFor(r_id, false))
48 if (r_inst->opcode() == SpvOpDecorate &&
49 r_inst->GetSingleWordInOperand(1) == SpvDecorationRelaxedPrecision)
50 return true;
51 return false;
52}
53
54bool ConvertToHalfPass::IsRelaxed(uint32_t id) {
55 return relaxed_ids_set_.count(id) > 0;
56}
57
58void ConvertToHalfPass::AddRelaxed(uint32_t id) { relaxed_ids_set_.insert(id); }
59
60analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
61 analysis::Float float_ty(width);
62 return context()->get_type_mgr()->GetRegisteredType(&float_ty);
63}
64
65analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
66 uint32_t width) {
67 analysis::Type* reg_float_ty = FloatScalarType(width);
68 analysis::Vector vec_ty(reg_float_ty, v_len);
69 return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
70}
71
72analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
73 uint32_t vty_id,
74 uint32_t width) {
75 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
76 uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
77 analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
78 analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
79 return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
80}
81
82uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
83 analysis::Type* reg_equiv_ty;
84 Instruction* ty_inst = get_def_use_mgr()->GetDef(ty_id);
85 if (ty_inst->opcode() == SpvOpTypeMatrix)
86 reg_equiv_ty = FloatMatrixType(ty_inst->GetSingleWordInOperand(1),
87 ty_inst->GetSingleWordInOperand(0), width);
88 else if (ty_inst->opcode() == SpvOpTypeVector)
89 reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
90 else // SpvOpTypeFloat
91 reg_equiv_ty = FloatScalarType(width);
92 return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
93}
94
95void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
96 Instruction* inst) {
97 Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
98 uint32_t ty_id = val_inst->type_id();
99 uint32_t nty_id = EquivFloatTypeId(ty_id, width);
100 if (nty_id == ty_id) return;
101 Instruction* cvt_inst;
102 InstructionBuilder builder(
103 context(), inst,
104 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
105 if (val_inst->opcode() == SpvOpUndef)
106 cvt_inst = builder.AddNullaryOp(nty_id, SpvOpUndef);
107 else
108 cvt_inst = builder.AddUnaryOp(nty_id, SpvOpFConvert, *val_idp);
109 *val_idp = cvt_inst->result_id();
110}
111
112bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
113 if (inst->opcode() != SpvOpFConvert) return false;
114 uint32_t mty_id = inst->type_id();
115 Instruction* mty_inst = get_def_use_mgr()->GetDef(mty_id);
116 if (mty_inst->opcode() != SpvOpTypeMatrix) return false;
117 uint32_t vty_id = mty_inst->GetSingleWordInOperand(0);
118 uint32_t v_cnt = mty_inst->GetSingleWordInOperand(1);
119 Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
120 uint32_t cty_id = vty_inst->GetSingleWordInOperand(0);
121 Instruction* cty_inst = get_def_use_mgr()->GetDef(cty_id);
122 InstructionBuilder builder(
123 context(), inst,
124 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
125 // Convert each component vector, combine them with OpCompositeConstruct
126 // and replace original instruction.
127 uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
128 uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
129 uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
130 std::vector<Operand> opnds = {};
131 for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
132 Instruction* ext_inst = builder.AddIdLiteralOp(
133 orig_vty_id, SpvOpCompositeExtract, orig_mat_id, vidx);
134 Instruction* cvt_inst =
135 builder.AddUnaryOp(vty_id, SpvOpFConvert, ext_inst->result_id());
136 opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
137 }
138 uint32_t mat_id = TakeNextId();
139 std::unique_ptr<Instruction> mat_inst(new Instruction(
140 context(), SpvOpCompositeConstruct, mty_id, mat_id, opnds));
141 (void)builder.AddInstruction(std::move(mat_inst));
142 context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
143 // Turn original instruction into copy so it is valid.
144 inst->SetOpcode(SpvOpCopyObject);
145 inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
146 get_def_use_mgr()->AnalyzeInstUse(inst);
147 return true;
148}
149
150void ConvertToHalfPass::RemoveRelaxedDecoration(uint32_t id) {
151 context()->get_decoration_mgr()->RemoveDecorationsFrom(
152 id, [](const Instruction& dec) {
153 if (dec.opcode() == SpvOpDecorate &&
154 dec.GetSingleWordInOperand(1u) == SpvDecorationRelaxedPrecision)
155 return true;
156 else
157 return false;
158 });
159}
160
161bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
162 bool modified = false;
163 // Convert all float32 based operands to float16 equivalent and change
164 // instruction type to float16 equivalent.
165 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
166 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
167 if (!IsFloat(op_inst, 32)) return;
168 GenConvert(idp, 16, inst);
169 modified = true;
170 });
171 if (IsFloat(inst, 32)) {
172 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
173 converted_ids_.insert(inst->result_id());
174 modified = true;
175 }
176 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
177 return modified;
178}
179
180bool ConvertToHalfPass::ProcessPhi(Instruction* inst) {
181 // Add float16 converts of any float32 operands and change type
182 // of phi to float16 equivalent. Operand converts need to be added to
183 // preceeding blocks.
184 uint32_t ocnt = 0;
185 uint32_t* prev_idp;
186 inst->ForEachInId([&ocnt, &prev_idp, this](uint32_t* idp) {
187 if (ocnt % 2 == 0) {
188 prev_idp = idp;
189 } else {
190 Instruction* val_inst = get_def_use_mgr()->GetDef(*prev_idp);
191 if (IsFloat(val_inst, 32)) {
192 BasicBlock* bp = context()->get_instr_block(*idp);
193 auto insert_before = bp->tail();
194 if (insert_before != bp->begin()) {
195 --insert_before;
196 if (insert_before->opcode() != SpvOpSelectionMerge &&
197 insert_before->opcode() != SpvOpLoopMerge)
198 ++insert_before;
199 }
200 GenConvert(prev_idp, 16, &*insert_before);
201 }
202 }
203 ++ocnt;
204 });
205 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
206 get_def_use_mgr()->AnalyzeInstUse(inst);
207 converted_ids_.insert(inst->result_id());
208 return true;
209}
210
211bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
212 // If float32 and relaxed, change to float16 convert
213 if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
214 inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
215 get_def_use_mgr()->AnalyzeInstUse(inst);
216 converted_ids_.insert(inst->result_id());
217 }
218 // If operand and result types are the same, change FConvert to CopyObject to
219 // keep validator happy; simplification and DCE will clean it up
220 // One way this can happen is if an FConvert generated during this pass
221 // (likely by ProcessPhi) is later encountered here and its operand has been
222 // changed to half.
223 uint32_t val_id = inst->GetSingleWordInOperand(0);
224 Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
225 if (inst->type_id() == val_inst->type_id()) inst->SetOpcode(SpvOpCopyObject);
226 return true; // modified
227}
228
229bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
230 bool modified = false;
231 // If image reference, only need to convert dref args back to float32
232 if (dref_image_ops_.count(inst->opcode()) != 0) {
233 uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
234 if (converted_ids_.count(dref_id) > 0) {
235 GenConvert(&dref_id, 32, inst);
236 inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
237 get_def_use_mgr()->AnalyzeInstUse(inst);
238 modified = true;
239 }
240 }
241 return modified;
242}
243
244bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
245 bool modified = false;
246 // If non-relaxed instruction has changed operands, need to convert
247 // them back to float32
248 inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
249 if (converted_ids_.count(*idp) == 0) return;
250 uint32_t old_id = *idp;
251 GenConvert(idp, 32, inst);
252 if (*idp != old_id) modified = true;
253 });
254 if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
255 return modified;
256}
257
258bool ConvertToHalfPass::GenHalfInst(Instruction* inst) {
259 bool modified = false;
260 // Remember id for later deletion of RelaxedPrecision decoration
261 bool inst_relaxed = IsRelaxed(inst->result_id());
262 if (IsArithmetic(inst) && inst_relaxed)
263 modified = GenHalfArith(inst);
264 else if (inst->opcode() == SpvOpPhi && inst_relaxed)
265 modified = ProcessPhi(inst);
266 else if (inst->opcode() == SpvOpFConvert)
267 modified = ProcessConvert(inst);
268 else if (image_ops_.count(inst->opcode()) != 0)
269 modified = ProcessImageRef(inst);
270 else
271 modified = ProcessDefault(inst);
272 return modified;
273}
274
275bool ConvertToHalfPass::CloseRelaxInst(Instruction* inst) {
276 if (inst->result_id() == 0) return false;
277 if (IsRelaxed(inst->result_id())) return false;
278 if (!IsFloat(inst, 32)) return false;
279 if (IsDecoratedRelaxed(inst)) {
280 AddRelaxed(inst->result_id());
281 return true;
282 }
283 if (closure_ops_.count(inst->opcode()) == 0) return false;
284 // Can relax if all float operands are relaxed
285 bool relax = true;
286 inst->ForEachInId([&relax, this](uint32_t* idp) {
287 Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
288 if (!IsFloat(op_inst, 32)) return;
289 if (!IsRelaxed(*idp)) relax = false;
290 });
291 if (relax) {
292 AddRelaxed(inst->result_id());
293 return true;
294 }
295 // Can relax if all uses are relaxed
296 relax = true;
297 get_def_use_mgr()->ForEachUser(inst, [&relax, this](Instruction* uinst) {
298 if (uinst->result_id() == 0 || !IsFloat(uinst, 32) ||
299 (!IsDecoratedRelaxed(uinst) && !IsRelaxed(uinst->result_id()))) {
300 relax = false;
301 return;
302 }
303 });
304 if (relax) {
305 AddRelaxed(inst->result_id());
306 return true;
307 }
308 return false;
309}
310
311bool ConvertToHalfPass::ProcessFunction(Function* func) {
312 // Do a closure of Relaxed on composite and phi instructions
313 bool changed = true;
314 while (changed) {
315 changed = false;
316 cfg()->ForEachBlockInReversePostOrder(
317 func->entry().get(), [&changed, this](BasicBlock* bb) {
318 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
319 changed |= CloseRelaxInst(&*ii);
320 });
321 }
322 // Do convert of relaxed instructions to half precision
323 bool modified = false;
324 cfg()->ForEachBlockInReversePostOrder(
325 func->entry().get(), [&modified, this](BasicBlock* bb) {
326 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
327 modified |= GenHalfInst(&*ii);
328 });
329 // Replace invalid converts of matrix into equivalent vector extracts,
330 // converts and finally a composite construct
331 cfg()->ForEachBlockInReversePostOrder(
332 func->entry().get(), [&modified, this](BasicBlock* bb) {
333 for (auto ii = bb->begin(); ii != bb->end(); ++ii)
334 modified |= MatConvertCleanup(&*ii);
335 });
336 return modified;
337}
338
339Pass::Status ConvertToHalfPass::ProcessImpl() {
340 Pass::ProcessFunction pfn = [this](Function* fp) {
341 return ProcessFunction(fp);
342 };
343 bool modified = context()->ProcessEntryPointCallTree(pfn);
344 // If modified, make sure module has Float16 capability
345 if (modified) context()->AddCapability(SpvCapabilityFloat16);
346 // Remove all RelaxedPrecision decorations from instructions and globals
347 for (auto c_id : relaxed_ids_set_) RemoveRelaxedDecoration(c_id);
348 for (auto& val : get_module()->types_values()) {
349 uint32_t v_id = val.result_id();
350 if (v_id != 0) RemoveRelaxedDecoration(v_id);
351 }
352 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
353}
354
355Pass::Status ConvertToHalfPass::Process() {
356 Initialize();
357 return ProcessImpl();
358}
359
360void ConvertToHalfPass::Initialize() {
361 target_ops_core_ = {
362 SpvOpVectorExtractDynamic,
363 SpvOpVectorInsertDynamic,
364 SpvOpVectorShuffle,
365 SpvOpCompositeConstruct,
366 SpvOpCompositeInsert,
367 SpvOpCompositeExtract,
368 SpvOpCopyObject,
369 SpvOpTranspose,
370 SpvOpConvertSToF,
371 SpvOpConvertUToF,
372 // SpvOpFConvert,
373 // SpvOpQuantizeToF16,
374 SpvOpFNegate,
375 SpvOpFAdd,
376 SpvOpFSub,
377 SpvOpFMul,
378 SpvOpFDiv,
379 SpvOpFMod,
380 SpvOpVectorTimesScalar,
381 SpvOpMatrixTimesScalar,
382 SpvOpVectorTimesMatrix,
383 SpvOpMatrixTimesVector,
384 SpvOpMatrixTimesMatrix,
385 SpvOpOuterProduct,
386 SpvOpDot,
387 SpvOpSelect,
388 SpvOpFOrdEqual,
389 SpvOpFUnordEqual,
390 SpvOpFOrdNotEqual,
391 SpvOpFUnordNotEqual,
392 SpvOpFOrdLessThan,
393 SpvOpFUnordLessThan,
394 SpvOpFOrdGreaterThan,
395 SpvOpFUnordGreaterThan,
396 SpvOpFOrdLessThanEqual,
397 SpvOpFUnordLessThanEqual,
398 SpvOpFOrdGreaterThanEqual,
399 SpvOpFUnordGreaterThanEqual,
400 };
401 target_ops_450_ = {
402 GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, GLSLstd450FAbs,
403 GLSLstd450FSign, GLSLstd450Floor, GLSLstd450Ceil, GLSLstd450Fract,
404 GLSLstd450Radians, GLSLstd450Degrees, GLSLstd450Sin, GLSLstd450Cos,
405 GLSLstd450Tan, GLSLstd450Asin, GLSLstd450Acos, GLSLstd450Atan,
406 GLSLstd450Sinh, GLSLstd450Cosh, GLSLstd450Tanh, GLSLstd450Asinh,
407 GLSLstd450Acosh, GLSLstd450Atanh, GLSLstd450Atan2, GLSLstd450Pow,
408 GLSLstd450Exp, GLSLstd450Log, GLSLstd450Exp2, GLSLstd450Log2,
409 GLSLstd450Sqrt, GLSLstd450InverseSqrt, GLSLstd450Determinant,
410 GLSLstd450MatrixInverse,
411 // TODO(greg-lunarg): GLSLstd450ModfStruct,
412 GLSLstd450FMin, GLSLstd450FMax, GLSLstd450FClamp, GLSLstd450FMix,
413 GLSLstd450Step, GLSLstd450SmoothStep, GLSLstd450Fma,
414 // TODO(greg-lunarg): GLSLstd450FrexpStruct,
415 GLSLstd450Ldexp, GLSLstd450Length, GLSLstd450Distance, GLSLstd450Cross,
416 GLSLstd450Normalize, GLSLstd450FaceForward, GLSLstd450Reflect,
417 GLSLstd450Refract, GLSLstd450NMin, GLSLstd450NMax, GLSLstd450NClamp};
418 image_ops_ = {SpvOpImageSampleImplicitLod,
419 SpvOpImageSampleExplicitLod,
420 SpvOpImageSampleDrefImplicitLod,
421 SpvOpImageSampleDrefExplicitLod,
422 SpvOpImageSampleProjImplicitLod,
423 SpvOpImageSampleProjExplicitLod,
424 SpvOpImageSampleProjDrefImplicitLod,
425 SpvOpImageSampleProjDrefExplicitLod,
426 SpvOpImageFetch,
427 SpvOpImageGather,
428 SpvOpImageDrefGather,
429 SpvOpImageRead,
430 SpvOpImageSparseSampleImplicitLod,
431 SpvOpImageSparseSampleExplicitLod,
432 SpvOpImageSparseSampleDrefImplicitLod,
433 SpvOpImageSparseSampleDrefExplicitLod,
434 SpvOpImageSparseSampleProjImplicitLod,
435 SpvOpImageSparseSampleProjExplicitLod,
436 SpvOpImageSparseSampleProjDrefImplicitLod,
437 SpvOpImageSparseSampleProjDrefExplicitLod,
438 SpvOpImageSparseFetch,
439 SpvOpImageSparseGather,
440 SpvOpImageSparseDrefGather,
441 SpvOpImageSparseTexelsResident,
442 SpvOpImageSparseRead};
443 dref_image_ops_ = {
444 SpvOpImageSampleDrefImplicitLod,
445 SpvOpImageSampleDrefExplicitLod,
446 SpvOpImageSampleProjDrefImplicitLod,
447 SpvOpImageSampleProjDrefExplicitLod,
448 SpvOpImageDrefGather,
449 SpvOpImageSparseSampleDrefImplicitLod,
450 SpvOpImageSparseSampleDrefExplicitLod,
451 SpvOpImageSparseSampleProjDrefImplicitLod,
452 SpvOpImageSparseSampleProjDrefExplicitLod,
453 SpvOpImageSparseDrefGather,
454 };
455 closure_ops_ = {
456 SpvOpVectorExtractDynamic,
457 SpvOpVectorInsertDynamic,
458 SpvOpVectorShuffle,
459 SpvOpCompositeConstruct,
460 SpvOpCompositeInsert,
461 SpvOpCompositeExtract,
462 SpvOpCopyObject,
463 SpvOpTranspose,
464 SpvOpPhi,
465 };
466 relaxed_ids_set_.clear();
467 converted_ids_.clear();
468}
469
470} // namespace opt
471} // namespace spvtools
472