From: Jakub Kuderski Date: Thu, 15 Dec 2022 00:32:40 +0000 (-0500) Subject: [mlir][arith][spirv] Account for possible type conversion failures X-Git-Tag: upstream/17.0.6~23765 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4f47677dee24b78548b49f3abca2c1ea65a79c8a;p=platform%2Fupstream%2Fllvm.git [mlir][arith][spirv] Account for possible type conversion failures Check results of all type conversions in `--convert-arith-to-spirv`. Fixes: https://github.com/llvm/llvm-project/issues/59496 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D140033 --- diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index ed5d044..9533494 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -320,10 +320,13 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { + assert(type && "Not a valid type"); if (type.isInteger(1)) return true; + if (auto vecType = type.dyn_cast()) return vecType.getElementType().isInteger(1); + return false; } @@ -343,6 +346,22 @@ static bool hasSameBitwidth(Type a, Type b) { return aBW != 0 && bBW != 0 && aBW == bBW; } +/// Returns a source type conversion failure for `srcType` and operation `op`. +static LogicalResult +getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, + Type srcType) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert source type '{0}'", srcType)); +} + +/// Returns a source type conversion failure for the result type of `op`. +static LogicalResult +getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { + assert(op->getNumResults() == 1); + return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); +} + //===----------------------------------------------------------------------===// // ConstantOp with composite type //===----------------------------------------------------------------------===// @@ -562,10 +581,10 @@ BitwiseOpPattern::matchAndRewrite( Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 2); - auto dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = this->getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); @@ -590,7 +609,8 @@ LogicalResult XOrIOpLogicalPattern::matchAndRewrite( Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); @@ -611,7 +631,8 @@ LogicalResult XOrIOpBooleanPattern::matchAndRewrite( Type dstType = getTypeConverter()->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); @@ -628,7 +649,10 @@ UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, if (!isBoolScalarOrVector(srcType)) return failure(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -649,7 +673,9 @@ ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, return failure(); Location loc = op.getLoc(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); Value allOnes; if (auto intTy = dstType.dyn_cast()) { @@ -684,7 +710,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, if (!isBoolScalarOrVector(srcType)) return failure(); - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -700,7 +729,10 @@ ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, LogicalResult TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + if (!isBoolScalarOrVector(dstType)) return failure(); @@ -728,10 +760,13 @@ LogicalResult TypeCastingOpPattern::matchAndRewrite( ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 1); Type srcType = adaptor.getOperands().front().getType(); - Type dstType = - this->getTypeConverter()->convertType(op.getResult().getType()); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); + if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. @@ -755,7 +790,7 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { case arith::CmpIPredicate::eq: { @@ -804,7 +839,7 @@ CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op, srcType); switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ @@ -999,7 +1034,7 @@ LogicalResult MinMaxFOpPattern::matchAndRewrite( auto *converter = this->template getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (!dstType) - return failure(); + return getTypeConversionFailure(rewriter, op); // arith.maxf/minf: // "if one of the arguments is NaN, then the result is also NaN." diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index f6e84e8..0d92a8e 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -130,3 +130,19 @@ func.func @unsupported_f64(%arg0: f64) { } } // end module + +// ----- + +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// i64 is not a valid result type in this target env. +func.func @type_conversion_failure(%arg0: i32) { + // expected-error@+1 {{failed to legalize operation 'arith.extsi'}} + %2 = arith.extsi %arg0 : i32 to i64 + return +} + +} // end module