[mlir][spirv] Allow bitwidth emulation on runtime arrays
authorLei Zhang <antiagainst@google.com>
Mon, 12 Apr 2021 20:50:24 +0000 (16:50 -0400)
committerLei Zhang <antiagainst@google.com>
Mon, 12 Apr 2021 21:04:18 +0000 (17:04 -0400)
Runtime arrays are converted from memrefs with unknown
dimensions.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D100335

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index 397d26b..7dea6e8 100644 (file)
@@ -994,13 +994,16 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
   bool isBool = srcBits == 1;
   if (isBool)
     srcBits = typeConverter.getOptions().boolNumBits;
-  auto dstType = typeConverter.convertType(memrefType)
-                     .cast<spirv::PointerType>()
-                     .getPointeeType()
-                     .cast<spirv::StructType>()
-                     .getElementType(0)
-                     .cast<spirv::ArrayType>()
-                     .getElementType();
+  Type pointeeType = typeConverter.convertType(memrefType)
+                         .cast<spirv::PointerType>()
+                         .getPointeeType();
+  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
+  Type dstType;
+  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+    dstType = arrayType.getElementType();
+  else
+    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
@@ -1136,13 +1139,16 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
   bool isBool = srcBits == 1;
   if (isBool)
     srcBits = typeConverter.getOptions().boolNumBits;
-  auto dstType = typeConverter.convertType(memrefType)
-                     .cast<spirv::PointerType>()
-                     .getPointeeType()
-                     .cast<spirv::StructType>()
-                     .getElementType(0)
-                     .cast<spirv::ArrayType>()
-                     .getElementType();
+  Type pointeeType = typeConverter.convertType(memrefType)
+                         .cast<spirv::PointerType>()
+                         .getPointeeType();
+  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
+  Type dstType;
+  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+    dstType = arrayType.getElementType();
+  else
+    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
index 86d390a..82157e0 100644 (file)
@@ -905,6 +905,19 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
   return
 }
 
+// CHECK-LABEL: func @load_store_unknown_dim
+// CHECK-SAME: %[[SRC:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>,
+// CHECK-SAME: %[[DST:[a-z0-9]+]]: !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>)
+func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
+  // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
+  // CHECK: spv.Load "StorageBuffer" %[[AC0]]
+  %0 = memref.load %source[%i] : memref<?xi32>
+  // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
+  // CHECK: spv.Store "StorageBuffer" %[[AC1]]
+  memref.store %0, %dest[%i]: memref<?xi32>
+  return
+}
+
 } // end module
 
 // -----