| 1 | // Copyright (c) 2018 Google LLC |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | |
| 15 | #include "source/opt/combine_access_chains.h" |
| 16 | |
| 17 | #include <utility> |
| 18 | |
| 19 | #include "source/opt/constants.h" |
| 20 | #include "source/opt/ir_builder.h" |
| 21 | #include "source/opt/ir_context.h" |
| 22 | |
| 23 | namespace spvtools { |
| 24 | namespace opt { |
| 25 | |
| 26 | Pass::Status CombineAccessChains::Process() { |
| 27 | bool modified = false; |
| 28 | |
| 29 | for (auto& function : *get_module()) { |
| 30 | modified |= ProcessFunction(function); |
| 31 | } |
| 32 | |
| 33 | return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); |
| 34 | } |
| 35 | |
| 36 | bool CombineAccessChains::ProcessFunction(Function& function) { |
| 37 | bool modified = false; |
| 38 | |
| 39 | cfg()->ForEachBlockInReversePostOrder( |
| 40 | function.entry().get(), [&modified, this](BasicBlock* block) { |
| 41 | block->ForEachInst([&modified, this](Instruction* inst) { |
| 42 | switch (inst->opcode()) { |
| 43 | case SpvOpAccessChain: |
| 44 | case SpvOpInBoundsAccessChain: |
| 45 | case SpvOpPtrAccessChain: |
| 46 | case SpvOpInBoundsPtrAccessChain: |
| 47 | modified |= CombineAccessChain(inst); |
| 48 | break; |
| 49 | default: |
| 50 | break; |
| 51 | } |
| 52 | }); |
| 53 | }); |
| 54 | |
| 55 | return modified; |
| 56 | } |
| 57 | |
| 58 | uint32_t CombineAccessChains::GetConstantValue( |
| 59 | const analysis::Constant* constant_inst) { |
| 60 | if (constant_inst->type()->AsInteger()->width() <= 32) { |
| 61 | if (constant_inst->type()->AsInteger()->IsSigned()) { |
| 62 | return static_cast<uint32_t>(constant_inst->GetS32()); |
| 63 | } else { |
| 64 | return constant_inst->GetU32(); |
| 65 | } |
| 66 | } else { |
| 67 | assert(false); |
| 68 | return 0u; |
| 69 | } |
| 70 | } |
| 71 | |
| 72 | uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) { |
| 73 | uint32_t array_stride = 0; |
| 74 | context()->get_decoration_mgr()->WhileEachDecoration( |
| 75 | inst->type_id(), SpvDecorationArrayStride, |
| 76 | [&array_stride](const Instruction& decoration) { |
| 77 | assert(decoration.opcode() != SpvOpDecorateId); |
| 78 | if (decoration.opcode() == SpvOpDecorate) { |
| 79 | array_stride = decoration.GetSingleWordInOperand(1); |
| 80 | } else { |
| 81 | array_stride = decoration.GetSingleWordInOperand(2); |
| 82 | } |
| 83 | return false; |
| 84 | }); |
| 85 | return array_stride; |
| 86 | } |
| 87 | |
| 88 | const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) { |
| 89 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| 90 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
| 91 | |
| 92 | Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0)); |
| 93 | const analysis::Type* type = type_mgr->GetType(base_ptr->type_id()); |
| 94 | assert(type->AsPointer()); |
| 95 | type = type->AsPointer()->pointee_type(); |
| 96 | std::vector<uint32_t> element_indices; |
| 97 | uint32_t starting_index = 1; |
| 98 | if (IsPtrAccessChain(inst->opcode())) { |
| 99 | // Skip the first index of OpPtrAccessChain as it does not affect type |
| 100 | // resolution. |
| 101 | starting_index = 2; |
| 102 | } |
| 103 | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
| 104 | Instruction* index_inst = |
| 105 | def_use_mgr->GetDef(inst->GetSingleWordInOperand(i)); |
| 106 | const analysis::Constant* index_constant = |
| 107 | context()->get_constant_mgr()->GetConstantFromInst(index_inst); |
| 108 | if (index_constant) { |
| 109 | uint32_t index_value = GetConstantValue(index_constant); |
| 110 | element_indices.push_back(index_value); |
| 111 | } else { |
| 112 | // This index must not matter to resolve the type in valid SPIR-V. |
| 113 | element_indices.push_back(0); |
| 114 | } |
| 115 | } |
| 116 | type = type_mgr->GetMemberType(type, element_indices); |
| 117 | return type; |
| 118 | } |
| 119 | |
| 120 | bool CombineAccessChains::CombineIndices(Instruction* ptr_input, |
| 121 | Instruction* inst, |
| 122 | std::vector<Operand>* new_operands) { |
| 123 | analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); |
| 124 | analysis::ConstantManager* constant_mgr = context()->get_constant_mgr(); |
| 125 | |
| 126 | Instruction* last_index_inst = def_use_mgr->GetDef( |
| 127 | ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1)); |
| 128 | const analysis::Constant* last_index_constant = |
| 129 | constant_mgr->GetConstantFromInst(last_index_inst); |
| 130 | |
| 131 | Instruction* element_inst = |
| 132 | def_use_mgr->GetDef(inst->GetSingleWordInOperand(1)); |
| 133 | const analysis::Constant* element_constant = |
| 134 | constant_mgr->GetConstantFromInst(element_inst); |
| 135 | |
| 136 | // Combine the last index of the AccessChain (|ptr_inst|) with the element |
| 137 | // operand of the PtrAccessChain (|inst|). |
| 138 | const bool combining_element_operands = |
| 139 | IsPtrAccessChain(inst->opcode()) && |
| 140 | IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2; |
| 141 | uint32_t new_value_id = 0; |
| 142 | const analysis::Type* type = GetIndexedType(ptr_input); |
| 143 | if (last_index_constant && element_constant) { |
| 144 | // Combine the constants. |
| 145 | uint32_t new_value = GetConstantValue(last_index_constant) + |
| 146 | GetConstantValue(element_constant); |
| 147 | const analysis::Constant* new_value_constant = |
| 148 | constant_mgr->GetConstant(last_index_constant->type(), {new_value}); |
| 149 | Instruction* new_value_inst = |
| 150 | constant_mgr->GetDefiningInstruction(new_value_constant); |
| 151 | new_value_id = new_value_inst->result_id(); |
| 152 | } else if (!type->AsStruct() || combining_element_operands) { |
| 153 | // Generate an addition of the two indices. |
| 154 | InstructionBuilder builder( |
| 155 | context(), inst, |
| 156 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
| 157 | Instruction* addition = builder.AddIAdd(last_index_inst->type_id(), |
| 158 | last_index_inst->result_id(), |
| 159 | element_inst->result_id()); |
| 160 | new_value_id = addition->result_id(); |
| 161 | } else { |
| 162 | // Indexing into structs must be constant, so bail out here. |
| 163 | return false; |
| 164 | } |
| 165 | new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}}); |
| 166 | return true; |
| 167 | } |
| 168 | |
| 169 | bool CombineAccessChains::CreateNewInputOperands( |
| 170 | Instruction* ptr_input, Instruction* inst, |
| 171 | std::vector<Operand>* new_operands) { |
| 172 | // Start by copying all the input operands of the feeder access chain. |
| 173 | for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) { |
| 174 | new_operands->push_back(ptr_input->GetInOperand(i)); |
| 175 | } |
| 176 | |
| 177 | // Deal with the last index of the feeder access chain. |
| 178 | if (IsPtrAccessChain(inst->opcode())) { |
| 179 | // The last index of the feeder should be combined with the element operand |
| 180 | // of |inst|. |
| 181 | if (!CombineIndices(ptr_input, inst, new_operands)) return false; |
| 182 | } else { |
| 183 | // The indices aren't being combined so now add the last index operand of |
| 184 | // |ptr_input|. |
| 185 | new_operands->push_back( |
| 186 | ptr_input->GetInOperand(ptr_input->NumInOperands() - 1)); |
| 187 | } |
| 188 | |
| 189 | // Copy the remaining index operands. |
| 190 | uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1; |
| 191 | for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) { |
| 192 | new_operands->push_back(inst->GetInOperand(i)); |
| 193 | } |
| 194 | |
| 195 | return true; |
| 196 | } |
| 197 | |
| 198 | bool CombineAccessChains::CombineAccessChain(Instruction* inst) { |
| 199 | assert((inst->opcode() == SpvOpPtrAccessChain || |
| 200 | inst->opcode() == SpvOpAccessChain || |
| 201 | inst->opcode() == SpvOpInBoundsAccessChain || |
| 202 | inst->opcode() == SpvOpInBoundsPtrAccessChain) && |
| 203 | "Wrong opcode. Expected an access chain." ); |
| 204 | |
| 205 | Instruction* ptr_input = |
| 206 | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0)); |
| 207 | if (ptr_input->opcode() != SpvOpAccessChain && |
| 208 | ptr_input->opcode() != SpvOpInBoundsAccessChain && |
| 209 | ptr_input->opcode() != SpvOpPtrAccessChain && |
| 210 | ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) { |
| 211 | return false; |
| 212 | } |
| 213 | |
| 214 | if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false; |
| 215 | |
| 216 | // Handles the following cases: |
| 217 | // 1. |ptr_input| is an index-less access chain. Replace the pointer |
| 218 | // in |inst| with |ptr_input|'s pointer. |
| 219 | // 2. |inst| is a index-less access chain. Change |inst| to an |
| 220 | // OpCopyObject. |
| 221 | // 3. |inst| is not a pointer access chain. |
| 222 | // |inst|'s indices are appended to |ptr_input|'s indices. |
| 223 | // 4. |ptr_input| is not pointer access chain. |
| 224 | // |inst| is a pointer access chain. |
| 225 | // |inst|'s element operand is combined with the last index in |
| 226 | // |ptr_input| to form a new operand. |
| 227 | // 5. |ptr_input| is a pointer access chain. |
| 228 | // Like the above scenario, |inst|'s element operand is combined |
| 229 | // with |ptr_input|'s last index. This results is either a |
| 230 | // combined element operand or combined regular index. |
| 231 | |
| 232 | // TODO(alan-baker): Support this properly. Requires analyzing the |
| 233 | // size/alignment of the type and converting the stride into an element |
| 234 | // index. |
| 235 | uint32_t array_stride = GetArrayStride(ptr_input); |
| 236 | if (array_stride != 0) return false; |
| 237 | |
| 238 | if (ptr_input->NumInOperands() == 1) { |
| 239 | // The input is effectively a no-op. |
| 240 | inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)}); |
| 241 | context()->AnalyzeUses(inst); |
| 242 | } else if (inst->NumInOperands() == 1) { |
| 243 | // |inst| is a no-op, change it to a copy. Instruction simplification will |
| 244 | // clean it up. |
| 245 | inst->SetOpcode(SpvOpCopyObject); |
| 246 | } else { |
| 247 | std::vector<Operand> new_operands; |
| 248 | if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false; |
| 249 | |
| 250 | // Update the instruction. |
| 251 | inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode())); |
| 252 | inst->SetInOperands(std::move(new_operands)); |
| 253 | context()->AnalyzeUses(inst); |
| 254 | } |
| 255 | return true; |
| 256 | } |
| 257 | |
| 258 | SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) { |
| 259 | auto IsInBounds = [](SpvOp opcode) { |
| 260 | return opcode == SpvOpInBoundsPtrAccessChain || |
| 261 | opcode == SpvOpInBoundsAccessChain; |
| 262 | }; |
| 263 | |
| 264 | if (input_opcode == SpvOpInBoundsPtrAccessChain) { |
| 265 | if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain; |
| 266 | } else if (input_opcode == SpvOpInBoundsAccessChain) { |
| 267 | if (!IsInBounds(base_opcode)) return SpvOpAccessChain; |
| 268 | } |
| 269 | |
| 270 | return input_opcode; |
| 271 | } |
| 272 | |
| 273 | bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) { |
| 274 | return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain; |
| 275 | } |
| 276 | |
| 277 | bool CombineAccessChains::Has64BitIndices(Instruction* inst) { |
| 278 | for (uint32_t i = 1; i < inst->NumInOperands(); ++i) { |
| 279 | Instruction* index_inst = |
| 280 | context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i)); |
| 281 | const analysis::Type* index_type = |
| 282 | context()->get_type_mgr()->GetType(index_inst->type_id()); |
| 283 | if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32) |
| 284 | return true; |
| 285 | } |
| 286 | return false; |
| 287 | } |
| 288 | |
| 289 | } // namespace opt |
| 290 | } // namespace spvtools |
| 291 | |