[mlir] Add `Block::eraseArguments` that erases a subrange
authorJeff Niu <jeff@modular.com>
Mon, 29 Aug 2022 21:32:14 +0000 (14:32 -0700)
committerJeff Niu <jeff@modular.com>
Mon, 29 Aug 2022 22:34:21 +0000 (15:34 -0700)
This patch adds a an `eraseArguments` function that erases a subrange of
a block's arguments. This can be used inplace of the terrible pattern

```
block->eraseArguments(llvm::to_vector(llvm::seq(...)));
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132890

mlir/include/mlir/IR/Block.h
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
mlir/lib/IR/Block.cpp

index ee4ebcf..1e4e51c 100644 (file)
@@ -105,6 +105,8 @@ public:
 
   /// Erase the argument at 'index' and remove it from the argument list.
   void eraseArgument(unsigned index);
+  /// Erases 'num' arguments from the index 'start'.
+  void eraseArguments(unsigned start, unsigned num);
   /// Erases the arguments listed in `argIndices` and removes them from the
   /// argument list.
   /// `argIndices` is allowed to have duplicates and can be in any order.
index 46e751e..3b261ff 100644 (file)
@@ -35,8 +35,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion,
   rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
   rewriter.eraseOp(yield);
 
-  headBlock->eraseArguments(
-      llvm::to_vector<4>(llvm::seq<unsigned>(0, headBlock->getNumArguments())));
+  headBlock->eraseArguments(0, headBlock->getNumArguments());
 }
 
 static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
index 66a0e36..703ff5b 100644 (file)
@@ -398,8 +398,7 @@ mlir::affineParallelize(AffineForOp forOp,
   // "main" induction variable whenc coming from a non-parallel for.
   unsigned numIVs = 1;
   yieldOp->setOperands(reducedValues);
-  newPloop.getBody()->eraseArguments(
-      llvm::to_vector<4>(llvm::seq<unsigned>(numIVs, numReductions + numIVs)));
+  newPloop.getBody()->eraseArguments(numIVs, numReductions);
 
   forOp.erase();
   return success();
index 355049d..9424305 100644 (file)
@@ -162,8 +162,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
       thenBlock.getArgument(ivs.index())
           .replaceAllUsesExcept(newIndex, newIndex);
     }
-    thenBlock.eraseArguments(llvm::to_vector<4>(
-        llvm::seq((unsigned)0, thenBlock.getNumArguments())));
+    thenBlock.eraseArguments(0, thenBlock.getNumArguments());
   } else {
     innerLoop.getRegion().takeBody(op.getRegion());
     b.setInsertionPointToStart(innerLoop.getBody());
index 18a79d2..cc84b0d 100644 (file)
@@ -186,6 +186,15 @@ void Block::eraseArgument(unsigned index) {
     arg.setArgNumber(index++);
 }
 
+void Block::eraseArguments(unsigned start, unsigned num) {
+  assert(start + num <= arguments.size());
+  for (unsigned i = 0; i < num; ++i)
+    arguments[start + i].destroy();
+  arguments.erase(arguments.begin() + start, arguments.begin() + start + num);
+  for (BlockArgument arg : llvm::drop_begin(arguments, start))
+    arg.setArgNumber(start++);
+}
+
 void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
   BitVector eraseIndices(getNumArguments());
   for (unsigned i : argIndices)