[mlir][vector] let transfer_read and transfer_write take non-zero addrspace.
authorWen-Heng (Jack) Chung <whchung@gmail.com>
Wed, 29 Apr 2020 15:06:44 +0000 (17:06 +0200)
committerAlex Zinenko <zinenko@google.com>
Wed, 29 Apr 2020 15:11:48 +0000 (17:11 +0200)
Enhance lowering logic and tests so vector.transfer_read and
vector.transfer_write take memrefs on non-zero addrspaces.

Differential Revision: https://reviews.llvm.org/D79023

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

index 6ed29f1..341ccfa 100644 (file)
@@ -803,7 +803,7 @@ bool isMinorIdentity(AffineMap map, unsigned rank) {
 
 /// Conversion pattern that converts a 1-D vector transfer read/write op in a
 /// sequence of:
-/// 1. Bitcast to vector form.
+/// 1. Bitcast or addrspacecast to vector form.
 /// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
 /// 3. Create a mask where offsetVector is compared against memref upper bound.
 /// 4. Rewrite op as a masked read or write.
@@ -835,13 +835,21 @@ public:
     MemRefType memRefType = xferOp.getMemRefType();
 
     // 1. Get the source/dst address as an LLVM vector pointer.
+    //    The vector pointer would always be on address space 0, therefore
+    //    addrspacecast shall be used when source/dst memrefs are not on
+    //    address space 0.
     // TODO: support alignment when possible.
     Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
                                adaptor.indices(), rewriter, getModule());
     auto vecTy =
         toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
-    auto vectorDataPtr =
-        rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+    Value vectorDataPtr;
+    if (memRefType.getMemorySpace() == 0)
+      vectorDataPtr =
+          rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
+    else
+      vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
+          loc, vecTy.getPointerTo(), dataPtr);
 
     // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
     unsigned vecWidth = vecTy.getVectorNumElements();
index bd0e4ec..ac454bf 100644 (file)
@@ -864,3 +864,32 @@ func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: index)
 //  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32, 0 : i32,
 //  CHECK-SAME:  0 : i32, 0 : i32, 0 : i32] :
 //  CHECK-SAME: !llvm<"<17 x i64>">, !llvm<"<17 x i64>">
+
+func @transfer_read_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %base: index) -> vector<17xf32> {
+  %f7 = constant 7.0: f32
+  %f = vector.transfer_read %A[%base], %f7
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    memref<?xf32, 3>, vector<17xf32>
+  vector.transfer_write %f, %A[%base]
+      {permutation_map = affine_map<(d0) -> (d0)>} :
+    vector<17xf32>, memref<?xf32, 3>
+  return %f: vector<17xf32>
+}
+// CHECK-LABEL: func @transfer_read_1d_non_zero_addrspace
+//  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: !llvm.i64) -> !llvm<"<17 x float>">
+//
+// 1. Check address space for GEP is correct.
+//       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm<"float addrspace(3)*">, !llvm.i64) -> !llvm<"float addrspace(3)*">
+//       CHECK: %[[vecPtr:.*]] = llvm.addrspacecast %[[gep]] :
+//  CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">
+//
+// 2. Check address space of the memref is correct.
+//       CHECK: %[[DIM:.*]] = llvm.extractvalue %{{.*}}[3, 0] :
+//  CHECK-SAME: !llvm<"{ float addrspace(3)*, float addrspace(3)*, i64, [1 x i64], [1 x i64] }">
+//
+// 3. Check address apce for GEP is correct.
+//       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
+//  CHECK-SAME: (!llvm<"float addrspace(3)*">, !llvm.i64) -> !llvm<"float addrspace(3)*">
+//       CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] :
+//  CHECK-SAME: !llvm<"float addrspace(3)*"> to !llvm<"<17 x float>*">