[MLIR] Replace splitBlock() with createBlock in GenericAtomicRMWOp lowering.
authorAlexander Belyaev <pifon@google.com>
Sat, 25 Apr 2020 15:35:49 +0000 (17:35 +0200)
committerAlexander Belyaev <pifon@google.com>
Sat, 25 Apr 2020 15:42:45 +0000 (17:42 +0200)
`addArgument()` is not undoable and should not be used in
ConversionPattern, therefore replacing `splitBlock()` with
`createBlock()`, that creates a block with specified args.

Differential Revision: https://reviews.llvm.org/D78731

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

index 74488cc..b60b2fa 100644 (file)
@@ -2791,11 +2791,15 @@ struct GenericAtomicRMWOpLowering
 
     // Split the block into initial, loop, and ending parts.
     auto *initBlock = rewriter.getInsertionBlock();
-    auto initPosition = rewriter.getInsertionPoint();
-    auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
-    auto loopArgument = loopBlock->addArgument(valueType);
-    auto loopPosition = rewriter.getInsertionPoint();
-    auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+    auto *loopBlock =
+        rewriter.createBlock(initBlock->getParent(),
+                             std::next(Region::iterator(initBlock)), valueType);
+    auto *endBlock = rewriter.createBlock(
+        loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
+
+    // Operations range to be moved to `endBlock`.
+    auto opsToMoveStart = atomicOp.getOperation()->getIterator();
+    auto opsToMoveEnd = initBlock->back().getIterator();
 
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
@@ -2807,9 +2811,9 @@ struct GenericAtomicRMWOpLowering
 
     // Prepare the body of the loop block.
     rewriter.setInsertionPointToStart(loopBlock);
-    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
 
     // Clone the GenericAtomicRMWOp region and extract the result.
+    auto loopArgument = loopBlock->getArgument(0);
     BlockAndValueMapping mapping;
     mapping.map(atomicOp.getCurrentValue(), loopArgument);
     Block &entryBlock = atomicOp.body().front();
@@ -2820,10 +2824,10 @@ struct GenericAtomicRMWOpLowering
     Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
 
     // Prepare the epilog of the loop block.
-    rewriter.setInsertionPointToEnd(loopBlock);
     // Append the cmpxchg op to the end of the loop block.
     auto successOrdering = LLVM::AtomicOrdering::acq_rel;
     auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+    auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
     auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
     auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
         loc, pairType, dataPtr, loopArgument, result, successOrdering,
@@ -2838,11 +2842,31 @@ struct GenericAtomicRMWOpLowering
     rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
                                     loopBlock, newLoaded);
 
+    rewriter.setInsertionPointToEnd(endBlock);
+    MoveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
+                 std::next(opsToMoveEnd), rewriter);
+
     // The 'result' of the atomic_rmw op is the newly loaded value.
     rewriter.replaceOp(op, {newLoaded});
 
     return success();
   }
+
+private:
+  // Clones a segment of ops [start, end) and erases the original.
+  void MoveOpsRange(ValueRange oldResult, ValueRange newResult,
+                    Block::iterator start, Block::iterator end,
+                    ConversionPatternRewriter &rewriter) const {
+    BlockAndValueMapping mapping;
+    mapping.map(oldResult, newResult);
+    SmallVector<Operation *, 2> opsToErase;
+    for (auto it = start; it != end; ++it) {
+      rewriter.clone(*it, mapping);
+      opsToErase.push_back(&*it);
+    }
+    for (auto *it : opsToErase)
+      rewriter.eraseOp(it);
+  }
 };
 
 } // namespace
index 2b6038d..330165b 100644 (file)
@@ -1030,7 +1030,6 @@ func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 {
 // -----
 
 // CHECK-LABEL: func @generic_atomic_rmw
-// CHECK32-LABEL: func @generic_atomic_rmw
 func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
   %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
     ^bb0(%old_value : f32):
@@ -1047,8 +1046,12 @@ func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
   // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
   // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
   // CHECK-NEXT: ^bb2:
-  return %x : f32
-  // CHECK-NEXT: llvm.return [[new]]
+  %c2 = constant 2.0 : f32
+  %add = addf %c2, %x : f32
+  return %add : f32
+  // CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2.000000e+00 : f32)
+  // CHECK-NEXT: [[add:%.*]] = llvm.fadd [[c2]], [[new]] : !llvm.float
+  // CHECK-NEXT: llvm.return [[add]]
 }
 
 // -----