Add parsing/printing for new affine.load and affine.store operations.
authorAndy Davis <andydavis@google.com>
Mon, 24 Jun 2019 14:31:52 +0000 (07:31 -0700)
committerjpienaar <jpienaar@google.com>
Mon, 24 Jun 2019 20:45:09 +0000 (13:45 -0700)
The new operations affine.load and affine.store will take composed affine maps by construction.
These operations will eventually replace load and store operations currently used in affine regions and operated on by affine transformation and analysis passes.

PiperOrigin-RevId: 254754048

mlir/include/mlir/AffineOps/AffineOps.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/TokenKinds.def
mlir/test/AffineOps/load-store.mlir [new file with mode: 0644]

index 4e048eb..0aeb985 100644 (file)
@@ -346,6 +346,107 @@ public:
   static StringRef getOperationName() { return "affine.terminator"; }
 };
 
+/// The "affine.load" op reads an element from a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The output of 'affine.load' is a new value with the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+//  Example 1:
+//
+//    %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+//  Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+//    %1 = affine.load %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+//      : memref<100x100xf32>
+//
+class AffineLoadOp : public Op<AffineLoadOp, OpTrait::OneResult,
+                               OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+  using Op::Op;
+
+  /// Builds an affine load op with the specified map and operands.
+  static void build(Builder *builder, OperationState *result, AffineMap map,
+                    ArrayRef<Value *> operands);
+
+  /// Get memref operand.
+  Value *getMemRef() { return getOperand(0); }
+  void setMemRef(Value *value) { setOperand(0, value); }
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Get affine map operands.
+  operand_range getIndices() { return llvm::drop_begin(getOperands(), 1); }
+
+  /// Returns the affine map used to index the memref for this operation.
+  AffineMap getAffineMap() {
+    return getAttrOfType<AffineMapAttr>("map").getValue();
+  }
+
+  static StringRef getOperationName() { return "affine.load"; }
+
+  // Hooks to customize behavior of this op.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+};
+
+/// The "affine.store" op writes an element to a memref, where the index
+/// for each memref dimension is an affine expression of loop induction
+/// variables and symbols. The 'affine.store' op stores a new value which is the
+/// same type as the elements of the memref. An affine expression of loop IVs
+/// and symbols must be specified for each dimension of the memref. The keyword
+/// 'symbol' can be used to indicate SSA identifiers which are symbolic.
+//
+//  Example 1:
+//
+//    affine.store %v0, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+//
+//  Example 2: Uses 'symbol' keyword for symbols '%n' and '%m'.
+//
+//    affine.store %v0, %0[%i0 + symbol(%n), %i1 + symbol(%m)]
+//      : memref<100x100xf32>
+//
+class AffineStoreOp : public Op<AffineStoreOp, OpTrait::ZeroResult,
+                                OpTrait::AtLeastNOperands<1>::Impl> {
+public:
+  using Op::Op;
+
+  /// Builds an affine store operation with the specified map and operands.
+  static void build(Builder *builder, OperationState *result,
+                    Value *valueToStore, AffineMap map,
+                    ArrayRef<Value *> operands);
+
+  /// Get value to be stored by store operation.
+  Value *getValueToStore() { return getOperand(0); }
+
+  /// Get memref operand.
+  Value *getMemRef() { return getOperand(1); }
+  void setMemRef(Value *value) { setOperand(1, value); }
+
+  MemRefType getMemRefType() {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+
+  /// Get affine map operands.
+  operand_range getIndices() { return llvm::drop_begin(getOperands(), 2); }
+
+  /// Returns the affine map used to index the memref for this operation.
+  AffineMap getAffineMap() {
+    return getAttrOfType<AffineMapAttr>("map").getValue();
+  }
+
+  static StringRef getOperationName() { return "affine.store"; }
+
+  // Hooks to customize behavior of this op.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+  LogicalResult verify();
+};
+
 /// Returns true if the given Value can be used as a dimension id.
 bool isValidDim(Value *value);
 
index c16a9c4..162ed11 100644 (file)
@@ -85,6 +85,11 @@ public:
   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
                            bool printBlockTerminators = true) = 0;
 
