[fir][NFC] Update fir.iterate_while op
authorEric Schweitz <eschweitz@nvidia.com>
Wed, 29 Sep 2021 16:12:17 +0000 (18:12 +0200)
committerValentin Clement <clementval@gmail.com>
Wed, 29 Sep 2021 16:13:43 +0000 (18:13 +0200)
Add getFinalValueAttrName() and remove specified number of
inlined elements for SmallVector. This patch is mainly motivated
to help the upstreaming effort.

This patch is part of the upstreaming effort from fir-dev branch.

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Valentin Clement <clementval@gmail.com>
Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D110710

flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Optimizer/Dialect/FIROps.cpp

index 279699b97df95d8b08fb73fa385d2a64131c775d..69a0b692833af520a2189bb6c848e78908ef70e7 100644 (file)
@@ -2577,6 +2577,9 @@ def fir_IterWhileOp : region_Op<"iterate_while",
   ];
 
   let extraClassDeclaration = [{
+    static constexpr llvm::StringRef getFinalValueAttrName() {
+      return "finalValue";
+    }
     mlir::Block *getBody() { return &region().front(); }
     mlir::Value getIterateVar() { return getBody()->getArgument(1); }
     mlir::Value getInductionVar() { return getBody()->getArgument(0); }
index b9e0d9a2a07b9330e4bdbf905717416d1e54caed..7f645e393bd047e2d5c0ad95c7aeec4636574170 100644 (file)
@@ -799,7 +799,7 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
   result.addOperands({lb, ub, step, iterate});
   if (finalCountValue) {
     result.addTypes(builder.getIndexType());
-    result.addAttribute(finalValueAttrName(result.name), builder.getUnitAttr());
+    result.addAttribute(getFinalValueAttrName(), builder.getUnitAttr());
   }
   result.addTypes(iterate.getType());
   result.addOperands(iterArgs);
@@ -841,7 +841,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
     return mlir::failure();
 
   // Parse the initial iteration arguments.
-  llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs;
+  llvm::SmallVector<mlir::OpAsmParser::OperandType> regionArgs;
   auto prependCount = false;
 
   // Induction variable.
@@ -849,8 +849,8 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
   regionArgs.push_back(iterateVar);
 
   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
-    llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
-    llvm::SmallVector<mlir::Type, 4> regionTypes;
+    llvm::SmallVector<mlir::OpAsmParser::OperandType> operands;
+    llvm::SmallVector<mlir::Type> regionTypes;
     // Parse assignment list and results type list.
     if (parser.parseAssignmentList(regionArgs, operands) ||
         parser.parseArrowTypeList(regionTypes))
@@ -860,9 +860,9 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
     llvm::ArrayRef<mlir::Type> resTypes = regionTypes;
     resTypes = prependCount ? resTypes.drop_front(2) : resTypes;
     // Resolve input operands.
-    for (auto operand_type : llvm::zip(operands, resTypes))
-      if (parser.resolveOperand(std::get<0>(operand_type),
-                                std::get<1>(operand_type), result.operands))
+    for (auto operandType : llvm::zip(operands, resTypes))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
         return failure();
     if (prependCount) {
       result.addTypes(regionTypes);
@@ -871,7 +871,7 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
       result.addTypes(resTypes);
     }
   } else if (succeeded(parser.parseOptionalArrow())) {
-    llvm::SmallVector<mlir::Type, 4> typeList;
+    llvm::SmallVector<mlir::Type> typeList;
     if (parser.parseLParen() || parser.parseTypeList(typeList) ||
         parser.parseRParen())
       return failure();
@@ -888,10 +888,10 @@ static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser,
   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
     return mlir::failure();
 
-  llvm::SmallVector<mlir::Type, 4> argTypes;
+  llvm::SmallVector<mlir::Type> argTypes;
   // Induction variable (hidden)
   if (prependCount)
-    result.addAttribute(IterWhileOp::finalValueAttrName(result.name),
+    result.addAttribute(IterWhileOp::getFinalValueAttrName(),
                         builder.getUnitAttr());
   else
     argTypes.push_back(indexType);
@@ -984,7 +984,8 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) {
   } else if (op.finalValue()) {
     p << " -> (" << op.getResultTypes() << ')';
   }
-  p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"finalValue"});
+  p.printOptionalAttrDictWithKeyword(op->getAttrs(),
+                                     {IterWhileOp::getFinalValueAttrName()});
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/true);
 }
@@ -997,7 +998,7 @@ bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) {
 
 mlir::LogicalResult
 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
-  for (auto op : ops)
+  for (auto *op : ops)
     op->moveBefore(*this);
   return success();
 }