[mlir][tosa] Updates tosa.equal to use the InferTensorType interface
authornot-jenni <jennik@google.com>
Mon, 8 Aug 2022 22:19:51 +0000 (15:19 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 8 Aug 2022 23:11:30 +0000 (16:11 -0700)
Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D130373

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

index c88f872..3f00587 100644 (file)
@@ -1139,10 +1139,8 @@ def Tosa_SelectOp : Tosa_Op<"select", [
 //===----------------------------------------------------------------------===//
 // Operator: equal
 //===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [
-    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>,
-    ResultsBroadcastableShape, Commutative, NoSideEffect]> {
+def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
+    Commutative, NoSideEffect]> {
   let summary = "Returns the truth value of (x == y) element-wise.";
 
   let description = [{
@@ -1157,6 +1155,12 @@ def Tosa_EqualOp : Tosa_Op<"equal", [
   let results = (outs
     I1Tensor:$output
   );
+
+  let extraClassDeclaration = [{
+    /// Returns when two result types are compatible for this op; method used by
+    /// InferTypeOpInterface.
+    static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
index 11ff2a0..90fe70d 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -339,6 +340,44 @@ static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
   }
 }
 
+static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
+                                           SmallVector<int64_t> &outShape) {
+  int64_t outRank = 0;
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    if (!shape.hasRank()) {
+      // TODO(jennik): Update function to have better case handling for invalid
+      // operands and for ranked tensors.
+      return failure();
+    }
+    outRank = std::max<int64_t>(outRank, shape.getRank());
+  }
+
+  outShape.resize(outRank, 1);
+
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    auto rankDiff = outShape.size() - shape.getRank();
+
+    for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
+      auto dim1 = outShape[i + rankDiff];
+      auto dim2 = shape.getDimSize(i);
+      auto resolvedDim = dim1;
+
+      if (dim1 == 1) {
+        resolvedDim = dim2;
+      } else if (dim2 == 1) {
+        resolvedDim = dim1;
+      } else if (dim1 != dim2) {
+        return failure();
+      }
+      outShape[i + rankDiff] = resolvedDim;
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -421,6 +460,27 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::EqualOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<int64_t> outShape;
+  if (resolveBroadcastShape(operands, outShape).failed()) {
+    inferredReturnShapes.push_back(ShapedTypeComponents());
+    return success();
+  }
+
+  inferredReturnShapes.push_back(
+      ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+  return success();
+}
+
+bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  if (l.size() != r.size() || l.size() != 1)
+    return false;
+  return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
@@ -870,42 +930,6 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 #undef REDUCE_SHAPE_INFER
 
-static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
-                                           SmallVector<int64_t> &outShape) {
-  int64_t outRank = 0;
-  for (int i = 0, e = operands.size(); i != e; ++i) {
-    auto shape = operands.getShape(i);
-    if (!shape.hasRank()) {
-      return failure();
-    }
-    outRank = std::max<int64_t>(outRank, shape.getRank());
-  }
-
-  outShape.resize(outRank, 1);
-
-  for (int i = 0, e = operands.size(); i != e; ++i) {
-    auto shape = operands.getShape(i);
-    auto rankDiff = outShape.size() - shape.getRank();
-
-    for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
-      auto dim1 = outShape[i + rankDiff];
-      auto dim2 = shape.getDimSize(i);
-      auto resolvedDim = dim1;
-
-      if (dim1 == 1) {
-        resolvedDim = dim2;
-      } else if (dim2 == 1) {
-        resolvedDim = dim1;
-      } else if (dim1 != dim2) {
-        return failure();
-      }
-      outShape[i + rankDiff] = resolvedDim;
-    }
-  }
-
-  return success();
-}
-
 static LogicalResult NAryInferReturnTypes(
     const ValueShapeRange &operands,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -939,7 +963,6 @@ NARY_SHAPE_INFER(tosa::CeilOp)
 NARY_SHAPE_INFER(tosa::ClampOp)
 NARY_SHAPE_INFER(tosa::ClzOp)
 NARY_SHAPE_INFER(tosa::DivOp)
-NARY_SHAPE_INFER(tosa::EqualOp)
 NARY_SHAPE_INFER(tosa::ExpOp)
 NARY_SHAPE_INFER(tosa::FloorOp)
 NARY_SHAPE_INFER(tosa::GreaterEqualOp)