+  /// Prints an affine map of SSA ids, where SSA id names are used in place
+  /// of dims/symbols.
+  virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+                                      ArrayRef<Value *> operands) = 0;
+
   /// Print an optional arrow followed by a type list.
   void printOptionalArrowTypeList(ArrayRef<Type> types) {
     if (types.empty())
@@ -236,6 +241,12 @@ public:
   /// Parse a `)` token if present.
   virtual ParseResult parseOptionalRParen() = 0;
 
+  /// Parse a `[` token.
+  virtual ParseResult parseLSquare() = 0;
+
+  /// Parse a `]` token.
+  virtual ParseResult parseRSquare() = 0;
+
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//
@@ -362,6 +373,12 @@ public:
     return success();
   }
 
+  /// Parses an affine map attribute where dims and symbols are SSA operands.
+  virtual ParseResult
+  parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
+                         StringRef attrName,
+                         SmallVectorImpl<NamedAttribute> &attrs) = 0;
+
   //===--------------------------------------------------------------------===//
   // Region Parsing
   //===--------------------------------------------------------------------===//
index f886c5a..e35847f 100644 (file)
@@ -37,7 +37,8 @@ using llvm::dbgs;
 
 AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineTerminatorOp>();
+  addOperations<AffineApplyOp, AffineForOp, AffineIfOp, AffineLoadOp,
+                AffineStoreOp, AffineTerminatorOp>();
 }
 
 /// A utility function to check if a value is defined at the top level of a
@@ -1332,3 +1333,115 @@ Region &AffineIfOp::getThenBlocks() { return getOperation()->getRegion(0); }
 
 /// Returns the list of 'else' blocks.
 Region &AffineIfOp::getElseBlocks() { return getOperation()->getRegion(1); }
