[mlir][tosa] Add InferTensorType interface to tosa reduce operations
authorAviad Cohen <aviadcohen7@gmail.com>
Sun, 2 Apr 2023 09:12:15 +0000 (12:12 +0300)
committerAviad Cohen <aviadcohen7@gmail.com>
Wed, 5 Apr 2023 04:25:10 +0000 (07:25 +0300)
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.

Reviewed By: jpienaar, eric-k256

Differential Revision: https://reviews.llvm.org/D147407

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir

index 043098f..287e624 100644 (file)
@@ -1262,9 +1262,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
 // Operator: reduce_all
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce All operator";
 
   let description = [{
@@ -1281,15 +1279,19 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_any
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Any operator";
 
   let description = [{
@@ -1306,15 +1308,19 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_max
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Max operator";
 
   let description = [{
@@ -1331,15 +1337,19 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_min
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Min operator";
 
   let description = [{
@@ -1356,15 +1366,19 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_prod
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Prod operator";
 
   let description = [{
@@ -1381,15 +1395,19 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Operator: reduce_sum
 //===----------------------------------------------------------------------===//
 def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    Pure]> {
+    InferTensorType, Pure]> {
   let summary = "Reduce Sum operator";
 
   let description = [{
@@ -1406,6 +1424,12 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
   );
 
   let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
index 13a4351..b22bd65 100644 (file)
@@ -422,14 +422,6 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
-bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
-  if (l.size() != r.size() || l.size() != 1)
-    return false;
-  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))
-    return false;
-  return succeeded(verifyCompatibleShape(l[0], r[0]));
-}
-
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -913,10 +905,10 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 }
 
 static LogicalResult ReduceInferReturnTypes(
-    ShapeAdaptor operandShape, IntegerAttr axis,
+    ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   if (!operandShape.hasRank()) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
@@ -924,20 +916,32 @@ static LogicalResult ReduceInferReturnTypes(
   operandShape.getDims(outputShape);
   int64_t axisVal = axis.getValue().getSExtValue();
   outputShape[axisVal] = 1;
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
+#define COMPATIBLE_RETURN_TYPES(OP)                                            \
+  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) {                 \
+    if (l.size() != r.size() || l.size() != 1)                                 \
+      return false;                                                            \
+    if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0]))              \
+      return false;                                                            \
+    return succeeded(verifyCompatibleShape(l[0], r[0]));                       \
+  }
+
 #define REDUCE_SHAPE_INFER(OP)                                                 \
   LogicalResult OP::inferReturnTypeComponents(                                 \
       MLIRContext *context, ::std::optional<Location> location,                \
       ValueShapeRange operands, DictionaryAttr attributes,                     \
       RegionRange regions,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
-    return ReduceInferReturnTypes(operands.getShape(0),                        \
+    Type inputType =                                                           \
+        operands.getType()[0].cast<TensorType>().getElementType();             \
+    return ReduceInferReturnTypes(operands.getShape(0), inputType,             \
                                   attributes.get("axis").cast<IntegerAttr>(),  \
                                   inferredReturnShapes);                       \
-  }
+  }                                                                            \
+  COMPATIBLE_RETURN_TYPES(OP)
 
 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
@@ -946,6 +950,8 @@ REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 #undef REDUCE_SHAPE_INFER
+COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
+#undef COMPATIBLE_RETURN_TYPES
 
 static LogicalResult NAryInferReturnTypes(
     const ValueShapeRange &operands,
index 5a120ee..c05a1c4 100644 (file)
@@ -96,3 +96,35 @@ func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: ten
   %2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
   return %2 : tensor<273x2xf32>
 }
+
+// -----
+
+func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+  // expected-error@+1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
+  %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<2x3x4x5xf32>) -> tensor<1x3x4x5xi32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_max_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+  // expected-error@+1 {{'tosa.reduce_max' op inferred type(s) 'tensor<2x3x4x1xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x1xi32>'}}
+  %0 = "tosa.reduce_max"(%arg0) {axis = 3 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x1xi32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_min_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+  // expected-error@+1 {{'tosa.reduce_min' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x1x4x5xi32>'}}
+  %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x1x4x5xi32>
+  return
+}
+
+// -----
+
+func.func @test_reduce_prod_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
+  // expected-error@+1 {{'tosa.reduce_prod' op inferred type(s) 'tensor<2x1x4x5xf32>' are incompatible with return type(s) of operation 'tensor<2x3x4x5xf32>'}}
+  %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
+  return
+}