[shape] Add inferReturnTypes to a couple ops.
authorSean Silva <silvasean@google.com>
Fri, 24 Apr 2020 22:54:22 +0000 (15:54 -0700)
committerSean Silva <silvasean@google.com>
Fri, 24 Apr 2020 23:10:20 +0000 (16:10 -0700)
- ShapeOfOp
- BroadcastOp

Differential Revision: https://reviews.llvm.org/D78822

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp

index f54456b..fa277f4 100644 (file)
@@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
   let results = (outs Shape_SizeType:$result);
 }
 
-def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
+def Shape_BroadcastOp : Shape_Op<"broadcast",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Returns the broadcasted output shape of two inputs";
   let description = [{
     Computes the broadcasted output shape following:
@@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
   let regions = (region SizedRegion<1>:$body);
 }
 
-def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
+def Shape_ShapeOfOp : Shape_Op<"shape_of",
+    [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "Returns shape of a value or shaped type operand";
 
   let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
index 4a1c0f1..10e766f 100644 (file)
@@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult BroadcastOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(ShapeType::get(context));
+  return success();
+}
+
 OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
   if (!operands[0] || !operands[1])
     return nullptr;
@@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult ShapeOfOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    ArrayRef<NamedAttribute> attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(ShapeType::get(context));
+  return success();
+}
+
 OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   auto type = getOperand().getType().dyn_cast<ShapedType>();
   if (!type || !type.hasStaticShape())