[mlir][spirv] Fix UnifyAliasedResourcePass for 64-bit index
authorLei Zhang <antiagainst@gmail.com>
Tue, 14 Mar 2023 23:45:42 +0000 (23:45 +0000)
committerLei Zhang <antiagainst@google.com>
Tue, 14 Mar 2023 23:54:27 +0000 (23:54 +0000)
Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D145079

mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

index 3acfba2..1713c44 100644 (file)
@@ -366,7 +366,6 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
     }
 
     Location loc = acOp.getLoc();
-    auto i32Type = rewriter.getI32Type();
 
     if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
       // The source indices are for a buffer with scalar element types. Rewrite
@@ -376,16 +375,19 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       int srcNumBytes = *srcElemType.getSizeInBytes();
       int dstNumBytes = *dstElemType.getSizeInBytes();
       assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
-      int ratio = dstNumBytes / srcNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
       auto indices = llvm::to_vector<4>(acOp.getIndices());
       Value oldIndex = indices.back();
+      Type indexType = oldIndex.getType();
+
+      int ratio = dstNumBytes / srcNumBytes;
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
       indices.back() =
-          rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+          rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
       indices.push_back(
-          rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+          rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);
@@ -400,14 +402,17 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       int srcNumBytes = *srcElemType.getSizeInBytes();
       int dstNumBytes = *dstElemType.getSizeInBytes();
       assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
-      int ratio = srcNumBytes / dstNumBytes;
-      auto ratioValue = rewriter.create<spirv::ConstantOp>(
-          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
       auto indices = llvm::to_vector<4>(acOp.getIndices());
       Value oldIndex = indices.back();
+      Type indexType = oldIndex.getType();
+
+      int ratio = srcNumBytes / dstNumBytes;
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
       indices.back() =
-          rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+          rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
 
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
           acOp, adaptor.getBasePtr(), indices);
index 1d532f3..8801fdb 100644 (file)
@@ -36,6 +36,33 @@ spirv.module Logical GLSL450 {
   spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
   spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
 
+  spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" {
+    %c0 = spirv.Constant 0 : i64
+    %addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i64, i64
+    %value = spirv.Load "StorageBuffer" %ac : f32
+    spirv.Store "StorageBuffer" %ac, %value : f32
+    spirv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+//     CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64)
+// CHECK-DAG:   %[[C4:.+]] = spirv.Constant 4 : i64
+//     CHECK:   spirv.SDiv %[[INDEX]], %[[C4]] : i64
+//     CHECK:   spirv.SMod %[[INDEX]], %[[C4]] : i64
+
+// -----
+
+spirv.module Logical GLSL450 {
+  spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
   spirv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" {
     %c0 = spirv.Constant 0 : i32
     %addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>