2 * Copyright 2016 Google Inc.
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
8 #include "src/sksl/codegen/SkSLMetalCodeGenerator.h"
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkSLLayout.h"
13 #include "include/private/SkSLModifiers.h"
14 #include "include/private/SkSLProgramElement.h"
15 #include "include/private/SkSLStatement.h"
16 #include "include/private/SkSLString.h"
17 #include "include/sksl/SkSLErrorReporter.h"
18 #include "include/sksl/SkSLPosition.h"
19 #include "src/core/SkScopeExit.h"
20 #include "src/sksl/SkSLAnalysis.h"
21 #include "src/sksl/SkSLBuiltinTypes.h"
22 #include "src/sksl/SkSLCompiler.h"
23 #include "src/sksl/SkSLContext.h"
24 #include "src/sksl/SkSLMemoryLayout.h"
25 #include "src/sksl/SkSLOutputStream.h"
26 #include "src/sksl/SkSLProgramSettings.h"
27 #include "src/sksl/SkSLUtil.h"
28 #include "src/sksl/ir/SkSLBinaryExpression.h"
29 #include "src/sksl/ir/SkSLBlock.h"
30 #include "src/sksl/ir/SkSLConstructor.h"
31 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
32 #include "src/sksl/ir/SkSLConstructorCompound.h"
33 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
34 #include "src/sksl/ir/SkSLDoStatement.h"
35 #include "src/sksl/ir/SkSLExpression.h"
36 #include "src/sksl/ir/SkSLExpressionStatement.h"
37 #include "src/sksl/ir/SkSLExtension.h"
38 #include "src/sksl/ir/SkSLFieldAccess.h"
39 #include "src/sksl/ir/SkSLForStatement.h"
40 #include "src/sksl/ir/SkSLFunctionCall.h"
41 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
42 #include "src/sksl/ir/SkSLFunctionDefinition.h"
43 #include "src/sksl/ir/SkSLFunctionPrototype.h"
44 #include "src/sksl/ir/SkSLIfStatement.h"
45 #include "src/sksl/ir/SkSLIndexExpression.h"
46 #include "src/sksl/ir/SkSLInterfaceBlock.h"
47 #include "src/sksl/ir/SkSLLiteral.h"
48 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
49 #include "src/sksl/ir/SkSLNop.h"
50 #include "src/sksl/ir/SkSLPostfixExpression.h"
51 #include "src/sksl/ir/SkSLPrefixExpression.h"
52 #include "src/sksl/ir/SkSLProgram.h"
53 #include "src/sksl/ir/SkSLReturnStatement.h"
54 #include "src/sksl/ir/SkSLSetting.h"
55 #include "src/sksl/ir/SkSLStructDefinition.h"
56 #include "src/sksl/ir/SkSLSwitchCase.h"
57 #include "src/sksl/ir/SkSLSwitchStatement.h"
58 #include "src/sksl/ir/SkSLSwizzle.h"
59 #include "src/sksl/ir/SkSLTernaryExpression.h"
60 #include "src/sksl/ir/SkSLVarDeclarations.h"
61 #include "src/sksl/ir/SkSLVariable.h"
62 #include "src/sksl/ir/SkSLVariableReference.h"
63 #include "src/sksl/spirv.h"
69 #include <type_traits>
73 static const char* operator_name(Operator op) {
75 case Operator::Kind::LOGICALXOR: return " != ";
76 default: return op.operatorName();
80 class MetalCodeGenerator::GlobalStructVisitor {
82 virtual ~GlobalStructVisitor() = default;
83 virtual void visitInterfaceBlock(const InterfaceBlock& block, std::string_view blockName) = 0;
84 virtual void visitTexture(const Type& type, std::string_view name) = 0;
85 virtual void visitSampler(const Type& type, std::string_view name) = 0;
86 virtual void visitVariable(const Variable& var, const Expression* value) = 0;
89 void MetalCodeGenerator::write(std::string_view s) {
94 for (int i = 0; i < fIndentation; i++) {
98 fOut->writeText(std::string(s).c_str());
102 void MetalCodeGenerator::writeLine(std::string_view s) {
104 fOut->writeText(fLineEnding);
108 void MetalCodeGenerator::finishLine() {
114 void MetalCodeGenerator::writeExtension(const Extension& ext) {
115 this->writeLine("#extension " + std::string(ext.name()) + " : enable");
118 std::string MetalCodeGenerator::typeName(const Type& type) {
119 switch (type.typeKind()) {
120 case Type::TypeKind::kArray:
121 SkASSERTF(type.columns() > 0, "invalid array size: %s", type.description().c_str());
122 return String::printf("array<%s, %d>",
123 this->typeName(type.componentType()).c_str(), type.columns());
125 case Type::TypeKind::kVector:
126 return this->typeName(type.componentType()) + std::to_string(type.columns());
128 case Type::TypeKind::kMatrix:
129 return this->typeName(type.componentType()) + std::to_string(type.columns()) + "x" +
130 std::to_string(type.rows());
132 case Type::TypeKind::kSampler:
133 return "texture2d<half>"; // FIXME - support other texture types
136 return std::string(type.name());
140 void MetalCodeGenerator::writeStructDefinition(const StructDefinition& s) {
141 const Type& type = s.type();
142 this->writeLine("struct " + type.displayName() + " {");
144 this->writeFields(type.fields(), type.fPosition);
146 this->writeLine("};");
149 void MetalCodeGenerator::writeType(const Type& type) {
150 this->write(this->typeName(type));
153 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
154 switch (expr.kind()) {
155 case Expression::Kind::kBinary:
156 this->writeBinaryExpression(expr.as<BinaryExpression>(), parentPrecedence);
158 case Expression::Kind::kConstructorArray:
159 case Expression::Kind::kConstructorStruct:
160 this->writeAnyConstructor(expr.asAnyConstructor(), "{", "}", parentPrecedence);
162 case Expression::Kind::kConstructorArrayCast:
163 this->writeConstructorArrayCast(expr.as<ConstructorArrayCast>(), parentPrecedence);
165 case Expression::Kind::kConstructorCompound:
166 this->writeConstructorCompound(expr.as<ConstructorCompound>(), parentPrecedence);
168 case Expression::Kind::kConstructorDiagonalMatrix:
169 case Expression::Kind::kConstructorSplat:
170 this->writeAnyConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
172 case Expression::Kind::kConstructorMatrixResize:
173 this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(),
176 case Expression::Kind::kConstructorScalarCast:
177 case Expression::Kind::kConstructorCompoundCast:
178 this->writeCastConstructor(expr.asAnyConstructor(), "(", ")", parentPrecedence);
180 case Expression::Kind::kFieldAccess:
181 this->writeFieldAccess(expr.as<FieldAccess>());
183 case Expression::Kind::kLiteral:
184 this->writeLiteral(expr.as<Literal>());
186 case Expression::Kind::kFunctionCall:
187 this->writeFunctionCall(expr.as<FunctionCall>());
189 case Expression::Kind::kPrefix:
190 this->writePrefixExpression(expr.as<PrefixExpression>(), parentPrecedence);
192 case Expression::Kind::kPostfix:
193 this->writePostfixExpression(expr.as<PostfixExpression>(), parentPrecedence);
195 case Expression::Kind::kSetting:
196 this->writeSetting(expr.as<Setting>());
198 case Expression::Kind::kSwizzle:
199 this->writeSwizzle(expr.as<Swizzle>());
201 case Expression::Kind::kVariableReference:
202 this->writeVariableReference(expr.as<VariableReference>());
204 case Expression::Kind::kTernary:
205 this->writeTernaryExpression(expr.as<TernaryExpression>(), parentPrecedence);
207 case Expression::Kind::kIndex:
208 this->writeIndexExpression(expr.as<IndexExpression>());
211 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
216 std::string MetalCodeGenerator::getOutParamHelper(const FunctionCall& call,
217 const ExpressionArray& arguments,
218 const SkTArray<VariableReference*>& outVars) {
219 AutoOutputStream outputToExtraFunctions(this, &fExtraFunctions, &fIndentation);
220 const FunctionDeclaration& function = call.function();
222 std::string name = "_skOutParamHelper" + std::to_string(fSwizzleHelperCount++) +
223 "_" + function.mangledName();
224 const char* separator = "";
226 // Emit a prototype for the function we'll be calling through to in our helper.
227 if (!function.isBuiltin()) {
228 this->writeFunctionDeclaration(function);
229 this->writeLine(";");
232 // Synthesize a helper function that takes the same inputs as `function`, except in places where
233 // `outVars` is non-null; in those places, we take the type of the VariableReference.
235 // float _skOutParamHelper0_originalFuncName(float _var0, float _var1, float& outParam) {
236 this->writeType(call.type());
240 this->writeFunctionRequirementParams(function, separator);
242 SkASSERT(outVars.size() == arguments.size());
243 SkASSERT(outVars.size() == function.parameters().size());
245 // We need to detect cases where the caller passes the same variable as an out-param more than
246 // once, and avoid reusing the variable name. (In those cases we can actually just ignore the
247 // redundant input parameter entirely, and not give it any name.)
248 SkTHashSet<const Variable*> writtenVars;
250 for (int index = 0; index < arguments.count(); ++index) {
251 this->write(separator);
254 const Variable* param = function.parameters()[index];
255 this->writeModifiers(param->modifiers());
257 const Type* type = outVars[index] ? &outVars[index]->type() : &arguments[index]->type();
258 this->writeType(*type);
260 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
263 if (outVars[index]) {
264 const Variable* var = outVars[index]->variable();
265 if (!writtenVars.contains(var)) {
266 writtenVars.add(var);
269 fIgnoreVariableReferenceModifiers = true;
270 this->writeVariableReference(*outVars[index]);
271 fIgnoreVariableReferenceModifiers = false;
274 this->write(" _var");
275 this->write(std::to_string(index));
278 this->writeLine(") {");
281 for (int index = 0; index < outVars.count(); ++index) {
282 if (!outVars[index]) {
285 // float3 _var2[ = outParam.zyx];
286 this->writeType(arguments[index]->type());
287 this->write(" _var");
288 this->write(std::to_string(index));
290 const Variable* param = function.parameters()[index];
291 if (param->modifiers().fFlags & Modifiers::kIn_Flag) {
293 fIgnoreVariableReferenceModifiers = true;
294 this->writeExpression(*arguments[index], Precedence::kAssignment);
295 fIgnoreVariableReferenceModifiers = false;
298 this->writeLine(";");
301 // [int _skResult = ] myFunction(inputs, outputs, _globals, _var0, _var1, _var2, _var3);
302 bool hasResult = (call.type().name() != "void");
304 this->writeType(call.type());
305 this->write(" _skResult = ");
308 this->writeName(function.mangledName());
311 this->writeFunctionRequirementArgs(function, separator);
313 for (int index = 0; index < arguments.count(); ++index) {
314 this->write(separator);
318 this->write(std::to_string(index));
320 this->writeLine(");");
322 for (int index = 0; index < outVars.count(); ++index) {
323 if (!outVars[index]) {
326 // outParam.zyx = _var2;
327 fIgnoreVariableReferenceModifiers = true;
328 this->writeExpression(*arguments[index], Precedence::kAssignment);
329 fIgnoreVariableReferenceModifiers = false;
330 this->write(" = _var");
331 this->write(std::to_string(index));
332 this->writeLine(";");
336 this->writeLine("return _skResult;");
340 this->writeLine("}");
345 std::string MetalCodeGenerator::getBitcastIntrinsic(const Type& outType) {
346 return "as_type<" + outType.displayName() + ">";
349 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
350 const FunctionDeclaration& function = c.function();
352 // Many intrinsics need to be rewritten in Metal.
353 if (function.isIntrinsic()) {
354 if (this->writeIntrinsicCall(c, function.intrinsicKind())) {
359 // Determine whether or not we need to emulate GLSL's out-param semantics for Metal using a
360 // helper function. (Specifically, out-parameters in GLSL are only written back to the original
361 // variable at the end of the function call; also, swizzles are supported, whereas Metal doesn't
362 // allow a swizzle to be passed to a `floatN&`.)
363 const ExpressionArray& arguments = c.arguments();
364 const std::vector<const Variable*>& parameters = function.parameters();
365 SkASSERT(arguments.size() == parameters.size());
367 bool foundOutParam = false;
368 SkSTArray<16, VariableReference*> outVars;
369 outVars.push_back_n(arguments.count(), (VariableReference*)nullptr);
371 for (int index = 0; index < arguments.count(); ++index) {
372 // If this is an out parameter...
373 if (parameters[index]->modifiers().fFlags & Modifiers::kOut_Flag) {
374 // Find the expression's inner variable being written to.
375 Analysis::AssignmentInfo info;
376 // Assignability was verified at IRGeneration time, so this should always succeed.
377 SkAssertResult(Analysis::IsAssignable(*arguments[index], &info));
378 outVars[index] = info.fAssignedVar;
379 foundOutParam = true;
384 // Out parameters need to be written back to at the end of the function. To do this, we
385 // synthesize a helper function which evaluates the out-param expression into a temporary
386 // variable, calls the original function, then writes the temp var back into the out param
387 // using the original out-param expression. (This lets us support things like swizzles and
389 this->write(getOutParamHelper(c, arguments, outVars));
391 this->write(function.mangledName());
395 const char* separator = "";
396 this->writeFunctionRequirementArgs(function, separator);
397 for (int i = 0; i < arguments.count(); ++i) {
398 this->write(separator);
402 this->writeExpression(*outVars[i], Precedence::kSequence);
404 this->writeExpression(*arguments[i], Precedence::kSequence);
410 static constexpr char kInverse2x2[] = R"(
411 template <typename T>
412 matrix<T, 2, 2> mat2_inverse(matrix<T, 2, 2> m) {
413 return matrix<T, 2, 2>(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));
417 static constexpr char kInverse3x3[] = R"(
418 template <typename T>
419 matrix<T, 3, 3> mat3_inverse(matrix<T, 3, 3> m) {
420 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];
421 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];
422 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];
423 T b01 = a22*a11 - a12*a21;
424 T b11 = -a22*a10 + a12*a20;
425 T b21 = a21*a10 - a11*a20;
426 T det = a00*b01 + a01*b11 + a02*b21;
427 return matrix<T, 3, 3>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),
428 b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),
429 b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);
433 static constexpr char kInverse4x4[] = R"(
434 template <typename T>
435 matrix<T, 4, 4> mat4_inverse(matrix<T, 4, 4> m) {
436 T a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];
437 T a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];
438 T a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];
439 T a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];
440 T b00 = a00*a11 - a01*a10;
441 T b01 = a00*a12 - a02*a10;
442 T b02 = a00*a13 - a03*a10;
443 T b03 = a01*a12 - a02*a11;
444 T b04 = a01*a13 - a03*a11;
445 T b05 = a02*a13 - a03*a12;
446 T b06 = a20*a31 - a21*a30;
447 T b07 = a20*a32 - a22*a30;
448 T b08 = a20*a33 - a23*a30;
449 T b09 = a21*a32 - a22*a31;
450 T b10 = a21*a33 - a23*a31;
451 T b11 = a22*a33 - a23*a32;
452 T det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;
453 return matrix<T, 4, 4>(a11*b11 - a12*b10 + a13*b09,
454 a02*b10 - a01*b11 - a03*b09,
455 a31*b05 - a32*b04 + a33*b03,
456 a22*b04 - a21*b05 - a23*b03,
457 a12*b08 - a10*b11 - a13*b07,
458 a00*b11 - a02*b08 + a03*b07,
459 a32*b02 - a30*b05 - a33*b01,
460 a20*b05 - a22*b02 + a23*b01,
461 a10*b10 - a11*b08 + a13*b06,
462 a01*b08 - a00*b10 - a03*b06,
463 a30*b04 - a31*b02 + a33*b00,
464 a21*b02 - a20*b04 - a23*b00,
465 a11*b07 - a10*b09 - a12*b06,
466 a00*b09 - a01*b07 + a02*b06,
467 a31*b01 - a30*b03 - a32*b00,
468 a20*b03 - a21*b01 + a22*b00) * (1/det);
472 std::string MetalCodeGenerator::getInversePolyfill(const ExpressionArray& arguments) {
473 // Only use polyfills for a function taking a single-argument square matrix.
474 if (arguments.size() == 1) {
475 const Type& type = arguments.front()->type();
476 if (type.isMatrix() && type.rows() == type.columns()) {
477 // Inject the correct polyfill based on the matrix size.
478 auto name = String::printf("mat%d_inverse", type.columns());
479 auto [iter, didInsert] = fWrittenIntrinsics.insert(name);
481 switch (type.rows()) {
483 fExtraFunctions.writeText(kInverse2x2);
486 fExtraFunctions.writeText(kInverse3x3);
489 fExtraFunctions.writeText(kInverse4x4);
496 // This isn't the built-in `inverse`. We don't want to polyfill it at all.
500 void MetalCodeGenerator::writeMatrixCompMult() {
501 static constexpr char kMatrixCompMult[] = R"(
502 template <typename T, int C, int R>
503 matrix<T, C, R> matrixCompMult(matrix<T, C, R> a, const matrix<T, C, R> b) {
504 for (int c = 0; c < C; ++c) {
511 std::string name = "matrixCompMult";
512 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
513 fWrittenIntrinsics.insert(name);
514 fExtraFunctions.writeText(kMatrixCompMult);
518 void MetalCodeGenerator::writeOuterProduct() {
519 static constexpr char kOuterProduct[] = R"(
520 template <typename T, int C, int R>
521 matrix<T, C, R> outerProduct(const vec<T, R> a, const vec<T, C> b) {
522 matrix<T, C, R> result;
523 for (int c = 0; c < C; ++c) {
524 result[c] = a * b[c];
530 std::string name = "outerProduct";
531 if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
532 fWrittenIntrinsics.insert(name);
533 fExtraFunctions.writeText(kOuterProduct);
537 std::string MetalCodeGenerator::getTempVariable(const Type& type) {
538 std::string tempVar = "_skTemp" + std::to_string(fVarCount++);
539 this->fFunctionHeader += " " + this->typeName(type) + " " + tempVar + ";\n";
543 void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
544 // Write out an intrinsic function call exactly as-is. No muss no fuss.
545 this->write(c.function().name());
546 this->writeArgumentList(c.arguments());
549 void MetalCodeGenerator::writeArgumentList(const ExpressionArray& arguments) {
551 const char* separator = "";
552 for (const std::unique_ptr<Expression>& arg : arguments) {
553 this->write(separator);
555 this->writeExpression(*arg, Precedence::kSequence);
560 bool MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind) {
561 const ExpressionArray& arguments = c.arguments();
563 case k_sample_IntrinsicKind: {
564 this->writeExpression(*arguments[0], Precedence::kSequence);
565 this->write(".sample(");
566 this->writeExpression(*arguments[0], Precedence::kSequence);
567 this->write(SAMPLER_SUFFIX);
569 const Type& arg1Type = arguments[1]->type();
570 if (arg1Type.columns() == 3) {
571 // have to store the vector in a temp variable to avoid double evaluating it
572 std::string tmpVar = this->getTempVariable(arg1Type);
573 this->write("(" + tmpVar + " = ");
574 this->writeExpression(*arguments[1], Precedence::kSequence);
575 this->write(", " + tmpVar + ".xy / " + tmpVar + ".z))");
577 SkASSERT(arg1Type.columns() == 2);
578 this->writeExpression(*arguments[1], Precedence::kSequence);
583 case k_mod_IntrinsicKind: {
584 // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
585 std::string tmpX = this->getTempVariable(arguments[0]->type());
586 std::string tmpY = this->getTempVariable(arguments[1]->type());
587 this->write("(" + tmpX + " = ");
588 this->writeExpression(*arguments[0], Precedence::kSequence);
589 this->write(", " + tmpY + " = ");
590 this->writeExpression(*arguments[1], Precedence::kSequence);
591 this->write(", " + tmpX + " - " + tmpY + " * floor(" + tmpX + " / " + tmpY + "))");
594 // GLSL declares scalar versions of most geometric intrinsics, but these don't exist in MSL
595 case k_distance_IntrinsicKind: {
596 if (arguments[0]->type().columns() == 1) {
598 this->writeExpression(*arguments[0], Precedence::kAdditive);
600 this->writeExpression(*arguments[1], Precedence::kAdditive);
603 this->writeSimpleIntrinsic(c);
607 case k_dot_IntrinsicKind: {
608 if (arguments[0]->type().columns() == 1) {
610 this->writeExpression(*arguments[0], Precedence::kMultiplicative);
612 this->writeExpression(*arguments[1], Precedence::kMultiplicative);
615 this->writeSimpleIntrinsic(c);
619 case k_faceforward_IntrinsicKind: {
620 if (arguments[0]->type().columns() == 1) {
621 // ((((Nref) * (I) < 0) ? 1 : -1) * (N))
623 this->writeExpression(*arguments[2], Precedence::kSequence);
624 this->write(") * (");
625 this->writeExpression(*arguments[1], Precedence::kSequence);
626 this->write(") < 0) ? 1 : -1) * (");
627 this->writeExpression(*arguments[0], Precedence::kSequence);
630 this->writeSimpleIntrinsic(c);
634 case k_length_IntrinsicKind: {
635 this->write(arguments[0]->type().columns() == 1 ? "abs(" : "length(");
636 this->writeExpression(*arguments[0], Precedence::kSequence);
640 case k_normalize_IntrinsicKind: {
641 this->write(arguments[0]->type().columns() == 1 ? "sign(" : "normalize(");
642 this->writeExpression(*arguments[0], Precedence::kSequence);
646 case k_packUnorm2x16_IntrinsicKind: {
647 this->write("pack_float_to_unorm2x16(");
648 this->writeExpression(*arguments[0], Precedence::kSequence);
652 case k_unpackUnorm2x16_IntrinsicKind: {
653 this->write("unpack_unorm2x16_to_float(");
654 this->writeExpression(*arguments[0], Precedence::kSequence);
658 case k_packSnorm2x16_IntrinsicKind: {
659 this->write("pack_float_to_snorm2x16(");
660 this->writeExpression(*arguments[0], Precedence::kSequence);
664 case k_unpackSnorm2x16_IntrinsicKind: {
665 this->write("unpack_snorm2x16_to_float(");
666 this->writeExpression(*arguments[0], Precedence::kSequence);
670 case k_packUnorm4x8_IntrinsicKind: {
671 this->write("pack_float_to_unorm4x8(");
672 this->writeExpression(*arguments[0], Precedence::kSequence);
676 case k_unpackUnorm4x8_IntrinsicKind: {
677 this->write("unpack_unorm4x8_to_float(");
678 this->writeExpression(*arguments[0], Precedence::kSequence);
682 case k_packSnorm4x8_IntrinsicKind: {
683 this->write("pack_float_to_snorm4x8(");
684 this->writeExpression(*arguments[0], Precedence::kSequence);
688 case k_unpackSnorm4x8_IntrinsicKind: {
689 this->write("unpack_snorm4x8_to_float(");
690 this->writeExpression(*arguments[0], Precedence::kSequence);
694 case k_packHalf2x16_IntrinsicKind: {
695 this->write("as_type<uint>(half2(");
696 this->writeExpression(*arguments[0], Precedence::kSequence);
700 case k_unpackHalf2x16_IntrinsicKind: {
701 this->write("float2(as_type<half2>(");
702 this->writeExpression(*arguments[0], Precedence::kSequence);
706 case k_floatBitsToInt_IntrinsicKind:
707 case k_floatBitsToUint_IntrinsicKind:
708 case k_intBitsToFloat_IntrinsicKind:
709 case k_uintBitsToFloat_IntrinsicKind: {
710 this->write(this->getBitcastIntrinsic(c.type()));
712 this->writeExpression(*arguments[0], Precedence::kSequence);
716 case k_degrees_IntrinsicKind: {
718 this->writeExpression(*arguments[0], Precedence::kSequence);
719 this->write(") * 57.2957795)");
722 case k_radians_IntrinsicKind: {
724 this->writeExpression(*arguments[0], Precedence::kSequence);
725 this->write(") * 0.0174532925)");
728 case k_dFdx_IntrinsicKind: {
730 this->writeArgumentList(c.arguments());
733 case k_dFdy_IntrinsicKind: {
734 if (!fRTFlipName.empty()) {
735 this->write("(" + fRTFlipName + ".y * dfdy");
737 this->write("(dfdy");
739 this->writeArgumentList(c.arguments());
743 case k_inverse_IntrinsicKind: {
744 this->write(this->getInversePolyfill(arguments));
745 this->writeArgumentList(c.arguments());
748 case k_inversesqrt_IntrinsicKind: {
749 this->write("rsqrt");
750 this->writeArgumentList(c.arguments());
753 case k_atan_IntrinsicKind: {
754 this->write(c.arguments().size() == 2 ? "atan2" : "atan");
755 this->writeArgumentList(c.arguments());
758 case k_reflect_IntrinsicKind: {
759 if (arguments[0]->type().columns() == 1) {
760 // We need to synthesize `I - 2 * N * I * N`.
761 std::string tmpI = this->getTempVariable(arguments[0]->type());
762 std::string tmpN = this->getTempVariable(arguments[1]->type());
765 this->write("(" + tmpI + " = ");
766 this->writeExpression(*arguments[0], Precedence::kSequence);
769 this->write(", " + tmpN + " = ");
770 this->writeExpression(*arguments[1], Precedence::kSequence);
772 // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
773 this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
775 this->writeSimpleIntrinsic(c);
779 case k_refract_IntrinsicKind: {
780 if (arguments[0]->type().columns() == 1) {
781 // Metal does implement refract for vectors; rather than reimplementing refract from
782 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
783 this->write("(refract(float2(");
784 this->writeExpression(*arguments[0], Precedence::kSequence);
785 this->write(", 0), float2(");
786 this->writeExpression(*arguments[1], Precedence::kSequence);
787 this->write(", 0), ");
788 this->writeExpression(*arguments[2], Precedence::kSequence);
791 this->writeSimpleIntrinsic(c);
795 case k_roundEven_IntrinsicKind: {
797 this->writeArgumentList(c.arguments());
800 case k_bitCount_IntrinsicKind: {
801 this->write("popcount(");
802 this->writeExpression(*arguments[0], Precedence::kSequence);
806 case k_findLSB_IntrinsicKind: {
807 // Create a temp variable to store the expression, to avoid double-evaluating it.
808 std::string skTemp = this->getTempVariable(arguments[0]->type());
809 std::string exprType = this->typeName(arguments[0]->type());
811 // ctz returns numbits(type) on zero inputs; GLSL documents it as generating -1 instead.
812 // Use select to detect zero inputs and force a -1 result.
814 // (_skTemp1 = (.....), select(ctz(_skTemp1), int4(-1), _skTemp1 == int4(0)))
818 this->writeExpression(*arguments[0], Precedence::kSequence);
819 this->write("), select(ctz(");
822 this->write(exprType);
823 this->write("(-1), ");
826 this->write(exprType);
827 this->write("(0)))");
830 case k_findMSB_IntrinsicKind: {
831 // Create a temp variable to store the expression, to avoid double-evaluating it.
832 std::string skTemp1 = this->getTempVariable(arguments[0]->type());
833 std::string exprType = this->typeName(arguments[0]->type());
835 // GLSL findMSB is actually quite different from Metal's clz:
836 // - For signed negative numbers, it returns the first zero bit, not the first one bit!
837 // - For an empty input (0/~0 depending on sign), findMSB gives -1; clz is numbits(type)
839 // (_skTemp1 = (.....),
841 this->write(skTemp1);
843 this->writeExpression(*arguments[0], Precedence::kSequence);
846 // Signed input types might be negative; we need another helper variable to negate the
847 // input (since we can only find one bits, not zero bits).
849 if (arguments[0]->type().isSigned()) {
850 // ... _skTemp2 = (select(_skTemp1, ~_skTemp1, _skTemp1 < 0)),
851 skTemp2 = this->getTempVariable(arguments[0]->type());
852 this->write(skTemp2);
853 this->write(" = (select(");
854 this->write(skTemp1);
856 this->write(skTemp1);
858 this->write(skTemp1);
859 this->write(" < 0)), ");
864 // ... select(int4(clz(_skTemp2)), int4(-1), _skTemp2 == int4(0)))
865 this->write("select(");
866 this->write(this->typeName(c.type()));
867 this->write("(clz(");
868 this->write(skTemp2);
870 this->write(this->typeName(c.type()));
871 this->write("(-1), ");
872 this->write(skTemp2);
874 this->write(exprType);
875 this->write("(0)))");
878 case k_sign_IntrinsicKind: {
879 if (arguments[0]->type().componentType().isInteger()) {
880 // Create a temp variable to store the expression, to avoid double-evaluating it.
881 std::string skTemp = this->getTempVariable(arguments[0]->type());
882 std::string exprType = this->typeName(arguments[0]->type());
884 // (_skTemp = (.....),
888 this->writeExpression(*arguments[0], Precedence::kSequence);
891 // ... select(select(int4(0), int4(-1), _skTemp < 0), int4(1), _skTemp > 0))
892 this->write("select(select(");
893 this->write(exprType);
894 this->write("(0), ");
895 this->write(exprType);
896 this->write("(-1), ");
898 this->write(" < 0), ");
899 this->write(exprType);
900 this->write("(1), ");
902 this->write(" > 0))");
904 this->writeSimpleIntrinsic(c);
908 case k_matrixCompMult_IntrinsicKind: {
909 this->writeMatrixCompMult();
910 this->writeSimpleIntrinsic(c);
913 case k_outerProduct_IntrinsicKind: {
914 this->writeOuterProduct();
915 this->writeSimpleIntrinsic(c);
918 case k_mix_IntrinsicKind: {
919 SkASSERT(c.arguments().size() == 3);
920 if (arguments[2]->type().componentType().isBoolean()) {
921 // The Boolean forms of GLSL mix() use the select() intrinsic in Metal.
922 this->write("select");
923 this->writeArgumentList(c.arguments());
926 // The basic form of mix() is supported by Metal as-is.
927 this->writeSimpleIntrinsic(c);
930 case k_equal_IntrinsicKind:
931 case k_greaterThan_IntrinsicKind:
932 case k_greaterThanEqual_IntrinsicKind:
933 case k_lessThan_IntrinsicKind:
934 case k_lessThanEqual_IntrinsicKind:
935 case k_notEqual_IntrinsicKind: {
937 this->writeExpression(*c.arguments()[0], Precedence::kRelational);
939 case k_equal_IntrinsicKind:
942 case k_notEqual_IntrinsicKind:
945 case k_lessThan_IntrinsicKind:
948 case k_lessThanEqual_IntrinsicKind:
951 case k_greaterThan_IntrinsicKind:
954 case k_greaterThanEqual_IntrinsicKind:
958 SK_ABORT("unsupported comparison intrinsic kind");
960 this->writeExpression(*c.arguments()[1], Precedence::kRelational);
969 // Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
970 // Cells that don't exist in the source matrix will be populated with identity-matrix values.
971 void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
973 SkASSERT(columns <= 4);
975 std::string matrixType = this->typeName(sourceMatrix.componentType());
977 const char* separator = "";
978 for (int c = 0; c < columns; ++c) {
979 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
982 // Determine how many values to take from the source matrix for this row.
983 int swizzleLength = 0;
984 if (c < sourceMatrix.columns()) {
985 swizzleLength = std::min<>(rows, sourceMatrix.rows());
988 // Emit all the values from the source matrix row.
990 switch (swizzleLength) {
991 case 0: firstItem = true; break;
992 case 1: firstItem = false; fExtraFunctions.printf("x0[%d].x", c); break;
993 case 2: firstItem = false; fExtraFunctions.printf("x0[%d].xy", c); break;
994 case 3: firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c); break;
995 case 4: firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
996 default: SkUNREACHABLE;
999 // Emit the placeholder identity-matrix cells.
1000 for (int r = swizzleLength; r < rows; ++r) {
1001 fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
1006 fExtraFunctions.writeText(")");
1009 // Assembles a matrix of type floatCxR by concatenating an arbitrary mix of values, named `x0`,
1010 // `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
1011 void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
1012 int columns, int rows) {
1013 SkASSERT(rows <= 4);
1014 SkASSERT(columns <= 4);
1016 std::string matrixType = this->typeName(ctor.type().componentType());
1017 size_t argIndex = 0;
1018 int argPosition = 0;
1019 auto args = ctor.argumentSpan();
1021 static constexpr char kSwizzle[] = "xyzw";
1022 const char* separator = "";
1023 for (int c = 0; c < columns; ++c) {
1024 fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
1027 const char* columnSeparator = "";
1028 for (int r = 0; r < rows;) {
1029 fExtraFunctions.writeText(columnSeparator);
1030 columnSeparator = ", ";
1032 if (argIndex < args.size()) {
1033 const Type& argType = args[argIndex]->type();
1034 switch (argType.typeKind()) {
1035 case Type::TypeKind::kScalar: {
1036 fExtraFunctions.printf("x%zu", argIndex);
1041 case Type::TypeKind::kVector: {
1042 fExtraFunctions.printf("x%zu.", argIndex);
1044 fExtraFunctions.write8(kSwizzle[argPosition]);
1047 } while (r < rows && argPosition < argType.columns());
1050 case Type::TypeKind::kMatrix: {
1051 fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
1053 fExtraFunctions.write8(kSwizzle[argPosition]);
1056 } while (r < rows && (argPosition % argType.rows()) != 0);
1060 SkDEBUGFAIL("incorrect type of argument for matrix constructor");
1061 fExtraFunctions.writeText("<error>");
1066 if (argPosition >= argType.columns() * argType.rows()) {
1071 SkDEBUGFAIL("not enough arguments for matrix constructor");
1072 fExtraFunctions.writeText("<error>");
1077 if (argPosition != 0 || argIndex != args.size()) {
1078 SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
1079 fExtraFunctions.writeText(", <error>");
1082 fExtraFunctions.writeText(")");
1085 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
1086 // Keeps track of previously generated constructors so that we won't generate more than one
1087 // constructor for any given permutation of input argument types. Returns the name of the
1088 // generated constructor method.
1089 std::string MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
1090 const Type& type = c.type();
1091 int columns = type.columns();
1092 int rows = type.rows();
1093 auto args = c.argumentSpan();
1094 std::string typeName = this->typeName(type);
1096 // Create the helper-method name and use it as our lookup key.
1097 std::string name = String::printf("%s_from", typeName.c_str());
1098 for (const std::unique_ptr<Expression>& expr : args) {
1099 String::appendf(&name, "_%s", this->typeName(expr->type()).c_str());
1102 // If a helper-method has not been synthesized yet, create it now.
1103 if (!fHelpers.contains(name)) {
1106 // Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
1107 // components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
1108 // supply a mixture of scalars and vectors.)
1109 fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
1111 size_t argIndex = 0;
1112 const char* argSeparator = "";
1113 for (const std::unique_ptr<Expression>& expr : args) {
1114 fExtraFunctions.printf("%s%s x%zu", argSeparator,
1115 this->typeName(expr->type()).c_str(), argIndex++);
1116 argSeparator = ", ";
1119 fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
1121 if (args.size() == 1 && args.front()->type().isMatrix()) {
1122 this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
1124 this->assembleMatrixFromExpressions(c, columns, rows);
1127 fExtraFunctions.writeText(");\n}\n");
1132 bool MetalCodeGenerator::matrixConstructHelperIsNeeded(const ConstructorCompound& c) {
1133 SkASSERT(c.type().isMatrix());
1135 // GLSL is fairly free-form about inputs to its matrix constructors, but Metal is not; it
1136 // expects exactly R vectors of C components apiece. (Metal 2.0 also allows a list of R*C
1137 // scalars.) Some cases are simple to translate and so we handle those inline--e.g. a list of
1138 // scalars can be constructed trivially. In more complex cases, we generate a helper function
1139 // that converts our inputs into a properly-shaped matrix.
1140 // A matrix construct helper method is always used if any input argument is a matrix.
1141 // Helper methods are also necessary when any argument would span multiple rows. For instance:
1143 // float2 x = (1, 2);
1144 // float3x2(x, 3, 4, 5, 6) = | 1 3 5 | = no helper needed; conversion can be done inline
1147 // float2 x = (2, 3);
1148 // float3x2(1, x, 4, 5, 6) = | 1 3 5 | = x spans multiple rows; a helper method will be used
1151 // float4 x = (1, 2, 3, 4);
1152 // float2x2(x) = | 1 3 | = x spans multiple rows; a helper method will be used
1157 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1158 // If an input argument is a matrix, we need a helper function.
1159 if (expr->type().isMatrix()) {
1162 position += expr->type().columns();
1163 if (position > c.type().rows()) {
1164 // An input argument would span multiple rows; a helper function is required.
1167 if (position == c.type().rows()) {
1168 // We've advanced to the end of a row. Wrap to the start of the next row.
1176 void MetalCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1177 Precedence parentPrecedence) {
1178 // Matrix-resize via casting doesn't natively exist in Metal at all, so we always need to use a
1179 // matrix-construct helper here.
1180 this->write(this->getMatrixConstructHelper(c));
1182 this->writeExpression(*c.argument(), Precedence::kSequence);
1186 void MetalCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1187 Precedence parentPrecedence) {
1188 if (c.type().isVector()) {
1189 this->writeConstructorCompoundVector(c, parentPrecedence);
1190 } else if (c.type().isMatrix()) {
1191 this->writeConstructorCompoundMatrix(c, parentPrecedence);
1193 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
1197 void MetalCodeGenerator::writeConstructorArrayCast(const ConstructorArrayCast& c,
1198 Precedence parentPrecedence) {
1199 const Type& inType = c.argument()->type().componentType();
1200 const Type& outType = c.type().componentType();
1201 std::string inTypeName = this->typeName(inType);
1202 std::string outTypeName = this->typeName(outType);
1204 std::string name = "array_of_" + outTypeName + "_from_" + inTypeName;
1205 if (!fHelpers.contains(name)) {
1207 fExtraFunctions.printf(R"(
1209 array<%s, N> %s(thread const array<%s, N>& x) {
1210 array<%s, N> result;
1211 for (int i = 0; i < N; ++i) {
1212 result[i] = %s(x[i]);
1217 outTypeName.c_str(), name.c_str(), inTypeName.c_str(),
1218 outTypeName.c_str(),
1219 outTypeName.c_str());
1224 this->writeExpression(*c.argument(), Precedence::kSequence);
1228 std::string MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
1229 SkASSERT(matrixType.isMatrix());
1230 SkASSERT(matrixType.rows() == 2);
1231 SkASSERT(matrixType.columns() == 2);
1233 std::string baseType = this->typeName(matrixType.componentType());
1234 std::string name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
1235 if (!fHelpers.contains(name)) {
1238 fExtraFunctions.printf(R"(
1240 return %s4(x[0].xy, x[1].xy);
1242 )", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
1248 void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
1249 Precedence parentPrecedence) {
1250 SkASSERT(c.type().isVector());
1252 // Metal supports constructing vectors from a mix of scalars and vectors, but not matrices.
1253 // GLSL supports vec4(mat2x2), so we detect that case here and emit a helper function.
1254 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
1255 const Expression& expr = *c.argumentSpan().front();
1256 if (expr.type().isMatrix()) {
1257 this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
1259 this->writeExpression(expr, Precedence::kSequence);
1265 this->writeAnyConstructor(c, "(", ")", parentPrecedence);
1268 void MetalCodeGenerator::writeConstructorCompoundMatrix(const ConstructorCompound& c,
1269 Precedence parentPrecedence) {
1270 SkASSERT(c.type().isMatrix());
1272 // Emit and invoke a matrix-constructor helper method if one is necessary.
1273 if (this->matrixConstructHelperIsNeeded(c)) {
1274 this->write(this->getMatrixConstructHelper(c));
1276 const char* separator = "";
1277 for (const std::unique_ptr<Expression>& expr : c.arguments()) {
1278 this->write(separator);
1280 this->writeExpression(*expr, Precedence::kSequence);
1286 // Metal doesn't allow creating matrices by passing in scalars and vectors in a jumble; it
1287 // requires your scalars to be grouped up into columns. Because `matrixConstructHelperIsNeeded`
1288 // returned false, we know that none of our scalars/vectors "wrap" across across a column, so we
1289 // can group our inputs up and synthesize a constructor for each column.
1290 const Type& matrixType = c.type();
1291 const Type& columnType = matrixType.componentType().toCompound(
1292 fContext, /*columns=*/matrixType.rows(), /*rows=*/1);
1294 this->writeType(matrixType);
1296 const char* separator = "";
1297 int scalarCount = 0;
1298 for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1299 this->write(separator);
1301 if (arg->type().columns() < matrixType.rows()) {
1302 // Write a `floatN(` constructor to group scalars and smaller vectors together.
1304 this->writeType(columnType);
1307 scalarCount += arg->type().columns();
1309 this->writeExpression(*arg, Precedence::kSequence);
1310 if (scalarCount && scalarCount == matrixType.rows()) {
1311 // Close our `floatN(...` constructor block from above.
1319 void MetalCodeGenerator::writeAnyConstructor(const AnyConstructor& c,
1320 const char* leftBracket,
1321 const char* rightBracket,
1322 Precedence parentPrecedence) {
1323 this->writeType(c.type());
1324 this->write(leftBracket);
1325 const char* separator = "";
1326 for (const std::unique_ptr<Expression>& arg : c.argumentSpan()) {
1327 this->write(separator);
1329 this->writeExpression(*arg, Precedence::kSequence);
1331 this->write(rightBracket);
1334 void MetalCodeGenerator::writeCastConstructor(const AnyConstructor& c,
1335 const char* leftBracket,
1336 const char* rightBracket,
1337 Precedence parentPrecedence) {
1338 return this->writeAnyConstructor(c, leftBracket, rightBracket, parentPrecedence);
1341 void MetalCodeGenerator::writeFragCoord() {
1342 if (!fRTFlipName.empty()) {
1343 this->write("float4(_fragCoord.x, ");
1344 this->write(fRTFlipName.c_str());
1345 this->write(".x + ");
1346 this->write(fRTFlipName.c_str());
1347 this->write(".y * _fragCoord.y, 0.0, _fragCoord.w)");
1349 this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
1353 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
1354 // When assembling out-param helper functions, we copy variables into local clones with matching
1355 // names. We never want to prepend "_in." or "_globals." when writing these variables since
1356 // we're actually targeting the clones.
1357 if (fIgnoreVariableReferenceModifiers) {
1358 this->writeName(ref.variable()->name());
1362 switch (ref.variable()->modifiers().fLayout.fBuiltin) {
1363 case SK_FRAGCOLOR_BUILTIN:
1364 this->write("_out.sk_FragColor");
1366 case SK_FRAGCOORD_BUILTIN:
1367 this->writeFragCoord();
1369 case SK_VERTEXID_BUILTIN:
1370 this->write("sk_VertexID");
1372 case SK_INSTANCEID_BUILTIN:
1373 this->write("sk_InstanceID");
1375 case SK_CLOCKWISE_BUILTIN:
1376 // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
1377 // clockwise to match Skia convention.
1378 if (!fRTFlipName.empty()) {
1379 this->write("(" + fRTFlipName + ".y < 0 ? _frontFacing : !_frontFacing)");
1381 this->write("_frontFacing");
1385 const Variable& var = *ref.variable();
1386 if (var.storage() == Variable::Storage::kGlobal) {
1387 if (var.modifiers().fFlags & Modifiers::kIn_Flag) {
1388 this->write("_in.");
1389 } else if (var.modifiers().fFlags & Modifiers::kOut_Flag) {
1390 this->write("_out.");
1391 } else if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
1392 var.type().typeKind() != Type::TypeKind::kSampler) {
1393 this->write("_uniforms.");
1395 this->write("_globals.");
1398 this->writeName(var.name());
1402 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
1403 this->writeExpression(*expr.base(), Precedence::kPostfix);
1405 this->writeExpression(*expr.index(), Precedence::kTopLevel);
1409 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
1410 const Type::Field* field = &f.base()->type().fields()[f.fieldIndex()];
1411 if (FieldAccess::OwnerKind::kDefault == f.ownerKind()) {
1412 this->writeExpression(*f.base(), Precedence::kPostfix);
1415 switch (field->fModifiers.fLayout.fBuiltin) {
1416 case SK_POSITION_BUILTIN:
1417 this->write("_out.sk_Position");
1419 case SK_POINTSIZE_BUILTIN:
1420 this->write("_out.sk_PointSize");
1423 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
1424 this->write("_globals.");
1425 this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
1428 this->writeName(field->fName);
1432 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
1433 this->writeExpression(*swizzle.base(), Precedence::kPostfix);
1435 for (int c : swizzle.components()) {
1436 SkASSERT(c >= 0 && c <= 3);
1437 this->write(&("x\0y\0z\0w\0"[c * 2]));
1441 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
1442 const Type& result) {
1443 SkASSERT(left.isMatrix());
1444 SkASSERT(right.isMatrix());
1445 SkASSERT(result.isMatrix());
1446 SkASSERT(left.rows() == right.rows());
1447 SkASSERT(left.columns() == right.columns());
1448 SkASSERT(left.rows() == result.rows());
1449 SkASSERT(left.columns() == result.columns());
1451 std::string key = "Matrix *= " + this->typeName(left) + ":" + this->typeName(right);
1453 if (!fHelpers.contains(key)) {
1455 fExtraFunctions.printf("thread %s& operator*=(thread %s& left, thread const %s& right) {\n"
1456 " left = left * right;\n"
1459 this->typeName(result).c_str(), this->typeName(left).c_str(),
1460 this->typeName(right).c_str());
1464 void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
1465 SkASSERT(left.isMatrix());
1466 SkASSERT(right.isMatrix());
1467 SkASSERT(left.rows() == right.rows());
1468 SkASSERT(left.columns() == right.columns());
1470 std::string key = "Matrix == " + this->typeName(left) + ":" + this->typeName(right);
1472 if (!fHelpers.contains(key)) {
1474 fExtraFunctionPrototypes.printf(R"(
1475 thread bool operator==(const %s left, const %s right);
1476 thread bool operator!=(const %s left, const %s right);
1478 this->typeName(left).c_str(),
1479 this->typeName(right).c_str(),
1480 this->typeName(left).c_str(),
1481 this->typeName(right).c_str());
1483 fExtraFunctions.printf(
1484 "thread bool operator==(const %s left, const %s right) {\n"
1486 this->typeName(left).c_str(), this->typeName(right).c_str());
1488 const char* separator = "";
1489 for (int index=0; index<left.columns(); ++index) {
1490 fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
1491 separator = " &&\n ";
1494 fExtraFunctions.printf(
1497 "thread bool operator!=(const %s left, const %s right) {\n"
1498 " return !(left == right);\n"
1500 this->typeName(left).c_str(), this->typeName(right).c_str());
1504 void MetalCodeGenerator::writeMatrixDivisionHelpers(const Type& type) {
1505 SkASSERT(type.isMatrix());
1507 std::string key = "Matrix / " + this->typeName(type);
1509 if (!fHelpers.contains(key)) {
1511 std::string typeName = this->typeName(type);
1513 fExtraFunctions.printf(
1514 "thread %s operator/(const %s left, const %s right) {\n"
1516 typeName.c_str(), typeName.c_str(), typeName.c_str(), typeName.c_str());
1518 const char* separator = "";
1519 for (int index=0; index<type.columns(); ++index) {
1520 fExtraFunctions.printf("%sleft[%d] / right[%d]", separator, index, index);
1524 fExtraFunctions.printf(");\n"
1526 "thread %s& operator/=(thread %s& left, thread const %s& right) {\n"
1527 " left = left / right;\n"
1530 typeName.c_str(), typeName.c_str(), typeName.c_str());
1534 void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
1535 SkASSERT(type.isArray());
1537 // If the array's component type needs a helper as well, we need to emit that one first.
1538 this->writeEqualityHelpers(type.componentType(), type.componentType());
1540 std::string key = "ArrayEquality []";
1541 if (!fHelpers.contains(key)) {
1543 fExtraFunctionPrototypes.writeText(R"(
1544 template <typename T1, typename T2, size_t N>
1545 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right);
1546 template <typename T1, typename T2, size_t N>
1547 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right);
1549 fExtraFunctions.writeText(R"(
1550 template <typename T1, typename T2, size_t N>
1551 bool operator==(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1552 for (size_t index = 0; index < N; ++index) {
1553 if (!all(left[index] == right[index])) {
1560 template <typename T1, typename T2, size_t N>
1561 bool operator!=(thread const array<T1, N>& left, thread const array<T2, N>& right) {
1562 return !(left == right);
1568 void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
1569 SkASSERT(type.isStruct());
1570 std::string key = "StructEquality " + this->typeName(type);
1572 if (!fHelpers.contains(key)) {
1574 // If one of the struct's fields needs a helper as well, we need to emit that one first.
1575 for (const Type::Field& field : type.fields()) {
1576 this->writeEqualityHelpers(*field.fType, *field.fType);
1579 // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
1580 // and GLSL but do not exist by default in Metal.
1581 fExtraFunctionPrototypes.printf(R"(
1582 thread bool operator==(thread const %s& left, thread const %s& right);
1583 thread bool operator!=(thread const %s& left, thread const %s& right);
1585 this->typeName(type).c_str(),
1586 this->typeName(type).c_str(),
1587 this->typeName(type).c_str(),
1588 this->typeName(type).c_str());
1590 fExtraFunctions.printf(
1591 "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
1593 this->typeName(type).c_str(),
1594 this->typeName(type).c_str());
1596 const char* separator = "";
1597 for (const Type::Field& field : type.fields()) {
1598 fExtraFunctions.printf("%sall(left.%.*s == right.%.*s)",
1600 (int)field.fName.size(), field.fName.data(),
1601 (int)field.fName.size(), field.fName.data());
1602 separator = " &&\n ";
1604 fExtraFunctions.printf(
1607 "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
1608 " return !(left == right);\n"
1610 this->typeName(type).c_str(),
1611 this->typeName(type).c_str());
1615 void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
1616 if (leftType.isArray() && rightType.isArray()) {
1617 this->writeArrayEqualityHelpers(leftType);
1620 if (leftType.isStruct() && rightType.isStruct()) {
1621 this->writeStructEqualityHelpers(leftType);
1624 if (leftType.isMatrix() && rightType.isMatrix()) {
1625 this->writeMatrixEqualityHelpers(leftType, rightType);
1630 void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
1631 SkASSERT(expr.type().isNumber());
1632 SkASSERT(matrixType.isMatrix());
1634 // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
1636 this->writeType(matrixType);
1639 const char* separator = "";
1640 for (int index = matrixType.slotCount(); index--;) {
1641 this->write(separator);
1646 this->write(") * ");
1647 this->writeExpression(expr, Precedence::kMultiplicative);
1651 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
1652 Precedence parentPrecedence) {
1653 const Expression& left = *b.left();
1654 const Expression& right = *b.right();
1655 const Type& leftType = left.type();
1656 const Type& rightType = right.type();
1657 Operator op = b.getOperator();
1658 Precedence precedence = op.getBinaryPrecedence();
1659 bool needParens = precedence >= parentPrecedence;
1660 switch (op.kind()) {
1661 case Operator::Kind::EQEQ:
1662 this->writeEqualityHelpers(leftType, rightType);
1663 if (leftType.isVector()) {
1668 case Operator::Kind::NEQ:
1669 this->writeEqualityHelpers(leftType, rightType);
1670 if (leftType.isVector()) {
1678 if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Operator::Kind::STAREQ) {
1679 this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
1681 if (op.removeAssignment().kind() == Operator::Kind::SLASH &&
1682 ((leftType.isMatrix() && rightType.isMatrix()) ||
1683 (leftType.isScalar() && rightType.isMatrix()) ||
1684 (leftType.isMatrix() && rightType.isScalar()))) {
1685 this->writeMatrixDivisionHelpers(leftType.isMatrix() ? leftType : rightType);
1690 bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isNumber() &&
1691 op.isValidForMatrixOrVector() &&
1692 op.removeAssignment().kind() != Operator::Kind::STAR;
1693 if (needMatrixSplatOnScalar) {
1694 this->writeNumberAsMatrix(left, rightType);
1696 this->writeExpression(left, precedence);
1698 if (op.kind() != Operator::Kind::EQ && op.isAssignment() &&
1699 left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
1700 // This doesn't compile in Metal:
1701 // float4 x = float4(1);
1702 // x.xy *= float2x2(...);
1703 // with the error message "non-const reference cannot bind to vector element",
1704 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
1705 // as long as the LHS has no side effects, and hope for the best otherwise.
1707 this->writeExpression(left, Precedence::kAssignment);
1708 this->write(operator_name(op.removeAssignment()));
1710 this->write(operator_name(op));
1713 needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isNumber() &&
1714 op.isValidForMatrixOrVector() &&
1715 op.removeAssignment().kind() != Operator::Kind::STAR;
1716 if (needMatrixSplatOnScalar) {
1717 this->writeNumberAsMatrix(right, leftType);
1719 this->writeExpression(right, precedence);
1726 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
1727 Precedence parentPrecedence) {
1728 if (Precedence::kTernary >= parentPrecedence) {
1731 this->writeExpression(*t.test(), Precedence::kTernary);
1733 this->writeExpression(*t.ifTrue(), Precedence::kTernary);
1735 this->writeExpression(*t.ifFalse(), Precedence::kTernary);
1736 if (Precedence::kTernary >= parentPrecedence) {
1741 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
1742 Precedence parentPrecedence) {
1743 // According to the MSL specification, the arithmetic unary operators (+ and –) do not act
1744 // upon matrix type operands. We treat the unary "+" as NOP for all operands.
1745 const Operator op = p.getOperator();
1746 if (op.kind() == Operator::Kind::PLUS) {
1747 return this->writeExpression(*p.operand(), Precedence::kPrefix);
1750 const bool matrixNegation =
1751 op.kind() == Operator::Kind::MINUS && p.operand()->type().isMatrix();
1752 const bool needParens = Precedence::kPrefix >= parentPrecedence || matrixNegation;
1758 // Transform the unary "-" on a matrix type to a multiplication by -1.
1759 if (matrixNegation) {
1760 this->write("-1.0 * ");
1762 this->write(p.getOperator().tightOperatorName());
1764 this->writeExpression(*p.operand(), Precedence::kPrefix);
1771 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
1772 Precedence parentPrecedence) {
1773 if (Precedence::kPostfix >= parentPrecedence) {
1776 this->writeExpression(*p.operand(), Precedence::kPostfix);
1777 this->write(p.getOperator().tightOperatorName());
1778 if (Precedence::kPostfix >= parentPrecedence) {
1783 void MetalCodeGenerator::writeLiteral(const Literal& l) {
1784 const Type& type = l.type();
1785 if (type.isFloat()) {
1786 this->write(skstd::to_string(l.floatValue()));
1787 if (!l.type().highPrecision()) {
1792 if (type.isInteger()) {
1793 if (type.matches(*fContext.fTypes.fUInt)) {
1794 this->write(std::to_string(l.intValue() & 0xffffffff));
1796 } else if (type.matches(*fContext.fTypes.fUShort)) {
1797 this->write(std::to_string(l.intValue() & 0xffff));
1800 this->write(std::to_string(l.intValue()));
1804 SkASSERT(type.isBoolean());
1805 this->write(l.boolValue() ? "true" : "false");
1808 void MetalCodeGenerator::writeSetting(const Setting& s) {
1809 SK_ABORT("internal error; setting was not folded to a constant during compilation\n");
1812 void MetalCodeGenerator::writeFunctionRequirementArgs(const FunctionDeclaration& f,
1813 const char*& separator) {
1814 Requirements requirements = this->requirements(f);
1815 if (requirements & kInputs_Requirement) {
1816 this->write(separator);
1820 if (requirements & kOutputs_Requirement) {
1821 this->write(separator);
1822 this->write("_out");
1825 if (requirements & kUniforms_Requirement) {
1826 this->write(separator);
1827 this->write("_uniforms");
1830 if (requirements & kGlobals_Requirement) {
1831 this->write(separator);
1832 this->write("_globals");
1835 if (requirements & kFragCoord_Requirement) {
1836 this->write(separator);
1837 this->write("_fragCoord");
1842 void MetalCodeGenerator::writeFunctionRequirementParams(const FunctionDeclaration& f,
1843 const char*& separator) {
1844 Requirements requirements = this->requirements(f);
1845 if (requirements & kInputs_Requirement) {
1846 this->write(separator);
1847 this->write("Inputs _in");
1850 if (requirements & kOutputs_Requirement) {
1851 this->write(separator);
1852 this->write("thread Outputs& _out");
1855 if (requirements & kUniforms_Requirement) {
1856 this->write(separator);
1857 this->write("Uniforms _uniforms");
1860 if (requirements & kGlobals_Requirement) {
1861 this->write(separator);
1862 this->write("thread Globals& _globals");
1865 if (requirements & kFragCoord_Requirement) {
1866 this->write(separator);
1867 this->write("float4 _fragCoord");
1872 int MetalCodeGenerator::getUniformBinding(const Modifiers& m) {
1873 return (m.fLayout.fBinding >= 0) ? m.fLayout.fBinding
1874 : fProgram.fConfig->fSettings.fDefaultUniformBinding;
1877 int MetalCodeGenerator::getUniformSet(const Modifiers& m) {
1878 return (m.fLayout.fSet >= 0) ? m.fLayout.fSet
1879 : fProgram.fConfig->fSettings.fDefaultUniformSet;
1882 bool MetalCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& f) {
1883 fRTFlipName = fProgram.fInputs.fUseFlipRTUniform
1884 ? "_globals._anonInterface0->" SKSL_RTFLIP_NAME
1886 const char* separator = "";
1888 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
1889 this->write("fragment Outputs fragmentMain");
1890 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
1891 this->write("vertex Outputs vertexMain");
1893 fContext.fErrors->error(Position(), "unsupported kind of program");
1896 this->write("(Inputs _in [[stage_in]]");
1897 if (-1 != fUniformBuffer) {
1898 this->write(", constant Uniforms& _uniforms [[buffer(" +
1899 std::to_string(fUniformBuffer) + ")]]");
1901 for (const ProgramElement* e : fProgram.elements()) {
1902 if (e->is<GlobalVarDeclaration>()) {
1903 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
1904 const VarDeclaration& var = decls.declaration()->as<VarDeclaration>();
1905 if (var.var().type().typeKind() == Type::TypeKind::kSampler) {
1906 if (var.var().type().dimensions() != SpvDim2D) {
1907 // Not yet implemented--Skia currently only uses 2D textures.
1908 fContext.fErrors->error(decls.fPosition, "Unsupported texture dimensions");
1911 int binding = getUniformBinding(var.var().modifiers());
1912 this->write(", texture2d<half> ");
1913 this->writeName(var.var().name());
1914 this->write("[[texture(");
1915 this->write(std::to_string(binding));
1917 this->write(", sampler ");
1918 this->writeName(var.var().name());
1919 this->write(SAMPLER_SUFFIX);
1920 this->write("[[sampler(");
1921 this->write(std::to_string(binding));
1924 } else if (e->is<InterfaceBlock>()) {
1925 const InterfaceBlock& intf = e->as<InterfaceBlock>();
1926 if (intf.typeName() == "sk_PerVertex") {
1929 this->write(", constant ");
1930 this->writeType(intf.variable().type());
1932 this->write(fInterfaceBlockNameMap[&intf]);
1933 this->write(" [[buffer(");
1934 this->write(std::to_string(this->getUniformBinding(intf.variable().modifiers())));
1938 if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
1939 if (fProgram.fInputs.fUseFlipRTUniform && fInterfaceBlockNameMap.empty()) {
1940 this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
1941 fRTFlipName = "_anonInterface0." SKSL_RTFLIP_NAME;
1943 this->write(", bool _frontFacing [[front_facing]]");
1944 this->write(", float4 _fragCoord [[position]]");
1945 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
1946 this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
1950 this->writeType(f.returnType());
1952 this->writeName(f.mangledName());
1954 this->writeFunctionRequirementParams(f, separator);
1956 for (const auto& param : f.parameters()) {
1957 if (f.isMain() && param->modifiers().fLayout.fBuiltin != -1) {
1960 this->write(separator);
1962 this->writeModifiers(param->modifiers());
1963 const Type* type = ¶m->type();
1964 this->writeType(*type);
1965 if (param->modifiers().fFlags & Modifiers::kOut_Flag) {
1969 this->writeName(param->name());
1975 void MetalCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
1976 this->writeFunctionDeclaration(f.declaration());
1977 this->writeLine(";");
1980 static bool is_block_ending_with_return(const Statement* stmt) {
1981 // This function detects (potentially nested) blocks that end in a return statement.
1982 if (!stmt->is<Block>()) {
1985 const StatementArray& block = stmt->as<Block>().children();
1986 for (int index = block.count(); index--; ) {
1987 stmt = block[index].get();
1988 if (stmt->is<ReturnStatement>()) {
1991 if (stmt->is<Block>()) {
1992 return is_block_ending_with_return(stmt);
1994 if (!stmt->is<Nop>()) {
2001 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
2002 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
2004 if (!this->writeFunctionDeclaration(f.declaration())) {
2008 fCurrentFunction = &f.declaration();
2009 SkScopeExit clearCurrentFunction([&] { fCurrentFunction = nullptr; });
2011 this->writeLine(" {");
2013 if (f.declaration().isMain()) {
2014 this->writeGlobalInit();
2015 this->writeLine(" Outputs _out;");
2016 this->writeLine(" (void)_out;");
2019 fFunctionHeader.clear();
2020 StringStream buffer;
2022 AutoOutputStream outputToBuffer(this, &buffer);
2024 for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
2025 if (!stmt->isEmpty()) {
2026 this->writeStatement(*stmt);
2030 if (f.declaration().isMain()) {
2031 // If the main function doesn't end with a return, we need to synthesize one here.
2032 if (!is_block_ending_with_return(f.body().get())) {
2033 this->writeReturnStatementFromMain();
2038 this->writeLine("}");
2040 this->write(fFunctionHeader);
2041 this->write(buffer.str());
2044 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers) {
2045 if (modifiers.fFlags & Modifiers::kOut_Flag) {
2046 this->write("thread ");
2048 if (modifiers.fFlags & Modifiers::kConst_Flag) {
2049 this->write("const ");
2053 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2054 if ("sk_PerVertex" == intf.typeName()) {
2057 this->writeModifiers(intf.variable().modifiers());
2058 this->write("struct ");
2059 this->writeLine(std::string(intf.typeName()) + " {");
2060 const Type* structType = &intf.variable().type();
2061 if (structType->isArray()) {
2062 structType = &structType->componentType();
2065 this->writeFields(structType->fields(), structType->fPosition, &intf);
2066 if (fProgram.fInputs.fUseFlipRTUniform) {
2067 this->writeLine("float2 " SKSL_RTFLIP_NAME ";");
2071 if (intf.instanceName().size()) {
2073 this->write(intf.instanceName());
2074 if (intf.arraySize() > 0) {
2076 this->write(std::to_string(intf.arraySize()));
2079 fInterfaceBlockNameMap.set(&intf, intf.instanceName());
2081 fInterfaceBlockNameMap.set(&intf, *fProgram.fSymbols->takeOwnershipOfString(
2082 "_anonInterface" + std::to_string(fAnonInterfaceCount++)));
2084 this->writeLine(";");
2087 void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, Position parentPos,
2088 const InterfaceBlock* parentIntf) {
2089 MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
2090 int currentOffset = 0;
2091 for (const Type::Field& field : fields) {
2092 int fieldOffset = field.fModifiers.fLayout.fOffset;
2093 const Type* fieldType = field.fType;
2094 if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
2095 fContext.fErrors->error(parentPos, "type '" + std::string(fieldType->name()) +
2096 "' is not permitted here");
2099 if (fieldOffset != -1) {
2100 if (currentOffset > fieldOffset) {
2101 fContext.fErrors->error(field.fPosition,
2102 "offset of field '" + std::string(field.fName) +
2103 "' must be at least " + std::to_string(currentOffset));
2105 } else if (currentOffset < fieldOffset) {
2106 this->write("char pad");
2107 this->write(std::to_string(fPaddingCount++));
2109 this->write(std::to_string(fieldOffset - currentOffset));
2110 this->writeLine("];");
2111 currentOffset = fieldOffset;
2113 int alignment = memoryLayout.alignment(*fieldType);
2114 if (fieldOffset % alignment) {
2115 fContext.fErrors->error(field.fPosition,
2116 "offset of field '" + std::string(field.fName) +
2117 "' must be a multiple of " + std::to_string(alignment));
2121 size_t fieldSize = memoryLayout.size(*fieldType);
2122 if (fieldSize > static_cast<size_t>(std::numeric_limits<int>::max() - currentOffset)) {
2123 fContext.fErrors->error(parentPos, "field offset overflow");
2126 currentOffset += fieldSize;
2127 this->writeModifiers(field.fModifiers);
2128 this->writeType(*fieldType);
2130 this->writeName(field.fName);
2131 this->writeLine(";");
2133 fInterfaceBlockMap.set(&field, parentIntf);
2138 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
2139 this->writeExpression(value, Precedence::kTopLevel);
2142 void MetalCodeGenerator::writeName(std::string_view name) {
2143 if (fReservedWords.contains(name)) {
2144 this->write("_"); // adding underscore before name to avoid conflict with reserved words
2149 void MetalCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2150 this->writeModifiers(varDecl.var().modifiers());
2151 this->writeType(varDecl.var().type());
2153 this->writeName(varDecl.var().name());
2154 if (varDecl.value()) {
2156 this->writeVarInitializer(varDecl.var(), *varDecl.value());
2161 void MetalCodeGenerator::writeStatement(const Statement& s) {
2163 case Statement::Kind::kBlock:
2164 this->writeBlock(s.as<Block>());
2166 case Statement::Kind::kExpression:
2167 this->writeExpressionStatement(s.as<ExpressionStatement>());
2169 case Statement::Kind::kReturn:
2170 this->writeReturnStatement(s.as<ReturnStatement>());
2172 case Statement::Kind::kVarDeclaration:
2173 this->writeVarDeclaration(s.as<VarDeclaration>());
2175 case Statement::Kind::kIf:
2176 this->writeIfStatement(s.as<IfStatement>());
2178 case Statement::Kind::kFor:
2179 this->writeForStatement(s.as<ForStatement>());
2181 case Statement::Kind::kDo:
2182 this->writeDoStatement(s.as<DoStatement>());
2184 case Statement::Kind::kSwitch:
2185 this->writeSwitchStatement(s.as<SwitchStatement>());
2187 case Statement::Kind::kBreak:
2188 this->write("break;");
2190 case Statement::Kind::kContinue:
2191 this->write("continue;");
2193 case Statement::Kind::kDiscard:
2194 this->write("discard_fragment();");
2196 case Statement::Kind::kNop:
2200 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
2205 void MetalCodeGenerator::writeBlock(const Block& b) {
2206 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
2207 // something here to make the code valid).
2208 bool isScope = b.isScope() || b.isEmpty();
2210 this->writeLine("{");
2213 for (const std::unique_ptr<Statement>& stmt : b.children()) {
2214 if (!stmt->isEmpty()) {
2215 this->writeStatement(*stmt);
2225 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
2226 this->write("if (");
2227 this->writeExpression(*stmt.test(), Precedence::kTopLevel);
2229 this->writeStatement(*stmt.ifTrue());
2230 if (stmt.ifFalse()) {
2231 this->write(" else ");
2232 this->writeStatement(*stmt.ifFalse());
2236 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
2237 // Emit loops of the form 'for(;test;)' as 'while(test)', which is probably how they started
2238 if (!f.initializer() && f.test() && !f.next()) {
2239 this->write("while (");
2240 this->writeExpression(*f.test(), Precedence::kTopLevel);
2242 this->writeStatement(*f.statement());
2246 this->write("for (");
2247 if (f.initializer() && !f.initializer()->isEmpty()) {
2248 this->writeStatement(*f.initializer());
2253 this->writeExpression(*f.test(), Precedence::kTopLevel);
2257 this->writeExpression(*f.next(), Precedence::kTopLevel);
2260 this->writeStatement(*f.statement());
2263 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
2265 this->writeStatement(*d.statement());
2266 this->write(" while (");
2267 this->writeExpression(*d.test(), Precedence::kTopLevel);
2271 void MetalCodeGenerator::writeExpressionStatement(const ExpressionStatement& s) {
2272 if (fProgram.fConfig->fSettings.fOptimize && !s.expression()->hasSideEffects()) {
2273 // Don't emit dead expressions.
2276 this->writeExpression(*s.expression(), Precedence::kTopLevel);
2280 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2281 this->write("switch (");
2282 this->writeExpression(*s.value(), Precedence::kTopLevel);
2283 this->writeLine(") {");
2285 for (const std::unique_ptr<Statement>& stmt : s.cases()) {
2286 const SwitchCase& c = stmt->as<SwitchCase>();
2287 if (c.isDefault()) {
2288 this->writeLine("default:");
2290 this->write("case ");
2291 this->write(std::to_string(c.value()));
2292 this->writeLine(":");
2294 if (!c.statement()->isEmpty()) {
2296 this->writeStatement(*c.statement());
2305 void MetalCodeGenerator::writeReturnStatementFromMain() {
2306 // main functions in Metal return a magic _out parameter that doesn't exist in SkSL.
2307 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) ||
2308 ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2309 this->write("return _out;");
2311 SkDEBUGFAIL("unsupported kind of program");
2315 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
2316 if (fCurrentFunction && fCurrentFunction->isMain()) {
2317 if (r.expression()) {
2318 if (r.expression()->type().matches(*fContext.fTypes.fHalf4)) {
2319 this->write("_out.sk_FragColor = ");
2320 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2321 this->writeLine(";");
2323 fContext.fErrors->error(r.fPosition,
2324 "Metal does not support returning '" +
2325 r.expression()->type().description() + "' from main()");
2328 this->writeReturnStatementFromMain();
2332 this->write("return");
2333 if (r.expression()) {
2335 this->writeExpression(*r.expression(), Precedence::kTopLevel);
2340 void MetalCodeGenerator::writeHeader() {
2341 this->write("#include <metal_stdlib>\n");
2342 this->write("#include <simd/simd.h>\n");
2343 this->write("using namespace metal;\n");
2346 void MetalCodeGenerator::writeUniformStruct() {
2347 for (const ProgramElement* e : fProgram.elements()) {
2348 if (e->is<GlobalVarDeclaration>()) {
2349 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2350 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2351 if (var.modifiers().fFlags & Modifiers::kUniform_Flag &&
2352 var.type().typeKind() != Type::TypeKind::kSampler) {
2353 int uniformSet = this->getUniformSet(var.modifiers());
2354 // Make sure that the program's uniform-set value is consistent throughout.
2355 if (-1 == fUniformBuffer) {
2356 this->write("struct Uniforms {\n");
2357 fUniformBuffer = uniformSet;
2358 } else if (uniformSet != fUniformBuffer) {
2359 fContext.fErrors->error(decls.fPosition,
2360 "Metal backend requires all uniforms to have the same "
2361 "'layout(set=...)'");
2364 this->writeType(var.type());
2366 this->writeName(var.name());
2371 if (-1 != fUniformBuffer) {
2372 this->write("};\n");
2376 void MetalCodeGenerator::writeInputStruct() {
2377 this->write("struct Inputs {\n");
2378 for (const ProgramElement* e : fProgram.elements()) {
2379 if (e->is<GlobalVarDeclaration>()) {
2380 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2381 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2382 if (var.modifiers().fFlags & Modifiers::kIn_Flag &&
2383 -1 == var.modifiers().fLayout.fBuiltin) {
2385 this->writeType(var.type());
2387 this->writeName(var.name());
2388 if (-1 != var.modifiers().fLayout.fLocation) {
2389 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2390 this->write(" [[attribute(" +
2391 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2392 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2393 this->write(" [[user(locn" +
2394 std::to_string(var.modifiers().fLayout.fLocation) + ")]]");
2401 this->write("};\n");
2404 void MetalCodeGenerator::writeOutputStruct() {
2405 this->write("struct Outputs {\n");
2406 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2407 this->write(" float4 sk_Position [[position]];\n");
2408 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2409 this->write(" half4 sk_FragColor [[color(0)]];\n");
2411 for (const ProgramElement* e : fProgram.elements()) {
2412 if (e->is<GlobalVarDeclaration>()) {
2413 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
2414 const Variable& var = decls.declaration()->as<VarDeclaration>().var();
2415 if (var.modifiers().fFlags & Modifiers::kOut_Flag &&
2416 -1 == var.modifiers().fLayout.fBuiltin) {
2418 this->writeType(var.type());
2420 this->writeName(var.name());
2422 int location = var.modifiers().fLayout.fLocation;
2424 fContext.fErrors->error(var.fPosition,
2425 "Metal out variables must have 'layout(location=...)'");
2426 } else if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2427 this->write(" [[user(locn" + std::to_string(location) + ")]]");
2428 } else if (ProgramConfig::IsFragment(fProgram.fConfig->fKind)) {
2429 this->write(" [[color(" + std::to_string(location) + ")");
2430 int colorIndex = var.modifiers().fLayout.fIndex;
2432 this->write(", index(" + std::to_string(colorIndex) + ")");
2440 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind)) {
2441 this->write(" float sk_PointSize [[point_size]];\n");
2443 this->write("};\n");
2446 void MetalCodeGenerator::writeInterfaceBlocks() {
2447 bool wroteInterfaceBlock = false;
2448 for (const ProgramElement* e : fProgram.elements()) {
2449 if (e->is<InterfaceBlock>()) {
2450 this->writeInterfaceBlock(e->as<InterfaceBlock>());
2451 wroteInterfaceBlock = true;
2454 if (!wroteInterfaceBlock && fProgram.fInputs.fUseFlipRTUniform) {
2455 this->writeLine("struct sksl_synthetic_uniforms {");
2456 this->writeLine(" float2 " SKSL_RTFLIP_NAME ";");
2457 this->writeLine("};");
2461 void MetalCodeGenerator::writeStructDefinitions() {
2462 for (const ProgramElement* e : fProgram.elements()) {
2463 if (e->is<StructDefinition>()) {
2464 this->writeStructDefinition(e->as<StructDefinition>());
2469 void MetalCodeGenerator::visitGlobalStruct(GlobalStructVisitor* visitor) {
2470 // Visit the interface blocks.
2471 for (const auto& [interfaceType, interfaceName] : fInterfaceBlockNameMap) {
2472 visitor->visitInterfaceBlock(*interfaceType, interfaceName);
2474 for (const ProgramElement* element : fProgram.elements()) {
2475 if (!element->is<GlobalVarDeclaration>()) {
2478 const GlobalVarDeclaration& global = element->as<GlobalVarDeclaration>();
2479 const VarDeclaration& decl = global.declaration()->as<VarDeclaration>();
2480 const Variable& var = decl.var();
2481 if (var.type().typeKind() == Type::TypeKind::kSampler) {
2482 // Samplers are represented as a "texture/sampler" duo in the global struct.
2483 visitor->visitTexture(var.type(), var.name());
2484 visitor->visitSampler(var.type(), std::string(var.name()) + SAMPLER_SUFFIX);
2488 if (!(var.modifiers().fFlags & ~Modifiers::kConst_Flag) &&
2489 -1 == var.modifiers().fLayout.fBuiltin) {
2490 // Visit a regular variable.
2491 visitor->visitVariable(var, decl.value().get());
2496 void MetalCodeGenerator::writeGlobalStruct() {
2497 class : public GlobalStructVisitor {
2499 void visitInterfaceBlock(const InterfaceBlock& block,
2500 std::string_view blockName) override {
2502 fCodeGen->write(" constant ");
2503 fCodeGen->write(block.typeName());
2504 fCodeGen->write("* ");
2505 fCodeGen->writeName(blockName);
2506 fCodeGen->write(";\n");
2508 void visitTexture(const Type& type, std::string_view name) override {
2510 fCodeGen->write(" ");
2511 fCodeGen->writeType(type);
2512 fCodeGen->write(" ");
2513 fCodeGen->writeName(name);
2514 fCodeGen->write(";\n");
2516 void visitSampler(const Type&, std::string_view name) override {
2518 fCodeGen->write(" sampler ");
2519 fCodeGen->writeName(name);
2520 fCodeGen->write(";\n");
2522 void visitVariable(const Variable& var, const Expression* value) override {
2524 fCodeGen->write(" ");
2525 fCodeGen->writeModifiers(var.modifiers());
2526 fCodeGen->writeType(var.type());
2527 fCodeGen->write(" ");
2528 fCodeGen->writeName(var.name());
2529 fCodeGen->write(";\n");
2533 fCodeGen->write("struct Globals {\n");
2539 fCodeGen->writeLine("};");
2544 MetalCodeGenerator* fCodeGen = nullptr;
2548 visitor.fCodeGen = this;
2549 this->visitGlobalStruct(&visitor);
2553 void MetalCodeGenerator::writeGlobalInit() {
2554 class : public GlobalStructVisitor {
2556 void visitInterfaceBlock(const InterfaceBlock& blockType,
2557 std::string_view blockName) override {
2559 fCodeGen->write("&");
2560 fCodeGen->writeName(blockName);
2562 void visitTexture(const Type&, std::string_view name) override {
2564 fCodeGen->writeName(name);
2566 void visitSampler(const Type&, std::string_view name) override {
2568 fCodeGen->writeName(name);
2570 void visitVariable(const Variable& var, const Expression* value) override {
2573 fCodeGen->writeVarInitializer(var, *value);
2575 fCodeGen->write("{}");
2580 fCodeGen->write(" Globals _globals{");
2583 fCodeGen->write(", ");
2588 fCodeGen->writeLine("};");
2589 fCodeGen->writeLine(" (void)_globals;");
2592 MetalCodeGenerator* fCodeGen = nullptr;
2596 visitor.fCodeGen = this;
2597 this->visitGlobalStruct(&visitor);
2601 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
2603 case ProgramElement::Kind::kExtension:
2605 case ProgramElement::Kind::kGlobalVar:
2607 case ProgramElement::Kind::kInterfaceBlock:
2608 // handled in writeInterfaceBlocks, do nothing
2610 case ProgramElement::Kind::kStructDefinition:
2611 // Handled in writeStructDefinitions. Do nothing.
2613 case ProgramElement::Kind::kFunction:
2614 this->writeFunction(e.as<FunctionDefinition>());
2616 case ProgramElement::Kind::kFunctionPrototype:
2617 this->writeFunctionPrototype(e.as<FunctionPrototype>());
2619 case ProgramElement::Kind::kModifiers:
2620 this->writeModifiers(e.as<ModifiersDeclaration>().modifiers());
2621 this->writeLine(";");
2624 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
2629 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression* e) {
2631 return kNo_Requirements;
2633 switch (e->kind()) {
2634 case Expression::Kind::kFunctionCall: {
2635 const FunctionCall& f = e->as<FunctionCall>();
2636 Requirements result = this->requirements(f.function());
2637 for (const auto& arg : f.arguments()) {
2638 result |= this->requirements(arg.get());
2642 case Expression::Kind::kConstructorCompound:
2643 case Expression::Kind::kConstructorCompoundCast:
2644 case Expression::Kind::kConstructorArray:
2645 case Expression::Kind::kConstructorArrayCast:
2646 case Expression::Kind::kConstructorDiagonalMatrix:
2647 case Expression::Kind::kConstructorScalarCast:
2648 case Expression::Kind::kConstructorSplat:
2649 case Expression::Kind::kConstructorStruct: {
2650 const AnyConstructor& c = e->asAnyConstructor();
2651 Requirements result = kNo_Requirements;
2652 for (const auto& arg : c.argumentSpan()) {
2653 result |= this->requirements(arg.get());
2657 case Expression::Kind::kFieldAccess: {
2658 const FieldAccess& f = e->as<FieldAccess>();
2659 if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
2660 return kGlobals_Requirement;
2662 return this->requirements(f.base().get());
2664 case Expression::Kind::kSwizzle:
2665 return this->requirements(e->as<Swizzle>().base().get());
2666 case Expression::Kind::kBinary: {
2667 const BinaryExpression& bin = e->as<BinaryExpression>();
2668 return this->requirements(bin.left().get()) |
2669 this->requirements(bin.right().get());
2671 case Expression::Kind::kIndex: {
2672 const IndexExpression& idx = e->as<IndexExpression>();
2673 return this->requirements(idx.base().get()) | this->requirements(idx.index().get());
2675 case Expression::Kind::kPrefix:
2676 return this->requirements(e->as<PrefixExpression>().operand().get());
2677 case Expression::Kind::kPostfix:
2678 return this->requirements(e->as<PostfixExpression>().operand().get());
2679 case Expression::Kind::kTernary: {
2680 const TernaryExpression& t = e->as<TernaryExpression>();
2681 return this->requirements(t.test().get()) | this->requirements(t.ifTrue().get()) |
2682 this->requirements(t.ifFalse().get());
2684 case Expression::Kind::kVariableReference: {
2685 const VariableReference& v = e->as<VariableReference>();
2686 const Modifiers& modifiers = v.variable()->modifiers();
2687 Requirements result = kNo_Requirements;
2688 if (modifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2689 result = kGlobals_Requirement | kFragCoord_Requirement;
2690 } else if (Variable::Storage::kGlobal == v.variable()->storage()) {
2691 if (modifiers.fFlags & Modifiers::kIn_Flag) {
2692 result = kInputs_Requirement;
2693 } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
2694 result = kOutputs_Requirement;
2695 } else if (modifiers.fFlags & Modifiers::kUniform_Flag &&
2696 v.variable()->type().typeKind() != Type::TypeKind::kSampler) {
2697 result = kUniforms_Requirement;
2699 result = kGlobals_Requirement;
2705 return kNo_Requirements;
2709 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement* s) {
2711 return kNo_Requirements;
2713 switch (s->kind()) {
2714 case Statement::Kind::kBlock: {
2715 Requirements result = kNo_Requirements;
2716 for (const std::unique_ptr<Statement>& child : s->as<Block>().children()) {
2717 result |= this->requirements(child.get());
2721 case Statement::Kind::kVarDeclaration: {
2722 const VarDeclaration& var = s->as<VarDeclaration>();
2723 return this->requirements(var.value().get());
2725 case Statement::Kind::kExpression:
2726 return this->requirements(s->as<ExpressionStatement>().expression().get());
2727 case Statement::Kind::kReturn: {
2728 const ReturnStatement& r = s->as<ReturnStatement>();
2729 return this->requirements(r.expression().get());
2731 case Statement::Kind::kIf: {
2732 const IfStatement& i = s->as<IfStatement>();
2733 return this->requirements(i.test().get()) |
2734 this->requirements(i.ifTrue().get()) |
2735 this->requirements(i.ifFalse().get());
2737 case Statement::Kind::kFor: {
2738 const ForStatement& f = s->as<ForStatement>();
2739 return this->requirements(f.initializer().get()) |
2740 this->requirements(f.test().get()) |
2741 this->requirements(f.next().get()) |
2742 this->requirements(f.statement().get());
2744 case Statement::Kind::kDo: {
2745 const DoStatement& d = s->as<DoStatement>();
2746 return this->requirements(d.test().get()) |
2747 this->requirements(d.statement().get());
2749 case Statement::Kind::kSwitch: {
2750 const SwitchStatement& sw = s->as<SwitchStatement>();
2751 Requirements result = this->requirements(sw.value().get());
2752 for (const std::unique_ptr<Statement>& sc : sw.cases()) {
2753 result |= this->requirements(sc->as<SwitchCase>().statement().get());
2758 return kNo_Requirements;
2762 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
2763 Requirements* found = fRequirements.find(&f);
2765 fRequirements.set(&f, kNo_Requirements);
2766 for (const ProgramElement* e : fProgram.elements()) {
2767 if (e->is<FunctionDefinition>()) {
2768 const FunctionDefinition& def = e->as<FunctionDefinition>();
2769 if (&def.declaration() == &f) {
2770 Requirements reqs = this->requirements(def.body().get());
2771 fRequirements.set(&f, reqs);
2776 // We never found a definition for this declared function, but it's legal to prototype a
2777 // function without ever giving a definition, as long as you don't call it.
2778 return kNo_Requirements;
2783 bool MetalCodeGenerator::generateCode() {
2784 StringStream header;
2786 AutoOutputStream outputToHeader(this, &header, &fIndentation);
2787 this->writeHeader();
2788 this->writeStructDefinitions();
2789 this->writeUniformStruct();
2790 this->writeInputStruct();
2791 this->writeOutputStruct();
2792 this->writeInterfaceBlocks();
2793 this->writeGlobalStruct();
2797 AutoOutputStream outputToBody(this, &body, &fIndentation);
2798 for (const ProgramElement* e : fProgram.elements()) {
2799 this->writeProgramElement(*e);
2802 write_stringstream(header, *fOut);
2803 write_stringstream(fExtraFunctionPrototypes, *fOut);
2804 write_stringstream(fExtraFunctions, *fOut);
2805 write_stringstream(body, *fOut);
2806 fContext.fErrors->reportPendingErrors(Position());
2807 return fContext.fErrors->errorCount() == 0;