Fix crash in ConvertVectorToLLVM.cpp pattern
authorMehdi Amini <joker.eph@gmail.com>
Fri, 3 Mar 2023 10:14:17 +0000 (11:14 +0100)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 3 Mar 2023 10:16:43 +0000 (11:16 +0100)
Fixes #61094

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/GPUCommon/lower-vector.mlir [new file with mode: 0644]

index b73c01a..d1b78bf 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Casting.h"
 #include <optional>
 
 using namespace mlir;
@@ -820,11 +821,10 @@ public:
   matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override final {
     // Match against the maskable operation kind.
-    Operation *maskableOp = maskOp.getMaskableOp();
-    if (!isa<MaskedOp>(maskableOp))
+    auto maskedOp = llvm::dyn_cast_or_null<MaskedOp>(maskOp.getMaskableOp());
+    if (!maskedOp)
       return failure();
-    return matchAndRewriteMaskableOp(
-        maskOp, cast<MaskedOp>(maskOp.getMaskableOp()), rewriter);
+    return matchAndRewriteMaskableOp(maskOp, maskedOp, rewriter);
   }
 
 protected:
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
new file mode 100644 (file)
index 0000000..44deb45
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
+
+module {
+  func.func @func(%arg: vector<11xf32>) {
+    %cst_41 = arith.constant dense<true> : vector<11xi1>
+    // CHECK: vector.mask
+    // CHECK-SAME: vector.yield %arg0
+    %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
+    return
+  }
+}