[mlir][arith][spirv] Hard fail in `-convert-arith-to-spirv`
authorJakub Kuderski <kubak@google.com>
Mon, 31 Oct 2022 21:00:45 +0000 (17:00 -0400)
committerJakub Kuderski <kubak@google.com>
Mon, 31 Oct 2022 21:01:21 +0000 (17:01 -0400)
Turn legalization failures into hard failures to make sure that we do
not miss conversion pattern application failures.

Add a message on type conversion failure.

Move unsupported cases into a separate test file.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137102

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Conversion/SPIRVCommon/Pattern.h
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir [new file with mode: 0644]
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

index 24bead8..2452928 100644 (file)
@@ -1046,6 +1046,9 @@ struct ConvertArithToSPIRVPass
     typeConverter.addTargetMaterialization(addUnrealizedCast);
     target->addLegalOp<UnrealizedConversionCastOp>();
 
+    // Fail hard when there are any remaining 'arith' ops.
+    target->addIllegalDialect<arith::ArithDialect>();
+
     RewritePatternSet patterns(&getContext());
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
 
index 5d32fa8..ed859a8 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 namespace spirv {
@@ -26,9 +27,13 @@ public:
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() <= 3);
-    auto dstType = this->getTypeConverter()->convertType(op.getType());
-    if (!dstType)
-      return failure();
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+    }
+
     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
         !op.getType().isIndex() && dstType != op.getType()) {
       return op.emitError(
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
new file mode 100644 (file)
index 0000000..967adbc
--- /dev/null
@@ -0,0 +1,72 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s
+
+///===----------------------------------------------------------------------===//
+// Binary ops
+//===----------------------------------------------------------------------===//
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+  // expected-error@+1 {{failed to legalize operation 'arith.subi'}}
+  %1 = arith.subi %arg0, %arg0: vector<5xi32>
+  return
+}
+
+} // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+
+func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+  // expected-error@+1 {{failed to legalize operation 'arith.muli'}}
+  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+  return
+}
+
+} // end module
+
+// -----
+
+func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
+  // expected-error @+2 {{failed to legalize operation 'arith.divui'}}
+  // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}}
+  %0 = arith.divui %arg0, %arg0: vector<2xi16>
+  return
+}
+
+///===----------------------------------------------------------------------===//
+// Constant ops
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @unsupported_constant_0() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %0 = arith.constant 4294967296 : i64 // 2^32
+  return
+}
+
+// -----
+
+func.func @unsupported_constant_1() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %1 = arith.constant -2147483649 : i64 // -2^31 - 1
+  return
+}
+
+// -----
+
+func.func @unsupported_constant_2() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %2 = arith.constant -2147483649 : i64 // -2^31 - 1
+  return
+}
index bd5238f..df6806a 100644 (file)
@@ -163,63 +163,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
   return
 }
 
-// CHECK-LABEL: @unsupported_5elem_vector
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
-  // CHECK: arith.subi
-  %1 = arith.subi %arg0, %arg0: vector<5xi32>
-  return
-}
-
-// CHECK-LABEL: @unsupported_2x2elem_vector
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
-  // CHECK: arith.muli
-  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
-  return
-}
-
-} // end module
-
-// -----
-
-// Check that types are converted to 32-bit when no special capabilities.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-LABEL: @int_vector23
-func.func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
-  // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
-  %0 = arith.divsi %arg0, %arg0: vector<2xi8>
-  // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<3xi32>
-  %1 = arith.divsi %arg1, %arg1: vector<3xi16>
-  return
-}
-
-// CHECK-LABEL: @float_scalar
-func.func @float_scalar(%arg0: f16, %arg1: f64) {
-  // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: f32
-  %0 = arith.addf %arg0, %arg0: f16
-  // CHECK: spirv.FMul %{{.*}}, %{{.*}}: f32
-  %1 = arith.mulf %arg1, %arg1: f64
-  return
-}
-
-} // end module
-
-// -----
-
-// Check that types are converted to 32-bit when no special capabilities that
-// are not supported.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
-  // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}}
-  %0 = arith.divui %arg0, %arg0: vector<4xi64>
-  return
-}
-
 } // end module
 
 // -----
@@ -643,17 +586,6 @@ func.func @corner_cases() {
   return
 }
 
-// CHECK-LABEL: @unsupported_cases
-func.func @unsupported_cases() {
-  // CHECK: %{{.*}} = arith.constant 4294967296 : i64
-  %0 = arith.constant 4294967296 : i64 // 2^32
-  // CHECK: %{{.*}} = arith.constant -2147483649 : i64
-  %1 = arith.constant -2147483649 : i64 // -2^31 - 1
-  // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64
-  %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1
-  return
-}
-
 } // end module
 
 // -----
@@ -1258,20 +1190,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
   return
 }
 
-// CHECK-LABEL: @unsupported_5elem_vector
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
-  // CHECK: subi
-  %1 = arith.subi %arg0, %arg0: vector<5xi32>
-  return
-}
-
-// CHECK-LABEL: @unsupported_2x2elem_vector
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
-  // CHECK: muli
-  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
-  return
-}
-
 } // end module
 
 // -----
@@ -1303,22 +1221,6 @@ func.func @float_scalar(%arg0: f16, %arg1: f64) {
 
 // -----
 
-// Check that types are converted to 32-bit when no special capabilities that
-// are not supported.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-func.func @int_vector4_invalid(%arg0: vector<4xi64>) {
-  // expected-error@+1 {{bitwidth emulation is not implemented yet on unsigned op}}
-  %0 = arith.divui %arg0, %arg0: vector<4xi64>
-  return
-}
-
-} // end module
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // std bit ops
 //===----------------------------------------------------------------------===//
@@ -1675,17 +1577,6 @@ func.func @corner_cases() {
   return
 }
 
-// CHECK-LABEL: @unsupported_cases
-func.func @unsupported_cases() {
-  // CHECK: %{{.*}} = arith.constant 4294967296 : i64
-  %0 = arith.constant 4294967296 : i64 // 2^32
-  // CHECK: %{{.*}} = arith.constant -2147483649 : i64
-  %1 = arith.constant -2147483649 : i64 // -2^31 - 1
-  // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64
-  %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1
-  return
-}
-
 } // end module
 
 // -----