[mlir] Refactoring dialect and test code to use parseCommaSeparatedList
authorJakub Tucholski <jtucholski00@gmail.com>
Mon, 2 May 2022 17:55:36 +0000 (13:55 -0400)
committerJakub Tucholski <jtucholski00@gmail.com>
Mon, 9 May 2022 16:36:54 +0000 (12:36 -0400)
Issue #55173

Reviewed By: lattner, rriddle

Differential Revision: https://reviews.llvm.org/D124791

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/DLTI/DLTI.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp

index 0a13118..f484c2b 100644 (file)
@@ -929,13 +929,8 @@ public:
 
   /// Parse a type list.
   ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
-    do {
-      Type type;
-      if (parseType(type))
-        return failure();
-      result.push_back(type);
-    } while (succeeded(parseOptionalComma()));
-    return success();
+    return parseCommaSeparatedList(
+        [&]() { return parseType(result.emplace_back()); });
   }
 
   /// Parse an arrow followed by a type list.
index 5e33d0c..a841048 100644 (file)
@@ -3454,7 +3454,7 @@ static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
   SmallVector<int32_t> numMapsPerGroup;
   SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
-  do {
+  auto parseOperands = [&]() {
     if (succeeded(parser.parseOptionalKeyword(
             kind == MinMaxKind::Min ? "min" : "max"))) {
       mapOperands.clear();
@@ -3482,9 +3482,9 @@ static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
         return failure();
       numMapsPerGroup.push_back(1);
     }
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (failed(parser.parseRParen()))
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
     return failure();
 
   unsigned totalNumDims = 0;
@@ -3572,7 +3572,7 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
     if (parser.parseLParen())
       return failure();
-    do {
+    auto parseAttributes = [&]() -> ParseResult {
       // Parse a single quoted string via the attribute parsing, and then
       // verify it is a member of the enum and convert to it's integer
       // representation.
@@ -3589,8 +3589,9 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
       reductions.push_back(builder.getI64IntegerAttr(
           static_cast<int64_t>(reduction.getValue())));
       // While we keep getting commas, keep parsing.
-    } while (succeeded(parser.parseOptionalComma()));
-    if (parser.parseRParen())
+      return success();
+    };
+    if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
       return failure();
   }
   result.addAttribute(AffineParallelOp::getReductionsAttrName(),
index 7382fba..046152a 100644 (file)
@@ -286,14 +286,11 @@ DataLayoutSpecAttr DataLayoutSpecAttr::parse(AsmParser &parser) {
     return get(parser.getContext(), {});
 
   SmallVector<DataLayoutEntryInterface> entries;
-  do {
-    entries.emplace_back();
-    if (failed(parser.parseAttribute(entries.back())))
-      return {};
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (failed(parser.parseGreater()))
+  if (parser.parseCommaSeparatedList(
+          [&]() { return parser.parseAttribute(entries.emplace_back()); }) ||
+      parser.parseGreater())
     return {};
+
   return getChecked([&] { return parser.emitError(parser.getNameLoc()); },
                     parser.getContext(), entries);
 }
index f80a429..3724b49 100644 (file)
@@ -540,7 +540,7 @@ parseGEPIndices(OpAsmParser &parser,
                 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &indices,
                 DenseIntElementsAttr &structIndices) {
   SmallVector<int32_t> constantIndices;
-  do {
+  parser.parseCommaSeparatedList([&]() -> ParseResult {
     int32_t constantIndex;
     OptionalParseResult parsedInteger =
         parser.parseOptionalInteger(constantIndex);
@@ -548,13 +548,12 @@ parseGEPIndices(OpAsmParser &parser,
       if (failed(parsedInteger.getValue()))
         return failure();
       constantIndices.push_back(constantIndex);
-      continue;
+      return success();
     }
 
     constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
-    if (failed(parser.parseOperand(indices.emplace_back())))
-      return failure();
-  } while (succeeded(parser.parseOptionalComma()));
+    return parser.parseOperand(indices.emplace_back());
+  });
 
   structIndices = parser.getBuilder().getI32TensorAttr(constantIndices);
   return success();
@@ -2868,22 +2867,21 @@ Attribute FMFAttr::parse(AsmParser &parser, Type type) {
 
   FastmathFlags flags = {};
   if (failed(parser.parseOptionalGreater())) {
-    do {
+    auto parseFlags = [&]() -> ParseResult {
       StringRef elemName;
       if (failed(parser.parseKeyword(&elemName)))
-        return {};
+        return failure();
 
       auto elem = symbolizeFastmathFlags(elemName);
-      if (!elem) {
-        parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
-            << elemName;
-        return {};
-      }
+      if (!elem)
+        return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
+               << elemName;
 
       flags = flags | *elem;
-    } while (succeeded(parser.parseOptionalComma()));
-
-    if (failed(parser.parseGreater()))
+      return success();
+    };
+    if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
+        parser.parseGreater())
       return {};
   }
 
@@ -3031,23 +3029,19 @@ Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
 
   SmallVector<std::pair<LoopOptionCase, int64_t>> options;
   llvm::SmallDenseSet<LoopOptionCase> seenOptions;
-  do {
+  auto parseLoopOptions = [&]() -> ParseResult {
     StringRef optionName;
     if (parser.parseKeyword(&optionName))
-      return {};
+      return failure();
 
     auto option = symbolizeLoopOptionCase(optionName);
-    if (!option) {
-      parser.emitError(parser.getNameLoc(), "unknown loop option: ")
-          << optionName;
-      return {};
-    }
-    if (!seenOptions.insert(*option).second) {
-      parser.emitError(parser.getNameLoc(), "loop option present twice");
-      return {};
-    }
+    if (!option)
+      return parser.emitError(parser.getNameLoc(), "unknown loop option: ")
+             << optionName;
+    if (!seenOptions.insert(*option).second)
+      return parser.emitError(parser.getNameLoc(), "loop option present twice");
     if (failed(parser.parseEqual()))
-      return {};
+      return failure();
 
     int64_t value;
     switch (*option) {
@@ -3059,22 +3053,20 @@ Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
       else if (succeeded(parser.parseOptionalKeyword("false")))
         value = 0;
       else {
-        parser.emitError(parser.getNameLoc(),
-                         "expected boolean value 'true' or 'false'");
-        return {};
+        return parser.emitError(parser.getNameLoc(),
+                                "expected boolean value 'true' or 'false'");
       }
       break;
     case LoopOptionCase::interleave_count:
     case LoopOptionCase::pipeline_initiation_interval:
-      if (failed(parser.parseInteger(value))) {
-        parser.emitError(parser.getNameLoc(), "expected integer value");
-        return {};
-      }
+      if (failed(parser.parseInteger(value)))
+        return parser.emitError(parser.getNameLoc(), "expected integer value");
       break;
     }
     options.push_back(std::make_pair(*option, value));
-  } while (succeeded(parser.parseOptionalComma()));
-  if (failed(parser.parseGreater()))
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater())
     return {};
 
   llvm::sort(options, llvm::less_first());
index 2a25cef..76de3a3 100644 (file)
@@ -66,19 +66,19 @@ parseOperandList(OpAsmParser &parser, StringRef keyword,
   if (succeeded(parser.parseOptionalRParen()))
     return success();
 
-  do {
-    OpAsmParser::UnresolvedOperand arg;
-    Type type;
-
-    if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
-        parser.parseColonType(type))
-      return failure();
-
-    args.push_back(arg);
-    argTypes.push_back(type);
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (failed(parser.parseRParen()))
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        OpAsmParser::UnresolvedOperand arg;
+        Type type;
+
+        if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
+            parser.parseColonType(type))
+          return failure();
+
+        args.push_back(arg);
+        argTypes.push_back(type);
+        return success();
+      })) ||
+      failed(parser.parseRParen()))
     return failure();
 
   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
