[mlir][StandardToSPIRV] Use spv.UMod for index re-calculation
authorLei Zhang <antiagainst@google.com>
Mon, 13 Jul 2020 20:20:59 +0000 (16:20 -0400)
committerLei Zhang <antiagainst@google.com>
Wed, 5 Aug 2020 18:52:04 +0000 (14:52 -0400)
Per Vulkan's SPIR-V environment spec: "While the OpSRem and OpSMod
instructions are supported by the Vulkan environment, they require
non-negative values and thus do not enable additional functionality
beyond what OpUMod provides."

The `getOffsetForBitwidth` function is used for lowering std.load
and std.store, whose indices are of `index` type and cannot be
negative. So we should be okay to use spv.UMod directly here to
be exact. Also made the comment explicit about the assumption.

Differential Revision: https://reviews.llvm.org/D83714

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index 543b23a..268139f 100644 (file)
@@ -126,9 +126,12 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
 }
 
-/// Returns the offset of the value in `targetBits` representation. `srcIdx` is
-/// an index into a 1-D array with each element having `sourceBits`. When
-/// accessing an element in the array treating as having elements of
+/// Returns the offset of the value in `targetBits` representation.
+///
+/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
+/// It's assumed to be non-negative.
+///
+/// When accessing an element in the array treating as having elements of
 /// `targetBits`, multiple values are loaded in the same time. The method
 /// returns the offset where the `srcIdx` locates in the value. For example, if
 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
@@ -144,7 +147,7 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
   auto srcBitsValue =
       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
-  auto m = builder.create<spirv::SModOp>(loc, srcIdx, idx);
+  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
 }
 
index 7352237..e85f78f 100644 (file)
@@ -762,7 +762,7 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
   //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
@@ -788,7 +788,7 @@ func @load_i16(%arg0: memref<10xi16>, %index : index) {
   //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[TWO2:.+]] = spv.constant 2 : i32
   //     CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO2]] : i32
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
@@ -824,7 +824,7 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
   //     CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
   //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
   //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
   //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
@@ -850,7 +850,7 @@ func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
   //     CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spv.constant 2 : i32
   //     CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[FLAT_IDX]], %[[TWO]] : i32
   //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32
   //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
@@ -907,7 +907,7 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
   //     CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32
   //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
@@ -934,7 +934,7 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
   //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
   //     CHECK: %[[FOUR:.+]] = spv.constant 4 : i32
   //     CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32
-  //     CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
   //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[MASK1:.+]] = spv.constant 255 : i32
   //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32