[mlir] tosa.concat - Add InferTensorType interface
authorMaya Amrami <maya.amrami@mobileye.com>
Thu, 9 Mar 2023 11:51:27 +0000 (13:51 +0200)
committerMaya Amrami <mayaam88@gmail.com>
Tue, 21 Mar 2023 15:01:08 +0000 (17:01 +0200)
When this interface is used, a call to inferReturnTypeComponents()
is generated on creation and verification of the op.
A few changes were required in inferReturnTypeComponents():
- Emit error when it fails.
  The verifier calls this method now, and it is preferable to
  indicate what caused the failure.
- Fix the inferred return shapes so they have a type too.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D146132

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index be5720c..7c8018a 100644 (file)
@@ -1419,8 +1419,7 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
 // Operator: concat
 //===----------------------------------------------------------------------===//
 def Tosa_ConcatOp : Tosa_Op<"concat", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
+                              InferTensorType,
     Pure]> {
   let summary = "Concatenates tensors along one dimension.";
 
@@ -1439,6 +1438,12 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
   );
 
   let hasCanonicalizer = 1;
+
+  let extraClassDeclaration = [{
+    /// Returns true when two result types are compatible for this op;
+    /// Method used by InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
index d7bb6d0..0a09cdd 100644 (file)
@@ -422,6 +422,12 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != r.size() || l.size() != 1)
+    return false;
+  return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -447,14 +453,17 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
       if (outputShape[i] == ShapedType::kDynamic)
         outputShape[i] = operandShape.getDimSize(i);
       if (outputShape[i] != operandShape.getDimSize(i))
-        return failure();
+        return emitOptionalError(location,
+                                 "Cannot concat tensors with different sizes"
+                                 " on the non-axis dimension ",
+                                 i);
     }
 
     hasRankedInput = true;
   }
-
+  Type inputType = operands.getType()[0].cast<TensorType>().getElementType();
   if (!hasRankedInput) {
-    inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
     return success();
   }
 
@@ -475,7 +484,7 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
 
   outputShape[axis] = concatDimSize;
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
index c81b196..9f9c6ca 100644 (file)
@@ -36,4 +36,10 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>,
   return %0 : tensor<1x27x27x16xi8>
 }
 
+// -----
 
+func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
+  // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
index 94eea3b..5053507 100644 (file)
@@ -491,16 +491,6 @@ func.func @test_concat_axis_1(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>)
 
 // -----
 
-// CHECK-LABEL: @test_concat_failure
-func.func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () {
-  // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-  %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
-
-  return
-}
-
-// -----
-
 // CHECK-LABEL: @test_padding_no_const
 func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () {
   // CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor<?x?xf32>