From f35ac8a4ffbecd1fee09731e5a9a242e6425df80 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 28 Feb 2023 15:00:45 -0500 Subject: [PATCH] Fix SimplifyAllocConst pattern when we have alloc of negative sizes This is UB, but we shouldn't crash the compiler either. Fixes #61056 Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D144978 --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 ++++++++----- mlir/test/Dialect/MemRef/canonicalize.mlir | 12 ++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 6814aa5..45fac0c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -284,7 +284,10 @@ struct SimplifyAllocConst : public OpRewritePattern { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); + APInt constSizeArg; + if (!matchPattern(operand, m_ConstantInt(&constSizeArg))) + return false; + return constSizeArg.isNonNegative(); })) return failure(); @@ -305,11 +308,11 @@ struct SimplifyAllocConst : public OpRewritePattern { continue; } auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos]; - auto *defOp = dynamicSize.getDefiningOp(); - if (auto constantIndexOp = - dyn_cast_or_null(defOp)) { + APInt constSizeArg; + if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) && + constSizeArg.isNonNegative()) { // Dynamic shape dimension will be folded. - newShapeConstants.push_back(constantIndexOp.value()); + newShapeConstants.push_back(constSizeArg.getZExtValue()); } else { // Dynamic shape dimension not folded; copy dynamicSize from old memref. newShapeConstants.push_back(ShapedType::kDynamic); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 4295947..b65426c 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -928,3 +928,15 @@ func.func @fold_multiple_memory_space_cast(%arg : memref) -> memref to memref return %1 : memref } + +// ----- + +// CHECK-lABEL: func @ub_negative_alloc_size +func.func private @ub_negative_alloc_size() -> memref { + %idx1 = index.constant 1 + %c-2 = arith.constant -2 : index + %c15 = arith.constant 15 : index +// CHECK: %[[ALLOC:.*]] = memref.alloc(%c-2) : memref<15x?x1xi1> + %alloc = memref.alloc(%c15, %c-2, %idx1) : memref + return %alloc : memref +} -- 2.7.4