From 9d0b596aada6fb2166dd4f6f58e359fbac483154 Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Mon, 13 Feb 2023 12:54:42 -0800 Subject: [PATCH] [mlir][tosa] Fix segmentation fault in case of folding unranked tensor Trying to fold the unranked tensor for "tosa.equal" crashes due to null reference. We need to check the dynamic cast result beforehand. This is reported in https://github.com/llvm/llvm-project/issues/60192. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D143034 --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++-- mlir/test/Dialect/Tosa/constant_folding.mlir | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 74325c8..1a8a578 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -803,8 +803,8 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { // If we are comparing an integer value to itself it is always true. We can // not do this with float due to float values. - if (lhsTy.getElementType().isa() && resultTy.hasStaticShape() && - lhs == rhs) { + if (lhsTy.getElementType().isa() && resultTy && + resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); } diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir index 3111b12..259b2ea 100644 --- a/mlir/test/Dialect/Tosa/constant_folding.mlir +++ b/mlir/test/Dialect/Tosa/constant_folding.mlir @@ -6,3 +6,11 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> { %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> return %0 : tensor<4xi32> } + +// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor +func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor) { + // CHECK: "tosa.equal" + // CHECK-NEXT: return + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<4xi32>, tensor) -> tensor<*xi1> + return +} -- 2.7.4