[tosa] Improve inferred shapes of TOSA operations
authorSpenser Bauman <sbauman@mathworks.com>
Thu, 25 May 2023 23:06:50 +0000 (16:06 -0700)
committerEric Kunze <eric.kunze@arm.com>
Thu, 25 May 2023 23:27:13 +0000 (16:27 -0700)
The TosaInferShapes pass avoids updating the shapes of tensor operators
when the consumers are not TOSA operations, limiting the efficacy of
TosaInferShapes when the IR is a mix of TOSA and other operations.
This change attempts to update the result shapes when the consumers
themselves have reasonable type/shape inference methods.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151228

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

index 65b66d2..9c49cd5 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -201,6 +202,16 @@ void propagateShapesInRegion(Region &region) {
     return it->second;
   };
 
+  // Check whether this use case is replaceable. We define an op as
+  // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
+  // type-inference related interface.
+  auto isReplaceableUser = [](Operation *user) -> bool {
+    return isa<func::ReturnOp>(user) ||
+           user->getDialect()->getNamespace() ==
+               TosaDialect::getDialectNamespace() ||
+           isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+  };
+
   for (auto &block : region) {
     for (Operation &op : block) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -227,18 +238,8 @@ void propagateShapesInRegion(Region &region) {
           Value result = std::get<0>(it);
           ShapedTypeComponents predictedShape = std::get<1>(it);
 
-          // Check whether this use case is replaceable. We define an op as
-          // being replaceable if it is used by a ReturnOp or a TosaOp.
-          bool replaceable = true;
-          for (auto *user : result.getUsers()) {
-            if (isa<func::ReturnOp>(user))
-              continue;
-            if (user->getDialect()->getNamespace() ==
-                TosaDialect::getDialectNamespace())
-              continue;
-
-            replaceable = false;
-          }
+          if (!llvm::all_of(result.getUsers(), isReplaceableUser))
+            continue;
 
           // Determine the knowledge based on the output type.
           // TODO: should also query WIP type probably
@@ -256,9 +257,6 @@ void propagateShapesInRegion(Region &region) {
             }
           }
 
-          if (!replaceable)
-            continue;
-
           // Compute the new type based on the joined version.
           auto newKnowledge =
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
index 5bbb6e1..bf91336 100644 (file)
@@ -1237,3 +1237,33 @@ func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor<f32>) -> ()
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape
+func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape2
+func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor<?xindex> {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_extract
+func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) -> f32 {
+  // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+  %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<?x?xf32>
+  %1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
+  return %1 : f32
+}