1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/strength_reduction_pass.h"
16
17#include <algorithm>
18#include <cstdio>
19#include <cstring>
20#include <memory>
21#include <unordered_map>
22#include <unordered_set>
23#include <utility>
24#include <vector>
25
26#include "source/opt/def_use_manager.h"
27#include "source/opt/ir_context.h"
28#include "source/opt/log.h"
29#include "source/opt/reflect.h"
30
31namespace {
32// Count the number of trailing zeros in the binary representation of
33// |constVal|.
34uint32_t CountTrailingZeros(uint32_t constVal) {
35 // Faster if we use the hardware count trailing zeros instruction.
36 // If not available, we could create a table.
37 uint32_t shiftAmount = 0;
38 while ((constVal & 1) == 0) {
39 ++shiftAmount;
40 constVal = (constVal >> 1);
41 }
42 return shiftAmount;
43}
44
45// Return true if |val| is a power of 2.
46bool IsPowerOf2(uint32_t val) {
47 // The idea is that the & will clear out the least
48 // significant 1 bit. If it is a power of 2, then
49 // there is exactly 1 bit set, and the value becomes 0.
50 if (val == 0) return false;
51 return ((val - 1) & val) == 0;
52}
53
54} // namespace
55
56namespace spvtools {
57namespace opt {
58
59Pass::Status StrengthReductionPass::Process() {
60 // Initialize the member variables on a per module basis.
61 bool modified = false;
62 int32_type_id_ = 0;
63 uint32_type_id_ = 0;
64 std::memset(constant_ids_, 0, sizeof(constant_ids_));
65
66 FindIntTypesAndConstants();
67 modified = ScanFunctions();
68 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
69}
70
71bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
72 BasicBlock::iterator* inst) {
73 assert((*inst)->opcode() == SpvOp::SpvOpIMul &&
74 "Only works for multiplication of integers.");
75 bool modified = false;
76
77 // Currently only works on 32-bit integers.
78 if ((*inst)->type_id() != int32_type_id_ &&
79 (*inst)->type_id() != uint32_type_id_) {
80 return modified;
81 }
82
83 // Check the operands for a constant that is a power of 2.
84 for (int i = 0; i < 2; i++) {
85 uint32_t opId = (*inst)->GetSingleWordInOperand(i);
86 Instruction* opInst = get_def_use_mgr()->GetDef(opId);
87 if (opInst->opcode() == SpvOp::SpvOpConstant) {
88 // We found a constant operand.
89 uint32_t constVal = opInst->GetSingleWordOperand(2);
90
91 if (IsPowerOf2(constVal)) {
92 modified = true;
93 uint32_t shiftAmount = CountTrailingZeros(constVal);
94 uint32_t shiftConstResultId = GetConstantId(shiftAmount);
95
96 // Create the new instruction.
97 uint32_t newResultId = TakeNextId();
98 std::vector<Operand> newOperands;
99 newOperands.push_back((*inst)->GetInOperand(1 - i));
100 Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
101 {shiftConstResultId});
102 newOperands.push_back(shiftOperand);
103 std::unique_ptr<Instruction> newInstruction(
104 new Instruction(context(), SpvOp::SpvOpShiftLeftLogical,
105 (*inst)->type_id(), newResultId, newOperands));
106
107 // Insert the new instruction and update the data structures.
108 (*inst) = (*inst).InsertBefore(std::move(newInstruction));
109 get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
110 ++(*inst);
111 context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
112
113 // Remove the old instruction.
114 Instruction* inst_to_delete = &*(*inst);
115 --(*inst);
116 context()->KillInst(inst_to_delete);
117
118 // We do not want to replace the instruction twice if both operands
119 // are constants that are a power of 2. So we break here.
120 break;
121 }
122 }
123 }
124
125 return modified;
126}
127
128void StrengthReductionPass::FindIntTypesAndConstants() {
129 analysis::Integer int32(32, true);
130 int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
131 analysis::Integer uint32(32, false);
132 uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
133 for (auto iter = get_module()->types_values_begin();
134 iter != get_module()->types_values_end(); ++iter) {
135 switch (iter->opcode()) {
136 case SpvOp::SpvOpConstant:
137 if (iter->type_id() == uint32_type_id_) {
138 uint32_t value = iter->GetSingleWordOperand(2);
139 if (value <= 32) constant_ids_[value] = iter->result_id();
140 }
141 break;
142 default:
143 break;
144 }
145 }
146}
147
148uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
149 assert(val <= 32 &&
150 "This function does not handle constants larger than 32.");
151
152 if (constant_ids_[val] == 0) {
153 if (uint32_type_id_ == 0) {
154 analysis::Integer uint(32, false);
155 uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
156 }
157
158 // Construct the constant.
159 uint32_t resultId = TakeNextId();
160 Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
161 {val});
162 std::unique_ptr<Instruction> newConstant(
163 new Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_,
164 resultId, {constant}));
165 get_module()->AddGlobalValue(std::move(newConstant));
166
167 // Notify the DefUseManager about this constant.
168 auto constantIter = --get_module()->types_values_end();
169 get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
170
171 // Store the result id for next time.
172 constant_ids_[val] = resultId;
173 }
174
175 return constant_ids_[val];
176}
177
178bool StrengthReductionPass::ScanFunctions() {
179 // I did not use |ForEachInst| in the module because the function that acts on
180 // the instruction gets a pointer to the instruction. We cannot use that to
181 // insert a new instruction. I want an iterator.
182 bool modified = false;
183 for (auto& func : *get_module()) {
184 for (auto& bb : func) {
185 for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
186 switch (inst->opcode()) {
187 case SpvOp::SpvOpIMul:
188 if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
189 break;
190 default:
191 break;
192 }
193 }
194 }
195 }
196 return modified;
197}
198
199} // namespace opt
200} // namespace spvtools
201