From a315534e52fd5c534fadc1e62101543aaf1537a2 Mon Sep 17 00:00:00 2001 From: "a.puschin" Date: Tue, 3 Jan 2023 11:37:59 -0800 Subject: [PATCH] [mlir][tosa] Fix out-of-boundaries iteration for tosa-to-linalg When the number of elements of two shapes are not equal, a Reshape operation cannot be used to transfer one into another Function findIntermediateShape(...) can cause out-of-boundaries operator[] call if the abovementioned condition strikes The test-case I used now causes no error as its root-cause was an issue in Tosa dialect with padded Conv2D operations lowering which is already solved in commit 69c984b6 Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D140013 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 1 + mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 ++++++-- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 9317828..8054c91 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1431,6 +1431,7 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let arguments = (ins Tosa_Tensor:$input1, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 7cb72f4..c4d8b68 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -864,10 +864,14 @@ static bool findIntermediateShape(ArrayRef lhsShape, currRhsDim < rhsShape.size()) { if (lhsSize < rhsSize) { currLhsDim++; - lhsSize *= lhsShape[currLhsDim]; + if (currLhsDim < lhsShape.size()) { + lhsSize *= lhsShape[currLhsDim]; + } } else { currRhsDim++; - rhsSize *= rhsShape[currRhsDim]; + if (currRhsDim < rhsShape.size()) { + rhsSize *= rhsShape[currRhsDim]; + } } } if (lhsSize == rhsSize) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1a3ba90..35c8152 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -700,6 +700,21 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( return success(); } +mlir::LogicalResult tosa::ReshapeOp::verify() { + ShapedType inputType = getInput1().getType().cast(); + ShapedType outputType = getType().cast(); + + if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + int64_t inputElementsNum = inputType.getNumElements(); + int64_t outputElementsNum = outputType.getNumElements(); + if (inputElementsNum != outputElementsNum) { + return emitOpError() << "Cannot reshape " << inputElementsNum + << " elements into " << outputElementsNum; + } + } + return mlir::success(); +} + LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, -- 2.7.4