From 179978d7b8ec00291401d2ec49fc0a55e7f7bfb3 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 31 Oct 2022 17:00:45 -0400 Subject: [PATCH] [mlir][arith][spirv] Hard fail in `-convert-arith-to-spirv` Turn legalization failures into hard failures to make sure that we do not miss conversion pattern application failures. Add a message on type conversion failure. Move unsupported cases into a separate test file. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D137102 --- mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 3 + mlir/lib/Conversion/SPIRVCommon/Pattern.h | 11 ++- .../ArithToSPIRV/arith-to-spirv-unsupported.mlir | 72 ++++++++++++++ .../Conversion/ArithToSPIRV/arith-to-spirv.mlir | 109 --------------------- 4 files changed, 83 insertions(+), 112 deletions(-) create mode 100644 mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 24bead8..2452928 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -1046,6 +1046,9 @@ struct ConvertArithToSPIRVPass typeConverter.addTargetMaterialization(addUnrealizedCast); target->addLegalOp(); + // Fail hard when there are any remaining 'arith' ops. + target->addIllegalDialect(); + RewritePatternSet patterns(&getContext()); arith::populateArithToSPIRVPatterns(typeConverter, patterns); diff --git a/mlir/lib/Conversion/SPIRVCommon/Pattern.h b/mlir/lib/Conversion/SPIRVCommon/Pattern.h index 5d32fa8..ed859a8 100644 --- a/mlir/lib/Conversion/SPIRVCommon/Pattern.h +++ b/mlir/lib/Conversion/SPIRVCommon/Pattern.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" namespace mlir { namespace spirv { @@ -26,9 +27,13 @@ public: matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() <= 3); - auto dstType = this->getTypeConverter()->convertType(op.getType()); - if (!dstType) - return failure(); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert type {0} for SPIR-V", op.getType())); + } + if (SPIRVOp::template hasTrait() && !op.getType().isIndex() && dstType != op.getType()) { return op.emitError( diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir new file mode 100644 index 0000000..967adbc --- /dev/null +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s + +///===----------------------------------------------------------------------===// +// Binary ops +//===----------------------------------------------------------------------===// + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { + // expected-error@+1 {{failed to legalize operation 'arith.subi'}} + %1 = arith.subi %arg0, %arg0: vector<5xi32> + return +} + +} // end module + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { + // expected-error@+1 {{failed to legalize operation 'arith.muli'}} + %2 = arith.muli %arg0, %arg0: vector<2x2xi32> + return +} + +} // end module + +// ----- + +func.func @int_vector4_invalid(%arg0: vector<2xi16>) { + // expected-error @+2 {{failed to legalize operation 'arith.divui'}} + // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}} + %0 = arith.divui %arg0, %arg0: vector<2xi16> + return +} + +///===----------------------------------------------------------------------===// +// Constant ops +//===----------------------------------------------------------------------===// + +// ----- + +func.func @unsupported_constant_0() { + // expected-error @+1 {{failed to legalize operation 'arith.constant'}} + %0 = arith.constant 4294967296 : i64 // 2^32 + return +} + +// ----- + +func.func @unsupported_constant_1() { + // expected-error @+1 {{failed to legalize operation 'arith.constant'}} + %1 = arith.constant -2147483649 : i64 // -2^31 - 1 + return +} + +// ----- + +func.func @unsupported_constant_2() { + // expected-error @+1 {{failed to legalize operation 'arith.constant'}} + %2 = arith.constant -2147483649 : i64 // -2^31 - 1 + return +} diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index bd5238f..df6806a 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -163,63 +163,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) { return } -// CHECK-LABEL: @unsupported_5elem_vector -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { - // CHECK: arith.subi - %1 = arith.subi %arg0, %arg0: vector<5xi32> - return -} - -// CHECK-LABEL: @unsupported_2x2elem_vector -func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { - // CHECK: arith.muli - %2 = arith.muli %arg0, %arg0: vector<2x2xi32> - return -} - -} // end module - -// ----- - -// Check that types are converted to 32-bit when no special capabilities. -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> -} { - -// CHECK-LABEL: @int_vector23 -func.func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) { - // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<2xi32> - %0 = arith.divsi %arg0, %arg0: vector<2xi8> - // CHECK: spirv.SDiv %{{.*}}, %{{.*}}: vector<3xi32> - %1 = arith.divsi %arg1, %arg1: vector<3xi16> - return -} - -// CHECK-LABEL: @float_scalar -func.func @float_scalar(%arg0: f16, %arg1: f64) { - // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: f32 - %0 = arith.addf %arg0, %arg0: f16 - // CHECK: spirv.FMul %{{.*}}, %{{.*}}: f32 - %1 = arith.mulf %arg1, %arg1: f64 - return -} - -} // end module - -// ----- - -// Check that types are converted to 32-bit when no special capabilities that -// are not supported. -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> -} { - -func.func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error @+1 {{bitwidth emulation is not implemented yet on unsigned op}} - %0 = arith.divui %arg0, %arg0: vector<4xi64> - return -} - } // end module // ----- @@ -643,17 +586,6 @@ func.func @corner_cases() { return } -// CHECK-LABEL: @unsupported_cases -func.func @unsupported_cases() { - // CHECK: %{{.*}} = arith.constant 4294967296 : i64 - %0 = arith.constant 4294967296 : i64 // 2^32 - // CHECK: %{{.*}} = arith.constant -2147483649 : i64 - %1 = arith.constant -2147483649 : i64 // -2^31 - 1 - // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64 - %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1 - return -} - } // end module // ----- @@ -1258,20 +1190,6 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) { return } -// CHECK-LABEL: @unsupported_5elem_vector -func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) { - // CHECK: subi - %1 = arith.subi %arg0, %arg0: vector<5xi32> - return -} - -// CHECK-LABEL: @unsupported_2x2elem_vector -func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) { - // CHECK: muli - %2 = arith.muli %arg0, %arg0: vector<2x2xi32> - return -} - } // end module // ----- @@ -1303,22 +1221,6 @@ func.func @float_scalar(%arg0: f16, %arg1: f64) { // ----- -// Check that types are converted to 32-bit when no special capabilities that -// are not supported. -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> -} { - -func.func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error@+1 {{bitwidth emulation is not implemented yet on unsigned op}} - %0 = arith.divui %arg0, %arg0: vector<4xi64> - return -} - -} // end module - -// ----- - //===----------------------------------------------------------------------===// // std bit ops //===----------------------------------------------------------------------===// @@ -1675,17 +1577,6 @@ func.func @corner_cases() { return } -// CHECK-LABEL: @unsupported_cases -func.func @unsupported_cases() { - // CHECK: %{{.*}} = arith.constant 4294967296 : i64 - %0 = arith.constant 4294967296 : i64 // 2^32 - // CHECK: %{{.*}} = arith.constant -2147483649 : i64 - %1 = arith.constant -2147483649 : i64 // -2^31 - 1 - // CHECK: %{{.*}} = arith.constant 1.0000000000000002 : f64 - %2 = arith.constant 0x3FF0000000000001 : f64 // smallest number > 1 - return -} - } // end module // ----- -- 2.7.4