1// Copyright 2018 The SwiftShader Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "VkPipeline.hpp"
16
17#include "VkDevice.hpp"
18#include "VkPipelineCache.hpp"
19#include "VkPipelineLayout.hpp"
20#include "VkShaderModule.hpp"
21#include "VkRenderPass.hpp"
22#include "Pipeline/ComputeProgram.hpp"
23#include "Pipeline/SpirvShader.hpp"
24
25#include "marl/trace.h"
26
27#include "spirv-tools/optimizer.hpp"
28
29#include <iostream>
30
31namespace
32{
33
34sw::StreamType getStreamType(VkFormat format)
35{
36 switch(format)
37 {
38 case VK_FORMAT_R8_UNORM:
39 case VK_FORMAT_R8G8_UNORM:
40 case VK_FORMAT_R8G8B8A8_UNORM:
41 case VK_FORMAT_R8_UINT:
42 case VK_FORMAT_R8G8_UINT:
43 case VK_FORMAT_R8G8B8A8_UINT:
44 case VK_FORMAT_A8B8G8R8_UNORM_PACK32:
45 case VK_FORMAT_A8B8G8R8_UINT_PACK32:
46 return sw::STREAMTYPE_BYTE;
47 case VK_FORMAT_B8G8R8A8_UNORM:
48 return sw::STREAMTYPE_COLOR;
49 case VK_FORMAT_R8_SNORM:
50 case VK_FORMAT_R8_SINT:
51 case VK_FORMAT_R8G8_SNORM:
52 case VK_FORMAT_R8G8_SINT:
53 case VK_FORMAT_R8G8B8A8_SNORM:
54 case VK_FORMAT_R8G8B8A8_SINT:
55 case VK_FORMAT_A8B8G8R8_SNORM_PACK32:
56 case VK_FORMAT_A8B8G8R8_SINT_PACK32:
57 return sw::STREAMTYPE_SBYTE;
58 case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
59 return sw::STREAMTYPE_2_10_10_10_UINT;
60 case VK_FORMAT_R16_UNORM:
61 case VK_FORMAT_R16_UINT:
62 case VK_FORMAT_R16G16_UNORM:
63 case VK_FORMAT_R16G16_UINT:
64 case VK_FORMAT_R16G16B16A16_UNORM:
65 case VK_FORMAT_R16G16B16A16_UINT:
66 return sw::STREAMTYPE_USHORT;
67 case VK_FORMAT_R16_SNORM:
68 case VK_FORMAT_R16_SINT:
69 case VK_FORMAT_R16G16_SNORM:
70 case VK_FORMAT_R16G16_SINT:
71 case VK_FORMAT_R16G16B16A16_SNORM:
72 case VK_FORMAT_R16G16B16A16_SINT:
73 return sw::STREAMTYPE_SHORT;
74 case VK_FORMAT_R16_SFLOAT:
75 case VK_FORMAT_R16G16_SFLOAT:
76 case VK_FORMAT_R16G16B16A16_SFLOAT:
77 return sw::STREAMTYPE_HALF;
78 case VK_FORMAT_R32_UINT:
79 case VK_FORMAT_R32G32_UINT:
80 case VK_FORMAT_R32G32B32_UINT:
81 case VK_FORMAT_R32G32B32A32_UINT:
82 return sw::STREAMTYPE_UINT;
83 case VK_FORMAT_R32_SINT:
84 case VK_FORMAT_R32G32_SINT:
85 case VK_FORMAT_R32G32B32_SINT:
86 case VK_FORMAT_R32G32B32A32_SINT:
87 return sw::STREAMTYPE_INT;
88 case VK_FORMAT_R32_SFLOAT:
89 case VK_FORMAT_R32G32_SFLOAT:
90 case VK_FORMAT_R32G32B32_SFLOAT:
91 case VK_FORMAT_R32G32B32A32_SFLOAT:
92 return sw::STREAMTYPE_FLOAT;
93 default:
94 UNIMPLEMENTED("format");
95 }
96
97 return sw::STREAMTYPE_BYTE;
98}
99
100unsigned char getNumberOfChannels(VkFormat format)
101{
102 switch(format)
103 {
104 case VK_FORMAT_R8_UNORM:
105 case VK_FORMAT_R8_SNORM:
106 case VK_FORMAT_R8_UINT:
107 case VK_FORMAT_R8_SINT:
108 case VK_FORMAT_R16_UNORM:
109 case VK_FORMAT_R16_SNORM:
110 case VK_FORMAT_R16_UINT:
111 case VK_FORMAT_R16_SINT:
112 case VK_FORMAT_R16_SFLOAT:
113 case VK_FORMAT_R32_UINT:
114 case VK_FORMAT_R32_SINT:
115 case VK_FORMAT_R32_SFLOAT:
116 return 1;
117 case VK_FORMAT_R8G8_UNORM:
118 case VK_FORMAT_R8G8_SNORM:
119 case VK_FORMAT_R8G8_UINT:
120 case VK_FORMAT_R8G8_SINT:
121 case VK_FORMAT_R16G16_UNORM:
122 case VK_FORMAT_R16G16_SNORM:
123 case VK_FORMAT_R16G16_UINT:
124 case VK_FORMAT_R16G16_SINT:
125 case VK_FORMAT_R16G16_SFLOAT:
126 case VK_FORMAT_R32G32_UINT:
127 case VK_FORMAT_R32G32_SINT:
128 case VK_FORMAT_R32G32_SFLOAT:
129 return 2;
130 case VK_FORMAT_R32G32B32_UINT:
131 case VK_FORMAT_R32G32B32_SINT:
132 case VK_FORMAT_R32G32B32_SFLOAT:
133 return 3;
134 case VK_FORMAT_R8G8B8A8_UNORM:
135 case VK_FORMAT_R8G8B8A8_SNORM:
136 case VK_FORMAT_R8G8B8A8_UINT:
137 case VK_FORMAT_R8G8B8A8_SINT:
138 case VK_FORMAT_B8G8R8A8_UNORM:
139 case VK_FORMAT_A8B8G8R8_UNORM_PACK32:
140 case VK_FORMAT_A8B8G8R8_SNORM_PACK32:
141 case VK_FORMAT_A8B8G8R8_UINT_PACK32:
142 case VK_FORMAT_A8B8G8R8_SINT_PACK32:
143 case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
144 case VK_FORMAT_R16G16B16A16_UNORM:
145 case VK_FORMAT_R16G16B16A16_SNORM:
146 case VK_FORMAT_R16G16B16A16_UINT:
147 case VK_FORMAT_R16G16B16A16_SINT:
148 case VK_FORMAT_R16G16B16A16_SFLOAT:
149 case VK_FORMAT_R32G32B32A32_UINT:
150 case VK_FORMAT_R32G32B32A32_SINT:
151 case VK_FORMAT_R32G32B32A32_SFLOAT:
152 return 4;
153 default:
154 UNIMPLEMENTED("format");
155 }
156
157 return 0;
158}
159
160// preprocessSpirv applies and freezes specializations into constants, and inlines all functions.
161std::vector<uint32_t> preprocessSpirv(
162 std::vector<uint32_t> const &code,
163 VkSpecializationInfo const *specializationInfo)
164{
165 spvtools::Optimizer opt{SPV_ENV_VULKAN_1_1};
166
167 opt.SetMessageConsumer([](spv_message_level_t level, const char*, const spv_position_t& p, const char* m) {
168 switch (level)
169 {
170 case SPV_MSG_FATAL: vk::warn("SPIR-V FATAL: %d:%d %s\n", int(p.line), int(p.column), m);
171 case SPV_MSG_INTERNAL_ERROR: vk::warn("SPIR-V INTERNAL_ERROR: %d:%d %s\n", int(p.line), int(p.column), m);
172 case SPV_MSG_ERROR: vk::warn("SPIR-V ERROR: %d:%d %s\n", int(p.line), int(p.column), m);
173 case SPV_MSG_WARNING: vk::warn("SPIR-V WARNING: %d:%d %s\n", int(p.line), int(p.column), m);
174 case SPV_MSG_INFO: vk::trace("SPIR-V INFO: %d:%d %s\n", int(p.line), int(p.column), m);
175 case SPV_MSG_DEBUG: vk::trace("SPIR-V DEBUG: %d:%d %s\n", int(p.line), int(p.column), m);
176 default: vk::trace("SPIR-V MESSAGE: %d:%d %s\n", int(p.line), int(p.column), m);
177 }
178 });
179
180 // If the pipeline uses specialization, apply the specializations before freezing
181 if (specializationInfo)
182 {
183 std::unordered_map<uint32_t, std::vector<uint32_t>> specializations;
184 for (auto i = 0u; i < specializationInfo->mapEntryCount; ++i)
185 {
186 auto const &e = specializationInfo->pMapEntries[i];
187 auto value_ptr =
188 static_cast<uint32_t const *>(specializationInfo->pData) + e.offset / sizeof(uint32_t);
189 specializations.emplace(e.constantID,
190 std::vector<uint32_t>{value_ptr, value_ptr + e.size / sizeof(uint32_t)});
191 }
192 opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
193 }
194
195 // Full optimization list taken from spirv-opt.
196 opt.RegisterPass(spvtools::CreateWrapOpKillPass())
197 .RegisterPass(spvtools::CreateDeadBranchElimPass())
198 .RegisterPass(spvtools::CreateMergeReturnPass())
199 .RegisterPass(spvtools::CreateInlineExhaustivePass())
200 .RegisterPass(spvtools::CreateAggressiveDCEPass())
201 .RegisterPass(spvtools::CreatePrivateToLocalPass())
202 .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass())
203 .RegisterPass(spvtools::CreateLocalSingleStoreElimPass())
204 .RegisterPass(spvtools::CreateAggressiveDCEPass())
205 .RegisterPass(spvtools::CreateScalarReplacementPass())
206 .RegisterPass(spvtools::CreateLocalAccessChainConvertPass())
207 .RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass())
208 .RegisterPass(spvtools::CreateLocalSingleStoreElimPass())
209 .RegisterPass(spvtools::CreateAggressiveDCEPass())
210 .RegisterPass(spvtools::CreateLocalMultiStoreElimPass())
211 .RegisterPass(spvtools::CreateAggressiveDCEPass())
212 .RegisterPass(spvtools::CreateCCPPass())
213 .RegisterPass(spvtools::CreateAggressiveDCEPass())
214 .RegisterPass(spvtools::CreateRedundancyEliminationPass())
215 .RegisterPass(spvtools::CreateCombineAccessChainsPass())
216 .RegisterPass(spvtools::CreateSimplificationPass())
217 .RegisterPass(spvtools::CreateVectorDCEPass())
218 .RegisterPass(spvtools::CreateDeadInsertElimPass())
219 .RegisterPass(spvtools::CreateDeadBranchElimPass())
220 .RegisterPass(spvtools::CreateSimplificationPass())
221 .RegisterPass(spvtools::CreateIfConversionPass())
222 .RegisterPass(spvtools::CreateCopyPropagateArraysPass())
223 .RegisterPass(spvtools::CreateReduceLoadSizePass())
224 .RegisterPass(spvtools::CreateAggressiveDCEPass())
225 .RegisterPass(spvtools::CreateBlockMergePass())
226 .RegisterPass(spvtools::CreateRedundancyEliminationPass())
227 .RegisterPass(spvtools::CreateDeadBranchElimPass())
228 .RegisterPass(spvtools::CreateBlockMergePass())
229 .RegisterPass(spvtools::CreateSimplificationPass());
230
231 std::vector<uint32_t> optimized;
232 opt.Run(code.data(), code.size(), &optimized);
233
234 if (false) {
235 spvtools::SpirvTools core(SPV_ENV_VULKAN_1_1);
236 std::string preOpt;
237 core.Disassemble(code, &preOpt, SPV_BINARY_TO_TEXT_OPTION_NONE);
238 std::string postOpt;
239 core.Disassemble(optimized, &postOpt, SPV_BINARY_TO_TEXT_OPTION_NONE);
240 std::cout << "PRE-OPT: " << preOpt << std::endl
241 << "POST-OPT: " << postOpt << std::endl;
242 }
243
244 return optimized;
245}
246
247std::shared_ptr<sw::SpirvShader> createShader(const vk::PipelineCache::SpirvShaderKey& key, const vk::ShaderModule *module, bool robustBufferAccess)
248{
249 auto code = preprocessSpirv(key.getInsns(), key.getSpecializationInfo());
250 ASSERT(code.size() > 0);
251
252 // If the pipeline has specialization constants, assume they're unique and
253 // use a new serial ID so the shader gets recompiled.
254 uint32_t codeSerialID = (key.getSpecializationInfo() ? vk::ShaderModule::nextSerialID() : module->getSerialID());
255
256 // TODO(b/119409619): use allocator.
257 return std::make_shared<sw::SpirvShader>(codeSerialID, key.getPipelineStage(), key.getEntryPointName().c_str(),
258 code, key.getRenderPass(), key.getSubpassIndex(), robustBufferAccess);
259}
260
261std::shared_ptr<sw::ComputeProgram> createProgram(const vk::PipelineCache::ComputeProgramKey& key)
262{
263 MARL_SCOPED_EVENT("createProgram");
264
265 vk::DescriptorSet::Bindings descriptorSets; // FIXME(b/129523279): Delay code generation until invoke time.
266 // TODO(b/119409619): use allocator.
267 auto program = std::make_shared<sw::ComputeProgram>(key.getShader(), key.getLayout(), descriptorSets);
268 program->generate();
269 program->finalize();
270 return program;
271}
272
273} // anonymous namespace
274
275namespace vk
276{
277
278Pipeline::Pipeline(PipelineLayout const *layout, const Device *device)
279 : layout(layout),
280 robustBufferAccess(device->getEnabledFeatures().robustBufferAccess)
281{
282}
283
284GraphicsPipeline::GraphicsPipeline(const VkGraphicsPipelineCreateInfo* pCreateInfo, void* mem, const Device *device)
285 : Pipeline(vk::Cast(pCreateInfo->layout), device)
286{
287 context.robustBufferAccess = robustBufferAccess;
288
289 if(((pCreateInfo->flags &
290 ~(VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT |
291 VK_PIPELINE_CREATE_DERIVATIVE_BIT |
292 VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT)) != 0) ||
293 (pCreateInfo->pTessellationState != nullptr))
294 {
295 UNIMPLEMENTED("pCreateInfo settings");
296 }
297
298 if(pCreateInfo->pDynamicState)
299 {
300 for(uint32_t i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; i++)
301 {
302 VkDynamicState dynamicState = pCreateInfo->pDynamicState->pDynamicStates[i];
303 switch(dynamicState)
304 {
305 case VK_DYNAMIC_STATE_VIEWPORT:
306 case VK_DYNAMIC_STATE_SCISSOR:
307 case VK_DYNAMIC_STATE_LINE_WIDTH:
308 case VK_DYNAMIC_STATE_DEPTH_BIAS:
309 case VK_DYNAMIC_STATE_BLEND_CONSTANTS:
310 case VK_DYNAMIC_STATE_DEPTH_BOUNDS:
311 case VK_DYNAMIC_STATE_STENCIL_COMPARE_MASK:
312 case VK_DYNAMIC_STATE_STENCIL_WRITE_MASK:
313 case VK_DYNAMIC_STATE_STENCIL_REFERENCE:
314 ASSERT(dynamicState < (sizeof(dynamicStateFlags) * 8));
315 dynamicStateFlags |= (1 << dynamicState);
316 break;
317 default:
318 UNIMPLEMENTED("dynamic state");
319 }
320 }
321 }
322
323 const VkPipelineVertexInputStateCreateInfo* vertexInputState = pCreateInfo->pVertexInputState;
324 if(vertexInputState->flags != 0)
325 {
326 UNIMPLEMENTED("vertexInputState->flags");
327 }
328
329 // Context must always have a PipelineLayout set.
330 context.pipelineLayout = layout;
331
332 // Temporary in-binding-order representation of buffer strides, to be consumed below
333 // when considering attributes. TODO: unfuse buffers from attributes in backend, is old GL model.
334 uint32_t vertexStrides[MAX_VERTEX_INPUT_BINDINGS];
335 uint32_t instanceStrides[MAX_VERTEX_INPUT_BINDINGS];
336 for(uint32_t i = 0; i < vertexInputState->vertexBindingDescriptionCount; i++)
337 {
338 auto const & desc = vertexInputState->pVertexBindingDescriptions[i];
339 vertexStrides[desc.binding] = desc.inputRate == VK_VERTEX_INPUT_RATE_VERTEX ? desc.stride : 0;
340 instanceStrides[desc.binding] = desc.inputRate == VK_VERTEX_INPUT_RATE_INSTANCE ? desc.stride : 0;
341 }
342
343 for(uint32_t i = 0; i < vertexInputState->vertexAttributeDescriptionCount; i++)
344 {
345 auto const & desc = vertexInputState->pVertexAttributeDescriptions[i];
346 sw::Stream& input = context.input[desc.location];
347 input.count = getNumberOfChannels(desc.format);
348 input.type = getStreamType(desc.format);
349 input.normalized = !vk::Format(desc.format).isNonNormalizedInteger();
350 input.offset = desc.offset;
351 input.binding = desc.binding;
352 input.vertexStride = vertexStrides[desc.binding];
353 input.instanceStride = instanceStrides[desc.binding];
354 }
355
356 const VkPipelineInputAssemblyStateCreateInfo* assemblyState = pCreateInfo->pInputAssemblyState;
357 if(assemblyState->flags != 0)
358 {
359 UNIMPLEMENTED("pCreateInfo->pInputAssemblyState settings");
360 }
361
362 primitiveRestartEnable = (assemblyState->primitiveRestartEnable != VK_FALSE);
363 context.topology = assemblyState->topology;
364
365 const VkPipelineViewportStateCreateInfo* viewportState = pCreateInfo->pViewportState;
366 if(viewportState)
367 {
368 if((viewportState->flags != 0) ||
369 (viewportState->viewportCount != 1) ||
370 (viewportState->scissorCount != 1))
371 {
372 UNIMPLEMENTED("pCreateInfo->pViewportState settings");
373 }
374
375 if(!hasDynamicState(VK_DYNAMIC_STATE_SCISSOR))
376 {
377 scissor = viewportState->pScissors[0];
378 }
379
380 if(!hasDynamicState(VK_DYNAMIC_STATE_VIEWPORT))
381 {
382 viewport = viewportState->pViewports[0];
383 }
384 }
385
386 const VkPipelineRasterizationStateCreateInfo* rasterizationState = pCreateInfo->pRasterizationState;
387 if((rasterizationState->flags != 0) ||
388 (rasterizationState->depthClampEnable != VK_FALSE))
389 {
390 UNIMPLEMENTED("pCreateInfo->pRasterizationState settings");
391 }
392
393 context.rasterizerDiscard = (rasterizationState->rasterizerDiscardEnable == VK_TRUE);
394 context.cullMode = rasterizationState->cullMode;
395 context.frontFace = rasterizationState->frontFace;
396 context.polygonMode = rasterizationState->polygonMode;
397 context.depthBias = (rasterizationState->depthBiasEnable != VK_FALSE) ? rasterizationState->depthBiasConstantFactor : 0.0f;
398 context.slopeDepthBias = (rasterizationState->depthBiasEnable != VK_FALSE) ? rasterizationState->depthBiasSlopeFactor : 0.0f;
399
400 const VkBaseInStructure* extensionCreateInfo = reinterpret_cast<const VkBaseInStructure*>(rasterizationState->pNext);
401 while(extensionCreateInfo)
402 {
403 // Casting to a long since some structures, such as
404 // VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROVOKING_VERTEX_FEATURES_EXT
405 // are not enumerated in the official Vulkan header
406 switch((long)(extensionCreateInfo->sType))
407 {
408 case VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_LINE_STATE_CREATE_INFO_EXT:
409 {
410 const VkPipelineRasterizationLineStateCreateInfoEXT* lineStateCreateInfo = reinterpret_cast<const VkPipelineRasterizationLineStateCreateInfoEXT*>(extensionCreateInfo);
411 context.lineRasterizationMode = lineStateCreateInfo->lineRasterizationMode;
412 }
413 break;
414 case VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_PROVOKING_VERTEX_STATE_CREATE_INFO_EXT:
415 {
416 const VkPipelineRasterizationProvokingVertexStateCreateInfoEXT* provokingVertexModeCreateInfo =
417 reinterpret_cast<const VkPipelineRasterizationProvokingVertexStateCreateInfoEXT*>(extensionCreateInfo);
418 context.provokingVertexMode = provokingVertexModeCreateInfo->provokingVertexMode;
419 }
420 break;
421 default:
422 UNIMPLEMENTED("extensionCreateInfo->sType");
423 break;
424 }
425
426 extensionCreateInfo = extensionCreateInfo->pNext;
427 }
428
429 const VkPipelineMultisampleStateCreateInfo* multisampleState = pCreateInfo->pMultisampleState;
430 if(multisampleState)
431 {
432 switch (multisampleState->rasterizationSamples)
433 {
434 case VK_SAMPLE_COUNT_1_BIT:
435 context.sampleCount = 1;
436 break;
437 case VK_SAMPLE_COUNT_4_BIT:
438 context.sampleCount = 4;
439 break;
440 default:
441 UNIMPLEMENTED("Unsupported sample count");
442 }
443
444 if (multisampleState->pSampleMask)
445 {
446 context.sampleMask = multisampleState->pSampleMask[0];
447 }
448
449 context.alphaToCoverage = (multisampleState->alphaToCoverageEnable == VK_TRUE);
450
451 if((multisampleState->flags != 0) ||
452 (multisampleState->sampleShadingEnable != VK_FALSE) ||
453 (multisampleState->alphaToOneEnable != VK_FALSE))
454 {
455 UNIMPLEMENTED("multisampleState");
456 }
457 }
458 else
459 {
460 context.sampleCount = 1;
461 }
462
463 const VkPipelineDepthStencilStateCreateInfo* depthStencilState = pCreateInfo->pDepthStencilState;
464 if(depthStencilState)
465 {
466 if((depthStencilState->flags != 0) ||
467 (depthStencilState->depthBoundsTestEnable != VK_FALSE))
468 {
469 UNIMPLEMENTED("depthStencilState");
470 }
471
472 context.depthBoundsTestEnable = (depthStencilState->depthBoundsTestEnable == VK_TRUE);
473 context.depthBufferEnable = (depthStencilState->depthTestEnable == VK_TRUE);
474 context.depthWriteEnable = (depthStencilState->depthWriteEnable == VK_TRUE);
475 context.depthCompareMode = depthStencilState->depthCompareOp;
476
477 context.stencilEnable = (depthStencilState->stencilTestEnable == VK_TRUE);
478 if(context.stencilEnable)
479 {
480 context.frontStencil = depthStencilState->front;
481 context.backStencil = depthStencilState->back;
482 }
483 }
484
485 const VkPipelineColorBlendStateCreateInfo* colorBlendState = pCreateInfo->pColorBlendState;
486 if(colorBlendState)
487 {
488 if((colorBlendState->flags != 0) ||
489 ((colorBlendState->logicOpEnable != VK_FALSE)))
490 {
491 UNIMPLEMENTED("colorBlendState");
492 }
493
494 if(!hasDynamicState(VK_DYNAMIC_STATE_BLEND_CONSTANTS))
495 {
496 blendConstants.r = colorBlendState->blendConstants[0];
497 blendConstants.g = colorBlendState->blendConstants[1];
498 blendConstants.b = colorBlendState->blendConstants[2];
499 blendConstants.a = colorBlendState->blendConstants[3];
500 }
501
502 for (auto i = 0u; i < colorBlendState->attachmentCount; i++)
503 {
504 const VkPipelineColorBlendAttachmentState& attachment = colorBlendState->pAttachments[i];
505 context.colorWriteMask[i] = attachment.colorWriteMask;
506
507 context.setBlendState(i, { (attachment.blendEnable == VK_TRUE),
508 attachment.srcColorBlendFactor, attachment.dstColorBlendFactor, attachment.colorBlendOp,
509 attachment.srcAlphaBlendFactor, attachment.dstAlphaBlendFactor, attachment.alphaBlendOp });
510 }
511 }
512
513 context.multiSampleMask = context.sampleMask & ((unsigned) 0xFFFFFFFF >> (32 - context.sampleCount));
514}
515
516void GraphicsPipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator)
517{
518 vertexShader.reset();
519 fragmentShader.reset();
520}
521
522size_t GraphicsPipeline::ComputeRequiredAllocationSize(const VkGraphicsPipelineCreateInfo* pCreateInfo)
523{
524 return 0;
525}
526
527void GraphicsPipeline::setShader(const VkShaderStageFlagBits& stage, const std::shared_ptr<sw::SpirvShader> spirvShader)
528{
529 switch(stage)
530 {
531 case VK_SHADER_STAGE_VERTEX_BIT:
532 ASSERT(vertexShader.get() == nullptr);
533 vertexShader = spirvShader;
534 context.vertexShader = vertexShader.get();
535 break;
536
537 case VK_SHADER_STAGE_FRAGMENT_BIT:
538 ASSERT(fragmentShader.get() == nullptr);
539 fragmentShader = spirvShader;
540 context.pixelShader = fragmentShader.get();
541 break;
542
543 default:
544 UNSUPPORTED("Unsupported stage");
545 break;
546 }
547}
548
549const std::shared_ptr<sw::SpirvShader> GraphicsPipeline::getShader(const VkShaderStageFlagBits& stage) const
550{
551 switch(stage)
552 {
553 case VK_SHADER_STAGE_VERTEX_BIT:
554 return vertexShader;
555 case VK_SHADER_STAGE_FRAGMENT_BIT:
556 return fragmentShader;
557 default:
558 UNSUPPORTED("Unsupported stage");
559 return fragmentShader;
560 }
561}
562
563void GraphicsPipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkGraphicsPipelineCreateInfo* pCreateInfo, PipelineCache* pPipelineCache)
564{
565 for (auto pStage = pCreateInfo->pStages; pStage != pCreateInfo->pStages + pCreateInfo->stageCount; pStage++)
566 {
567 if (pStage->flags != 0)
568 {
569 UNIMPLEMENTED("pStage->flags");
570 }
571
572 const ShaderModule *module = vk::Cast(pStage->module);
573 const PipelineCache::SpirvShaderKey key(pStage->stage, pStage->pName, module->getCode(),
574 vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass,
575 pStage->pSpecializationInfo);
576 auto pipelineStage = key.getPipelineStage();
577
578 if(pPipelineCache)
579 {
580 PipelineCache& pipelineCache = *pPipelineCache;
581 {
582 std::unique_lock<std::mutex> lock(pipelineCache.getShaderMutex());
583 const std::shared_ptr<sw::SpirvShader>* spirvShader = pipelineCache[key];
584 if(!spirvShader)
585 {
586 auto shader = createShader(key, module, robustBufferAccess);
587 setShader(pipelineStage, shader);
588 pipelineCache.insert(key, getShader(pipelineStage));
589 }
590 else
591 {
592 setShader(pipelineStage, *spirvShader);
593 }
594 }
595 }
596 else
597 {
598 auto shader = createShader(key, module, robustBufferAccess);
599 setShader(pipelineStage, shader);
600 }
601 }
602}
603
604uint32_t GraphicsPipeline::computePrimitiveCount(uint32_t vertexCount) const
605{
606 switch(context.topology)
607 {
608 case VK_PRIMITIVE_TOPOLOGY_POINT_LIST:
609 return vertexCount;
610 case VK_PRIMITIVE_TOPOLOGY_LINE_LIST:
611 return vertexCount / 2;
612 case VK_PRIMITIVE_TOPOLOGY_LINE_STRIP:
613 return std::max<uint32_t>(vertexCount, 1) - 1;
614 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST:
615 return vertexCount / 3;
616 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP:
617 return std::max<uint32_t>(vertexCount, 2) - 2;
618 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_FAN:
619 return std::max<uint32_t>(vertexCount, 2) - 2;
620 default:
621 UNIMPLEMENTED("context.topology %d", int(context.topology));
622 }
623
624 return 0;
625}
626
627const sw::Context& GraphicsPipeline::getContext() const
628{
629 return context;
630}
631
632const VkRect2D& GraphicsPipeline::getScissor() const
633{
634 return scissor;
635}
636
637const VkViewport& GraphicsPipeline::getViewport() const
638{
639 return viewport;
640}
641
642const sw::Color<float>& GraphicsPipeline::getBlendConstants() const
643{
644 return blendConstants;
645}
646
647bool GraphicsPipeline::hasDynamicState(VkDynamicState dynamicState) const
648{
649 return (dynamicStateFlags & (1 << dynamicState)) != 0;
650}
651
652ComputePipeline::ComputePipeline(const VkComputePipelineCreateInfo* pCreateInfo, void* mem, const Device *device)
653 : Pipeline(vk::Cast(pCreateInfo->layout), device)
654{
655}
656
657void ComputePipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator)
658{
659 shader.reset();
660 program.reset();
661}
662
663size_t ComputePipeline::ComputeRequiredAllocationSize(const VkComputePipelineCreateInfo* pCreateInfo)
664{
665 return 0;
666}
667
668void ComputePipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkComputePipelineCreateInfo* pCreateInfo, PipelineCache* pPipelineCache)
669{
670 auto &stage = pCreateInfo->stage;
671 const ShaderModule *module = vk::Cast(stage.module);
672
673 ASSERT(shader.get() == nullptr);
674 ASSERT(program.get() == nullptr);
675
676 const PipelineCache::SpirvShaderKey shaderKey(
677 stage.stage, stage.pName, module->getCode(), nullptr, 0, stage.pSpecializationInfo);
678 if(pPipelineCache)
679 {
680 PipelineCache& pipelineCache = *pPipelineCache;
681 {
682 std::unique_lock<std::mutex> lock(pipelineCache.getShaderMutex());
683 const std::shared_ptr<sw::SpirvShader>* spirvShader = pipelineCache[shaderKey];
684 if(!spirvShader)
685 {
686 shader = createShader(shaderKey, module, robustBufferAccess);
687 pipelineCache.insert(shaderKey, shader);
688 }
689 else
690 {
691 shader = *spirvShader;
692 }
693 }
694
695 {
696 const PipelineCache::ComputeProgramKey programKey(shader.get(), layout);
697 std::unique_lock<std::mutex> lock(pipelineCache.getProgramMutex());
698 const std::shared_ptr<sw::ComputeProgram>* computeProgram = pipelineCache[programKey];
699 if(!computeProgram)
700 {
701 program = createProgram(programKey);
702 pipelineCache.insert(programKey, program);
703 }
704 else
705 {
706 program = *computeProgram;
707 }
708 }
709 }
710 else
711 {
712 shader = createShader(shaderKey, module, robustBufferAccess);
713 const PipelineCache::ComputeProgramKey programKey(shader.get(), layout);
714 program = createProgram(programKey);
715 }
716}
717
718void ComputePipeline::run(uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
719 uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ,
720 vk::DescriptorSet::Bindings const &descriptorSets,
721 vk::DescriptorSet::DynamicOffsets const &descriptorDynamicOffsets,
722 sw::PushConstantStorage const &pushConstants)
723{
724 ASSERT_OR_RETURN(program != nullptr);
725 program->run(
726 descriptorSets, descriptorDynamicOffsets, pushConstants,
727 baseGroupX, baseGroupY, baseGroupZ,
728 groupCountX, groupCountY, groupCountZ);
729}
730
731} // namespace vk
732