+
+//===----------------------------------------------------------------------===//
+// AffineLoadOp
+//===----------------------------------------------------------------------===//
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+                         AffineMap map, ArrayRef<Value *> operands) {
+  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
+  result->addOperands(operands);
+  result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto &builder = parser->getBuilder();
+  auto affineIntTy = builder.getIndexType();
+
+  MemRefType type;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser->parseOperand(memrefInfo) || parser->parseLSquare() ||
+      parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map",
+                                     result->attributes) ||
+      parser->parseRSquare() || parser->parseColonType(type) ||
+      parser->resolveOperand(memrefInfo, type, result->operands) ||
+      parser->resolveOperands(mapOperands, affineIntTy, result->operands) ||
+      parser->addTypeToList(type.getElementType(), result->types));
+}
+
+void AffineLoadOp::print(OpAsmPrinter *p) {
+  *p << "affine.load " << *getMemRef() << '[';
+  AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>("map");
+  SmallVector<Value *, 2> operands(getIndices());
+  p->printAffineMapOfSSAIds(mapAttr, operands);
+  *p << "] : " << getMemRefType();
+}
+
+LogicalResult AffineLoadOp::verify() {
+  if (getType() != getMemRefType().getElementType())
+    return emitOpError("result type must match element type of memref");
+
+  AffineMap map = getAttrOfType<AffineMapAttr>("map").getValue();
+  if (map.getNumResults() != getMemRefType().getRank())
+    return emitOpError("affine.load affine map num results must equal memref "
+                       "rank");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+  // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineStoreOp
+//===----------------------------------------------------------------------===//
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+                          Value *valueToStore, AffineMap map,
+                          ArrayRef<Value *> operands) {
+  // TODO(b/133776335) Check that map operands are loop IVs or symbols.
+  result->addOperands(valueToStore);
+  result->addOperands(operands);
+  result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
+  auto affineIntTy = parser->getBuilder().getIndexType();
+
+  MemRefType type;
+  OpAsmParser::OperandType storeValueInfo;
+  OpAsmParser::OperandType memrefInfo;
+  AffineMapAttr mapAttr;
+  SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+  return failure(
+      parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+      parser->parseOperand(memrefInfo) || parser->parseLSquare() ||
+      parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, "map",
+                                     result->attributes) ||
+      parser->parseRSquare() || parser->parseColonType(type) ||
+      parser->resolveOperand(storeValueInfo, type.getElementType(),
+                             result->operands) ||
+      parser->resolveOperand(memrefInfo, type, result->operands) ||
+      parser->resolveOperands(mapOperands, affineIntTy, result->operands));
+}
+
+void AffineStoreOp::print(OpAsmPrinter *p) {
+  *p << "affine.store " << *getValueToStore();
+  *p << ", " << *getMemRef() << '[';
+  AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>("map");
+  SmallVector<Value *, 2> operands(getIndices());
+  p->printAffineMapOfSSAIds(mapAttr, operands);
+  *p << "] : " << getMemRefType();
+}
+
+LogicalResult AffineStoreOp::verify() {
+  // First operand must have same type as memref element type.
+  if (getValueToStore()->getType() != getMemRefType().getElementType())
+    return emitOpError("first operand must have same type memref element type");
+
+  AffineMap map = getAttrOfType<AffineMapAttr>("map").getValue();
+  if (map.getNumResults() != getMemRefType().getRank())
+    return emitOpError("affine.store affine map num results must equal memref "
+                       "rank");
+
+  for (auto *idx : getIndices())
+    if (!idx->getType().isIndex())
+      return emitOpError("index to load must have 'index' type");
+  // TODO(b/133776335) Verify that map operands are loop IVs or symbols.
+  return success();
+}
index a56c118..efa1de3 100644 (file)
@@ -347,7 +347,8 @@ public:
   void printLocation(LocationAttr loc);
 
   void printAffineMap(AffineMap map);
-  void printAffineExpr(AffineExpr expr);
+  void printAffineExpr(AffineExpr expr, ArrayRef<StringRef> dimValueNames = {},
+                       ArrayRef<StringRef> symbolValueNames = {});
   void printAffineConstraint(AffineExpr expr, bool isEq);
   void printIntegerSet(IntegerSet set);
 
@@ -370,7 +371,9 @@ protected:
     Strong, // All other binary operators.
   };
   void printAffineExprInternal(AffineExpr expr,
-                               BindingStrength enclosingTightness);
+                               BindingStrength enclosingTightness,
+                               ArrayRef<StringRef> dimValueNames = {},
+                               ArrayRef<StringRef> symbolValueNames = {});
 };
 } // end anonymous namespace
 
