Add a few utility overloads for OpAsmParser methods:
authorRiver Riddle <riverriddle@google.com>
Wed, 5 Jun 2019 06:33:18 +0000 (23:33 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:18:21 +0000 (16:18 -0700)
* Add a getCurrentLocation that returns the location directly.
* Add parseOperandList/parseTrailingOperandList overloads without the required operand count.

PiperOrigin-RevId: 251585488

mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/AffineOps/AffineOps.cpp
mlir/lib/GPU/IR/GPUDialect.cpp
mlir/lib/IR/Function.cpp
mlir/lib/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/VectorOps/VectorOps.cpp

index 0e9a4c3..f1017b1 100644 (file)
@@ -94,8 +94,7 @@ ParseResult linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
   SmallVector<Type, 8> types;
   if (parser->parseOperand(memRefInfo) ||
-      parser->parseOperandList(indexingsInfo, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonTypeList(types))
     return failure();
index 1ebed7c..d0c445f 100644 (file)
@@ -152,7 +152,11 @@ public:
 
   /// Get the location of the next token and store it into the argument.  This
   /// always succeeds.
-  virtual ParseResult getCurrentLocation(llvm::SMLoc *loc) = 0;
+  virtual llvm::SMLoc getCurrentLocation() = 0;
+  ParseResult getCurrentLocation(llvm::SMLoc *loc) {
+    *loc = getCurrentLocation();
+    return success();
+  }
 
   /// This parses... a comma!
   virtual ParseResult parseComma() = 0;
@@ -189,8 +193,7 @@ public:
 
   /// Parse a type of a specific kind, e.g. a FunctionType.
   template <typename TypeType> ParseResult parseColonType(TypeType &result) {
-    llvm::SMLoc loc;
-    getCurrentLocation(&loc);
+    llvm::SMLoc loc = getCurrentLocation();
 
     // Parse any kind of type.
     Type type;
@@ -261,8 +264,7 @@ public:
   template <typename AttrType>
   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
                              SmallVectorImpl<NamedAttribute> &attrs) {
-    llvm::SMLoc loc;
-    getCurrentLocation(&loc);
+    llvm::SMLoc loc = getCurrentLocation();
 
     // Parse any kind of attribute.
     Attribute attr;
@@ -298,7 +300,7 @@ public:
 
   /// These are the supported delimiters around operand lists, used by
   /// parseOperandList.
-  enum Delimiter {
+  enum class Delimiter {
     /// Zero or more operands with no delimiters.
     None,
     /// Parens surrounding zero or more operands.
@@ -317,6 +319,10 @@ public:
   parseOperandList(SmallVectorImpl<OperandType> &result,
                    int requiredOperandCount = -1,
                    Delimiter delimiter = Delimiter::None) = 0;
+  ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
+                               Delimiter delimiter) {
+    return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
+  }
 
   /// Parse zero or more trailing SSA comma-separated trailing operand
   /// references with a specified surrounding delimiter, and an optional
@@ -325,6 +331,11 @@ public:
   parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
                            int requiredOperandCount = -1,
                            Delimiter delimiter = Delimiter::None) = 0;
+  ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
+                                       Delimiter delimiter) {
+    return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
+                                    delimiter);
+  }
 
   /// Parses a region. Any parsed blocks are appended to "region" and must be
   /// moved to the op regions after the op is created. The first block of the
index 9189acf..fc70807 100644 (file)
@@ -842,8 +842,7 @@ static ParseResult parseBound(bool isLower, OperationState *result,
   }
 
   // Get the attribute location.
-  llvm::SMLoc attrLoc;
-  p->getCurrentLocation(&attrLoc);
+  llvm::SMLoc attrLoc = p->getCurrentLocation();
 
   Attribute boundAttr;
   if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
@@ -916,10 +915,9 @@ ParseResult AffineForOp::parse(OpAsmParser *parser, OperationState *result) {
         getStepAttrName(),
         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
   } else {
-    llvm::SMLoc stepLoc;
+    llvm::SMLoc stepLoc = parser->getCurrentLocation();
     IntegerAttr stepAttr;
-    if (parser->getCurrentLocation(&stepLoc) ||
-        parser->parseAttribute(stepAttr, builder.getIndexType(),
+    if (parser->parseAttribute(stepAttr, builder.getIndexType(),
                                getStepAttrName().data(), result->attributes))
       return failure();
 
index 5c0539a..c252bec 100644 (file)
@@ -252,11 +252,11 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) {
   // to resolve the operands passed to the kernel arguments.
   SmallVector<Type, 4> dataTypes;
   if (!parser->parseOptionalKeyword(getArgsKeyword().data())) {
-    llvm::SMLoc argsLoc;
+    llvm::SMLoc argsLoc = parser->getCurrentLocation();
 
     regionArgs.push_back({});
     dataOperands.push_back({});
-    if (parser->getCurrentLocation(&argsLoc) || parser->parseLParen() ||
+    if (parser->parseLParen() ||
         parser->parseRegionArgument(regionArgs.back()) ||
         parser->parseEqual() || parser->parseOperand(dataOperands.back()))
       return failure();
index 6ab5a6f..78cce5f 100644 (file)
@@ -241,8 +241,7 @@ parseArgumentList(OpAsmParser *parser, SmallVectorImpl<Type> &argTypes,
   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
   // sometimes not.
   auto parseArgument = [&]() -> ParseResult {
-    llvm::SMLoc loc;
-    parser->getCurrentLocation(&loc);
+    llvm::SMLoc loc = parser->getCurrentLocation();
 
     // Parse argument name if present.
     OpAsmParser::OperandType argument;
index e281702..e27994d 100644 (file)
@@ -196,8 +196,7 @@ static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
   Type type;
   llvm::SMLoc trailingTypeLoc;
   if (parser->parseOperand(base) ||
-      parser->parseOperandList(indices, /*requiredOperandCount=*/-1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
       parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
     return failure();
@@ -352,8 +351,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
     if (parser->parseAttribute(funcAttr, "callee", attrs))
       return failure();
 
-  if (parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
-                               OpAsmParser::Delimiter::Paren) ||
+  if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
       parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
     return failure();
index 3b3a040..7ccb0ff 100644 (file)
@@ -268,7 +268,7 @@ ParseResult mlir::linalg::LoadOp::parse(OpAsmParser *parser,
   auto affineIntTy = parser->getBuilder().getIndexType();
   return failure(
       parser->parseOperand(viewInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(type) ||
       parser->resolveOperand(viewInfo, type, result->operands) ||
@@ -391,8 +391,7 @@ ParseResult mlir::linalg::SliceOp::parse(OpAsmParser *parser,
   SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
   SmallVector<Type, 8> types;
   if (parser->parseOperand(baseInfo) ||
-      parser->parseOperandList(indexingsInfo, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonTypeList(types))
     return failure();
@@ -500,7 +499,7 @@ ParseResult mlir::linalg::StoreOp::parse(OpAsmParser *parser,
   return failure(
       parser->parseOperand(storeValueInfo) || parser->parseComma() ||
       parser->parseOperand(viewInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(viewType) ||
       parser->resolveOperand(storeValueInfo, viewType.getElementType(),
@@ -576,8 +575,7 @@ ParseResult mlir::linalg::ViewOp::parse(OpAsmParser *parser,
   SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
   Type type;
   if (parser->parseOperand(bufferInfo) ||
-      parser->parseOperandList(indexingsInfo, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexingsInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(type))
     return failure();
@@ -712,12 +710,11 @@ static ParseResult parseLinalgLibraryOp(OpAsmParser *parser,
                                         OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 3> ops;
   SmallVector<Type, 3> types;
-  return failure(
-      parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) ||
-      parser->parseOptionalAttributeDict(result->attributes) ||
-      parser->parseColonTypeList(types) ||
-      parser->resolveOperands(ops, types, parser->getNameLoc(),
-                              result->operands));
+  return failure(parser->parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonTypeList(types) ||
+                 parser->resolveOperands(ops, types, parser->getNameLoc(),
+                                         result->operands));
 }
 
 namespace mlir {
index a3d44f9..2ca1da9 100644 (file)
@@ -3193,9 +3193,8 @@ public:
   // High level parsing methods.
   //===--------------------------------------------------------------------===//
 
-  ParseResult getCurrentLocation(llvm::SMLoc *loc) override {
-    *loc = parser.getToken().getLoc();
-    return success();
+  llvm::SMLoc getCurrentLocation() override {
+    return parser.getToken().getLoc();
   }
 
   ParseResult parseComma() override {
index da0a48b..4b2940a 100644 (file)
@@ -108,14 +108,14 @@ ParseResult mlir::parseDimAndSymbolList(OpAsmParser *parser,
                                         SmallVector<Value *, 4> &operands,
                                         unsigned &numDims) {
   SmallVector<OpAsmParser::OperandType, 8> opInfos;
-  if (parser->parseOperandList(opInfos, -1, OpAsmParser::Delimiter::Paren))
+  if (parser->parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
     return failure();
   // Store number of dimensions for validation by caller.
   numDims = opInfos.size();
 
   // Parse the optional symbol operands.
   auto affineIntTy = parser->getBuilder().getIndexType();
-  if (parser->parseOperandList(opInfos, -1,
+  if (parser->parseOperandList(opInfos,
                                OpAsmParser::Delimiter::OptionalSquare) ||
       parser->resolveOperands(opInfos, affineIntTy, operands))
     return failure();
@@ -415,8 +415,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 4> operands;
   auto calleeLoc = parser->getNameLoc();
   if (parser->parseAttribute(calleeAttr, "callee", result->attributes) ||
-      parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
-                               OpAsmParser::Delimiter::Paren) ||
+      parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(calleeType) ||
       parser->addTypesToList(calleeType.getResults(), result->types) ||
@@ -507,8 +506,7 @@ static ParseResult parseCallIndirectOp(OpAsmParser *parser,
   return failure(
       parser->parseOperand(callee) ||
       parser->getCurrentLocation(&operandsLoc) ||
-      parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
-                               OpAsmParser::Delimiter::Paren) ||
+      parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(calleeType) ||
       parser->resolveOperand(callee, calleeType, result->operands) ||
@@ -1406,15 +1404,12 @@ ParseResult DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
   // *) destination memref followed by its indices (in square brackets).
   // *) dma size in KiB.
   if (parser->parseOperand(srcMemRefInfo) ||
-      parser->parseOperandList(srcIndexInfos, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
       parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
-      parser->parseOperandList(dstIndexInfos, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
       parser->parseComma() || parser->parseOperand(numElementsInfo) ||
       parser->parseComma() || parser->parseOperand(tagMemrefInfo) ||
-      parser->parseOperandList(tagIndexInfos, -1,
-                               OpAsmParser::Delimiter::Square))
+      parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
     return failure();
 
   // Parse optional stride and elements per stride.
@@ -1537,8 +1532,7 @@ ParseResult DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
 
   // Parse tag memref, its indices, and dma size.
   if (parser->parseOperand(tagMemrefInfo) ||
-      parser->parseOperandList(tagIndexInfos, -1,
-                               OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) ||
       parser->parseComma() || parser->parseOperand(numElementsInfo) ||
       parser->parseColonType(type) ||
       parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
@@ -1586,7 +1580,7 @@ static ParseResult parseExtractElementOp(OpAsmParser *parser,
   auto affineIntTy = parser->getBuilder().getIndexType();
   return failure(
       parser->parseOperand(aggregateInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(type) ||
       parser->resolveOperand(aggregateInfo, type, result->operands) ||
@@ -1662,7 +1656,7 @@ static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
   auto affineIntTy = parser->getBuilder().getIndexType();
   return failure(
       parser->parseOperand(memrefInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(type) ||
       parser->resolveOperand(memrefInfo, type, result->operands) ||
@@ -1845,9 +1839,8 @@ OpFoldResult RemIUOp::fold(ArrayRef<Attribute> operands) {
 static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> opInfo;
   SmallVector<Type, 2> types;
-  llvm::SMLoc loc;
-  return failure(parser->getCurrentLocation(&loc) ||
-                 parser->parseOperandList(opInfo) ||
+  llvm::SMLoc loc = parser->getCurrentLocation();
+  return failure(parser->parseOperandList(opInfo) ||
                  (!opInfo.empty() && parser->parseColonTypeList(types)) ||
                  parser->resolveOperands(opInfo, types, loc, result->operands));
 }
@@ -1963,7 +1956,7 @@ static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
   return failure(
       parser->parseOperand(storeValueInfo) || parser->parseComma() ||
       parser->parseOperand(memrefInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(memrefType) ||
       parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
index 23b2f99..580dd66 100644 (file)
@@ -129,8 +129,8 @@ ParseResult VectorTransferReadOp::parse(OpAsmParser *parser,
 
   // Parsing with support for optional paddingValue.
   if (parser->parseOperand(memrefInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
-      parser->parseTrailingOperandList(paddingInfo, -1,
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
+      parser->parseTrailingOperandList(paddingInfo,
                                        OpAsmParser::Delimiter::Paren) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonTypeList(types))
@@ -293,7 +293,7 @@ ParseResult VectorTransferWriteOp::parse(OpAsmParser *parser,
   auto indexType = parser->getBuilder().getIndexType();
   if (parser->parseOperand(storeValueInfo) || parser->parseComma() ||
       parser->parseOperand(memrefInfo) ||
-      parser->parseOperandList(indexInfo, -1, OpAsmParser::Delimiter::Square) ||
+      parser->parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonTypeList(types))
     return failure();