From 2dd396c18bc035f8f87fb7ca2c33b8f00c287759 Mon Sep 17 00:00:00 2001 From: Aviad Cohen Date: Mon, 17 Apr 2023 09:44:06 +0300 Subject: [PATCH] [mlir] tosa.reshape - Add InferTensorType interface When this interface is used, a call to inferReturnTypeComponents() is generated on creation and verification of the op. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D148498 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 13 ++++++++----- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 13 +++++++++++-- mlir/test/Dialect/Tosa/invalid.mlir | 8 ++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 287e624..e36ab18 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1441,8 +1441,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ // Operator: concat //===----------------------------------------------------------------------===// def Tosa_ConcatOp : Tosa_Op<"concat", [ - InferTensorType, - Pure]> { + InferTensorType, Pure]> { let summary = "Concatenates tensors along one dimension."; let description = [{ @@ -1503,9 +1502,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [ // Operator: reshape //===----------------------------------------------------------------------===// def Tosa_ReshapeOp: Tosa_Op<"reshape", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reshape operator"; let description = [{ @@ -1526,6 +1523,12 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [ let results = (outs Tosa_RankedTensor:$output ); + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b22bd65..2da687b 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -674,19 +674,27 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( return success(); } +bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != r.size() || l.size() != 1) + return false; + return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); +} + LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ReshapeOpAdaptor adaptor(operands, attributes); ShapeAdaptor inputShape = operands.getShape(0); + Type inputType = getElementTypeOrSelf(operands.getType()[0]); llvm::SmallVector newShapeValue = convertToMlirShape(adaptor.getNewShape()); // We cannot infer from the total number of elements so we must take the // shape attribute as exact. if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { - inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + inferredReturnShapes.push_back( + ShapedTypeComponents(newShapeValue, inputType)); return success(); } @@ -707,7 +715,8 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( val = numElements / staticMul; } - inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue)); + inferredReturnShapes.push_back( + ShapedTypeComponents(newShapeValue, inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c05a1c4..27661f4 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -128,3 +128,11 @@ func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> return } + +// ----- + +func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { + // expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<13x21x3x1xi32> + return +} -- 2.7.4