index 15e2836..565744c 100644 (file)
@@ -76,7 +76,7 @@ static ParseResult parseAllocateAndAllocator(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
     SmallVectorImpl<Type> &typesAllocator) {
 
-  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+  return parser.parseCommaSeparatedList([&]() {
     OpAsmParser::UnresolvedOperand operand;
     Type type;
     if (parser.parseOperand(operand) || parser.parseColonType(type))
@@ -142,7 +142,7 @@ parseLinearClause(OpAsmParser &parser,
                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                   SmallVectorImpl<Type> &types,
                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
-  do {
+  return parser.parseCommaSeparatedList([&]() {
     OpAsmParser::UnresolvedOperand var;
     Type type;
     OpAsmParser::UnresolvedOperand stepVar;
@@ -153,8 +153,8 @@ parseLinearClause(OpAsmParser &parser,
     vars.push_back(var);
     types.push_back(type);
     stepVars.push_back(stepVar);
-  } while (succeeded(parser.parseOptionalComma()));
-  return success();
+    return success();
+  });
 }
 
 /// Print Linear Clause
@@ -304,12 +304,15 @@ parseReductionVarList(OpAsmParser &parser,
                       SmallVectorImpl<Type> &types,
                       ArrayAttr &redcuctionSymbols) {
   SmallVector<SymbolRefAttr> reductionVec;
-  do {
-    if (parser.parseAttribute(reductionVec.emplace_back()) ||
-        parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
-        parser.parseColonType(types.emplace_back()))
-      return failure();
-  } while (succeeded(parser.parseOptionalComma()));
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(reductionVec.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
   return success();
@@ -386,7 +389,7 @@ static ParseResult parseSynchronizationHint(OpAsmParser &parser,
     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
     return success();
   }
-  do {
+  auto parseKeyword = [&]() -> ParseResult {
     if (failed(parser.parseKeyword(&hintKeyword)))
       return failure();
     if (hintKeyword == "uncontended")
@@ -400,7 +403,10 @@ static ParseResult parseSynchronizationHint(OpAsmParser &parser,
     else
       return parser.emitError(parser.getCurrentLocation())
              << hintKeyword << " is not a valid hint";
-  } while (succeeded(parser.parseOptionalComma()));
+    return success();
+  };
+  if (parser.parseCommaSeparatedList(parseKeyword))
+    return failure();
   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
   return success();
 }
index 417a2da..7fba6ca 100644 (file)
@@ -148,7 +148,7 @@ static ParseResult parseOperationOpAttributes(
   Builder &builder = p.getBuilder();
   SmallVector<Attribute, 4> attrNames;
   if (succeeded(p.parseOptionalLBrace())) {
-    do {
+    auto parseOperands = [&]() {
       StringAttr nameAttr;
       OpAsmParser::UnresolvedOperand operand;
       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
@@ -156,8 +156,9 @@ static ParseResult parseOperationOpAttributes(
         return failure();
       attrNames.push_back(nameAttr);
       attrOperands.push_back(operand);
-    } while (succeeded(p.parseOptionalComma()));
-    if (p.parseRBrace())
+      return success();
+    };
+    if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
       return failure();
   }
   attrNamesAttr = builder.getArrayAttr(attrNames);
index 40e5599..01670e3 100644 (file)
@@ -71,7 +71,7 @@ static ParseResult parseCreateOperationOpAttributes(
   Builder &builder = p.getBuilder();
   SmallVector<Attribute, 4> attrNames;
   if (succeeded(p.parseOptionalLBrace())) {
-    do {
+    auto parseOperands = [&]() {
       StringAttr nameAttr;
       OpAsmParser::UnresolvedOperand operand;
       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
@@ -79,8 +79,9 @@ static ParseResult parseCreateOperationOpAttributes(
         return failure();
       attrNames.push_back(nameAttr);
       attrOperands.push_back(operand);
-    } while (succeeded(p.parseOptionalComma()));
-    if (p.parseRBrace())
+      return success();
+    };
+    if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
       return failure();
   }
   attrNamesAttr = builder.getArrayAttr(attrNames);
index 5745aee..927ff54 100644 (file)
@@ -592,7 +592,7 @@ static ParseResult parseStructMemberDecorations(
     return failure();
 
   // Check for spirv::Decorations.
-  do {
+  auto parseDecorations = [&]() {
     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
     if (!memberDecoration)
       return failure();
@@ -613,10 +613,13 @@ static ParseResult parseStructMemberDecorations(
           static_cast<uint32_t>(memberTypes.size() - 1), 0,
           memberDecoration.getValue(), 0);
     }
+    return success();
+  };
+  if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
+      failed(parser.parseRSquare()))
+    return failure();
 
-  } while (succeeded(parser.parseOptionalComma()));
-
-  return parser.parseRSquare();
+  return success();
 }
 
 // struct-member-decoration ::= integer-literal? spirv-decoration*
@@ -885,17 +888,16 @@ static ParseResult parseKeywordList(
 
   // Keep parsing the keyword and an optional comma following it. If the comma
   // is successfully parsed, then we have more keywords to parse.
-  do {
-    auto loc = parser.getCurrentLocation();
-    StringRef keyword;
-    if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
-      return failure();
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (parser.parseRSquare())
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        auto loc = parser.getCurrentLocation();
+        StringRef keyword;
+        if (parser.parseKeyword(&keyword) ||
+            failed(processKeyword(loc, keyword)))
+          return failure();
+        return success();
+      })))
     return failure();
-
-  return success();
+  return parser.parseRSquare();
 }
 
 /// Parses a spirv::InterfaceVarABIAttr.
index 4be64e4..ab144cd 100644 (file)
@@ -212,11 +212,12 @@ struct DLTestDialect : Dialect {
       return CustomDataLayoutSpec::get(parser.getContext(), {});
 
     SmallVector<DataLayoutEntryInterface> entries;
-    do {
+    parser.parseCommaSeparatedList([&]() {
       entries.emplace_back();
       ok = succeeded(parser.parseAttribute(entries.back()));
       assert(ok);
-    } while (succeeded(parser.parseOptionalComma()));
+      return success();
+    });
     ok = succeeded(parser.parseGreater());
     assert(ok);
     return CustomDataLayoutSpec::get(parser.getContext(), entries);