Update rive-cpp to 2.0 version
[platform/core/uifw/rive-tizen.git] / submodule / skia / src / sksl / codegen / SkSLMetalCodeGenerator.cpp
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkSLLayout.h"
13 #include "include/private/SkSLModifiers.h"
14 #include "include/private/SkSLProgramElement.h"
15 #include "include/private/SkSLStatement.h"
16 #include "include/private/SkSLString.h"
17 #include "include/sksl/SkSLErrorReporter.h"
18 #include "include/sksl/SkSLPosition.h"
19 #include "src/core/SkScopeExit.h"
20 #include "src/sksl/SkSLAnalysis.h"
21 #include "src/sksl/SkSLBuiltinTypes.h"
22 #include "src/sksl/SkSLCompiler.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLMemoryLayout.h"
25 #include "src/sksl/SkSLOutputStream.h"
26 #include "src/sksl/SkSLProgramSettings.h"
27 #include "src/sksl/SkSLUtil.h"
28 #include "src/sksl/ir/SkSLBinaryExpression.h"
29 #include "src/sksl/ir/SkSLBlock.h"
30 #include "src/sksl/ir/SkSLConstructor.h"
31 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
32 #include "src/sksl/ir/SkSLConstructorCompound.h"
33 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
34 #include "src/sksl/ir/SkSLDoStatement.h"
35 #include "src/sksl/ir/SkSLExpression.h"
36 #include "src/sksl/ir/SkSLExpressionStatement.h"
37 #include "src/sksl/ir/SkSLExtension.h"
38 #include "src/sksl/ir/SkSLFieldAccess.h"
39 #include "src/sksl/ir/SkSLForStatement.h"
40 #include "src/sksl/ir/SkSLFunctionCall.h"
41 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
42 #include "src/sksl/ir/SkSLFunctionDefinition.h"
43 #include "src/sksl/ir/SkSLFunctionPrototype.h"
44 #include "src/sksl/ir/SkSLIfStatement.h"
45 #include "src/sksl/ir/SkSLIndexExpression.h"
46 #include "src/sksl/ir/SkSLInterfaceBlock.h"
47 #include "src/sksl/ir/SkSLLiteral.h"
48 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
49 #include "src/sksl/ir/SkSLNop.h"
50 #include "src/sksl/ir/SkSLPostfixExpression.h"
51 #include "src/sksl/ir/SkSLPrefixExpression.h"
52 #include "src/sksl/ir/SkSLProgram.h"
53 #include "src/sksl/ir/SkSLReturnStatement.h"
54 #include "src/sksl/ir/SkSLSetting.h"
55 #include "src/sksl/ir/SkSLStructDefinition.h"
56 #include "src/sksl/ir/SkSLSwitchCase.h"
57 #include "src/sksl/ir/SkSLSwitchStatement.h"
58 #include "src/sksl/ir/SkSLSwizzle.h"
59 #include "src/sksl/ir/SkSLTernaryExpression.h"
60 #include "src/sksl/ir/SkSLVarDeclarations.h"
61 #include "src/sksl/ir/SkSLVariable.h"
62 #include "src/sksl/ir/SkSLVariableReference.h"
63 #include "src/sksl/spirv.h"
64
65 #include <algorithm>
66 #include <functional>
67 #include <limits>
68 #include <memory>
69 #include <type_traits>
70
71 namespace SkSL {
72
73 static const char* operator_name(Operator op) {
74     switch (op.kind()) {
75         case Operator::Kind::LOGICALXOR:  return " != ";
76         default:                          return op.operatorName();
77     }
78 }
79
80 class MetalCodeGenerator::GlobalStructVisitor {
81 public:
82     virtual ~GlobalStructVisitor() = default;
83     virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) = 0;
84     virtual void visitTexture(const Type& type, std::string_view name) = 0;
85     virtual void visitSampler(const Type& type, std::string_view name) = 0;
86     virtual void visitVariable(const Variable& var, const Expression* value) = 0;
87 };
88
89 void MetalCodeGenerator::write(std::string_view s) {
90     if (s.empty()) {
91         return;
92     }
93     if (fAtLineStart) {
94         for (int i = 0; i < fIndentation; i++) {
95             fOut->writeText("    ");
96         }
97     }
98     fOut->writeText(std::string(s).c_str());
99     fAtLineStart = false;
100 }
101
102 void MetalCodeGenerator::writeLine(std::string_view s) {
103     this->write(s);
104     fOut->writeText(fLineEnding);
105     fAtLineStart = true;
106 }
107
108 void MetalCodeGenerator::finishLine() {
109     if (!fAtLineStart) {
110         this->writeLine();
111     }
112 }
113
114 void MetalCodeGenerator::writeExtension(const Extension& ext) {
115     this->writeLine("#extension " + std::string(ext.name()) + " : enable");
116 }
117
118 std::string MetalCodeGenerator::typeName(const Type& type) {
119     switch (type.typeKind()) {
120         case Type::TypeKind::kArray:
121             SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
122             return String::printf("array<%s, %d>",
123                                   this->typeName(type.componentType()).c_str(), type.columns());
124
125         case Type::TypeKind::kVector:
126             return this->typeName(type.componentType()) + std::to_string(type.columns());
127
128         case Type::TypeKind::kMatrix:
129             return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
130                                   std::to_string(type.rows());
131
132         case Type::TypeKind::kSampler:
133             return "texture2d<half>"; // FIXME - support other texture types
134
135         default:
136             return std::string(type.name());
137     }
138 }
139
140 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
141     const Type& type = s.type();
142     this->writeLine("struct " + type.displayName() + " {");
143     fIndentation++;
144     this->writeFields(type.fields(), type.fPosition);
145     fIndentation--;
146     this->writeLine("};");
147 }
148
149 void MetalCodeGenerator::writeType(const Type& type) {
150     this->write(this->typeName(type));
151 }
152
153 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
154     switch (expr.kind()) {
155         case Expression::Kind::kBinary:
156             this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
157             break;
158         case Expression::Kind::kConstructorArray:
159         case Expression::Kind::kConstructorStruct:
160             this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
161             break;
162         case Expression::Kind::kConstructorArrayCast:
163             this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
164             break;
165         case Expression::Kind::kConstructorCompound:
166             this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
167             break;
168         case Expression::Kind::kConstructorDiagonalMatrix:
169         case Expression::Kind::kConstructorSplat:
170             this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
171             break;
172         case Expression::Kind::kConstructorMatrixResize:
173             this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
174                                                parentPrecedence);
175             break;
176         case Expression::Kind::kConstructorScalarCast:
177         case Expression::Kind::kConstructorCompoundCast:
178             this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
179             break;
180         case Expression::Kind::kFieldAccess:
181             this->writeFieldAccess(expr.as<FieldAccess>());
182             break;
183         case Expression::Kind::kLiteral:
184             this->writeLiteral(expr.as<Literal>());
185             break;
186         case Expression::Kind::kFunctionCall:
187             this->writeFunctionCall(expr.as<FunctionCall>());
188             break;
189         case Expression::Kind::kPrefix:
190             this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
191             break;
192         case Expression::Kind::kPostfix:
193             this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
194             break;
195         case Expression::Kind::kSetting:
196             this->writeSetting(expr.as<Setting>());
197             break;
198         case Expression::Kind::kSwizzle:
199             this->writeSwizzle(expr.as<Swizzle>());
200             break;
201         case Expression::Kind::kVariableReference:
202             this->writeVariableReference(expr.as<VariableReference>());
203             break;
204         case Expression::Kind::kTernary:
205             this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
206             break;
207         case Expression::Kind::kIndex:
208             this->writeIndexExpression(expr.as<IndexExpression>());
209             break;
210         default:
211             SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
212             break;
213     }
214 }
215
216 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
217                                              const ExpressionArray& arguments,
218                                              const SkTArray<VariableReference*>& outVars) {
219     AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
220     const FunctionDeclaration& function = call.function();
221
222     std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
223                        "_" + function.mangledName();
224     const char* separator = "";
225
226     // Emit a prototype for the function we'll be calling through to in our helper.
227     if (!function.isBuiltin()) {
228         this->writeFunctionDeclaration(function);
229         this->writeLine(";");
230     }
231
232     // Synthesize a helper function that takes the same inputs as `function`, except in places where
233     // `outVars` is non-null; in those places, we take the type of the VariableReference.
234     //
235     // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
236     this->writeType(call.type());
237     this->write(" ");
238     this->write(name);
239     this->write("(");
240     this->writeFunctionRequirementParams(function, separator);
241
242     SkASSERT(outVars.size() == arguments.size());
243     SkASSERT(outVars.size() == function.parameters().size());
244
245     // We need to detect cases where the caller passes the same variable as an out-param more than
246     // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
247     // redundant input parameter entirely, and not give it any name.)
248     SkTHashSet<const Variable*> writtenVars;
249
250     for (int index = 0; index < arguments.count(); ++index) {
251         this->write(separator);
252         separator = ", ";
253
254         const Variable* param = function.parameters()[index];
255         this->writeModifiers(param->modifiers());
256
257         const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
258         this->writeType(*type);
259
260         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
261             this->write("&");
262         }
263         if (outVars[index]) {
264             const Variable* var = outVars[index]->variable();
265             if (!writtenVars.contains(var)) {
266                 writtenVars.add(var);
267
268                 this->write(" ");
269                 fIgnoreVariableReferenceModifiers = true;
270                 this->writeVariableReference(*outVars[index]);
271                 fIgnoreVariableReferenceModifiers = false;
272             }
273         } else {
274             this->write(" _var");
275             this->write(std::to_string(index));
276         }
277     }
278     this->writeLine(") {");
279
280     ++fIndentation;
281     for (int index = 0; index < outVars.count(); ++index) {
282         if (!outVars[index]) {
283             continue;
284         }
285         // float3 _var2[ = outParam.zyx];
286         this->writeType(arguments[index]->type());
287         this->write(" _var");
288         this->write(std::to_string(index));
289
290         const Variable* param = function.parameters()[index];
291         if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
292             this->write(" = ");
293             fIgnoreVariableReferenceModifiers = true;
294             this->writeExpression(*arguments[index], Precedence::kAssignment);
295             fIgnoreVariableReferenceModifiers = false;
296         }
297
298         this->writeLine(";");
299     }
300
301     // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
302     bool hasResult = (call.type().name() != "void");
303     if (hasResult) {
304         this->writeType(call.type());
305         this->write(" _skResult = ");
306     }
307
308     this->writeName(function.mangledName());
309     this->write("(");
310     separator = "";
311     this->writeFunctionRequirementArgs(function, separator);
312
313     for (int index = 0; index < arguments.count(); ++index) {
314         this->write(separator);
315         separator = ", ";
316
317         this->write("_var");
318         this->write(std::to_string(index));
319     }
320     this->writeLine(");");
321
322     for (int index = 0; index < outVars.count(); ++index) {
323         if (!outVars[index]) {
324             continue;
325         }
326         // outParam.zyx = _var2;
327         fIgnoreVariableReferenceModifiers = true;
328         this->writeExpression(*arguments[index], Precedence::kAssignment);
329         fIgnoreVariableReferenceModifiers = false;
330         this->write(" = _var");
331         this->write(std::to_string(index));
332         this->writeLine(";");
333     }
334
335     if (hasResult) {
336         this->writeLine("return _skResult;");
337     }
338
339     --fIndentation;
340     this->writeLine("}");
341
342     return name;
343 }
344
345 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
346     return "as_type<" +  outType.displayName() + ">";
347 }
348
349 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
350     const FunctionDeclaration& function = c.function();
351
352     // Many intrinsics need to be rewritten in Metal.
353     if (function.isIntrinsic()) {
354         if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
355             return;
356         }
357     }
358
359     // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
360     // helper function. (Specifically, out-parameters in GLSL are only written back to the original
361     // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
362     // allow a swizzle to be passed to a `floatN&`.)
363     const ExpressionArray& arguments = c.arguments();
364     const std::vector<const Variable*>& parameters = function.parameters();
365     SkASSERT(arguments.size() == parameters.size());
366
367     bool foundOutParam = false;
368     SkSTArray<16, VariableReference*> outVars;
369     outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
370
371     for (int index = 0; index < arguments.count(); ++index) {
372         // If this is an out parameter...
373         if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
374             // Find the expression's inner variable being written to.
375             Analysis::AssignmentInfo info;
376             // Assignability was verified at IRGeneration time, so this should always succeed.
377             SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
378             outVars[index] = info.fAssignedVar;
379             foundOutParam = true;
380         }
381     }
382
383     if (foundOutParam) {
384         // Out parameters need to be written back to at the end of the function. To do this, we
385         // synthesize a helper function which evaluates the out-param expression into a temporary
386         // variable, calls the original function, then writes the temp var back into the out param
387         // using the original out-param expression. (This lets us support things like swizzles and
388         // array indices.)
389         this->write(getOutParamHelper(c, arguments, outVars));
390     } else {
391         this->write(function.mangledName());
392     }
393
394     this->write("(");
395     const char* separator = "";
396     this->writeFunctionRequirementArgs(function, separator);
397     for (int i = 0; i < arguments.count(); ++i) {
398         this->write(separator);
399         separator = ", ";
400
401         if (outVars[i]) {
402             this->writeExpression(*outVars[i], Precedence::kSequence);
403         } else {
404             this->writeExpression(*arguments[i], Precedence::kSequence);
405         }
406     }
407     this->write(")");
408 }
409
410 static constexpr char kInverse2x2[] = R"(
411 template <typename T>
412 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
413     return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
414 }
415 )";
416
417 static constexpr char kInverse3x3[] = R"(
418 template <typename T>
419 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
420     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
421     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
422     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
423     T b01 =  a22*a11 - a12*a21;
424     T b11 = -a22*a10 + a12*a20;
425     T b21 =  a21*a10 - a11*a20;
426     T det = a00*b01 + a01*b11 + a02*b21;
427     return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
428                            b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
429                            b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
430 }
431 )";
432
433 static constexpr char kInverse4x4[] = R"(
434 template <typename T>
435 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
436     T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
437     T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
438     T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
439     T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
440     T b00 = a00*a11 - a01*a10;
441     T b01 = a00*a12 - a02*a10;
442     T b02 = a00*a13 - a03*a10;
443     T b03 = a01*a12 - a02*a11;
444     T b04 = a01*a13 - a03*a11;
445     T b05 = a02*a13 - a03*a12;
446     T b06 = a20*a31 - a21*a30;
447     T b07 = a20*a32 - a22*a30;
448     T b08 = a20*a33 - a23*a30;
449     T b09 = a21*a32 - a22*a31;
450     T b10 = a21*a33 - a23*a31;
451     T b11 = a22*a33 - a23*a32;
452     T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
453     return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
454                            a02*b10 - a01*b11 - a03*b09,
455                            a31*b05 - a32*b04 + a33*b03,
456                            a22*b04 - a21*b05 - a23*b03,
457                            a12*b08 - a10*b11 - a13*b07,
458                            a00*b11 - a02*b08 + a03*b07,
459                            a32*b02 - a30*b05 - a33*b01,
460                            a20*b05 - a22*b02 + a23*b01,
461                            a10*b10 - a11*b08 + a13*b06,
462                            a01*b08 - a00*b10 - a03*b06,
463                            a30*b04 - a31*b02 + a33*b00,
464                            a21*b02 - a20*b04 - a23*b00,
465                            a11*b07 - a10*b09 - a12*b06,
466                            a00*b09 - a01*b07 + a02*b06,
467                            a31*b01 - a30*b03 - a32*b00,
468                            a20*b03 - a21*b01 + a22*b00) * (1/det);
469 }
470 )";
471
472 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
473     // Only use polyfills for a function taking a single-argument square matrix.
474     if (arguments.size() == 1) {
475         const Type& type = arguments.front()->type();
476         if (type.isMatrix() && type.rows() == type.columns()) {
477             // Inject the correct polyfill based on the matrix size.
478             auto name = String::printf("mat%d_inverse", type.columns());
479             auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
480             if (didInsert) {
481                 switch (type.rows()) {
482                     case 2:
483                         fExtraFunctions.writeText(kInverse2x2);
484                         break;
485                     case 3:
486                         fExtraFunctions.writeText(kInverse3x3);
487                         break;
488                     case 4:
489                         fExtraFunctions.writeText(kInverse4x4);
490                         break;
491                 }
492             }
493             return name;
494         }
495     }
496     // This isn't the built-in `inverse`. We don't want to polyfill it at all.
497     return "inverse";
498 }
499
500 void MetalCodeGenerator::writeMatrixCompMult() {
501     static constexpr char kMatrixCompMult[] = R"(
502 template <typename T, int C, int R>
503 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
504     for (int c = 0; c < C; ++c) {
505         a[c] *= b[c];
506     }
507     return a;
508 }
509 )";
510
511     std::string name = "matrixCompMult";
512     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
513         fWrittenIntrinsics.insert(name);
514         fExtraFunctions.writeText(kMatrixCompMult);
515     }
516 }
517
518 void MetalCodeGenerator::writeOuterProduct() {
519     static constexpr char kOuterProduct[] = R"(
520 template <typename T, int C, int R>
521 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
522     matrix<T, C, R> result;
523     for (int c = 0; c < C; ++c) {
524         result[c] = a * b[c];
525     }
526     return result;
527 }
528 )";
529
530     std::string name = "outerProduct";
531     if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
532         fWrittenIntrinsics.insert(name);
533         fExtraFunctions.writeText(kOuterProduct);
534     }
535 }
536
537 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
538     std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
539     this->fFunctionHeader += "    " + this->typeName(type) + " " + tempVar + ";\n";
540     return tempVar;
541 }
542
543 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
544     // Write out an intrinsic function call exactly as-is. No muss no fuss.
545     this->write(c.function().name());
546     this->writeArgumentList(c.arguments());
547 }
548
549 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
550     this->write("(");
551     const char* separator = "";
552     for (const std::unique_ptr<Expression>& arg : arguments) {
553         this->write(separator);
554         separator = ", ";
555         this->writeExpression(*arg, Precedence::kSequence);
556     }
557     this->write(")");
558 }
559
560 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
561     const ExpressionArray& arguments = c.arguments();
562     switch (kind) {
563         case k_sample_IntrinsicKind: {
564             this->writeExpression(*arguments[0], Precedence::kSequence);
565             this->write(".sample(");
566             this->writeExpression(*arguments[0], Precedence::kSequence);
567             this->write(SAMPLER_SUFFIX);
568             this->write(", ");
569             const Type& arg1Type = arguments[1]->type();
570             if (arg1Type.columns() == 3) {
571                 // have to store the vector in a temp variable to avoid double evaluating it
572                 std::string tmpVar = this->getTempVariable(arg1Type);
573                 this->write("(" + tmpVar + " = ");
574                 this->writeExpression(*arguments[1], Precedence::kSequence);
575                 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
576             } else {
577                 SkASSERT(arg1Type.columns() == 2);
578                 this->writeExpression(*arguments[1], Precedence::kSequence);
579                 this->write(")");
580             }
581             return true;
582         }
583         case k_mod_IntrinsicKind: {
584             // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
585             std::string tmpX = this->getTempVariable(arguments[0]->type());
586             std::string tmpY = this->getTempVariable(arguments[1]->type());
587             this->write("(" + tmpX + " = ");
588             this->writeExpression(*arguments[0], Precedence::kSequence);
589             this->write(", " + tmpY + " = ");
590             this->writeExpression(*arguments[1], Precedence::kSequence);
591             this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
592             return true;
593         }
594         // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
595         case k_distance_IntrinsicKind: {
596             if (arguments[0]->type().columns() == 1) {
597                 this->write("abs(");
598                 this->writeExpression(*arguments[0], Precedence::kAdditive);
599                 this->write(" - ");
600                 this->writeExpression(*arguments[1], Precedence::kAdditive);
601                 this->write(")");
602             } else {
603                 this->writeSimpleIntrinsic(c);
604             }
605             return true;
606         }
607         case k_dot_IntrinsicKind: {
608             if (arguments[0]->type().columns() == 1) {
609                 this->write("(");
610                 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
611                 this->write(" * ");
612                 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
613                 this->write(")");
614             } else {
615                 this->writeSimpleIntrinsic(c);
616             }
617             return true;
618         }
619         case k_faceforward_IntrinsicKind: {
620             if (arguments[0]->type().columns() == 1) {
621                 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
622                 this->write("((((");
623                 this->writeExpression(*arguments[2], Precedence::kSequence);
624                 this->write(") * (");
625                 this->writeExpression(*arguments[1], Precedence::kSequence);
626                 this->write(") < 0) ? 1 : -1) * (");
627                 this->writeExpression(*arguments[0], Precedence::kSequence);
628                 this->write("))");
629             } else {
630                 this->writeSimpleIntrinsic(c);
631             }
632             return true;
633         }
634         case k_length_IntrinsicKind: {
635             this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
636             this->writeExpression(*arguments[0], Precedence::kSequence);
637             this->write(")");
638             return true;
639         }
640         case k_normalize_IntrinsicKind: {
641             this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
642             this->writeExpression(*arguments[0], Precedence::kSequence);
643             this->write(")");
644             return true;
645         }
646         case k_packUnorm2x16_IntrinsicKind: {
647             this->write("pack_float_to_unorm2x16(");
648             this->writeExpression(*arguments[0], Precedence::kSequence);
649             this->write(")");
650             return true;
651         }
652         case k_unpackUnorm2x16_IntrinsicKind: {
653             this->write("unpack_unorm2x16_to_float(");
654             this->writeExpression(*arguments[0], Precedence::kSequence);
655             this->write(")");
656             return true;
657         }
658         case k_packSnorm2x16_IntrinsicKind: {
659             this->write("pack_float_to_snorm2x16(");
660             this->writeExpression(*arguments[0], Precedence::kSequence);
661             this->write(")");
662             return true;
663         }
664         case k_unpackSnorm2x16_IntrinsicKind: {
665             this->write("unpack_snorm2x16_to_float(");
666             this->writeExpression(*arguments[0], Precedence::kSequence);
667             this->write(")");
668             return true;
669         }
670         case k_packUnorm4x8_IntrinsicKind: {
671             this->write("pack_float_to_unorm4x8(");
672             this->writeExpression(*arguments[0], Precedence::kSequence);
673             this->write(")");
674             return true;
675         }
676         case k_unpackUnorm4x8_IntrinsicKind: {
677             this->write("unpack_unorm4x8_to_float(");
678             this->writeExpression(*arguments[0], Precedence::kSequence);
679             this->write(")");
680             return true;
681         }
682         case k_packSnorm4x8_IntrinsicKind: {
683             this->write("pack_float_to_snorm4x8(");
684             this->writeExpression(*arguments[0], Precedence::kSequence);
685             this->write(")");
686             return true;
687         }
688         case k_unpackSnorm4x8_IntrinsicKind: {
689             this->write("unpack_snorm4x8_to_float(");
690             this->writeExpression(*arguments[0], Precedence::kSequence);
691             this->write(")");
692             return true;
693         }
694         case k_packHalf2x16_IntrinsicKind: {
695             this->write("as_type<uint>(half2(");
696             this->writeExpression(*arguments[0], Precedence::kSequence);
697             this->write("))");
698             return true;
699         }
700         case k_unpackHalf2x16_IntrinsicKind: {
701             this->write("float2(as_type<half2>(");
702             this->writeExpression(*arguments[0], Precedence::kSequence);
703             this->write("))");
704             return true;
705         }
706         case k_floatBitsToInt_IntrinsicKind:
707         case k_floatBitsToUint_IntrinsicKind:
708         case k_intBitsToFloat_IntrinsicKind:
709         case k_uintBitsToFloat_IntrinsicKind: {
710             this->write(this->getBitcastIntrinsic(c.type()));
711             this->write("(");
712             this->writeExpression(*arguments[0], Precedence::kSequence);
713             this->write(")");
714             return true;
715         }
716         case k_degrees_IntrinsicKind: {
717             this->write("((");
718             this->writeExpression(*arguments[0], Precedence::kSequence);
719             this->write(") * 57.2957795)");
720             return true;
721         }
722         case k_radians_IntrinsicKind: {
723             this->write("((");
724             this->writeExpression(*arguments[0], Precedence::kSequence);
725             this->write(") * 0.0174532925)");
726             return true;
727         }
728         case k_dFdx_IntrinsicKind: {
729             this->write("dfdx");
730             this->writeArgumentList(c.arguments());
731             return true;
732         }
733         case k_dFdy_IntrinsicKind: {
734             if (!fRTFlipName.empty()) {
735                 this->write("(" + fRTFlipName + ".y * dfdy");
736             } else {
737                 this->write("(dfdy");
738             }
739             this->writeArgumentList(c.arguments());
740             this->write(")");
741             return true;
742         }
743         case k_inverse_IntrinsicKind: {
744             this->write(this->getInversePolyfill(arguments));
745             this->writeArgumentList(c.arguments());
746             return true;
747         }
748         case k_inversesqrt_IntrinsicKind: {
749             this->write("rsqrt");
750             this->writeArgumentList(c.arguments());
751             return true;
752         }
753         case k_atan_IntrinsicKind: {
754             this->write(c.arguments().size() == 2 ? "atan2" : "atan");
755             this->writeArgumentList(c.arguments());
756             return true;
757         }
758         case k_reflect_IntrinsicKind: {
759             if (arguments[0]->type().columns() == 1) {
760                 // We need to synthesize `I - 2 * N * I * N`.
761                 std::string tmpI = this->getTempVariable(arguments[0]->type());
762                 std::string tmpN = this->getTempVariable(arguments[1]->type());
763
764                 // (_skTempI = ...
765                 this->write("(" + tmpI + " = ");
766                 this->writeExpression(*arguments[0], Precedence::kSequence);
767
768                 // , _skTempN = ...
769                 this->write(", " + tmpN + " = ");
770                 this->writeExpression(*arguments[1], Precedence::kSequence);
771
772                 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
773                 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
774             } else {
775                 this->writeSimpleIntrinsic(c);
776             }
777             return true;
778         }
779         case k_refract_IntrinsicKind: {
780             if (arguments[0]->type().columns() == 1) {
781                 // Metal does implement refract for vectors; rather than reimplementing refract from
782                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
783                 this->write("(refract(float2(");
784                 this->writeExpression(*arguments[0], Precedence::kSequence);
785                 this->write(", 0), float2(");
786                 this->writeExpression(*arguments[1], Precedence::kSequence);
787                 this->write(", 0), ");
788                 this->writeExpression(*arguments[2], Precedence::kSequence);
789                 this->write(").x)");
790             } else {
791                 this->writeSimpleIntrinsic(c);
792             }
793             return true;
794         }
795         case k_roundEven_IntrinsicKind: {
796             this->write("rint");
797             this->writeArgumentList(c.arguments());
798             return true;
799         }
800         case k_bitCount_IntrinsicKind: {
801             this->write("popcount(");
802             this->writeExpression(*arguments[0], Precedence::kSequence);
803             this->write(")");
804             return true;
805         }
806         case k_findLSB_IntrinsicKind: {
807             // Create a temp variable to store the expression, to avoid double-evaluating it.
808             std::string skTemp = this->getTempVariable(arguments[0]->type());
809             std::string exprType = this->typeName(arguments[0]->type());
810
811             // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
812             // Use select to detect zero inputs and force a -1 result.
813
814             // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
815             this->write("(");
816             this->write(skTemp);
817             this->write(" = (");
818             this->writeExpression(*arguments[0], Precedence::kSequence);
819             this->write("), select(ctz(");
820             this->write(skTemp);
821             this->write("), ");
822             this->write(exprType);
823             this->write("(-1), ");
824             this->write(skTemp);
825             this->write(" == ");
826             this->write(exprType);
827             this->write("(0)))");
828             return true;
829         }
830         case k_findMSB_IntrinsicKind: {
831             // Create a temp variable to store the expression, to avoid double-evaluating it.
832             std::string skTemp1 = this->getTempVariable(arguments[0]->type());
833             std::string exprType = this->typeName(arguments[0]->type());
834
835             // GLSL findMSB is actually quite different from Metal's clz:
836             // - For signed negative numbers, it returns the first zero bit, not the first one bit!
837             // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
838
839             // (_skTemp1 = (.....),
840             this->write("(");
841             this->write(skTemp1);
842             this->write(" = (");
843             this->writeExpression(*arguments[0], Precedence::kSequence);
844             this->write("), ");
845
846             // Signed input types might be negative; we need another helper variable to negate the
847             // input (since we can only find one bits, not zero bits).
848             std::string skTemp2;
849             if (arguments[0]->type().isSigned()) {
850                 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
851                 skTemp2 = this->getTempVariable(arguments[0]->type());
852                 this->write(skTemp2);
853                 this->write(" = (select(");
854                 this->write(skTemp1);
855                 this->write(", ~");
856                 this->write(skTemp1);
857                 this->write(", ");
858                 this->write(skTemp1);
859                 this->write(" < 0)), ");
860             } else {
861                 skTemp2 = skTemp1;
862             }
863
864             // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
865             this->write("select(");
866             this->write(this->typeName(c.type()));
867             this->write("(clz(");
868             this->write(skTemp2);
869             this->write(")), ");
870             this->write(this->typeName(c.type()));
871             this->write("(-1), ");
872             this->write(skTemp2);
873             this->write(" == ");
874             this->write(exprType);
875             this->write("(0)))");
876             return true;
877         }
878         case k_sign_IntrinsicKind: {
879             if (arguments[0]->type().componentType().isInteger()) {
880                 // Create a temp variable to store the expression, to avoid double-evaluating it.
881                 std::string skTemp = this->getTempVariable(arguments[0]->type());
882                 std::string exprType = this->typeName(arguments[0]->type());
883
884                 // (_skTemp = (.....),
885                 this->write("(");
886                 this->write(skTemp);
887                 this->write(" = (");
888                 this->writeExpression(*arguments[0], Precedence::kSequence);
889                 this->write("), ");
890
891                 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
892                 this->write("select(select(");
893                 this->write(exprType);
894                 this->write("(0), ");
895                 this->write(exprType);
896                 this->write("(-1), ");
897                 this->write(skTemp);
898                 this->write(" < 0), ");
899                 this->write(exprType);
900                 this->write("(1), ");
901                 this->write(skTemp);
902                 this->write(" > 0))");
903             } else {
904                 this->writeSimpleIntrinsic(c);
905             }
906             return true;
907         }
908         case k_matrixCompMult_IntrinsicKind: {
909             this->writeMatrixCompMult();
910             this->writeSimpleIntrinsic(c);
911             return true;
912         }
913         case k_outerProduct_IntrinsicKind: {
914             this->writeOuterProduct();
915             this->writeSimpleIntrinsic(c);
916             return true;
917         }
918         case k_mix_IntrinsicKind: {
919             SkASSERT(c.arguments().size() == 3);
920             if (arguments[2]->type().componentType().isBoolean()) {
921                 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
922                 this->write("select");
923                 this->writeArgumentList(c.arguments());
924                 return true;
925             }
926             // The basic form of mix() is supported by Metal as-is.
927             this->writeSimpleIntrinsic(c);
928             return true;
929         }
930         case k_equal_IntrinsicKind:
931         case k_greaterThan_IntrinsicKind:
932         case k_greaterThanEqual_IntrinsicKind:
933         case k_lessThan_IntrinsicKind:
934         case k_lessThanEqual_IntrinsicKind:
935         case k_notEqual_IntrinsicKind: {
936             this->write("(");
937             this->writeExpression(*c.arguments()[0], Precedence::kRelational);
938             switch (kind) {
939                 case k_equal_IntrinsicKind:
940                     this->write(" == ");
941                     break;
942                 case k_notEqual_IntrinsicKind:
943                     this->write(" != ");
944                     break;
945                 case k_lessThan_IntrinsicKind:
946                     this->write(" < ");
947                     break;
948                 case k_lessThanEqual_IntrinsicKind:
949                     this->write(" <= ");
950                     break;
951                 case k_greaterThan_IntrinsicKind:
952                     this->write(" > ");
953                     break;
954                 case k_greaterThanEqual_IntrinsicKind:
955                     this->write(" >= ");
956                     break;
957                 default:
958                     SK_ABORT("unsupported comparison intrinsic kind");
959             }
960             this->writeExpression(*c.arguments()[1], Precedence::kRelational);
961             this->write(")");
962             return true;
963         }
964         default:
965             return false;
966     }
967 }
968
969 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
970 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
971 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
972     SkASSERT(rows <= 4);
973     SkASSERT(columns <= 4);
974
975     std::string matrixType = this->typeName(sourceMatrix.componentType());
976
977     const char* separator = "";
978     for (int c = 0; c < columns; ++c) {
979         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
980         separator = "), ";
981
982         // Determine how many values to take from the source matrix for this row.
983         int swizzleLength = 0;
984         if (c < sourceMatrix.columns()) {
985             swizzleLength = std::min<>(rows, sourceMatrix.rows());
986         }
987
988         // Emit all the values from the source matrix row.
989         bool firstItem;
990         switch (swizzleLength) {
991             case 0:  firstItem = true;                                            break;
992             case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
993             case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
994             case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
995             case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
996             default: SkUNREACHABLE;
997         }
998
999         // Emit the placeholder identity-matrix cells.
1000         for (int r = swizzleLength; r < rows; ++r) {
1001             fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1002             firstItem = false;
1003         }
1004     }
1005
1006     fExtraFunctions.writeText(")");
1007 }
1008
1009 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1010 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
1011 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1012                                                        int columns, int rows) {
1013     SkASSERT(rows <= 4);
1014     SkASSERT(columns <= 4);
1015
1016     std::string matrixType = this->typeName(ctor.type().componentType());
1017     size_t argIndex = 0;
1018     int argPosition = 0;
1019     auto args = ctor.argumentSpan();
1020
1021     static constexpr char kSwizzle[] = "xyzw";
1022     const char* separator = "";
1023     for (int c = 0; c < columns; ++c) {
1024         fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1025         separator = "), ";
1026
1027         const char* columnSeparator = "";
1028         for (int r = 0; r < rows;) {
1029             fExtraFunctions.writeText(columnSeparator);
1030             columnSeparator = ", ";
1031
1032             if (argIndex < args.size()) {
1033                 const Type& argType = args[argIndex]->type();
1034                 switch (argType.typeKind()) {
1035                     case Type::TypeKind::kScalar: {
1036                         fExtraFunctions.printf("x%zu", argIndex);
1037                         ++r;
1038                         ++argPosition;
1039                         break;
1040                     }
1041                     case Type::TypeKind::kVector: {
1042                         fExtraFunctions.printf("x%zu.", argIndex);
1043                         do {
1044                             fExtraFunctions.write8(kSwizzle[argPosition]);
1045                             ++r;
1046                             ++argPosition;
1047                         } while (r < rows && argPosition < argType.columns());
1048                         break;
1049                     }
1050                     case Type::TypeKind::kMatrix: {
1051                         fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1052                         do {
1053                             fExtraFunctions.write8(kSwizzle[argPosition]);
1054                             ++r;
1055                             ++argPosition;
1056                         } while (r < rows && (argPosition % argType.rows()) != 0);
1057                         break;
1058                     }
1059                     default: {
1060                         SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1061                         fExtraFunctions.writeText("<error>");
1062                         break;
1063                     }
1064                 }
1065
1066                 if (argPosition >= argType.columns() * argType.rows()) {
1067                     ++argIndex;
1068                     argPosition = 0;
1069                 }
1070             } else {
1071                 SkDEBUGFAIL("not enough arguments for matrix constructor");
1072                 fExtraFunctions.writeText("<error>");
1073             }
1074         }
1075     }
1076
1077     if (argPosition != 0 || argIndex != args.size()) {
1078         SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1079         fExtraFunctions.writeText(", <error>");
1080     }
1081
1082     fExtraFunctions.writeText(")");
1083 }
1084
1085 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1086 // Keeps track of previously generated constructors so that we won't generate more than one
1087 // constructor for any given permutation of input argument types. Returns the name of the
1088 // generated constructor method.
1089 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1090     const Type& type = c.type();
1091     int columns = type.columns();
1092     int rows = type.rows();
1093     auto args = c.argumentSpan();
1094     std::string typeName = this->typeName(type);
1095
1096     // Create the helper-method name and use it as our lookup key.
1097     std::string name = String::printf("%s_from", typeName.c_str());
1098     for (const std::unique_ptr<Expression>& expr : args) {
1099         String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1100     }
1101
1102     // If a helper-method has not been synthesized yet, create it now.
1103     if (!fHelpers.contains(name)) {
1104         fHelpers.add(name);
1105
1106         // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1107         // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1108         // supply a mixture of scalars and vectors.)
1109         fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1110
1111         size_t argIndex = 0;
1112         const char* argSeparator = "";
1113         for (const std::unique_ptr<Expression>& expr : args) {
1114             fExtraFunctions.printf("%s%s x%zu", argSeparator,
1115                                    this->typeName(expr->type()).c_str(), argIndex++);
1116             argSeparator = ", ";
1117         }
1118
1119         fExtraFunctions.printf(") {\n    return %s(", typeName.c_str());
1120
1121         if (args.size() == 1 && args.front()->type().isMatrix()) {
1122             this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1123         } else {
1124             this->assembleMatrixFromExpressions(c, columns, rows);
1125         }
1126
1127         fExtraFunctions.writeText(");\n}\n");
1128     }
1129     return name;
1130 }
1131
1132 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1133     SkASSERT(c.type().isMatrix());
1134
1135     // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1136     // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1137     // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1138     // scalars can be constructed trivially. In more complex cases, we generate a helper function
1139     // that converts our inputs into a properly-shaped matrix.
1140     // A matrix construct helper method is always used if any input argument is a matrix.
1141     // Helper methods are also necessary when any argument would span multiple rows. For instance:
1142     //
1143     // float2 x = (1, 2);
1144     // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1145     //                           | 2 4 6 |
1146     //
1147     // float2 x = (2, 3);
1148     // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1149     //                           | 2 4 6 |
1150     //
1151     // float4 x = (1, 2, 3, 4);
1152     // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1153     //               | 2 4 |
1154     //
1155
1156     int position = 0;
1157     for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1158         // If an input argument is a matrix, we need a helper function.
1159         if (expr->type().isMatrix()) {
1160             return true;
1161         }
1162         position += expr->type().columns();
1163         if (position > c.type().rows()) {
1164             // An input argument would span multiple rows; a helper function is required.
1165             return true;
1166         }
1167         if (position == c.type().rows()) {
1168             // We've advanced to the end of a row. Wrap to the start of the next row.
1169             position = 0;
1170         }
1171     }
1172
1173     return false;
1174 }
1175
1176 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1177                                                       Precedence parentPrecedence) {
1178     // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1179     // matrix-construct helper here.
1180     this->write(this->getMatrixConstructHelper(c));
1181     this->write("(");
1182     this->writeExpression(*c.argument(), Precedence::kSequence);
1183     this->write(")");
1184 }
1185
1186 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1187                                                   Precedence parentPrecedence) {
1188     if (c.type().isVector()) {
1189         this->writeConstructorCompoundVector(c, parentPrecedence);
1190     } else if (c.type().isMatrix()) {
1191         this->writeConstructorCompoundMatrix(c, parentPrecedence);
1192     } else {
1193         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1194     }
1195 }
1196
1197 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1198                                                    Precedence parentPrecedence) {
1199     const Type& inType = c.argument()->type().componentType();
1200     const Type& outType = c.type().componentType();
1201     std::string inTypeName = this->typeName(inType);
1202     std::string outTypeName = this->typeName(outType);
1203
1204     std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1205     if (!fHelpers.contains(name)) {
1206         fHelpers.add(name);
1207         fExtraFunctions.printf(R"(
1208 template <size_t N>
1209 array<%s, N> %s(thread const array<%s, N>& x) {
1210     array<%s, N> result;
1211     for (int i = 0; i < N; ++i) {
1212         result[i] = %s(x[i]);
1213     }
1214     return result;
1215 }
1216 )",
1217                                outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1218                                outTypeName.c_str(),
1219                                outTypeName.c_str());
1220     }
1221
1222     this->write(name);
1223     this->write("(");
1224     this->writeExpression(*c.argument(), Precedence::kSequence);
1225     this->write(")");
1226 }
1227
1228 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1229     SkASSERT(matrixType.isMatrix());
1230     SkASSERT(matrixType.rows() == 2);
1231     SkASSERT(matrixType.columns() == 2);
1232
1233     std::string baseType = this->typeName(matrixType.componentType());
1234     std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1235     if (!fHelpers.contains(name)) {
1236         fHelpers.add(name);
1237
1238         fExtraFunctions.printf(R"(
1239 %s4 %s(%s2x2 x) {
1240     return %s4(x[0].xy, x[1].xy);
1241 }
1242 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1243     }
1244
1245     return name;
1246 }
1247
1248 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1249                                                         Precedence parentPrecedence) {
1250     SkASSERT(c.type().isVector());
1251
1252     // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1253     // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1254     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1255         const Expression& expr = *c.argumentSpan().front();
1256         if (expr.type().isMatrix()) {
1257             this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1258             this->write("(");
1259             this->writeExpression(expr, Precedence::kSequence);
1260             this->write(")");
1261             return;
1262         }
1263     }
1264
1265     this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1266 }
1267
1268 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1269                                                         Precedence parentPrecedence) {
1270     SkASSERT(c.type().isMatrix());
1271
1272     // Emit and invoke a matrix-constructor helper method if one is necessary.
1273     if (this->matrixConstructHelperIsNeeded(c)) {
1274         this->write(this->getMatrixConstructHelper(c));
1275         this->write("(");
1276         const char* separator = "";
1277         for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1278             this->write(separator);
1279             separator = ", ";
1280             this->writeExpression(*expr, Precedence::kSequence);
1281         }
1282         this->write(")");
1283         return;
1284     }
1285
1286     // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1287     // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1288     // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1289     // can group our inputs up and synthesize a constructor for each column.
1290     const Type& matrixType = c.type();
1291     const Type& columnType = matrixType.componentType().toCompound(
1292             fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1293
1294     this->writeType(matrixType);
1295     this->write("(");
1296     const char* separator = "";
1297     int scalarCount = 0;
1298     for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1299         this->write(separator);
1300         separator = ", ";
1301         if (arg->type().columns() < matrixType.rows()) {
1302             // Write a `floatN(` constructor to group scalars and smaller vectors together.
1303             if (!scalarCount) {
1304                 this->writeType(columnType);
1305                 this->write("(");
1306             }
1307             scalarCount += arg->type().columns();
1308         }
1309         this->writeExpression(*arg, Precedence::kSequence);
1310         if (scalarCount && scalarCount == matrixType.rows()) {
1311             // Close our `floatN(...` constructor block from above.
1312             this->write(")");
1313             scalarCount = 0;
1314         }
1315     }
1316     this->write(")");
1317 }
1318
1319 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1320                                              const char* leftBracket,
1321                                              const char* rightBracket,
1322                                              Precedence parentPrecedence) {
1323     this->writeType(c.type());
1324     this->write(leftBracket);
1325     const char* separator = "";
1326     for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1327         this->write(separator);
1328         separator = ", ";
1329         this->writeExpression(*arg, Precedence::kSequence);
1330     }
1331     this->write(rightBracket);
1332 }
1333
1334 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1335                                               const char* leftBracket,
1336                                               const char* rightBracket,
1337                                               Precedence parentPrecedence) {
1338     return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1339 }
1340
1341 void MetalCodeGenerator::writeFragCoord() {
1342     if (!fRTFlipName.empty()) {
1343         this->write("float4(_fragCoord.x, ");
1344         this->write(fRTFlipName.c_str());
1345         this->write(".x + ");
1346         this->write(fRTFlipName.c_str());
1347         this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1348     } else {
1349         this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1350     }
1351 }
1352
1353 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1354     // When assembling out-param helper functions, we copy variables into local clones with matching
1355     // names. We never want to prepend "_in." or "_globals." when writing these variables since
1356     // we're actually targeting the clones.
1357     if (fIgnoreVariableReferenceModifiers) {
1358         this->writeName(ref.variable()->name());
1359         return;
1360     }
1361
1362     switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1363         case SK_FRAGCOLOR_BUILTIN:
1364             this->write("_out.sk_FragColor");
1365             break;
1366         case SK_FRAGCOORD_BUILTIN:
1367             this->writeFragCoord();
1368             break;
1369         case SK_VERTEXID_BUILTIN:
1370             this->write("sk_VertexID");
1371             break;
1372         case SK_INSTANCEID_BUILTIN:
1373             this->write("sk_InstanceID");
1374             break;
1375         case SK_CLOCKWISE_BUILTIN:
1376             // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1377             // clockwise to match Skia convention.
1378             if (!fRTFlipName.empty()) {
1379                 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1380             } else {
1381                 this->write("_frontFacing");
1382             }
1383             break;
1384         default:
1385             const Variable& var = *ref.variable();
1386             if (var.storage() == Variable::Storage::kGlobal) {
1387                 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1388                     this->write("_in.");
1389                 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1390                     this->write("_out.");
1391                 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1392                            var.type().typeKind() != Type::TypeKind::kSampler) {
1393                     this->write("_uniforms.");
1394                 } else {
1395                     this->write("_globals.");
1396                 }
1397             }
1398             this->writeName(var.name());
1399     }
1400 }
1401
1402 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1403     this->writeExpression(*expr.base(), Precedence::kPostfix);
1404     this->write("[");
1405     this->writeExpression(*expr.index(), Precedence::kTopLevel);
1406     this->write("]");
1407 }
1408
1409 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1410     const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1411     if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1412         this->writeExpression(*f.base(), Precedence::kPostfix);
1413         this->write(".");
1414     }
1415     switch (field->fModifiers.fLayout.fBuiltin) {
1416         case SK_POSITION_BUILTIN:
1417             this->write("_out.sk_Position");
1418             break;
1419         case SK_POINTSIZE_BUILTIN:
1420             this->write("_out.sk_PointSize");
1421             break;
1422         default:
1423             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1424                 this->write("_globals.");
1425                 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1426                 this->write("->");
1427             }
1428             this->writeName(field->fName);
1429     }
1430 }
1431
1432 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1433     this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1434     this->write(".");
1435     for (int c : swizzle.components()) {
1436         SkASSERT(c >= 0 && c <= 3);
1437         this->write(&("x\0y\0z\0w\0"[c * 2]));
1438     }
1439 }
1440
1441 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1442                                                      const Type& result) {
1443     SkASSERT(left.isMatrix());
1444     SkASSERT(right.isMatrix());
1445     SkASSERT(result.isMatrix());
1446     SkASSERT(left.rows() == right.rows());
1447     SkASSERT(left.columns() == right.columns());
1448     SkASSERT(left.rows() == result.rows());
1449     SkASSERT(left.columns() == result.columns());
1450
1451     std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1452
1453     if (!fHelpers.contains(key)) {
1454         fHelpers.add(key);
1455         fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1456                                "    left = left * right;\n"
1457                                "    return left;\n"
1458                                "}\n",
1459                                this->typeName(result).c_str(), this->typeName(left).c_str(),
1460                                this->typeName(right).c_str());
1461     }
1462 }
1463
1464 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1465     SkASSERT(left.isMatrix());
1466     SkASSERT(right.isMatrix());
1467     SkASSERT(left.rows() == right.rows());
1468     SkASSERT(left.columns() == right.columns());
1469
1470     std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1471
1472     if (!fHelpers.contains(key)) {
1473         fHelpers.add(key);
1474         fExtraFunctionPrototypes.printf(R"(
1475 thread bool operator==(const %s left, const %s right);
1476 thread bool operator!=(const %s left, const %s right);
1477 )",
1478                                         this->typeName(left).c_str(),
1479                                         this->typeName(right).c_str(),
1480                                         this->typeName(left).c_str(),
1481                                         this->typeName(right).c_str());
1482
1483         fExtraFunctions.printf(
1484                 "thread bool operator==(const %s left, const %s right) {\n"
1485                 "    return ",
1486                 this->typeName(left).c_str(), this->typeName(right).c_str());
1487
1488         const char* separator = "";
1489         for (int index=0; index<left.columns(); ++index) {
1490             fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1491             separator = " &&\n           ";
1492         }
1493
1494         fExtraFunctions.printf(
1495                 ";\n"
1496                 "}\n"
1497                 "thread bool operator!=(const %s left, const %s right) {\n"
1498                 "    return !(left == right);\n"
1499                 "}\n",
1500                 this->typeName(left).c_str(), this->typeName(right).c_str());
1501     }
1502 }
1503
1504 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1505     SkASSERT(type.isMatrix());
1506
1507     std::string key = "Matrix / " + this->typeName(type);
1508
1509     if (!fHelpers.contains(key)) {
1510         fHelpers.add(key);
1511         std::string typeName = this->typeName(type);
1512
1513         fExtraFunctions.printf(
1514                 "thread %s operator/(const %s left, const %s right) {\n"
1515                 "    return %s(",
1516                 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1517
1518         const char* separator = "";
1519         for (int index=0; index<type.columns(); ++index) {
1520             fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1521             separator = ", ";
1522         }
1523
1524         fExtraFunctions.printf(");\n"
1525                                "}\n"
1526                                "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1527                                "    left = left / right;\n"
1528                                "    return left;\n"
1529                                "}\n",
1530                                typeName.c_str(), typeName.c_str(), typeName.c_str());
1531     }
1532 }
1533
1534 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1535     SkASSERT(type.isArray());
1536
1537     // If the array's component type needs a helper as well, we need to emit that one first.
1538     this->writeEqualityHelpers(type.componentType(), type.componentType());
1539
1540     std::string key = "ArrayEquality []";
1541     if (!fHelpers.contains(key)) {
1542         fHelpers.add(key);
1543         fExtraFunctionPrototypes.writeText(R"(
1544 template <typename T1, typename T2, size_t N>
1545 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1546 template <typename T1, typename T2, size_t N>
1547 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1548 )");
1549         fExtraFunctions.writeText(R"(
1550 template <typename T1, typename T2, size_t N>
1551 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1552     for (size_t index = 0; index < N; ++index) {
1553         if (!all(left[index] == right[index])) {
1554             return false;
1555         }
1556     }
1557     return true;
1558 }
1559
1560 template <typename T1, typename T2, size_t N>
1561 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1562     return !(left == right);
1563 }
1564 )");
1565     }
1566 }
1567
1568 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1569     SkASSERT(type.isStruct());
1570     std::string key = "StructEquality " + this->typeName(type);
1571
1572     if (!fHelpers.contains(key)) {
1573         fHelpers.add(key);
1574         // If one of the struct's fields needs a helper as well, we need to emit that one first.
1575         for (const Type::Field& field : type.fields()) {
1576             this->writeEqualityHelpers(*field.fType, *field.fType);
1577         }
1578
1579         // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1580         // and GLSL but do not exist by default in Metal.
1581         fExtraFunctionPrototypes.printf(R"(
1582 thread bool operator==(thread const %s& left, thread const %s& right);
1583 thread bool operator!=(thread const %s& left, thread const %s& right);
1584 )",
1585                                         this->typeName(type).c_str(),
1586                                         this->typeName(type).c_str(),
1587                                         this->typeName(type).c_str(),
1588                                         this->typeName(type).c_str());
1589
1590         fExtraFunctions.printf(
1591                 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1592                 "    return ",
1593                 this->typeName(type).c_str(),
1594                 this->typeName(type).c_str());
1595
1596         const char* separator = "";
1597         for (const Type::Field& field : type.fields()) {
1598             fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1599                                    separator,
1600                                    (int)field.fName.size(), field.fName.data(),
1601                                    (int)field.fName.size(), field.fName.data());
1602             separator = " &&\n           ";
1603         }
1604         fExtraFunctions.printf(
1605                 ";\n"
1606                 "}\n"
1607                 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1608                 "    return !(left == right);\n"
1609                 "}\n",
1610                 this->typeName(type).c_str(),
1611                 this->typeName(type).c_str());
1612     }
1613 }
1614
1615 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1616     if (leftType.isArray() && rightType.isArray()) {
1617         this->writeArrayEqualityHelpers(leftType);
1618         return;
1619     }
1620     if (leftType.isStruct() && rightType.isStruct()) {
1621         this->writeStructEqualityHelpers(leftType);
1622         return;
1623     }
1624     if (leftType.isMatrix() && rightType.isMatrix()) {
1625         this->writeMatrixEqualityHelpers(leftType, rightType);
1626         return;
1627     }
1628 }
1629
1630 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1631     SkASSERT(expr.type().isNumber());
1632     SkASSERT(matrixType.isMatrix());
1633
1634     // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1635     this->write("(");
1636     this->writeType(matrixType);
1637     this->write("(");
1638
1639     const char* separator = "";
1640     for (int index = matrixType.slotCount(); index--;) {
1641         this->write(separator);
1642         this->write("1.0");
1643         separator = ", ";
1644     }
1645
1646     this->write(") * ");
1647     this->writeExpression(expr, Precedence::kMultiplicative);
1648     this->write(")");
1649 }
1650
1651 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1652                                                Precedence parentPrecedence) {
1653     const Expression& left = *b.left();
1654     const Expression& right = *b.right();
1655     const Type& leftType = left.type();
1656     const Type& rightType = right.type();
1657     Operator op = b.getOperator();
1658     Precedence precedence = op.getBinaryPrecedence();
1659     bool needParens = precedence >= parentPrecedence;
1660     switch (op.kind()) {
1661         case Operator::Kind::EQEQ:
1662             this->writeEqualityHelpers(leftType, rightType);
1663             if (leftType.isVector()) {
1664                 this->write("all");
1665                 needParens = true;
1666             }
1667             break;
1668         case Operator::Kind::NEQ:
1669             this->writeEqualityHelpers(leftType, rightType);
1670             if (leftType.isVector()) {
1671                 this->write("any");
1672                 needParens = true;
1673             }
1674             break;
1675         default:
1676             break;
1677     }
1678     if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
1679         this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1680     }
1681     if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
1682         ((leftType.isMatrix() && rightType.isMatrix()) ||
1683          (leftType.isScalar() && rightType.isMatrix()) ||
1684          (leftType.isMatrix() && rightType.isScalar()))) {
1685         this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1686     }
1687     if (needParens) {
1688         this->write("(");
1689     }
1690     bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1691                                    op.isValidForMatrixOrVector() &&
1692                                    op.removeAssignment().kind() != Operator::Kind::STAR;
1693     if (needMatrixSplatOnScalar) {
1694         this->writeNumberAsMatrix(left, rightType);
1695     } else {
1696         this->writeExpression(left, precedence);
1697     }
1698     if (op.kind() != Operator::Kind::EQ && op.isAssignment() &&
1699         left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1700         // This doesn't compile in Metal:
1701         // float4 x = float4(1);
1702         // x.xy *= float2x2(...);
1703         // with the error message "non-const reference cannot bind to vector element",
1704         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1705         // as long as the LHS has no side effects, and hope for the best otherwise.
1706         this->write(" = ");
1707         this->writeExpression(left, Precedence::kAssignment);
1708         this->write(operator_name(op.removeAssignment()));
1709     } else {
1710         this->write(operator_name(op));
1711     }
1712
1713     needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1714                               op.isValidForMatrixOrVector() &&
1715                               op.removeAssignment().kind() != Operator::Kind::STAR;
1716     if (needMatrixSplatOnScalar) {
1717         this->writeNumberAsMatrix(right, leftType);
1718     } else {
1719         this->writeExpression(right, precedence);
1720     }
1721     if (needParens) {
1722         this->write(")");
1723     }
1724 }
1725
1726 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1727                                                Precedence parentPrecedence) {
1728     if (Precedence::kTernary >= parentPrecedence) {
1729         this->write("(");
1730     }
1731     this->writeExpression(*t.test(), Precedence::kTernary);
1732     this->write(" ? ");
1733     this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1734     this->write(" : ");
1735     this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1736     if (Precedence::kTernary >= parentPrecedence) {
1737         this->write(")");
1738     }
1739 }
1740
1741 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1742                                                Precedence parentPrecedence) {
1743     // According to the MSL specification, the arithmetic unary operators (+ and â€“) do not act
1744     // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1745     const Operator op = p.getOperator();
1746     if (op.kind() == Operator::Kind::PLUS) {
1747         return this->writeExpression(*p.operand(), Precedence::kPrefix);
1748     }
1749
1750     const bool matrixNegation =
1751             op.kind() == Operator::Kind::MINUS && p.operand()->type().isMatrix();
1752     const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1753
1754     if (needParens) {
1755         this->write("(");
1756     }
1757
1758     // Transform the unary "-" on a matrix type to a multiplication by -1.
1759     if (matrixNegation) {
1760         this->write("-1.0 * ");
1761     } else {
1762         this->write(p.getOperator().tightOperatorName());
1763     }
1764     this->writeExpression(*p.operand(), Precedence::kPrefix);
1765
1766     if (needParens) {
1767         this->write(")");
1768     }
1769 }
1770
1771 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1772                                                 Precedence parentPrecedence) {
1773     if (Precedence::kPostfix >= parentPrecedence) {
1774         this->write("(");
1775     }
1776     this->writeExpression(*p.operand(), Precedence::kPostfix);
1777     this->write(p.getOperator().tightOperatorName());
1778     if (Precedence::kPostfix >= parentPrecedence) {
1779         this->write(")");
1780     }
1781 }
1782
1783 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1784     const Type& type = l.type();
1785     if (type.isFloat()) {
1786         this->write(skstd::to_string(l.floatValue()));
1787         if (!l.type().highPrecision()) {
1788             this->write("h");
1789         }
1790         return;
1791     }
1792     if (type.isInteger()) {
1793         if (type.matches(*fContext.fTypes.fUInt)) {
1794             this->write(std::to_string(l.intValue() & 0xffffffff));
1795             this->write("u");
1796         } else if (type.matches(*fContext.fTypes.fUShort)) {
1797             this->write(std::to_string(l.intValue() & 0xffff));
1798             this->write("u");
1799         } else {
1800             this->write(std::to_string(l.intValue()));
1801         }
1802         return;
1803     }
1804     SkASSERT(type.isBoolean());
1805     this->write(l.boolValue() ? "true" : "false");
1806 }
1807
1808 void MetalCodeGenerator::writeSetting(const Setting& s) {
1809     SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1810 }
1811
1812 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1813                                                       const char*& separator) {
1814     Requirements requirements = this->requirements(f);
1815     if (requirements & kInputs_Requirement) {
1816         this->write(separator);
1817         this->write("_in");
1818         separator = ", ";
1819     }
1820     if (requirements & kOutputs_Requirement) {
1821         this->write(separator);
1822         this->write("_out");
1823         separator = ", ";
1824     }
1825     if (requirements & kUniforms_Requirement) {
1826         this->write(separator);
1827         this->write("_uniforms");
1828         separator = ", ";
1829     }
1830     if (requirements & kGlobals_Requirement) {
1831         this->write(separator);
1832         this->write("_globals");
1833         separator = ", ";
1834     }
1835     if (requirements & kFragCoord_Requirement) {
1836         this->write(separator);
1837         this->write("_fragCoord");
1838         separator = ", ";
1839     }
1840 }
1841
1842 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1843                                                         const char*& separator) {
1844     Requirements requirements = this->requirements(f);
1845     if (requirements & kInputs_Requirement) {
1846         this->write(separator);
1847         this->write("Inputs _in");
1848         separator = ", ";
1849     }
1850     if (requirements & kOutputs_Requirement) {
1851         this->write(separator);
1852         this->write("thread Outputs& _out");
1853         separator = ", ";
1854     }
1855     if (requirements & kUniforms_Requirement) {
1856         this->write(separator);
1857         this->write("Uniforms _uniforms");
1858         separator = ", ";
1859     }
1860     if (requirements & kGlobals_Requirement) {
1861         this->write(separator);
1862         this->write("thread Globals& _globals");
1863         separator = ", ";
1864     }
1865     if (requirements & kFragCoord_Requirement) {
1866         this->write(separator);
1867         this->write("float4 _fragCoord");
1868         separator = ", ";
1869     }
1870 }
1871
1872 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1873     return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1874                                      : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1875 }
1876
1877 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1878     return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1879                                  : fProgram.fConfig->fSettings.fDefaultUniformSet;
1880 }
1881
1882 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1883     fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1884                           ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1885                           : "";
1886     const char* separator = "";
1887     if (f.isMain()) {
1888         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
1889             this->write("fragment Outputs fragmentMain");
1890         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
1891             this->write("vertex Outputs vertexMain");
1892         } else {
1893             fContext.fErrors->error(Position(), "unsupported kind of program");
1894             return false;
1895         }
1896         this->write("(Inputs _in [[stage_in]]");
1897         if (-1 != fUniformBuffer) {
1898             this->write(", constant Uniforms& _uniforms [[buffer(" +
1899                         std::to_string(fUniformBuffer) + ")]]");
1900         }
1901         for (const ProgramElement* e : fProgram.elements()) {
1902             if (e->is<GlobalVarDeclaration>()) {
1903                 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1904                 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1905                 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1906                     if (var.var().type().dimensions() != SpvDim2D) {
1907                         // Not yet implemented--Skia currently only uses 2D textures.
1908                         fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
1909                         return false;
1910                     }
1911                     int binding = getUniformBinding(var.var().modifiers());
1912                     this->write(", texture2d<half> ");
1913                     this->writeName(var.var().name());
1914                     this->write("[[texture(");
1915                     this->write(std::to_string(binding));
1916                     this->write(")]]");
1917                     this->write(", sampler ");
1918                     this->writeName(var.var().name());
1919                     this->write(SAMPLER_SUFFIX);
1920                     this->write("[[sampler(");
1921                     this->write(std::to_string(binding));
1922                     this->write(")]]");
1923                 }
1924             } else if (e->is<InterfaceBlock>()) {
1925                 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1926                 if (intf.typeName() == "sk_PerVertex") {
1927                     continue;
1928                 }
1929                 this->write(", constant ");
1930                 this->writeType(intf.variable().type());
1931                 this->write("& " );
1932                 this->write(fInterfaceBlockNameMap[&intf]);
1933                 this->write(" [[buffer(");
1934                 this->write(std::to_string(this->getUniformBinding(intf.variable().modifiers())));
1935                 this->write(")]]");
1936             }
1937         }
1938         if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
1939             if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1940                 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1941                 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1942             }
1943             this->write(", bool _frontFacing [[front_facing]]");
1944             this->write(", float4 _fragCoord [[position]]");
1945         } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
1946             this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1947         }
1948         separator = ", ";
1949     } else {
1950         this->writeType(f.returnType());
1951         this->write(" ");
1952         this->writeName(f.mangledName());
1953         this->write("(");
1954         this->writeFunctionRequirementParams(f, separator);
1955     }
1956     for (const auto& param : f.parameters()) {
1957         if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1958             continue;
1959         }
1960         this->write(separator);
1961         separator = ", ";
1962         this->writeModifiers(param->modifiers());
1963         const Type* type = &param->type();
1964         this->writeType(*type);
1965         if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1966             this->write("&");
1967         }
1968         this->write(" ");
1969         this->writeName(param->name());
1970     }
1971     this->write(")");
1972     return true;
1973 }
1974
1975 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1976     this->writeFunctionDeclaration(f.declaration());
1977     this->writeLine(";");
1978 }
1979
1980 static bool is_block_ending_with_return(const Statement* stmt) {
1981     // This function detects (potentially nested) blocks that end in a return statement.
1982     if (!stmt->is<Block>()) {
1983         return false;
1984     }
1985     const StatementArray& block = stmt->as<Block>().children();
1986     for (int index = block.count(); index--; ) {
1987         stmt = block[index].get();
1988         if (stmt->is<ReturnStatement>()) {
1989             return true;
1990         }
1991         if (stmt->is<Block>()) {
1992             return is_block_ending_with_return(stmt);
1993         }
1994         if (!stmt->is<Nop>()) {
1995             break;
1996         }
1997     }
1998     return false;
1999 }
2000
2001 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2002     SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2003
2004     if (!this->writeFunctionDeclaration(f.declaration())) {
2005         return;
2006     }
2007
2008     fCurrentFunction = &f.declaration();
2009     SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2010
2011     this->writeLine(" {");
2012
2013     if (f.declaration().isMain()) {
2014         this->writeGlobalInit();
2015         this->writeLine("    Outputs _out;");
2016         this->writeLine("    (void)_out;");
2017     }
2018
2019     fFunctionHeader.clear();
2020     StringStream buffer;
2021     {
2022         AutoOutputStream outputToBuffer(this, &buffer);
2023         fIndentation++;
2024         for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2025             if (!stmt->isEmpty()) {
2026                 this->writeStatement(*stmt);
2027                 this->finishLine();
2028             }
2029         }
2030         if (f.declaration().isMain()) {
2031             // If the main function doesn't end with a return, we need to synthesize one here.
2032             if (!is_block_ending_with_return(f.body().get())) {
2033                 this->writeReturnStatementFromMain();
2034                 this->finishLine();
2035             }
2036         }
2037         fIndentation--;
2038         this->writeLine("}");
2039     }
2040     this->write(fFunctionHeader);
2041     this->write(buffer.str());
2042 }
2043
2044 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2045     if (modifiers.fFlags & Modifiers::kOut_Flag) {
2046         this->write("thread ");
2047     }
2048     if (modifiers.fFlags & Modifiers::kConst_Flag) {
2049         this->write("const ");
2050     }
2051 }
2052
2053 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2054     if ("sk_PerVertex" == intf.typeName()) {
2055         return;
2056     }
2057     this->writeModifiers(intf.variable().modifiers());
2058     this->write("struct ");
2059     this->writeLine(std::string(intf.typeName()) + " {");
2060     const Type* structType = &intf.variable().type();
2061     if (structType->isArray()) {
2062         structType = &structType->componentType();
2063     }
2064     fIndentation++;
2065     this->writeFields(structType->fields(), structType->fPosition, &intf);
2066     if (fProgram.fInputs.fUseFlipRTUniform) {
2067         this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2068     }
2069     fIndentation--;
2070     this->write("}");
2071     if (intf.instanceName().size()) {
2072         this->write(" ");
2073         this->write(intf.instanceName());
2074         if (intf.arraySize() > 0) {
2075             this->write("[");
2076             this->write(std::to_string(intf.arraySize()));
2077             this->write("]");
2078         }
2079         fInterfaceBlockNameMap.set(&intf, intf.instanceName());
2080     } else {
2081         fInterfaceBlockNameMap.set(&intf, *fProgram.fSymbols->takeOwnershipOfString(
2082                 "_anonInterface" + std::to_string(fAnonInterfaceCount++)));
2083     }
2084     this->writeLine(";");
2085 }
2086
2087 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, Position parentPos,
2088         const InterfaceBlock* parentIntf) {
2089     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2090     int currentOffset = 0;
2091     for (const Type::Field& field : fields) {
2092         int fieldOffset = field.fModifiers.fLayout.fOffset;
2093         const Type* fieldType = field.fType;
2094         if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2095             fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2096                                                 "' is not permitted here");
2097             return;
2098         }
2099         if (fieldOffset != -1) {
2100             if (currentOffset > fieldOffset) {
2101                 fContext.fErrors->error(field.fPosition,
2102                                         "offset of field '" + std::string(field.fName) +
2103                                         "' must be at least " + std::to_string(currentOffset));
2104                 return;
2105             } else if (currentOffset < fieldOffset) {
2106                 this->write("char pad");
2107                 this->write(std::to_string(fPaddingCount++));
2108                 this->write("[");
2109                 this->write(std::to_string(fieldOffset - currentOffset));
2110                 this->writeLine("];");
2111                 currentOffset = fieldOffset;
2112             }
2113             int alignment = memoryLayout.alignment(*fieldType);
2114             if (fieldOffset % alignment) {
2115                 fContext.fErrors->error(field.fPosition,
2116                                         "offset of field '" + std::string(field.fName) +
2117                                         "' must be a multiple of " + std::to_string(alignment));
2118                 return;
2119             }
2120         }
2121         size_t fieldSize = memoryLayout.size(*fieldType);
2122         if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2123             fContext.fErrors->error(parentPos, "field offset overflow");
2124             return;
2125         }
2126         currentOffset += fieldSize;
2127         this->writeModifiers(field.fModifiers);
2128         this->writeType(*fieldType);
2129         this->write(" ");
2130         this->writeName(field.fName);
2131         this->writeLine(";");
2132         if (parentIntf) {
2133             fInterfaceBlockMap.set(&field, parentIntf);
2134         }
2135     }
2136 }
2137
2138 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2139     this->writeExpression(value, Precedence::kTopLevel);
2140 }
2141
2142 void MetalCodeGenerator::writeName(std::string_view name) {
2143     if (fReservedWords.contains(name)) {
2144         this->write("_"); // adding underscore before name to avoid conflict with reserved words
2145     }
2146     this->write(name);
2147 }
2148
2149 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2150     this->writeModifiers(varDecl.var().modifiers());
2151     this->writeType(varDecl.var().type());
2152     this->write(" ");
2153     this->writeName(varDecl.var().name());
2154     if (varDecl.value()) {
2155         this->write(" = ");
2156         this->writeVarInitializer(varDecl.var(), *varDecl.value());
2157     }
2158     this->write(";");
2159 }
2160
2161 void MetalCodeGenerator::writeStatement(const Statement& s) {
2162     switch (s.kind()) {
2163         case Statement::Kind::kBlock:
2164             this->writeBlock(s.as<Block>());
2165             break;
2166         case Statement::Kind::kExpression:
2167             this->writeExpressionStatement(s.as<ExpressionStatement>());
2168             break;
2169         case Statement::Kind::kReturn:
2170             this->writeReturnStatement(s.as<ReturnStatement>());
2171             break;
2172         case Statement::Kind::kVarDeclaration:
2173             this->writeVarDeclaration(s.as<VarDeclaration>());
2174             break;
2175         case Statement::Kind::kIf:
2176             this->writeIfStatement(s.as<IfStatement>());
2177             break;
2178         case Statement::Kind::kFor:
2179             this->writeForStatement(s.as<ForStatement>());
2180             break;
2181         case Statement::Kind::kDo:
2182             this->writeDoStatement(s.as<DoStatement>());
2183             break;
2184         case Statement::Kind::kSwitch:
2185             this->writeSwitchStatement(s.as<SwitchStatement>());
2186             break;
2187         case Statement::Kind::kBreak:
2188             this->write("break;");
2189             break;
2190         case Statement::Kind::kContinue:
2191             this->write("continue;");
2192             break;
2193         case Statement::Kind::kDiscard:
2194             this->write("discard_fragment();");
2195             break;
2196         case Statement::Kind::kNop:
2197             this->write(";");
2198             break;
2199         default:
2200             SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2201             break;
2202     }
2203 }
2204
2205 void MetalCodeGenerator::writeBlock(const Block& b) {
2206     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2207     // something here to make the code valid).
2208     bool isScope = b.isScope() || b.isEmpty();
2209     if (isScope) {
2210         this->writeLine("{");
2211         fIndentation++;
2212     }
2213     for (const std::unique_ptr<Statement>& stmt : b.children()) {
2214         if (!stmt->isEmpty()) {
2215             this->writeStatement(*stmt);
2216             this->finishLine();
2217         }
2218     }
2219     if (isScope) {
2220         fIndentation--;
2221         this->write("}");
2222     }
2223 }
2224
2225 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2226     this->write("if (");
2227     this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2228     this->write(") ");
2229     this->writeStatement(*stmt.ifTrue());
2230     if (stmt.ifFalse()) {
2231         this->write(" else ");
2232         this->writeStatement(*stmt.ifFalse());
2233     }
2234 }
2235
2236 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2237     // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2238     if (!f.initializer() && f.test() && !f.next()) {
2239         this->write("while (");
2240         this->writeExpression(*f.test(), Precedence::kTopLevel);
2241         this->write(") ");
2242         this->writeStatement(*f.statement());
2243         return;
2244     }
2245
2246     this->write("for (");
2247     if (f.initializer() && !f.initializer()->isEmpty()) {
2248         this->writeStatement(*f.initializer());
2249     } else {
2250         this->write("; ");
2251     }
2252     if (f.test()) {
2253         this->writeExpression(*f.test(), Precedence::kTopLevel);
2254     }
2255     this->write("; ");
2256     if (f.next()) {
2257         this->writeExpression(*f.next(), Precedence::kTopLevel);
2258     }
2259     this->write(") ");
2260     this->writeStatement(*f.statement());
2261 }
2262
2263 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2264     this->write("do ");
2265     this->writeStatement(*d.statement());
2266     this->write(" while (");
2267     this->writeExpression(*d.test(), Precedence::kTopLevel);
2268     this->write(");");
2269 }
2270
2271 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2272     if (fProgram.fConfig->fSettings.fOptimize && !s.expression()->hasSideEffects()) {
2273         // Don't emit dead expressions.
2274         return;
2275     }
2276     this->writeExpression(*s.expression(), Precedence::kTopLevel);
2277     this->write(";");
2278 }
2279
2280 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2281     this->write("switch (");
2282     this->writeExpression(*s.value(), Precedence::kTopLevel);
2283     this->writeLine(") {");
2284     fIndentation++;
2285     for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2286         const SwitchCase& c = stmt->as<SwitchCase>();
2287         if (c.isDefault()) {
2288             this->writeLine("default:");
2289         } else {
2290             this->write("case ");
2291             this->write(std::to_string(c.value()));
2292             this->writeLine(":");
2293         }
2294         if (!c.statement()->isEmpty()) {
2295             fIndentation++;
2296             this->writeStatement(*c.statement());
2297             this->finishLine();
2298             fIndentation--;
2299         }
2300     }
2301     fIndentation--;
2302     this->write("}");
2303 }
2304
2305 void MetalCodeGenerator::writeReturnStatementFromMain() {
2306     // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2307     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2308         ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2309         this->write("return _out;");
2310     } else {
2311         SkDEBUGFAIL("unsupported kind of program");
2312     }
2313 }
2314
2315 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2316     if (fCurrentFunction && fCurrentFunction->isMain()) {
2317         if (r.expression()) {
2318             if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2319                 this->write("_out.sk_FragColor = ");
2320                 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2321                 this->writeLine(";");
2322             } else {
2323                 fContext.fErrors->error(r.fPosition,
2324                         "Metal does not support returning '" +
2325                         r.expression()->type().description() + "' from main()");
2326             }
2327         }
2328         this->writeReturnStatementFromMain();
2329         return;
2330     }
2331
2332     this->write("return");
2333     if (r.expression()) {
2334         this->write(" ");
2335         this->writeExpression(*r.expression(), Precedence::kTopLevel);
2336     }
2337     this->write(";");
2338 }
2339
2340 void MetalCodeGenerator::writeHeader() {
2341     this->write("#include <metal_stdlib>\n");
2342     this->write("#include <simd/simd.h>\n");
2343     this->write("using namespace metal;\n");
2344 }
2345
2346 void MetalCodeGenerator::writeUniformStruct() {
2347     for (const ProgramElement* e : fProgram.elements()) {
2348         if (e->is<GlobalVarDeclaration>()) {
2349             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2350             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2351             if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2352                 var.type().typeKind() != Type::TypeKind::kSampler) {
2353                 int uniformSet = this->getUniformSet(var.modifiers());
2354                 // Make sure that the program's uniform-set value is consistent throughout.
2355                 if (-1 == fUniformBuffer) {
2356                     this->write("struct Uniforms {\n");
2357                     fUniformBuffer = uniformSet;
2358                 } else if (uniformSet != fUniformBuffer) {
2359                     fContext.fErrors->error(decls.fPosition,
2360                             "Metal backend requires all uniforms to have the same "
2361                             "'layout(set=...)'");
2362                 }
2363                 this->write("    ");
2364                 this->writeType(var.type());
2365                 this->write(" ");
2366                 this->writeName(var.name());
2367                 this->write(";\n");
2368             }
2369         }
2370     }
2371     if (-1 != fUniformBuffer) {
2372         this->write("};\n");
2373     }
2374 }
2375
2376 void MetalCodeGenerator::writeInputStruct() {
2377     this->write("struct Inputs {\n");
2378     for (const ProgramElement* e : fProgram.elements()) {
2379         if (e->is<GlobalVarDeclaration>()) {
2380             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2381             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2382             if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2383                 -1 == var.modifiers().fLayout.fBuiltin) {
2384                 this->write("    ");
2385                 this->writeType(var.type());
2386                 this->write(" ");
2387                 this->writeName(var.name());
2388                 if (-1 != var.modifiers().fLayout.fLocation) {
2389                     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2390                         this->write("  [[attribute(" +
2391                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2392                     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2393                         this->write("  [[user(locn" +
2394                                     std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2395                     }
2396                 }
2397                 this->write(";\n");
2398             }
2399         }
2400     }
2401     this->write("};\n");
2402 }
2403
2404 void MetalCodeGenerator::writeOutputStruct() {
2405     this->write("struct Outputs {\n");
2406     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2407         this->write("    float4 sk_Position [[position]];\n");
2408     } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2409         this->write("    half4 sk_FragColor [[color(0)]];\n");
2410     }
2411     for (const ProgramElement* e : fProgram.elements()) {
2412         if (e->is<GlobalVarDeclaration>()) {
2413             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2414             const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2415             if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2416                 -1 == var.modifiers().fLayout.fBuiltin) {
2417                 this->write("    ");
2418                 this->writeType(var.type());
2419                 this->write(" ");
2420                 this->writeName(var.name());
2421
2422                 int location = var.modifiers().fLayout.fLocation;
2423                 if (location < 0) {
2424                     fContext.fErrors->error(var.fPosition,
2425                             "Metal out variables must have 'layout(location=...)'");
2426                 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2427                     this->write(" [[user(locn" + std::to_string(location) + ")]]");
2428                 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2429                     this->write(" [[color(" + std::to_string(location) + ")");
2430                     int colorIndex = var.modifiers().fLayout.fIndex;
2431                     if (colorIndex) {
2432                         this->write(", index(" + std::to_string(colorIndex) + ")");
2433                     }
2434                     this->write("]]");
2435                 }
2436                 this->write(";\n");
2437             }
2438         }
2439     }
2440     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2441         this->write("    float sk_PointSize [[point_size]];\n");
2442     }
2443     this->write("};\n");
2444 }
2445
2446 void MetalCodeGenerator::writeInterfaceBlocks() {
2447     bool wroteInterfaceBlock = false;
2448     for (const ProgramElement* e : fProgram.elements()) {
2449         if (e->is<InterfaceBlock>()) {
2450             this->writeInterfaceBlock(e->as<InterfaceBlock>());
2451             wroteInterfaceBlock = true;
2452         }
2453     }
2454     if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2455         this->writeLine("struct sksl_synthetic_uniforms {");
2456         this->writeLine("    float2 " SKSL_RTFLIP_NAME ";");
2457         this->writeLine("};");
2458     }
2459 }
2460
2461 void MetalCodeGenerator::writeStructDefinitions() {
2462     for (const ProgramElement* e : fProgram.elements()) {
2463         if (e->is<StructDefinition>()) {
2464             this->writeStructDefinition(e->as<StructDefinition>());
2465         }
2466     }
2467 }
2468
2469 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2470     // Visit the interface blocks.
2471     for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2472         visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2473     }
2474     for (const ProgramElement* element : fProgram.elements()) {
2475         if (!element->is<GlobalVarDeclaration>()) {
2476             continue;
2477         }
2478         const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2479         const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2480         const Variable& var = decl.var();
2481         if (var.type().typeKind() == Type::TypeKind::kSampler) {
2482             // Samplers are represented as a "texture/sampler" duo in the global struct.
2483             visitor->visitTexture(var.type(), var.name());
2484             visitor->visitSampler(var.type(), std::string(var.name()) + SAMPLER_SUFFIX);
2485             continue;
2486         }
2487
2488         if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2489             -1 == var.modifiers().fLayout.fBuiltin) {
2490             // Visit a regular variable.
2491             visitor->visitVariable(var, decl.value().get());
2492         }
2493     }
2494 }
2495
2496 void MetalCodeGenerator::writeGlobalStruct() {
2497     class : public GlobalStructVisitor {
2498     public:
2499         void visitInterfaceBlock(const InterfaceBlock& block,
2500                                  std::string_view blockName) override {
2501             this->addElement();
2502             fCodeGen->write("    constant ");
2503             fCodeGen->write(block.typeName());
2504             fCodeGen->write("* ");
2505             fCodeGen->writeName(blockName);
2506             fCodeGen->write(";\n");
2507         }
2508         void visitTexture(const Type& type, std::string_view name) override {
2509             this->addElement();
2510             fCodeGen->write("    ");
2511             fCodeGen->writeType(type);
2512             fCodeGen->write(" ");
2513             fCodeGen->writeName(name);
2514             fCodeGen->write(";\n");
2515         }
2516         void visitSampler(const Type&, std::string_view name) override {
2517             this->addElement();
2518             fCodeGen->write("    sampler ");
2519             fCodeGen->writeName(name);
2520             fCodeGen->write(";\n");
2521         }
2522         void visitVariable(const Variable& var, const Expression* value) override {
2523             this->addElement();
2524             fCodeGen->write("    ");
2525             fCodeGen->writeModifiers(var.modifiers());
2526             fCodeGen->writeType(var.type());
2527             fCodeGen->write(" ");
2528             fCodeGen->writeName(var.name());
2529             fCodeGen->write(";\n");
2530         }
2531         void addElement() {
2532             if (fFirst) {
2533                 fCodeGen->write("struct Globals {\n");
2534                 fFirst = false;
2535             }
2536         }
2537         void finish() {
2538             if (!fFirst) {
2539                 fCodeGen->writeLine("};");
2540                 fFirst = true;
2541             }
2542         }
2543
2544         MetalCodeGenerator* fCodeGen = nullptr;
2545         bool fFirst = true;
2546     } visitor;
2547
2548     visitor.fCodeGen = this;
2549     this->visitGlobalStruct(&visitor);
2550     visitor.finish();
2551 }
2552
2553 void MetalCodeGenerator::writeGlobalInit() {
2554     class : public GlobalStructVisitor {
2555     public:
2556         void visitInterfaceBlock(const InterfaceBlock& blockType,
2557                                  std::string_view blockName) override {
2558             this->addElement();
2559             fCodeGen->write("&");
2560             fCodeGen->writeName(blockName);
2561         }
2562         void visitTexture(const Type&, std::string_view name) override {
2563             this->addElement();
2564             fCodeGen->writeName(name);
2565         }
2566         void visitSampler(const Type&, std::string_view name) override {
2567             this->addElement();
2568             fCodeGen->writeName(name);
2569         }
2570         void visitVariable(const Variable& var, const Expression* value) override {
2571             this->addElement();
2572             if (value) {
2573                 fCodeGen->writeVarInitializer(var, *value);
2574             } else {
2575                 fCodeGen->write("{}");
2576             }
2577         }
2578         void addElement() {
2579             if (fFirst) {
2580                 fCodeGen->write("    Globals _globals{");
2581                 fFirst = false;
2582             } else {
2583                 fCodeGen->write(", ");
2584             }
2585         }
2586         void finish() {
2587             if (!fFirst) {
2588                 fCodeGen->writeLine("};");
2589                 fCodeGen->writeLine("    (void)_globals;");
2590             }
2591         }
2592         MetalCodeGenerator* fCodeGen = nullptr;
2593         bool fFirst = true;
2594     } visitor;
2595
2596     visitor.fCodeGen = this;
2597     this->visitGlobalStruct(&visitor);
2598     visitor.finish();
2599 }
2600
2601 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2602     switch (e.kind()) {
2603         case ProgramElement::Kind::kExtension:
2604             break;
2605         case ProgramElement::Kind::kGlobalVar:
2606             break;
2607         case ProgramElement::Kind::kInterfaceBlock:
2608             // handled in writeInterfaceBlocks, do nothing
2609             break;
2610         case ProgramElement::Kind::kStructDefinition:
2611             // Handled in writeStructDefinitions. Do nothing.
2612             break;
2613         case ProgramElement::Kind::kFunction:
2614             this->writeFunction(e.as<FunctionDefinition>());
2615             break;
2616         case ProgramElement::Kind::kFunctionPrototype:
2617             this->writeFunctionPrototype(e.as<FunctionPrototype>());
2618             break;
2619         case ProgramElement::Kind::kModifiers:
2620             this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2621             this->writeLine(";");
2622             break;
2623         default:
2624             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2625             break;
2626     }
2627 }
2628
2629 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2630     if (!e) {
2631         return kNo_Requirements;
2632     }
2633     switch (e->kind()) {
2634         case Expression::Kind::kFunctionCall: {
2635             const FunctionCall& f = e->as<FunctionCall>();
2636             Requirements result = this->requirements(f.function());
2637             for (const auto& arg : f.arguments()) {
2638                 result |= this->requirements(arg.get());
2639             }
2640             return result;
2641         }
2642         case Expression::Kind::kConstructorCompound:
2643         case Expression::Kind::kConstructorCompoundCast:
2644         case Expression::Kind::kConstructorArray:
2645         case Expression::Kind::kConstructorArrayCast:
2646         case Expression::Kind::kConstructorDiagonalMatrix:
2647         case Expression::Kind::kConstructorScalarCast:
2648         case Expression::Kind::kConstructorSplat:
2649         case Expression::Kind::kConstructorStruct: {
2650             const AnyConstructor& c = e->asAnyConstructor();
2651             Requirements result = kNo_Requirements;
2652             for (const auto& arg : c.argumentSpan()) {
2653                 result |= this->requirements(arg.get());
2654             }
2655             return result;
2656         }
2657         case Expression::Kind::kFieldAccess: {
2658             const FieldAccess& f = e->as<FieldAccess>();
2659             if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2660                 return kGlobals_Requirement;
2661             }
2662             return this->requirements(f.base().get());
2663         }
2664         case Expression::Kind::kSwizzle:
2665             return this->requirements(e->as<Swizzle>().base().get());
2666         case Expression::Kind::kBinary: {
2667             const BinaryExpression& bin = e->as<BinaryExpression>();
2668             return this->requirements(bin.left().get()) |
2669                    this->requirements(bin.right().get());
2670         }
2671         case Expression::Kind::kIndex: {
2672             const IndexExpression& idx = e->as<IndexExpression>();
2673             return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2674         }
2675         case Expression::Kind::kPrefix:
2676             return this->requirements(e->as<PrefixExpression>().operand().get());
2677         case Expression::Kind::kPostfix:
2678             return this->requirements(e->as<PostfixExpression>().operand().get());
2679         case Expression::Kind::kTernary: {
2680             const TernaryExpression& t = e->as<TernaryExpression>();
2681             return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2682                    this->requirements(t.ifFalse().get());
2683         }
2684         case Expression::Kind::kVariableReference: {
2685             const VariableReference& v = e->as<VariableReference>();
2686             const Modifiers& modifiers = v.variable()->modifiers();
2687             Requirements result = kNo_Requirements;
2688             if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2689                 result = kGlobals_Requirement | kFragCoord_Requirement;
2690             } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2691                 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2692                     result = kInputs_Requirement;
2693                 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2694                     result = kOutputs_Requirement;
2695                 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2696                            v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2697                     result = kUniforms_Requirement;
2698                 } else {
2699                     result = kGlobals_Requirement;
2700                 }
2701             }
2702             return result;
2703         }
2704         default:
2705             return kNo_Requirements;
2706     }
2707 }
2708
2709 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2710     if (!s) {
2711         return kNo_Requirements;
2712     }
2713     switch (s->kind()) {
2714         case Statement::Kind::kBlock: {
2715             Requirements result = kNo_Requirements;
2716             for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2717                 result |= this->requirements(child.get());
2718             }
2719             return result;
2720         }
2721         case Statement::Kind::kVarDeclaration: {
2722             const VarDeclaration& var = s->as<VarDeclaration>();
2723             return this->requirements(var.value().get());
2724         }
2725         case Statement::Kind::kExpression:
2726             return this->requirements(s->as<ExpressionStatement>().expression().get());
2727         case Statement::Kind::kReturn: {
2728             const ReturnStatement& r = s->as<ReturnStatement>();
2729             return this->requirements(r.expression().get());
2730         }
2731         case Statement::Kind::kIf: {
2732             const IfStatement& i = s->as<IfStatement>();
2733             return this->requirements(i.test().get()) |
2734                    this->requirements(i.ifTrue().get()) |
2735                    this->requirements(i.ifFalse().get());
2736         }
2737         case Statement::Kind::kFor: {
2738             const ForStatement& f = s->as<ForStatement>();
2739             return this->requirements(f.initializer().get()) |
2740                    this->requirements(f.test().get()) |
2741                    this->requirements(f.next().get()) |
2742                    this->requirements(f.statement().get());
2743         }
2744         case Statement::Kind::kDo: {
2745             const DoStatement& d = s->as<DoStatement>();
2746             return this->requirements(d.test().get()) |
2747                    this->requirements(d.statement().get());
2748         }
2749         case Statement::Kind::kSwitch: {
2750             const SwitchStatement& sw = s->as<SwitchStatement>();
2751             Requirements result = this->requirements(sw.value().get());
2752             for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2753                 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2754             }
2755             return result;
2756         }
2757         default:
2758             return kNo_Requirements;
2759     }
2760 }
2761
2762 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2763     Requirements* found = fRequirements.find(&f);
2764     if (!found) {
2765         fRequirements.set(&f, kNo_Requirements);
2766         for (const ProgramElement* e : fProgram.elements()) {
2767             if (e->is<FunctionDefinition>()) {
2768                 const FunctionDefinition& def = e->as<FunctionDefinition>();
2769                 if (&def.declaration() == &f) {
2770                     Requirements reqs = this->requirements(def.body().get());
2771                     fRequirements.set(&f, reqs);
2772                     return reqs;
2773                 }
2774             }
2775         }
2776         // We never found a definition for this declared function, but it's legal to prototype a
2777         // function without ever giving a definition, as long as you don't call it.
2778         return kNo_Requirements;
2779     }
2780     return *found;
2781 }
2782
2783 bool MetalCodeGenerator::generateCode() {
2784     StringStream header;
2785     {
2786         AutoOutputStream outputToHeader(this, &header, &fIndentation);
2787         this->writeHeader();
2788         this->writeStructDefinitions();
2789         this->writeUniformStruct();
2790         this->writeInputStruct();
2791         this->writeOutputStruct();
2792         this->writeInterfaceBlocks();
2793         this->writeGlobalStruct();
2794     }
2795     StringStream body;
2796     {
2797         AutoOutputStream outputToBody(this, &body, &fIndentation);
2798         for (const ProgramElement* e : fProgram.elements()) {
2799             this->writeProgramElement(*e);
2800         }
2801     }
2802     write_stringstream(header, *fOut);
2803     write_stringstream(fExtraFunctionPrototypes, *fOut);
2804     write_stringstream(fExtraFunctions, *fOut);
2805     write_stringstream(body, *fOut);
2806     fContext.fErrors->reportPendingErrors(Position());
2807     return fContext.fErrors->errorCount() == 0;
2808 }
2809
2810 }  // namespace SkSL