From 3df510bf425dbb4f2f554ff7170826fa62f64efc Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Mon, 24 Jun 2019 07:31:52 -0700 Subject: [PATCH] Add parsing/printing for new affine.load and affine.store operations. The new operations affine.load and affine.store will take composed affine maps by construction. These operations will eventually replace load and store operations currently used in affine regions and operated on by affine transformation and analysis passes. PiperOrigin-RevId: 254754048 --- mlir/include/mlir/AffineOps/AffineOps.h | 101 +++++++++++++++++++ mlir/include/mlir/IR/OpImplementation.h | 17 ++++ mlir/lib/AffineOps/AffineOps.cpp | 115 +++++++++++++++++++++- mlir/lib/IR/AsmPrinter.cpp | 85 ++++++++++++---- mlir/lib/Parser/Parser.cpp | 128 +++++++++++++++++++++++- mlir/lib/Parser/TokenKinds.def | 1 + mlir/test/AffineOps/load-store.mlir | 168 ++++++++++++++++++++++++++++++++ 7 files changed, 593 insertions(+), 22 deletions(-) create mode 100644 mlir/test/AffineOps/load-store.mlir diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h index 4e048eb..0aeb985 100644 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ b/mlir/include/mlir/AffineOps/AffineOps.h @@ -346,6 +346,107 @@ public: static StringRef getOperationName() { return "affine.terminator"; } }; +/// The "affine.load" op reads an element from a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The output of 'affine.load' is a new value with the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineLoadOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine load op with the specified map and operands. + static void build(Builder *builder, OperationState *result, AffineMap map, + ArrayRef operands); + + /// Get memref operand. + Value *getMemRef() { return getOperand(0); } + void setMemRef(Value *value) { setOperand(0, value); } + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { + return getAttrOfType("map").getValue(); + } + + static StringRef getOperationName() { return "affine.load"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); +}; + +/// The "affine.store" op writes an element to a memref, where the index +/// for each memref dimension is an affine expression of loop induction +/// variables and symbols. The 'affine.store' op stores a new value which is the +/// same type as the elements of the memref. An affine expression of loop IVs +/// and symbols must be specified for each dimension of the memref. The keyword +/// 'symbol' can be used to indicate SSA identifiers which are symbolic. +// +// Example 1: +// +// affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// +// Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'. +// +// affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)] +// : memref<100x100xf32> +// +class AffineStoreOp : public Op::Impl> { +public: + using Op::Op; + + /// Builds an affine store operation with the specified map and operands. + static void build(Builder *builder, OperationState *result, + Value *valueToStore, AffineMap map, + ArrayRef operands); + + /// Get value to be stored by store operation. + Value *getValueToStore() { return getOperand(0); } + + /// Get memref operand. + Value *getMemRef() { return getOperand(1); } + void setMemRef(Value *value) { setOperand(1, value); } + + MemRefType getMemRefType() { + return getMemRef()->getType().cast(); + } + + /// Get affine map operands. + operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); } + + /// Returns the affine map used to index the memref for this operation. + AffineMap getAffineMap() { + return getAttrOfType("map").getValue(); + } + + static StringRef getOperationName() { return "affine.store"; } + + // Hooks to customize behavior of this op. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + LogicalResult verify(); +}; + /// Returns true if the given Value can be used as a dimension id. bool isValidDim(Value *value); diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index c16a9c4..162ed11 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -85,6 +85,11 @@ public: virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, bool printBlockTerminators = true) = 0; + /// Prints an affine map of SSA ids, where SSA id names are used in place + /// of dims/symbols. + virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ArrayRef operands) = 0; + /// Print an optional arrow followed by a type list. void printOptionalArrowTypeList(ArrayRef types) { if (types.empty()) @@ -236,6 +241,12 @@ public: /// Parse a `)` token if present. virtual ParseResult parseOptionalRParen() = 0; + /// Parse a `[` token. + virtual ParseResult parseLSquare() = 0; + + /// Parse a `]` token. + virtual ParseResult parseRSquare() = 0; + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -362,6 +373,12 @@ public: return success(); } + /// Parses an affine map attribute where dims and symbols are SSA operands. + virtual ParseResult + parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &map, + StringRef attrName, + SmallVectorImpl &attrs) = 0; + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index f886c5a..e35847f 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -37,7 +37,8 @@ using llvm::dbgs; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addOperations(); + addOperations(); } /// A utility function to check if a value is defined at the top level of a @@ -1332,3 +1333,115 @@ Region &AffineIfOp::getThenBlocks() { return getOperation()->getRegion(0); } /// Returns the list of 'else' blocks. Region &AffineIfOp::getElseBlocks() { return getOperation()->getRegion(1); } + +//===----------------------------------------------------------------------===// +// AffineLoadOp +//===----------------------------------------------------------------------===// + +void AffineLoadOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef operands) { + // TODO(b/133776335) Check that map operands are loop IVs or symbols. + result->addOperands(operands); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + MemRefType type; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser->parseOperand(memrefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map", + result->attributes) || + parser->parseRSquare() || parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +void AffineLoadOp::print(OpAsmPrinter *p) { + *p << "affine.load " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType("map"); + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + *p << "] : " << getMemRefType(); +} + +LogicalResult AffineLoadOp::verify() { + if (getType() != getMemRefType().getElementType()) + return emitOpError("result type must match element type of memref"); + + AffineMap map = getAttrOfType("map").getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.load affine map num results must equal memref " + "rank"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + // TODO(b/133776335) Verify that map operands are loop IVs or symbols. + return success(); +} + +//===----------------------------------------------------------------------===// +// AffineStoreOp +//===----------------------------------------------------------------------===// + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, AffineMap map, + ArrayRef operands) { + // TODO(b/133776335) Check that map operands are loop IVs or symbols. + result->addOperands(valueToStore); + result->addOperands(operands); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { + auto affineIntTy = parser->getBuilder().getIndexType(); + + MemRefType type; + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector mapOperands; + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || parser->parseLSquare() || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map", + result->attributes) || + parser->parseRSquare() || parser->parseColonType(type) || + parser->resolveOperand(storeValueInfo, type.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands)); +} + +void AffineStoreOp::print(OpAsmPrinter *p) { + *p << "affine.store " << *getValueToStore(); + *p << ", " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType("map"); + SmallVector operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + *p << "] : " << getMemRefType(); +} + +LogicalResult AffineStoreOp::verify() { + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != getMemRefType().getElementType()) + return emitOpError("first operand must have same type memref element type"); + + AffineMap map = getAttrOfType("map").getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.store affine map num results must equal memref " + "rank"); + + for (auto *idx : getIndices()) + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + // TODO(b/133776335) Verify that map operands are loop IVs or symbols. + return success(); +} diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a56c118..efa1de3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -347,7 +347,8 @@ public: void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); - void printAffineExpr(AffineExpr expr); + void printAffineExpr(AffineExpr expr, ArrayRef dimValueNames = {}, + ArrayRef symbolValueNames = {}); void printAffineConstraint(AffineExpr expr, bool isEq); void printIntegerSet(IntegerSet set); @@ -370,7 +371,9 @@ protected: Strong, // All other binary operators. }; void printAffineExprInternal(AffineExpr expr, - BindingStrength enclosingTightness); + BindingStrength enclosingTightness, + ArrayRef dimValueNames = {}, + ArrayRef symbolValueNames = {}); }; } // end anonymous namespace @@ -918,20 +921,34 @@ void ModulePrinter::printType(Type type) { // Affine expressions and maps //===----------------------------------------------------------------------===// -void ModulePrinter::printAffineExpr(AffineExpr expr) { - printAffineExprInternal(expr, BindingStrength::Weak); +void ModulePrinter::printAffineExpr(AffineExpr expr, + ArrayRef dimValueNames, + ArrayRef symbolValueNames) { + printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames, + symbolValueNames); } void ModulePrinter::printAffineExprInternal( - AffineExpr expr, BindingStrength enclosingTightness) { + AffineExpr expr, BindingStrength enclosingTightness, + ArrayRef dimValueNames, ArrayRef symbolValueNames) { const char *binopSpelling = nullptr; switch (expr.getKind()) { - case AffineExprKind::SymbolId: - os << 's' << expr.cast().getPosition(); + case AffineExprKind::SymbolId: { + unsigned pos = expr.cast().getPosition(); + if (pos < symbolValueNames.size()) + os << "symbol(%" << symbolValueNames[pos] << ')'; + else + os << 's' << pos; return; - case AffineExprKind::DimId: - os << 'd' << expr.cast().getPosition(); + } + case AffineExprKind::DimId: { + unsigned pos = expr.cast().getPosition(); + if (pos < dimValueNames.size()) + os << '%' << dimValueNames[pos]; + else + os << 'd' << pos; return; + } case AffineExprKind::Constant: os << expr.cast().getValue(); return; @@ -965,13 +982,16 @@ void ModulePrinter::printAffineExprInternal( auto rhsConst = rhsExpr.dyn_cast(); if (rhsConst && rhsConst.getValue() == -1) { os << "-"; - printAffineExprInternal(lhsExpr, BindingStrength::Strong); + printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames, + symbolValueNames); return; } - printAffineExprInternal(lhsExpr, BindingStrength::Strong); + printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames, + symbolValueNames); os << binopSpelling; - printAffineExprInternal(rhsExpr, BindingStrength::Strong); + printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames, + symbolValueNames); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -989,12 +1009,15 @@ void ModulePrinter::printAffineExprInternal( AffineExpr rrhsExpr = rhs.getRHS(); if (auto rrhs = rrhsExpr.dyn_cast()) { if (rrhs.getValue() == -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, + symbolValueNames); os << " - "; if (rhs.getLHS().getKind() == AffineExprKind::Add) { - printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong); + printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, + dimValueNames, symbolValueNames); } else { - printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak); + printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak, + dimValueNames, symbolValueNames); } if (enclosingTightness == BindingStrength::Strong) @@ -1003,9 +1026,11 @@ void ModulePrinter::printAffineExprInternal( } if (rrhs.getValue() < -1) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, + symbolValueNames); os << " - "; - printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong); + printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong, + dimValueNames, symbolValueNames); os << " * " << -rrhs.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1018,7 +1043,8 @@ void ModulePrinter::printAffineExprInternal( // Pretty print addition to a negative number as a subtraction. if (auto rhsConst = rhsExpr.dyn_cast()) { if (rhsConst.getValue() < 0) { - printAffineExprInternal(lhsExpr, BindingStrength::Weak); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, + symbolValueNames); os << " - " << -rhsConst.getValue(); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1026,9 +1052,11 @@ void ModulePrinter::printAffineExprInternal( } } - printAffineExprInternal(lhsExpr, BindingStrength::Weak); + printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames, + symbolValueNames); os << " + "; - printAffineExprInternal(rhsExpr, BindingStrength::Weak); + printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames, + symbolValueNames); if (enclosingTightness == BindingStrength::Strong) os << ')'; @@ -1210,6 +1238,23 @@ public: os.indent(currentIndent) << "}"; } + void printAffineMapOfSSAIds(AffineMapAttr mapAttr, + ArrayRef operands) { + AffineMap map = mapAttr.getValue(); + unsigned numDims = map.getNumDims(); + SmallVector dimValueNames; + SmallVector symbolValueNames; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (i < numDims) + dimValueNames.push_back(valueNames[operands[i]]); + else + symbolValueNames.push_back(valueNames[operands[i]]); + } + interleaveComma(map.getResults(), [&](AffineExpr expr) { + printAffineExpr(expr, dimValueNames, symbolValueNames); + }); + } + // Number of spaces used for indenting nested operations. const static unsigned indentWidth = 2; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index f5108b5..d4428d4 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -273,6 +273,11 @@ public: ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map, IntegerSet &set); + /// Parse an AffineMap where the dim and symbol identifiers are SSA ids. + ParseResult + parseAffineMapOfSSAIds(AffineMap &map, + llvm::function_ref parseElement); + private: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. @@ -1668,11 +1673,17 @@ namespace { /// bodies. class AffineParser : public Parser { public: - explicit AffineParser(ParserState &state) : Parser(state) {} + AffineParser(ParserState &state, bool allowParsingSSAIds = false, + llvm::function_ref parseElement = nullptr) + : Parser(state), allowParsingSSAIds(allowParsingSSAIds), + parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {} AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols); ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set); IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols); + ParseResult parseAffineMapOfSSAIds(AffineMap &map); + void getDimsAndSymbolSSAIds(SmallVectorImpl &dimAndSymbolSSAIds, + unsigned &numDims); private: // Binary affine op parsing. @@ -1691,6 +1702,8 @@ private: AffineExpr parseNegateExpression(AffineExpr lhs); AffineExpr parseIntegerExpr(); AffineExpr parseBareIdExpr(); + AffineExpr parseSSAIdExpr(bool isSymbol); + AffineExpr parseSymbolSSAIdExpr(); AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, AffineExpr rhs, SMLoc opLoc); @@ -1703,6 +1716,10 @@ private: AffineExpr parseAffineConstraint(bool *isEq); private: + bool allowParsingSSAIds; + llvm::function_ref parseElement; + unsigned numDimOperands; + unsigned numSymbolOperands; SmallVector, 4> dimsAndSymbols; }; } // end anonymous namespace @@ -1892,6 +1909,42 @@ AffineExpr AffineParser::parseBareIdExpr() { return (emitError("use of undeclared identifier"), nullptr); } +/// Parse an SSA id which may appear in an affine expression. +AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) { + if (!allowParsingSSAIds) + return (emitError("unexpected ssa identifier"), nullptr); + if (getToken().isNot(Token::percent_identifier)) + return (emitError("expected ssa identifier"), nullptr); + auto name = getTokenSpelling(); + // Check if we already parsed this SSA id. + for (auto entry : dimsAndSymbols) { + if (entry.first == name) { + consumeToken(Token::percent_identifier); + return entry.second; + } + } + // Parse the SSA id and add an AffineDim/SymbolExpr to represent it. + if (parseElement(isSymbol)) + return (emitError("failed to parse ssa identifier"), nullptr); + auto idExpr = isSymbol + ? getAffineSymbolExpr(numSymbolOperands++, getContext()) + : getAffineDimExpr(numDimOperands++, getContext()); + dimsAndSymbols.push_back({name, idExpr}); + return idExpr; +} + +AffineExpr AffineParser::parseSymbolSSAIdExpr() { + if (parseToken(Token::kw_symbol, "expected symbol keyword") || + parseToken(Token::l_paren, "expected '(' at start of SSA symbol")) + return nullptr; + AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true); + if (!symbolExpr) + return nullptr; + if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol")) + return nullptr; + return symbolExpr; +} + /// Parse a positive integral constant appearing in an affine expression. /// /// affine-expr ::= integer-literal @@ -1917,6 +1970,10 @@ AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { switch (getToken().getKind()) { case Token::bare_identifier: return parseBareIdExpr(); + case Token::kw_symbol: + return parseSymbolSSAIdExpr(); + case Token::percent_identifier: + return parseSSAIdExpr(/*isSymbol=*/false); case Token::integer: return parseIntegerExpr(); case Token::l_paren: @@ -2109,6 +2166,26 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map, return failure(); } +/// Parse an AffineMap where the dim and symbol identifiers are SSA ids. +ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) { + SmallVector exprs; + auto parseElt = [&]() -> ParseResult { + auto elt = parseAffineExpr(); + exprs.push_back(elt); + return elt ? success() : failure(); + }; + + // Parse a multi-dimensional affine expression (a comma-separated list of + // 1-d affine expressions); the list cannot be empty. Grammar: + // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) + if (parseCommaSeparatedList(parseElt)) + return failure(); + // Parsed a valid affine map. + map = builder.getAffineMap(numDimOperands, + dimsAndSymbols.size() - numDimOperands, exprs); + return success(); +} + /// Parse the range and sizes affine map definition inline. /// /// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr @@ -2221,6 +2298,14 @@ ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map, return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set); } +/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to +/// parse SSA value uses encountered while parsing affine expressions. +ParseResult Parser::parseAffineMapOfSSAIds( + AffineMap &map, llvm::function_ref parseElement) { + return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement) + .parseAffineMapOfSSAIds(map); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -3014,6 +3099,16 @@ public: return success(parser.consumeIf(Token::r_paren)); } + /// Parse a `[` token. + ParseResult parseLSquare() override { + return parser.parseToken(Token::l_square, "expected '['"); + } + + /// Parse a `]` token. + ParseResult parseRSquare() override { + return parser.parseToken(Token::r_square, "expected ']'"); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -3155,6 +3250,37 @@ public: return failure(); } + /// Parse an AffineMap of SSA ids. + ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, + Attribute &mapAttr, StringRef attrName, + SmallVectorImpl &attrs) { + SmallVector dimOperands; + SmallVector symOperands; + + auto parseElement = [&](bool isSymbol) -> ParseResult { + OperandType operand; + if (parseOperand(operand)) + return failure(); + if (isSymbol) + symOperands.push_back(operand); + else + dimOperands.push_back(operand); + return success(); + }; + + AffineMap map; + if (parser.parseAffineMapOfSSAIds(map, parseElement)) + return failure(); + // Add AffineMap attribute. + mapAttr = parser.builder.getAffineMapAttr(map); + attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr)); + + // Add dim operands before symbol operands in 'operands'. + operands.assign(dimOperands.begin(), dimOperands.end()); + operands.append(symOperands.begin(), symOperands.end()); + return success(); + } + //===--------------------------------------------------------------------===// // Region Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 98645cb..18067a8 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -113,6 +113,7 @@ TOK_KEYWORD(opaque) TOK_KEYWORD(size) TOK_KEYWORD(sparse) TOK_KEYWORD(step) +TOK_KEYWORD(symbol) TOK_KEYWORD(tensor) TOK_KEYWORD(to) TOK_KEYWORD(true) diff --git a/mlir/test/AffineOps/load-store.mlir b/mlir/test/AffineOps/load-store.mlir new file mode 100644 index 0000000..fdb32f8 --- /dev/null +++ b/mlir/test/AffineOps/load-store.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1) + +// Test with just loop IVs. +func @test0(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %1 = affine.load %0[%i0, %i1] : memref<100x100xf32> +// CHECK: %1 = affine.load %0[%i0, %i1] : memref<100x100xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0 + 3, d1 + 7) + +// Test with loop IVs and constants. +func @test1(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> + affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// CHECK: %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> +// CHECK: affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (d0 + d1, d2 + d3) + +// Test with loop IVs and function args without 'symbol' keyword (should +// be parsed as dim identifiers). +func @test2(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32> + affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32> +// CHECK: %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32> +// CHECK: affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) + +// Test with loop IVs and function args with 'symbol' keyword (should +// be parsed as symbol identifiers). +func @test3(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] + : memref<100x100xf32> + affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] + : memref<100x100xf32> +// CHECK: %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32> +// CHECK: affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> ((d0 + s0) floordiv 3 + 11, (d1 + s1) mod 4 + 7) + +// Test with loop IVs, symbols and constants in nested affine expressions. +func @test4(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<100x100xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, + (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32> + affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, + (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32> +// CHECK: %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32> +// CHECK: affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32> + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) + +// Test with swizzled loop IVs. +func @test5(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<10x10x10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { + %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32> + affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32> +// CHECK: %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32> +// CHECK: affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32> + } + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (d0 + d1, d2 + d3, d3 + d1 + d4) + +// Test with swizzled loop IVs, duplicate args, and function args used as dims. +// Dim identifiers are assigned in parse order: +// d0 = %i2, d1 = %arg0, d2 = %i0, d3 = %i1, d4 = %arg1 +func @test6(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<10x10x10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { + %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] + : memref<10x10x10xf32> + affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] + : memref<10x10x10xf32> +// CHECK: %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32> +// CHECK: affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32> + } + } + } + return +} + +// ----- + +// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2)[s0, s1] -> (d0 + s0, d1 + d2, d2 + s0 + s1) + +// Test with swizzled loop IVs, duplicate args, and function args used as syms. +// Dim and symbol identifiers are assigned in parse order: +// d0 = %i2, d1 = %i0, d2 = %i1 +// s0 = %arg0, s1 = %arg1 +func @test6(%arg0 : index, %arg1 : index) { + %0 = alloc() : memref<10x10x10xf32> + affine.for %i0 = 0 to 10 { + affine.for %i1 = 0 to 10 { + affine.for %i2 = 0 to 10 { + %1 = affine.load %0[%i2 + symbol(%arg0), + %i0 + %i1, + %i1 + symbol(%arg0) + symbol(%arg1)] + : memref<10x10x10xf32> + affine.store %1, %0[%i2 + symbol(%arg0), + %i0 + %i1, + %i1 + symbol(%arg0) + symbol(%arg1)] + : memref<10x10x10xf32> +// CHECK: %1 = affine.load %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32> +// CHECK: affine.store %1, %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32> + } + } + } + return +} \ No newline at end of file -- 2.7.4