1 | // Copyright 2018 The SwiftShader Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include "VkPipeline.hpp" |
16 | |
17 | #include "VkDevice.hpp" |
18 | #include "VkPipelineCache.hpp" |
19 | #include "VkPipelineLayout.hpp" |
20 | #include "VkShaderModule.hpp" |
21 | #include "VkRenderPass.hpp" |
22 | #include "Pipeline/ComputeProgram.hpp" |
23 | #include "Pipeline/SpirvShader.hpp" |
24 | |
25 | #include "marl/trace.h" |
26 | |
27 | #include "spirv-tools/optimizer.hpp" |
28 | |
29 | #include <iostream> |
30 | |
31 | namespace |
32 | { |
33 | |
34 | sw::StreamType getStreamType(VkFormat format) |
35 | { |
36 | switch(format) |
37 | { |
38 | case VK_FORMAT_R8_UNORM: |
39 | case VK_FORMAT_R8G8_UNORM: |
40 | case VK_FORMAT_R8G8B8A8_UNORM: |
41 | case VK_FORMAT_R8_UINT: |
42 | case VK_FORMAT_R8G8_UINT: |
43 | case VK_FORMAT_R8G8B8A8_UINT: |
44 | case VK_FORMAT_A8B8G8R8_UNORM_PACK32: |
45 | case VK_FORMAT_A8B8G8R8_UINT_PACK32: |
46 | return sw::STREAMTYPE_BYTE; |
47 | case VK_FORMAT_B8G8R8A8_UNORM: |
48 | return sw::STREAMTYPE_COLOR; |
49 | case VK_FORMAT_R8_SNORM: |
50 | case VK_FORMAT_R8_SINT: |
51 | case VK_FORMAT_R8G8_SNORM: |
52 | case VK_FORMAT_R8G8_SINT: |
53 | case VK_FORMAT_R8G8B8A8_SNORM: |
54 | case VK_FORMAT_R8G8B8A8_SINT: |
55 | case VK_FORMAT_A8B8G8R8_SNORM_PACK32: |
56 | case VK_FORMAT_A8B8G8R8_SINT_PACK32: |
57 | return sw::STREAMTYPE_SBYTE; |
58 | case VK_FORMAT_A2B10G10R10_UNORM_PACK32: |
59 | return sw::STREAMTYPE_2_10_10_10_UINT; |
60 | case VK_FORMAT_R16_UNORM: |
61 | case VK_FORMAT_R16_UINT: |
62 | case VK_FORMAT_R16G16_UNORM: |
63 | case VK_FORMAT_R16G16_UINT: |
64 | case VK_FORMAT_R16G16B16A16_UNORM: |
65 | case VK_FORMAT_R16G16B16A16_UINT: |
66 | return sw::STREAMTYPE_USHORT; |
67 | case VK_FORMAT_R16_SNORM: |
68 | case VK_FORMAT_R16_SINT: |
69 | case VK_FORMAT_R16G16_SNORM: |
70 | case VK_FORMAT_R16G16_SINT: |
71 | case VK_FORMAT_R16G16B16A16_SNORM: |
72 | case VK_FORMAT_R16G16B16A16_SINT: |
73 | return sw::STREAMTYPE_SHORT; |
74 | case VK_FORMAT_R16_SFLOAT: |
75 | case VK_FORMAT_R16G16_SFLOAT: |
76 | case VK_FORMAT_R16G16B16A16_SFLOAT: |
77 | return sw::STREAMTYPE_HALF; |
78 | case VK_FORMAT_R32_UINT: |
79 | case VK_FORMAT_R32G32_UINT: |
80 | case VK_FORMAT_R32G32B32_UINT: |
81 | case VK_FORMAT_R32G32B32A32_UINT: |
82 | return sw::STREAMTYPE_UINT; |
83 | case VK_FORMAT_R32_SINT: |
84 | case VK_FORMAT_R32G32_SINT: |
85 | case VK_FORMAT_R32G32B32_SINT: |
86 | case VK_FORMAT_R32G32B32A32_SINT: |
87 | return sw::STREAMTYPE_INT; |
88 | case VK_FORMAT_R32_SFLOAT: |
89 | case VK_FORMAT_R32G32_SFLOAT: |
90 | case VK_FORMAT_R32G32B32_SFLOAT: |
91 | case VK_FORMAT_R32G32B32A32_SFLOAT: |
92 | return sw::STREAMTYPE_FLOAT; |
93 | default: |
94 | UNIMPLEMENTED("format" ); |
95 | } |
96 | |
97 | return sw::STREAMTYPE_BYTE; |
98 | } |
99 | |
100 | unsigned char getNumberOfChannels(VkFormat format) |
101 | { |
102 | switch(format) |
103 | { |
104 | case VK_FORMAT_R8_UNORM: |
105 | case VK_FORMAT_R8_SNORM: |
106 | case VK_FORMAT_R8_UINT: |
107 | case VK_FORMAT_R8_SINT: |
108 | case VK_FORMAT_R16_UNORM: |
109 | case VK_FORMAT_R16_SNORM: |
110 | case VK_FORMAT_R16_UINT: |
111 | case VK_FORMAT_R16_SINT: |
112 | case VK_FORMAT_R16_SFLOAT: |
113 | case VK_FORMAT_R32_UINT: |
114 | case VK_FORMAT_R32_SINT: |
115 | case VK_FORMAT_R32_SFLOAT: |
116 | return 1; |
117 | case VK_FORMAT_R8G8_UNORM: |
118 | case VK_FORMAT_R8G8_SNORM: |
119 | case VK_FORMAT_R8G8_UINT: |
120 | case VK_FORMAT_R8G8_SINT: |
121 | case VK_FORMAT_R16G16_UNORM: |
122 | case VK_FORMAT_R16G16_SNORM: |
123 | case VK_FORMAT_R16G16_UINT: |
124 | case VK_FORMAT_R16G16_SINT: |
125 | case VK_FORMAT_R16G16_SFLOAT: |
126 | case VK_FORMAT_R32G32_UINT: |
127 | case VK_FORMAT_R32G32_SINT: |
128 | case VK_FORMAT_R32G32_SFLOAT: |
129 | return 2; |
130 | case VK_FORMAT_R32G32B32_UINT: |
131 | case VK_FORMAT_R32G32B32_SINT: |
132 | case VK_FORMAT_R32G32B32_SFLOAT: |
133 | return 3; |
134 | case VK_FORMAT_R8G8B8A8_UNORM: |
135 | case VK_FORMAT_R8G8B8A8_SNORM: |
136 | case VK_FORMAT_R8G8B8A8_UINT: |
137 | case VK_FORMAT_R8G8B8A8_SINT: |
138 | case VK_FORMAT_B8G8R8A8_UNORM: |
139 | case VK_FORMAT_A8B8G8R8_UNORM_PACK32: |
140 | case VK_FORMAT_A8B8G8R8_SNORM_PACK32: |
141 | case VK_FORMAT_A8B8G8R8_UINT_PACK32: |
142 | case VK_FORMAT_A8B8G8R8_SINT_PACK32: |
143 | case VK_FORMAT_A2B10G10R10_UNORM_PACK32: |
144 | case VK_FORMAT_R16G16B16A16_UNORM: |
145 | case VK_FORMAT_R16G16B16A16_SNORM: |
146 | case VK_FORMAT_R16G16B16A16_UINT: |
147 | case VK_FORMAT_R16G16B16A16_SINT: |
148 | case VK_FORMAT_R16G16B16A16_SFLOAT: |
149 | case VK_FORMAT_R32G32B32A32_UINT: |
150 | case VK_FORMAT_R32G32B32A32_SINT: |
151 | case VK_FORMAT_R32G32B32A32_SFLOAT: |
152 | return 4; |
153 | default: |
154 | UNIMPLEMENTED("format" ); |
155 | } |
156 | |
157 | return 0; |
158 | } |
159 | |
160 | // preprocessSpirv applies and freezes specializations into constants, and inlines all functions. |
161 | std::vector<uint32_t> preprocessSpirv( |
162 | std::vector<uint32_t> const &code, |
163 | VkSpecializationInfo const *specializationInfo) |
164 | { |
165 | spvtools::Optimizer opt{SPV_ENV_VULKAN_1_1}; |
166 | |
167 | opt.SetMessageConsumer([](spv_message_level_t level, const char*, const spv_position_t& p, const char* m) { |
168 | switch (level) |
169 | { |
170 | case SPV_MSG_FATAL: vk::warn("SPIR-V FATAL: %d:%d %s\n" , int(p.line), int(p.column), m); |
171 | case SPV_MSG_INTERNAL_ERROR: vk::warn("SPIR-V INTERNAL_ERROR: %d:%d %s\n" , int(p.line), int(p.column), m); |
172 | case SPV_MSG_ERROR: vk::warn("SPIR-V ERROR: %d:%d %s\n" , int(p.line), int(p.column), m); |
173 | case SPV_MSG_WARNING: vk::warn("SPIR-V WARNING: %d:%d %s\n" , int(p.line), int(p.column), m); |
174 | case SPV_MSG_INFO: vk::trace("SPIR-V INFO: %d:%d %s\n" , int(p.line), int(p.column), m); |
175 | case SPV_MSG_DEBUG: vk::trace("SPIR-V DEBUG: %d:%d %s\n" , int(p.line), int(p.column), m); |
176 | default: vk::trace("SPIR-V MESSAGE: %d:%d %s\n" , int(p.line), int(p.column), m); |
177 | } |
178 | }); |
179 | |
180 | // If the pipeline uses specialization, apply the specializations before freezing |
181 | if (specializationInfo) |
182 | { |
183 | std::unordered_map<uint32_t, std::vector<uint32_t>> specializations; |
184 | for (auto i = 0u; i < specializationInfo->mapEntryCount; ++i) |
185 | { |
186 | auto const &e = specializationInfo->pMapEntries[i]; |
187 | auto value_ptr = |
188 | static_cast<uint32_t const *>(specializationInfo->pData) + e.offset / sizeof(uint32_t); |
189 | specializations.emplace(e.constantID, |
190 | std::vector<uint32_t>{value_ptr, value_ptr + e.size / sizeof(uint32_t)}); |
191 | } |
192 | opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations)); |
193 | } |
194 | |
195 | // Full optimization list taken from spirv-opt. |
196 | opt.RegisterPass(spvtools::CreateWrapOpKillPass()) |
197 | .RegisterPass(spvtools::CreateDeadBranchElimPass()) |
198 | .RegisterPass(spvtools::CreateMergeReturnPass()) |
199 | .RegisterPass(spvtools::CreateInlineExhaustivePass()) |
200 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
201 | .RegisterPass(spvtools::CreatePrivateToLocalPass()) |
202 | .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()) |
203 | .RegisterPass(spvtools::CreateLocalSingleStoreElimPass()) |
204 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
205 | .RegisterPass(spvtools::CreateScalarReplacementPass()) |
206 | .RegisterPass(spvtools::CreateLocalAccessChainConvertPass()) |
207 | .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass()) |
208 | .RegisterPass(spvtools::CreateLocalSingleStoreElimPass()) |
209 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
210 | .RegisterPass(spvtools::CreateLocalMultiStoreElimPass()) |
211 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
212 | .RegisterPass(spvtools::CreateCCPPass()) |
213 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
214 | .RegisterPass(spvtools::CreateRedundancyEliminationPass()) |
215 | .RegisterPass(spvtools::CreateCombineAccessChainsPass()) |
216 | .RegisterPass(spvtools::CreateSimplificationPass()) |
217 | .RegisterPass(spvtools::CreateVectorDCEPass()) |
218 | .RegisterPass(spvtools::CreateDeadInsertElimPass()) |
219 | .RegisterPass(spvtools::CreateDeadBranchElimPass()) |
220 | .RegisterPass(spvtools::CreateSimplificationPass()) |
221 | .RegisterPass(spvtools::CreateIfConversionPass()) |
222 | .RegisterPass(spvtools::CreateCopyPropagateArraysPass()) |
223 | .RegisterPass(spvtools::CreateReduceLoadSizePass()) |
224 | .RegisterPass(spvtools::CreateAggressiveDCEPass()) |
225 | .RegisterPass(spvtools::CreateBlockMergePass()) |
226 | .RegisterPass(spvtools::CreateRedundancyEliminationPass()) |
227 | .RegisterPass(spvtools::CreateDeadBranchElimPass()) |
228 | .RegisterPass(spvtools::CreateBlockMergePass()) |
229 | .RegisterPass(spvtools::CreateSimplificationPass()); |
230 | |
231 | std::vector<uint32_t> optimized; |
232 | opt.Run(code.data(), code.size(), &optimized); |
233 | |
234 | if (false) { |
235 | spvtools::SpirvTools core(SPV_ENV_VULKAN_1_1); |
236 | std::string preOpt; |
237 | core.Disassemble(code, &preOpt, SPV_BINARY_TO_TEXT_OPTION_NONE); |
238 | std::string postOpt; |
239 | core.Disassemble(optimized, &postOpt, SPV_BINARY_TO_TEXT_OPTION_NONE); |
240 | std::cout << "PRE-OPT: " << preOpt << std::endl |
241 | << "POST-OPT: " << postOpt << std::endl; |
242 | } |
243 | |
244 | return optimized; |
245 | } |
246 | |
247 | std::shared_ptr<sw::SpirvShader> createShader(const vk::PipelineCache::SpirvShaderKey& key, const vk::ShaderModule *module, bool robustBufferAccess) |
248 | { |
249 | auto code = preprocessSpirv(key.getInsns(), key.getSpecializationInfo()); |
250 | ASSERT(code.size() > 0); |
251 | |
252 | // If the pipeline has specialization constants, assume they're unique and |
253 | // use a new serial ID so the shader gets recompiled. |
254 | uint32_t codeSerialID = (key.getSpecializationInfo() ? vk::ShaderModule::nextSerialID() : module->getSerialID()); |
255 | |
256 | // TODO(b/119409619): use allocator. |
257 | return std::make_shared<sw::SpirvShader>(codeSerialID, key.getPipelineStage(), key.getEntryPointName().c_str(), |
258 | code, key.getRenderPass(), key.getSubpassIndex(), robustBufferAccess); |
259 | } |
260 | |
261 | std::shared_ptr<sw::ComputeProgram> createProgram(const vk::PipelineCache::ComputeProgramKey& key) |
262 | { |
263 | MARL_SCOPED_EVENT("createProgram" ); |
264 | |
265 | vk::DescriptorSet::Bindings descriptorSets; // FIXME(b/129523279): Delay code generation until invoke time. |
266 | // TODO(b/119409619): use allocator. |
267 | auto program = std::make_shared<sw::ComputeProgram>(key.getShader(), key.getLayout(), descriptorSets); |
268 | program->generate(); |
269 | program->finalize(); |
270 | return program; |
271 | } |
272 | |
273 | } // anonymous namespace |
274 | |
275 | namespace vk |
276 | { |
277 | |
278 | Pipeline::Pipeline(PipelineLayout const *layout, const Device *device) |
279 | : layout(layout), |
280 | robustBufferAccess(device->getEnabledFeatures().robustBufferAccess) |
281 | { |
282 | } |
283 | |
284 | GraphicsPipeline::GraphicsPipeline(const VkGraphicsPipelineCreateInfo* pCreateInfo, void* mem, const Device *device) |
285 | : Pipeline(vk::Cast(pCreateInfo->layout), device) |
286 | { |
287 | context.robustBufferAccess = robustBufferAccess; |
288 | |
289 | if(((pCreateInfo->flags & |
290 | ~(VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT | |
291 | VK_PIPELINE_CREATE_DERIVATIVE_BIT | |
292 | VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT)) != 0) || |
293 | (pCreateInfo->pTessellationState != nullptr)) |
294 | { |
295 | UNIMPLEMENTED("pCreateInfo settings" ); |
296 | } |
297 | |
298 | if(pCreateInfo->pDynamicState) |
299 | { |
300 | for(uint32_t i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; i++) |
301 | { |
302 | VkDynamicState dynamicState = pCreateInfo->pDynamicState->pDynamicStates[i]; |
303 | switch(dynamicState) |
304 | { |
305 | case VK_DYNAMIC_STATE_VIEWPORT: |
306 | case VK_DYNAMIC_STATE_SCISSOR: |
307 | case VK_DYNAMIC_STATE_LINE_WIDTH: |
308 | case VK_DYNAMIC_STATE_DEPTH_BIAS: |
309 | case VK_DYNAMIC_STATE_BLEND_CONSTANTS: |
310 | case VK_DYNAMIC_STATE_DEPTH_BOUNDS: |
311 | case VK_DYNAMIC_STATE_STENCIL_COMPARE_MASK: |
312 | case VK_DYNAMIC_STATE_STENCIL_WRITE_MASK: |
313 | case VK_DYNAMIC_STATE_STENCIL_REFERENCE: |
314 | ASSERT(dynamicState < (sizeof(dynamicStateFlags) * 8)); |
315 | dynamicStateFlags |= (1 << dynamicState); |
316 | break; |
317 | default: |
318 | UNIMPLEMENTED("dynamic state" ); |
319 | } |
320 | } |
321 | } |
322 | |
323 | const VkPipelineVertexInputStateCreateInfo* vertexInputState = pCreateInfo->pVertexInputState; |
324 | if(vertexInputState->flags != 0) |
325 | { |
326 | UNIMPLEMENTED("vertexInputState->flags" ); |
327 | } |
328 | |
329 | // Context must always have a PipelineLayout set. |
330 | context.pipelineLayout = layout; |
331 | |
332 | // Temporary in-binding-order representation of buffer strides, to be consumed below |
333 | // when considering attributes. TODO: unfuse buffers from attributes in backend, is old GL model. |
334 | uint32_t vertexStrides[MAX_VERTEX_INPUT_BINDINGS]; |
335 | uint32_t instanceStrides[MAX_VERTEX_INPUT_BINDINGS]; |
336 | for(uint32_t i = 0; i < vertexInputState->vertexBindingDescriptionCount; i++) |
337 | { |
338 | auto const & desc = vertexInputState->pVertexBindingDescriptions[i]; |
339 | vertexStrides[desc.binding] = desc.inputRate == VK_VERTEX_INPUT_RATE_VERTEX ? desc.stride : 0; |
340 | instanceStrides[desc.binding] = desc.inputRate == VK_VERTEX_INPUT_RATE_INSTANCE ? desc.stride : 0; |
341 | } |
342 | |
343 | for(uint32_t i = 0; i < vertexInputState->vertexAttributeDescriptionCount; i++) |
344 | { |
345 | auto const & desc = vertexInputState->pVertexAttributeDescriptions[i]; |
346 | sw::Stream& input = context.input[desc.location]; |
347 | input.count = getNumberOfChannels(desc.format); |
348 | input.type = getStreamType(desc.format); |
349 | input.normalized = !vk::Format(desc.format).isNonNormalizedInteger(); |
350 | input.offset = desc.offset; |
351 | input.binding = desc.binding; |
352 | input.vertexStride = vertexStrides[desc.binding]; |
353 | input.instanceStride = instanceStrides[desc.binding]; |
354 | } |
355 | |
356 | const VkPipelineInputAssemblyStateCreateInfo* assemblyState = pCreateInfo->pInputAssemblyState; |
357 | if(assemblyState->flags != 0) |
358 | { |
359 | UNIMPLEMENTED("pCreateInfo->pInputAssemblyState settings" ); |
360 | } |
361 | |
362 | primitiveRestartEnable = (assemblyState->primitiveRestartEnable != VK_FALSE); |
363 | context.topology = assemblyState->topology; |
364 | |
365 | const VkPipelineViewportStateCreateInfo* viewportState = pCreateInfo->pViewportState; |
366 | if(viewportState) |
367 | { |
368 | if((viewportState->flags != 0) || |
369 | (viewportState->viewportCount != 1) || |
370 | (viewportState->scissorCount != 1)) |
371 | { |
372 | UNIMPLEMENTED("pCreateInfo->pViewportState settings" ); |
373 | } |
374 | |
375 | if(!hasDynamicState(VK_DYNAMIC_STATE_SCISSOR)) |
376 | { |
377 | scissor = viewportState->pScissors[0]; |
378 | } |
379 | |
380 | if(!hasDynamicState(VK_DYNAMIC_STATE_VIEWPORT)) |
381 | { |
382 | viewport = viewportState->pViewports[0]; |
383 | } |
384 | } |
385 | |
386 | const VkPipelineRasterizationStateCreateInfo* rasterizationState = pCreateInfo->pRasterizationState; |
387 | if((rasterizationState->flags != 0) || |
388 | (rasterizationState->depthClampEnable != VK_FALSE)) |
389 | { |
390 | UNIMPLEMENTED("pCreateInfo->pRasterizationState settings" ); |
391 | } |
392 | |
393 | context.rasterizerDiscard = (rasterizationState->rasterizerDiscardEnable == VK_TRUE); |
394 | context.cullMode = rasterizationState->cullMode; |
395 | context.frontFace = rasterizationState->frontFace; |
396 | context.polygonMode = rasterizationState->polygonMode; |
397 | context.depthBias = (rasterizationState->depthBiasEnable != VK_FALSE) ? rasterizationState->depthBiasConstantFactor : 0.0f; |
398 | context.slopeDepthBias = (rasterizationState->depthBiasEnable != VK_FALSE) ? rasterizationState->depthBiasSlopeFactor : 0.0f; |
399 | |
400 | const VkBaseInStructure* extensionCreateInfo = reinterpret_cast<const VkBaseInStructure*>(rasterizationState->pNext); |
401 | while(extensionCreateInfo) |
402 | { |
403 | // Casting to a long since some structures, such as |
404 | // VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROVOKING_VERTEX_FEATURES_EXT |
405 | // are not enumerated in the official Vulkan header |
406 | switch((long)(extensionCreateInfo->sType)) |
407 | { |
408 | case VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_LINE_STATE_CREATE_INFO_EXT: |
409 | { |
410 | const VkPipelineRasterizationLineStateCreateInfoEXT* lineStateCreateInfo = reinterpret_cast<const VkPipelineRasterizationLineStateCreateInfoEXT*>(extensionCreateInfo); |
411 | context.lineRasterizationMode = lineStateCreateInfo->lineRasterizationMode; |
412 | } |
413 | break; |
414 | case VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_PROVOKING_VERTEX_STATE_CREATE_INFO_EXT: |
415 | { |
416 | const VkPipelineRasterizationProvokingVertexStateCreateInfoEXT* provokingVertexModeCreateInfo = |
417 | reinterpret_cast<const VkPipelineRasterizationProvokingVertexStateCreateInfoEXT*>(extensionCreateInfo); |
418 | context.provokingVertexMode = provokingVertexModeCreateInfo->provokingVertexMode; |
419 | } |
420 | break; |
421 | default: |
422 | UNIMPLEMENTED("extensionCreateInfo->sType" ); |
423 | break; |
424 | } |
425 | |
426 | extensionCreateInfo = extensionCreateInfo->pNext; |
427 | } |
428 | |
429 | const VkPipelineMultisampleStateCreateInfo* multisampleState = pCreateInfo->pMultisampleState; |
430 | if(multisampleState) |
431 | { |
432 | switch (multisampleState->rasterizationSamples) |
433 | { |
434 | case VK_SAMPLE_COUNT_1_BIT: |
435 | context.sampleCount = 1; |
436 | break; |
437 | case VK_SAMPLE_COUNT_4_BIT: |
438 | context.sampleCount = 4; |
439 | break; |
440 | default: |
441 | UNIMPLEMENTED("Unsupported sample count" ); |
442 | } |
443 | |
444 | if (multisampleState->pSampleMask) |
445 | { |
446 | context.sampleMask = multisampleState->pSampleMask[0]; |
447 | } |
448 | |
449 | context.alphaToCoverage = (multisampleState->alphaToCoverageEnable == VK_TRUE); |
450 | |
451 | if((multisampleState->flags != 0) || |
452 | (multisampleState->sampleShadingEnable != VK_FALSE) || |
453 | (multisampleState->alphaToOneEnable != VK_FALSE)) |
454 | { |
455 | UNIMPLEMENTED("multisampleState" ); |
456 | } |
457 | } |
458 | else |
459 | { |
460 | context.sampleCount = 1; |
461 | } |
462 | |
463 | const VkPipelineDepthStencilStateCreateInfo* depthStencilState = pCreateInfo->pDepthStencilState; |
464 | if(depthStencilState) |
465 | { |
466 | if((depthStencilState->flags != 0) || |
467 | (depthStencilState->depthBoundsTestEnable != VK_FALSE)) |
468 | { |
469 | UNIMPLEMENTED("depthStencilState" ); |
470 | } |
471 | |
472 | context.depthBoundsTestEnable = (depthStencilState->depthBoundsTestEnable == VK_TRUE); |
473 | context.depthBufferEnable = (depthStencilState->depthTestEnable == VK_TRUE); |
474 | context.depthWriteEnable = (depthStencilState->depthWriteEnable == VK_TRUE); |
475 | context.depthCompareMode = depthStencilState->depthCompareOp; |
476 | |
477 | context.stencilEnable = (depthStencilState->stencilTestEnable == VK_TRUE); |
478 | if(context.stencilEnable) |
479 | { |
480 | context.frontStencil = depthStencilState->front; |
481 | context.backStencil = depthStencilState->back; |
482 | } |
483 | } |
484 | |
485 | const VkPipelineColorBlendStateCreateInfo* colorBlendState = pCreateInfo->pColorBlendState; |
486 | if(colorBlendState) |
487 | { |
488 | if((colorBlendState->flags != 0) || |
489 | ((colorBlendState->logicOpEnable != VK_FALSE))) |
490 | { |
491 | UNIMPLEMENTED("colorBlendState" ); |
492 | } |
493 | |
494 | if(!hasDynamicState(VK_DYNAMIC_STATE_BLEND_CONSTANTS)) |
495 | { |
496 | blendConstants.r = colorBlendState->blendConstants[0]; |
497 | blendConstants.g = colorBlendState->blendConstants[1]; |
498 | blendConstants.b = colorBlendState->blendConstants[2]; |
499 | blendConstants.a = colorBlendState->blendConstants[3]; |
500 | } |
501 | |
502 | for (auto i = 0u; i < colorBlendState->attachmentCount; i++) |
503 | { |
504 | const VkPipelineColorBlendAttachmentState& attachment = colorBlendState->pAttachments[i]; |
505 | context.colorWriteMask[i] = attachment.colorWriteMask; |
506 | |
507 | context.setBlendState(i, { (attachment.blendEnable == VK_TRUE), |
508 | attachment.srcColorBlendFactor, attachment.dstColorBlendFactor, attachment.colorBlendOp, |
509 | attachment.srcAlphaBlendFactor, attachment.dstAlphaBlendFactor, attachment.alphaBlendOp }); |
510 | } |
511 | } |
512 | |
513 | context.multiSampleMask = context.sampleMask & ((unsigned) 0xFFFFFFFF >> (32 - context.sampleCount)); |
514 | } |
515 | |
516 | void GraphicsPipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator) |
517 | { |
518 | vertexShader.reset(); |
519 | fragmentShader.reset(); |
520 | } |
521 | |
522 | size_t GraphicsPipeline::ComputeRequiredAllocationSize(const VkGraphicsPipelineCreateInfo* pCreateInfo) |
523 | { |
524 | return 0; |
525 | } |
526 | |
527 | void GraphicsPipeline::setShader(const VkShaderStageFlagBits& stage, const std::shared_ptr<sw::SpirvShader> spirvShader) |
528 | { |
529 | switch(stage) |
530 | { |
531 | case VK_SHADER_STAGE_VERTEX_BIT: |
532 | ASSERT(vertexShader.get() == nullptr); |
533 | vertexShader = spirvShader; |
534 | context.vertexShader = vertexShader.get(); |
535 | break; |
536 | |
537 | case VK_SHADER_STAGE_FRAGMENT_BIT: |
538 | ASSERT(fragmentShader.get() == nullptr); |
539 | fragmentShader = spirvShader; |
540 | context.pixelShader = fragmentShader.get(); |
541 | break; |
542 | |
543 | default: |
544 | UNSUPPORTED("Unsupported stage" ); |
545 | break; |
546 | } |
547 | } |
548 | |
549 | const std::shared_ptr<sw::SpirvShader> GraphicsPipeline::getShader(const VkShaderStageFlagBits& stage) const |
550 | { |
551 | switch(stage) |
552 | { |
553 | case VK_SHADER_STAGE_VERTEX_BIT: |
554 | return vertexShader; |
555 | case VK_SHADER_STAGE_FRAGMENT_BIT: |
556 | return fragmentShader; |
557 | default: |
558 | UNSUPPORTED("Unsupported stage" ); |
559 | return fragmentShader; |
560 | } |
561 | } |
562 | |
563 | void GraphicsPipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkGraphicsPipelineCreateInfo* pCreateInfo, PipelineCache* pPipelineCache) |
564 | { |
565 | for (auto pStage = pCreateInfo->pStages; pStage != pCreateInfo->pStages + pCreateInfo->stageCount; pStage++) |
566 | { |
567 | if (pStage->flags != 0) |
568 | { |
569 | UNIMPLEMENTED("pStage->flags" ); |
570 | } |
571 | |
572 | const ShaderModule *module = vk::Cast(pStage->module); |
573 | const PipelineCache::SpirvShaderKey key(pStage->stage, pStage->pName, module->getCode(), |
574 | vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass, |
575 | pStage->pSpecializationInfo); |
576 | auto pipelineStage = key.getPipelineStage(); |
577 | |
578 | if(pPipelineCache) |
579 | { |
580 | PipelineCache& pipelineCache = *pPipelineCache; |
581 | { |
582 | std::unique_lock<std::mutex> lock(pipelineCache.getShaderMutex()); |
583 | const std::shared_ptr<sw::SpirvShader>* spirvShader = pipelineCache[key]; |
584 | if(!spirvShader) |
585 | { |
586 | auto shader = createShader(key, module, robustBufferAccess); |
587 | setShader(pipelineStage, shader); |
588 | pipelineCache.insert(key, getShader(pipelineStage)); |
589 | } |
590 | else |
591 | { |
592 | setShader(pipelineStage, *spirvShader); |
593 | } |
594 | } |
595 | } |
596 | else |
597 | { |
598 | auto shader = createShader(key, module, robustBufferAccess); |
599 | setShader(pipelineStage, shader); |
600 | } |
601 | } |
602 | } |
603 | |
604 | uint32_t GraphicsPipeline::computePrimitiveCount(uint32_t vertexCount) const |
605 | { |
606 | switch(context.topology) |
607 | { |
608 | case VK_PRIMITIVE_TOPOLOGY_POINT_LIST: |
609 | return vertexCount; |
610 | case VK_PRIMITIVE_TOPOLOGY_LINE_LIST: |
611 | return vertexCount / 2; |
612 | case VK_PRIMITIVE_TOPOLOGY_LINE_STRIP: |
613 | return std::max<uint32_t>(vertexCount, 1) - 1; |
614 | case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST: |
615 | return vertexCount / 3; |
616 | case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP: |
617 | return std::max<uint32_t>(vertexCount, 2) - 2; |
618 | case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_FAN: |
619 | return std::max<uint32_t>(vertexCount, 2) - 2; |
620 | default: |
621 | UNIMPLEMENTED("context.topology %d" , int(context.topology)); |
622 | } |
623 | |
624 | return 0; |
625 | } |
626 | |
627 | const sw::Context& GraphicsPipeline::getContext() const |
628 | { |
629 | return context; |
630 | } |
631 | |
632 | const VkRect2D& GraphicsPipeline::getScissor() const |
633 | { |
634 | return scissor; |
635 | } |
636 | |
637 | const VkViewport& GraphicsPipeline::getViewport() const |
638 | { |
639 | return viewport; |
640 | } |
641 | |
642 | const sw::Color<float>& GraphicsPipeline::getBlendConstants() const |
643 | { |
644 | return blendConstants; |
645 | } |
646 | |
647 | bool GraphicsPipeline::hasDynamicState(VkDynamicState dynamicState) const |
648 | { |
649 | return (dynamicStateFlags & (1 << dynamicState)) != 0; |
650 | } |
651 | |
652 | ComputePipeline::ComputePipeline(const VkComputePipelineCreateInfo* pCreateInfo, void* mem, const Device *device) |
653 | : Pipeline(vk::Cast(pCreateInfo->layout), device) |
654 | { |
655 | } |
656 | |
657 | void ComputePipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator) |
658 | { |
659 | shader.reset(); |
660 | program.reset(); |
661 | } |
662 | |
663 | size_t ComputePipeline::ComputeRequiredAllocationSize(const VkComputePipelineCreateInfo* pCreateInfo) |
664 | { |
665 | return 0; |
666 | } |
667 | |
668 | void ComputePipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkComputePipelineCreateInfo* pCreateInfo, PipelineCache* pPipelineCache) |
669 | { |
670 | auto &stage = pCreateInfo->stage; |
671 | const ShaderModule *module = vk::Cast(stage.module); |
672 | |
673 | ASSERT(shader.get() == nullptr); |
674 | ASSERT(program.get() == nullptr); |
675 | |
676 | const PipelineCache::SpirvShaderKey shaderKey( |
677 | stage.stage, stage.pName, module->getCode(), nullptr, 0, stage.pSpecializationInfo); |
678 | if(pPipelineCache) |
679 | { |
680 | PipelineCache& pipelineCache = *pPipelineCache; |
681 | { |
682 | std::unique_lock<std::mutex> lock(pipelineCache.getShaderMutex()); |
683 | const std::shared_ptr<sw::SpirvShader>* spirvShader = pipelineCache[shaderKey]; |
684 | if(!spirvShader) |
685 | { |
686 | shader = createShader(shaderKey, module, robustBufferAccess); |
687 | pipelineCache.insert(shaderKey, shader); |
688 | } |
689 | else |
690 | { |
691 | shader = *spirvShader; |
692 | } |
693 | } |
694 | |
695 | { |
696 | const PipelineCache::ComputeProgramKey programKey(shader.get(), layout); |
697 | std::unique_lock<std::mutex> lock(pipelineCache.getProgramMutex()); |
698 | const std::shared_ptr<sw::ComputeProgram>* computeProgram = pipelineCache[programKey]; |
699 | if(!computeProgram) |
700 | { |
701 | program = createProgram(programKey); |
702 | pipelineCache.insert(programKey, program); |
703 | } |
704 | else |
705 | { |
706 | program = *computeProgram; |
707 | } |
708 | } |
709 | } |
710 | else |
711 | { |
712 | shader = createShader(shaderKey, module, robustBufferAccess); |
713 | const PipelineCache::ComputeProgramKey programKey(shader.get(), layout); |
714 | program = createProgram(programKey); |
715 | } |
716 | } |
717 | |
718 | void ComputePipeline::run(uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ, |
719 | uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ, |
720 | vk::DescriptorSet::Bindings const &descriptorSets, |
721 | vk::DescriptorSet::DynamicOffsets const &descriptorDynamicOffsets, |
722 | sw::PushConstantStorage const &pushConstants) |
723 | { |
724 | ASSERT_OR_RETURN(program != nullptr); |
725 | program->run( |
726 | descriptorSets, descriptorDynamicOffsets, pushConstants, |
727 | baseGroupX, baseGroupY, baseGroupZ, |
728 | groupCountX, groupCountY, groupCountZ); |
729 | } |
730 | |
731 | } // namespace vk |
732 | |