[mlir][StandardToSPIRV] Fix signedness issue in bitwidth emulation.
authorHanhan Wang <hanchung@google.com>
Tue, 19 May 2020 17:55:17 +0000 (10:55 -0700)
committerHanhan Wang <hanchung@google.com>
Tue, 19 May 2020 18:00:01 +0000 (11:00 -0700)
Summary:
Previously, after applying the mask, a negative number would convert to a
positive number because the sign flag was forgotten. This patch adds two more
shift operations to do the sign extension. This assumes that we're using two's
complement.

This patch applies sign extension unconditionally when loading a unspported integer width, and it relies the pattern to do the casting because the signedness semantic is carried by operator itself.

Differential Revision: https://reviews.llvm.org/D79753

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

index fbe0256..560bc4a 100644 (file)
@@ -147,14 +147,44 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
 }
 
 /// Returns the shifted `targetBits`-bit value with the given offset.
-Value shiftValue(Location loc, Value value, Value offset, Value mask,
-                 int targetBits, OpBuilder &builder) {
+static Value shiftValue(Location loc, Value value, Value offset, Value mask,
+                        int targetBits, OpBuilder &builder) {
   Type targetType = builder.getIntegerType(targetBits);
   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
                                                    offset);
 }
 
+/// Returns true if the operator is operating on unsigned integers.
+/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
+/// to the ops themselves.
+template <typename SPIRVOp>
+bool isUnsignedOp() {
+  return false;
+}
+
+#define CHECK_UNSIGNED_OP(SPIRVOp)                                             \
+  template <>                                                                  \
+  bool isUnsignedOp<SPIRVOp>() {                                               \
+    return true;                                                               \
+  }
+
+CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp);
+CHECK_UNSIGNED_OP(spirv::AtomicUMinOp);
+CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp);
+CHECK_UNSIGNED_OP(spirv::ConvertUToFOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp);
+CHECK_UNSIGNED_OP(spirv::UConvertOp);
+CHECK_UNSIGNED_OP(spirv::UDivOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanOp);
+CHECK_UNSIGNED_OP(spirv::UModOp);
+
+#undef CHECK_UNSIGNED_OP
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -178,6 +208,10 @@ public:
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
+    if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+      return operation.emitError(
+          "bitwidth emulation is not implemented yet on unsigned op");
+    }
     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
                                                   ArrayRef<NamedAttribute>());
     return success();
@@ -581,6 +615,11 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   switch (cmpIOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
+    if (isUnsignedOp<spirvOp>() &&                                             \
+        operandType != this->typeConverter.convertType(operandType)) {         \
+      return cmpIOp.emitError(                                                 \
+          "bitwidth emulation is not implemented yet on unsigned op");         \
+    }                                                                          \
     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
                                          cmpIOpOperands.lhs(),                 \
                                          cmpIOpOperands.rhs());                \
@@ -661,6 +700,18 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
   Value mask = rewriter.create<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+
+  // Apply sign extension on the loading value unconditionally. The signedness
+  // semantic is carried in the operator itself, we relies other pattern to
+  // handle the casting.
+  IntegerAttr shiftValueAttr =
+      rewriter.getIntegerAttr(dstType, dstBits - srcBits);
+  Value shiftValue =
+      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
+                                                      shiftValue);
+  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
+                                                          shiftValue);
   rewriter.replaceOp(loadOp, result);
 
   assert(accessChainOp.use_empty());
index 1663366..bf54dba 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // std arithmetic ops
@@ -128,14 +128,12 @@ module attributes {
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
 } {
 
-// CHECK-LABEL: @int_vector234
-func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+// CHECK-LABEL: @int_vector23
+func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
   // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
   %0 = divi_signed %arg0, %arg0: vector<2xi8>
   // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
   %1 = remi_signed %arg1, %arg1: vector<3xi16>
-  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32>
-  %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
   return
 }
 
@@ -152,6 +150,27 @@ func @float_scalar(%arg0: f16, %arg1: f64) {
 
 // -----
 
+// Check that types are converted to 32-bit when no special capabilities that
+// are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LEBEL: @int_vector4_invalid
+func @int_vector4_invalid(%arg0: vector<4xi64>) {
+  // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
+  // expected-error @+1 {{op requires the same type for all operands and results}}
+  %0 = divi_unsigned %arg0, %arg0: vector<4xi64>
+  return
+}
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // std bit ops
 //===----------------------------------------------------------------------===//
@@ -717,7 +736,10 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[] : memref<i8>
   return
 }
@@ -738,7 +760,10 @@ func @load_i16(%arg0: memref<10xi16>, %index : index) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 16 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[%index] : memref<10xi16>
   return
 }
@@ -852,7 +877,10 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[] : memref<i8>
   return
 }