[mlir][tosa] Fix out-of-boundaries iteration for tosa-to-linalg
authora.puschin <a.puschin@yadro.com>
Tue, 3 Jan 2023 19:37:59 +0000 (11:37 -0800)
committerRob Suderman <suderman@google.com>
Tue, 3 Jan 2023 19:52:09 +0000 (11:52 -0800)
When the number of elements of two shapes are not equal, a Reshape operation cannot be used to transfer one into another

Function findIntermediateShape(...) can cause out-of-boundaries operator[] call if the abovementioned condition strikes

The test-case I used now causes no error as its root-cause was an issue in Tosa dialect with padded Conv2D operations lowering which is already solved in commit 69c984b6

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D140013

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

index 9317828..8054c91 100644 (file)
@@ -1431,6 +1431,7 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let arguments = (ins
     Tosa_Tensor:$input1,
index 7cb72f4..c4d8b68 100644 (file)
@@ -864,10 +864,14 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
            currRhsDim < rhsShape.size()) {
       if (lhsSize < rhsSize) {
         currLhsDim++;
-        lhsSize *= lhsShape[currLhsDim];
+        if (currLhsDim < lhsShape.size()) {
+          lhsSize *= lhsShape[currLhsDim];
+        }
       } else {
         currRhsDim++;
-        rhsSize *= rhsShape[currRhsDim];
+        if (currRhsDim < rhsShape.size()) {
+          rhsSize *= rhsShape[currRhsDim];
+        }
       }
     }
     if (lhsSize == rhsSize) {
index 1a3ba90..35c8152 100644 (file)
@@ -700,6 +700,21 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
   return success();
 }
 
+mlir::LogicalResult tosa::ReshapeOp::verify() {
+  ShapedType inputType = getInput1().getType().cast<ShapedType>();
+  ShapedType outputType = getType().cast<ShapedType>();
+
+  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
+    int64_t inputElementsNum = inputType.getNumElements();
+    int64_t outputElementsNum = outputType.getNumElements();
+    if (inputElementsNum != outputElementsNum) {
+      return emitOpError() << "Cannot reshape " << inputElementsNum
+                           << " elements into " << outputElementsNum;
+    }
+  }
+  return mlir::success();
+}
+
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,