@@ -918,20 +921,34 @@ void ModulePrinter::printType(Type type) {
 // Affine expressions and maps
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printAffineExpr(AffineExpr expr) {
-  printAffineExprInternal(expr, BindingStrength::Weak);
+void ModulePrinter::printAffineExpr(AffineExpr expr,
+                                    ArrayRef<StringRef> dimValueNames,
+                                    ArrayRef<StringRef> symbolValueNames) {
+  printAffineExprInternal(expr, BindingStrength::Weak, dimValueNames,
+                          symbolValueNames);
 }
 
 void ModulePrinter::printAffineExprInternal(
-    AffineExpr expr, BindingStrength enclosingTightness) {
+    AffineExpr expr, BindingStrength enclosingTightness,
+    ArrayRef<StringRef> dimValueNames, ArrayRef<StringRef> symbolValueNames) {
   const char *binopSpelling = nullptr;
   switch (expr.getKind()) {
-  case AffineExprKind::SymbolId:
-    os << 's' << expr.cast<AffineSymbolExpr>().getPosition();
+  case AffineExprKind::SymbolId: {
+    unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
+    if (pos < symbolValueNames.size())
+      os << "symbol(%" << symbolValueNames[pos] << ')';
+    else
+      os << 's' << pos;
     return;
-  case AffineExprKind::DimId:
-    os << 'd' << expr.cast<AffineDimExpr>().getPosition();
+  }
+  case AffineExprKind::DimId: {
+    unsigned pos = expr.cast<AffineDimExpr>().getPosition();
+    if (pos < dimValueNames.size())
+      os << '%' << dimValueNames[pos];
+    else
+      os << 'd' << pos;
     return;
+  }
   case AffineExprKind::Constant:
     os << expr.cast<AffineConstantExpr>().getValue();
     return;
@@ -965,13 +982,16 @@ void ModulePrinter::printAffineExprInternal(
     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
     if (rhsConst && rhsConst.getValue() == -1) {
       os << "-";
-      printAffineExprInternal(lhsExpr, BindingStrength::Strong);
+      printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
+                              symbolValueNames);
       return;
     }
 
-    printAffineExprInternal(lhsExpr, BindingStrength::Strong);
+    printAffineExprInternal(lhsExpr, BindingStrength::Strong, dimValueNames,
+                            symbolValueNames);
     os << binopSpelling;
-    printAffineExprInternal(rhsExpr, BindingStrength::Strong);
+    printAffineExprInternal(rhsExpr, BindingStrength::Strong, dimValueNames,
+                            symbolValueNames);
 
     if (enclosingTightness == BindingStrength::Strong)
       os << ')';
@@ -989,12 +1009,15 @@ void ModulePrinter::printAffineExprInternal(
       AffineExpr rrhsExpr = rhs.getRHS();
       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
         if (rrhs.getValue() == -1) {
-          printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+                                  symbolValueNames);
           os << " - ";
           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
-            printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
+            printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+                                    dimValueNames, symbolValueNames);
           } else {
-            printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak);
+            printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
+                                    dimValueNames, symbolValueNames);
           }
 
           if (enclosingTightness == BindingStrength::Strong)
@@ -1003,9 +1026,11 @@ void ModulePrinter::printAffineExprInternal(
         }
 
         if (rrhs.getValue() < -1) {
-          printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+          printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+                                  symbolValueNames);
           os << " - ";
-          printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong);
+          printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
+                                  dimValueNames, symbolValueNames);
           os << " * " << -rrhs.getValue();
           if (enclosingTightness == BindingStrength::Strong)
             os << ')';
@@ -1018,7 +1043,8 @@ void ModulePrinter::printAffineExprInternal(
   // Pretty print addition to a negative number as a subtraction.
   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
     if (rhsConst.getValue() < 0) {
-      printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+      printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+                              symbolValueNames);
       os << " - " << -rhsConst.getValue();
       if (enclosingTightness == BindingStrength::Strong)
         os << ')';
@@ -1026,9 +1052,11 @@ void ModulePrinter::printAffineExprInternal(
     }
   }
 
-  printAffineExprInternal(lhsExpr, BindingStrength::Weak);
+  printAffineExprInternal(lhsExpr, BindingStrength::Weak, dimValueNames,
+                          symbolValueNames);
   os << " + ";
-  printAffineExprInternal(rhsExpr, BindingStrength::Weak);
+  printAffineExprInternal(rhsExpr, BindingStrength::Weak, dimValueNames,
+                          symbolValueNames);
 
   if (enclosingTightness == BindingStrength::Strong)
     os << ')';
@@ -1210,6 +1238,23 @@ public:
     os.indent(currentIndent) << "}";
   }
 
