[mlir][spirv] Support OpenCL when lowering memref load/store
authorStanley Winata <stanley@nod-labs.com>
Mon, 19 Sep 2022 17:14:58 +0000 (13:14 -0400)
committerLei Zhang <antiagainst@google.com>
Mon, 19 Sep 2022 17:24:21 +0000 (13:24 -0400)
-Add awareness to Kernel vs Shader capability for memref to SPIR-V
 lowering.
-Add lowering using spv.PtrAccessChain for Kernel capability.
-Enable lowering from scalar pointee types for kernel capabilities.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D132714

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

index ff1cdac..9b480f6 100644 (file)
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -72,6 +73,9 @@ public:
   /// Returns the options controlling the SPIR-V type converter.
   const SPIRVConversionOptions &getOptions() const { return options; }
 
+  /// Checks if the SPIR-V capability inquired is supported.
+  bool allows(spirv::Capability capability);
+
 private:
   spirv::TargetEnv targetEnv;
   SPIRVConversionOptions options;
@@ -151,10 +155,19 @@ Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
 
 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
 // that has static strides. Extend to handle dynamic strides.
-spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
-                                   MemRefType baseType, Value basePtr,
-                                   ValueRange indices, Location loc,
-                                   OpBuilder &builder);
+Value getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType,
+                    Value basePtr, ValueRange indices, Location loc,
+                    OpBuilder &builder);
+
+// GetElementPtr implementation for Kernel/OpenCL flavored SPIR-V.
+Value getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+                          MemRefType baseType, Value basePtr,
+                          ValueRange indices, Location loc, OpBuilder &builder);
+
+// GetElementPtr implementation for Vulkan/Shader flavored SPIR-V.
+Value getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+                          MemRefType baseType, Value basePtr,
+                          ValueRange indices, Location loc, OpBuilder &builder);
 
 } // namespace spirv
 } // namespace mlir
index d72802b..766d42b 100644 (file)
@@ -192,7 +192,7 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts memref.load to spv.Load.
+/// Converts memref.load to spv.Load + spv.AccessChain on integers.
 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
 public:
   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
@@ -202,7 +202,7 @@ public:
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Converts memref.load to spv.Load.
+/// Converts memref.load to spv.Load + spv.AccessChain.
 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
 public:
   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
@@ -319,11 +319,11 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
     return failure();
 
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
-  spirv::AccessChainOp accessChainOp =
+  Value accessChain =
       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
                            adaptor.getIndices(), loc, rewriter);
 
-  if (!accessChainOp)
+  if (!accessChain)
     return failure();
 
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
@@ -333,27 +333,41 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   Type pointeeType = typeConverter.convertType(memrefType)
                          .cast<spirv::PointerType>()
                          .getPointeeType();
-  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
   Type dstType;
-  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
-    dstType = arrayType.getElementType();
-  else
-    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
-
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    // For OpenCL Kernel, pointer will be directly pointing to the element.
+    dstType = pointeeType;
+  } else {
+    // For Vulkan we need to extract element from wrapping struct and array.
+    Type structElemType =
+        pointeeType.cast<spirv::StructType>().getElementType(0);
+    if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+      dstType = arrayType.getElementType();
+    else
+      dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+  }
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
   // If the rewrited load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
-    Value loadVal =
-        rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
+    Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
     if (isBool)
       loadVal = castIntNToBool(loc, loadVal, rewriter);
     rewriter.replaceOp(loadOp, loadVal);
     return success();
   }
 
+  // Bitcasting is currently unsupported for Kernel capability /
+  // spv.PtrAccessChain.
+  if (typeConverter.allows(spirv::Capability::Kernel))
+    return failure();
+
+  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
+  if (!accessChainOp)
+    return failure();
+
   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
   // still returns a linearized accessing. If the accessing is not linearized,
   // there will be offset issues.
@@ -432,11 +446,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 
   auto loc = storeOp.getLoc();
   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
-  spirv::AccessChainOp accessChainOp =
+  Value accessChain =
       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
                            adaptor.getIndices(), loc, rewriter);
 
-  if (!accessChainOp)
+  if (!accessChain)
     return failure();
 
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
@@ -448,12 +462,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
   Type pointeeType = typeConverter.convertType(memrefType)
                          .cast<spirv::PointerType>()
                          .getPointeeType();
-  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
   Type dstType;
