#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
}
};
+/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
+/// retried until it succeeds in atomically storing a new value into memory.
+///
+/// +---------------------------------+
+/// | <code before the AtomicRMWOp> |
+/// | <compute initial %loaded> |
+/// | br loop(%loaded) |
+/// +---------------------------------+
+/// |
+/// -------| |
+/// | v v
+/// | +--------------------------------+
+/// | | loop(%loaded): |
+/// | | <body contents> |
+/// | | %pair = cmpxchg |
+/// | | %ok = %pair[0] |
+/// | | %new = %pair[1] |
+/// | | cond_br %ok, end, loop(%new) |
+/// | +--------------------------------+
+/// | | |
+/// |----------- |
+/// v
+/// +--------------------------------+
+/// | end: |
+/// | <code after the AtomicRMWOp> |
+/// +--------------------------------+
+///
+struct GenericAtomicRMWOpLowering
+ : public LoadStoreOpLowering<GenericAtomicRMWOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto atomicOp = cast<GenericAtomicRMWOp>(op);
+
+ auto loc = op->getLoc();
+ OperandAdaptor<GenericAtomicRMWOp> adaptor(operands);
+ LLVM::LLVMType valueType =
+ typeConverter.convertType(atomicOp.getResult().getType())
+ .cast<LLVM::LLVMType>();
+
+ // Split the block into initial, loop, and ending parts.
+ auto *initBlock = rewriter.getInsertionBlock();
+ auto initPosition = rewriter.getInsertionPoint();
+ auto *loopBlock = rewriter.splitBlock(initBlock, initPosition);
+ auto loopArgument = loopBlock->addArgument(valueType);
+ auto loopPosition = rewriter.getInsertionPoint();
+ auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition);
+
+ // Compute the loaded value and branch to the loop block.
+ rewriter.setInsertionPointToEnd(initBlock);
+ auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
+ auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
+ adaptor.indices(), rewriter, getModule());
+ Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
+ rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
+
+ // Prepare the body of the loop block.
+ rewriter.setInsertionPointToStart(loopBlock);
+ auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect());
+
+ // Clone the GenericAtomicRMWOp region and extract the result.
+ BlockAndValueMapping mapping;
+ mapping.map(atomicOp.getCurrentValue(), loopArgument);
+ Block &entryBlock = atomicOp.body().front();
+ for (auto &nestedOp : entryBlock.without_terminator()) {
+ Operation *clone = rewriter.clone(nestedOp, mapping);
+ mapping.map(nestedOp.getResults(), clone->getResults());
+ }
+ Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
+
+ // Prepare the epilog of the loop block.
+ rewriter.setInsertionPointToEnd(loopBlock);
+ // Append the cmpxchg op to the end of the loop block.
+ auto successOrdering = LLVM::AtomicOrdering::acq_rel;
+ auto failureOrdering = LLVM::AtomicOrdering::monotonic;
+ auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType);
+ auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
+ loc, pairType, dataPtr, loopArgument, result, successOrdering,
+ failureOrdering);
+ // Extract the %new_loaded and %ok values from the pair.
+ Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
+ loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
+ Value ok = rewriter.create<LLVM::ExtractValueOp>(
+ loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));
+
+ // Conditionally branch to the end or back to the loop depending on %ok.
+ rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
+ loopBlock, newLoaded);
+
+ // The 'result' of the atomic_rmw op is the newly loaded value.
+ rewriter.replaceOp(op, {newLoaded});
+
+ return success();
+ }
+};
+
} // namespace
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
DivFOpLowering,
ExpOpLowering,
Exp2OpLowering,
+ GenericAtomicRMWOpLowering,
LogOpLowering,
Log10OpLowering,
Log2OpLowering,
// -----
+// CHECK-LABEL: func @generic_atomic_rmw
+// CHECK32-LABEL: func @generic_atomic_rmw
+func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 {
+ %x = generic_atomic_rmw %I[%i] : memref<10xf32> {
+ ^bb0(%old_value : f32):
+ %c1 = constant 1.0 : f32
+ atomic_yield %c1 : f32
+ }
+ // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm<"float*">
+ // CHECK-NEXT: llvm.br ^bb1([[init]] : !llvm.float)
+ // CHECK-NEXT: ^bb1([[loaded:%.*]]: !llvm.float):
+ // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1.000000e+00 : f32)
+ // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
+ // CHECK-SAME: acq_rel monotonic : !llvm.float
+ // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
+ // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
+ // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float)
+ // CHECK-NEXT: ^bb2:
+ return %x : f32
+ // CHECK-NEXT: llvm.return [[new]]
+}
+
+// -----
+
// CHECK-LABEL: func @assume_alignment
func @assume_alignment(%0 : memref<4x4xf16>) {
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">