+  void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
+                              ArrayRef<Value *> operands) {
+    AffineMap map = mapAttr.getValue();
+    unsigned numDims = map.getNumDims();
+    SmallVector<StringRef, 2> dimValueNames;
+    SmallVector<StringRef, 1> symbolValueNames;
+    for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+      if (i < numDims)
+        dimValueNames.push_back(valueNames[operands[i]]);
+      else
+        symbolValueNames.push_back(valueNames[operands[i]]);
+    }
+    interleaveComma(map.getResults(), [&](AffineExpr expr) {
+      printAffineExpr(expr, dimValueNames, symbolValueNames);
+    });
+  }
+
   // Number of spaces used for indenting nested operations.
   const static unsigned indentWidth = 2;
 
index f5108b5..d4428d4 100644 (file)
@@ -273,6 +273,11 @@ public:
   ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
                                                   IntegerSet &set);
 
+  /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+  ParseResult
+  parseAffineMapOfSSAIds(AffineMap &map,
+                         llvm::function_ref<ParseResult(bool)> parseElement);
+
 private:
   /// The Parser is subclassed and reinstantiated.  Do not add additional
   /// non-trivial state here, add it to the ParserState class.
@@ -1668,11 +1673,17 @@ namespace {
 /// bodies.
 class AffineParser : public Parser {
 public:
-  explicit AffineParser(ParserState &state) : Parser(state) {}
+  AffineParser(ParserState &state, bool allowParsingSSAIds = false,
+               llvm::function_ref<ParseResult(bool)> parseElement = nullptr)
+      : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
+        parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
 
   AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
   ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
   IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
+  ParseResult parseAffineMapOfSSAIds(AffineMap &map);
+  void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
+                              unsigned &numDims);
 
 private:
   // Binary affine op parsing.
@@ -1691,6 +1702,8 @@ private:
   AffineExpr parseNegateExpression(AffineExpr lhs);
   AffineExpr parseIntegerExpr();
   AffineExpr parseBareIdExpr();
+  AffineExpr parseSSAIdExpr(bool isSymbol);
+  AffineExpr parseSymbolSSAIdExpr();
 
   AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
                                    AffineExpr rhs, SMLoc opLoc);
@@ -1703,6 +1716,10 @@ private:
   AffineExpr parseAffineConstraint(bool *isEq);
 
 private:
+  bool allowParsingSSAIds;
+  llvm::function_ref<ParseResult(bool)> parseElement;
+  unsigned numDimOperands;
+  unsigned numSymbolOperands;
   SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
 };
 } // end anonymous namespace
@@ -1892,6 +1909,42 @@ AffineExpr AffineParser::parseBareIdExpr() {
   return (emitError("use of undeclared identifier"), nullptr);
 }
 
+/// Parse an SSA id which may appear in an affine expression.
+AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
+  if (!allowParsingSSAIds)
+    return (emitError("unexpected ssa identifier"), nullptr);
+  if (getToken().isNot(Token::percent_identifier))
+    return (emitError("expected ssa identifier"), nullptr);
+  auto name = getTokenSpelling();
+  // Check if we already parsed this SSA id.
+  for (auto entry : dimsAndSymbols) {
+    if (entry.first == name) {
+      consumeToken(Token::percent_identifier);
+      return entry.second;
+    }
+  }
+  // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
+  if (parseElement(isSymbol))
+    return (emitError("failed to parse ssa identifier"), nullptr);
+  auto idExpr = isSymbol
+                    ? getAffineSymbolExpr(numSymbolOperands++, getContext())
+                    : getAffineDimExpr(numDimOperands++, getContext());
+  dimsAndSymbols.push_back({name, idExpr});
+  return idExpr;
+}
+
+AffineExpr AffineParser::parseSymbolSSAIdExpr() {
+  if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
+      parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
+    return nullptr;
+  AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
+  if (!symbolExpr)
+    return nullptr;
+  if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
+    return nullptr;
+  return symbolExpr;
+}
+
 /// Parse a positive integral constant appearing in an affine expression.
 ///
 ///   affine-expr ::= integer-literal
