1 | /* |
2 | * Copyright 2016 Google Inc. |
3 | * |
4 | * Use of this source code is governed by a BSD-style license that can be |
5 | * found in the LICENSE file. |
6 | */ |
7 | |
8 | #include "src/sksl/SkSLMetalCodeGenerator.h" |
9 | |
10 | #include "src/sksl/SkSLCompiler.h" |
11 | #include "src/sksl/ir/SkSLExpressionStatement.h" |
12 | #include "src/sksl/ir/SkSLExtension.h" |
13 | #include "src/sksl/ir/SkSLIndexExpression.h" |
14 | #include "src/sksl/ir/SkSLModifiersDeclaration.h" |
15 | #include "src/sksl/ir/SkSLNop.h" |
16 | #include "src/sksl/ir/SkSLVariableReference.h" |
17 | |
18 | #include <algorithm> |
19 | |
20 | namespace SkSL { |
21 | |
22 | class MetalCodeGenerator::GlobalStructVisitor { |
23 | public: |
24 | virtual ~GlobalStructVisitor() = default; |
25 | virtual void VisitInterfaceBlock(const InterfaceBlock& block, const String& blockName) = 0; |
26 | virtual void VisitTexture(const Type& type, const String& name) = 0; |
27 | virtual void VisitSampler(const Type& type, const String& name) = 0; |
28 | virtual void VisitVariable(const Variable& var, const Expression* value) = 0; |
29 | }; |
30 | |
31 | void MetalCodeGenerator::setupIntrinsics() { |
32 | #define METAL(x) std::make_pair(kMetal_IntrinsicKind, k ## x ## _MetalIntrinsic) |
33 | #define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic) |
34 | fIntrinsicMap[String("sample" )] = SPECIAL(Texture); |
35 | fIntrinsicMap[String("mod" )] = SPECIAL(Mod); |
36 | fIntrinsicMap[String("equal" )] = METAL(Equal); |
37 | fIntrinsicMap[String("notEqual" )] = METAL(NotEqual); |
38 | fIntrinsicMap[String("lessThan" )] = METAL(LessThan); |
39 | fIntrinsicMap[String("lessThanEqual" )] = METAL(LessThanEqual); |
40 | fIntrinsicMap[String("greaterThan" )] = METAL(GreaterThan); |
41 | fIntrinsicMap[String("greaterThanEqual" )] = METAL(GreaterThanEqual); |
42 | } |
43 | |
44 | void MetalCodeGenerator::write(const char* s) { |
45 | if (!s[0]) { |
46 | return; |
47 | } |
48 | if (fAtLineStart) { |
49 | for (int i = 0; i < fIndentation; i++) { |
50 | fOut->writeText(" " ); |
51 | } |
52 | } |
53 | fOut->writeText(s); |
54 | fAtLineStart = false; |
55 | } |
56 | |
57 | void MetalCodeGenerator::writeLine(const char* s) { |
58 | this->write(s); |
59 | fOut->writeText(fLineEnding); |
60 | fAtLineStart = true; |
61 | } |
62 | |
63 | void MetalCodeGenerator::write(const String& s) { |
64 | this->write(s.c_str()); |
65 | } |
66 | |
67 | void MetalCodeGenerator::writeLine(const String& s) { |
68 | this->writeLine(s.c_str()); |
69 | } |
70 | |
71 | void MetalCodeGenerator::writeLine() { |
72 | this->writeLine("" ); |
73 | } |
74 | |
75 | void MetalCodeGenerator::writeExtension(const Extension& ext) { |
76 | this->writeLine("#extension " + ext.fName + " : enable" ); |
77 | } |
78 | |
79 | String MetalCodeGenerator::typeName(const Type& type) { |
80 | switch (type.kind()) { |
81 | case Type::kVector_Kind: |
82 | return this->typeName(type.componentType()) + to_string(type.columns()); |
83 | case Type::kMatrix_Kind: |
84 | return this->typeName(type.componentType()) + to_string(type.columns()) + "x" + |
85 | to_string(type.rows()); |
86 | case Type::kSampler_Kind: |
87 | return "texture2d<float>" ; // FIXME - support other texture types; |
88 | default: |
89 | if (type == *fContext.fHalf_Type) { |
90 | // FIXME - Currently only supporting floats in MSL to avoid type coercion issues. |
91 | return fContext.fFloat_Type->name(); |
92 | } else if (type == *fContext.fByte_Type) { |
93 | return "char" ; |
94 | } else if (type == *fContext.fUByte_Type) { |
95 | return "uchar" ; |
96 | } else { |
97 | return type.name(); |
98 | } |
99 | } |
100 | } |
101 | |
102 | void MetalCodeGenerator::writeType(const Type& type) { |
103 | if (type.kind() == Type::kStruct_Kind) { |
104 | for (const Type* search : fWrittenStructs) { |
105 | if (*search == type) { |
106 | // already written |
107 | this->write(type.name()); |
108 | return; |
109 | } |
110 | } |
111 | fWrittenStructs.push_back(&type); |
112 | this->writeLine("struct " + type.name() + " {" ); |
113 | fIndentation++; |
114 | this->writeFields(type.fields(), type.fOffset); |
115 | fIndentation--; |
116 | this->write("}" ); |
117 | } else { |
118 | this->write(this->typeName(type)); |
119 | } |
120 | } |
121 | |
122 | void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) { |
123 | switch (expr.fKind) { |
124 | case Expression::kBinary_Kind: |
125 | this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence); |
126 | break; |
127 | case Expression::kBoolLiteral_Kind: |
128 | this->writeBoolLiteral((BoolLiteral&) expr); |
129 | break; |
130 | case Expression::kConstructor_Kind: |
131 | this->writeConstructor((Constructor&) expr, parentPrecedence); |
132 | break; |
133 | case Expression::kIntLiteral_Kind: |
134 | this->writeIntLiteral((IntLiteral&) expr); |
135 | break; |
136 | case Expression::kFieldAccess_Kind: |
137 | this->writeFieldAccess(((FieldAccess&) expr)); |
138 | break; |
139 | case Expression::kFloatLiteral_Kind: |
140 | this->writeFloatLiteral(((FloatLiteral&) expr)); |
141 | break; |
142 | case Expression::kFunctionCall_Kind: |
143 | this->writeFunctionCall((FunctionCall&) expr); |
144 | break; |
145 | case Expression::kPrefix_Kind: |
146 | this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence); |
147 | break; |
148 | case Expression::kPostfix_Kind: |
149 | this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence); |
150 | break; |
151 | case Expression::kSetting_Kind: |
152 | this->writeSetting((Setting&) expr); |
153 | break; |
154 | case Expression::kSwizzle_Kind: |
155 | this->writeSwizzle((Swizzle&) expr); |
156 | break; |
157 | case Expression::kVariableReference_Kind: |
158 | this->writeVariableReference((VariableReference&) expr); |
159 | break; |
160 | case Expression::kTernary_Kind: |
161 | this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence); |
162 | break; |
163 | case Expression::kIndex_Kind: |
164 | this->writeIndexExpression((IndexExpression&) expr); |
165 | break; |
166 | default: |
167 | #ifdef SK_DEBUG |
168 | ABORT("unsupported expression: %s" , expr.description().c_str()); |
169 | #endif |
170 | break; |
171 | } |
172 | } |
173 | |
174 | void MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c) { |
175 | auto i = fIntrinsicMap.find(c.fFunction.fName); |
176 | SkASSERT(i != fIntrinsicMap.end()); |
177 | Intrinsic intrinsic = i->second; |
178 | int32_t intrinsicId = intrinsic.second; |
179 | switch (intrinsic.first) { |
180 | case kSpecial_IntrinsicKind: |
181 | return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId); |
182 | break; |
183 | case kMetal_IntrinsicKind: |
184 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
185 | switch ((MetalIntrinsic) intrinsicId) { |
186 | case kEqual_MetalIntrinsic: |
187 | this->write(" == " ); |
188 | break; |
189 | case kNotEqual_MetalIntrinsic: |
190 | this->write(" != " ); |
191 | break; |
192 | case kLessThan_MetalIntrinsic: |
193 | this->write(" < " ); |
194 | break; |
195 | case kLessThanEqual_MetalIntrinsic: |
196 | this->write(" <= " ); |
197 | break; |
198 | case kGreaterThan_MetalIntrinsic: |
199 | this->write(" > " ); |
200 | break; |
201 | case kGreaterThanEqual_MetalIntrinsic: |
202 | this->write(" >= " ); |
203 | break; |
204 | default: |
205 | ABORT("unsupported metal intrinsic kind" ); |
206 | } |
207 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
208 | break; |
209 | default: |
210 | ABORT("unsupported intrinsic kind" ); |
211 | } |
212 | } |
213 | |
214 | void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) { |
215 | const auto& entry = fIntrinsicMap.find(c.fFunction.fName); |
216 | if (entry != fIntrinsicMap.end()) { |
217 | this->writeIntrinsicCall(c); |
218 | return; |
219 | } |
220 | if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) { |
221 | this->write("atan2" ); |
222 | } else if (c.fFunction.fBuiltin && "inversesqrt" == c.fFunction.fName) { |
223 | this->write("rsqrt" ); |
224 | } else if (c.fFunction.fBuiltin && "inverse" == c.fFunction.fName) { |
225 | SkASSERT(c.fArguments.size() == 1); |
226 | this->writeInverseHack(*c.fArguments[0]); |
227 | } else if (c.fFunction.fBuiltin && "dFdx" == c.fFunction.fName) { |
228 | this->write("dfdx" ); |
229 | } else if (c.fFunction.fBuiltin && "dFdy" == c.fFunction.fName) { |
230 | // Flipping Y also negates the Y derivatives. |
231 | this->write((fProgram.fSettings.fFlipY) ? "-dfdy" : "dfdy" ); |
232 | } else { |
233 | this->writeName(c.fFunction.fName); |
234 | } |
235 | this->write("(" ); |
236 | const char* separator = "" ; |
237 | if (this->requirements(c.fFunction) & kInputs_Requirement) { |
238 | this->write("_in" ); |
239 | separator = ", " ; |
240 | } |
241 | if (this->requirements(c.fFunction) & kOutputs_Requirement) { |
242 | this->write(separator); |
243 | this->write("_out" ); |
244 | separator = ", " ; |
245 | } |
246 | if (this->requirements(c.fFunction) & kUniforms_Requirement) { |
247 | this->write(separator); |
248 | this->write("_uniforms" ); |
249 | separator = ", " ; |
250 | } |
251 | if (this->requirements(c.fFunction) & kGlobals_Requirement) { |
252 | this->write(separator); |
253 | this->write("_globals" ); |
254 | separator = ", " ; |
255 | } |
256 | if (this->requirements(c.fFunction) & kFragCoord_Requirement) { |
257 | this->write(separator); |
258 | this->write("_fragCoord" ); |
259 | separator = ", " ; |
260 | } |
261 | for (size_t i = 0; i < c.fArguments.size(); ++i) { |
262 | const Expression& arg = *c.fArguments[i]; |
263 | this->write(separator); |
264 | separator = ", " ; |
265 | if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) { |
266 | this->write("&" ); |
267 | } |
268 | this->writeExpression(arg, kSequence_Precedence); |
269 | } |
270 | this->write(")" ); |
271 | } |
272 | |
273 | void MetalCodeGenerator::writeInverseHack(const Expression& mat) { |
274 | String typeName = mat.fType.name(); |
275 | String name = typeName + "_inverse" ; |
276 | if (mat.fType == *fContext.fFloat2x2_Type || mat.fType == *fContext.fHalf2x2_Type) { |
277 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
278 | fWrittenIntrinsics.insert(name); |
279 | fExtraFunctions.writeText(( |
280 | typeName + " " + name + "(" + typeName + " m) {" |
281 | " return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));" |
282 | "}" |
283 | ).c_str()); |
284 | } |
285 | } |
286 | else if (mat.fType == *fContext.fFloat3x3_Type || mat.fType == *fContext.fHalf3x3_Type) { |
287 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
288 | fWrittenIntrinsics.insert(name); |
289 | fExtraFunctions.writeText(( |
290 | typeName + " " + name + "(" + typeName + " m) {" |
291 | " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];" |
292 | " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];" |
293 | " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];" |
294 | " float b01 = a22 * a11 - a12 * a21;" |
295 | " float b11 = -a22 * a10 + a12 * a20;" |
296 | " float b21 = a21 * a10 - a11 * a20;" |
297 | " float det = a00 * b01 + a01 * b11 + a02 * b21;" |
298 | " return " + typeName + |
299 | " (b01, (-a22 * a01 + a02 * a21), (a12 * a01 - a02 * a11)," |
300 | " b11, (a22 * a00 - a02 * a20), (-a12 * a00 + a02 * a10)," |
301 | " b21, (-a21 * a00 + a01 * a20), (a11 * a00 - a01 * a10)) * " |
302 | " (1/det);" |
303 | "}" |
304 | ).c_str()); |
305 | } |
306 | } |
307 | else if (mat.fType == *fContext.fFloat4x4_Type || mat.fType == *fContext.fHalf4x4_Type) { |
308 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
309 | fWrittenIntrinsics.insert(name); |
310 | fExtraFunctions.writeText(( |
311 | typeName + " " + name + "(" + typeName + " m) {" |
312 | " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];" |
313 | " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];" |
314 | " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];" |
315 | " float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];" |
316 | " float b00 = a00 * a11 - a01 * a10;" |
317 | " float b01 = a00 * a12 - a02 * a10;" |
318 | " float b02 = a00 * a13 - a03 * a10;" |
319 | " float b03 = a01 * a12 - a02 * a11;" |
320 | " float b04 = a01 * a13 - a03 * a11;" |
321 | " float b05 = a02 * a13 - a03 * a12;" |
322 | " float b06 = a20 * a31 - a21 * a30;" |
323 | " float b07 = a20 * a32 - a22 * a30;" |
324 | " float b08 = a20 * a33 - a23 * a30;" |
325 | " float b09 = a21 * a32 - a22 * a31;" |
326 | " float b10 = a21 * a33 - a23 * a31;" |
327 | " float b11 = a22 * a33 - a23 * a32;" |
328 | " float det = b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - " |
329 | " b04 * b07 + b05 * b06;" |
330 | " return " + typeName + "(a11 * b11 - a12 * b10 + a13 * b09," |
331 | " a02 * b10 - a01 * b11 - a03 * b09," |
332 | " a31 * b05 - a32 * b04 + a33 * b03," |
333 | " a22 * b04 - a21 * b05 - a23 * b03," |
334 | " a12 * b08 - a10 * b11 - a13 * b07," |
335 | " a00 * b11 - a02 * b08 + a03 * b07," |
336 | " a32 * b02 - a30 * b05 - a33 * b01," |
337 | " a20 * b05 - a22 * b02 + a23 * b01," |
338 | " a10 * b10 - a11 * b08 + a13 * b06," |
339 | " a01 * b08 - a00 * b10 - a03 * b06," |
340 | " a30 * b04 - a31 * b02 + a33 * b00," |
341 | " a21 * b02 - a20 * b04 - a23 * b00," |
342 | " a11 * b07 - a10 * b09 - a12 * b06," |
343 | " a00 * b09 - a01 * b07 + a02 * b06," |
344 | " a31 * b01 - a30 * b03 - a32 * b00," |
345 | " a20 * b03 - a21 * b01 + a22 * b00) / det;" |
346 | "}" |
347 | ).c_str()); |
348 | } |
349 | } |
350 | this->write(name); |
351 | } |
352 | |
353 | void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall & c, SpecialIntrinsic kind) { |
354 | switch (kind) { |
355 | case kTexture_SpecialIntrinsic: |
356 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
357 | this->write(".sample(" ); |
358 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
359 | this->write(SAMPLER_SUFFIX); |
360 | this->write(", " ); |
361 | if (c.fArguments[1]->fType == *fContext.fFloat3_Type) { |
362 | // have to store the vector in a temp variable to avoid double evaluating it |
363 | String tmpVar = "tmpCoord" + to_string(fVarCount++); |
364 | this->fFunctionHeader += " " + this->typeName(c.fArguments[1]->fType) + " " + |
365 | tmpVar + ";\n" ; |
366 | this->write("(" + tmpVar + " = " ); |
367 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
368 | this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))" ); |
369 | } else { |
370 | SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type); |
371 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
372 | this->write(")" ); |
373 | } |
374 | break; |
375 | case kMod_SpecialIntrinsic: { |
376 | // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y) |
377 | String tmpX = "tmpX" + to_string(fVarCount++); |
378 | String tmpY = "tmpY" + to_string(fVarCount++); |
379 | this->fFunctionHeader += " " + this->typeName(c.fArguments[0]->fType) + " " + tmpX + |
380 | ", " + tmpY + ";\n" ; |
381 | this->write("(" + tmpX + " = " ); |
382 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
383 | this->write(", " + tmpY + " = " ); |
384 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
385 | this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))" ); |
386 | break; |
387 | } |
388 | default: |
389 | ABORT("unsupported special intrinsic kind" ); |
390 | } |
391 | } |
392 | |
393 | // Assembles a matrix of type floatRxC by resizing another matrix named `x0`. |
394 | // Cells that don't exist in the source matrix will be populated with identity-matrix values. |
395 | void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) { |
396 | SkASSERT(rows <= 4); |
397 | SkASSERT(columns <= 4); |
398 | |
399 | const char* columnSeparator = "" ; |
400 | for (int c = 0; c < columns; ++c) { |
401 | fExtraFunctions.printf("%sfloat%d(" , columnSeparator, rows); |
402 | columnSeparator = "), " ; |
403 | |
404 | // Determine how many values to take from the source matrix for this row. |
405 | int swizzleLength = 0; |
406 | if (c < sourceMatrix.columns()) { |
407 | swizzleLength = std::min<>(rows, sourceMatrix.rows()); |
408 | } |
409 | |
410 | // Emit all the values from the source matrix row. |
411 | bool firstItem; |
412 | switch (swizzleLength) { |
413 | case 0: firstItem = true; break; |
414 | case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x" , c); break; |
415 | case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy" , c); break; |
416 | case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz" , c); break; |
417 | case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw" , c); break; |
418 | default: SkUNREACHABLE; |
419 | } |
420 | |
421 | // Emit the placeholder identity-matrix cells. |
422 | for (int r = swizzleLength; r < rows; ++r) { |
423 | fExtraFunctions.printf("%s%s" , firstItem ? "" : ", " , (r == c) ? "1.0" : "0.0" ); |
424 | firstItem = false; |
425 | } |
426 | } |
427 | |
428 | fExtraFunctions.writeText(")" ); |
429 | } |
430 | |
431 | // Assembles a matrix of type floatRxC by concatenating an arbitrary mix of values, named `x0`, |
432 | // `x1`, etc. An error is written if the expression list don't contain exactly R*C scalars. |
433 | void MetalCodeGenerator::assembleMatrixFromExpressions( |
434 | const std::vector<std::unique_ptr<Expression>>& args, int rows, int columns) { |
435 | size_t argIndex = 0; |
436 | int argPosition = 0; |
437 | |
438 | const char* columnSeparator = "" ; |
439 | for (int c = 0; c < columns; ++c) { |
440 | fExtraFunctions.printf("%sfloat%d(" , columnSeparator, rows); |
441 | columnSeparator = "), " ; |
442 | |
443 | const char* rowSeparator = "" ; |
444 | for (int r = 0; r < rows; ++r) { |
445 | fExtraFunctions.writeText(rowSeparator); |
446 | rowSeparator = ", " ; |
447 | |
448 | if (argIndex < args.size()) { |
449 | const Type& argType = args[argIndex]->fType; |
450 | switch (argType.kind()) { |
451 | case Type::kScalar_Kind: { |
452 | fExtraFunctions.printf("x%zu" , argIndex); |
453 | break; |
454 | } |
455 | case Type::kVector_Kind: { |
456 | fExtraFunctions.printf("x%zu[%d]" , argIndex, argPosition); |
457 | break; |
458 | } |
459 | case Type::kMatrix_Kind: { |
460 | fExtraFunctions.printf("x%zu[%d][%d]" , argIndex, |
461 | argPosition / argType.rows(), |
462 | argPosition % argType.rows()); |
463 | break; |
464 | } |
465 | default: { |
466 | SkDEBUGFAIL("incorrect type of argument for matrix constructor" ); |
467 | fExtraFunctions.writeText("<error>" ); |
468 | break; |
469 | } |
470 | } |
471 | |
472 | ++argPosition; |
473 | if (argPosition >= argType.columns() * argType.rows()) { |
474 | ++argIndex; |
475 | argPosition = 0; |
476 | } |
477 | } else { |
478 | SkDEBUGFAIL("not enough arguments for matrix constructor" ); |
479 | fExtraFunctions.writeText("<error>" ); |
480 | } |
481 | } |
482 | } |
483 | |
484 | if (argPosition != 0 || argIndex != args.size()) { |
485 | SkDEBUGFAIL("incorrect number of arguments for matrix constructor" ); |
486 | fExtraFunctions.writeText(", <error>" ); |
487 | } |
488 | |
489 | fExtraFunctions.writeText(")" ); |
490 | } |
491 | |
492 | // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape. |
493 | // Keeps track of previously generated constructors so that we won't generate more than one |
494 | // constructor for any given permutation of input argument types. Returns the name of the |
495 | // generated constructor method. |
496 | String MetalCodeGenerator::getMatrixConstructHelper(const Constructor& c) { |
497 | const Type& matrix = c.fType; |
498 | int columns = matrix.columns(); |
499 | int rows = matrix.rows(); |
500 | const std::vector<std::unique_ptr<Expression>>& args = c.fArguments; |
501 | |
502 | // Create the helper-method name and use it as our lookup key. |
503 | String name; |
504 | name.appendf("float%dx%d_from" , columns, rows); |
505 | for (const std::unique_ptr<Expression>& expr : args) { |
506 | name.appendf("_%s" , expr->fType.displayName().c_str()); |
507 | } |
508 | |
509 | // If a helper-method has already been synthesized, we don't need to synthesize it again. |
510 | auto [iter, newlyCreated] = fHelpers.insert(name); |
511 | if (!newlyCreated) { |
512 | return name; |
513 | } |
514 | |
515 | // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C |
516 | // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot |
517 | // supply a mixture of scalars and vectors.) |
518 | fExtraFunctions.printf("float%dx%d %s(" , columns, rows, name.c_str()); |
519 | |
520 | size_t argIndex = 0; |
521 | const char* argSeparator = "" ; |
522 | for (const std::unique_ptr<Expression>& expr : args) { |
523 | fExtraFunctions.printf("%s%s x%zu" , argSeparator, |
524 | expr->fType.displayName().c_str(), argIndex++); |
525 | argSeparator = ", " ; |
526 | } |
527 | |
528 | fExtraFunctions.printf(") {\n return float%dx%d(" , columns, rows); |
529 | |
530 | if (args.size() == 1 && args.front()->fType.kind() == Type::kMatrix_Kind) { |
531 | this->assembleMatrixFromMatrix(args.front()->fType, rows, columns); |
532 | } else { |
533 | this->assembleMatrixFromExpressions(args, rows, columns); |
534 | } |
535 | |
536 | fExtraFunctions.writeText(");\n}\n" ); |
537 | return name; |
538 | } |
539 | |
540 | bool MetalCodeGenerator::canCoerce(const Type& t1, const Type& t2) { |
541 | if (t1.columns() != t2.columns() || t1.rows() != t2.rows()) { |
542 | return false; |
543 | } |
544 | if (t1.columns() > 1) { |
545 | return this->canCoerce(t1.componentType(), t2.componentType()); |
546 | } |
547 | return t1.isFloat() && t2.isFloat(); |
548 | } |
549 | |
550 | bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const Constructor& c) { |
551 | // A matrix construct helper is only necessary if we are, in fact, constructing a matrix. |
552 | if (c.fType.kind() != Type::kMatrix_Kind) { |
553 | return false; |
554 | } |
555 | |
556 | // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it |
557 | // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C |
558 | // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of |
559 | // scalars can be constructed trivially. In more complex cases, we generate a helper function |
560 | // that converts our inputs into a properly-shaped matrix. |
561 | // A matrix construct helper method is always used if any input argument is a matrix. |
562 | // Helper methods are also necessary when any argument would span multiple rows. For instance: |
563 | // |
564 | // float2 x = (1, 2); |
565 | // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline |
566 | // | 2 4 6 | |
567 | // |
568 | // float2 x = (2, 3); |
569 | // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used |
570 | // | 2 4 6 | |
571 | // |
572 | // float4 x = (1, 2, 3, 4); |
573 | // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used |
574 | // | 2 4 | |
575 | // |
576 | |
577 | int position = 0; |
578 | for (const std::unique_ptr<Expression>& expr : c.fArguments) { |
579 | // If an input argument is a matrix, we need a helper function. |
580 | if (expr->fType.kind() == Type::kMatrix_Kind) { |
581 | return true; |
582 | } |
583 | position += expr->fType.columns(); |
584 | if (position > c.fType.rows()) { |
585 | // An input argument would span multiple rows; a helper function is required. |
586 | return true; |
587 | } |
588 | if (position == c.fType.rows()) { |
589 | // We've advanced to the end of a row. Wrap to the start of the next row. |
590 | position = 0; |
591 | } |
592 | } |
593 | |
594 | return false; |
595 | } |
596 | |
597 | void MetalCodeGenerator::writeConstructor(const Constructor& c, Precedence parentPrecedence) { |
598 | // Handle special cases for single-argument constructors. |
599 | if (c.fArguments.size() == 1) { |
600 | // If the type is coercible, emit it directly. |
601 | const Expression& arg = *c.fArguments.front(); |
602 | if (this->canCoerce(c.fType, arg.fType)) { |
603 | this->writeExpression(arg, parentPrecedence); |
604 | return; |
605 | } |
606 | |
607 | // Metal supports creating matrices with a scalar on the diagonal via the single-argument |
608 | // matrix constructor. |
609 | if (c.fType.kind() == Type::kMatrix_Kind && arg.fType.isNumber()) { |
610 | const Type& matrix = c.fType; |
611 | this->write("float" ); |
612 | this->write(to_string(matrix.columns())); |
613 | this->write("x" ); |
614 | this->write(to_string(matrix.rows())); |
615 | this->write("(" ); |
616 | this->writeExpression(arg, parentPrecedence); |
617 | this->write(")" ); |
618 | return; |
619 | } |
620 | } |
621 | |
622 | // Emit and invoke a matrix-constructor helper method if one is necessary. |
623 | if (this->matrixConstructHelperIsNeeded(c)) { |
624 | this->write(this->getMatrixConstructHelper(c)); |
625 | this->write("(" ); |
626 | const char* separator = "" ; |
627 | for (const std::unique_ptr<Expression>& expr : c.fArguments) { |
628 | this->write(separator); |
629 | separator = ", " ; |
630 | this->writeExpression(*expr, kSequence_Precedence); |
631 | } |
632 | this->write(")" ); |
633 | return; |
634 | } |
635 | |
636 | // Explicitly invoke the constructor, passing in the necessary arguments. |
637 | this->writeType(c.fType); |
638 | this->write("(" ); |
639 | const char* separator = "" ; |
640 | int scalarCount = 0; |
641 | for (const std::unique_ptr<Expression>& arg : c.fArguments) { |
642 | this->write(separator); |
643 | separator = ", " ; |
644 | if (Type::kMatrix_Kind == c.fType.kind() && arg->fType.columns() < c.fType.rows()) { |
645 | // Merge scalars and smaller vectors together. |
646 | if (!scalarCount) { |
647 | this->writeType(c.fType.componentType()); |
648 | this->write(to_string(c.fType.rows())); |
649 | this->write("(" ); |
650 | } |
651 | scalarCount += arg->fType.columns(); |
652 | } |
653 | this->writeExpression(*arg, kSequence_Precedence); |
654 | if (scalarCount && scalarCount == c.fType.rows()) { |
655 | this->write(")" ); |
656 | scalarCount = 0; |
657 | } |
658 | } |
659 | this->write(")" ); |
660 | } |
661 | |
662 | void MetalCodeGenerator::writeFragCoord() { |
663 | if (fRTHeightName.length()) { |
664 | this->write("float4(_fragCoord.x, " ); |
665 | this->write(fRTHeightName.c_str()); |
666 | this->write(" - _fragCoord.y, 0.0, _fragCoord.w)" ); |
667 | } else { |
668 | this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)" ); |
669 | } |
670 | } |
671 | |
672 | void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) { |
673 | switch (ref.fVariable.fModifiers.fLayout.fBuiltin) { |
674 | case SK_FRAGCOLOR_BUILTIN: |
675 | this->write("_out->sk_FragColor" ); |
676 | break; |
677 | case SK_FRAGCOORD_BUILTIN: |
678 | this->writeFragCoord(); |
679 | break; |
680 | case SK_VERTEXID_BUILTIN: |
681 | this->write("sk_VertexID" ); |
682 | break; |
683 | case SK_INSTANCEID_BUILTIN: |
684 | this->write("sk_InstanceID" ); |
685 | break; |
686 | case SK_CLOCKWISE_BUILTIN: |
687 | // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter |
688 | // clockwise to match Skia convention. |
689 | this->write(fProgram.fSettings.fFlipY ? "_frontFacing" : "(!_frontFacing)" ); |
690 | break; |
691 | default: |
692 | if (Variable::kGlobal_Storage == ref.fVariable.fStorage) { |
693 | if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { |
694 | this->write("_in." ); |
695 | } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { |
696 | this->write("_out->" ); |
697 | } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && |
698 | ref.fVariable.fType.kind() != Type::kSampler_Kind) { |
699 | this->write("_uniforms." ); |
700 | } else { |
701 | this->write("_globals->" ); |
702 | } |
703 | } |
704 | this->writeName(ref.fVariable.fName); |
705 | } |
706 | } |
707 | |
708 | void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) { |
709 | this->writeExpression(*expr.fBase, kPostfix_Precedence); |
710 | this->write("[" ); |
711 | this->writeExpression(*expr.fIndex, kTopLevel_Precedence); |
712 | this->write("]" ); |
713 | } |
714 | |
715 | void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) { |
716 | const Type::Field* field = &f.fBase->fType.fields()[f.fFieldIndex]; |
717 | if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) { |
718 | this->writeExpression(*f.fBase, kPostfix_Precedence); |
719 | this->write("." ); |
720 | } |
721 | switch (field->fModifiers.fLayout.fBuiltin) { |
722 | case SK_CLIPDISTANCE_BUILTIN: |
723 | this->write("gl_ClipDistance" ); |
724 | break; |
725 | case SK_POSITION_BUILTIN: |
726 | this->write("_out->sk_Position" ); |
727 | break; |
728 | default: |
729 | if (field->fName == "sk_PointSize" ) { |
730 | this->write("_out->sk_PointSize" ); |
731 | } else { |
732 | if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) { |
733 | this->write("_globals->" ); |
734 | this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]); |
735 | this->write("->" ); |
736 | } |
737 | this->writeName(field->fName); |
738 | } |
739 | } |
740 | } |
741 | |
742 | void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) { |
743 | int last = swizzle.fComponents.back(); |
744 | if (last == SKSL_SWIZZLE_0 || last == SKSL_SWIZZLE_1) { |
745 | this->writeType(swizzle.fType); |
746 | this->write("(" ); |
747 | } |
748 | this->writeExpression(*swizzle.fBase, kPostfix_Precedence); |
749 | this->write("." ); |
750 | for (int c : swizzle.fComponents) { |
751 | if (c >= 0) { |
752 | this->write(&("x\0y\0z\0w\0" [c * 2])); |
753 | } |
754 | } |
755 | if (last == SKSL_SWIZZLE_0) { |
756 | this->write(", 0)" ); |
757 | } |
758 | else if (last == SKSL_SWIZZLE_1) { |
759 | this->write(", 1)" ); |
760 | } |
761 | } |
762 | |
763 | MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) { |
764 | switch (op) { |
765 | case Token::Kind::TK_STAR: // fall through |
766 | case Token::Kind::TK_SLASH: // fall through |
767 | case Token::Kind::TK_PERCENT: return MetalCodeGenerator::kMultiplicative_Precedence; |
768 | case Token::Kind::TK_PLUS: // fall through |
769 | case Token::Kind::TK_MINUS: return MetalCodeGenerator::kAdditive_Precedence; |
770 | case Token::Kind::TK_SHL: // fall through |
771 | case Token::Kind::TK_SHR: return MetalCodeGenerator::kShift_Precedence; |
772 | case Token::Kind::TK_LT: // fall through |
773 | case Token::Kind::TK_GT: // fall through |
774 | case Token::Kind::TK_LTEQ: // fall through |
775 | case Token::Kind::TK_GTEQ: return MetalCodeGenerator::kRelational_Precedence; |
776 | case Token::Kind::TK_EQEQ: // fall through |
777 | case Token::Kind::TK_NEQ: return MetalCodeGenerator::kEquality_Precedence; |
778 | case Token::Kind::TK_BITWISEAND: return MetalCodeGenerator::kBitwiseAnd_Precedence; |
779 | case Token::Kind::TK_BITWISEXOR: return MetalCodeGenerator::kBitwiseXor_Precedence; |
780 | case Token::Kind::TK_BITWISEOR: return MetalCodeGenerator::kBitwiseOr_Precedence; |
781 | case Token::Kind::TK_LOGICALAND: return MetalCodeGenerator::kLogicalAnd_Precedence; |
782 | case Token::Kind::TK_LOGICALXOR: return MetalCodeGenerator::kLogicalXor_Precedence; |
783 | case Token::Kind::TK_LOGICALOR: return MetalCodeGenerator::kLogicalOr_Precedence; |
784 | case Token::Kind::TK_EQ: // fall through |
785 | case Token::Kind::TK_PLUSEQ: // fall through |
786 | case Token::Kind::TK_MINUSEQ: // fall through |
787 | case Token::Kind::TK_STAREQ: // fall through |
788 | case Token::Kind::TK_SLASHEQ: // fall through |
789 | case Token::Kind::TK_PERCENTEQ: // fall through |
790 | case Token::Kind::TK_SHLEQ: // fall through |
791 | case Token::Kind::TK_SHREQ: // fall through |
792 | case Token::Kind::TK_LOGICALANDEQ: // fall through |
793 | case Token::Kind::TK_LOGICALXOREQ: // fall through |
794 | case Token::Kind::TK_LOGICALOREQ: // fall through |
795 | case Token::Kind::TK_BITWISEANDEQ: // fall through |
796 | case Token::Kind::TK_BITWISEXOREQ: // fall through |
797 | case Token::Kind::TK_BITWISEOREQ: return MetalCodeGenerator::kAssignment_Precedence; |
798 | case Token::Kind::TK_COMMA: return MetalCodeGenerator::kSequence_Precedence; |
799 | default: ABORT("unsupported binary operator" ); |
800 | } |
801 | } |
802 | |
803 | void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right, |
804 | const Type& result) { |
805 | String key = "TimesEqual" + left.name() + right.name(); |
806 | if (fHelpers.find(key) == fHelpers.end()) { |
807 | fExtraFunctions.printf("%s operator*=(thread %s& left, thread const %s& right) {\n" |
808 | " left = left * right;\n" |
809 | " return left;\n" |
810 | "}" , result.name().c_str(), left.name().c_str(), |
811 | right.name().c_str()); |
812 | } |
813 | } |
814 | |
815 | void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b, |
816 | Precedence parentPrecedence) { |
817 | Precedence precedence = GetBinaryPrecedence(b.fOperator); |
818 | bool needParens = precedence >= parentPrecedence; |
819 | switch (b.fOperator) { |
820 | case Token::Kind::TK_EQEQ: |
821 | if (b.fLeft->fType.kind() == Type::kVector_Kind) { |
822 | this->write("all" ); |
823 | needParens = true; |
824 | } |
825 | break; |
826 | case Token::Kind::TK_NEQ: |
827 | if (b.fLeft->fType.kind() == Type::kVector_Kind) { |
828 | this->write("any" ); |
829 | needParens = true; |
830 | } |
831 | break; |
832 | default: |
833 | break; |
834 | } |
835 | if (needParens) { |
836 | this->write("(" ); |
837 | } |
838 | if (Compiler::IsAssignment(b.fOperator) && |
839 | Expression::kVariableReference_Kind == b.fLeft->fKind && |
840 | Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage && |
841 | (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { |
842 | // writing to an out parameter. Since we have to turn those into pointers, we have to |
843 | // dereference it here. |
844 | this->write("*" ); |
845 | } |
846 | if (b.fOperator == Token::Kind::TK_STAREQ && b.fLeft->fType.kind() == Type::kMatrix_Kind && |
847 | b.fRight->fType.kind() == Type::kMatrix_Kind) { |
848 | this->writeMatrixTimesEqualHelper(b.fLeft->fType, b.fRight->fType, b.fType); |
849 | } |
850 | this->writeExpression(*b.fLeft, precedence); |
851 | if (b.fOperator != Token::Kind::TK_EQ && Compiler::IsAssignment(b.fOperator) && |
852 | Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) { |
853 | // This doesn't compile in Metal: |
854 | // float4 x = float4(1); |
855 | // x.xy *= float2x2(...); |
856 | // with the error message "non-const reference cannot bind to vector element", |
857 | // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation |
858 | // as long as the LHS has no side effects, and hope for the best otherwise. |
859 | this->write(" = " ); |
860 | this->writeExpression(*b.fLeft, kAssignment_Precedence); |
861 | this->write(" " ); |
862 | String op = Compiler::OperatorName(b.fOperator); |
863 | SkASSERT(op.endsWith("=" )); |
864 | this->write(op.substr(0, op.size() - 1).c_str()); |
865 | this->write(" " ); |
866 | } else { |
867 | this->write(String(" " ) + Compiler::OperatorName(b.fOperator) + " " ); |
868 | } |
869 | this->writeExpression(*b.fRight, precedence); |
870 | if (needParens) { |
871 | this->write(")" ); |
872 | } |
873 | } |
874 | |
875 | void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t, |
876 | Precedence parentPrecedence) { |
877 | if (kTernary_Precedence >= parentPrecedence) { |
878 | this->write("(" ); |
879 | } |
880 | this->writeExpression(*t.fTest, kTernary_Precedence); |
881 | this->write(" ? " ); |
882 | this->writeExpression(*t.fIfTrue, kTernary_Precedence); |
883 | this->write(" : " ); |
884 | this->writeExpression(*t.fIfFalse, kTernary_Precedence); |
885 | if (kTernary_Precedence >= parentPrecedence) { |
886 | this->write(")" ); |
887 | } |
888 | } |
889 | |
890 | void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p, |
891 | Precedence parentPrecedence) { |
892 | if (kPrefix_Precedence >= parentPrecedence) { |
893 | this->write("(" ); |
894 | } |
895 | this->write(Compiler::OperatorName(p.fOperator)); |
896 | this->writeExpression(*p.fOperand, kPrefix_Precedence); |
897 | if (kPrefix_Precedence >= parentPrecedence) { |
898 | this->write(")" ); |
899 | } |
900 | } |
901 | |
902 | void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p, |
903 | Precedence parentPrecedence) { |
904 | if (kPostfix_Precedence >= parentPrecedence) { |
905 | this->write("(" ); |
906 | } |
907 | this->writeExpression(*p.fOperand, kPostfix_Precedence); |
908 | this->write(Compiler::OperatorName(p.fOperator)); |
909 | if (kPostfix_Precedence >= parentPrecedence) { |
910 | this->write(")" ); |
911 | } |
912 | } |
913 | |
914 | void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { |
915 | this->write(b.fValue ? "true" : "false" ); |
916 | } |
917 | |
918 | void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) { |
919 | if (i.fType == *fContext.fUInt_Type) { |
920 | this->write(to_string(i.fValue & 0xffffffff) + "u" ); |
921 | } else { |
922 | this->write(to_string((int32_t) i.fValue)); |
923 | } |
924 | } |
925 | |
926 | void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { |
927 | this->write(to_string(f.fValue)); |
928 | } |
929 | |
930 | void MetalCodeGenerator::writeSetting(const Setting& s) { |
931 | ABORT("internal error; setting was not folded to a constant during compilation\n" ); |
932 | } |
933 | |
934 | void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) { |
935 | fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals->_anonInterface0->u_skRTHeight" : "" ; |
936 | const char* separator = "" ; |
937 | if ("main" == f.fDeclaration.fName) { |
938 | switch (fProgram.fKind) { |
939 | case Program::kFragment_Kind: |
940 | this->write("fragment Outputs fragmentMain" ); |
941 | break; |
942 | case Program::kVertex_Kind: |
943 | this->write("vertex Outputs vertexMain" ); |
944 | break; |
945 | default: |
946 | SkDEBUGFAIL("unsupported kind of program" ); |
947 | } |
948 | this->write("(Inputs _in [[stage_in]]" ); |
949 | if (-1 != fUniformBuffer) { |
950 | this->write(", constant Uniforms& _uniforms [[buffer(" + |
951 | to_string(fUniformBuffer) + ")]]" ); |
952 | } |
953 | for (const auto& e : fProgram) { |
954 | if (ProgramElement::kVar_Kind == e.fKind) { |
955 | VarDeclarations& decls = (VarDeclarations&) e; |
956 | if (!decls.fVars.size()) { |
957 | continue; |
958 | } |
959 | for (const auto& stmt: decls.fVars) { |
960 | VarDeclaration& var = (VarDeclaration&) *stmt; |
961 | if (var.fVar->fType.kind() == Type::kSampler_Kind) { |
962 | if (var.fVar->fModifiers.fLayout.fBinding < 0) { |
963 | fErrors.error(decls.fOffset, |
964 | "Metal samplers must have 'layout(binding=...)'" ); |
965 | } |
966 | this->write(", texture2d<float> " ); // FIXME - support other texture types |
967 | this->writeName(var.fVar->fName); |
968 | this->write("[[texture(" ); |
969 | this->write(to_string(var.fVar->fModifiers.fLayout.fBinding)); |
970 | this->write(")]]" ); |
971 | this->write(", sampler " ); |
972 | this->writeName(var.fVar->fName); |
973 | this->write(SAMPLER_SUFFIX); |
974 | this->write("[[sampler(" ); |
975 | this->write(to_string(var.fVar->fModifiers.fLayout.fBinding)); |
976 | this->write(")]]" ); |
977 | } |
978 | } |
979 | } else if (ProgramElement::kInterfaceBlock_Kind == e.fKind) { |
980 | InterfaceBlock& intf = (InterfaceBlock&) e; |
981 | if ("sk_PerVertex" == intf.fTypeName) { |
982 | continue; |
983 | } |
984 | this->write(", constant " ); |
985 | this->writeType(intf.fVariable.fType); |
986 | this->write("& " ); |
987 | this->write(fInterfaceBlockNameMap[&intf]); |
988 | this->write(" [[buffer(" ); |
989 | this->write(to_string(intf.fVariable.fModifiers.fLayout.fBinding)); |
990 | this->write(")]]" ); |
991 | } |
992 | } |
993 | if (fProgram.fKind == Program::kFragment_Kind) { |
994 | if (fProgram.fInputs.fRTHeight && fInterfaceBlockNameMap.empty()) { |
995 | this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]" ); |
996 | fRTHeightName = "_anonInterface0.u_skRTHeight" ; |
997 | } |
998 | this->write(", bool _frontFacing [[front_facing]]" ); |
999 | this->write(", float4 _fragCoord [[position]]" ); |
1000 | } else if (fProgram.fKind == Program::kVertex_Kind) { |
1001 | this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]" ); |
1002 | } |
1003 | separator = ", " ; |
1004 | } else { |
1005 | this->writeType(f.fDeclaration.fReturnType); |
1006 | this->write(" " ); |
1007 | this->writeName(f.fDeclaration.fName); |
1008 | this->write("(" ); |
1009 | Requirements requirements = this->requirements(f.fDeclaration); |
1010 | if (requirements & kInputs_Requirement) { |
1011 | this->write("Inputs _in" ); |
1012 | separator = ", " ; |
1013 | } |
1014 | if (requirements & kOutputs_Requirement) { |
1015 | this->write(separator); |
1016 | this->write("thread Outputs* _out" ); |
1017 | separator = ", " ; |
1018 | } |
1019 | if (requirements & kUniforms_Requirement) { |
1020 | this->write(separator); |
1021 | this->write("Uniforms _uniforms" ); |
1022 | separator = ", " ; |
1023 | } |
1024 | if (requirements & kGlobals_Requirement) { |
1025 | this->write(separator); |
1026 | this->write("thread Globals* _globals" ); |
1027 | separator = ", " ; |
1028 | } |
1029 | if (requirements & kFragCoord_Requirement) { |
1030 | this->write(separator); |
1031 | this->write("float4 _fragCoord" ); |
1032 | separator = ", " ; |
1033 | } |
1034 | } |
1035 | for (const auto& param : f.fDeclaration.fParameters) { |
1036 | this->write(separator); |
1037 | separator = ", " ; |
1038 | this->writeModifiers(param->fModifiers, false); |
1039 | std::vector<int> sizes; |
1040 | const Type* type = ¶m->fType; |
1041 | while (Type::kArray_Kind == type->kind()) { |
1042 | sizes.push_back(type->columns()); |
1043 | type = &type->componentType(); |
1044 | } |
1045 | this->writeType(*type); |
1046 | if (param->fModifiers.fFlags & Modifiers::kOut_Flag) { |
1047 | this->write("*" ); |
1048 | } |
1049 | this->write(" " ); |
1050 | this->writeName(param->fName); |
1051 | for (int s : sizes) { |
1052 | if (s <= 0) { |
1053 | this->write("[]" ); |
1054 | } else { |
1055 | this->write("[" + to_string(s) + "]" ); |
1056 | } |
1057 | } |
1058 | } |
1059 | this->writeLine(") {" ); |
1060 | |
1061 | SkASSERT(!fProgram.fSettings.fFragColorIsInOut); |
1062 | |
1063 | if ("main" == f.fDeclaration.fName) { |
1064 | this->writeGlobalInit(); |
1065 | this->writeLine(" Outputs _outputStruct;" ); |
1066 | this->writeLine(" thread Outputs* _out = &_outputStruct;" ); |
1067 | } |
1068 | |
1069 | fFunctionHeader = "" ; |
1070 | OutputStream* oldOut = fOut; |
1071 | StringStream buffer; |
1072 | fOut = &buffer; |
1073 | fIndentation++; |
1074 | this->writeStatements(((Block&) *f.fBody).fStatements); |
1075 | if ("main" == f.fDeclaration.fName) { |
1076 | switch (fProgram.fKind) { |
1077 | case Program::kFragment_Kind: |
1078 | this->writeLine("return *_out;" ); |
1079 | break; |
1080 | case Program::kVertex_Kind: |
1081 | this->writeLine("_out->sk_Position.y = -_out->sk_Position.y;" ); |
1082 | this->writeLine("return *_out;" ); // FIXME - detect if function already has return |
1083 | break; |
1084 | default: |
1085 | SkDEBUGFAIL("unsupported kind of program" ); |
1086 | } |
1087 | } |
1088 | fIndentation--; |
1089 | this->writeLine("}" ); |
1090 | |
1091 | fOut = oldOut; |
1092 | this->write(fFunctionHeader); |
1093 | this->write(buffer.str()); |
1094 | } |
1095 | |
1096 | void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers, |
1097 | bool globalContext) { |
1098 | if (modifiers.fFlags & Modifiers::kOut_Flag) { |
1099 | this->write("thread " ); |
1100 | } |
1101 | if (modifiers.fFlags & Modifiers::kConst_Flag) { |
1102 | this->write("constant " ); |
1103 | } |
1104 | } |
1105 | |
1106 | void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { |
1107 | if ("sk_PerVertex" == intf.fTypeName) { |
1108 | return; |
1109 | } |
1110 | this->writeModifiers(intf.fVariable.fModifiers, true); |
1111 | this->write("struct " ); |
1112 | this->writeLine(intf.fTypeName + " {" ); |
1113 | const Type* structType = &intf.fVariable.fType; |
1114 | fWrittenStructs.push_back(structType); |
1115 | while (Type::kArray_Kind == structType->kind()) { |
1116 | structType = &structType->componentType(); |
1117 | } |
1118 | fIndentation++; |
1119 | writeFields(structType->fields(), structType->fOffset, &intf); |
1120 | if (fProgram.fInputs.fRTHeight) { |
1121 | this->writeLine("float u_skRTHeight;" ); |
1122 | } |
1123 | fIndentation--; |
1124 | this->write("}" ); |
1125 | if (intf.fInstanceName.size()) { |
1126 | this->write(" " ); |
1127 | this->write(intf.fInstanceName); |
1128 | for (const auto& size : intf.fSizes) { |
1129 | this->write("[" ); |
1130 | if (size) { |
1131 | this->writeExpression(*size, kTopLevel_Precedence); |
1132 | } |
1133 | this->write("]" ); |
1134 | } |
1135 | fInterfaceBlockNameMap[&intf] = intf.fInstanceName; |
1136 | } else { |
1137 | fInterfaceBlockNameMap[&intf] = "_anonInterface" + to_string(fAnonInterfaceCount++); |
1138 | } |
1139 | this->writeLine(";" ); |
1140 | } |
1141 | |
1142 | void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentOffset, |
1143 | const InterfaceBlock* parentIntf) { |
1144 | MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard); |
1145 | int currentOffset = 0; |
1146 | for (const auto& field: fields) { |
1147 | int fieldOffset = field.fModifiers.fLayout.fOffset; |
1148 | const Type* fieldType = field.fType; |
1149 | if (fieldOffset != -1) { |
1150 | if (currentOffset > fieldOffset) { |
1151 | fErrors.error(parentOffset, |
1152 | "offset of field '" + field.fName + "' must be at least " + |
1153 | to_string((int) currentOffset)); |
1154 | } else if (currentOffset < fieldOffset) { |
1155 | this->write("char pad" ); |
1156 | this->write(to_string(fPaddingCount++)); |
1157 | this->write("[" ); |
1158 | this->write(to_string(fieldOffset - currentOffset)); |
1159 | this->writeLine("];" ); |
1160 | currentOffset = fieldOffset; |
1161 | } |
1162 | int alignment = memoryLayout.alignment(*fieldType); |
1163 | if (fieldOffset % alignment) { |
1164 | fErrors.error(parentOffset, |
1165 | "offset of field '" + field.fName + "' must be a multiple of " + |
1166 | to_string((int) alignment)); |
1167 | } |
1168 | } |
1169 | currentOffset += memoryLayout.size(*fieldType); |
1170 | std::vector<int> sizes; |
1171 | while (fieldType->kind() == Type::kArray_Kind) { |
1172 | sizes.push_back(fieldType->columns()); |
1173 | fieldType = &fieldType->componentType(); |
1174 | } |
1175 | this->writeModifiers(field.fModifiers, false); |
1176 | this->writeType(*fieldType); |
1177 | this->write(" " ); |
1178 | this->writeName(field.fName); |
1179 | for (int s : sizes) { |
1180 | if (s <= 0) { |
1181 | this->write("[]" ); |
1182 | } else { |
1183 | this->write("[" + to_string(s) + "]" ); |
1184 | } |
1185 | } |
1186 | this->writeLine(";" ); |
1187 | if (parentIntf) { |
1188 | fInterfaceBlockMap[&field] = parentIntf; |
1189 | } |
1190 | } |
1191 | } |
1192 | |
1193 | void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) { |
1194 | this->writeExpression(value, kTopLevel_Precedence); |
1195 | } |
1196 | |
1197 | void MetalCodeGenerator::writeName(const String& name) { |
1198 | if (fReservedWords.find(name) != fReservedWords.end()) { |
1199 | this->write("_" ); // adding underscore before name to avoid conflict with reserved words |
1200 | } |
1201 | this->write(name); |
1202 | } |
1203 | |
1204 | void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) { |
1205 | SkASSERT(decl.fVars.size() > 0); |
1206 | bool wroteType = false; |
1207 | for (const auto& stmt : decl.fVars) { |
1208 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1209 | if (global && !(var.fVar->fModifiers.fFlags & Modifiers::kConst_Flag)) { |
1210 | continue; |
1211 | } |
1212 | if (wroteType) { |
1213 | this->write(", " ); |
1214 | } else { |
1215 | this->writeModifiers(var.fVar->fModifiers, global); |
1216 | this->writeType(decl.fBaseType); |
1217 | this->write(" " ); |
1218 | wroteType = true; |
1219 | } |
1220 | this->writeName(var.fVar->fName); |
1221 | for (const auto& size : var.fSizes) { |
1222 | this->write("[" ); |
1223 | if (size) { |
1224 | this->writeExpression(*size, kTopLevel_Precedence); |
1225 | } |
1226 | this->write("]" ); |
1227 | } |
1228 | if (var.fValue) { |
1229 | this->write(" = " ); |
1230 | this->writeVarInitializer(*var.fVar, *var.fValue); |
1231 | } |
1232 | } |
1233 | if (wroteType) { |
1234 | this->write(";" ); |
1235 | } |
1236 | } |
1237 | |
1238 | void MetalCodeGenerator::writeStatement(const Statement& s) { |
1239 | switch (s.fKind) { |
1240 | case Statement::kBlock_Kind: |
1241 | this->writeBlock((Block&) s); |
1242 | break; |
1243 | case Statement::kExpression_Kind: |
1244 | this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence); |
1245 | this->write(";" ); |
1246 | break; |
1247 | case Statement::kReturn_Kind: |
1248 | this->writeReturnStatement((ReturnStatement&) s); |
1249 | break; |
1250 | case Statement::kVarDeclarations_Kind: |
1251 | this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false); |
1252 | break; |
1253 | case Statement::kIf_Kind: |
1254 | this->writeIfStatement((IfStatement&) s); |
1255 | break; |
1256 | case Statement::kFor_Kind: |
1257 | this->writeForStatement((ForStatement&) s); |
1258 | break; |
1259 | case Statement::kWhile_Kind: |
1260 | this->writeWhileStatement((WhileStatement&) s); |
1261 | break; |
1262 | case Statement::kDo_Kind: |
1263 | this->writeDoStatement((DoStatement&) s); |
1264 | break; |
1265 | case Statement::kSwitch_Kind: |
1266 | this->writeSwitchStatement((SwitchStatement&) s); |
1267 | break; |
1268 | case Statement::kBreak_Kind: |
1269 | this->write("break;" ); |
1270 | break; |
1271 | case Statement::kContinue_Kind: |
1272 | this->write("continue;" ); |
1273 | break; |
1274 | case Statement::kDiscard_Kind: |
1275 | this->write("discard_fragment();" ); |
1276 | break; |
1277 | case Statement::kNop_Kind: |
1278 | this->write(";" ); |
1279 | break; |
1280 | default: |
1281 | #ifdef SK_DEBUG |
1282 | ABORT("unsupported statement: %s" , s.description().c_str()); |
1283 | #endif |
1284 | break; |
1285 | } |
1286 | } |
1287 | |
1288 | void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) { |
1289 | for (const auto& s : statements) { |
1290 | if (!s->isEmpty()) { |
1291 | this->writeStatement(*s); |
1292 | this->writeLine(); |
1293 | } |
1294 | } |
1295 | } |
1296 | |
1297 | void MetalCodeGenerator::writeBlock(const Block& b) { |
1298 | if (b.fIsScope) { |
1299 | this->writeLine("{" ); |
1300 | fIndentation++; |
1301 | } |
1302 | this->writeStatements(b.fStatements); |
1303 | if (b.fIsScope) { |
1304 | fIndentation--; |
1305 | this->write("}" ); |
1306 | } |
1307 | } |
1308 | |
1309 | void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) { |
1310 | this->write("if (" ); |
1311 | this->writeExpression(*stmt.fTest, kTopLevel_Precedence); |
1312 | this->write(") " ); |
1313 | this->writeStatement(*stmt.fIfTrue); |
1314 | if (stmt.fIfFalse) { |
1315 | this->write(" else " ); |
1316 | this->writeStatement(*stmt.fIfFalse); |
1317 | } |
1318 | } |
1319 | |
1320 | void MetalCodeGenerator::writeForStatement(const ForStatement& f) { |
1321 | this->write("for (" ); |
1322 | if (f.fInitializer && !f.fInitializer->isEmpty()) { |
1323 | this->writeStatement(*f.fInitializer); |
1324 | } else { |
1325 | this->write("; " ); |
1326 | } |
1327 | if (f.fTest) { |
1328 | this->writeExpression(*f.fTest, kTopLevel_Precedence); |
1329 | } |
1330 | this->write("; " ); |
1331 | if (f.fNext) { |
1332 | this->writeExpression(*f.fNext, kTopLevel_Precedence); |
1333 | } |
1334 | this->write(") " ); |
1335 | this->writeStatement(*f.fStatement); |
1336 | } |
1337 | |
1338 | void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) { |
1339 | this->write("while (" ); |
1340 | this->writeExpression(*w.fTest, kTopLevel_Precedence); |
1341 | this->write(") " ); |
1342 | this->writeStatement(*w.fStatement); |
1343 | } |
1344 | |
1345 | void MetalCodeGenerator::writeDoStatement(const DoStatement& d) { |
1346 | this->write("do " ); |
1347 | this->writeStatement(*d.fStatement); |
1348 | this->write(" while (" ); |
1349 | this->writeExpression(*d.fTest, kTopLevel_Precedence); |
1350 | this->write(");" ); |
1351 | } |
1352 | |
1353 | void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) { |
1354 | this->write("switch (" ); |
1355 | this->writeExpression(*s.fValue, kTopLevel_Precedence); |
1356 | this->writeLine(") {" ); |
1357 | fIndentation++; |
1358 | for (const auto& c : s.fCases) { |
1359 | if (c->fValue) { |
1360 | this->write("case " ); |
1361 | this->writeExpression(*c->fValue, kTopLevel_Precedence); |
1362 | this->writeLine(":" ); |
1363 | } else { |
1364 | this->writeLine("default:" ); |
1365 | } |
1366 | fIndentation++; |
1367 | for (const auto& stmt : c->fStatements) { |
1368 | this->writeStatement(*stmt); |
1369 | this->writeLine(); |
1370 | } |
1371 | fIndentation--; |
1372 | } |
1373 | fIndentation--; |
1374 | this->write("}" ); |
1375 | } |
1376 | |
1377 | void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) { |
1378 | this->write("return" ); |
1379 | if (r.fExpression) { |
1380 | this->write(" " ); |
1381 | this->writeExpression(*r.fExpression, kTopLevel_Precedence); |
1382 | } |
1383 | this->write(";" ); |
1384 | } |
1385 | |
1386 | void MetalCodeGenerator::() { |
1387 | this->write("#include <metal_stdlib>\n" ); |
1388 | this->write("#include <simd/simd.h>\n" ); |
1389 | this->write("using namespace metal;\n" ); |
1390 | } |
1391 | |
1392 | void MetalCodeGenerator::writeUniformStruct() { |
1393 | for (const auto& e : fProgram) { |
1394 | if (ProgramElement::kVar_Kind == e.fKind) { |
1395 | VarDeclarations& decls = (VarDeclarations&) e; |
1396 | if (!decls.fVars.size()) { |
1397 | continue; |
1398 | } |
1399 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1400 | if (first.fModifiers.fFlags & Modifiers::kUniform_Flag && |
1401 | first.fType.kind() != Type::kSampler_Kind) { |
1402 | if (-1 == fUniformBuffer) { |
1403 | this->write("struct Uniforms {\n" ); |
1404 | fUniformBuffer = first.fModifiers.fLayout.fSet; |
1405 | if (-1 == fUniformBuffer) { |
1406 | fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'" ); |
1407 | } |
1408 | } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) { |
1409 | if (-1 == fUniformBuffer) { |
1410 | fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have " |
1411 | "the same 'layout(set=...)'" ); |
1412 | } |
1413 | } |
1414 | this->write(" " ); |
1415 | this->writeType(first.fType); |
1416 | this->write(" " ); |
1417 | for (const auto& stmt : decls.fVars) { |
1418 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1419 | this->writeName(var.fVar->fName); |
1420 | } |
1421 | this->write(";\n" ); |
1422 | } |
1423 | } |
1424 | } |
1425 | if (-1 != fUniformBuffer) { |
1426 | this->write("};\n" ); |
1427 | } |
1428 | } |
1429 | |
1430 | void MetalCodeGenerator::writeInputStruct() { |
1431 | this->write("struct Inputs {\n" ); |
1432 | for (const auto& e : fProgram) { |
1433 | if (ProgramElement::kVar_Kind == e.fKind) { |
1434 | VarDeclarations& decls = (VarDeclarations&) e; |
1435 | if (!decls.fVars.size()) { |
1436 | continue; |
1437 | } |
1438 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1439 | if (first.fModifiers.fFlags & Modifiers::kIn_Flag && |
1440 | -1 == first.fModifiers.fLayout.fBuiltin) { |
1441 | this->write(" " ); |
1442 | this->writeType(first.fType); |
1443 | this->write(" " ); |
1444 | for (const auto& stmt : decls.fVars) { |
1445 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1446 | this->writeName(var.fVar->fName); |
1447 | if (-1 != var.fVar->fModifiers.fLayout.fLocation) { |
1448 | if (fProgram.fKind == Program::kVertex_Kind) { |
1449 | this->write(" [[attribute(" + |
1450 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1451 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1452 | this->write(" [[user(locn" + |
1453 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1454 | } |
1455 | } |
1456 | } |
1457 | this->write(";\n" ); |
1458 | } |
1459 | } |
1460 | } |
1461 | this->write("};\n" ); |
1462 | } |
1463 | |
1464 | void MetalCodeGenerator::writeOutputStruct() { |
1465 | this->write("struct Outputs {\n" ); |
1466 | if (fProgram.fKind == Program::kVertex_Kind) { |
1467 | this->write(" float4 sk_Position [[position]];\n" ); |
1468 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1469 | this->write(" float4 sk_FragColor [[color(0)]];\n" ); |
1470 | } |
1471 | for (const auto& e : fProgram) { |
1472 | if (ProgramElement::kVar_Kind == e.fKind) { |
1473 | VarDeclarations& decls = (VarDeclarations&) e; |
1474 | if (!decls.fVars.size()) { |
1475 | continue; |
1476 | } |
1477 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1478 | if (first.fModifiers.fFlags & Modifiers::kOut_Flag && |
1479 | -1 == first.fModifiers.fLayout.fBuiltin) { |
1480 | this->write(" " ); |
1481 | this->writeType(first.fType); |
1482 | this->write(" " ); |
1483 | for (const auto& stmt : decls.fVars) { |
1484 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1485 | this->writeName(var.fVar->fName); |
1486 | if (fProgram.fKind == Program::kVertex_Kind) { |
1487 | this->write(" [[user(locn" + |
1488 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1489 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1490 | this->write(" [[color(" + |
1491 | to_string(var.fVar->fModifiers.fLayout.fLocation) +")" ); |
1492 | int colorIndex = var.fVar->fModifiers.fLayout.fIndex; |
1493 | if (colorIndex) { |
1494 | this->write(", index(" + to_string(colorIndex) + ")" ); |
1495 | } |
1496 | this->write("]]" ); |
1497 | } |
1498 | } |
1499 | this->write(";\n" ); |
1500 | } |
1501 | } |
1502 | } |
1503 | if (fProgram.fKind == Program::kVertex_Kind) { |
1504 | this->write(" float sk_PointSize;\n" ); |
1505 | } |
1506 | this->write("};\n" ); |
1507 | } |
1508 | |
1509 | void MetalCodeGenerator::writeInterfaceBlocks() { |
1510 | bool wroteInterfaceBlock = false; |
1511 | for (const auto& e : fProgram) { |
1512 | if (ProgramElement::kInterfaceBlock_Kind == e.fKind) { |
1513 | this->writeInterfaceBlock((InterfaceBlock&) e); |
1514 | wroteInterfaceBlock = true; |
1515 | } |
1516 | } |
1517 | if (!wroteInterfaceBlock && fProgram.fInputs.fRTHeight) { |
1518 | this->writeLine("struct sksl_synthetic_uniforms {" ); |
1519 | this->writeLine(" float u_skRTHeight;" ); |
1520 | this->writeLine("};" ); |
1521 | } |
1522 | } |
1523 | |
1524 | void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) { |
1525 | // Visit the interface blocks. |
1526 | for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) { |
1527 | visitor->VisitInterfaceBlock(*interfaceType, interfaceName); |
1528 | } |
1529 | for (const ProgramElement& element : fProgram) { |
1530 | if (element.fKind != ProgramElement::kVar_Kind) { |
1531 | continue; |
1532 | } |
1533 | const VarDeclarations& decls = static_cast<const VarDeclarations&>(element); |
1534 | if (decls.fVars.empty()) { |
1535 | continue; |
1536 | } |
1537 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1538 | if ((!first.fModifiers.fFlags && -1 == first.fModifiers.fLayout.fBuiltin) || |
1539 | first.fType.kind() == Type::kSampler_Kind) { |
1540 | for (const auto& stmt : decls.fVars) { |
1541 | VarDeclaration& var = static_cast<VarDeclaration&>(*stmt); |
1542 | |
1543 | if (var.fVar->fType.kind() == Type::kSampler_Kind) { |
1544 | // Samplers are represented as a "texture/sampler" duo in the global struct. |
1545 | visitor->VisitTexture(first.fType, var.fVar->fName); |
1546 | visitor->VisitSampler(first.fType, String(var.fVar->fName) + SAMPLER_SUFFIX); |
1547 | } else { |
1548 | // Visit a regular variable. |
1549 | visitor->VisitVariable(*var.fVar, var.fValue.get()); |
1550 | } |
1551 | } |
1552 | } |
1553 | } |
1554 | } |
1555 | |
1556 | void MetalCodeGenerator::writeGlobalStruct() { |
1557 | class : public GlobalStructVisitor { |
1558 | public: |
1559 | void VisitInterfaceBlock(const InterfaceBlock& block, const String& blockName) override { |
1560 | this->AddElement(); |
1561 | fCodeGen->write(" constant " ); |
1562 | fCodeGen->write(block.fTypeName); |
1563 | fCodeGen->write("* " ); |
1564 | fCodeGen->writeName(blockName); |
1565 | fCodeGen->write(";\n" ); |
1566 | } |
1567 | void VisitTexture(const Type& type, const String& name) override { |
1568 | this->AddElement(); |
1569 | fCodeGen->write(" " ); |
1570 | fCodeGen->writeType(type); |
1571 | fCodeGen->write(" " ); |
1572 | fCodeGen->writeName(name); |
1573 | fCodeGen->write(";\n" ); |
1574 | } |
1575 | void VisitSampler(const Type&, const String& name) override { |
1576 | this->AddElement(); |
1577 | fCodeGen->write(" sampler " ); |
1578 | fCodeGen->writeName(name); |
1579 | fCodeGen->write(";\n" ); |
1580 | } |
1581 | void VisitVariable(const Variable& var, const Expression* value) override { |
1582 | this->AddElement(); |
1583 | fCodeGen->write(" " ); |
1584 | fCodeGen->writeType(var.fType); |
1585 | fCodeGen->write(" " ); |
1586 | fCodeGen->writeName(var.fName); |
1587 | fCodeGen->write(";\n" ); |
1588 | } |
1589 | void AddElement() { |
1590 | if (fFirst) { |
1591 | fCodeGen->write("struct Globals {\n" ); |
1592 | fFirst = false; |
1593 | } |
1594 | } |
1595 | void Finish() { |
1596 | if (!fFirst) { |
1597 | fCodeGen->write("};" ); |
1598 | fFirst = true; |
1599 | } |
1600 | } |
1601 | |
1602 | MetalCodeGenerator* fCodeGen = nullptr; |
1603 | bool fFirst = true; |
1604 | } visitor; |
1605 | |
1606 | visitor.fCodeGen = this; |
1607 | this->visitGlobalStruct(&visitor); |
1608 | visitor.Finish(); |
1609 | } |
1610 | |
1611 | void MetalCodeGenerator::writeGlobalInit() { |
1612 | class : public GlobalStructVisitor { |
1613 | public: |
1614 | void VisitInterfaceBlock(const InterfaceBlock& blockType, |
1615 | const String& blockName) override { |
1616 | this->AddElement(); |
1617 | fCodeGen->write("&" ); |
1618 | fCodeGen->writeName(blockName); |
1619 | } |
1620 | void VisitTexture(const Type&, const String& name) override { |
1621 | this->AddElement(); |
1622 | fCodeGen->writeName(name); |
1623 | } |
1624 | void VisitSampler(const Type&, const String& name) override { |
1625 | this->AddElement(); |
1626 | fCodeGen->writeName(name); |
1627 | } |
1628 | void VisitVariable(const Variable& var, const Expression* value) override { |
1629 | this->AddElement(); |
1630 | if (value) { |
1631 | fCodeGen->writeVarInitializer(var, *value); |
1632 | } else { |
1633 | fCodeGen->write("{}" ); |
1634 | } |
1635 | } |
1636 | void AddElement() { |
1637 | if (fFirst) { |
1638 | fCodeGen->write(" Globals globalStruct{" ); |
1639 | fFirst = false; |
1640 | } else { |
1641 | fCodeGen->write(", " ); |
1642 | } |
1643 | } |
1644 | void Finish() { |
1645 | if (!fFirst) { |
1646 | fCodeGen->writeLine("};" ); |
1647 | fCodeGen->writeLine(" thread Globals* _globals = &globalStruct;" ); |
1648 | fCodeGen->writeLine(" (void)_globals;" ); |
1649 | } |
1650 | } |
1651 | MetalCodeGenerator* fCodeGen = nullptr; |
1652 | bool fFirst = true; |
1653 | } visitor; |
1654 | |
1655 | visitor.fCodeGen = this; |
1656 | this->visitGlobalStruct(&visitor); |
1657 | visitor.Finish(); |
1658 | } |
1659 | |
1660 | void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) { |
1661 | switch (e.fKind) { |
1662 | case ProgramElement::kExtension_Kind: |
1663 | break; |
1664 | case ProgramElement::kVar_Kind: { |
1665 | VarDeclarations& decl = (VarDeclarations&) e; |
1666 | if (decl.fVars.size() > 0) { |
1667 | int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin; |
1668 | if (-1 == builtin) { |
1669 | // normal var |
1670 | this->writeVarDeclarations(decl, true); |
1671 | this->writeLine(); |
1672 | } else if (SK_FRAGCOLOR_BUILTIN == builtin) { |
1673 | // ignore |
1674 | } |
1675 | } |
1676 | break; |
1677 | } |
1678 | case ProgramElement::kInterfaceBlock_Kind: |
1679 | // handled in writeInterfaceBlocks, do nothing |
1680 | break; |
1681 | case ProgramElement::kFunction_Kind: |
1682 | this->writeFunction((FunctionDefinition&) e); |
1683 | break; |
1684 | case ProgramElement::kModifiers_Kind: |
1685 | this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true); |
1686 | this->writeLine(";" ); |
1687 | break; |
1688 | default: |
1689 | #ifdef SK_DEBUG |
1690 | ABORT("unsupported program element: %s\n" , e.description().c_str()); |
1691 | #endif |
1692 | break; |
1693 | } |
1694 | } |
1695 | |
1696 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) { |
1697 | if (!e) { |
1698 | return kNo_Requirements; |
1699 | } |
1700 | switch (e->fKind) { |
1701 | case Expression::kFunctionCall_Kind: { |
1702 | const FunctionCall& f = (const FunctionCall&) *e; |
1703 | Requirements result = this->requirements(f.fFunction); |
1704 | for (const auto& arg : f.fArguments) { |
1705 | result |= this->requirements(arg.get()); |
1706 | } |
1707 | return result; |
1708 | } |
1709 | case Expression::kConstructor_Kind: { |
1710 | const Constructor& c = (const Constructor&) *e; |
1711 | Requirements result = kNo_Requirements; |
1712 | for (const auto& arg : c.fArguments) { |
1713 | result |= this->requirements(arg.get()); |
1714 | } |
1715 | return result; |
1716 | } |
1717 | case Expression::kFieldAccess_Kind: { |
1718 | const FieldAccess& f = (const FieldAccess&) *e; |
1719 | if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) { |
1720 | return kGlobals_Requirement; |
1721 | } |
1722 | return this->requirements(f.fBase.get()); |
1723 | } |
1724 | case Expression::kSwizzle_Kind: |
1725 | return this->requirements(((const Swizzle&) *e).fBase.get()); |
1726 | case Expression::kBinary_Kind: { |
1727 | const BinaryExpression& b = (const BinaryExpression&) *e; |
1728 | return this->requirements(b.fLeft.get()) | this->requirements(b.fRight.get()); |
1729 | } |
1730 | case Expression::kIndex_Kind: { |
1731 | const IndexExpression& idx = (const IndexExpression&) *e; |
1732 | return this->requirements(idx.fBase.get()) | this->requirements(idx.fIndex.get()); |
1733 | } |
1734 | case Expression::kPrefix_Kind: |
1735 | return this->requirements(((const PrefixExpression&) *e).fOperand.get()); |
1736 | case Expression::kPostfix_Kind: |
1737 | return this->requirements(((const PostfixExpression&) *e).fOperand.get()); |
1738 | case Expression::kTernary_Kind: { |
1739 | const TernaryExpression& t = (const TernaryExpression&) *e; |
1740 | return this->requirements(t.fTest.get()) | this->requirements(t.fIfTrue.get()) | |
1741 | this->requirements(t.fIfFalse.get()); |
1742 | } |
1743 | case Expression::kVariableReference_Kind: { |
1744 | const VariableReference& v = (const VariableReference&) *e; |
1745 | Requirements result = kNo_Requirements; |
1746 | if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) { |
1747 | result = kGlobals_Requirement | kFragCoord_Requirement; |
1748 | } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) { |
1749 | if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { |
1750 | result = kInputs_Requirement; |
1751 | } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { |
1752 | result = kOutputs_Requirement; |
1753 | } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && |
1754 | v.fVariable.fType.kind() != Type::kSampler_Kind) { |
1755 | result = kUniforms_Requirement; |
1756 | } else { |
1757 | result = kGlobals_Requirement; |
1758 | } |
1759 | } |
1760 | return result; |
1761 | } |
1762 | default: |
1763 | return kNo_Requirements; |
1764 | } |
1765 | } |
1766 | |
1767 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) { |
1768 | if (!s) { |
1769 | return kNo_Requirements; |
1770 | } |
1771 | switch (s->fKind) { |
1772 | case Statement::kBlock_Kind: { |
1773 | Requirements result = kNo_Requirements; |
1774 | for (const auto& child : ((const Block*) s)->fStatements) { |
1775 | result |= this->requirements(child.get()); |
1776 | } |
1777 | return result; |
1778 | } |
1779 | case Statement::kVarDeclaration_Kind: { |
1780 | const VarDeclaration& var = (const VarDeclaration&) *s; |
1781 | return this->requirements(var.fValue.get()); |
1782 | } |
1783 | case Statement::kVarDeclarations_Kind: { |
1784 | Requirements result = kNo_Requirements; |
1785 | const VarDeclarations& decls = *((const VarDeclarationsStatement&) *s).fDeclaration; |
1786 | for (const auto& stmt : decls.fVars) { |
1787 | result |= this->requirements(stmt.get()); |
1788 | } |
1789 | return result; |
1790 | } |
1791 | case Statement::kExpression_Kind: |
1792 | return this->requirements(((const ExpressionStatement&) *s).fExpression.get()); |
1793 | case Statement::kReturn_Kind: { |
1794 | const ReturnStatement& r = (const ReturnStatement&) *s; |
1795 | return this->requirements(r.fExpression.get()); |
1796 | } |
1797 | case Statement::kIf_Kind: { |
1798 | const IfStatement& i = (const IfStatement&) *s; |
1799 | return this->requirements(i.fTest.get()) | |
1800 | this->requirements(i.fIfTrue.get()) | |
1801 | this->requirements(i.fIfFalse.get()); |
1802 | } |
1803 | case Statement::kFor_Kind: { |
1804 | const ForStatement& f = (const ForStatement&) *s; |
1805 | return this->requirements(f.fInitializer.get()) | |
1806 | this->requirements(f.fTest.get()) | |
1807 | this->requirements(f.fNext.get()) | |
1808 | this->requirements(f.fStatement.get()); |
1809 | } |
1810 | case Statement::kWhile_Kind: { |
1811 | const WhileStatement& w = (const WhileStatement&) *s; |
1812 | return this->requirements(w.fTest.get()) | |
1813 | this->requirements(w.fStatement.get()); |
1814 | } |
1815 | case Statement::kDo_Kind: { |
1816 | const DoStatement& d = (const DoStatement&) *s; |
1817 | return this->requirements(d.fTest.get()) | |
1818 | this->requirements(d.fStatement.get()); |
1819 | } |
1820 | case Statement::kSwitch_Kind: { |
1821 | const SwitchStatement& sw = (const SwitchStatement&) *s; |
1822 | Requirements result = this->requirements(sw.fValue.get()); |
1823 | for (const auto& c : sw.fCases) { |
1824 | for (const auto& st : c->fStatements) { |
1825 | result |= this->requirements(st.get()); |
1826 | } |
1827 | } |
1828 | return result; |
1829 | } |
1830 | default: |
1831 | return kNo_Requirements; |
1832 | } |
1833 | } |
1834 | |
1835 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) { |
1836 | if (f.fBuiltin) { |
1837 | return kNo_Requirements; |
1838 | } |
1839 | auto found = fRequirements.find(&f); |
1840 | if (found == fRequirements.end()) { |
1841 | fRequirements[&f] = kNo_Requirements; |
1842 | for (const auto& e : fProgram) { |
1843 | if (ProgramElement::kFunction_Kind == e.fKind) { |
1844 | const FunctionDefinition& def = (const FunctionDefinition&) e; |
1845 | if (&def.fDeclaration == &f) { |
1846 | Requirements reqs = this->requirements(def.fBody.get()); |
1847 | fRequirements[&f] = reqs; |
1848 | return reqs; |
1849 | } |
1850 | } |
1851 | } |
1852 | } |
1853 | return found->second; |
1854 | } |
1855 | |
1856 | bool MetalCodeGenerator::generateCode() { |
1857 | OutputStream* rawOut = fOut; |
1858 | fOut = &fHeader; |
1859 | fProgramKind = fProgram.fKind; |
1860 | this->writeHeader(); |
1861 | this->writeUniformStruct(); |
1862 | this->writeInputStruct(); |
1863 | this->writeOutputStruct(); |
1864 | this->writeInterfaceBlocks(); |
1865 | this->writeGlobalStruct(); |
1866 | StringStream body; |
1867 | fOut = &body; |
1868 | for (const auto& e : fProgram) { |
1869 | this->writeProgramElement(e); |
1870 | } |
1871 | fOut = rawOut; |
1872 | |
1873 | write_stringstream(fHeader, *rawOut); |
1874 | write_stringstream(fExtraFunctions, *rawOut); |
1875 | write_stringstream(body, *rawOut); |
1876 | return true; |
1877 | } |
1878 | |
1879 | } // namespace SkSL |
1880 | |