2 * Copyright 2016 Google Inc.
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
8 #ifndef SKSL_METALCODEGENERATOR
9 #define SKSL_METALCODEGENERATOR
11 #include "include/private/SkSLDefines.h"
12 #include "include/private/SkTArray.h"
13 #include "include/private/SkTHash.h"
14 #include "include/sksl/SkSLOperator.h"
15 #include "src/sksl/SkSLStringStream.h"
16 #include "src/sksl/codegen/SkSLCodeGenerator.h"
17 #include "src/sksl/ir/SkSLType.h"
20 #include <initializer_list>
23 #include <string_view>
29 class BinaryExpression;
31 class ConstructorArrayCast;
32 class ConstructorCompound;
33 class ConstructorMatrixResize;
37 class ExpressionStatement;
42 class FunctionDeclaration;
43 class FunctionDefinition;
44 class FunctionPrototype;
50 class PostfixExpression;
51 class PrefixExpression;
53 class ReturnStatement;
56 class StructDefinition;
57 class SwitchStatement;
58 class TernaryExpression;
61 class VariableReference;
62 enum IntrinsicKind : int8_t;
63 struct IndexExpression;
70 * Converts a Program into Metal code.
72 class MetalCodeGenerator : public CodeGenerator {
74 inline static constexpr const char* SAMPLER_SUFFIX = "Smplr";
75 inline static constexpr const char* PACKED_PREFIX = "packed_";
77 MetalCodeGenerator(const Context* context, const Program* program, OutputStream* out)
78 : INHERITED(context, program, out)
79 , fReservedWords({"atan2", "rsqrt", "rint", "dfdx", "dfdy", "vertex", "fragment"})
80 , fLineEnding("\n") {}
82 bool generateCode() override;
85 using Precedence = Operator::Precedence;
87 typedef int Requirements;
88 inline static constexpr Requirements kNo_Requirements = 0;
89 inline static constexpr Requirements kInputs_Requirement = 1 << 0;
90 inline static constexpr Requirements kOutputs_Requirement = 1 << 1;
91 inline static constexpr Requirements kUniforms_Requirement = 1 << 2;
92 inline static constexpr Requirements kGlobals_Requirement = 1 << 3;
93 inline static constexpr Requirements kFragCoord_Requirement = 1 << 4;
95 class GlobalStructVisitor;
96 void visitGlobalStruct(GlobalStructVisitor* visitor);
98 void write(std::string_view s);
100 void writeLine(std::string_view s = std::string_view());
106 void writeUniformStruct();
108 void writeInputStruct();
110 void writeOutputStruct();
112 void writeInterfaceBlocks();
114 void writeStructDefinitions();
116 void writeFields(const std::vector<Type::Field>& fields, Position pos,
117 const InterfaceBlock* parentIntf = nullptr);
119 int size(const Type* type, bool isPacked) const;
121 int alignment(const Type* type, bool isPacked) const;
123 void writeGlobalStruct();
125 void writeGlobalInit();
127 void writePrecisionModifier();
129 std::string typeName(const Type& type);
131 void writeStructDefinition(const StructDefinition& s);
133 void writeType(const Type& type);
135 void writeExtension(const Extension& ext);
137 void writeInterfaceBlock(const InterfaceBlock& intf);
139 void writeFunctionRequirementParams(const FunctionDeclaration& f,
140 const char*& separator);
142 void writeFunctionRequirementArgs(const FunctionDeclaration& f, const char*& separator);
144 bool writeFunctionDeclaration(const FunctionDeclaration& f);
146 void writeFunction(const FunctionDefinition& f);
148 void writeFunctionPrototype(const FunctionPrototype& f);
150 void writeLayout(const Layout& layout);
152 void writeModifiers(const Modifiers& modifiers);
154 void writeVarInitializer(const Variable& var, const Expression& value);
156 void writeName(std::string_view name);
158 void writeVarDeclaration(const VarDeclaration& decl);
160 void writeFragCoord();
162 void writeVariableReference(const VariableReference& ref);
164 void writeExpression(const Expression& expr, Precedence parentPrecedence);
166 void writeMinAbsHack(Expression& absExpr, Expression& otherExpr);
168 std::string getOutParamHelper(const FunctionCall& c,
169 const ExpressionArray& arguments,
170 const SkTArray<VariableReference*>& outVars);
172 std::string getInversePolyfill(const ExpressionArray& arguments);
174 std::string getBitcastIntrinsic(const Type& outType);
176 std::string getTempVariable(const Type& varType);
178 void writeFunctionCall(const FunctionCall& c);
180 bool matrixConstructHelperIsNeeded(const ConstructorCompound& c);
181 std::string getMatrixConstructHelper(const AnyConstructor& c);
182 void assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns);
183 void assembleMatrixFromExpressions(const AnyConstructor& ctor, int rows, int columns);
185 void writeMatrixCompMult();
187 void writeOuterProduct();
189 void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
191 void writeMatrixDivisionHelpers(const Type& type);
193 void writeMatrixEqualityHelpers(const Type& left, const Type& right);
195 std::string getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
197 void writeArrayEqualityHelpers(const Type& type);
199 void writeStructEqualityHelpers(const Type& type);
201 void writeEqualityHelpers(const Type& leftType, const Type& rightType);
203 void writeArgumentList(const ExpressionArray& arguments);
205 void writeSimpleIntrinsic(const FunctionCall& c);
207 bool writeIntrinsicCall(const FunctionCall& c, IntrinsicKind kind);
209 void writeConstructorCompound(const ConstructorCompound& c, Precedence parentPrecedence);
211 void writeConstructorCompoundVector(const ConstructorCompound& c, Precedence parentPrecedence);
213 void writeConstructorCompoundMatrix(const ConstructorCompound& c, Precedence parentPrecedence);
215 void writeConstructorMatrixResize(const ConstructorMatrixResize& c,
216 Precedence parentPrecedence);
218 void writeAnyConstructor(const AnyConstructor& c,
219 const char* leftBracket,
220 const char* rightBracket,
221 Precedence parentPrecedence);
223 void writeCastConstructor(const AnyConstructor& c,
224 const char* leftBracket,
225 const char* rightBracket,
226 Precedence parentPrecedence);
228 void writeConstructorArrayCast(const ConstructorArrayCast& c, Precedence parentPrecedence);
230 void writeFieldAccess(const FieldAccess& f);
232 void writeSwizzle(const Swizzle& swizzle);
234 // Splats a scalar expression across a matrix of arbitrary size.
235 void writeNumberAsMatrix(const Expression& expr, const Type& matrixType);
237 void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
239 void writeTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
241 void writeIndexExpression(const IndexExpression& expr);
243 void writePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
245 void writePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
247 void writeLiteral(const Literal& f);
249 void writeSetting(const Setting& s);
251 void writeStatement(const Statement& s);
253 void writeStatements(const StatementArray& statements);
255 void writeBlock(const Block& b);
257 void writeIfStatement(const IfStatement& stmt);
259 void writeForStatement(const ForStatement& f);
261 void writeDoStatement(const DoStatement& d);
263 void writeExpressionStatement(const ExpressionStatement& s);
265 void writeSwitchStatement(const SwitchStatement& s);
267 void writeReturnStatementFromMain();
269 void writeReturnStatement(const ReturnStatement& r);
271 void writeProgramElement(const ProgramElement& e);
273 Requirements requirements(const FunctionDeclaration& f);
275 Requirements requirements(const Expression* e);
277 Requirements requirements(const Statement* s);
279 int getUniformBinding(const Modifiers& m);
281 int getUniformSet(const Modifiers& m);
283 SkTHashSet<std::string_view> fReservedWords;
284 SkTHashMap<const Type::Field*, const InterfaceBlock*> fInterfaceBlockMap;
285 SkTHashMap<const InterfaceBlock*, std::string_view> fInterfaceBlockNameMap;
286 int fAnonInterfaceCount = 0;
287 int fPaddingCount = 0;
288 const char* fLineEnding;
289 std::string fFunctionHeader;
290 StringStream fExtraFunctions;
291 StringStream fExtraFunctionPrototypes;
293 int fIndentation = 0;
294 bool fAtLineStart = false;
295 std::set<std::string> fWrittenIntrinsics;
296 // true if we have run into usages of dFdx / dFdy
297 bool fFoundDerivatives = false;
298 SkTHashMap<const FunctionDeclaration*, Requirements> fRequirements;
299 SkTHashSet<std::string> fHelpers;
300 int fUniformBuffer = -1;
301 std::string fRTFlipName;
302 const FunctionDeclaration* fCurrentFunction = nullptr;
303 int fSwizzleHelperCount = 0;
304 bool fIgnoreVariableReferenceModifiers = false;
306 using INHERITED = CodeGenerator;