[mlir] Use splitBlock instread of createBlock in GenericAtomicRMWLowering.
authorAlexander Belyaev <pifon@google.com>
Mon, 13 Mar 2023 07:35:21 +0000 (08:35 +0100)
committerAlexander Belyaev <pifon@google.com>
Mon, 13 Mar 2023 17:14:04 +0000 (18:14 +0100)
When generic_atomic_rmw is inside of memref.alloca_scope, then the pattern would fail.

Differential Revision: https://reviews.llvm.org/D145901

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

index 7da7b66ab1f350c3aba6213b6ebd3c376574c7c7..2af5a2522566d2bb2772518367e8f4ef38c9bc7a 100644 (file)
@@ -580,15 +580,11 @@ struct GenericAtomicRMWOpLowering
 
     // Split the block into initial, loop, and ending parts.
     auto *initBlock = rewriter.getInsertionBlock();
-    auto *loopBlock = rewriter.createBlock(
-        initBlock->getParent(), std::next(Region::iterator(initBlock)),
-        valueType, loc);
-    auto *endBlock = rewriter.createBlock(
-        loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
+    auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
+    loopBlock->addArgument(valueType, loc);
 
-    // Operations range to be moved to `endBlock`.
-    auto opsToMoveStart = atomicOp->getIterator();
-    auto opsToMoveEnd = initBlock->back().getIterator();
+    auto *endBlock =
+        rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
 
     // Compute the loaded value and branch to the loop block.
     rewriter.setInsertionPointToEnd(initBlock);
@@ -628,30 +624,12 @@ struct GenericAtomicRMWOpLowering
                                     loopBlock, newLoaded);
 
     rewriter.setInsertionPointToEnd(endBlock);
-    moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
-                 std::next(opsToMoveEnd), rewriter);
 
     // The 'result' of the atomic_rmw op is the newly loaded value.
     rewriter.replaceOp(atomicOp, {newLoaded});
 
     return success();
   }
-
-private:
-  // Clones a segment of ops [start, end) and erases the original.
-  void moveOpsRange(ValueRange oldResult, ValueRange newResult,
-                    Block::iterator start, Block::iterator end,
-                    ConversionPatternRewriter &rewriter) const {
-    IRMapping mapping;
-    mapping.map(oldResult, newResult);
-    SmallVector<Operation *, 2> opsToErase;
-    for (auto it = start; it != end; ++it) {
-      rewriter.clone(*it, mapping);
-      opsToErase.push_back(&*it);
-    }
-    for (auto *it : opsToErase)
-      rewriter.eraseOp(it);
-  }
 };
 
 /// Returns the LLVM type of the global variable given the memref type `type`.
index f6dc44cf4571fdf800eee4feccd19141e8193b57..4b4b3836a007580cfbeb31e1c5ed3dc7f2d75615 100644 (file)
@@ -362,16 +362,47 @@ func.func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
     ^bb0(%old_value : i32):
       memref.atomic_yield %old_value : i32
   }
-  // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
-  // CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
-  // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
-  // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]]
-  // CHECK-SAME:                    acq_rel monotonic : !llvm.ptr, i32
-  // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
-  // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
-  // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
   llvm.return
 }
+// CHECK:        %[[INIT:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32
+// CHECK-NEXT:   llvm.br ^bb1(%[[INIT]] : i32)
+// CHECK-NEXT: ^bb1(%[[LOADED:.*]]: i32):
+// CHECK-NEXT:   %[[PAIR:.*]] = llvm.cmpxchg %{{.*}}, %[[LOADED]], %[[LOADED]]
+// CHECK-SAME:                      acq_rel monotonic : !llvm.ptr, i32
+// CHECK-NEXT:   %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0]
+// CHECK-NEXT:   %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1]
+// CHECK-NEXT:   llvm.cond_br %[[OK]], ^bb2, ^bb1(%[[NEW]] : i32)
+
+// -----
+
+// CHECK-LABEL: func @generic_atomic_rmw_in_alloca_scope
+func.func @generic_atomic_rmw_in_alloca_scope(){
+  %c1 = arith.constant 1 : index
+  %alloc = memref.alloc() : memref<2x3xi32>
+  memref.alloca_scope  {
+    %0 = memref.generic_atomic_rmw %alloc[%c1, %c1] : memref<2x3xi32> {
+    ^bb0(%arg0: i32):
+      memref.atomic_yield %arg0 : i32
+    }
+  }
+  return
+}
+// CHECK:        %[[STACK_SAVE:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-NEXT:   llvm.br ^bb1
+// CHECK:      ^bb1:
+// CHECK:        %[[INIT:.*]] = llvm.load %[[BUF:.*]] : !llvm.ptr -> i32
+// CHECK-NEXT:   llvm.br ^bb2(%[[INIT]] : i32)
+// CHECK-NEXT: ^bb2(%[[LOADED:.*]]: i32):
+// CHECK-NEXT:   %[[PAIR:.*]] = llvm.cmpxchg %[[BUF]], %[[LOADED]], %[[LOADED]]
+// CHECK-SAME:     acq_rel monotonic : !llvm.ptr, i32
+// CHECK-NEXT:   %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0]
+// CHECK-NEXT:   %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1]
+// CHECK-NEXT:   llvm.cond_br %[[OK]], ^bb3, ^bb2(%[[NEW]] : i32)
+// CHECK-NEXT: ^bb3:
+// CHECK-NEXT:   llvm.intr.stackrestore %[[STACK_SAVE]] : !llvm.ptr
+// CHECK-NEXT:   llvm.br ^bb4
+// CHECK-NEXT: ^bb4:
+// CHECK-NEXT:   return
 
 // -----