1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#ifndef SKSL_SPIRVCODEGENERATOR
9#define SKSL_SPIRVCODEGENERATOR
10
11#include <stack>
12#include <tuple>
13#include <unordered_map>
14
15#include "src/sksl/SkSLCodeGenerator.h"
16#include "src/sksl/SkSLMemoryLayout.h"
17#include "src/sksl/SkSLStringStream.h"
18#include "src/sksl/ir/SkSLBinaryExpression.h"
19#include "src/sksl/ir/SkSLBoolLiteral.h"
20#include "src/sksl/ir/SkSLConstructor.h"
21#include "src/sksl/ir/SkSLDoStatement.h"
22#include "src/sksl/ir/SkSLFieldAccess.h"
23#include "src/sksl/ir/SkSLFloatLiteral.h"
24#include "src/sksl/ir/SkSLForStatement.h"
25#include "src/sksl/ir/SkSLFunctionCall.h"
26#include "src/sksl/ir/SkSLFunctionDeclaration.h"
27#include "src/sksl/ir/SkSLFunctionDefinition.h"
28#include "src/sksl/ir/SkSLIfStatement.h"
29#include "src/sksl/ir/SkSLIndexExpression.h"
30#include "src/sksl/ir/SkSLIntLiteral.h"
31#include "src/sksl/ir/SkSLInterfaceBlock.h"
32#include "src/sksl/ir/SkSLPostfixExpression.h"
33#include "src/sksl/ir/SkSLPrefixExpression.h"
34#include "src/sksl/ir/SkSLProgramElement.h"
35#include "src/sksl/ir/SkSLReturnStatement.h"
36#include "src/sksl/ir/SkSLStatement.h"
37#include "src/sksl/ir/SkSLSwitchStatement.h"
38#include "src/sksl/ir/SkSLSwizzle.h"
39#include "src/sksl/ir/SkSLTernaryExpression.h"
40#include "src/sksl/ir/SkSLVarDeclarations.h"
41#include "src/sksl/ir/SkSLVarDeclarationsStatement.h"
42#include "src/sksl/ir/SkSLVariableReference.h"
43#include "src/sksl/ir/SkSLWhileStatement.h"
44#include "src/sksl/spirv.h"
45
46union ConstantValue {
47 ConstantValue(int64_t i)
48 : fInt(i) {}
49
50 ConstantValue(double d)
51 : fDouble(d) {}
52
53 bool operator==(const ConstantValue& other) const {
54 return fInt == other.fInt;
55 }
56
57 int64_t fInt;
58 double fDouble;
59};
60
61enum class ConstantType {
62 kInt,
63 kUInt,
64 kShort,
65 kUShort,
66 kFloat,
67 kDouble,
68 kHalf,
69};
70
71namespace std {
72
73template <>
74struct hash<std::pair<ConstantValue, ConstantType>> {
75 size_t operator()(const std::pair<ConstantValue, ConstantType>& key) const {
76 return key.first.fInt ^ (int) key.second;
77 }
78};
79
80}
81
82namespace SkSL {
83
84#define kLast_Capability SpvCapabilityMultiViewport
85
86/**
87 * Converts a Program into a SPIR-V binary.
88 */
89class SPIRVCodeGenerator : public CodeGenerator {
90public:
91 class LValue {
92 public:
93 virtual ~LValue() {}
94
95 // returns a pointer to the lvalue, if possible. If the lvalue cannot be directly referenced
96 // by a pointer (e.g. vector swizzles), returns 0.
97 virtual SpvId getPointer() = 0;
98
99 virtual SpvId load(OutputStream& out) = 0;
100
101 virtual void store(SpvId value, OutputStream& out) = 0;
102 };
103
104 SPIRVCodeGenerator(const Context* context, const Program* program, ErrorReporter* errors,
105 OutputStream* out)
106 : INHERITED(program, errors, out)
107 , fContext(*context)
108 , fDefaultLayout(MemoryLayout::k140_Standard)
109 , fCapabilities(0)
110 , fIdCount(1)
111 , fBoolTrue(0)
112 , fBoolFalse(0)
113 , fSetupFragPosition(false)
114 , fCurrentBlock(0)
115 , fSynthetics(nullptr, errors) {
116 this->setupIntrinsics();
117 }
118
119 bool generateCode() override;
120
121private:
122 enum IntrinsicKind {
123 kGLSL_STD_450_IntrinsicKind,
124 kSPIRV_IntrinsicKind,
125 kSpecial_IntrinsicKind
126 };
127
128 enum SpecialIntrinsic {
129 kAtan_SpecialIntrinsic,
130 kClamp_SpecialIntrinsic,
131 kMax_SpecialIntrinsic,
132 kMin_SpecialIntrinsic,
133 kMix_SpecialIntrinsic,
134 kMod_SpecialIntrinsic,
135 kDFdy_SpecialIntrinsic,
136 kSaturate_SpecialIntrinsic,
137 kSampledImage_SpecialIntrinsic,
138 kSubpassLoad_SpecialIntrinsic,
139 kTexture_SpecialIntrinsic,
140 };
141
142 enum class Precision {
143 kLow,
144 kHigh,
145 };
146
147 void setupIntrinsics();
148
149 SpvId nextId();
150
151 Type getActualType(const Type& type);
152
153 SpvId getType(const Type& type);
154
155 SpvId getType(const Type& type, const MemoryLayout& layout);
156
157 SpvId getImageType(const Type& type);
158
159 SpvId getFunctionType(const FunctionDeclaration& function);
160
161 SpvId getPointerType(const Type& type, SpvStorageClass_ storageClass);
162
163 SpvId getPointerType(const Type& type, const MemoryLayout& layout,
164 SpvStorageClass_ storageClass);
165
166 void writePrecisionModifier(Precision precision, SpvId id);
167
168 void writePrecisionModifier(const Type& type, SpvId id);
169
170 std::vector<SpvId> getAccessChain(const Expression& expr, OutputStream& out);
171
172 void writeLayout(const Layout& layout, SpvId target);
173
174 void writeLayout(const Layout& layout, SpvId target, int member);
175
176 void writeStruct(const Type& type, const MemoryLayout& layout, SpvId resultId);
177
178 void writeProgramElement(const ProgramElement& pe, OutputStream& out);
179
180 SpvId writeInterfaceBlock(const InterfaceBlock& intf);
181
182 SpvId writeFunctionStart(const FunctionDeclaration& f, OutputStream& out);
183
184 SpvId writeFunctionDeclaration(const FunctionDeclaration& f, OutputStream& out);
185
186 SpvId writeFunction(const FunctionDefinition& f, OutputStream& out);
187
188 void writeGlobalVars(Program::Kind kind, const VarDeclarations& v, OutputStream& out);
189
190 void writeVarDeclarations(const VarDeclarations& decl, OutputStream& out);
191
192 SpvId writeVariableReference(const VariableReference& ref, OutputStream& out);
193
194 std::unique_ptr<LValue> getLValue(const Expression& value, OutputStream& out);
195
196 SpvId writeExpression(const Expression& expr, OutputStream& out);
197
198 SpvId writeIntrinsicCall(const FunctionCall& c, OutputStream& out);
199
200 SpvId writeFunctionCall(const FunctionCall& c, OutputStream& out);
201
202
203 void writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
204 SpvId signedInst, SpvId unsignedInst,
205 const std::vector<SpvId>& args, OutputStream& out);
206
207 /**
208 * Given a list of potentially mixed scalars and vectors, promotes the scalars to match the
209 * size of the vectors and returns the ids of the written expressions. e.g. given (float, vec2),
210 * returns (vec2(float), vec2). It is an error to use mismatched vector sizes, e.g. (float,
211 * vec2, vec3).
212 */
213 std::vector<SpvId> vectorize(const std::vector<std::unique_ptr<Expression>>& args,
214 OutputStream& out);
215
216 SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
217
218 SpvId writeConstantVector(const Constructor& c);
219
220 SpvId writeFloatConstructor(const Constructor& c, OutputStream& out);
221
222 SpvId writeIntConstructor(const Constructor& c, OutputStream& out);
223
224 SpvId writeUIntConstructor(const Constructor& c, OutputStream& out);
225
226 /**
227 * Writes a matrix with the diagonal entries all equal to the provided expression, and all other
228 * entries equal to zero.
229 */
230 void writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, OutputStream& out);
231
232 /**
233 * Writes a potentially-different-sized copy of a matrix. Entries which do not exist in the
234 * source matrix are filled with zero; entries which do not exist in the destination matrix are
235 * ignored.
236 */
237 void writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, const Type& dstType,
238 OutputStream& out);
239
240 void addColumnEntry(SpvId columnType, Precision precision, std::vector<SpvId>* currentColumn,
241 std::vector<SpvId>* columnIds, int* currentCount, int rows, SpvId entry,
242 OutputStream& out);
243
244 SpvId writeMatrixConstructor(const Constructor& c, OutputStream& out);
245
246 SpvId writeVectorConstructor(const Constructor& c, OutputStream& out);
247
248 SpvId writeArrayConstructor(const Constructor& c, OutputStream& out);
249
250 SpvId writeConstructor(const Constructor& c, OutputStream& out);
251
252 SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
253
254 SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
255
256 /**
257 * Folds the potentially-vector result of a logical operation down to a single bool. If
258 * operandType is a vector type, assumes that the intermediate result in id is a bvec of the
259 * same dimensions, and applys all() to it to fold it down to a single bool value. Otherwise,
260 * returns the original id value.
261 */
262 SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
263
264 SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
265 SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
266 SpvOp_ mergeOperator, OutputStream& out);
267
268 SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
269 SpvOp_ floatOperator, SpvOp_ intOperator,
270 OutputStream& out);
271
272 SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
273 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,
274 SpvOp_ ifBool, OutputStream& out);
275
276 SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
277 SpvOp_ ifUInt, OutputStream& out);
278
279 SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
280 const Type& rightType, SpvId rhs, const Type& resultType,
281 OutputStream& out);
282
283 SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
284
285 SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
286
287 SpvId writeIndexExpression(const IndexExpression& expr, OutputStream& out);
288
289 SpvId writeLogicalAnd(const BinaryExpression& b, OutputStream& out);
290
291 SpvId writeLogicalOr(const BinaryExpression& o, OutputStream& out);
292
293 SpvId writePrefixExpression(const PrefixExpression& p, OutputStream& out);
294
295 SpvId writePostfixExpression(const PostfixExpression& p, OutputStream& out);
296
297 SpvId writeBoolLiteral(const BoolLiteral& b);
298
299 SpvId writeIntLiteral(const IntLiteral& i);
300
301 SpvId writeFloatLiteral(const FloatLiteral& f);
302
303 void writeStatement(const Statement& s, OutputStream& out);
304
305 void writeBlock(const Block& b, OutputStream& out);
306
307 void writeIfStatement(const IfStatement& stmt, OutputStream& out);
308
309 void writeForStatement(const ForStatement& f, OutputStream& out);
310
311 void writeWhileStatement(const WhileStatement& w, OutputStream& out);
312
313 void writeDoStatement(const DoStatement& d, OutputStream& out);
314
315 void writeSwitchStatement(const SwitchStatement& s, OutputStream& out);
316
317 void writeReturnStatement(const ReturnStatement& r, OutputStream& out);
318
319 void writeCapabilities(OutputStream& out);
320
321 void writeInstructions(const Program& program, OutputStream& out);
322
323 void writeOpCode(SpvOp_ opCode, int length, OutputStream& out);
324
325 void writeWord(int32_t word, OutputStream& out);
326
327 void writeString(const char* string, size_t length, OutputStream& out);
328
329 void writeLabel(SpvId id, OutputStream& out);
330
331 void writeInstruction(SpvOp_ opCode, OutputStream& out);
332
333 void writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out);
334
335 void writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out);
336
337 void writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, OutputStream& out);
338
339 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, StringFragment string,
340 OutputStream& out);
341
342 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, OutputStream& out);
343
344 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3,
345 OutputStream& out);
346
347 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
348 OutputStream& out);
349
350 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
351 int32_t word5, OutputStream& out);
352
353 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
354 int32_t word5, int32_t word6, OutputStream& out);
355
356 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
357 int32_t word5, int32_t word6, int32_t word7, OutputStream& out);
358
359 void writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, int32_t word3, int32_t word4,
360 int32_t word5, int32_t word6, int32_t word7, int32_t word8,
361 OutputStream& out);
362
363 void writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out);
364
365 const Context& fContext;
366 const MemoryLayout fDefaultLayout;
367
368 uint64_t fCapabilities;
369 SpvId fIdCount;
370 SpvId fGLSLExtendedInstructions;
371 typedef std::tuple<IntrinsicKind, int32_t, int32_t, int32_t, int32_t> Intrinsic;
372 std::unordered_map<String, Intrinsic> fIntrinsicMap;
373 std::unordered_map<const FunctionDeclaration*, SpvId> fFunctionMap;
374 std::unordered_map<const Variable*, SpvId> fVariableMap;
375 std::unordered_map<const Variable*, int32_t> fInterfaceBlockMap;
376 std::unordered_map<String, SpvId> fImageTypeMap;
377 std::unordered_map<String, SpvId> fTypeMap;
378 StringStream fCapabilitiesBuffer;
379 StringStream fGlobalInitializersBuffer;
380 StringStream fConstantBuffer;
381 StringStream fExtraGlobalsBuffer;
382 StringStream fExternalFunctionsBuffer;
383 StringStream fVariableBuffer;
384 StringStream fNameBuffer;
385 StringStream fDecorationBuffer;
386
387 SpvId fBoolTrue;
388 SpvId fBoolFalse;
389 std::unordered_map<std::pair<ConstantValue, ConstantType>, SpvId> fNumberConstants;
390 // The constant float2(0, 1), used in swizzling
391 SpvId fConstantZeroOneVector = 0;
392 bool fSetupFragPosition;
393 // label of the current block, or 0 if we are not in a block
394 SpvId fCurrentBlock;
395 std::stack<SpvId> fBreakTarget;
396 std::stack<SpvId> fContinueTarget;
397 SpvId fRTHeightStructId = (SpvId) -1;
398 SpvId fRTHeightFieldIndex = (SpvId) -1;
399 // holds variables synthesized during output, for lifetime purposes
400 SymbolTable fSynthetics;
401 int fSkInCount = 1;
402
403 friend class PointerLValue;
404 friend class SwizzleLValue;
405
406 typedef CodeGenerator INHERITED;
407};
408
409}
410
411#endif
412