[mlir][MemRef] Fix MemRefCopyOpLowering to use correct number of bytes
authorAdrian Kuegel <akuegel@google.com>
Fri, 11 Feb 2022 11:53:47 +0000 (12:53 +0100)
committerAdrian Kuegel <akuegel@google.com>
Fri, 11 Feb 2022 12:59:08 +0000 (13:59 +0100)
When lowering to memrefCopy call, the size for i1 type was calculated as 0.
Instead of using getTypeSizeInBits() and dividing by 8, we should just use getTypeSize().

Differential Revision: https://reviews.llvm.org/D119540

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

index 4507b10..a8910c2 100644 (file)
@@ -914,10 +914,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto sourcePtr = promote(unrankedSource);
     auto targetPtr = promote(unrankedTarget);
 
-    unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
-        srcType.getElementType());
+    unsigned typeSize =
+        mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
     auto elemSize = rewriter.create<LLVM::ConstantOp>(
-        loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
+        loc, getIndexType(), rewriter.getIndexAttr(typeSize));
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
     rewriter.create<LLVM::CallOp>(loc, copyFn,
index ee7d360..2fc6905 100644 (file)
@@ -911,3 +911,62 @@ func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
   // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_ranked
+func @memref_copy_ranked() {
+  %0 = memref.alloc() : memref<2xf32>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  %1 = memref.cast %0 : memref<2xf32> to memref<?xf32>
+  %2 = memref.alloc() : memref<2xf32>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  %3 = memref.cast %2 : memref<2xf32> to memref<?xf32>
+  memref.copy %1, %3 : memref<?xf32> to memref<?xf32>
+  // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[MUL:%.*]] = llvm.mul [[ONE]], [[EXTRACT0]] : i64
+  // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<f32>
+  // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<f32> to i64
+  // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64
+  // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_unranked
+func @memref_copy_unranked() {
+  %0 = memref.alloc() : memref<2xi1>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+  %1 = memref.cast %0 : memref<2xi1> to memref<*xi1>
+  %2 = memref.alloc() : memref<2xi1>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+  %3 = memref.cast %2 : memref<2xi1> to memref<*xi1>
+  memref.copy %1, %3 : memref<*xi1> to memref<*xi1>
+  // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[ALLOCA:%.*]] = llvm.alloca %35 x !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+  // CHECK: llvm.store {{%.*}}, [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+  // CHECK: [[BITCAST:%.*]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr<i8>
+  // CHECK: [[RANK:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[UNDEF:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[INSERT:%.*]] = llvm.insertvalue [[RANK]], [[UNDEF]][0] : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[INSERT2:%.*]] = llvm.insertvalue [[BITCAST]], [[INSERT]][1] : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[RANK2:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[ALLOCA2:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
+  return
+}