}
Location loc = acOp.getLoc();
- auto i32Type = rewriter.getI32Type();
if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
// The source indices are for a buffer with scalar element types. Rewrite
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
- int ratio = dstNumBytes / srcNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
+ Type indexType = oldIndex.getType();
+
+ int ratio = dstNumBytes / srcNumBytes;
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
indices.back() =
- rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+ rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
indices.push_back(
- rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+ rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
- int ratio = srcNumBytes / dstNumBytes;
- auto ratioValue = rewriter.create<spirv::ConstantOp>(
- loc, i32Type, rewriter.getI32IntegerAttr(ratio));
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
+ Type indexType = oldIndex.getType();
+
+ int ratio = srcNumBytes / dstNumBytes;
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
+
indices.back() =
- rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+ rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ spirv.func @load_store_scalar_64bit(%index: i64) -> f32 "None" {
+ %c0 = spirv.Constant 0 : i64
+ %addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac = spirv.AccessChain %addr[%c0, %index] : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i64, i64
+ %value = spirv.Load "StorageBuffer" %ac : f32
+ spirv.Store "StorageBuffer" %ac, %value : f32
+ spirv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spirv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spirv.GlobalVariable @var01v bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+// CHECK: spirv.func @load_store_scalar_64bit(%[[INDEX:.+]]: i64)
+// CHECK-DAG: %[[C4:.+]] = spirv.Constant 4 : i64
+// CHECK: spirv.SDiv %[[INDEX]], %[[C4]] : i64
+// CHECK: spirv.SMod %[[INDEX]], %[[C4]] : i64
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.GlobalVariable @var01s bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spirv.GlobalVariable @var01v bind(0, 1) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
spirv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" {
%c0 = spirv.Constant 0 : i32
%addr = spirv.mlir.addressof @var01s : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>