Split `InferShapedTypeOpInterface` to create `ReifyRankedShapedTypeInterface`.
authorMaheshRavishankar <ravishankarm@google.com>
Mon, 19 Jul 2021 21:35:20 +0000 (14:35 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Mon, 19 Jul 2021 21:44:52 +0000 (14:44 -0700)
The `reifyReturnTypeShapesPerResultDim` method supports shape
inference for rsults that are ranked types. These are used lower in
the codegeneration stack than its counter part `reifyReturnTypeShapes`
which also supports unranked types, and is more suited for use higher
up the compilation stack. To have separation of concerns, this method
is split into its own interface.
See discussion : https://llvm.discourse.group/t/better-layering-for-infershapedtypeopinterface/3823

Differential Revision: https://reviews.llvm.org/D106133

17 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Interfaces/InferTypeOpInterface.h
mlir/include/mlir/Interfaces/InferTypeOpInterface.td
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 418e609..4837fc8 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 namespace mlir {
index e1f096d..bc87771 100644 (file)
@@ -928,8 +928,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
 
     /// Returns the value that expresses the shape of the output in terms of
     /// shape of the input operands where possible
-    LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
-        SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
+    LogicalResult reifyResultShapes(OpBuilder &b,
+        ReifiedRankedShapedTypeDims &reifiedReturnShapes);
 
     //========================================================================//
     // Helper functions to mutate the `operand_segment_sizes` attribute.
index 8d14880..8cf9dc3 100644 (file)
@@ -36,8 +36,7 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
 
 def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
     [NoSideEffect,
-     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-         ["reifyReturnTypeShapesPerResultDim"]>]> {
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "operation to define a tensor of particular value";
 
   let description = [{
@@ -130,10 +129,8 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
 }
 
 def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-         ["reifyReturnTypeShapesPerResultDim"]>,
-     NoSideEffect]> {
+    [AttrSizedOperandSegments, NoSideEffect,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let summary = "tensor pad operation";
   let description = [{
     `linalg.pad_tensor` is an operation that pads the `source` tensor
@@ -398,8 +395,7 @@ def IndexListArrayAttr :
 
 class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
     mnemonic,
-    [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-      ["reifyReturnTypeShapesPerResultDim"]>]>,
+    [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>,
     Arguments<(ins AnyTensor:$src,
                    IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyTensor:$result)> {
index fa17237..9055f3c 100644 (file)
@@ -26,7 +26,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 // depending on the specific Linalg op.
 class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
   : Op<Linalg_Dialect, mnemonic, !listconcat(props, [
-       LinalgStructuredInterface, InferShapedTypeOpInterface])> {
+       LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface])> {
   code structuredOpsBaseDecls = [{
     // Return whether the op accesses the iteration indices.
     bool hasIndexSemantics() {
@@ -36,9 +36,9 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
       return !op->getRegion(0).front().getOps<IndexOp>().empty();
     }
 
-    LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
-        SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
-      return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
+    LogicalResult reifyResultShapes(OpBuilder &b,
+        ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+      return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
           reifiedReturnShapes);
     }
   }];
index 153991b..186782c 100644 (file)
@@ -36,6 +36,13 @@ namespace memref {
 void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+void populateResolveRankedShapeTypeResultDimsPatterns(
+    RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
 /// defined by operations that implement the `InferShapedTypeOpInterface`, in
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
@@ -50,7 +57,14 @@ std::unique_ptr<Pass> createFoldSubViewOpsPass();
 
 /// Creates an operation pass to resolve `memref.dim` operations with values
 /// that are defined by operations that implement the
-/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
+
+/// Creates an operation pass to resolve `memref.dim` operations with values
+/// that are defined by operations that implement the
+/// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
+/// in terms of shapes of its input operands.
 std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
 
 //===----------------------------------------------------------------------===//
index d7a7ddc..026f345 100644 (file)
@@ -23,12 +23,28 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
   ];
 }
 
+def ResolveRankedShapeTypeResultDims :
+    Pass<"resolve-ranked-shaped-type-result-dims"> {
+  let summary = "Resolve memref.dim of result values of ranked shape type";
+  let description = [{
+    The pass resolves memref.dim of result of operations that
+    implement the `ReifyRankedShapedTypeOpInterface` in terms of
+    shapes of its operands.
+  }];
+  let constructor =
+      "mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
+
 def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
   let summary = "Resolve memref.dim of result values";
   let description = [{
     The pass resolves memref.dim of result of operations that
-    implement the `InferShapedTypeOpInterface` in terms of shapes of
-    its operands.
+    implement the `InferShapedTypeOpInterface` or
+    `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+    operands.
   }];
   let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
   let dependentDialects = [
index c256854..74cc361 100644 (file)
@@ -432,8 +432,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
 def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
     Tensor_Dialect, "insert_slice",
     [NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
-     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-         ["reifyReturnTypeShapesPerResultDim"]>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
      TypesMatchWith<"expected result type to match dest type",
                     "dest", "result", "$_self">]> {
   let summary = "insert_slice operation";
index 1ae4aa6..9f83c59 100644 (file)
@@ -23,6 +23,8 @@
 
 namespace mlir {
 
+using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
+
 /// ShapedTypeComponents that represents the components of a ShapedType.
 /// The components consist of
 ///  - A ranked or unranked shape with the dimension specification match those
index 4d0271d..8ec3a12 100644 (file)
@@ -105,9 +105,7 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
       /*desc=*/[{Reify the shape computation for the operation.
 
       Insert operations using the given OpBuilder that computes the
-      result shape. Only one of this method or
-      `reifyReturnTypeShapesPerResultDim` needs to be overriden by the
-      operation. This interface is supposed to be workable during dialect
+      result shape. This interface is supposed to be workable during dialect
       conversion (e.g. convert from tensor world to buffer world),
       where `getOperand` may be invalid. For example, some ops (e.g.
       dynamic_reshape(input, target_shape)) may depend on their operands
@@ -127,34 +125,6 @@ def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
           "::mlir::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{ return ::mlir::failure(); }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{Reify the shape computation for the operation.
-
-      Insert operations using the given OpBuilder that computes the
-      result shape. The `reifiedReturnShapes` is expected to be
-      populated with as many vectors as the number of results of the
-      op (empty if the shape of a result value cannot be computed). If
-      the returned shape for a result is not empty, its size must
-      match the rank of the shaped type returned. Consequently, this
-      interface can only be overridden if the return types are ranked.
-
-      If both this method and `reifyReturnTypeShapes` are overridden
-      by the operation, `reifyReturnTypeShapes` takes precedence. This
-      method is intended to be used when the shape of each result, dim
-      pair can be computed independently. Using this method avoids
-      adding additional instructions to aggregate individual dimension
-      of a result shape into an single `Value` (and consequently
-      avoids the need to extract the value from the shape on the
-      client side).
-      }],
-      /*retTy=*/"::mlir::LogicalResult",
-      /*methodName=*/"reifyReturnTypeShapesPerResultDim",
-      /*args=*/(ins "::mlir::OpBuilder&":$builder,
-          "::mlir::SmallVectorImpl<::mlir::SmallVector<::mlir::Value>>&"
-          :$reifiedReturnShapes),
-      /*methodBody=*/[{}],
-      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
     >
   ];
 }
@@ -176,4 +146,35 @@ class InferTensorType<list<string> overridenMethods = []> {
 defvar InferTensorTypeWithReify = InferTensorType<[
     "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
 
+
+def ReifyRankedShapedTypeOpInterface :
+    OpInterface<"ReifyRankedShapedTypeOpInterface"> {
+  let description = [{
+    Interface to compute the shape of the result of an operation when
+    the result is a ranked shape type, i.e. `RankedTensorType` or
+    `MemRefType`.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Reify the shape of the result of an operation (typically in
+        terms of shape of its operands)
+
+        Insert operations using the given `OpBuilder` that computes
+        the result shape. The `reifiedReturnShapes` is expected to be
+        populated with as many vectors as the number of results of the
+        op. Each of these vectors is expected to be of size equal to
+        rank of the corresponding result. If the shape of a particular
+        result cannot be computed it must be empty.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"reifyResultShapes",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+        "ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+    >
+  ];
+}
+
 #endif // MLIR_INFERTYPEOPINTERFACE
index 7d22cfd..fb1d1cc 100644 (file)
@@ -274,8 +274,9 @@ private:
   llvm::SmallSet<unsigned, 4> positions;
 };
 
-LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult
+LinalgOp::reifyResultShapes(OpBuilder &b,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   // An example that helps understand the logic below.
   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
index 16ff071..23927a0 100644 (file)
@@ -779,9 +779,8 @@ struct FoldInitTensorWithTensorReshapeOp
     if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
       return failure();
     Location loc = reshapeOp.getLoc();
-    SmallVector<SmallVector<Value>, 4> resultShapes;
-    if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
-                                                           resultShapes)) ||
+    ReifiedRankedShapedTypeDims resultShapes;
+    if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) ||
         !llvm::hasSingleElement(resultShapes))
       return failure();
     Value initTensor = rewriter.create<InitTensorOp>(
@@ -825,9 +824,8 @@ void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ReplaceStaticShapeDims>(context);
 }
 
-LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &builder,
-    SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult InitTensorOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   auto shapes = llvm::to_vector<4>(llvm::map_range(
       llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
         if (isDynamicSize(dim))
@@ -1003,8 +1001,8 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
                                         builder);
 }
 
-LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult PadTensorOp::reifyResultShapes(
+    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   Location loc = getLoc();
   auto lowPad = getMixedLowPad();
   auto highPad = getMixedHighPad();
@@ -1429,8 +1427,8 @@ void TensorCollapseShapeOp::getCanonicalizationPatterns(
            FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
 }
 
-LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult TensorExpandShapeOp::reifyResultShapes(
+    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   auto resultShape =
       getAsValues(b, getLoc(),
                   getReshapeOutputShapeFromInputShape(
@@ -1440,8 +1438,8 @@ LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
   return success();
 }
 
-LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult TensorCollapseShapeOp::reifyResultShapes(
+    OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   auto resultShape =
       getAsValues(b, getLoc(),
                   getReshapeOutputShapeFromInputShape(
index 9c70734..d62286c 100644 (file)
@@ -1,5 +1,4 @@
-//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
-//-------===//
+//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 
 using namespace mlir;
 
-/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
-/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
-/// TODO(ravishankarm): This is better put as a interface utility method
-/// somewhere, but that would imply the interface will depend on the `tensor`
-/// dialect. Ideally maybe a utility method in the `tensor` dialect.
-static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
-                                            int64_t dimIndex) {
-  unsigned resultNumber = result.getResultNumber();
-  auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
-  Location loc = result.getOwner()->getLoc();
-  if (!shapedTypeOp)
-    return nullptr;
-
-  // The interface exposes two methods, one that returns the shape of all the
-  // results as `Value` and other that returns the shape as a list of
-  // `SmallVector<Value>`. The former takes precedence over the latter. So first
-  // check if the op implements the first interface method or the second, and
-  // get the value to use appropriately.
-  SmallVector<Value> reifiedResultShapes;
-  if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
-          builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
-    if (reifiedResultShapes.size() <= resultNumber)
-      return nullptr;
-    Value resultShape = reifiedResultShapes[resultNumber];
-    auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
-    if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
-      return nullptr;
-    return builder.create<tensor::ExtractOp>(
-        loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
-  }
-
-  SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
-  if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
-          builder, reifiedResultShapesPerDim)))
-    return nullptr;
-  if (reifiedResultShapesPerDim.size() <= resultNumber ||
-      reifiedResultShapesPerDim[resultNumber].size() !=
-          static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
-    return nullptr;
-  OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
-  if (auto attr = valueOrAttr.dyn_cast<Attribute>())
-    return builder.createOrFold<ConstantIndexOp>(
-        loc, attr.cast<IntegerAttr>().getInt());
-  return valueOrAttr.get<Value>();
-}
-
 namespace {
 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
 template <typename OpTy>
@@ -86,11 +39,62 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
     Optional<int64_t> dimIndex = dimOp.getConstantIndex();
     if (!dimIndex)
       return failure();
-    Value replacement =
-        getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
-    if (!replacement)
+
+    SmallVector<Value> reifiedResultShapes;
+    if (failed(shapedTypeOp.reifyReturnTypeShapes(
+            rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
+      return failure();
+
+    if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
       return failure();
-    rewriter.replaceOp(dimOp, replacement);
+
+    Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
+    auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+    if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+      return failure();
+
+    Location loc = dimOp->getLoc();
+    rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+        dimOp, resultShape,
+        rewriter.createOrFold<ConstantIndexOp>(loc, *dimIndex));
+    return success();
+  }
+};
+
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+template <typename OpTy>
+struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy dimOp,
+                                PatternRewriter &rewriter) const override {
+    OpResult dimValue = dimOp.source().template dyn_cast<OpResult>();
+    if (!dimValue)
+      return failure();
+    auto rankedShapeTypeOp =
+        dyn_cast<ReifyRankedShapedTypeOpInterface>(dimValue.getOwner());
+    if (!rankedShapeTypeOp)
+      return failure();
+
+    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+
+    SmallVector<SmallVector<Value>> reifiedResultShapes;
+    if (failed(
+            rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
+      return failure();
+
+    if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults())
+      return failure();
+
+    unsigned resultNumber = dimValue.getResultNumber();
+    auto sourceType = dimValue.getType().dyn_cast<RankedTensorType>();
+    if (reifiedResultShapes[resultNumber].size() !=
+        static_cast<size_t>(sourceType.getRank()))
+      return failure();
+
+    rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]);
     return success();
   }
 };
@@ -104,12 +108,26 @@ namespace {
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 
+struct ResolveRankedShapeTypeResultDimsPass final
+    : public ResolveRankedShapeTypeResultDimsBase<
+          ResolveRankedShapeTypeResultDimsPass> {
+  void runOnOperation() override;
+};
+
 struct ResolveShapedTypeResultDimsPass final
     : public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
   void runOnOperation() override;
 };
+
 } // namespace
 
+void memref::populateResolveRankedShapeTypeResultDimsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
+               DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
+      patterns.getContext());
+}
+
 void memref::populateResolveShapedTypeResultDimsPatterns(
     RewritePatternSet &patterns) {
   // TODO: Move tensor::DimOp pattern to the Tensor dialect.
@@ -118,8 +136,17 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
       patterns.getContext());
 }
 
+void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+                                          std::move(patterns))))
+    return signalPassFailure();
+}
+
 void ResolveShapedTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
+  memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
   if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
                                           std::move(patterns))))
@@ -129,3 +156,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
 std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
   return std::make_unique<ResolveShapedTypeResultDimsPass>();
 }
+
+std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
+  return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
+}
index fd02faf..5dd3127 100644 (file)
@@ -1042,9 +1042,8 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
   return OpFoldResult();
 }
 
-LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &builder,
-    SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult InsertSliceOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
   for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
     reifiedReturnShapes[0][dim] =
index db3fea6..1e71b2c 100644 (file)
@@ -55,34 +55,3 @@ func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
 //   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
 //   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
 //       CHECK:   return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
-
-// -----
-
-func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-    -> (index, index, index, index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
-      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
-  %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
-  %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
-  %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
-  %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
-  %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
-  return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
-// CHECK-LABEL: func @result_shape_and_per_dim(
-//  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
-//  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
-//   CHECK-DAG:   %[[C5:.+]] = constant 5 : index
-//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
-//   CHECK-DAG:   %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
-//   CHECK-DAG:   %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
-//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
-//   CHECK-DAG:   %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
-//   CHECK-DAG:   %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
-//       CHECK:   return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
index 67fc2f7..7c470dd 100644 (file)
@@ -822,46 +822,8 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
   return success();
 }
 
-LogicalResult
-OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &builder,
-    llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
-  Location loc = getLoc();
-  shapes.reserve(getNumOperands());
-  for (Value operand : llvm::reverse(getOperands())) {
-    auto currShape = llvm::to_vector<4>(llvm::map_range(
-        llvm::seq<int64_t>(
-            0, operand.getType().cast<RankedTensorType>().getRank()),
-        [&](int64_t dim) -> Value {
-          return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
-        }));
-    shapes.emplace_back(std::move(currShape));
-  }
-  return success();
-}
-
-LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
-    OpBuilder &builder, ValueRange operands,
-    llvm::SmallVectorImpl<Value> &shapes) {
-  Location loc = getLoc();
-  shapes.reserve(operands.size());
-  for (Value operand : llvm::reverse(operands)) {
-    auto currShape = llvm::to_vector<4>(llvm::map_range(
-        llvm::seq<int64_t>(
-            0, operand.getType().cast<RankedTensorType>().getRank()),
-        [&](int64_t dim) -> Value {
-          return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
-        }));
-    shapes.push_back(builder.create<tensor::FromElementsOp>(
-        getLoc(), builder.getIndexType(), currShape));
-  }
-  return success();
-}
-
-LogicalResult
-OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
-    OpBuilder &builder,
-    llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
   Location loc = getLoc();
   shapes.reserve(getNumOperands());
   for (Value operand : llvm::reverse(getOperands())) {
index c59b9fa..16e141e 100644 (file)
@@ -579,16 +579,7 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
 
 def OpWithResultShapePerDimInterfaceOp :
     TEST_Op<"op_with_result_shape_per_dim_interface",
-        [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-             ["reifyReturnTypeShapesPerResultDim"]>]> {
-  let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
-  let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
-}
-
-def OpWithResultShapeAndPerDimInterfaceOp :
-    TEST_Op<"op_with_result_shape_and_per_dim_interface",
-        [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-             ["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
+        [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
   let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
   let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
 }
index 470e3cb..3b8cc4b 100644 (file)
@@ -2046,6 +2046,7 @@ cc_library(
         ":Affine",
         ":DialectUtils",
         ":IR",
+        ":InferTypeOpInterface",
         ":LinalgInterfacesIncGen",
         ":LinalgStructuredOpsIncGen",
         ":MemRefDialect",