[mlir][arith] Improve EmulateWideInt diagnostics
authorJakub Kuderski <kubak@google.com>
Tue, 11 Oct 2022 18:24:38 +0000 (14:24 -0400)
committerJakub Kuderski <kubak@google.com>
Tue, 11 Oct 2022 18:24:54 +0000 (14:24 -0400)
Print unsupported types on match failures.

Suggested by @Mogball and @jpienaar in D135204.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D135673

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp

index c53abbc..9784f0d 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
 
@@ -264,7 +265,8 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
                      ->convertType(op.getType())
                      .dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newElemTy = reduceInnermostDim(newTy);
 
@@ -305,7 +307,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
                      ->convertType(op.getType())
                      .template dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "unsupported type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     auto [lhsElem0, lhsElem1] =
         extractLastDimHalves(rewriter, loc, adaptor.getLhs());
@@ -336,7 +339,8 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
                      ->convertType(op.getType())
                      .dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "expected scalar or vector type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newElemTy = reduceInnermostDim(newTy);
     unsigned newBitWidth = newTy.getElementTypeBitWidth();
@@ -430,7 +434,8 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
                      ->convertType(op.getType())
                      .dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "unsupported type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newResultComponentTy = reduceInnermostDim(newTy);
 
@@ -469,7 +474,8 @@ struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
                      ->convertType(op.getType())
                      .dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "unsupported type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newResultComponentTy = reduceInnermostDim(newTy);
 
@@ -501,7 +507,8 @@ struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
     auto newTy =
         getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "unsupported type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newOperandTy = reduceInnermostDim(newTy);
     // `oldBitWidth` == `2 * newBitWidth`
@@ -590,7 +597,8 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
     auto newTy =
         getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
     if (!newTy)
-      return rewriter.notifyMatchFailure(loc, "unsupported type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", op.getType()));
 
     Type newOperandTy = reduceInnermostDim(newTy);
     // `oldBitWidth` == `2 * newBitWidth`
@@ -677,8 +685,9 @@ struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
     // Check if the result type is legal for this target. Currently, we do not
     // support truncation to types wider than supported by the target.
     if (!getTypeConverter()->isLegal(op.getType()))
-      return rewriter.notifyMatchFailure(loc,
-                                         "unsupported truncation result type");
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported truncation result type: {0}",
+                             op.getType()));
 
     // Discard the high half of the input. Truncate the low half, if necessary.
     Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);