@@ -1917,6 +1970,10 @@ AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
   switch (getToken().getKind()) {
   case Token::bare_identifier:
     return parseBareIdExpr();
+  case Token::kw_symbol:
+    return parseSymbolSSAIdExpr();
+  case Token::percent_identifier:
+    return parseSSAIdExpr(/*isSymbol=*/false);
   case Token::integer:
     return parseIntegerExpr();
   case Token::l_paren:
@@ -2109,6 +2166,26 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
   return failure();
 }
 
+/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ParseResult AffineParser::parseAffineMapOfSSAIds(AffineMap &map) {
+  SmallVector<AffineExpr, 4> exprs;
+  auto parseElt = [&]() -> ParseResult {
+    auto elt = parseAffineExpr();
+    exprs.push_back(elt);
+    return elt ? success() : failure();
+  };
+
+  // Parse a multi-dimensional affine expression (a comma-separated list of
+  // 1-d affine expressions); the list cannot be empty. Grammar:
+  // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
+  if (parseCommaSeparatedList(parseElt))
+    return failure();
+  // Parsed a valid affine map.
+  map = builder.getAffineMap(numDimOperands,
+                             dimsAndSymbols.size() - numDimOperands, exprs);
+  return success();
+}
+
 /// Parse the range and sizes affine map definition inline.
 ///
 ///  affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
@@ -2221,6 +2298,14 @@ ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
   return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
 }
 
+/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
+/// parse SSA value uses encountered while parsing affine expressions.
+ParseResult Parser::parseAffineMapOfSSAIds(
+    AffineMap &map, llvm::function_ref<ParseResult(bool)> parseElement) {
+  return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
+      .parseAffineMapOfSSAIds(map);
+}
+
 //===----------------------------------------------------------------------===//
 // OperationParser
 //===----------------------------------------------------------------------===//
@@ -3014,6 +3099,16 @@ public:
     return success(parser.consumeIf(Token::r_paren));
   }
 
+  /// Parse a `[` token.
+  ParseResult parseLSquare() override {
+    return parser.parseToken(Token::l_square, "expected '['");
+  }
+
+  /// Parse a `]` token.
+  ParseResult parseRSquare() override {
+    return parser.parseToken(Token::r_square, "expected ']'");
+  }
+
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//
@@ -3155,6 +3250,37 @@ public:
     return failure();
   }
 
+  /// Parse an AffineMap of SSA ids.
+  ParseResult parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands,
+                                     Attribute &mapAttr, StringRef attrName,
+                                     SmallVectorImpl<NamedAttribute> &attrs) {
+    SmallVector<OperandType, 2> dimOperands;
+    SmallVector<OperandType, 1> symOperands;
+
+    auto parseElement = [&](bool isSymbol) -> ParseResult {
+      OperandType operand;
+      if (parseOperand(operand))
+        return failure();
+      if (isSymbol)
+        symOperands.push_back(operand);
+      else
+        dimOperands.push_back(operand);
+      return success();
+    };
+
+    AffineMap map;
+    if (parser.parseAffineMapOfSSAIds(map, parseElement))
+      return failure();
+    // Add AffineMap attribute.
+    mapAttr = parser.builder.getAffineMapAttr(map);
+    attrs.push_back(parser.builder.getNamedAttr(attrName, mapAttr));
+
+    // Add dim operands before symbol operands in 'operands'.
+    operands.assign(dimOperands.begin(), dimOperands.end());
+    operands.append(symOperands.begin(), symOperands.end());
+    return success();
+  }
+
   //===--------------------------------------------------------------------===//
   // Region Parsing
   //===--------------------------------------------------------------------===//
index 98645cb..18067a8 100644 (file)
@@ -113,6 +113,7 @@ TOK_KEYWORD(opaque)
 TOK_KEYWORD(size)
 TOK_KEYWORD(sparse)
 TOK_KEYWORD(step)
+TOK_KEYWORD(symbol)
 TOK_KEYWORD(tensor)
 TOK_KEYWORD(to)
 TOK_KEYWORD(true)
