[mlir] Add reifyReturnShape to shaped type OpInterface
authorJacques Pienaar <jpienaar@google.com>
Fri, 28 Feb 2020 16:37:09 +0000 (08:37 -0800)
committerJacques Pienaar <jpienaar@google.com>
Fri, 28 Feb 2020 16:41:18 +0000 (08:41 -0800)
This call results in inserting operations that compute the return shape
dynamically for the operation.

mlir/include/mlir/Analysis/InferTypeOpInterface.h
mlir/include/mlir/Analysis/InferTypeOpInterface.td
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/return-types.mlir

index 2a64983..4c26285 100644 (file)
@@ -15,6 +15,7 @@
 #define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
index 621d586..548cd09 100644 (file)
@@ -97,6 +97,18 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
                     "SmallVectorImpl<ShapedTypeComponents>&":
                       $inferedReturnShapes)
     >,
+    InterfaceMethod<
+      /*desc=*/[{Reify the shape computation for the operation.
+
+      Insert operations using the given OpBulder that computes the result shape.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"reifyReturnTypeShapes",
+      /*args=*/(ins "OpBuilder&":$builder,
+                    "SmallVectorImpl<Value>&":$reifiedReturnShapes),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{ return failure(); }]
+    >,
   ];
 }
 
index 330b804..12ec279 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
@@ -312,24 +313,24 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
     MLIRContext *context, Optional<Location> location, ValueRange operands,
     ArrayRef<NamedAttribute> attributes, RegionRange regions,
-    SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
-  // Create return type consisting of the first element of each shape of the
-  // input operands or unknown for unranked operand.
-  std::vector<int64_t> shape;
-  shape.reserve(operands.size());
-  for (auto operandType : operands.getTypes()) {
-    if (auto sval = operandType.dyn_cast<ShapedType>()) {
-      if (sval.hasRank())
-        shape.push_back(sval.getShape().front());
-      else
-        shape.push_back(ShapedType::kDynamicSize);
-    } else {
-      return emitOptionalError(location, "only shaped type operands allowed");
-    }
+    SmallVectorImpl<ShapedTypeComponents> &inferedReturnShapes) {
+  // Create return type consisting of the last element of the first operand.
+  auto operandType = *operands.getTypes().begin();
+  auto sval = operandType.dyn_cast<ShapedType>();
+  if (!sval) {
+    return emitOptionalError(location, "only shaped type operands allowed");
   }
-  inferedComponents.reserve(1);
+  int64_t dim =
+      sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
   auto type = IntegerType::get(17, context);
-  inferedComponents.emplace_back(shape, type);
+  inferedReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+  return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+    OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
+  shapes = SmallVector<Value, 1>{
+      builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
   return success();
 }
 
index f899876..decb5e2 100644 (file)
@@ -82,6 +82,19 @@ static void invokeCreateWithInferedReturnType(Operation *op) {
   }
 }
 
+static void reifyReturnShape(Operation *op) {
+  OpBuilder b(op);
+
+  // Use permutations of 2 args as operands.
+  auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
+  SmallVector<Value, 2> shapes;
+  if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
+    return;
+  for (auto it : llvm::enumerate(shapes))
+    op->emitRemark() << "value " << it.index() << ": "
+                     << it.value().getDefiningOp();
+}
+
 struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
   void runOnFunction() override {
     if (getFunction().getName() == "testCreateFunctions") {
@@ -100,6 +113,16 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
       };
       return;
     }
+    if (getFunction().getName() == "testReifyFunctions") {
+      std::vector<Operation *> ops;
+      // Collect ops to avoid triggering on inserted ops.
+      for (auto &op : getFunction().getBody().front())
+        if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
+          ops.push_back(&op);
+      // Generate test patterns for each, but skip terminator.
+      for (auto *op : ops)
+        reifyReturnShape(op);
+    }
   }
 };
 } // end anonymous namespace
index 3fcb223..d0eb364 100644 (file)
@@ -7,13 +7,13 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
 // CHECK: "test.no_attributes"
   %good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
 // CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
 // CHECK: "test.op_with_infer_type_if"
 // CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
 // CHECK: "test.op_with_infer_type_if"
@@ -36,3 +36,14 @@ func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<2
   %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: testReifyFunctions
+func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
+  // expected-remark@+1 {{constant 10}}
+  %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17>
+  // expected-remark@+1 {{constant 20}}
+  %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17>
+  return
+}