247766f51e5897dc39ed341573fdfa08835d0926
[platform/upstream/libSkiaSharp.git] / src / sksl / SkSLIRGenerator.cpp
1 /*
2  * Copyright 2016 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7
8 #include "SkSLIRGenerator.h"
9
10 #include "limits.h"
11
12 #include "SkSLCompiler.h"
13 #include "ast/SkSLASTBoolLiteral.h"
14 #include "ast/SkSLASTFieldSuffix.h"
15 #include "ast/SkSLASTFloatLiteral.h"
16 #include "ast/SkSLASTIndexSuffix.h"
17 #include "ast/SkSLASTIntLiteral.h"
18 #include "ir/SkSLBinaryExpression.h"
19 #include "ir/SkSLBoolLiteral.h"
20 #include "ir/SkSLBreakStatement.h"
21 #include "ir/SkSLConstructor.h"
22 #include "ir/SkSLContinueStatement.h"
23 #include "ir/SkSLDiscardStatement.h"
24 #include "ir/SkSLDoStatement.h"
25 #include "ir/SkSLExpressionStatement.h"
26 #include "ir/SkSLField.h"
27 #include "ir/SkSLFieldAccess.h"
28 #include "ir/SkSLFloatLiteral.h"
29 #include "ir/SkSLForStatement.h"
30 #include "ir/SkSLFunctionCall.h"
31 #include "ir/SkSLFunctionDeclaration.h"
32 #include "ir/SkSLFunctionDefinition.h"
33 #include "ir/SkSLFunctionReference.h"
34 #include "ir/SkSLIfStatement.h"
35 #include "ir/SkSLIndexExpression.h"
36 #include "ir/SkSLInterfaceBlock.h"
37 #include "ir/SkSLIntLiteral.h"
38 #include "ir/SkSLLayout.h"
39 #include "ir/SkSLPostfixExpression.h"
40 #include "ir/SkSLPrefixExpression.h"
41 #include "ir/SkSLReturnStatement.h"
42 #include "ir/SkSLSwizzle.h"
43 #include "ir/SkSLTernaryExpression.h"
44 #include "ir/SkSLUnresolvedFunction.h"
45 #include "ir/SkSLVariable.h"
46 #include "ir/SkSLVarDeclarations.h"
47 #include "ir/SkSLVarDeclarationsStatement.h"
48 #include "ir/SkSLVariableReference.h"
49 #include "ir/SkSLWhileStatement.h"
50
51 namespace SkSL {
52
53 class AutoSymbolTable {
54 public:
55     AutoSymbolTable(IRGenerator* ir)
56     : fIR(ir)
57     , fPrevious(fIR->fSymbolTable) {
58         fIR->pushSymbolTable();
59     }
60
61     ~AutoSymbolTable() {
62         fIR->popSymbolTable();
63         ASSERT(fPrevious == fIR->fSymbolTable);
64     }
65
66     IRGenerator* fIR;
67     std::shared_ptr<SymbolTable> fPrevious;
68 };
69
70 class AutoLoopLevel {
71 public:
72     AutoLoopLevel(IRGenerator* ir)
73     : fIR(ir) {
74         fIR->fLoopLevel++;
75     }
76
77     ~AutoLoopLevel() {
78         fIR->fLoopLevel--;
79     }
80
81     IRGenerator* fIR;
82 };
83
84 IRGenerator::IRGenerator(const Context* context, std::shared_ptr<SymbolTable> symbolTable,
85                          ErrorReporter& errorReporter)
86 : fContext(*context)
87 , fCurrentFunction(nullptr)
88 , fSymbolTable(std::move(symbolTable))
89 , fLoopLevel(0)
90 , fErrors(errorReporter) {}
91
92 void IRGenerator::pushSymbolTable() {
93     fSymbolTable.reset(new SymbolTable(std::move(fSymbolTable), fErrors));
94 }
95
96 void IRGenerator::popSymbolTable() {
97     fSymbolTable = fSymbolTable->fParent;
98 }
99
100 static void fill_caps(const GrShaderCaps& caps, std::unordered_map<SkString, CapValue>* capsMap) {
101 #define CAP(name) capsMap->insert(std::make_pair(SkString(#name), CapValue(caps.name())));
102     CAP(fbFetchSupport);
103     CAP(fbFetchNeedsCustomOutput);
104     CAP(bindlessTextureSupport);
105     CAP(dropsTileOnZeroDivide);
106     CAP(flatInterpolationSupport);
107     CAP(noperspectiveInterpolationSupport);
108     CAP(multisampleInterpolationSupport);
109     CAP(sampleVariablesSupport);
110     CAP(sampleMaskOverrideCoverageSupport);
111     CAP(externalTextureSupport);
112     CAP(texelFetchSupport);
113     CAP(imageLoadStoreSupport);
114     CAP(mustEnableAdvBlendEqs);
115     CAP(mustEnableSpecificAdvBlendEqs);
116     CAP(mustDeclareFragmentShaderOutput);
117     CAP(canUseAnyFunctionInShader);
118 #undef CAP
119 }
120
121 void IRGenerator::start(const Program::Settings* settings) {
122     fSettings = settings;
123     fCapsMap.clear();
124     if (settings->fCaps) {
125         fill_caps(*settings->fCaps, &fCapsMap);
126     }
127     this->pushSymbolTable();
128     fInputs.reset();
129 }
130
131 void IRGenerator::finish() {
132     this->popSymbolTable();
133     fSettings = nullptr;
134 }
135
136 std::unique_ptr<Extension> IRGenerator::convertExtension(const ASTExtension& extension) {
137     return std::unique_ptr<Extension>(new Extension(extension.fPosition, extension.fName));
138 }
139
140 std::unique_ptr<Statement> IRGenerator::convertStatement(const ASTStatement& statement) {
141     switch (statement.fKind) {
142         case ASTStatement::kBlock_Kind:
143             return this->convertBlock((ASTBlock&) statement);
144         case ASTStatement::kVarDeclaration_Kind:
145             return this->convertVarDeclarationStatement((ASTVarDeclarationStatement&) statement);
146         case ASTStatement::kExpression_Kind:
147             return this->convertExpressionStatement((ASTExpressionStatement&) statement);
148         case ASTStatement::kIf_Kind:
149             return this->convertIf((ASTIfStatement&) statement);
150         case ASTStatement::kFor_Kind:
151             return this->convertFor((ASTForStatement&) statement);
152         case ASTStatement::kWhile_Kind:
153             return this->convertWhile((ASTWhileStatement&) statement);
154         case ASTStatement::kDo_Kind:
155             return this->convertDo((ASTDoStatement&) statement);
156         case ASTStatement::kReturn_Kind:
157             return this->convertReturn((ASTReturnStatement&) statement);
158         case ASTStatement::kBreak_Kind:
159             return this->convertBreak((ASTBreakStatement&) statement);
160         case ASTStatement::kContinue_Kind:
161             return this->convertContinue((ASTContinueStatement&) statement);
162         case ASTStatement::kDiscard_Kind:
163             return this->convertDiscard((ASTDiscardStatement&) statement);
164         default:
165             ABORT("unsupported statement type: %d\n", statement.fKind);
166     }
167 }
168
169 std::unique_ptr<Block> IRGenerator::convertBlock(const ASTBlock& block) {
170     AutoSymbolTable table(this);
171     std::vector<std::unique_ptr<Statement>> statements;
172     for (size_t i = 0; i < block.fStatements.size(); i++) {
173         std::unique_ptr<Statement> statement = this->convertStatement(*block.fStatements[i]);
174         if (!statement) {
175             return nullptr;
176         }
177         statements.push_back(std::move(statement));
178     }
179     return std::unique_ptr<Block>(new Block(block.fPosition, std::move(statements), fSymbolTable));
180 }
181
182 std::unique_ptr<Statement> IRGenerator::convertVarDeclarationStatement(
183                                                               const ASTVarDeclarationStatement& s) {
184     auto decl = this->convertVarDeclarations(*s.fDeclarations, Variable::kLocal_Storage);
185     if (!decl) {
186         return nullptr;
187     }
188     return std::unique_ptr<Statement>(new VarDeclarationsStatement(std::move(decl)));
189 }
190
191 std::unique_ptr<VarDeclarations> IRGenerator::convertVarDeclarations(const ASTVarDeclarations& decl,
192                                                                      Variable::Storage storage) {
193     std::vector<VarDeclaration> variables;
194     const Type* baseType = this->convertType(*decl.fType);
195     if (!baseType) {
196         return nullptr;
197     }
198     for (const auto& varDecl : decl.fVars) {
199         const Type* type = baseType;
200         std::vector<std::unique_ptr<Expression>> sizes;
201         for (const auto& rawSize : varDecl.fSizes) {
202             if (rawSize) {
203                 auto size = this->coerce(this->convertExpression(*rawSize), *fContext.fInt_Type);
204                 if (!size) {
205                     return nullptr;
206                 }
207                 SkString name = type->fName;
208                 int64_t count;
209                 if (size->fKind == Expression::kIntLiteral_Kind) {
210                     count = ((IntLiteral&) *size).fValue;
211                     if (count <= 0) {
212                         fErrors.error(size->fPosition, "array size must be positive");
213                     }
214                     name += "[" + to_string(count) + "]";
215                 } else {
216                     count = -1;
217                     name += "[]";
218                 }
219                 type = new Type(name, Type::kArray_Kind, *type, (int) count);
220                 fSymbolTable->takeOwnership((Type*) type);
221                 sizes.push_back(std::move(size));
222             } else {
223                 type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
224                 fSymbolTable->takeOwnership((Type*) type);
225                 sizes.push_back(nullptr);
226             }
227         }
228         auto var = std::unique_ptr<Variable>(new Variable(decl.fPosition, decl.fModifiers,
229                                                           varDecl.fName, *type, storage));
230         std::unique_ptr<Expression> value;
231         if (varDecl.fValue) {
232             value = this->convertExpression(*varDecl.fValue);
233             if (!value) {
234                 return nullptr;
235             }
236             value = this->coerce(std::move(value), *type);
237         }
238         if (storage == Variable::kGlobal_Storage && varDecl.fName == SkString("sk_FragColor") &&
239             (*fSymbolTable)[varDecl.fName]) {
240             // already defined, ignore
241         } else if (storage == Variable::kGlobal_Storage && (*fSymbolTable)[varDecl.fName] &&
242                    (*fSymbolTable)[varDecl.fName]->fKind == Symbol::kVariable_Kind &&
243                    ((Variable*) (*fSymbolTable)[varDecl.fName])->fModifiers.fLayout.fBuiltin >= 0) {
244             // already defined, just update the modifiers
245             Variable* old = (Variable*) (*fSymbolTable)[varDecl.fName];
246             old->fModifiers = var->fModifiers;
247         } else {
248             variables.emplace_back(var.get(), std::move(sizes), std::move(value));
249             fSymbolTable->add(varDecl.fName, std::move(var));
250         }
251     }
252     return std::unique_ptr<VarDeclarations>(new VarDeclarations(decl.fPosition,
253                                                                 baseType,
254                                                                 std::move(variables)));
255 }
256
257 std::unique_ptr<ModifiersDeclaration> IRGenerator::convertModifiersDeclaration(
258                                                                  const ASTModifiersDeclaration& m) {
259     return std::unique_ptr<ModifiersDeclaration>(new ModifiersDeclaration(m.fModifiers));
260 }
261
262 std::unique_ptr<Statement> IRGenerator::convertIf(const ASTIfStatement& s) {
263     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*s.fTest),
264                                                     *fContext.fBool_Type);
265     if (!test) {
266         return nullptr;
267     }
268     std::unique_ptr<Statement> ifTrue = this->convertStatement(*s.fIfTrue);
269     if (!ifTrue) {
270         return nullptr;
271     }
272     std::unique_ptr<Statement> ifFalse;
273     if (s.fIfFalse) {
274         ifFalse = this->convertStatement(*s.fIfFalse);
275         if (!ifFalse) {
276             return nullptr;
277         }
278     }
279     if (test->fKind == Expression::kBoolLiteral_Kind) {
280         // static boolean value, fold down to a single branch
281         if (((BoolLiteral&) *test).fValue) {
282             return ifTrue;
283         } else if (s.fIfFalse) {
284             return ifFalse;
285         } else {
286             // False & no else clause. Not an error, so don't return null!
287             std::vector<std::unique_ptr<Statement>> empty;
288             return std::unique_ptr<Statement>(new Block(s.fPosition, std::move(empty),
289                                                         fSymbolTable));
290         }
291     }
292     return std::unique_ptr<Statement>(new IfStatement(s.fPosition, std::move(test),
293                                                       std::move(ifTrue), std::move(ifFalse)));
294 }
295
296 std::unique_ptr<Statement> IRGenerator::convertFor(const ASTForStatement& f) {
297     AutoLoopLevel level(this);
298     AutoSymbolTable table(this);
299     std::unique_ptr<Statement> initializer;
300     if (f.fInitializer) {
301         initializer = this->convertStatement(*f.fInitializer);
302         if (!initializer) {
303             return nullptr;
304         }
305     }
306     std::unique_ptr<Expression> test;
307     if (f.fTest) {
308         test = this->coerce(this->convertExpression(*f.fTest), *fContext.fBool_Type);
309         if (!test) {
310             return nullptr;
311         }
312     }
313     std::unique_ptr<Expression> next;
314     if (f.fNext) {
315         next = this->convertExpression(*f.fNext);
316         if (!next) {
317             return nullptr;
318         }
319         this->checkValid(*next);
320     }
321     std::unique_ptr<Statement> statement = this->convertStatement(*f.fStatement);
322     if (!statement) {
323         return nullptr;
324     }
325     return std::unique_ptr<Statement>(new ForStatement(f.fPosition, std::move(initializer),
326                                                        std::move(test), std::move(next),
327                                                        std::move(statement), fSymbolTable));
328 }
329
330 std::unique_ptr<Statement> IRGenerator::convertWhile(const ASTWhileStatement& w) {
331     AutoLoopLevel level(this);
332     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*w.fTest),
333                                                     *fContext.fBool_Type);
334     if (!test) {
335         return nullptr;
336     }
337     std::unique_ptr<Statement> statement = this->convertStatement(*w.fStatement);
338     if (!statement) {
339         return nullptr;
340     }
341     return std::unique_ptr<Statement>(new WhileStatement(w.fPosition, std::move(test),
342                                                          std::move(statement)));
343 }
344
345 std::unique_ptr<Statement> IRGenerator::convertDo(const ASTDoStatement& d) {
346     AutoLoopLevel level(this);
347     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*d.fTest),
348                                                     *fContext.fBool_Type);
349     if (!test) {
350         return nullptr;
351     }
352     std::unique_ptr<Statement> statement = this->convertStatement(*d.fStatement);
353     if (!statement) {
354         return nullptr;
355     }
356     return std::unique_ptr<Statement>(new DoStatement(d.fPosition, std::move(statement),
357                                                       std::move(test)));
358 }
359
360 std::unique_ptr<Statement> IRGenerator::convertExpressionStatement(
361                                                                   const ASTExpressionStatement& s) {
362     std::unique_ptr<Expression> e = this->convertExpression(*s.fExpression);
363     if (!e) {
364         return nullptr;
365     }
366     this->checkValid(*e);
367     return std::unique_ptr<Statement>(new ExpressionStatement(std::move(e)));
368 }
369
370 std::unique_ptr<Statement> IRGenerator::convertReturn(const ASTReturnStatement& r) {
371     ASSERT(fCurrentFunction);
372     if (r.fExpression) {
373         std::unique_ptr<Expression> result = this->convertExpression(*r.fExpression);
374         if (!result) {
375             return nullptr;
376         }
377         if (fCurrentFunction->fReturnType == *fContext.fVoid_Type) {
378             fErrors.error(result->fPosition, "may not return a value from a void function");
379         } else {
380             result = this->coerce(std::move(result), fCurrentFunction->fReturnType);
381             if (!result) {
382                 return nullptr;
383             }
384         }
385         return std::unique_ptr<Statement>(new ReturnStatement(std::move(result)));
386     } else {
387         if (fCurrentFunction->fReturnType != *fContext.fVoid_Type) {
388             fErrors.error(r.fPosition, "expected function to return '" +
389                                        fCurrentFunction->fReturnType.description() + "'");
390         }
391         return std::unique_ptr<Statement>(new ReturnStatement(r.fPosition));
392     }
393 }
394
395 std::unique_ptr<Statement> IRGenerator::convertBreak(const ASTBreakStatement& b) {
396     if (fLoopLevel > 0) {
397         return std::unique_ptr<Statement>(new BreakStatement(b.fPosition));
398     } else {
399         fErrors.error(b.fPosition, "break statement must be inside a loop");
400         return nullptr;
401     }
402 }
403
404 std::unique_ptr<Statement> IRGenerator::convertContinue(const ASTContinueStatement& c) {
405     if (fLoopLevel > 0) {
406         return std::unique_ptr<Statement>(new ContinueStatement(c.fPosition));
407     } else {
408         fErrors.error(c.fPosition, "continue statement must be inside a loop");
409         return nullptr;
410     }
411 }
412
413 std::unique_ptr<Statement> IRGenerator::convertDiscard(const ASTDiscardStatement& d) {
414     return std::unique_ptr<Statement>(new DiscardStatement(d.fPosition));
415 }
416
417 std::unique_ptr<FunctionDefinition> IRGenerator::convertFunction(const ASTFunction& f) {
418     const Type* returnType = this->convertType(*f.fReturnType);
419     if (!returnType) {
420         return nullptr;
421     }
422     std::vector<const Variable*> parameters;
423     for (const auto& param : f.fParameters) {
424         const Type* type = this->convertType(*param->fType);
425         if (!type) {
426             return nullptr;
427         }
428         for (int j = (int) param->fSizes.size() - 1; j >= 0; j--) {
429             int size = param->fSizes[j];
430             SkString name = type->name() + "[" + to_string(size) + "]";
431             Type* newType = new Type(std::move(name), Type::kArray_Kind, *type, size);
432             fSymbolTable->takeOwnership(newType);
433             type = newType;
434         }
435         SkString name = param->fName;
436         Position pos = param->fPosition;
437         Variable* var = new Variable(pos, param->fModifiers, std::move(name), *type,
438                                      Variable::kParameter_Storage);
439         fSymbolTable->takeOwnership(var);
440         parameters.push_back(var);
441     }
442
443     // find existing declaration
444     const FunctionDeclaration* decl = nullptr;
445     auto entry = (*fSymbolTable)[f.fName];
446     if (entry) {
447         std::vector<const FunctionDeclaration*> functions;
448         switch (entry->fKind) {
449             case Symbol::kUnresolvedFunction_Kind:
450                 functions = ((UnresolvedFunction*) entry)->fFunctions;
451                 break;
452             case Symbol::kFunctionDeclaration_Kind:
453                 functions.push_back((FunctionDeclaration*) entry);
454                 break;
455             default:
456                 fErrors.error(f.fPosition, "symbol '" + f.fName + "' was already defined");
457                 return nullptr;
458         }
459         for (const auto& other : functions) {
460             ASSERT(other->fName == f.fName);
461             if (parameters.size() == other->fParameters.size()) {
462                 bool match = true;
463                 for (size_t i = 0; i < parameters.size(); i++) {
464                     if (parameters[i]->fType != other->fParameters[i]->fType) {
465                         match = false;
466                         break;
467                     }
468                 }
469                 if (match) {
470                     if (*returnType != other->fReturnType) {
471                         FunctionDeclaration newDecl(f.fPosition, f.fName, parameters, *returnType);
472                         fErrors.error(f.fPosition, "functions '" + newDecl.description() +
473                                                    "' and '" + other->description() +
474                                                    "' differ only in return type");
475                         return nullptr;
476                     }
477                     decl = other;
478                     for (size_t i = 0; i < parameters.size(); i++) {
479                         if (parameters[i]->fModifiers != other->fParameters[i]->fModifiers) {
480                             fErrors.error(f.fPosition, "modifiers on parameter " +
481                                                        to_string((uint64_t) i + 1) +
482                                                        " differ between declaration and "
483                                                        "definition");
484                             return nullptr;
485                         }
486                     }
487                     if (other->fDefined) {
488                         fErrors.error(f.fPosition, "duplicate definition of " +
489                                                    other->description());
490                     }
491                     break;
492                 }
493             }
494         }
495     }
496     if (!decl) {
497         // couldn't find an existing declaration
498         auto newDecl = std::unique_ptr<FunctionDeclaration>(new FunctionDeclaration(f.fPosition,
499                                                                                     f.fName,
500                                                                                     parameters,
501                                                                                     *returnType));
502         decl = newDecl.get();
503         fSymbolTable->add(decl->fName, std::move(newDecl));
504     }
505     if (f.fBody) {
506         ASSERT(!fCurrentFunction);
507         fCurrentFunction = decl;
508         decl->fDefined = true;
509         std::shared_ptr<SymbolTable> old = fSymbolTable;
510         AutoSymbolTable table(this);
511         for (size_t i = 0; i < parameters.size(); i++) {
512             fSymbolTable->addWithoutOwnership(parameters[i]->fName, decl->fParameters[i]);
513         }
514         std::unique_ptr<Block> body = this->convertBlock(*f.fBody);
515         fCurrentFunction = nullptr;
516         if (!body) {
517             return nullptr;
518         }
519         return std::unique_ptr<FunctionDefinition>(new FunctionDefinition(f.fPosition, *decl,
520                                                                           std::move(body)));
521     }
522     return nullptr;
523 }
524
525 std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInterfaceBlock& intf) {
526     std::shared_ptr<SymbolTable> old = fSymbolTable;
527     AutoSymbolTable table(this);
528     std::vector<Type::Field> fields;
529     for (size_t i = 0; i < intf.fDeclarations.size(); i++) {
530         std::unique_ptr<VarDeclarations> decl = this->convertVarDeclarations(
531                                                                          *intf.fDeclarations[i],
532                                                                          Variable::kGlobal_Storage);
533         if (!decl) {
534             return nullptr;
535         }
536         for (const auto& var : decl->fVars) {
537             fields.push_back(Type::Field(var.fVar->fModifiers, var.fVar->fName,
538                                          &var.fVar->fType));
539             if (var.fValue) {
540                 fErrors.error(decl->fPosition,
541                               "initializers are not permitted on interface block fields");
542             }
543             if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag |
544                                                Modifiers::kOut_Flag |
545                                                Modifiers::kUniform_Flag |
546                                                Modifiers::kConst_Flag)) {
547                 fErrors.error(decl->fPosition,
548                               "interface block fields may not have storage qualifiers");
549             }
550         }
551     }
552     Type* type = new Type(intf.fPosition, intf.fTypeName, fields);
553     old->takeOwnership(type);
554     std::vector<std::unique_ptr<Expression>> sizes;
555     for (const auto& size : intf.fSizes) {
556         if (size) {
557             std::unique_ptr<Expression> converted = this->convertExpression(*size);
558             if (!converted) {
559                 return nullptr;
560             }
561             SkString name = type->fName;
562             int64_t count;
563             if (converted->fKind == Expression::kIntLiteral_Kind) {
564                 count = ((IntLiteral&) *converted).fValue;
565                 if (count <= 0) {
566                     fErrors.error(converted->fPosition, "array size must be positive");
567                 }
568                 name += "[" + to_string(count) + "]";
569             } else {
570                 count = -1;
571                 name += "[]";
572             }
573             type = new Type(name, Type::kArray_Kind, *type, (int) count);
574             fSymbolTable->takeOwnership((Type*) type);
575             sizes.push_back(std::move(converted));
576         } else {
577             type = new Type(type->fName + "[]", Type::kArray_Kind, *type, -1);
578             fSymbolTable->takeOwnership((Type*) type);
579             sizes.push_back(nullptr);
580         }
581     }
582     Variable* var = new Variable(intf.fPosition, intf.fModifiers,
583                                  intf.fInstanceName.size() ? intf.fInstanceName : intf.fTypeName,
584                                  *type, Variable::kGlobal_Storage);
585     old->takeOwnership(var);
586     if (intf.fInstanceName.size()) {
587         old->addWithoutOwnership(intf.fInstanceName, var);
588     } else {
589         for (size_t i = 0; i < fields.size(); i++) {
590             old->add(fields[i].fName, std::unique_ptr<Field>(new Field(intf.fPosition, *var,
591                                                                        (int) i)));
592         }
593     }
594     return std::unique_ptr<InterfaceBlock>(new InterfaceBlock(intf.fPosition, *var,
595                                                               intf.fTypeName,
596                                                               intf.fInstanceName,
597                                                               std::move(sizes),
598                                                               fSymbolTable));
599 }
600
601 const Type* IRGenerator::convertType(const ASTType& type) {
602     const Symbol* result = (*fSymbolTable)[type.fName];
603     if (result && result->fKind == Symbol::kType_Kind) {
604         for (int size : type.fSizes) {
605             SkString name = result->fName + "[";
606             if (size != -1) {
607                 name += to_string(size);
608             }
609             name += "]";
610             result = new Type(name, Type::kArray_Kind, (const Type&) *result, size);
611             fSymbolTable->takeOwnership((Type*) result);
612         }
613         return (const Type*) result;
614     }
615     fErrors.error(type.fPosition, "unknown type '" + type.fName + "'");
616     return nullptr;
617 }
618
619 std::unique_ptr<Expression> IRGenerator::convertExpression(const ASTExpression& expr) {
620     switch (expr.fKind) {
621         case ASTExpression::kIdentifier_Kind:
622             return this->convertIdentifier((ASTIdentifier&) expr);
623         case ASTExpression::kBool_Kind:
624             return std::unique_ptr<Expression>(new BoolLiteral(fContext, expr.fPosition,
625                                                                ((ASTBoolLiteral&) expr).fValue));
626         case ASTExpression::kInt_Kind:
627             return std::unique_ptr<Expression>(new IntLiteral(fContext, expr.fPosition,
628                                                               ((ASTIntLiteral&) expr).fValue));
629         case ASTExpression::kFloat_Kind:
630             return std::unique_ptr<Expression>(new FloatLiteral(fContext, expr.fPosition,
631                                                                 ((ASTFloatLiteral&) expr).fValue));
632         case ASTExpression::kBinary_Kind:
633             return this->convertBinaryExpression((ASTBinaryExpression&) expr);
634         case ASTExpression::kPrefix_Kind:
635             return this->convertPrefixExpression((ASTPrefixExpression&) expr);
636         case ASTExpression::kSuffix_Kind:
637             return this->convertSuffixExpression((ASTSuffixExpression&) expr);
638         case ASTExpression::kTernary_Kind:
639             return this->convertTernaryExpression((ASTTernaryExpression&) expr);
640         default:
641             ABORT("unsupported expression type: %d\n", expr.fKind);
642     }
643 }
644
645 std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier& identifier) {
646     const Symbol* result = (*fSymbolTable)[identifier.fText];
647     if (!result) {
648         fErrors.error(identifier.fPosition, "unknown identifier '" + identifier.fText + "'");
649         return nullptr;
650     }
651     switch (result->fKind) {
652         case Symbol::kFunctionDeclaration_Kind: {
653             std::vector<const FunctionDeclaration*> f = {
654                 (const FunctionDeclaration*) result
655             };
656             return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
657                                                                             identifier.fPosition,
658                                                                             f));
659         }
660         case Symbol::kUnresolvedFunction_Kind: {
661             const UnresolvedFunction* f = (const UnresolvedFunction*) result;
662             return std::unique_ptr<FunctionReference>(new FunctionReference(fContext,
663                                                                             identifier.fPosition,
664                                                                             f->fFunctions));
665         }
666         case Symbol::kVariable_Kind: {
667             const Variable* var = (const Variable*) result;
668             if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
669                 fInputs.fFlipY = true;
670                 if (fSettings->fFlipY &&
671                     (!fSettings->fCaps ||
672                      !fSettings->fCaps->fragCoordConventionsExtensionString())) {
673                     fInputs.fRTHeight = true;
674                 }
675             }
676             // default to kRead_RefKind; this will be corrected later if the variable is written to
677             return std::unique_ptr<VariableReference>(new VariableReference(
678                                                                  identifier.fPosition,
679                                                                  *var,
680                                                                  VariableReference::kRead_RefKind));
681         }
682         case Symbol::kField_Kind: {
683             const Field* field = (const Field*) result;
684             VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner,
685                                                             VariableReference::kRead_RefKind);
686             return std::unique_ptr<Expression>(new FieldAccess(
687                                                   std::unique_ptr<Expression>(base),
688                                                   field->fFieldIndex,
689                                                   FieldAccess::kAnonymousInterfaceBlock_OwnerKind));
690         }
691         case Symbol::kType_Kind: {
692             const Type* t = (const Type*) result;
693             return std::unique_ptr<TypeReference>(new TypeReference(fContext, identifier.fPosition,
694                                                                     *t));
695         }
696         default:
697             ABORT("unsupported symbol type %d\n", result->fKind);
698     }
699
700 }
701
702 std::unique_ptr<Expression> IRGenerator::coerce(std::unique_ptr<Expression> expr,
703                                                 const Type& type) {
704     if (!expr) {
705         return nullptr;
706     }
707     if (expr->fType == type) {
708         return expr;
709     }
710     this->checkValid(*expr);
711     if (expr->fType == *fContext.fInvalid_Type) {
712         return nullptr;
713     }
714     if (!expr->fType.canCoerceTo(type)) {
715         fErrors.error(expr->fPosition, "expected '" + type.description() + "', but found '" +
716                                         expr->fType.description() + "'");
717         return nullptr;
718     }
719     if (type.kind() == Type::kScalar_Kind) {
720         std::vector<std::unique_ptr<Expression>> args;
721         args.push_back(std::move(expr));
722         ASTIdentifier id(Position(), type.description());
723         std::unique_ptr<Expression> ctor = this->convertIdentifier(id);
724         ASSERT(ctor);
725         return this->call(Position(), std::move(ctor), std::move(args));
726     }
727     std::vector<std::unique_ptr<Expression>> args;
728     args.push_back(std::move(expr));
729     return std::unique_ptr<Expression>(new Constructor(Position(), type, std::move(args)));
730 }
731
732 static bool is_matrix_multiply(const Type& left, const Type& right) {
733     if (left.kind() == Type::kMatrix_Kind) {
734         return right.kind() == Type::kMatrix_Kind || right.kind() == Type::kVector_Kind;
735     }
736     return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind;
737 }
738
739 /**
740  * Determines the operand and result types of a binary expression. Returns true if the expression is
741  * legal, false otherwise. If false, the values of the out parameters are undefined.
742  */
743 static bool determine_binary_type(const Context& context,
744                                   Token::Kind op,
745                                   const Type& left,
746                                   const Type& right,
747                                   const Type** outLeftType,
748                                   const Type** outRightType,
749                                   const Type** outResultType,
750                                   bool tryFlipped) {
751     bool isLogical;
752     bool validMatrixOrVectorOp;
753     switch (op) {
754         case Token::EQ:
755             *outLeftType = &left;
756             *outRightType = &left;
757             *outResultType = &left;
758             return right.canCoerceTo(left);
759         case Token::EQEQ: // fall through
760         case Token::NEQ:
761             isLogical = true;
762             validMatrixOrVectorOp = true;
763             break;
764         case Token::LT:   // fall through
765         case Token::GT:   // fall through
766         case Token::LTEQ: // fall through
767         case Token::GTEQ:
768             isLogical = true;
769             validMatrixOrVectorOp = false;
770             break;
771         case Token::LOGICALOR: // fall through
772         case Token::LOGICALAND: // fall through
773         case Token::LOGICALXOR: // fall through
774         case Token::LOGICALOREQ: // fall through
775         case Token::LOGICALANDEQ: // fall through
776         case Token::LOGICALXOREQ:
777             *outLeftType = context.fBool_Type.get();
778             *outRightType = context.fBool_Type.get();
779             *outResultType = context.fBool_Type.get();
780             return left.canCoerceTo(*context.fBool_Type) &&
781                    right.canCoerceTo(*context.fBool_Type);
782         case Token::STAR: // fall through
783         case Token::STAREQ:
784             if (is_matrix_multiply(left, right)) {
785                 // determine final component type
786                 if (determine_binary_type(context, Token::STAR, left.componentType(),
787                                           right.componentType(), outLeftType, outRightType,
788                                           outResultType, false)) {
789                     *outLeftType = &(*outResultType)->toCompound(context, left.columns(),
790                                                                  left.rows());;
791                     *outRightType = &(*outResultType)->toCompound(context, right.columns(),
792                                                                   right.rows());;
793                     int leftColumns = left.columns();
794                     int leftRows = left.rows();
795                     int rightColumns;
796                     int rightRows;
797                     if (right.kind() == Type::kVector_Kind) {
798                         // matrix * vector treats the vector as a column vector, so we need to
799                         // transpose it
800                         rightColumns = right.rows();
801                         rightRows = right.columns();
802                         ASSERT(rightColumns == 1);
803                     } else {
804                         rightColumns = right.columns();
805                         rightRows = right.rows();
806                     }
807                     if (rightColumns > 1) {
808                         *outResultType = &(*outResultType)->toCompound(context, rightColumns,
809                                                                        leftRows);
810                     } else {
811                         // result was a column vector, transpose it back to a row
812                         *outResultType = &(*outResultType)->toCompound(context, leftRows,
813                                                                        rightColumns);
814                     }
815                     return leftColumns == rightRows;
816                 } else {
817                     return false;
818                 }
819             }
820             isLogical = false;
821             validMatrixOrVectorOp = true;
822             break;
823         case Token::PLUS:    // fall through
824         case Token::PLUSEQ:  // fall through
825         case Token::MINUS:   // fall through
826         case Token::MINUSEQ: // fall through
827         case Token::SLASH:   // fall through
828         case Token::SLASHEQ:
829             isLogical = false;
830             validMatrixOrVectorOp = true;
831             break;
832         default:
833             isLogical = false;
834             validMatrixOrVectorOp = false;
835     }
836     bool isVectorOrMatrix = left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind;
837     // FIXME: incorrect for shift
838     if (right.canCoerceTo(left) && (left.kind() == Type::kScalar_Kind ||
839                                    (isVectorOrMatrix && validMatrixOrVectorOp))) {
840         *outLeftType = &left;
841         *outRightType = &left;
842         if (isLogical) {
843             *outResultType = context.fBool_Type.get();
844         } else {
845             *outResultType = &left;
846         }
847         return true;
848     }
849     if ((left.kind() == Type::kVector_Kind || left.kind() == Type::kMatrix_Kind) &&
850         (right.kind() == Type::kScalar_Kind)) {
851         if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
852                                   outRightType, outResultType, false)) {
853             *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
854             if (!isLogical) {
855                 *outResultType = &(*outResultType)->toCompound(context, left.columns(),
856                                                                left.rows());
857             }
858             return true;
859         }
860         return false;
861     }
862     if (tryFlipped) {
863         return determine_binary_type(context, op, right, left, outRightType, outLeftType,
864                                      outResultType, false);
865     }
866     return false;
867 }
868
869 std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
870                                                       Token::Kind op,
871                                                       const Expression& right) const {
872     // Note that we expressly do not worry about precision and overflow here -- we use the maximum
873     // precision to calculate the results and hope the result makes sense. The plan is to move the
874     // Skia caps into SkSL, so we have access to all of them including the precisions of the various
875     // types, which will let us be more intelligent about this.
876     if (left.fKind == Expression::kBoolLiteral_Kind &&
877         right.fKind == Expression::kBoolLiteral_Kind) {
878         bool leftVal  = ((BoolLiteral&) left).fValue;
879         bool rightVal = ((BoolLiteral&) right).fValue;
880         bool result;
881         switch (op) {
882             case Token::LOGICALAND: result = leftVal && rightVal; break;
883             case Token::LOGICALOR:  result = leftVal || rightVal; break;
884             case Token::LOGICALXOR: result = leftVal ^  rightVal; break;
885             default: return nullptr;
886         }
887         return std::unique_ptr<Expression>(new BoolLiteral(fContext, left.fPosition, result));
888     }
889     #define RESULT(t, op) std::unique_ptr<Expression>(new t ## Literal(fContext, left.fPosition, \
890                                                                        leftVal op rightVal))
891     if (left.fKind == Expression::kIntLiteral_Kind && right.fKind == Expression::kIntLiteral_Kind) {
892         int64_t leftVal  = ((IntLiteral&) left).fValue;
893         int64_t rightVal = ((IntLiteral&) right).fValue;
894         switch (op) {
895             case Token::PLUS:       return RESULT(Int,  +);
896             case Token::MINUS:      return RESULT(Int,  -);
897             case Token::STAR:       return RESULT(Int,  *);
898             case Token::SLASH:
899                 if (rightVal) {
900                     return RESULT(Int, /);
901                 }
902                 fErrors.error(right.fPosition, "division by zero");
903                 return nullptr;
904             case Token::PERCENT:
905                 if (rightVal) {
906                     return RESULT(Int, %);
907                 }
908                 fErrors.error(right.fPosition, "division by zero");
909                 return nullptr;
910             case Token::BITWISEAND: return RESULT(Int,  &);
911             case Token::BITWISEOR:  return RESULT(Int,  |);
912             case Token::BITWISEXOR: return RESULT(Int,  ^);
913             case Token::SHL:        return RESULT(Int,  <<);
914             case Token::SHR:        return RESULT(Int,  >>);
915             case Token::EQEQ:       return RESULT(Bool, ==);
916             case Token::NEQ:        return RESULT(Bool, !=);
917             case Token::GT:         return RESULT(Bool, >);
918             case Token::GTEQ:       return RESULT(Bool, >=);
919             case Token::LT:         return RESULT(Bool, <);
920             case Token::LTEQ:       return RESULT(Bool, <=);
921             default:                return nullptr;
922         }
923     }
924     if (left.fKind == Expression::kFloatLiteral_Kind &&
925         right.fKind == Expression::kFloatLiteral_Kind) {
926         double leftVal  = ((FloatLiteral&) left).fValue;
927         double rightVal = ((FloatLiteral&) right).fValue;
928         switch (op) {
929             case Token::PLUS:       return RESULT(Float, +);
930             case Token::MINUS:      return RESULT(Float, -);
931             case Token::STAR:       return RESULT(Float, *);
932             case Token::SLASH:
933                 if (rightVal) {
934                     return RESULT(Float, /);
935                 }
936                 fErrors.error(right.fPosition, "division by zero");
937                 return nullptr;
938             case Token::EQEQ:       return RESULT(Bool,  ==);
939             case Token::NEQ:        return RESULT(Bool,  !=);
940             case Token::GT:         return RESULT(Bool,  >);
941             case Token::GTEQ:       return RESULT(Bool,  >=);
942             case Token::LT:         return RESULT(Bool,  <);
943             case Token::LTEQ:       return RESULT(Bool,  <=);
944             default:                return nullptr;
945         }
946     }
947     #undef RESULT
948     return nullptr;
949 }
950
951 std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
952                                                             const ASTBinaryExpression& expression) {
953     std::unique_ptr<Expression> left = this->convertExpression(*expression.fLeft);
954     if (!left) {
955         return nullptr;
956     }
957     std::unique_ptr<Expression> right = this->convertExpression(*expression.fRight);
958     if (!right) {
959         return nullptr;
960     }
961     const Type* leftType;
962     const Type* rightType;
963     const Type* resultType;
964     if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType,
965                                &rightType, &resultType,
966                                !Token::IsAssignment(expression.fOperator))) {
967         fErrors.error(expression.fPosition, "type mismatch: '" +
968                                             Token::OperatorName(expression.fOperator) +
969                                             "' cannot operate on '" + left->fType.fName +
970                                             "', '" + right->fType.fName + "'");
971         return nullptr;
972     }
973     if (Token::IsAssignment(expression.fOperator)) {
974         this->markWrittenTo(*left, expression.fOperator != Token::EQ);
975     }
976     left = this->coerce(std::move(left), *leftType);
977     right = this->coerce(std::move(right), *rightType);
978     if (!left || !right) {
979         return nullptr;
980     }
981     std::unique_ptr<Expression> result = this->constantFold(*left.get(), expression.fOperator,
982                                                             *right.get());
983     if (!result) {
984         result = std::unique_ptr<Expression>(new BinaryExpression(expression.fPosition,
985                                                                   std::move(left),
986                                                                   expression.fOperator,
987                                                                   std::move(right),
988                                                                   *resultType));
989     }
990     return result;
991 }
992
993 std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
994                                                            const ASTTernaryExpression& expression) {
995     std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*expression.fTest),
996                                                     *fContext.fBool_Type);
997     if (!test) {
998         return nullptr;
999     }
1000     std::unique_ptr<Expression> ifTrue = this->convertExpression(*expression.fIfTrue);
1001     if (!ifTrue) {
1002         return nullptr;
1003     }
1004     std::unique_ptr<Expression> ifFalse = this->convertExpression(*expression.fIfFalse);
1005     if (!ifFalse) {
1006         return nullptr;
1007     }
1008     const Type* trueType;
1009     const Type* falseType;
1010     const Type* resultType;
1011     if (!determine_binary_type(fContext, Token::EQEQ, ifTrue->fType, ifFalse->fType, &trueType,
1012                                &falseType, &resultType, true) || trueType != falseType) {
1013         fErrors.error(expression.fPosition, "ternary operator result mismatch: '" +
1014                                             ifTrue->fType.fName + "', '" +
1015                                             ifFalse->fType.fName + "'");
1016         return nullptr;
1017     }
1018     ifTrue = this->coerce(std::move(ifTrue), *trueType);
1019     if (!ifTrue) {
1020         return nullptr;
1021     }
1022     ifFalse = this->coerce(std::move(ifFalse), *falseType);
1023     if (!ifFalse) {
1024         return nullptr;
1025     }
1026     if (test->fKind == Expression::kBoolLiteral_Kind) {
1027         // static boolean test, just return one of the branches
1028         if (((BoolLiteral&) *test).fValue) {
1029             return ifTrue;
1030         } else {
1031             return ifFalse;
1032         }
1033     }
1034     return std::unique_ptr<Expression>(new TernaryExpression(expression.fPosition,
1035                                                              std::move(test),
1036                                                              std::move(ifTrue),
1037                                                              std::move(ifFalse)));
1038 }
1039
1040 std::unique_ptr<Expression> IRGenerator::call(Position position,
1041                                               const FunctionDeclaration& function,
1042                                               std::vector<std::unique_ptr<Expression>> arguments) {
1043     if (function.fParameters.size() != arguments.size()) {
1044         SkString msg = "call to '" + function.fName + "' expected " +
1045                                  to_string((uint64_t) function.fParameters.size()) +
1046                                  " argument";
1047         if (function.fParameters.size() != 1) {
1048             msg += "s";
1049         }
1050         msg += ", but found " + to_string((uint64_t) arguments.size());
1051         fErrors.error(position, msg);
1052         return nullptr;
1053     }
1054     std::vector<const Type*> types;
1055     const Type* returnType;
1056     if (!function.determineFinalTypes(arguments, &types, &returnType)) {
1057         SkString msg = "no match for " + function.fName + "(";
1058         SkString separator;
1059         for (size_t i = 0; i < arguments.size(); i++) {
1060             msg += separator;
1061             separator = ", ";
1062             msg += arguments[i]->fType.description();
1063         }
1064         msg += ")";
1065         fErrors.error(position, msg);
1066         return nullptr;
1067     }
1068     for (size_t i = 0; i < arguments.size(); i++) {
1069         arguments[i] = this->coerce(std::move(arguments[i]), *types[i]);
1070         if (!arguments[i]) {
1071             return nullptr;
1072         }
1073         if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
1074             this->markWrittenTo(*arguments[i], true);
1075         }
1076     }
1077     return std::unique_ptr<FunctionCall>(new FunctionCall(position, *returnType, function,
1078                                                           std::move(arguments)));
1079 }
1080
1081 /**
1082  * Determines the cost of coercing the arguments of a function to the required types. Returns true
1083  * if the cost could be computed, false if the call is not valid. Cost has no particular meaning
1084  * other than "lower costs are preferred".
1085  */
1086 bool IRGenerator::determineCallCost(const FunctionDeclaration& function,
1087                                     const std::vector<std::unique_ptr<Expression>>& arguments,
1088                                     int* outCost) {
1089     if (function.fParameters.size() != arguments.size()) {
1090         return false;
1091     }
1092     int total = 0;
1093     std::vector<const Type*> types;
1094     const Type* ignored;
1095     if (!function.determineFinalTypes(arguments, &types, &ignored)) {
1096         return false;
1097     }
1098     for (size_t i = 0; i < arguments.size(); i++) {
1099         int cost;
1100         if (arguments[i]->fType.determineCoercionCost(*types[i], &cost)) {
1101             total += cost;
1102         } else {
1103             return false;
1104         }
1105     }
1106     *outCost = total;
1107     return true;
1108 }
1109
1110 std::unique_ptr<Expression> IRGenerator::call(Position position,
1111                                               std::unique_ptr<Expression> functionValue,
1112                                               std::vector<std::unique_ptr<Expression>> arguments) {
1113     if (functionValue->fKind == Expression::kTypeReference_Kind) {
1114         return this->convertConstructor(position,
1115                                         ((TypeReference&) *functionValue).fValue,
1116                                         std::move(arguments));
1117     }
1118     if (functionValue->fKind != Expression::kFunctionReference_Kind) {
1119         fErrors.error(position, "'" + functionValue->description() + "' is not a function");
1120         return nullptr;
1121     }
1122     FunctionReference* ref = (FunctionReference*) functionValue.get();
1123     int bestCost = INT_MAX;
1124     const FunctionDeclaration* best = nullptr;
1125     if (ref->fFunctions.size() > 1) {
1126         for (const auto& f : ref->fFunctions) {
1127             int cost;
1128             if (this->determineCallCost(*f, arguments, &cost) && cost < bestCost) {
1129                 bestCost = cost;
1130                 best = f;
1131             }
1132         }
1133         if (best) {
1134             return this->call(position, *best, std::move(arguments));
1135         }
1136         SkString msg = "no match for " + ref->fFunctions[0]->fName + "(";
1137         SkString separator;
1138         for (size_t i = 0; i < arguments.size(); i++) {
1139             msg += separator;
1140             separator = ", ";
1141             msg += arguments[i]->fType.description();
1142         }
1143         msg += ")";
1144         fErrors.error(position, msg);
1145         return nullptr;
1146     }
1147     return this->call(position, *ref->fFunctions[0], std::move(arguments));
1148 }
1149
1150 std::unique_ptr<Expression> IRGenerator::convertNumberConstructor(
1151                                                     Position position,
1152                                                     const Type& type,
1153                                                     std::vector<std::unique_ptr<Expression>> args) {
1154     ASSERT(type.isNumber());
1155     if (args.size() != 1) {
1156         fErrors.error(position, "invalid arguments to '" + type.description() +
1157                                 "' constructor, (expected exactly 1 argument, but found " +
1158                                 to_string((uint64_t) args.size()) + ")");
1159         return nullptr;
1160     }
1161     if (type == *fContext.fFloat_Type && args.size() == 1 &&
1162         args[0]->fKind == Expression::kIntLiteral_Kind) {
1163         int64_t value = ((IntLiteral&) *args[0]).fValue;
1164         return std::unique_ptr<Expression>(new FloatLiteral(fContext, position, (double) value));
1165     }
1166     if (args[0]->fKind == Expression::kIntLiteral_Kind && (type == *fContext.fInt_Type ||
1167         type == *fContext.fUInt_Type)) {
1168         return std::unique_ptr<Expression>(new IntLiteral(fContext,
1169                                                           position,
1170                                                           ((IntLiteral&) *args[0]).fValue,
1171                                                           &type));
1172     }
1173     if (args[0]->fType == *fContext.fBool_Type) {
1174         std::unique_ptr<IntLiteral> zero(new IntLiteral(fContext, position, 0));
1175         std::unique_ptr<IntLiteral> one(new IntLiteral(fContext, position, 1));
1176         return std::unique_ptr<Expression>(
1177                                      new TernaryExpression(position, std::move(args[0]),
1178                                                            this->coerce(std::move(one), type),
1179                                                            this->coerce(std::move(zero),
1180                                                                         type)));
1181     }
1182     if (!args[0]->fType.isNumber()) {
1183         fErrors.error(position, "invalid argument to '" + type.description() +
1184                                 "' constructor (expected a number or bool, but found '" +
1185                                 args[0]->fType.description() + "')");
1186         return nullptr;
1187     }
1188     return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
1189 }
1190
1191 int component_count(const Type& type) {
1192     switch (type.kind()) {
1193         case Type::kVector_Kind:
1194             return type.columns();
1195         case Type::kMatrix_Kind:
1196             return type.columns() * type.rows();
1197         default:
1198             return 1;
1199     }
1200 }
1201
1202 std::unique_ptr<Expression> IRGenerator::convertCompoundConstructor(
1203                                                     Position position,
1204                                                     const Type& type,
1205                                                     std::vector<std::unique_ptr<Expression>> args) {
1206     ASSERT(type.kind() == Type::kVector_Kind || type.kind() == Type::kMatrix_Kind);
1207     if (type.kind() == Type::kMatrix_Kind && args.size() == 1 &&
1208         args[0]->fType.kind() == Type::kMatrix_Kind) {
1209         // matrix from matrix is always legal
1210         return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
1211                                                            std::move(args)));
1212     }
1213     int actual = 0;
1214     int expected = type.rows() * type.columns();
1215     if (args.size() != 1 || expected != component_count(args[0]->fType) ||
1216         type.componentType().isNumber() != args[0]->fType.componentType().isNumber()) {
1217         for (size_t i = 0; i < args.size(); i++) {
1218             if (args[i]->fType.kind() == Type::kVector_Kind) {
1219                 if (type.componentType().isNumber() !=
1220                     args[i]->fType.componentType().isNumber()) {
1221                     fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
1222                                             "parameter to '" + type.description() +
1223                                             "' constructor");
1224                     return nullptr;
1225                 }
1226                 actual += args[i]->fType.columns();
1227             } else if (args[i]->fType.kind() == Type::kScalar_Kind) {
1228                 actual += 1;
1229                 if (type.kind() != Type::kScalar_Kind) {
1230                     args[i] = this->coerce(std::move(args[i]), type.componentType());
1231                     if (!args[i]) {
1232                         return nullptr;
1233                     }
1234                 }
1235             } else {
1236                 fErrors.error(position, "'" + args[i]->fType.description() + "' is not a valid "
1237                                         "parameter to '" + type.description() + "' constructor");
1238                 return nullptr;
1239             }
1240         }
1241         if (actual != 1 && actual != expected) {
1242             fErrors.error(position, "invalid arguments to '" + type.description() +
1243                                     "' constructor (expected " + to_string(expected) +
1244                                     " scalars, but found " + to_string(actual) + ")");
1245             return nullptr;
1246         }
1247     }
1248     return std::unique_ptr<Expression>(new Constructor(position, std::move(type), std::move(args)));
1249 }
1250
1251 std::unique_ptr<Expression> IRGenerator::convertConstructor(
1252                                                     Position position,
1253                                                     const Type& type,
1254                                                     std::vector<std::unique_ptr<Expression>> args) {
1255     // FIXME: add support for structs
1256     Type::Kind kind = type.kind();
1257     if (args.size() == 1 && args[0]->fType == type) {
1258         // argument is already the right type, just return it
1259         return std::move(args[0]);
1260     }
1261     if (type.isNumber()) {
1262         return this->convertNumberConstructor(position, type, std::move(args));
1263     } else if (kind == Type::kArray_Kind) {
1264         const Type& base = type.componentType();
1265         for (size_t i = 0; i < args.size(); i++) {
1266             args[i] = this->coerce(std::move(args[i]), base);
1267             if (!args[i]) {
1268                 return nullptr;
1269             }
1270         }
1271         return std::unique_ptr<Expression>(new Constructor(position, std::move(type),
1272                                                            std::move(args)));
1273     } else if (kind == Type::kVector_Kind || kind == Type::kMatrix_Kind) {
1274         return this->convertCompoundConstructor(position, type, std::move(args));
1275     } else {
1276         fErrors.error(position, "cannot construct '" + type.description() + "'");
1277         return nullptr;
1278     }
1279 }
1280
1281 std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
1282                                                             const ASTPrefixExpression& expression) {
1283     std::unique_ptr<Expression> base = this->convertExpression(*expression.fOperand);
1284     if (!base) {
1285         return nullptr;
1286     }
1287     switch (expression.fOperator) {
1288         case Token::PLUS:
1289             if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
1290                 fErrors.error(expression.fPosition,
1291                               "'+' cannot operate on '" + base->fType.description() + "'");
1292                 return nullptr;
1293             }
1294             return base;
1295         case Token::MINUS:
1296             if (!base->fType.isNumber() && base->fType.kind() != Type::kVector_Kind) {
1297                 fErrors.error(expression.fPosition,
1298                               "'-' cannot operate on '" + base->fType.description() + "'");
1299                 return nullptr;
1300             }
1301             if (base->fKind == Expression::kIntLiteral_Kind) {
1302                 return std::unique_ptr<Expression>(new IntLiteral(fContext, base->fPosition,
1303                                                                   -((IntLiteral&) *base).fValue));
1304             }
1305             if (base->fKind == Expression::kFloatLiteral_Kind) {
1306                 double value = -((FloatLiteral&) *base).fValue;
1307                 return std::unique_ptr<Expression>(new FloatLiteral(fContext, base->fPosition,
1308                                                                     value));
1309             }
1310             return std::unique_ptr<Expression>(new PrefixExpression(Token::MINUS, std::move(base)));
1311         case Token::PLUSPLUS:
1312             if (!base->fType.isNumber()) {
1313                 fErrors.error(expression.fPosition,
1314                               "'" + Token::OperatorName(expression.fOperator) +
1315                               "' cannot operate on '" + base->fType.description() + "'");
1316                 return nullptr;
1317             }
1318             this->markWrittenTo(*base, true);
1319             break;
1320         case Token::MINUSMINUS:
1321             if (!base->fType.isNumber()) {
1322                 fErrors.error(expression.fPosition,
1323                               "'" + Token::OperatorName(expression.fOperator) +
1324                               "' cannot operate on '" + base->fType.description() + "'");
1325                 return nullptr;
1326             }
1327             this->markWrittenTo(*base, true);
1328             break;
1329         case Token::LOGICALNOT:
1330             if (base->fType != *fContext.fBool_Type) {
1331                 fErrors.error(expression.fPosition,
1332                               "'" + Token::OperatorName(expression.fOperator) +
1333                               "' cannot operate on '" + base->fType.description() + "'");
1334                 return nullptr;
1335             }
1336             if (base->fKind == Expression::kBoolLiteral_Kind) {
1337                 return std::unique_ptr<Expression>(new BoolLiteral(fContext, base->fPosition,
1338                                                                    !((BoolLiteral&) *base).fValue));
1339             }
1340             break;
1341         case Token::BITWISENOT:
1342             if (base->fType != *fContext.fInt_Type) {
1343                 fErrors.error(expression.fPosition,
1344                               "'" + Token::OperatorName(expression.fOperator) +
1345                               "' cannot operate on '" + base->fType.description() + "'");
1346                 return nullptr;
1347             }
1348             break;
1349         default:
1350             ABORT("unsupported prefix operator\n");
1351     }
1352     return std::unique_ptr<Expression>(new PrefixExpression(expression.fOperator,
1353                                                             std::move(base)));
1354 }
1355
1356 std::unique_ptr<Expression> IRGenerator::convertIndex(std::unique_ptr<Expression> base,
1357                                                       const ASTExpression& index) {
1358     if (base->fKind == Expression::kTypeReference_Kind) {
1359         if (index.fKind == ASTExpression::kInt_Kind) {
1360             const Type& oldType = ((TypeReference&) *base).fValue;
1361             int64_t size = ((const ASTIntLiteral&) index).fValue;
1362             Type* newType = new Type(oldType.name() + "[" + to_string(size) + "]",
1363                                      Type::kArray_Kind, oldType, size);
1364             fSymbolTable->takeOwnership(newType);
1365             return std::unique_ptr<Expression>(new TypeReference(fContext, base->fPosition,
1366                                                                  *newType));
1367
1368         } else {
1369             fErrors.error(base->fPosition, "array size must be a constant");
1370             return nullptr;
1371         }
1372     }
1373     if (base->fType.kind() != Type::kArray_Kind && base->fType.kind() != Type::kMatrix_Kind &&
1374             base->fType.kind() != Type::kVector_Kind) {
1375         fErrors.error(base->fPosition, "expected array, but found '" + base->fType.description() +
1376                                        "'");
1377         return nullptr;
1378     }
1379     std::unique_ptr<Expression> converted = this->convertExpression(index);
1380     if (!converted) {
1381         return nullptr;
1382     }
1383     if (converted->fType != *fContext.fUInt_Type) {
1384         converted = this->coerce(std::move(converted), *fContext.fInt_Type);
1385         if (!converted) {
1386             return nullptr;
1387         }
1388     }
1389     return std::unique_ptr<Expression>(new IndexExpression(fContext, std::move(base),
1390                                                            std::move(converted)));
1391 }
1392
1393 std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
1394                                                       const SkString& field) {
1395     auto fields = base->fType.fields();
1396     for (size_t i = 0; i < fields.size(); i++) {
1397         if (fields[i].fName == field) {
1398             return std::unique_ptr<Expression>(new FieldAccess(std::move(base), (int) i));
1399         }
1400     }
1401     fErrors.error(base->fPosition, "type '" + base->fType.description() + "' does not have a "
1402                                    "field named '" + field + "");
1403     return nullptr;
1404 }
1405
1406 std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
1407                                                         const SkString& fields) {
1408     if (base->fType.kind() != Type::kVector_Kind) {
1409         fErrors.error(base->fPosition, "cannot swizzle type '" + base->fType.description() + "'");
1410         return nullptr;
1411     }
1412     std::vector<int> swizzleComponents;
1413     for (size_t i = 0; i < fields.size(); i++) {
1414         switch (fields[i]) {
1415             case 'x': // fall through
1416             case 'r': // fall through
1417             case 's':
1418                 swizzleComponents.push_back(0);
1419                 break;
1420             case 'y': // fall through
1421             case 'g': // fall through
1422             case 't':
1423                 if (base->fType.columns() >= 2) {
1424                     swizzleComponents.push_back(1);
1425                     break;
1426                 }
1427                 // fall through
1428             case 'z': // fall through
1429             case 'b': // fall through
1430             case 'p':
1431                 if (base->fType.columns() >= 3) {
1432                     swizzleComponents.push_back(2);
1433                     break;
1434                 }
1435                 // fall through
1436             case 'w': // fall through
1437             case 'a': // fall through
1438             case 'q':
1439                 if (base->fType.columns() >= 4) {
1440                     swizzleComponents.push_back(3);
1441                     break;
1442                 }
1443                 // fall through
1444             default:
1445                 fErrors.error(base->fPosition, SkStringPrintf("invalid swizzle component '%c'",
1446                                                               fields[i]));
1447                 return nullptr;
1448         }
1449     }
1450     ASSERT(swizzleComponents.size() > 0);
1451     if (swizzleComponents.size() > 4) {
1452         fErrors.error(base->fPosition, "too many components in swizzle mask '" + fields + "'");
1453         return nullptr;
1454     }
1455     return std::unique_ptr<Expression>(new Swizzle(fContext, std::move(base), swizzleComponents));
1456 }
1457
1458 std::unique_ptr<Expression> IRGenerator::getCap(Position position, SkString name) {
1459     auto found = fCapsMap.find(name);
1460     if (found == fCapsMap.end()) {
1461         fErrors.error(position, "unknown capability flag '" + name + "'");
1462         return nullptr;
1463     }
1464     switch (found->second.fKind) {
1465         case CapValue::kBool_Kind:
1466             return std::unique_ptr<Expression>(new BoolLiteral(fContext, position,
1467                                                                (bool) found->second.fValue));
1468         case CapValue::kInt_Kind:
1469             return std::unique_ptr<Expression>(new IntLiteral(fContext, position,
1470                                                               found->second.fValue));
1471     }
1472     ASSERT(false);
1473     return nullptr;
1474 }
1475
1476 std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
1477                                                             const ASTSuffixExpression& expression) {
1478     std::unique_ptr<Expression> base = this->convertExpression(*expression.fBase);
1479     if (!base) {
1480         return nullptr;
1481     }
1482     switch (expression.fSuffix->fKind) {
1483         case ASTSuffix::kIndex_Kind: {
1484             const ASTExpression* expr = ((ASTIndexSuffix&) *expression.fSuffix).fExpression.get();
1485             if (expr) {
1486                 return this->convertIndex(std::move(base), *expr);
1487             } else if (base->fKind == Expression::kTypeReference_Kind) {
1488                 const Type& oldType = ((TypeReference&) *base).fValue;
1489                 Type* newType = new Type(oldType.name() + "[]", Type::kArray_Kind, oldType,
1490                                          -1);
1491                 fSymbolTable->takeOwnership(newType);
1492                 return std::unique_ptr<Expression>(new TypeReference(fContext, base->fPosition,
1493                                                                      *newType));
1494             } else {
1495                 fErrors.error(expression.fPosition, "'[]' must follow a type name");
1496                 return nullptr;
1497             }
1498         }
1499         case ASTSuffix::kCall_Kind: {
1500             auto rawArguments = &((ASTCallSuffix&) *expression.fSuffix).fArguments;
1501             std::vector<std::unique_ptr<Expression>> arguments;
1502             for (size_t i = 0; i < rawArguments->size(); i++) {
1503                 std::unique_ptr<Expression> converted =
1504                         this->convertExpression(*(*rawArguments)[i]);
1505                 if (!converted) {
1506                     return nullptr;
1507                 }
1508                 arguments.push_back(std::move(converted));
1509             }
1510             return this->call(expression.fPosition, std::move(base), std::move(arguments));
1511         }
1512         case ASTSuffix::kField_Kind: {
1513             if (base->fType == *fContext.fSkCaps_Type) {
1514                 return this->getCap(expression.fPosition,
1515                                     ((ASTFieldSuffix&) *expression.fSuffix).fField);
1516             }
1517             switch (base->fType.kind()) {
1518                 case Type::kVector_Kind:
1519                     return this->convertSwizzle(std::move(base),
1520                                                 ((ASTFieldSuffix&) *expression.fSuffix).fField);
1521                 case Type::kStruct_Kind:
1522                     return this->convertField(std::move(base),
1523                                               ((ASTFieldSuffix&) *expression.fSuffix).fField);
1524                 default:
1525                     fErrors.error(base->fPosition, "cannot swizzle value of type '" +
1526                                                    base->fType.description() + "'");
1527                     return nullptr;
1528             }
1529         }
1530         case ASTSuffix::kPostIncrement_Kind:
1531             if (!base->fType.isNumber()) {
1532                 fErrors.error(expression.fPosition,
1533                               "'++' cannot operate on '" + base->fType.description() + "'");
1534                 return nullptr;
1535             }
1536             this->markWrittenTo(*base, true);
1537             return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
1538                                                                      Token::PLUSPLUS));
1539         case ASTSuffix::kPostDecrement_Kind:
1540             if (!base->fType.isNumber()) {
1541                 fErrors.error(expression.fPosition,
1542                               "'--' cannot operate on '" + base->fType.description() + "'");
1543                 return nullptr;
1544             }
1545             this->markWrittenTo(*base, true);
1546             return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
1547                                                                      Token::MINUSMINUS));
1548         default:
1549             ABORT("unsupported suffix operator");
1550     }
1551 }
1552
1553 void IRGenerator::checkValid(const Expression& expr) {
1554     switch (expr.fKind) {
1555         case Expression::kFunctionReference_Kind:
1556             fErrors.error(expr.fPosition, "expected '(' to begin function call");
1557             break;
1558         case Expression::kTypeReference_Kind:
1559             fErrors.error(expr.fPosition, "expected '(' to begin constructor invocation");
1560             break;
1561         default:
1562             if (expr.fType == *fContext.fInvalid_Type) {
1563                 fErrors.error(expr.fPosition, "invalid expression");
1564             }
1565     }
1566 }
1567
1568 static bool has_duplicates(const Swizzle& swizzle) {
1569     int bits = 0;
1570     for (int idx : swizzle.fComponents) {
1571         ASSERT(idx >= 0 && idx <= 3);
1572         int bit = 1 << idx;
1573         if (bits & bit) {
1574             return true;
1575         }
1576         bits |= bit;
1577     }
1578     return false;
1579 }
1580
1581 void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) {
1582     switch (expr.fKind) {
1583         case Expression::kVariableReference_Kind: {
1584             const Variable& var = ((VariableReference&) expr).fVariable;
1585             if (var.fModifiers.fFlags & (Modifiers::kConst_Flag | Modifiers::kUniform_Flag)) {
1586                 fErrors.error(expr.fPosition,
1587                               "cannot modify immutable variable '" + var.fName + "'");
1588             }
1589             ((VariableReference&) expr).setRefKind(readWrite ? VariableReference::kReadWrite_RefKind
1590                                                              : VariableReference::kWrite_RefKind);
1591             break;
1592         }
1593         case Expression::kFieldAccess_Kind:
1594             this->markWrittenTo(*((FieldAccess&) expr).fBase, readWrite);
1595             break;
1596         case Expression::kSwizzle_Kind:
1597             if (has_duplicates((Swizzle&) expr)) {
1598                 fErrors.error(expr.fPosition,
1599                               "cannot write to the same swizzle field more than once");
1600             }
1601             this->markWrittenTo(*((Swizzle&) expr).fBase, readWrite);
1602             break;
1603         case Expression::kIndex_Kind:
1604             this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite);
1605             break;
1606         default:
1607             fErrors.error(expr.fPosition, "cannot assign to '" + expr.description() + "'");
1608             break;
1609     }
1610 }
1611
1612 }