From 3500e11065d616f4653ea8ba8c979b29c69a00d7 Mon Sep 17 00:00:00 2001 From: Aviad Cohen Date: Sun, 2 Apr 2023 12:12:15 +0300 Subject: [PATCH] [mlir][tosa] Add InferTensorType interface to tosa reduce operations When this interface is used, a call to inferReturnTypeComponents() is generated on creation and verification of the op. Reviewed By: jpienaar, eric-k256 Differential Revision: https://reviews.llvm.org/D147407 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 60 +++++++++++++++++++--------- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 32 +++++++++------ mlir/test/Dialect/Tosa/invalid.mlir | 32 +++++++++++++++ 3 files changed, 93 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 043098f..287e624 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1262,9 +1262,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ // Operator: reduce_all //===----------------------------------------------------------------------===// def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce All operator"; let description = [{ @@ -1281,15 +1279,19 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_any //===----------------------------------------------------------------------===// def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Any operator"; let description = [{ @@ -1306,15 +1308,19 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_max //===----------------------------------------------------------------------===// def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Max operator"; let description = [{ @@ -1331,15 +1337,19 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_min //===----------------------------------------------------------------------===// def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Min operator"; let description = [{ @@ -1356,15 +1366,19 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_prod //===----------------------------------------------------------------------===// def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Prod operator"; let description = [{ @@ -1381,15 +1395,19 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// // Operator: reduce_sum //===----------------------------------------------------------------------===// def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ - DeclareOpInterfaceMethods, - Pure]> { + InferTensorType, Pure]> { let summary = "Reduce Sum operator"; let description = [{ @@ -1406,6 +1424,12 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [ ); let hasFolder = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 13a4351..b22bd65 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -422,14 +422,6 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( return success(); } -bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { - if (l.size() != r.size() || l.size() != 1) - return false; - if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) - return false; - return succeeded(verifyCompatibleShape(l[0], r[0])); -} - LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -913,10 +905,10 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents( } static LogicalResult ReduceInferReturnTypes( - ShapeAdaptor operandShape, IntegerAttr axis, + ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { if (!operandShape.hasRank()) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } @@ -924,20 +916,32 @@ static LogicalResult ReduceInferReturnTypes( operandShape.getDims(outputShape); int64_t axisVal = axis.getValue().getSExtValue(); outputShape[axisVal] = 1; - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } +#define COMPATIBLE_RETURN_TYPES(OP) \ + bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ + if (l.size() != r.size() || l.size() != 1) \ + return false; \ + if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ + return false; \ + return succeeded(verifyCompatibleShape(l[0], r[0])); \ + } + #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ ValueShapeRange operands, DictionaryAttr attributes, \ RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ - return ReduceInferReturnTypes(operands.getShape(0), \ + Type inputType = \ + operands.getType()[0].cast().getElementType(); \ + return ReduceInferReturnTypes(operands.getShape(0), inputType, \ attributes.get("axis").cast(), \ inferredReturnShapes); \ - } + } \ + COMPATIBLE_RETURN_TYPES(OP) REDUCE_SHAPE_INFER(tosa::ReduceAllOp) REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) @@ -946,6 +950,8 @@ REDUCE_SHAPE_INFER(tosa::ReduceMinOp) REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER +COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) +#undef COMPATIBLE_RETURN_TYPES static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 5a120ee..c05a1c4 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -96,3 +96,35 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten %2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> return %2 : tensor<273x2xf32> } + +// ----- + +func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}} + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32> + return +} + +// ----- + +func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}} + %0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32> + return +} + +// ----- + +func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}} + %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32> + return +} + +// ----- + +func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () { + // expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}} + %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> + return +} -- 2.7.4