diff --git a/mlir/test/AffineOps/load-store.mlir b/mlir/test/AffineOps/load-store.mlir
new file mode 100644 (file)
index 0000000..fdb32f8
--- /dev/null
@@ -0,0 +1,168 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0, d1)
+
+// Test with just loop IVs.
+func @test0(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0, %i1] : memref<100x100xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1) -> (d0 + 3, d1 + 7)
+
+// Test with loop IVs and constants.
+func @test1(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+      affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + 3, %i1 + 7] : memref<100x100xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3) -> (d0 + d1, d2 + d3)
+
+// Test with loop IVs and function args without 'symbol' keyword (should
+// be parsed as dim identifiers).
+func @test2(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+      affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + %arg0, %i1 + %arg1] : memref<100x100xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)
+
+// Test with loop IVs and function args with 'symbol' keyword (should
+// be parsed as symbol identifiers).
+func @test3(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)]
+        : memref<100x100xf32>
+      affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)]
+        : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[%i0 + symbol(%arg0), %i1 + symbol(%arg1)] : memref<100x100xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1)[s0, s1] -> ((d0 + s0) floordiv 3 + 11, (d1 + s1) mod 4 + 7)
+
+// Test with loop IVs, symbols and constants in nested affine expressions.
+func @test4(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<100x100xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11,
+                          (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+      affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11,
+                          (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+// CHECK: %1 = affine.load %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+// CHECK: affine.store %1, %0[(%i0 + symbol(%arg0)) floordiv 3 + 11, (%i1 + symbol(%arg1)) mod 4 + 7] : memref<100x100xf32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
+
+// Test with swizzled loop IVs.
+func @test5(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<10x10x10xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      affine.for %i2 = 0 to 10 {
+        %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+        affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2, %i0, %i1] : memref<10x10x10xf32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2, d3, d4) -> (d0 + d1, d2 + d3, d3 + d1 + d4)
+
+// Test with swizzled loop IVs, duplicate args, and function args used as dims.
+// Dim identifiers are assigned in parse order:
+// d0 = %i2, d1 = %arg0, d2 = %i0, d3 = %i1, d4 = %arg1
+func @test6(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<10x10x10xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      affine.for %i2 = 0 to 10 {
+        %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1]
+          : memref<10x10x10xf32>
+        affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1]
+          : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2 + %arg0, %i0 + %i1, %i1 + %arg0 + %arg1] : memref<10x10x10xf32>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK: [[MAP0:#map[0-9]+]] = (d0, d1, d2)[s0, s1] -> (d0 + s0, d1 + d2, d2 + s0 + s1)
+
+// Test with swizzled loop IVs, duplicate args, and function args used as syms.
+// Dim and symbol identifiers are assigned in parse order:
+// d0 = %i2, d1 = %i0, d2 = %i1
+// s0 = %arg0, s1 = %arg1
+func @test6(%arg0 : index, %arg1 : index) {
+  %0 = alloc() : memref<10x10x10xf32>
+  affine.for %i0 = 0 to 10 {
+    affine.for %i1 = 0 to 10 {
+      affine.for %i2 = 0 to 10 {
+        %1 = affine.load %0[%i2 + symbol(%arg0),
+                            %i0 + %i1,
+                            %i1 + symbol(%arg0) + symbol(%arg1)]
+                              : memref<10x10x10xf32>
+        affine.store %1, %0[%i2 + symbol(%arg0),
+                             %i0 + %i1,
+                             %i1 + symbol(%arg0) + symbol(%arg1)]
+                              : memref<10x10x10xf32>
+// CHECK: %1 = affine.load %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32>
+// CHECK: affine.store %1, %0[%i2 + symbol(%arg0), %i0 + %i1, %i1 + symbol(%arg0) + symbol(%arg1)] : memref<10x10x10xf32>
+      }
+    }
+  }
+  return
+}
\ No newline at end of file