-  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
-    dstType = arrayType.getElementType();
-  else
-    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    // For OpenCL Kernel, pointer will be directly pointing to the element.
+    dstType = pointeeType;
+  } else {
+    // For Vulkan we need to extract element from wrapping struct and array.
+    Type structElemType =
+        pointeeType.cast<spirv::StructType>().getElementType(0);
+    if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
+      dstType = arrayType.getElementType();
+    else
+      dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();
+  }
 
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
@@ -462,11 +483,19 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
     Value storeVal = adaptor.getValue();
     if (isBool)
       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
-        storeOp, accessChainOp.getResult(), storeVal);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal);
     return success();
   }
 
+  // Bitcasting is currently unsupported for Kernel capability /
+  // spv.PtrAccessChain.
+  if (typeConverter.allows(spirv::Capability::Kernel))
+    return failure();
+
+  auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
+  if (!accessChainOp)
+    return failure();
+
   // Since there are multi threads in the processing, the emulation will be done
   // with atomic operations. E.g., if the storing value is i8, rewrite the
   // StoreOp to
index 8d0bde6..b56b8c0 100644 (file)
@@ -122,6 +122,10 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
   return targetEnv.getAttr().getContext();
 }
 
+bool SPIRVTypeConverter::allows(spirv::Capability capability) {
+  return targetEnv.allows(capability);
+}
+
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
 static Optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
@@ -334,6 +338,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
+  // For OpenCL Kernel we can just emit a pointer pointing to the element.
+  if (targetEnv.allows(spirv::Capability::Kernel))
+    return spirv::PointerType::get(arrayElemType, storageClass);
+
+  // For Vulkan we need extra wrapping struct and array to satisfy interface
+  // needs.
   if (!type.hasStaticShape()) {
     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
@@ -393,6 +403,12 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
+  // For OpenCL Kernel we can just emit a pointer pointing to the element.
+  if (targetEnv.allows(spirv::Capability::Kernel))
+    return spirv::PointerType::get(arrayElemType, storageClass);
+
+  // For Vulkan we need extra wrapping struct and array to satisfy interface
+  // needs.
   if (!type.hasStaticShape()) {
     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
@@ -712,9 +728,10 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
   return linearizedIndex;
 }
 
-spirv::AccessChainOp mlir::spirv::getElementPtr(
-    SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
-    ValueRange indices, Location loc, OpBuilder &builder) {
+Value mlir::spirv::getVulkanElementPtr(SPIRVTypeConverter &typeConverter,
+                                       MemRefType baseType, Value basePtr,
+                                       ValueRange indices, Location loc,
+                                       OpBuilder &builder) {
   // Get base and offset of the MemRefType and verify they are static.
 
   int64_t offset;
@@ -742,6 +759,50 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }
 
+Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
+                                       MemRefType baseType, Value basePtr,
+                                       ValueRange indices, Location loc,
+                                       OpBuilder &builder) {
+  // Get base and offset of the MemRefType and verify they are static.
+
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
+      offset == MemRefType::getDynamicStrideOrOffset()) {
+    return nullptr;
+  }
+
+  auto indexType = typeConverter.getIndexType();
+
+  SmallVector<Value, 2> linearizedIndices;
+  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
+
+  Value linearIndex;
+  if (baseType.getRank() == 0) {
+    linearIndex = zero;
+  } else {
+    linearIndex =
+        linearizeIndex(indices, strides, offset, indexType, loc, builder);
+  }
+  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
+                                                 linearizedIndices);
+}
+
+Value mlir::spirv::getElementPtr(SPIRVTypeConverter &typeConverter,
+                                 MemRefType baseType, Value basePtr,
+                                 ValueRange indices, Location loc,
+                                 OpBuilder &builder) {
+
+  if (typeConverter.allows(spirv::Capability::Kernel)) {
+    return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
+                               builder);
+  }
+
+  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
+                             builder);
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V ConversionTarget
 //===----------------------------------------------------------------------===//
