// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
- auto initPosition = rewriter.getInsertionPoint();
- auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
- auto loopArgument = loopBlock->addArgument(valueType);
- auto loopPosition = rewriter.getInsertionPoint();
- auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+ auto *loopBlock =
+ rewriter.createBlock(initBlock->getParent(),
+ std::next(Region::iterator(initBlock)), valueType);
+ auto *endBlock = rewriter.createBlock(
+ loopBlock->getParent(), std::next(Region::iterator(loopBlock)));
+
+ // Operations range to be moved to `endBlock`.
+ auto opsToMoveStart = atomicOp.getOperation()->getIterator();
+ auto opsToMoveEnd = initBlock->back().getIterator();
// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
- auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
// Clone the GenericAtomicRMWOp region and extract the result.
+ auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
// Prepare the epilog of the loop block.
- rewriter.setInsertionPointToEnd(loopBlock);
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+ auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);
+ rewriter.setInsertionPointToEnd(endBlock);
+ MoveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
+ std::next(opsToMoveEnd), rewriter);
+
// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
return success();
}
+
+private:
+ // Clones a segment of ops [start, end) and erases the original.
+ void MoveOpsRange(ValueRange oldResult, ValueRange newResult,
+ Block::iterator start, Block::iterator end,
+ ConversionPatternRewriter &rewriter) const {
+ BlockAndValueMapping mapping;
+ mapping.map(oldResult, newResult);
+ SmallVector<Operation *, 2> opsToErase;
+ for (auto it = start; it != end; ++it) {
+ rewriter.clone(*it, mapping);
+ opsToErase.push_back(&*it);
+ }
+ for (auto *it : opsToErase)
+ rewriter.eraseOp(it);
+ }
};
} // namespace
// -----
// CHECK-LABEL: func @generic_atomic_rmw
-// CHECK32-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
// CHECK-NEXT: ^bb2:
- return %x : f32
- // CHECK-NEXT: llvm.return [[new]]
+ %c2 = constant 2.0 : f32
+ %add = addf %c2, %x : f32
+ return %add : f32
+ // CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2.000000e+00 : f32)
+ // CHECK-NEXT: [[add:%.*]] = llvm.fadd [[c2]], [[new]] : !llvm.float
+ // CHECK-NEXT: llvm.return [[add]]
}
// -----