#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
return it->second;
};
+ // Check whether this use case is replaceable. We define an op as
+ // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
+ // type-inference related interface.
+ auto isReplaceableUser = [](Operation *user) -> bool {
+ return isa<func::ReturnOp>(user) ||
+ user->getDialect()->getNamespace() ==
+ TosaDialect::getDialectNamespace() ||
+ isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+ };
+
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
- // Check whether this use case is replaceable. We define an op as
- // being replaceable if it is used by a ReturnOp or a TosaOp.
- bool replaceable = true;
- for (auto *user : result.getUsers()) {
- if (isa<func::ReturnOp>(user))
- continue;
- if (user->getDialect()->getNamespace() ==
- TosaDialect::getDialectNamespace())
- continue;
-
- replaceable = false;
- }
+ if (!llvm::all_of(result.getUsers(), isReplaceableUser))
+ continue;
// Determine the knowledge based on the output type.
// TODO: should also query WIP type probably
}
}
- if (!replaceable)
- continue;
-
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
return
}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape
+func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape {
+ // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+ %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+ return %1 : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_shape2
+func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor<?xindex> {
+ // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+ %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: test_non_tosa_consumer_extract
+func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) -> f32 {
+ // CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
+ %0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<?x?xf32>
+ %1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
+ return %1 : f32
+}