index ac22d13..5413157 100644 (file)
@@ -9,7 +9,7 @@ module attributes {
     //       CHECK:   spv.func
     //  CHECK-SAME:     {{%.*}}: f32
     //   CHECK-NOT:     spv.interface_var_abi
-    //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32>)>, CrossWorkgroup>
+    //  CHECK-SAME:     {{%.*}}: !spv.ptr<f32, CrossWorkgroup>
     //   CHECK-NOT:     spv.interface_var_abi
     //  CHECK-SAME:     spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class<CrossWorkgroup>>) kernel
index 212363c..98e44b9 100644 (file)
@@ -109,6 +109,111 @@ func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<StorageBuffer>>, %i: i
 
 // -----
 
+// Check for Kernel capability, that with proper compute and storage extensions, we don't need to
+// perform special tricks.
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0,
+      [
+        Kernel, Addresses, Int8, Int16, Int64, Float16, Float64], []>, #spv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_zero_rank_float
+func.func @load_store_zero_rank_float(%arg0: memref<f32, #spv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spv.storage_class<CrossWorkgroup>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<f32, CrossWorkgroup>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<f32, CrossWorkgroup>
+  //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
+  //      CHECK: spv.PtrAccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "CrossWorkgroup" %{{.*}} : f32
+  %0 = memref.load %arg0[] : memref<f32, #spv.storage_class<CrossWorkgroup>>
+  //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
+  //      CHECK: spv.PtrAccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "CrossWorkgroup" %{{.*}} : f32
+  memref.store %0, %arg1[] : memref<f32, #spv.storage_class<CrossWorkgroup>>
+  return
+}
+
+// CHECK-LABEL: @load_store_zero_rank_int
+func.func @load_store_zero_rank_int(%arg0: memref<i32, #spv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spv.storage_class<CrossWorkgroup>>) {
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<CrossWorkgroup>> to  !spv.ptr<i32, CrossWorkgroup>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spv.storage_class<CrossWorkgroup>> to  !spv.ptr<i32, CrossWorkgroup>
+  //      CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
+  //      CHECK: spv.PtrAccessChain [[ARG0]][
+  // CHECK-SAME: [[ZERO1]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Load "CrossWorkgroup" %{{.*}} : i32
+  %0 = memref.load %arg0[] : memref<i32, #spv.storage_class<CrossWorkgroup>>
+  //      CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
+  //      CHECK: spv.PtrAccessChain [[ARG1]][
+  // CHECK-SAME: [[ZERO2]]
+  // CHECK-SAME: ] :
+  //      CHECK: spv.Store "CrossWorkgroup" %{{.*}} : i32
+  memref.store %0, %arg1[] : memref<i32, #spv.storage_class<CrossWorkgroup>>
+  return
+}
+
+// CHECK-LABEL: func @load_store_unknown_dim
+func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spv.storage_class<CrossWorkgroup>>, %dest: memref<?xi32, #spv.storage_class<CrossWorkgroup>>) {
+  // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+  // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i32, CrossWorkgroup>
+  // CHECK: %[[AC0:.+]] = spv.PtrAccessChain %[[SRC]]
+  // CHECK: spv.Load "CrossWorkgroup" %[[AC0]]
+  %0 = memref.load %source[%i] : memref<?xi32, #spv.storage_class<CrossWorkgroup>>
+  // CHECK: %[[AC1:.+]] = spv.PtrAccessChain %[[DST]]
+  // CHECK: spv.Store "CrossWorkgroup" %[[AC1]]
+  memref.store %0, %dest[%i]: memref<?xi32, #spv.storage_class<CrossWorkgroup>>
+  return
+}
+
+// CHECK-LABEL: func @load_i1
+//  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
+func.func @load_i1(%src: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i8, CrossWorkgroup>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+  // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
+  // CHECK: %[[VAL:.+]] = spv.Load "CrossWorkgroup" %[[ADDR]] : i8
+  // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+  // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8
+  %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class<CrossWorkgroup>>
+  // CHECK: return %[[BOOL]]
+  return %0: i1
+}
+
+// CHECK-LABEL: func @store_i1
+//  CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>,
+//  CHECK-SAME: %[[IDX:.+]]: index
+func.func @store_i1(%dst: memref<4xi1, #spv.storage_class<CrossWorkgroup>>, %i: index) {
+  %true = arith.constant true
+  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class<CrossWorkgroup>> to !spv.ptr<i8, CrossWorkgroup>
+  // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+  // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32
+  // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
+  // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8
+  // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8
+  // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8
+  // CHECK: spv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8
+  memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class<CrossWorkgroup>>
+  return
+}
+
+} // end module
+
+// -----
+
 // Check that access chain indices are properly adjusted if non-32-bit types are
 // emulated via 32-bit types.
 // TODO: Test i64 types.