1 | /* |
2 | * Copyright 2016 Google Inc. |
3 | * |
4 | * Use of this source code is governed by a BSD-style license that can be |
5 | * found in the LICENSE file. |
6 | */ |
7 | |
8 | #include "src/sksl/SkSLMetalCodeGenerator.h" |
9 | |
10 | #include "src/sksl/SkSLCompiler.h" |
11 | #include "src/sksl/ir/SkSLExpressionStatement.h" |
12 | #include "src/sksl/ir/SkSLExtension.h" |
13 | #include "src/sksl/ir/SkSLIndexExpression.h" |
14 | #include "src/sksl/ir/SkSLModifiersDeclaration.h" |
15 | #include "src/sksl/ir/SkSLNop.h" |
16 | #include "src/sksl/ir/SkSLVariableReference.h" |
17 | |
18 | namespace SkSL { |
19 | |
20 | void MetalCodeGenerator::setupIntrinsics() { |
21 | #define METAL(x) std::make_pair(kMetal_IntrinsicKind, k ## x ## _MetalIntrinsic) |
22 | #define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic) |
23 | fIntrinsicMap[String("sample" )] = SPECIAL(Texture); |
24 | fIntrinsicMap[String("mod" )] = SPECIAL(Mod); |
25 | fIntrinsicMap[String("equal" )] = METAL(Equal); |
26 | fIntrinsicMap[String("notEqual" )] = METAL(NotEqual); |
27 | fIntrinsicMap[String("lessThan" )] = METAL(LessThan); |
28 | fIntrinsicMap[String("lessThanEqual" )] = METAL(LessThanEqual); |
29 | fIntrinsicMap[String("greaterThan" )] = METAL(GreaterThan); |
30 | fIntrinsicMap[String("greaterThanEqual" )] = METAL(GreaterThanEqual); |
31 | } |
32 | |
33 | void MetalCodeGenerator::write(const char* s) { |
34 | if (!s[0]) { |
35 | return; |
36 | } |
37 | if (fAtLineStart) { |
38 | for (int i = 0; i < fIndentation; i++) { |
39 | fOut->writeText(" " ); |
40 | } |
41 | } |
42 | fOut->writeText(s); |
43 | fAtLineStart = false; |
44 | } |
45 | |
46 | void MetalCodeGenerator::writeLine(const char* s) { |
47 | this->write(s); |
48 | fOut->writeText(fLineEnding); |
49 | fAtLineStart = true; |
50 | } |
51 | |
52 | void MetalCodeGenerator::write(const String& s) { |
53 | this->write(s.c_str()); |
54 | } |
55 | |
56 | void MetalCodeGenerator::writeLine(const String& s) { |
57 | this->writeLine(s.c_str()); |
58 | } |
59 | |
60 | void MetalCodeGenerator::writeLine() { |
61 | this->writeLine("" ); |
62 | } |
63 | |
64 | void MetalCodeGenerator::writeExtension(const Extension& ext) { |
65 | this->writeLine("#extension " + ext.fName + " : enable" ); |
66 | } |
67 | |
68 | String MetalCodeGenerator::typeName(const Type& type) { |
69 | switch (type.kind()) { |
70 | case Type::kVector_Kind: |
71 | return this->typeName(type.componentType()) + to_string(type.columns()); |
72 | case Type::kMatrix_Kind: |
73 | return this->typeName(type.componentType()) + to_string(type.columns()) + "x" + |
74 | to_string(type.rows()); |
75 | case Type::kSampler_Kind: |
76 | return "texture2d<float>" ; // FIXME - support other texture types; |
77 | default: |
78 | if (type == *fContext.fHalf_Type) { |
79 | // FIXME - Currently only supporting floats in MSL to avoid type coercion issues. |
80 | return fContext.fFloat_Type->name(); |
81 | } else if (type == *fContext.fByte_Type) { |
82 | return "char" ; |
83 | } else if (type == *fContext.fUByte_Type) { |
84 | return "uchar" ; |
85 | } else { |
86 | return type.name(); |
87 | } |
88 | } |
89 | } |
90 | |
91 | void MetalCodeGenerator::writeType(const Type& type) { |
92 | if (type.kind() == Type::kStruct_Kind) { |
93 | for (const Type* search : fWrittenStructs) { |
94 | if (*search == type) { |
95 | // already written |
96 | this->write(type.name()); |
97 | return; |
98 | } |
99 | } |
100 | fWrittenStructs.push_back(&type); |
101 | this->writeLine("struct " + type.name() + " {" ); |
102 | fIndentation++; |
103 | this->writeFields(type.fields(), type.fOffset); |
104 | fIndentation--; |
105 | this->write("}" ); |
106 | } else { |
107 | this->write(this->typeName(type)); |
108 | } |
109 | } |
110 | |
111 | void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) { |
112 | switch (expr.fKind) { |
113 | case Expression::kBinary_Kind: |
114 | this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence); |
115 | break; |
116 | case Expression::kBoolLiteral_Kind: |
117 | this->writeBoolLiteral((BoolLiteral&) expr); |
118 | break; |
119 | case Expression::kConstructor_Kind: |
120 | this->writeConstructor((Constructor&) expr, parentPrecedence); |
121 | break; |
122 | case Expression::kIntLiteral_Kind: |
123 | this->writeIntLiteral((IntLiteral&) expr); |
124 | break; |
125 | case Expression::kFieldAccess_Kind: |
126 | this->writeFieldAccess(((FieldAccess&) expr)); |
127 | break; |
128 | case Expression::kFloatLiteral_Kind: |
129 | this->writeFloatLiteral(((FloatLiteral&) expr)); |
130 | break; |
131 | case Expression::kFunctionCall_Kind: |
132 | this->writeFunctionCall((FunctionCall&) expr); |
133 | break; |
134 | case Expression::kPrefix_Kind: |
135 | this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence); |
136 | break; |
137 | case Expression::kPostfix_Kind: |
138 | this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence); |
139 | break; |
140 | case Expression::kSetting_Kind: |
141 | this->writeSetting((Setting&) expr); |
142 | break; |
143 | case Expression::kSwizzle_Kind: |
144 | this->writeSwizzle((Swizzle&) expr); |
145 | break; |
146 | case Expression::kVariableReference_Kind: |
147 | this->writeVariableReference((VariableReference&) expr); |
148 | break; |
149 | case Expression::kTernary_Kind: |
150 | this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence); |
151 | break; |
152 | case Expression::kIndex_Kind: |
153 | this->writeIndexExpression((IndexExpression&) expr); |
154 | break; |
155 | default: |
156 | #ifdef SK_DEBUG |
157 | ABORT("unsupported expression: %s" , expr.description().c_str()); |
158 | #endif |
159 | break; |
160 | } |
161 | } |
162 | |
163 | void MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c) { |
164 | auto i = fIntrinsicMap.find(c.fFunction.fName); |
165 | SkASSERT(i != fIntrinsicMap.end()); |
166 | Intrinsic intrinsic = i->second; |
167 | int32_t intrinsicId = intrinsic.second; |
168 | switch (intrinsic.first) { |
169 | case kSpecial_IntrinsicKind: |
170 | return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId); |
171 | break; |
172 | case kMetal_IntrinsicKind: |
173 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
174 | switch ((MetalIntrinsic) intrinsicId) { |
175 | case kEqual_MetalIntrinsic: |
176 | this->write(" == " ); |
177 | break; |
178 | case kNotEqual_MetalIntrinsic: |
179 | this->write(" != " ); |
180 | break; |
181 | case kLessThan_MetalIntrinsic: |
182 | this->write(" < " ); |
183 | break; |
184 | case kLessThanEqual_MetalIntrinsic: |
185 | this->write(" <= " ); |
186 | break; |
187 | case kGreaterThan_MetalIntrinsic: |
188 | this->write(" > " ); |
189 | break; |
190 | case kGreaterThanEqual_MetalIntrinsic: |
191 | this->write(" >= " ); |
192 | break; |
193 | default: |
194 | ABORT("unsupported metal intrinsic kind" ); |
195 | } |
196 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
197 | break; |
198 | default: |
199 | ABORT("unsupported intrinsic kind" ); |
200 | } |
201 | } |
202 | |
203 | void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) { |
204 | const auto& entry = fIntrinsicMap.find(c.fFunction.fName); |
205 | if (entry != fIntrinsicMap.end()) { |
206 | this->writeIntrinsicCall(c); |
207 | return; |
208 | } |
209 | if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) { |
210 | this->write("atan2" ); |
211 | } else if (c.fFunction.fBuiltin && "inversesqrt" == c.fFunction.fName) { |
212 | this->write("rsqrt" ); |
213 | } else if (c.fFunction.fBuiltin && "inverse" == c.fFunction.fName) { |
214 | SkASSERT(c.fArguments.size() == 1); |
215 | this->writeInverseHack(*c.fArguments[0]); |
216 | } else if (c.fFunction.fBuiltin && "dFdx" == c.fFunction.fName) { |
217 | this->write("dfdx" ); |
218 | } else if (c.fFunction.fBuiltin && "dFdy" == c.fFunction.fName) { |
219 | // Flipping Y also negates the Y derivatives. |
220 | this->write((fProgram.fSettings.fFlipY) ? "-dfdy" : "dfdy" ); |
221 | } else { |
222 | this->writeName(c.fFunction.fName); |
223 | } |
224 | this->write("(" ); |
225 | const char* separator = "" ; |
226 | if (this->requirements(c.fFunction) & kInputs_Requirement) { |
227 | this->write("_in" ); |
228 | separator = ", " ; |
229 | } |
230 | if (this->requirements(c.fFunction) & kOutputs_Requirement) { |
231 | this->write(separator); |
232 | this->write("_out" ); |
233 | separator = ", " ; |
234 | } |
235 | if (this->requirements(c.fFunction) & kUniforms_Requirement) { |
236 | this->write(separator); |
237 | this->write("_uniforms" ); |
238 | separator = ", " ; |
239 | } |
240 | if (this->requirements(c.fFunction) & kGlobals_Requirement) { |
241 | this->write(separator); |
242 | this->write("_globals" ); |
243 | separator = ", " ; |
244 | } |
245 | if (this->requirements(c.fFunction) & kFragCoord_Requirement) { |
246 | this->write(separator); |
247 | this->write("_fragCoord" ); |
248 | separator = ", " ; |
249 | } |
250 | for (size_t i = 0; i < c.fArguments.size(); ++i) { |
251 | const Expression& arg = *c.fArguments[i]; |
252 | this->write(separator); |
253 | separator = ", " ; |
254 | if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) { |
255 | this->write("&" ); |
256 | } |
257 | this->writeExpression(arg, kSequence_Precedence); |
258 | } |
259 | this->write(")" ); |
260 | } |
261 | |
262 | void MetalCodeGenerator::writeInverseHack(const Expression& mat) { |
263 | String typeName = mat.fType.name(); |
264 | String name = typeName + "_inverse" ; |
265 | if (mat.fType == *fContext.fFloat2x2_Type || mat.fType == *fContext.fHalf2x2_Type) { |
266 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
267 | fWrittenIntrinsics.insert(name); |
268 | fExtraFunctions.writeText(( |
269 | typeName + " " + name + "(" + typeName + " m) {" |
270 | " return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));" |
271 | "}" |
272 | ).c_str()); |
273 | } |
274 | } |
275 | else if (mat.fType == *fContext.fFloat3x3_Type || mat.fType == *fContext.fHalf3x3_Type) { |
276 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
277 | fWrittenIntrinsics.insert(name); |
278 | fExtraFunctions.writeText(( |
279 | typeName + " " + name + "(" + typeName + " m) {" |
280 | " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];" |
281 | " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];" |
282 | " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];" |
283 | " float b01 = a22 * a11 - a12 * a21;" |
284 | " float b11 = -a22 * a10 + a12 * a20;" |
285 | " float b21 = a21 * a10 - a11 * a20;" |
286 | " float det = a00 * b01 + a01 * b11 + a02 * b21;" |
287 | " return " + typeName + |
288 | " (b01, (-a22 * a01 + a02 * a21), (a12 * a01 - a02 * a11)," |
289 | " b11, (a22 * a00 - a02 * a20), (-a12 * a00 + a02 * a10)," |
290 | " b21, (-a21 * a00 + a01 * a20), (a11 * a00 - a01 * a10)) * " |
291 | " (1/det);" |
292 | "}" |
293 | ).c_str()); |
294 | } |
295 | } |
296 | else if (mat.fType == *fContext.fFloat4x4_Type || mat.fType == *fContext.fHalf4x4_Type) { |
297 | if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) { |
298 | fWrittenIntrinsics.insert(name); |
299 | fExtraFunctions.writeText(( |
300 | typeName + " " + name + "(" + typeName + " m) {" |
301 | " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];" |
302 | " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];" |
303 | " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];" |
304 | " float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];" |
305 | " float b00 = a00 * a11 - a01 * a10;" |
306 | " float b01 = a00 * a12 - a02 * a10;" |
307 | " float b02 = a00 * a13 - a03 * a10;" |
308 | " float b03 = a01 * a12 - a02 * a11;" |
309 | " float b04 = a01 * a13 - a03 * a11;" |
310 | " float b05 = a02 * a13 - a03 * a12;" |
311 | " float b06 = a20 * a31 - a21 * a30;" |
312 | " float b07 = a20 * a32 - a22 * a30;" |
313 | " float b08 = a20 * a33 - a23 * a30;" |
314 | " float b09 = a21 * a32 - a22 * a31;" |
315 | " float b10 = a21 * a33 - a23 * a31;" |
316 | " float b11 = a22 * a33 - a23 * a32;" |
317 | " float det = b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - " |
318 | " b04 * b07 + b05 * b06;" |
319 | " return " + typeName + "(a11 * b11 - a12 * b10 + a13 * b09," |
320 | " a02 * b10 - a01 * b11 - a03 * b09," |
321 | " a31 * b05 - a32 * b04 + a33 * b03," |
322 | " a22 * b04 - a21 * b05 - a23 * b03," |
323 | " a12 * b08 - a10 * b11 - a13 * b07," |
324 | " a00 * b11 - a02 * b08 + a03 * b07," |
325 | " a32 * b02 - a30 * b05 - a33 * b01," |
326 | " a20 * b05 - a22 * b02 + a23 * b01," |
327 | " a10 * b10 - a11 * b08 + a13 * b06," |
328 | " a01 * b08 - a00 * b10 - a03 * b06," |
329 | " a30 * b04 - a31 * b02 + a33 * b00," |
330 | " a21 * b02 - a20 * b04 - a23 * b00," |
331 | " a11 * b07 - a10 * b09 - a12 * b06," |
332 | " a00 * b09 - a01 * b07 + a02 * b06," |
333 | " a31 * b01 - a30 * b03 - a32 * b00," |
334 | " a20 * b03 - a21 * b01 + a22 * b00) / det;" |
335 | "}" |
336 | ).c_str()); |
337 | } |
338 | } |
339 | this->write(name); |
340 | } |
341 | |
342 | void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall & c, SpecialIntrinsic kind) { |
343 | switch (kind) { |
344 | case kTexture_SpecialIntrinsic: |
345 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
346 | this->write(".sample(" ); |
347 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
348 | this->write(SAMPLER_SUFFIX); |
349 | this->write(", " ); |
350 | if (c.fArguments[1]->fType == *fContext.fFloat3_Type) { |
351 | // have to store the vector in a temp variable to avoid double evaluating it |
352 | String tmpVar = "tmpCoord" + to_string(fVarCount++); |
353 | this->fFunctionHeader += " " + this->typeName(c.fArguments[1]->fType) + " " + |
354 | tmpVar + ";\n" ; |
355 | this->write("(" + tmpVar + " = " ); |
356 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
357 | this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))" ); |
358 | } else { |
359 | SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type); |
360 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
361 | this->write(")" ); |
362 | } |
363 | break; |
364 | case kMod_SpecialIntrinsic: { |
365 | // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y) |
366 | String tmpX = "tmpX" + to_string(fVarCount++); |
367 | String tmpY = "tmpY" + to_string(fVarCount++); |
368 | this->fFunctionHeader += " " + this->typeName(c.fArguments[0]->fType) + " " + tmpX + |
369 | ", " + tmpY + ";\n" ; |
370 | this->write("(" + tmpX + " = " ); |
371 | this->writeExpression(*c.fArguments[0], kSequence_Precedence); |
372 | this->write(", " + tmpY + " = " ); |
373 | this->writeExpression(*c.fArguments[1], kSequence_Precedence); |
374 | this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))" ); |
375 | break; |
376 | } |
377 | default: |
378 | ABORT("unsupported special intrinsic kind" ); |
379 | } |
380 | } |
381 | |
382 | // If it hasn't already been written, writes a constructor for 'matrix' which takes a single value |
383 | // of type 'arg'. |
384 | String MetalCodeGenerator::getMatrixConstructHelper(const Type& matrix, const Type& arg) { |
385 | String key = matrix.name() + arg.name(); |
386 | auto found = fHelpers.find(key); |
387 | if (found != fHelpers.end()) { |
388 | return found->second; |
389 | } |
390 | String name; |
391 | int columns = matrix.columns(); |
392 | int rows = matrix.rows(); |
393 | if (arg.isNumber()) { |
394 | // creating a matrix from a single scalar value |
395 | name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float" ; |
396 | fExtraFunctions.printf("float%dx%d %s(float x) {\n" , |
397 | columns, rows, name.c_str()); |
398 | fExtraFunctions.printf(" return float%dx%d(" , columns, rows); |
399 | for (int i = 0; i < columns; ++i) { |
400 | if (i > 0) { |
401 | fExtraFunctions.writeText(", " ); |
402 | } |
403 | fExtraFunctions.printf("float%d(" , rows); |
404 | for (int j = 0; j < rows; ++j) { |
405 | if (j > 0) { |
406 | fExtraFunctions.writeText(", " ); |
407 | } |
408 | if (i == j) { |
409 | fExtraFunctions.writeText("x" ); |
410 | } else { |
411 | fExtraFunctions.writeText("0" ); |
412 | } |
413 | } |
414 | fExtraFunctions.writeText(")" ); |
415 | } |
416 | fExtraFunctions.writeText(");\n}\n" ); |
417 | } else if (arg.kind() == Type::kMatrix_Kind) { |
418 | // creating a matrix from another matrix |
419 | int argColumns = arg.columns(); |
420 | int argRows = arg.rows(); |
421 | name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float" + |
422 | to_string(argColumns) + "x" + to_string(argRows); |
423 | fExtraFunctions.printf("float%dx%d %s(float%dx%d m) {\n" , |
424 | columns, rows, name.c_str(), argColumns, argRows); |
425 | fExtraFunctions.printf(" return float%dx%d(" , columns, rows); |
426 | for (int i = 0; i < columns; ++i) { |
427 | if (i > 0) { |
428 | fExtraFunctions.writeText(", " ); |
429 | } |
430 | fExtraFunctions.printf("float%d(" , rows); |
431 | for (int j = 0; j < rows; ++j) { |
432 | if (j > 0) { |
433 | fExtraFunctions.writeText(", " ); |
434 | } |
435 | if (i < argColumns && j < argRows) { |
436 | fExtraFunctions.printf("m[%d][%d]" , i, j); |
437 | } else { |
438 | fExtraFunctions.writeText("0" ); |
439 | } |
440 | } |
441 | fExtraFunctions.writeText(")" ); |
442 | } |
443 | fExtraFunctions.writeText(");\n}\n" ); |
444 | } else if (matrix.rows() == 2 && matrix.columns() == 2 && arg == *fContext.fFloat4_Type) { |
445 | // float2x2(float4) doesn't work, need to split it into float2x2(float2, float2) |
446 | name = "float2x2_from_float4" ; |
447 | fExtraFunctions.printf( |
448 | "float2x2 %s(float4 v) {\n" |
449 | " return float2x2(float2(v[0], v[1]), float2(v[2], v[3]));\n" |
450 | "}\n" , |
451 | name.c_str() |
452 | ); |
453 | } else { |
454 | SkASSERT(false); |
455 | name = "<error>" ; |
456 | } |
457 | fHelpers[key] = name; |
458 | return name; |
459 | } |
460 | |
461 | bool MetalCodeGenerator::canCoerce(const Type& t1, const Type& t2) { |
462 | if (t1.columns() != t2.columns() || t1.rows() != t2.rows()) { |
463 | return false; |
464 | } |
465 | if (t1.columns() > 1) { |
466 | return this->canCoerce(t1.componentType(), t2.componentType()); |
467 | } |
468 | return t1.isFloat() && t2.isFloat(); |
469 | } |
470 | |
471 | void MetalCodeGenerator::writeConstructor(const Constructor& c, Precedence parentPrecedence) { |
472 | if (c.fArguments.size() == 1 && this->canCoerce(c.fType, c.fArguments[0]->fType)) { |
473 | this->writeExpression(*c.fArguments[0], parentPrecedence); |
474 | return; |
475 | } |
476 | if (c.fType.kind() == Type::kMatrix_Kind && c.fArguments.size() == 1) { |
477 | const Expression& arg = *c.fArguments[0]; |
478 | String name = this->getMatrixConstructHelper(c.fType, arg.fType); |
479 | this->write(name); |
480 | this->write("(" ); |
481 | this->writeExpression(arg, kSequence_Precedence); |
482 | this->write(")" ); |
483 | } else { |
484 | this->writeType(c.fType); |
485 | this->write("(" ); |
486 | const char* separator = "" ; |
487 | int scalarCount = 0; |
488 | for (const auto& arg : c.fArguments) { |
489 | this->write(separator); |
490 | separator = ", " ; |
491 | if (Type::kMatrix_Kind == c.fType.kind() && arg->fType.columns() != c.fType.rows()) { |
492 | // merge scalars and smaller vectors together |
493 | if (!scalarCount) { |
494 | this->writeType(c.fType.componentType()); |
495 | this->write(to_string(c.fType.rows())); |
496 | this->write("(" ); |
497 | } |
498 | scalarCount += arg->fType.columns(); |
499 | } |
500 | this->writeExpression(*arg, kSequence_Precedence); |
501 | if (scalarCount && scalarCount == c.fType.rows()) { |
502 | this->write(")" ); |
503 | scalarCount = 0; |
504 | } |
505 | } |
506 | this->write(")" ); |
507 | } |
508 | } |
509 | |
510 | void MetalCodeGenerator::writeFragCoord() { |
511 | if (fRTHeightName.length()) { |
512 | this->write("float4(_fragCoord.x, " ); |
513 | this->write(fRTHeightName.c_str()); |
514 | this->write(" - _fragCoord.y, 0.0, _fragCoord.w)" ); |
515 | } else { |
516 | this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)" ); |
517 | } |
518 | } |
519 | |
520 | void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) { |
521 | switch (ref.fVariable.fModifiers.fLayout.fBuiltin) { |
522 | case SK_FRAGCOLOR_BUILTIN: |
523 | this->write("_out->sk_FragColor" ); |
524 | break; |
525 | case SK_FRAGCOORD_BUILTIN: |
526 | this->writeFragCoord(); |
527 | break; |
528 | case SK_VERTEXID_BUILTIN: |
529 | this->write("sk_VertexID" ); |
530 | break; |
531 | case SK_INSTANCEID_BUILTIN: |
532 | this->write("sk_InstanceID" ); |
533 | break; |
534 | case SK_CLOCKWISE_BUILTIN: |
535 | // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter |
536 | // clockwise to match Skia convention. |
537 | this->write(fProgram.fSettings.fFlipY ? "_frontFacing" : "(!_frontFacing)" ); |
538 | break; |
539 | default: |
540 | if (Variable::kGlobal_Storage == ref.fVariable.fStorage) { |
541 | if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { |
542 | this->write("_in." ); |
543 | } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { |
544 | this->write("_out->" ); |
545 | } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && |
546 | ref.fVariable.fType.kind() != Type::kSampler_Kind) { |
547 | this->write("_uniforms." ); |
548 | } else { |
549 | this->write("_globals->" ); |
550 | } |
551 | } |
552 | this->writeName(ref.fVariable.fName); |
553 | } |
554 | } |
555 | |
556 | void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) { |
557 | this->writeExpression(*expr.fBase, kPostfix_Precedence); |
558 | this->write("[" ); |
559 | this->writeExpression(*expr.fIndex, kTopLevel_Precedence); |
560 | this->write("]" ); |
561 | } |
562 | |
563 | void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) { |
564 | const Type::Field* field = &f.fBase->fType.fields()[f.fFieldIndex]; |
565 | if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) { |
566 | this->writeExpression(*f.fBase, kPostfix_Precedence); |
567 | this->write("." ); |
568 | } |
569 | switch (field->fModifiers.fLayout.fBuiltin) { |
570 | case SK_CLIPDISTANCE_BUILTIN: |
571 | this->write("gl_ClipDistance" ); |
572 | break; |
573 | case SK_POSITION_BUILTIN: |
574 | this->write("_out->sk_Position" ); |
575 | break; |
576 | default: |
577 | if (field->fName == "sk_PointSize" ) { |
578 | this->write("_out->sk_PointSize" ); |
579 | } else { |
580 | if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) { |
581 | this->write("_globals->" ); |
582 | this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]); |
583 | this->write("->" ); |
584 | } |
585 | this->writeName(field->fName); |
586 | } |
587 | } |
588 | } |
589 | |
590 | void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) { |
591 | int last = swizzle.fComponents.back(); |
592 | if (last == SKSL_SWIZZLE_0 || last == SKSL_SWIZZLE_1) { |
593 | this->writeType(swizzle.fType); |
594 | this->write("(" ); |
595 | } |
596 | this->writeExpression(*swizzle.fBase, kPostfix_Precedence); |
597 | this->write("." ); |
598 | for (int c : swizzle.fComponents) { |
599 | if (c >= 0) { |
600 | this->write(&("x\0y\0z\0w\0" [c * 2])); |
601 | } |
602 | } |
603 | if (last == SKSL_SWIZZLE_0) { |
604 | this->write(", 0)" ); |
605 | } |
606 | else if (last == SKSL_SWIZZLE_1) { |
607 | this->write(", 1)" ); |
608 | } |
609 | } |
610 | |
611 | MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) { |
612 | switch (op) { |
613 | case Token::STAR: // fall through |
614 | case Token::SLASH: // fall through |
615 | case Token::PERCENT: return MetalCodeGenerator::kMultiplicative_Precedence; |
616 | case Token::PLUS: // fall through |
617 | case Token::MINUS: return MetalCodeGenerator::kAdditive_Precedence; |
618 | case Token::SHL: // fall through |
619 | case Token::SHR: return MetalCodeGenerator::kShift_Precedence; |
620 | case Token::LT: // fall through |
621 | case Token::GT: // fall through |
622 | case Token::LTEQ: // fall through |
623 | case Token::GTEQ: return MetalCodeGenerator::kRelational_Precedence; |
624 | case Token::EQEQ: // fall through |
625 | case Token::NEQ: return MetalCodeGenerator::kEquality_Precedence; |
626 | case Token::BITWISEAND: return MetalCodeGenerator::kBitwiseAnd_Precedence; |
627 | case Token::BITWISEXOR: return MetalCodeGenerator::kBitwiseXor_Precedence; |
628 | case Token::BITWISEOR: return MetalCodeGenerator::kBitwiseOr_Precedence; |
629 | case Token::LOGICALAND: return MetalCodeGenerator::kLogicalAnd_Precedence; |
630 | case Token::LOGICALXOR: return MetalCodeGenerator::kLogicalXor_Precedence; |
631 | case Token::LOGICALOR: return MetalCodeGenerator::kLogicalOr_Precedence; |
632 | case Token::EQ: // fall through |
633 | case Token::PLUSEQ: // fall through |
634 | case Token::MINUSEQ: // fall through |
635 | case Token::STAREQ: // fall through |
636 | case Token::SLASHEQ: // fall through |
637 | case Token::PERCENTEQ: // fall through |
638 | case Token::SHLEQ: // fall through |
639 | case Token::SHREQ: // fall through |
640 | case Token::LOGICALANDEQ: // fall through |
641 | case Token::LOGICALXOREQ: // fall through |
642 | case Token::LOGICALOREQ: // fall through |
643 | case Token::BITWISEANDEQ: // fall through |
644 | case Token::BITWISEXOREQ: // fall through |
645 | case Token::BITWISEOREQ: return MetalCodeGenerator::kAssignment_Precedence; |
646 | case Token::COMMA: return MetalCodeGenerator::kSequence_Precedence; |
647 | default: ABORT("unsupported binary operator" ); |
648 | } |
649 | } |
650 | |
651 | void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right, |
652 | const Type& result) { |
653 | String key = "TimesEqual" + left.name() + right.name(); |
654 | if (fHelpers.find(key) == fHelpers.end()) { |
655 | fExtraFunctions.printf("%s operator*=(thread %s& left, thread const %s& right) {\n" |
656 | " left = left * right;\n" |
657 | " return left;\n" |
658 | "}" , result.name().c_str(), left.name().c_str(), |
659 | right.name().c_str()); |
660 | } |
661 | } |
662 | |
663 | void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b, |
664 | Precedence parentPrecedence) { |
665 | Precedence precedence = GetBinaryPrecedence(b.fOperator); |
666 | bool needParens = precedence >= parentPrecedence; |
667 | switch (b.fOperator) { |
668 | case Token::EQEQ: |
669 | if (b.fLeft->fType.kind() == Type::kVector_Kind) { |
670 | this->write("all" ); |
671 | needParens = true; |
672 | } |
673 | break; |
674 | case Token::NEQ: |
675 | if (b.fLeft->fType.kind() == Type::kVector_Kind) { |
676 | this->write("any" ); |
677 | needParens = true; |
678 | } |
679 | break; |
680 | default: |
681 | break; |
682 | } |
683 | if (needParens) { |
684 | this->write("(" ); |
685 | } |
686 | if (Compiler::IsAssignment(b.fOperator) && |
687 | Expression::kVariableReference_Kind == b.fLeft->fKind && |
688 | Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage && |
689 | (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { |
690 | // writing to an out parameter. Since we have to turn those into pointers, we have to |
691 | // dereference it here. |
692 | this->write("*" ); |
693 | } |
694 | if (b.fOperator == Token::STAREQ && b.fLeft->fType.kind() == Type::kMatrix_Kind && |
695 | b.fRight->fType.kind() == Type::kMatrix_Kind) { |
696 | this->writeMatrixTimesEqualHelper(b.fLeft->fType, b.fRight->fType, b.fType); |
697 | } |
698 | this->writeExpression(*b.fLeft, precedence); |
699 | if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) && |
700 | Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) { |
701 | // This doesn't compile in Metal: |
702 | // float4 x = float4(1); |
703 | // x.xy *= float2x2(...); |
704 | // with the error message "non-const reference cannot bind to vector element", |
705 | // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation |
706 | // as long as the LHS has no side effects, and hope for the best otherwise. |
707 | this->write(" = " ); |
708 | this->writeExpression(*b.fLeft, kAssignment_Precedence); |
709 | this->write(" " ); |
710 | String op = Compiler::OperatorName(b.fOperator); |
711 | SkASSERT(op.endsWith("=" )); |
712 | this->write(op.substr(0, op.size() - 1).c_str()); |
713 | this->write(" " ); |
714 | } else { |
715 | this->write(String(" " ) + Compiler::OperatorName(b.fOperator) + " " ); |
716 | } |
717 | this->writeExpression(*b.fRight, precedence); |
718 | if (needParens) { |
719 | this->write(")" ); |
720 | } |
721 | } |
722 | |
723 | void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t, |
724 | Precedence parentPrecedence) { |
725 | if (kTernary_Precedence >= parentPrecedence) { |
726 | this->write("(" ); |
727 | } |
728 | this->writeExpression(*t.fTest, kTernary_Precedence); |
729 | this->write(" ? " ); |
730 | this->writeExpression(*t.fIfTrue, kTernary_Precedence); |
731 | this->write(" : " ); |
732 | this->writeExpression(*t.fIfFalse, kTernary_Precedence); |
733 | if (kTernary_Precedence >= parentPrecedence) { |
734 | this->write(")" ); |
735 | } |
736 | } |
737 | |
738 | void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p, |
739 | Precedence parentPrecedence) { |
740 | if (kPrefix_Precedence >= parentPrecedence) { |
741 | this->write("(" ); |
742 | } |
743 | this->write(Compiler::OperatorName(p.fOperator)); |
744 | this->writeExpression(*p.fOperand, kPrefix_Precedence); |
745 | if (kPrefix_Precedence >= parentPrecedence) { |
746 | this->write(")" ); |
747 | } |
748 | } |
749 | |
750 | void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p, |
751 | Precedence parentPrecedence) { |
752 | if (kPostfix_Precedence >= parentPrecedence) { |
753 | this->write("(" ); |
754 | } |
755 | this->writeExpression(*p.fOperand, kPostfix_Precedence); |
756 | this->write(Compiler::OperatorName(p.fOperator)); |
757 | if (kPostfix_Precedence >= parentPrecedence) { |
758 | this->write(")" ); |
759 | } |
760 | } |
761 | |
762 | void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { |
763 | this->write(b.fValue ? "true" : "false" ); |
764 | } |
765 | |
766 | void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) { |
767 | if (i.fType == *fContext.fUInt_Type) { |
768 | this->write(to_string(i.fValue & 0xffffffff) + "u" ); |
769 | } else { |
770 | this->write(to_string((int32_t) i.fValue)); |
771 | } |
772 | } |
773 | |
774 | void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { |
775 | this->write(to_string(f.fValue)); |
776 | } |
777 | |
778 | void MetalCodeGenerator::writeSetting(const Setting& s) { |
779 | ABORT("internal error; setting was not folded to a constant during compilation\n" ); |
780 | } |
781 | |
782 | void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) { |
783 | fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals->_anonInterface0->u_skRTHeight" : "" ; |
784 | const char* separator = "" ; |
785 | if ("main" == f.fDeclaration.fName) { |
786 | switch (fProgram.fKind) { |
787 | case Program::kFragment_Kind: |
788 | this->write("fragment Outputs fragmentMain" ); |
789 | break; |
790 | case Program::kVertex_Kind: |
791 | this->write("vertex Outputs vertexMain" ); |
792 | break; |
793 | default: |
794 | SkASSERT(false); |
795 | } |
796 | this->write("(Inputs _in [[stage_in]]" ); |
797 | if (-1 != fUniformBuffer) { |
798 | this->write(", constant Uniforms& _uniforms [[buffer(" + |
799 | to_string(fUniformBuffer) + ")]]" ); |
800 | } |
801 | for (const auto& e : fProgram) { |
802 | if (ProgramElement::kVar_Kind == e.fKind) { |
803 | VarDeclarations& decls = (VarDeclarations&) e; |
804 | if (!decls.fVars.size()) { |
805 | continue; |
806 | } |
807 | for (const auto& stmt: decls.fVars) { |
808 | VarDeclaration& var = (VarDeclaration&) *stmt; |
809 | if (var.fVar->fType.kind() == Type::kSampler_Kind) { |
810 | this->write(", texture2d<float> " ); // FIXME - support other texture types |
811 | this->writeName(var.fVar->fName); |
812 | this->write("[[texture(" ); |
813 | this->write(to_string(var.fVar->fModifiers.fLayout.fBinding)); |
814 | this->write(")]]" ); |
815 | this->write(", sampler " ); |
816 | this->writeName(var.fVar->fName); |
817 | this->write(SAMPLER_SUFFIX); |
818 | this->write("[[sampler(" ); |
819 | this->write(to_string(var.fVar->fModifiers.fLayout.fBinding)); |
820 | this->write(")]]" ); |
821 | } |
822 | } |
823 | } else if (ProgramElement::kInterfaceBlock_Kind == e.fKind) { |
824 | InterfaceBlock& intf = (InterfaceBlock&) e; |
825 | if ("sk_PerVertex" == intf.fTypeName) { |
826 | continue; |
827 | } |
828 | this->write(", constant " ); |
829 | this->writeType(intf.fVariable.fType); |
830 | this->write("& " ); |
831 | this->write(fInterfaceBlockNameMap[&intf]); |
832 | this->write(" [[buffer(" ); |
833 | this->write(to_string(intf.fVariable.fModifiers.fLayout.fBinding)); |
834 | this->write(")]]" ); |
835 | } |
836 | } |
837 | if (fProgram.fKind == Program::kFragment_Kind) { |
838 | if (fProgram.fInputs.fRTHeight && fInterfaceBlockNameMap.empty()) { |
839 | this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]" ); |
840 | fRTHeightName = "_anonInterface0.u_skRTHeight" ; |
841 | } |
842 | this->write(", bool _frontFacing [[front_facing]]" ); |
843 | this->write(", float4 _fragCoord [[position]]" ); |
844 | } else if (fProgram.fKind == Program::kVertex_Kind) { |
845 | this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]" ); |
846 | } |
847 | separator = ", " ; |
848 | } else { |
849 | this->writeType(f.fDeclaration.fReturnType); |
850 | this->write(" " ); |
851 | this->writeName(f.fDeclaration.fName); |
852 | this->write("(" ); |
853 | Requirements requirements = this->requirements(f.fDeclaration); |
854 | if (requirements & kInputs_Requirement) { |
855 | this->write("Inputs _in" ); |
856 | separator = ", " ; |
857 | } |
858 | if (requirements & kOutputs_Requirement) { |
859 | this->write(separator); |
860 | this->write("thread Outputs* _out" ); |
861 | separator = ", " ; |
862 | } |
863 | if (requirements & kUniforms_Requirement) { |
864 | this->write(separator); |
865 | this->write("Uniforms _uniforms" ); |
866 | separator = ", " ; |
867 | } |
868 | if (requirements & kGlobals_Requirement) { |
869 | this->write(separator); |
870 | this->write("thread Globals* _globals" ); |
871 | separator = ", " ; |
872 | } |
873 | if (requirements & kFragCoord_Requirement) { |
874 | this->write(separator); |
875 | this->write("float4 _fragCoord" ); |
876 | separator = ", " ; |
877 | } |
878 | } |
879 | for (const auto& param : f.fDeclaration.fParameters) { |
880 | this->write(separator); |
881 | separator = ", " ; |
882 | this->writeModifiers(param->fModifiers, false); |
883 | std::vector<int> sizes; |
884 | const Type* type = ¶m->fType; |
885 | while (Type::kArray_Kind == type->kind()) { |
886 | sizes.push_back(type->columns()); |
887 | type = &type->componentType(); |
888 | } |
889 | this->writeType(*type); |
890 | if (param->fModifiers.fFlags & Modifiers::kOut_Flag) { |
891 | this->write("*" ); |
892 | } |
893 | this->write(" " ); |
894 | this->writeName(param->fName); |
895 | for (int s : sizes) { |
896 | if (s <= 0) { |
897 | this->write("[]" ); |
898 | } else { |
899 | this->write("[" + to_string(s) + "]" ); |
900 | } |
901 | } |
902 | } |
903 | this->writeLine(") {" ); |
904 | |
905 | SkASSERT(!fProgram.fSettings.fFragColorIsInOut); |
906 | |
907 | if ("main" == f.fDeclaration.fName) { |
908 | if (fNeedsGlobalStructInit) { |
909 | this->writeLine(" Globals globalStruct{" ); |
910 | const char* separator = "" ; |
911 | for (const auto& intf: fInterfaceBlockNameMap) { |
912 | const auto& intfName = intf.second; |
913 | this->write(separator); |
914 | separator = ", " ; |
915 | this->write("&" ); |
916 | this->writeName(intfName); |
917 | } |
918 | for (const auto& var: fInitNonConstGlobalVars) { |
919 | this->write(separator); |
920 | separator = ", " ; |
921 | this->writeVarInitializer(*var->fVar, *var->fValue); |
922 | } |
923 | for (const auto& texture: fTextures) { |
924 | this->write(separator); |
925 | separator = ", " ; |
926 | this->writeName(texture->fName); |
927 | this->write(separator); |
928 | this->writeName(texture->fName); |
929 | this->write(SAMPLER_SUFFIX); |
930 | } |
931 | this->writeLine("};" ); |
932 | this->writeLine(" thread Globals* _globals = &globalStruct;" ); |
933 | this->writeLine(" (void)_globals;" ); |
934 | } |
935 | this->writeLine(" Outputs _outputStruct;" ); |
936 | this->writeLine(" thread Outputs* _out = &_outputStruct;" ); |
937 | } |
938 | fFunctionHeader = "" ; |
939 | OutputStream* oldOut = fOut; |
940 | StringStream buffer; |
941 | fOut = &buffer; |
942 | fIndentation++; |
943 | this->writeStatements(((Block&) *f.fBody).fStatements); |
944 | if ("main" == f.fDeclaration.fName) { |
945 | switch (fProgram.fKind) { |
946 | case Program::kFragment_Kind: |
947 | this->writeLine("return *_out;" ); |
948 | break; |
949 | case Program::kVertex_Kind: |
950 | this->writeLine("_out->sk_Position.y = -_out->sk_Position.y;" ); |
951 | this->writeLine("return *_out;" ); // FIXME - detect if function already has return |
952 | break; |
953 | default: |
954 | SkASSERT(false); |
955 | } |
956 | } |
957 | fIndentation--; |
958 | this->writeLine("}" ); |
959 | |
960 | fOut = oldOut; |
961 | this->write(fFunctionHeader); |
962 | this->write(buffer.str()); |
963 | } |
964 | |
965 | void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers, |
966 | bool globalContext) { |
967 | if (modifiers.fFlags & Modifiers::kOut_Flag) { |
968 | this->write("thread " ); |
969 | } |
970 | if (modifiers.fFlags & Modifiers::kConst_Flag) { |
971 | this->write("constant " ); |
972 | } |
973 | } |
974 | |
975 | void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { |
976 | if ("sk_PerVertex" == intf.fTypeName) { |
977 | return; |
978 | } |
979 | this->writeModifiers(intf.fVariable.fModifiers, true); |
980 | this->write("struct " ); |
981 | this->writeLine(intf.fTypeName + " {" ); |
982 | const Type* structType = &intf.fVariable.fType; |
983 | fWrittenStructs.push_back(structType); |
984 | while (Type::kArray_Kind == structType->kind()) { |
985 | structType = &structType->componentType(); |
986 | } |
987 | fIndentation++; |
988 | writeFields(structType->fields(), structType->fOffset, &intf); |
989 | if (fProgram.fInputs.fRTHeight) { |
990 | this->writeLine("float u_skRTHeight;" ); |
991 | } |
992 | fIndentation--; |
993 | this->write("}" ); |
994 | if (intf.fInstanceName.size()) { |
995 | this->write(" " ); |
996 | this->write(intf.fInstanceName); |
997 | for (const auto& size : intf.fSizes) { |
998 | this->write("[" ); |
999 | if (size) { |
1000 | this->writeExpression(*size, kTopLevel_Precedence); |
1001 | } |
1002 | this->write("]" ); |
1003 | } |
1004 | fInterfaceBlockNameMap[&intf] = intf.fInstanceName; |
1005 | } else { |
1006 | fInterfaceBlockNameMap[&intf] = "_anonInterface" + to_string(fAnonInterfaceCount++); |
1007 | } |
1008 | this->writeLine(";" ); |
1009 | } |
1010 | |
1011 | void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentOffset, |
1012 | const InterfaceBlock* parentIntf) { |
1013 | MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard); |
1014 | int currentOffset = 0; |
1015 | for (const auto& field: fields) { |
1016 | int fieldOffset = field.fModifiers.fLayout.fOffset; |
1017 | const Type* fieldType = field.fType; |
1018 | if (fieldOffset != -1) { |
1019 | if (currentOffset > fieldOffset) { |
1020 | fErrors.error(parentOffset, |
1021 | "offset of field '" + field.fName + "' must be at least " + |
1022 | to_string((int) currentOffset)); |
1023 | } else if (currentOffset < fieldOffset) { |
1024 | this->write("char pad" ); |
1025 | this->write(to_string(fPaddingCount++)); |
1026 | this->write("[" ); |
1027 | this->write(to_string(fieldOffset - currentOffset)); |
1028 | this->writeLine("];" ); |
1029 | currentOffset = fieldOffset; |
1030 | } |
1031 | int alignment = memoryLayout.alignment(*fieldType); |
1032 | if (fieldOffset % alignment) { |
1033 | fErrors.error(parentOffset, |
1034 | "offset of field '" + field.fName + "' must be a multiple of " + |
1035 | to_string((int) alignment)); |
1036 | } |
1037 | } |
1038 | currentOffset += memoryLayout.size(*fieldType); |
1039 | std::vector<int> sizes; |
1040 | while (fieldType->kind() == Type::kArray_Kind) { |
1041 | sizes.push_back(fieldType->columns()); |
1042 | fieldType = &fieldType->componentType(); |
1043 | } |
1044 | this->writeModifiers(field.fModifiers, false); |
1045 | this->writeType(*fieldType); |
1046 | this->write(" " ); |
1047 | this->writeName(field.fName); |
1048 | for (int s : sizes) { |
1049 | if (s <= 0) { |
1050 | this->write("[]" ); |
1051 | } else { |
1052 | this->write("[" + to_string(s) + "]" ); |
1053 | } |
1054 | } |
1055 | this->writeLine(";" ); |
1056 | if (parentIntf) { |
1057 | fInterfaceBlockMap[&field] = parentIntf; |
1058 | } |
1059 | } |
1060 | } |
1061 | |
1062 | void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) { |
1063 | this->writeExpression(value, kTopLevel_Precedence); |
1064 | } |
1065 | |
1066 | void MetalCodeGenerator::writeName(const String& name) { |
1067 | if (fReservedWords.find(name) != fReservedWords.end()) { |
1068 | this->write("_" ); // adding underscore before name to avoid conflict with reserved words |
1069 | } |
1070 | this->write(name); |
1071 | } |
1072 | |
1073 | void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) { |
1074 | SkASSERT(decl.fVars.size() > 0); |
1075 | bool wroteType = false; |
1076 | for (const auto& stmt : decl.fVars) { |
1077 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1078 | if (global && !(var.fVar->fModifiers.fFlags & Modifiers::kConst_Flag)) { |
1079 | continue; |
1080 | } |
1081 | if (wroteType) { |
1082 | this->write(", " ); |
1083 | } else { |
1084 | this->writeModifiers(var.fVar->fModifiers, global); |
1085 | this->writeType(decl.fBaseType); |
1086 | this->write(" " ); |
1087 | wroteType = true; |
1088 | } |
1089 | this->writeName(var.fVar->fName); |
1090 | for (const auto& size : var.fSizes) { |
1091 | this->write("[" ); |
1092 | if (size) { |
1093 | this->writeExpression(*size, kTopLevel_Precedence); |
1094 | } |
1095 | this->write("]" ); |
1096 | } |
1097 | if (var.fValue) { |
1098 | this->write(" = " ); |
1099 | this->writeVarInitializer(*var.fVar, *var.fValue); |
1100 | } |
1101 | } |
1102 | if (wroteType) { |
1103 | this->write(";" ); |
1104 | } |
1105 | } |
1106 | |
1107 | void MetalCodeGenerator::writeStatement(const Statement& s) { |
1108 | switch (s.fKind) { |
1109 | case Statement::kBlock_Kind: |
1110 | this->writeBlock((Block&) s); |
1111 | break; |
1112 | case Statement::kExpression_Kind: |
1113 | this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence); |
1114 | this->write(";" ); |
1115 | break; |
1116 | case Statement::kReturn_Kind: |
1117 | this->writeReturnStatement((ReturnStatement&) s); |
1118 | break; |
1119 | case Statement::kVarDeclarations_Kind: |
1120 | this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false); |
1121 | break; |
1122 | case Statement::kIf_Kind: |
1123 | this->writeIfStatement((IfStatement&) s); |
1124 | break; |
1125 | case Statement::kFor_Kind: |
1126 | this->writeForStatement((ForStatement&) s); |
1127 | break; |
1128 | case Statement::kWhile_Kind: |
1129 | this->writeWhileStatement((WhileStatement&) s); |
1130 | break; |
1131 | case Statement::kDo_Kind: |
1132 | this->writeDoStatement((DoStatement&) s); |
1133 | break; |
1134 | case Statement::kSwitch_Kind: |
1135 | this->writeSwitchStatement((SwitchStatement&) s); |
1136 | break; |
1137 | case Statement::kBreak_Kind: |
1138 | this->write("break;" ); |
1139 | break; |
1140 | case Statement::kContinue_Kind: |
1141 | this->write("continue;" ); |
1142 | break; |
1143 | case Statement::kDiscard_Kind: |
1144 | this->write("discard_fragment();" ); |
1145 | break; |
1146 | case Statement::kNop_Kind: |
1147 | this->write(";" ); |
1148 | break; |
1149 | default: |
1150 | #ifdef SK_DEBUG |
1151 | ABORT("unsupported statement: %s" , s.description().c_str()); |
1152 | #endif |
1153 | break; |
1154 | } |
1155 | } |
1156 | |
1157 | void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) { |
1158 | for (const auto& s : statements) { |
1159 | if (!s->isEmpty()) { |
1160 | this->writeStatement(*s); |
1161 | this->writeLine(); |
1162 | } |
1163 | } |
1164 | } |
1165 | |
1166 | void MetalCodeGenerator::writeBlock(const Block& b) { |
1167 | this->writeLine("{" ); |
1168 | fIndentation++; |
1169 | this->writeStatements(b.fStatements); |
1170 | fIndentation--; |
1171 | this->write("}" ); |
1172 | } |
1173 | |
1174 | void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) { |
1175 | this->write("if (" ); |
1176 | this->writeExpression(*stmt.fTest, kTopLevel_Precedence); |
1177 | this->write(") " ); |
1178 | this->writeStatement(*stmt.fIfTrue); |
1179 | if (stmt.fIfFalse) { |
1180 | this->write(" else " ); |
1181 | this->writeStatement(*stmt.fIfFalse); |
1182 | } |
1183 | } |
1184 | |
1185 | void MetalCodeGenerator::writeForStatement(const ForStatement& f) { |
1186 | this->write("for (" ); |
1187 | if (f.fInitializer && !f.fInitializer->isEmpty()) { |
1188 | this->writeStatement(*f.fInitializer); |
1189 | } else { |
1190 | this->write("; " ); |
1191 | } |
1192 | if (f.fTest) { |
1193 | this->writeExpression(*f.fTest, kTopLevel_Precedence); |
1194 | } |
1195 | this->write("; " ); |
1196 | if (f.fNext) { |
1197 | this->writeExpression(*f.fNext, kTopLevel_Precedence); |
1198 | } |
1199 | this->write(") " ); |
1200 | this->writeStatement(*f.fStatement); |
1201 | } |
1202 | |
1203 | void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) { |
1204 | this->write("while (" ); |
1205 | this->writeExpression(*w.fTest, kTopLevel_Precedence); |
1206 | this->write(") " ); |
1207 | this->writeStatement(*w.fStatement); |
1208 | } |
1209 | |
1210 | void MetalCodeGenerator::writeDoStatement(const DoStatement& d) { |
1211 | this->write("do " ); |
1212 | this->writeStatement(*d.fStatement); |
1213 | this->write(" while (" ); |
1214 | this->writeExpression(*d.fTest, kTopLevel_Precedence); |
1215 | this->write(");" ); |
1216 | } |
1217 | |
1218 | void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) { |
1219 | this->write("switch (" ); |
1220 | this->writeExpression(*s.fValue, kTopLevel_Precedence); |
1221 | this->writeLine(") {" ); |
1222 | fIndentation++; |
1223 | for (const auto& c : s.fCases) { |
1224 | if (c->fValue) { |
1225 | this->write("case " ); |
1226 | this->writeExpression(*c->fValue, kTopLevel_Precedence); |
1227 | this->writeLine(":" ); |
1228 | } else { |
1229 | this->writeLine("default:" ); |
1230 | } |
1231 | fIndentation++; |
1232 | for (const auto& stmt : c->fStatements) { |
1233 | this->writeStatement(*stmt); |
1234 | this->writeLine(); |
1235 | } |
1236 | fIndentation--; |
1237 | } |
1238 | fIndentation--; |
1239 | this->write("}" ); |
1240 | } |
1241 | |
1242 | void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) { |
1243 | this->write("return" ); |
1244 | if (r.fExpression) { |
1245 | this->write(" " ); |
1246 | this->writeExpression(*r.fExpression, kTopLevel_Precedence); |
1247 | } |
1248 | this->write(";" ); |
1249 | } |
1250 | |
1251 | void MetalCodeGenerator::() { |
1252 | this->write("#include <metal_stdlib>\n" ); |
1253 | this->write("#include <simd/simd.h>\n" ); |
1254 | this->write("using namespace metal;\n" ); |
1255 | } |
1256 | |
1257 | void MetalCodeGenerator::writeUniformStruct() { |
1258 | for (const auto& e : fProgram) { |
1259 | if (ProgramElement::kVar_Kind == e.fKind) { |
1260 | VarDeclarations& decls = (VarDeclarations&) e; |
1261 | if (!decls.fVars.size()) { |
1262 | continue; |
1263 | } |
1264 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1265 | if (first.fModifiers.fFlags & Modifiers::kUniform_Flag && |
1266 | first.fType.kind() != Type::kSampler_Kind) { |
1267 | if (-1 == fUniformBuffer) { |
1268 | this->write("struct Uniforms {\n" ); |
1269 | fUniformBuffer = first.fModifiers.fLayout.fSet; |
1270 | if (-1 == fUniformBuffer) { |
1271 | fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'" ); |
1272 | } |
1273 | } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) { |
1274 | if (-1 == fUniformBuffer) { |
1275 | fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have " |
1276 | "the same 'layout(set=...)'" ); |
1277 | } |
1278 | } |
1279 | this->write(" " ); |
1280 | this->writeType(first.fType); |
1281 | this->write(" " ); |
1282 | for (const auto& stmt : decls.fVars) { |
1283 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1284 | this->writeName(var.fVar->fName); |
1285 | } |
1286 | this->write(";\n" ); |
1287 | } |
1288 | } |
1289 | } |
1290 | if (-1 != fUniformBuffer) { |
1291 | this->write("};\n" ); |
1292 | } |
1293 | } |
1294 | |
1295 | void MetalCodeGenerator::writeInputStruct() { |
1296 | this->write("struct Inputs {\n" ); |
1297 | for (const auto& e : fProgram) { |
1298 | if (ProgramElement::kVar_Kind == e.fKind) { |
1299 | VarDeclarations& decls = (VarDeclarations&) e; |
1300 | if (!decls.fVars.size()) { |
1301 | continue; |
1302 | } |
1303 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1304 | if (first.fModifiers.fFlags & Modifiers::kIn_Flag && |
1305 | -1 == first.fModifiers.fLayout.fBuiltin) { |
1306 | this->write(" " ); |
1307 | this->writeType(first.fType); |
1308 | this->write(" " ); |
1309 | for (const auto& stmt : decls.fVars) { |
1310 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1311 | this->writeName(var.fVar->fName); |
1312 | if (-1 != var.fVar->fModifiers.fLayout.fLocation) { |
1313 | if (fProgram.fKind == Program::kVertex_Kind) { |
1314 | this->write(" [[attribute(" + |
1315 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1316 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1317 | this->write(" [[user(locn" + |
1318 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1319 | } |
1320 | } |
1321 | } |
1322 | this->write(";\n" ); |
1323 | } |
1324 | } |
1325 | } |
1326 | this->write("};\n" ); |
1327 | } |
1328 | |
1329 | void MetalCodeGenerator::writeOutputStruct() { |
1330 | this->write("struct Outputs {\n" ); |
1331 | if (fProgram.fKind == Program::kVertex_Kind) { |
1332 | this->write(" float4 sk_Position [[position]];\n" ); |
1333 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1334 | this->write(" float4 sk_FragColor [[color(0)]];\n" ); |
1335 | } |
1336 | for (const auto& e : fProgram) { |
1337 | if (ProgramElement::kVar_Kind == e.fKind) { |
1338 | VarDeclarations& decls = (VarDeclarations&) e; |
1339 | if (!decls.fVars.size()) { |
1340 | continue; |
1341 | } |
1342 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1343 | if (first.fModifiers.fFlags & Modifiers::kOut_Flag && |
1344 | -1 == first.fModifiers.fLayout.fBuiltin) { |
1345 | this->write(" " ); |
1346 | this->writeType(first.fType); |
1347 | this->write(" " ); |
1348 | for (const auto& stmt : decls.fVars) { |
1349 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1350 | this->writeName(var.fVar->fName); |
1351 | if (fProgram.fKind == Program::kVertex_Kind) { |
1352 | this->write(" [[user(locn" + |
1353 | to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]" ); |
1354 | } else if (fProgram.fKind == Program::kFragment_Kind) { |
1355 | this->write(" [[color(" + |
1356 | to_string(var.fVar->fModifiers.fLayout.fLocation) +")" ); |
1357 | int colorIndex = var.fVar->fModifiers.fLayout.fIndex; |
1358 | if (colorIndex) { |
1359 | this->write(", index(" + to_string(colorIndex) + ")" ); |
1360 | } |
1361 | this->write("]]" ); |
1362 | } |
1363 | } |
1364 | this->write(";\n" ); |
1365 | } |
1366 | } |
1367 | } |
1368 | if (fProgram.fKind == Program::kVertex_Kind) { |
1369 | this->write(" float sk_PointSize;\n" ); |
1370 | } |
1371 | this->write("};\n" ); |
1372 | } |
1373 | |
1374 | void MetalCodeGenerator::writeInterfaceBlocks() { |
1375 | bool wroteInterfaceBlock = false; |
1376 | for (const auto& e : fProgram) { |
1377 | if (ProgramElement::kInterfaceBlock_Kind == e.fKind) { |
1378 | this->writeInterfaceBlock((InterfaceBlock&) e); |
1379 | wroteInterfaceBlock = true; |
1380 | } |
1381 | } |
1382 | if (!wroteInterfaceBlock && fProgram.fInputs.fRTHeight) { |
1383 | this->writeLine("struct sksl_synthetic_uniforms {" ); |
1384 | this->writeLine(" float u_skRTHeight;" ); |
1385 | this->writeLine("};" ); |
1386 | } |
1387 | } |
1388 | |
1389 | void MetalCodeGenerator::writeGlobalStruct() { |
1390 | bool wroteStructDecl = false; |
1391 | for (const auto& intf : fInterfaceBlockNameMap) { |
1392 | if (!wroteStructDecl) { |
1393 | this->write("struct Globals {\n" ); |
1394 | wroteStructDecl = true; |
1395 | } |
1396 | fNeedsGlobalStructInit = true; |
1397 | const auto& intfType = intf.first; |
1398 | const auto& intfName = intf.second; |
1399 | this->write(" constant " ); |
1400 | this->write(intfType->fTypeName); |
1401 | this->write("* " ); |
1402 | this->writeName(intfName); |
1403 | this->write(";\n" ); |
1404 | } |
1405 | for (const auto& e : fProgram) { |
1406 | if (ProgramElement::kVar_Kind == e.fKind) { |
1407 | VarDeclarations& decls = (VarDeclarations&) e; |
1408 | if (!decls.fVars.size()) { |
1409 | continue; |
1410 | } |
1411 | const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; |
1412 | if ((!first.fModifiers.fFlags && -1 == first.fModifiers.fLayout.fBuiltin) || |
1413 | first.fType.kind() == Type::kSampler_Kind) { |
1414 | if (!wroteStructDecl) { |
1415 | this->write("struct Globals {\n" ); |
1416 | wroteStructDecl = true; |
1417 | } |
1418 | fNeedsGlobalStructInit = true; |
1419 | this->write(" " ); |
1420 | this->writeType(first.fType); |
1421 | this->write(" " ); |
1422 | for (const auto& stmt : decls.fVars) { |
1423 | VarDeclaration& var = (VarDeclaration&) *stmt; |
1424 | this->writeName(var.fVar->fName); |
1425 | if (var.fVar->fType.kind() == Type::kSampler_Kind) { |
1426 | fTextures.push_back(var.fVar); |
1427 | this->write(";\n" ); |
1428 | this->write(" sampler " ); |
1429 | this->writeName(var.fVar->fName); |
1430 | this->write(SAMPLER_SUFFIX); |
1431 | } |
1432 | if (var.fValue) { |
1433 | fInitNonConstGlobalVars.push_back(&var); |
1434 | } |
1435 | } |
1436 | this->write(";\n" ); |
1437 | } |
1438 | } |
1439 | } |
1440 | if (wroteStructDecl) { |
1441 | this->write("};\n" ); |
1442 | } |
1443 | } |
1444 | |
1445 | void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) { |
1446 | switch (e.fKind) { |
1447 | case ProgramElement::kExtension_Kind: |
1448 | break; |
1449 | case ProgramElement::kVar_Kind: { |
1450 | VarDeclarations& decl = (VarDeclarations&) e; |
1451 | if (decl.fVars.size() > 0) { |
1452 | int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin; |
1453 | if (-1 == builtin) { |
1454 | // normal var |
1455 | this->writeVarDeclarations(decl, true); |
1456 | this->writeLine(); |
1457 | } else if (SK_FRAGCOLOR_BUILTIN == builtin) { |
1458 | // ignore |
1459 | } |
1460 | } |
1461 | break; |
1462 | } |
1463 | case ProgramElement::kInterfaceBlock_Kind: |
1464 | // handled in writeInterfaceBlocks, do nothing |
1465 | break; |
1466 | case ProgramElement::kFunction_Kind: |
1467 | this->writeFunction((FunctionDefinition&) e); |
1468 | break; |
1469 | case ProgramElement::kModifiers_Kind: |
1470 | this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true); |
1471 | this->writeLine(";" ); |
1472 | break; |
1473 | default: |
1474 | #ifdef SK_DEBUG |
1475 | ABORT("unsupported program element: %s\n" , e.description().c_str()); |
1476 | #endif |
1477 | break; |
1478 | } |
1479 | } |
1480 | |
1481 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) { |
1482 | switch (e.fKind) { |
1483 | case Expression::kFunctionCall_Kind: { |
1484 | const FunctionCall& f = (const FunctionCall&) e; |
1485 | Requirements result = this->requirements(f.fFunction); |
1486 | for (const auto& e : f.fArguments) { |
1487 | result |= this->requirements(*e); |
1488 | } |
1489 | return result; |
1490 | } |
1491 | case Expression::kConstructor_Kind: { |
1492 | const Constructor& c = (const Constructor&) e; |
1493 | Requirements result = kNo_Requirements; |
1494 | for (const auto& e : c.fArguments) { |
1495 | result |= this->requirements(*e); |
1496 | } |
1497 | return result; |
1498 | } |
1499 | case Expression::kFieldAccess_Kind: { |
1500 | const FieldAccess& f = (const FieldAccess&) e; |
1501 | if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) { |
1502 | return kGlobals_Requirement; |
1503 | } |
1504 | return this->requirements(*((const FieldAccess&) e).fBase); |
1505 | } |
1506 | case Expression::kSwizzle_Kind: |
1507 | return this->requirements(*((const Swizzle&) e).fBase); |
1508 | case Expression::kBinary_Kind: { |
1509 | const BinaryExpression& b = (const BinaryExpression&) e; |
1510 | return this->requirements(*b.fLeft) | this->requirements(*b.fRight); |
1511 | } |
1512 | case Expression::kIndex_Kind: { |
1513 | const IndexExpression& idx = (const IndexExpression&) e; |
1514 | return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex); |
1515 | } |
1516 | case Expression::kPrefix_Kind: |
1517 | return this->requirements(*((const PrefixExpression&) e).fOperand); |
1518 | case Expression::kPostfix_Kind: |
1519 | return this->requirements(*((const PostfixExpression&) e).fOperand); |
1520 | case Expression::kTernary_Kind: { |
1521 | const TernaryExpression& t = (const TernaryExpression&) e; |
1522 | return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) | |
1523 | this->requirements(*t.fIfFalse); |
1524 | } |
1525 | case Expression::kVariableReference_Kind: { |
1526 | const VariableReference& v = (const VariableReference&) e; |
1527 | Requirements result = kNo_Requirements; |
1528 | if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) { |
1529 | result = kGlobals_Requirement | kFragCoord_Requirement; |
1530 | } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) { |
1531 | if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { |
1532 | result = kInputs_Requirement; |
1533 | } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { |
1534 | result = kOutputs_Requirement; |
1535 | } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && |
1536 | v.fVariable.fType.kind() != Type::kSampler_Kind) { |
1537 | result = kUniforms_Requirement; |
1538 | } else { |
1539 | result = kGlobals_Requirement; |
1540 | } |
1541 | } |
1542 | return result; |
1543 | } |
1544 | default: |
1545 | return kNo_Requirements; |
1546 | } |
1547 | } |
1548 | |
1549 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) { |
1550 | switch (s.fKind) { |
1551 | case Statement::kBlock_Kind: { |
1552 | Requirements result = kNo_Requirements; |
1553 | for (const auto& child : ((const Block&) s).fStatements) { |
1554 | result |= this->requirements(*child); |
1555 | } |
1556 | return result; |
1557 | } |
1558 | case Statement::kVarDeclaration_Kind: { |
1559 | Requirements result = kNo_Requirements; |
1560 | const VarDeclaration& var = (const VarDeclaration&) s; |
1561 | if (var.fValue) { |
1562 | result = this->requirements(*var.fValue); |
1563 | } |
1564 | return result; |
1565 | } |
1566 | case Statement::kVarDeclarations_Kind: { |
1567 | Requirements result = kNo_Requirements; |
1568 | const VarDeclarations& decls = *((const VarDeclarationsStatement&) s).fDeclaration; |
1569 | for (const auto& stmt : decls.fVars) { |
1570 | result |= this->requirements(*stmt); |
1571 | } |
1572 | return result; |
1573 | } |
1574 | case Statement::kExpression_Kind: |
1575 | return this->requirements(*((const ExpressionStatement&) s).fExpression); |
1576 | case Statement::kReturn_Kind: { |
1577 | const ReturnStatement& r = (const ReturnStatement&) s; |
1578 | if (r.fExpression) { |
1579 | return this->requirements(*r.fExpression); |
1580 | } |
1581 | return kNo_Requirements; |
1582 | } |
1583 | case Statement::kIf_Kind: { |
1584 | const IfStatement& i = (const IfStatement&) s; |
1585 | return this->requirements(*i.fTest) | |
1586 | this->requirements(*i.fIfTrue) | |
1587 | (i.fIfFalse ? this->requirements(*i.fIfFalse) : 0); |
1588 | } |
1589 | case Statement::kFor_Kind: { |
1590 | const ForStatement& f = (const ForStatement&) s; |
1591 | return this->requirements(*f.fInitializer) | |
1592 | this->requirements(*f.fTest) | |
1593 | this->requirements(*f.fNext) | |
1594 | this->requirements(*f.fStatement); |
1595 | } |
1596 | case Statement::kWhile_Kind: { |
1597 | const WhileStatement& w = (const WhileStatement&) s; |
1598 | return this->requirements(*w.fTest) | |
1599 | this->requirements(*w.fStatement); |
1600 | } |
1601 | case Statement::kDo_Kind: { |
1602 | const DoStatement& d = (const DoStatement&) s; |
1603 | return this->requirements(*d.fTest) | |
1604 | this->requirements(*d.fStatement); |
1605 | } |
1606 | case Statement::kSwitch_Kind: { |
1607 | const SwitchStatement& sw = (const SwitchStatement&) s; |
1608 | Requirements result = this->requirements(*sw.fValue); |
1609 | for (const auto& c : sw.fCases) { |
1610 | for (const auto& st : c->fStatements) { |
1611 | result |= this->requirements(*st); |
1612 | } |
1613 | } |
1614 | return result; |
1615 | } |
1616 | default: |
1617 | return kNo_Requirements; |
1618 | } |
1619 | } |
1620 | |
1621 | MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) { |
1622 | if (f.fBuiltin) { |
1623 | return kNo_Requirements; |
1624 | } |
1625 | auto found = fRequirements.find(&f); |
1626 | if (found == fRequirements.end()) { |
1627 | fRequirements[&f] = kNo_Requirements; |
1628 | for (const auto& e : fProgram) { |
1629 | if (ProgramElement::kFunction_Kind == e.fKind) { |
1630 | const FunctionDefinition& def = (const FunctionDefinition&) e; |
1631 | if (&def.fDeclaration == &f) { |
1632 | Requirements reqs = this->requirements(*def.fBody); |
1633 | fRequirements[&f] = reqs; |
1634 | return reqs; |
1635 | } |
1636 | } |
1637 | } |
1638 | } |
1639 | return found->second; |
1640 | } |
1641 | |
1642 | bool MetalCodeGenerator::generateCode() { |
1643 | OutputStream* rawOut = fOut; |
1644 | fOut = &fHeader; |
1645 | fProgramKind = fProgram.fKind; |
1646 | this->writeHeader(); |
1647 | this->writeUniformStruct(); |
1648 | this->writeInputStruct(); |
1649 | this->writeOutputStruct(); |
1650 | this->writeInterfaceBlocks(); |
1651 | this->writeGlobalStruct(); |
1652 | StringStream body; |
1653 | fOut = &body; |
1654 | for (const auto& e : fProgram) { |
1655 | this->writeProgramElement(e); |
1656 | } |
1657 | fOut = rawOut; |
1658 | |
1659 | write_stringstream(fHeader, *rawOut); |
1660 | write_stringstream(fExtraFunctions, *rawOut); |
1661 | write_stringstream(body, *rawOut); |
1662 | return true; |
1663 | } |
1664 | |
1665 | } |
1666 | |