[mlir][memref] Implement fast lowering of memref.copy
authorStephan Herhut <herhut@google.com>
Fri, 7 Jan 2022 09:00:19 +0000 (10:00 +0100)
committerStephan Herhut <herhut@google.com>
Fri, 14 Jan 2022 13:22:15 +0000 (14:22 +0100)
In the absence of maps, we can lower memref.copy to a memcpy.

Differential Revision: https://reviews.llvm.org/D116099

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/mlir-cpu-runner/copy.mlir

index 3e3adb7..0bc6eb9 100644 (file)
@@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
   }
 };
 
+/// Pattern to lower a `memref.copy` to llvm.
+///
+/// For memrefs with identity layouts, the copy is lowered to the llvm
+/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
+/// to the generic `MemrefCopyFn`.
 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
+                          ConversionPatternRewriter &rewriter) const {
+    auto loc = op.getLoc();
+    auto srcType = op.source().getType().dyn_cast<MemRefType>();
+
+    MemRefDescriptor srcDesc(adaptor.source());
+
+    // Compute number of elements.
+    Value numElements;
+    for (int pos = 0; pos < srcType.getRank(); ++pos) {
+      auto size = srcDesc.size(rewriter, loc, pos);
+      numElements = numElements
+                        ? rewriter.create<LLVM::MulOp>(loc, numElements, size)
+                        : size;
+    }
+    // Get element size.
+    auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
+    // Compute total.
+    Value totalSize =
+        rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
+
+    Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
+    MemRefDescriptor targetDesc(adaptor.target());
+    Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
+    Value isVolatile = rewriter.create<LLVM::ConstantOp>(
+        loc, typeConverter->convertType(rewriter.getI1Type()),
+        rewriter.getBoolAttr(false));
+    rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
+                                    isVolatile);
+    rewriter.eraseOp(op);
+
+    return success();
+  }
+
+  LogicalResult
+  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter) const {
     auto loc = op.getLoc();
     auto srcType = op.source().getType().cast<BaseMemRefType>();
     auto targetType = op.target().getType().cast<BaseMemRefType>();
@@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
 
     return success();
   }
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = op.source().getType().cast<BaseMemRefType>();
+    auto targetType = op.target().getType().cast<BaseMemRefType>();
+
+    if (srcType.hasRank() &&
+        srcType.cast<MemRefType>().getLayout().isIdentity() &&
+        targetType.hasRank() &&
+        targetType.cast<MemRefType>().getLayout().isIdentity())
+      return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
+
+    return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
+  }
 };
 
 /// Extracts allocated, aligned pointers and offset from a ranked or unranked
index e5a471f..8581f13 100644 (file)
@@ -35,7 +35,7 @@ func @main() -> () {
   // CHECK-NEXT: [3,   4,   5]
 
   %copy_two = memref.alloc() : memref<3x2xf32>
-  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
+  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
     : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
   memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
   %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
@@ -49,6 +49,13 @@ func @main() -> () {
   %copy_empty = memref.alloc() : memref<3x0x1xf32>
   // Copying an empty shape should do nothing (and should not crash).
   memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
+
+  %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
+    : memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
+  %copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
+  // Copying a casted empty shape should do nothing (and should not crash).
+  memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
+
   memref.dealloc %copy_empty : memref<3x0x1xf32>
   memref.dealloc %input_empty : memref<3x0x1xf32>
   memref.dealloc %copy_two : memref<3x2xf32>