[mlir] Move `memref.dim` canonicalization using `InferShapedTypeOpInterface` to a...
authorMaheshRavishankar <ravishankarm@google.com>
Thu, 17 Jun 2021 05:12:16 +0000 (22:12 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Thu, 17 Jun 2021 05:13:11 +0000 (22:13 -0700)
Based on dicussion in
[this](https://llvm.discourse.group/t/remove-canonicalizer-for-memref-dim-via-shapedtypeopinterface/3641)
thread the pattern to resolve the `memref.dim` of a value that is a
result of an operation that implements the
`InferShapedTypeOpInterface` is moved to a separate pass instead of
running it as a canonicalization pass. This allows shape resolution to
happen when explicitly required, instead of automatically through a
canonicalization.

Differential Revision: https://reviews.llvm.org/D104321

15 files changed:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/fusion-sequence.mlir
mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir [new file with mode: 0644]
mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir [new file with mode: 0644]
mlir/test/Transforms/test-canonicalize.mlir
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index 1eae023..153991b 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+
+class AffineDialect;
+namespace tensor {
+class TensorDialect;
+} // namespace tensor
+namespace vector {
+class VectorDialect;
+} // namespace vector
+
 namespace memref {
 
 //===----------------------------------------------------------------------===//
@@ -26,6 +35,11 @@ namespace memref {
 /// into `patterns`.
 void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the `InferShapedTypeOpInterface`, in
+/// terms of shapes of its input operands.
+void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -34,6 +48,11 @@ void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
 /// load/store ops into `patterns`.
 std::unique_ptr<Pass> createFoldSubViewOpsPass();
 
+/// Creates an operation pass to resolve `memref.dim` operations with values
+/// that are defined by operations that implement the
+/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
+std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
index d98d510..d7a7ddc 100644 (file)
@@ -23,6 +23,18 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
   ];
 }
 
+def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
+  let summary = "Resolve memref.dim of result values";
+  let description = [{
+    The pass resolves memref.dim of result of operations that
+    implement the `InferShapedTypeOpInterface` in terms of shapes of
+    its operands.
+  }];
+  let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
+  let dependentDialects = [
+    "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
 
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 
index b9f4dc9..9f59738 100644 (file)
@@ -794,84 +794,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
-
-/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
-/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
-/// TODO(ravishankarm): This is better put as a interface utility method
-/// somewhere, but that would imply the interface will depend on the `tensor`
-/// dialect. Ideally maybe a utility method in the `tensor` dialect.
-static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
-                                            int64_t dimIndex) {
-  unsigned resultNumber = result.getResultNumber();
-  auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
-  Location loc = result.getOwner()->getLoc();
-  if (!shapedTypeOp)
-    return nullptr;
-
-  // The interface exposes two methods, one that returns the shape of all the
-  // results as `Value` and other that returns the shape as a list of
-  // `SmallVector<Value>`. The former takes precedence over the latter. So first
-  // check if the op implements the first interface method or the second, and
-  // get the value to use appropriately.
-  SmallVector<Value> reifiedResultShapes;
-  if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
-          builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
-    if (reifiedResultShapes.size() <= resultNumber)
-      return nullptr;
-    Value resultShape = reifiedResultShapes[resultNumber];
-    auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
-    if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
-      return nullptr;
-    return builder.create<tensor::ExtractOp>(
-        loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
-  }
-
-  SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
-  if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
-          builder, reifiedResultShapesPerDim)))
-    return nullptr;
-  if (reifiedResultShapesPerDim.size() <= resultNumber ||
-      reifiedResultShapesPerDim[resultNumber].size() !=
-          static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
-    return nullptr;
-  OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
-  if (auto attr = valueOrAttr.dyn_cast<Attribute>())
-    return builder.createOrFold<ConstantIndexOp>(
-        loc, attr.cast<IntegerAttr>().getInt());
-  return valueOrAttr.get<Value>();
-}
-
-/// Fold dim of an operation that implements the InferShapedTypeOpInterface
-struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
-    if (!dimValue)
-      return failure();
-    auto shapedTypeOp =
-        dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
-    if (!shapedTypeOp)
-      return failure();
-
-    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
-    if (!dimIndex)
-      return failure();
-    Value replacement =
-        getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
-    if (!replacement)
-      return failure();
-    rewriter.replaceOp(dimOp, replacement);
-    return success();
-  }
-};
 } // end anonymous namespace.
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
-              DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
+              DimOfCastOp<tensor::CastOp>>(context);
 }
 
 // ---------------------------------------------------------------------------
index e795a86..672d897 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   FoldSubViewOps.cpp
+  ResolveShapedTypeResultDims.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
@@ -9,9 +10,11 @@ add_mlir_dialect_library(MLIRMemRefTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffine
+  MLIRInferTypeOpInterface
   MLIRMemRef
   MLIRPass
   MLIRStandard
+  MLIRTensor
   MLIRTransforms
   MLIRVector
 )
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
new file mode 100644 (file)
index 0000000..1b16efe
--- /dev/null
@@ -0,0 +1,127 @@
+//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass resolves `memref.dim` operations of result values in terms of
+// shapes of their operands using the `InferShapedTypeOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
+/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
+/// TODO(ravishankarm): This is better put as a interface utility method
+/// somewhere, but that would imply the interface will depend on the `tensor`
+/// dialect. Ideally maybe a utility method in the `tensor` dialect.
+static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
+                                            int64_t dimIndex) {
+  unsigned resultNumber = result.getResultNumber();
+  auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
+  Location loc = result.getOwner()->getLoc();
+  if (!shapedTypeOp)
+    return nullptr;
+
+  // The interface exposes two methods, one that returns the shape of all the
+  // results as `Value` and other that returns the shape as a list of
+  // `SmallVector<Value>`. The former takes precedence over the latter. So first
+  // check if the op implements the first interface method or the second, and
+  // get the value to use appropriately.
+  SmallVector<Value> reifiedResultShapes;
+  if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
+          builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
+    if (reifiedResultShapes.size() <= resultNumber)
+      return nullptr;
+    Value resultShape = reifiedResultShapes[resultNumber];
+    auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+    if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+      return nullptr;
+    return builder.create<tensor::ExtractOp>(
+        loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
+  }
+
+  SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
+  if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
+          builder, reifiedResultShapesPerDim)))
+    return nullptr;
+  if (reifiedResultShapesPerDim.size() <= resultNumber ||
+      reifiedResultShapesPerDim[resultNumber].size() !=
+          static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
+    return nullptr;
+  OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
+  if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+    return builder.createOrFold<ConstantIndexOp>(
+        loc, attr.cast<IntegerAttr>().getInt());
+  return valueOrAttr.get<Value>();
+}
+
+namespace {
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
+  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
+    if (!dimValue)
+      return failure();
+    auto shapedTypeOp =
+        dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
+    if (!shapedTypeOp)
+      return failure();
+
+    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+    Value replacement =
+        getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
+    if (!replacement)
+      return failure();
+    rewriter.replaceOp(dimOp, replacement);
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct ResolveShapedTypeResultDimsPass final
+    : public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void memref::populateResolveShapedTypeResultDimsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
+}
+
+void ResolveShapedTypeResultDimsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  memref::populateResolveShapedTypeResultDimsPatterns(patterns);
+  if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+                                          std::move(patterns))))
+    return signalPassFailure();
+}
+
+std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
+  return std::make_unique<ResolveShapedTypeResultDimsPass>();
+}
index 029ac62..1689559 100644 (file)
@@ -532,205 +532,6 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
 
 // -----
 
-func @init_tensor_static_dim() -> (index, index) {
-  %c0 = constant 0 : index
-  %c2 = constant 2 : index
-  %c6 = constant 6 : index
-  %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
-  %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
-  %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
-  return %1, %2 : index, index
-}
-//      CHECK: func @init_tensor_static_dim
-//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//  CHECK-DAG:   %[[C6:.+]] = constant 6 : index
-//      CHECK:   return %[[C6]], %[[C4]]
-
-// -----
-
-func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
-  %c2 = constant 2 : index
-  %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
-  %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
-  return %1 : index
-}
-//      CHECK: func @init_tensor_dynamic_dim
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-//      CHECK:   return %[[ARG0]]
-
-// -----
-
-func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
-  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
-  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
-  return %1, %2 : index, index
-}
-//      CHECK: func @init_tensor_dynamic_dim2
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//      CHECK:   return %[[ARG0]], %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses
-  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
-   %arg2 : tensor<?x?xf32>) -> (index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
-                      affine_map<(d0, d1, d2) -> (d2, d1)>,
-                      affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
-     iterator_types = ["parallel", "parallel", "reduction"]}
-    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-    outs(%arg2 : tensor<?x?xf32>) {
-    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
-      %1 = mulf %arg3, %arg4 : f32
-      %2 = addf %1, %arg5 : f32
-      linalg.yield %2 : f32
-    } -> tensor<?x?xf32>
-  %3 = memref.dim %0, %c0 : tensor<?x?xf32>
-  %4 = memref.dim %0, %c1 : tensor<?x?xf32>
-  return %3, %4 : index, index
-}
-//       CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-//       CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
-//       CHECK: func @remove_dim_result_uses
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//       CHECK:   %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
-//   CHECK-DAG:   %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//       CHECK:   %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
-//       CHECK:   return %[[T2]], %[[T5]]
-
-// -----
-
-func @remove_dim_result_uses_outs
-  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
-  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
-  %1 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
-                      affine_map<(d0, d1) -> (d0, d1)>],
-     iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
-    ^bb0(%arg2: f32, %arg3: f32) :
-      linalg.yield %arg2 : f32
-    } -> tensor<?x?xf32>
-  %2 = memref.dim %1, %c1 : tensor<?x?xf32>
-  return %2 : index
-}
-//      CHECK: func @remove_dim_result_uses_outs
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//      CHECK:   return %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses_sequence
-  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
-   %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
-  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
-  %3 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
-                      affine_map<(d0, d1, d2) -> (d0, d2)>,
-                      affine_map<(d0, d1, d2) -> (d0, d2)>],
-     iterator_types = ["parallel", "reduction", "parallel"]}
-    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
-    outs(%0 : tensor<?x?xf32>) {
-    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
-      %4 = mulf %arg3, %arg4 : f32
-      %5 = addf %4, %arg5 : f32
-      linalg.yield %5 : f32
-    } -> tensor<?x?xf32>
-  %6 = memref.dim %3, %c0 : tensor<?x?xf32>
-  %7 = memref.dim %3, %c1 : tensor<?x?xf32>
-  return %1, %2, %6, %7 : index, index, index, index
-}
-// CHECK-LABEL: func @remove_dim_result_uses_sequence
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//   CHECK-DAG:   %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
-//   CHECK-DAG:   %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//       CHECK:   return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
-
-// -----
-
-func @keep_result_dim_uses_sequence2
-  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
-  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
-  %1 = linalg.generic
-    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
-                      affine_map<(d0, d1) -> (d0, d1)>],
-     iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
-    ^bb0(%arg2: f32, %arg3 : f32):
-      linalg.yield %arg2 : f32
-    } -> tensor<?x?xf32>
-  %2 = memref.dim %1, %c0 : tensor<?x?xf32>
-  %3 = memref.dim %1, %c1 : tensor<?x?xf32>
-  return %2, %3 : index, index
-}
-//       CHECK: func @keep_result_dim_uses_sequence2
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
-//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
-//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-//       CHECK:   return %[[T0]], %[[ARG1]]
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-
-func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
-    %arg_1: tensor<?xf32>) -> (index, index) {
-  %0, %1 = linalg.generic {
-    indexing_maps = [#map, #map, #map],
-    iterator_types = ["parallel"]
-  } ins(%arg_0 : tensor<?xf32>)
-    outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
-  ^bb0(%in: f32, %out_0: f32, %out_1: f32):
-    linalg.yield %in, %in : f32, f32
-  } -> (tensor<?xf32>, tensor<?xf32>)
-
-  %c0 = constant 0 : index
-  %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
-
-  %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
-  return %num_elem_0, %num_elem_1 : index, index
-}
-//      CHECK: func @init_tensor_dim_of_linalg_result(
-// CHECK-SAME:   %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK-SAME:   %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
-//      CHECK:   %[[R0:.+]] = memref.dim %[[ARG_0]]
-//      CHECK:   %[[R1:.+]] = memref.dim %[[ARG_0]]
-//      CHECK:   return %[[R0]], %[[R1]]
-
-// -----
-
 func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
   %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
   %1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
@@ -740,9 +541,12 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
 //      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @init_tensor_reshape_expansion
 // CHECK-SAME:   %[[ARG0:.+]]: index
-//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
-//      CHECK:   return %[[T1]]
+//      CHECK:   %[[C2:.+]] = constant 2
+//      CHECK:   %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
+//      CHECK:   %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]]
+//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
+//      CHECK:   return %[[INIT2]]
 
 // -----
 
@@ -755,9 +559,12 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
 //      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
 //      CHECK: func @init_tensor_reshape_collapse
 // CHECK-SAME:   %[[ARG0:.+]]: index
-//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-//      CHECK:   %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
-//      CHECK:   return %[[T1]]
+//      CHECK:   %[[C4:.+]] = constant 4
+//      CHECK:   %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
+//      CHECK:   %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]]
+//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
+//      CHECK:   return %[[INIT2]]
 
 // -----
 
@@ -906,54 +713,6 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
   } : tensor<5x6xf32> to tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
 }
-
-// -----
-
-func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
-{
-  %c1 = constant 1 : index
-  %c3 = constant 3 : index
-  %c4 = constant 4 : index
-  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
-      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
-  %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
-  %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
-  %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
-  return %1, %2, %3 : index, index, index
-}
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
-//      CHECK: func @dim_reshape_expansion
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//  CHECK-DAG:   %[[C3:.+]] = constant 3 : index
-//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//      CHECK:   %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
-
-// -----
-
-func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
-{
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
-      : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
-  %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
-  %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
-  return %1, %2 : index, index
-}
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
-//      CHECK: func @dim_reshape_collapse
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
-//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
-//  CHECK-DAG:   %[[C5:.+]] = constant 5 : index
-//      CHECK:   %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C5]], %[[D1]]
-
-// -----
-
 func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
     %arg3 : index) -> tensor<?x?xf32> {
   %c0 = constant 0 : index
@@ -1083,41 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
 
 // -----
 
-func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
-    %arg3: f32) -> (index, index, index)
-{
-   %c0 = constant 0 : index
-   %c1 = constant 1 : index
-   %c2 = constant 2 : index
-   %c3 = constant 3 : index
-   %c4 = constant 4 : index
-   %c5 = constant 5 : index
-   %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
-     ^bb0(%arg4: index, %arg5: index, %arg6: index):
-       linalg.yield %arg3 : f32
-   } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
-   %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
-   %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
-   %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
-   return %1, %2, %3 : index, index, index
-}
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
-//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
-//      CHECK: func @dim_of_pad_op
-// CHECK-SAME:   %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
-// CHECK-SAME:   %[[ARG1:[A-Za-z0-9_]+]]: index
-// CHECK-SAME:   %[[ARG2:[A-Za-z0-9_]+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
-//  CHECK-DAG:   %[[C12:.+]] = constant 12 : index
-//      CHECK:   %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-//      CHECK:   %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
-//      CHECK:   %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
-//      CHECK:   return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
-
-// -----
-
 #map = affine_map<(d0, d1) -> (d0, d1)>
 
 func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
index 3321595..8455991 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
 
 module {
   func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
index 730482d..4ea25df 100644 (file)
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse  --split-input-file | FileCheck %s --check-prefix=TLOOP
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse  --split-input-file | FileCheck %s --check-prefix=TLOOP
 
 module {
   func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
new file mode 100644 (file)
index 0000000..0bcf601
--- /dev/null
@@ -0,0 +1,278 @@
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func @init_tensor_static_dim() -> (index, index) {
+  %c0 = constant 0 : index
+  %c2 = constant 2 : index
+  %c6 = constant 6 : index
+  %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
+  %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+  %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: func @init_tensor_static_dim
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//  CHECK-DAG:   %[[C6:.+]] = constant 6 : index
+//      CHECK:   return %[[C6]], %[[C4]]
+
+// -----
+
+func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
+  %c2 = constant 2 : index
+  %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
+  %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+  return %1 : index
+}
+//      CHECK: func @init_tensor_dynamic_dim
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG0]]
+
+// -----
+
+func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: func @init_tensor_dynamic_dim2
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses
+  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+   %arg2 : tensor<?x?xf32>) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                      affine_map<(d0, d1, d2) -> (d2, d1)>,
+                      affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %1 = mulf %arg3, %arg4 : f32
+      %2 = addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?x?xf32>
+  %3 = memref.dim %0, %c0 : tensor<?x?xf32>
+  %4 = memref.dim %0, %c1 : tensor<?x?xf32>
+  return %3, %4 : index, index
+}
+//       CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)>
+//       CHECK: func @remove_dim_result_uses
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
+//   CHECK-DAG:   %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
+//       CHECK:   return %[[T2]], %[[T5]]
+
+// -----
+
+func @remove_dim_result_uses_outs
+  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3: f32) :
+      linalg.yield %arg2 : f32
+    } -> tensor<?x?xf32>
+  %2 = memref.dim %1, %c1 : tensor<?x?xf32>
+  return %2 : index
+}
+//      CHECK: func @remove_dim_result_uses_outs
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//      CHECK:   return %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses_sequence
+  (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+   %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+  %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+  %3 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
+                      affine_map<(d0, d1, d2) -> (d0, d2)>,
+                      affine_map<(d0, d1, d2) -> (d0, d2)>],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+      %4 = mulf %arg3, %arg4 : f32
+      %5 = addf %4, %arg5 : f32
+      linalg.yield %5 : f32
+    } -> tensor<?x?xf32>
+  %6 = memref.dim %3, %c0 : tensor<?x?xf32>
+  %7 = memref.dim %3, %c1 : tensor<?x?xf32>
+  return %1, %2, %6, %7 : index, index, index, index
+}
+// CHECK-LABEL: func @remove_dim_result_uses_sequence
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//   CHECK-DAG:   %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
+
+// -----
+
+func @keep_result_dim_uses_sequence2
+  (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+  %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+  %1 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+                      affine_map<(d0, d1) -> (d0, d1)>],
+     iterator_types = ["parallel", "parallel"]}
+    ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+    ^bb0(%arg2: f32, %arg3 : f32):
+      linalg.yield %arg2 : f32
+    } -> tensor<?x?xf32>
+  %2 = memref.dim %1, %c0 : tensor<?x?xf32>
+  %3 = memref.dim %1, %c1 : tensor<?x?xf32>
+  return %2, %3 : index, index
+}
+//       CHECK: func @keep_result_dim_uses_sequence2
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//       CHECK:   return %[[T0]], %[[ARG1]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+    %arg_1: tensor<?xf32>) -> (index, index) {
+  %0, %1 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel"]
+  } ins(%arg_0 : tensor<?xf32>)
+    outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+  ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+    linalg.yield %in, %in : f32, f32
+  } -> (tensor<?xf32>, tensor<?xf32>)
+
+  %c0 = constant 0 : index
+  %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
+
+  %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
+  return %num_elem_0, %num_elem_1 : index, index
+}
+//      CHECK: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME:   %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME:   %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
+//      CHECK:   %[[R0:.+]] = memref.dim %[[ARG_0]]
+//      CHECK:   %[[R1:.+]] = memref.dim %[[ARG_0]]
+//      CHECK:   return %[[R0]], %[[R1]]
+
+// -----
+
+func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+{
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+      : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
+  %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
+  %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
+  return %1, %2, %3 : index, index, index
+}
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+//      CHECK: func @dim_reshape_expansion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
+//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//      CHECK:   %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
+//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+
+// -----
+
+func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
+{
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+      : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+  %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
+  %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+//      CHECK: func @dim_reshape_collapse
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//  CHECK-DAG:   %[[C5:.+]] = constant 5 : index
+//      CHECK:   %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
+//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   return %[[C5]], %[[D1]]
+
+// -----
+
+func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
+    %arg3: f32) -> (index, index, index)
+{
+   %c0 = constant 0 : index
+   %c1 = constant 1 : index
+   %c2 = constant 2 : index
+   %c3 = constant 3 : index
+   %c4 = constant 4 : index
+   %c5 = constant 5 : index
+   %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
+     ^bb0(%arg4: index, %arg5: index, %arg6: index):
+       linalg.yield %arg3 : f32
+   } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
+   %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
+   %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
+   %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
+   return %1, %2, %3 : index, index, index
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)>
+//      CHECK: func @dim_of_pad_op
+// CHECK-SAME:   %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[A-Za-z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[A-Za-z0-9_]+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:   %[[C12:.+]] = constant 12 : index
+//      CHECK:   %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//      CHECK:   %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
+//      CHECK:   %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
+//      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
+//      CHECK:   return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
index 9526521..bffbca1 100644 (file)
@@ -205,16 +205,14 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 
 // CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
 // CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)>
-// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
 // CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
 // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
 // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
-// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
 // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
 // CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)>
 // CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)>
-// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)>
 
 //      CHECK: func @conv_tensors_dynamic
 // CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
@@ -240,16 +238,20 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 //  CHECK-DAG:   %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
 //  CHECK-DAG:   %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
 //  CHECK-DAG:   %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor<?x?x?x?xf32>
 
 //      CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
 // CHECK-NEXT:     %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
 // CHECK-NEXT:     %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
-// CHECK-NEXT:     %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]]
+// CHECK-NEXT:     %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]]
 // CHECK-NEXT:     scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
 // CHECK-NEXT:       %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
 // CHECK-NEXT:       %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
 // CHECK-NEXT:       %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
-// CHECK-NEXT:       %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]]
+// CHECK-NEXT:       %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
 // CHECK-NEXT:       scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
 // CHECK-NEXT:         %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
 // CHECK-NEXT:         %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
@@ -257,7 +259,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 // CHECK-NEXT:         %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
 // CHECK-NEXT:         %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
 // CHECK-SAME:               [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
-// CHECK-NEXT:         %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]]
+// CHECK-NEXT:         %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
 // CHECK-NEXT:         scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
 // CHECK-NEXT:           %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 // CHECK-SAME:                 [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
@@ -266,7 +268,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
 // CHECK-NEXT:           %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]]
 // CHECK-NEXT:           %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
 // CHECK-SAME:                 [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
-// CHECK-NEXT:           %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]]
+// CHECK-NEXT:           %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]]
 // CHECK-NEXT:           %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 // CHECK-SAME:                 [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
 // CHECK-NEXT:           %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
new file mode 100644 (file)
index 0000000..2568e23
--- /dev/null
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+
+func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+    -> (index, index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
+      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+  %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+  %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+  %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+  %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+  %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+  return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape(
+//  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+//  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//   CHECK-DAG:   %[[C5:.+]] = constant 5 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+//   CHECK-DAG:   %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+//   CHECK-DAG:   %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+//   CHECK-DAG:   %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+//   CHECK-DAG:   %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+//   CHECK-DAG:   %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+//       CHECK:   return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+
+// -----
+
+func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+    -> (index, index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+  %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+  %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+  %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+  %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+  %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+  return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_per_dim(
+//  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+//  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//   CHECK-DAG:   %[[C5:.+]] = constant 5 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+//       CHECK:   return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+    -> (index, index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
+      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+  %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+  %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+  %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+  %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+  %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+  return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_and_per_dim(
+//  CHECK-SAME:   %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+//  CHECK-SAME:   %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//   CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//   CHECK-DAG:   %[[C5:.+]] = constant 5 : index
+//   CHECK-DAG:   %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+//   CHECK-DAG:   %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+//   CHECK-DAG:   %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+//   CHECK-DAG:   %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+//   CHECK-DAG:   %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+//   CHECK-DAG:   %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+//       CHECK:   return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
index d245a56..8c919da 100644 (file)
@@ -82,30 +82,6 @@ func @typemismatch() -> i32 {
   return %0 : i32
 }
 
-// CHECK-LABEL: func @result_shape_per_dim
-// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
-func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-    -> (index, index, index, index, index) {
-  // CHECK-DAG: %[[C0:.+]] = constant 0 : index
-  // CHECK-DAG: %[[C2:.+]] = constant 2 : index
-  // CHECK-DAG: %[[C3:.+]] = constant 3 : index
-  // CHECK-DAG: %[[C5:.+]] = constant 5 : index
-  %c0 = constant 0 : index
-  %c1 = constant 1 : index
-  %c2 = constant 2 : index
-  %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
-      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
-  %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
-  %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
-  %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
-  %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
-  %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
-  // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
-  // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
-  // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
-  return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
-
 // CHECK-LABEL: test_dialect_canonicalizer
 func @test_dialect_canonicalizer() -> (i32) {
   %0 = "test.dialect_canonicalizable"() : () -> (i32)
index c5bda7b..a591ab5 100644 (file)
@@ -65,6 +65,7 @@ add_mlir_library(MLIRTestDialect
   MLIRReduce
   MLIRStandard
   MLIRStandardOpsTransforms
+  MLIRTensor
   MLIRTransformUtils
   MLIRTransforms
 )
index 8ef6ec6..ca6f180 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -802,22 +803,75 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
   return success();
 }
 
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+    OpBuilder &builder, ValueRange operands,
+    llvm::SmallVectorImpl<Value> &shapes) {
+  Location loc = getLoc();
+  shapes.reserve(operands.size());
+  for (Value operand : llvm::reverse(operands)) {
+    auto currShape = llvm::to_vector<4>(llvm::map_range(
+        llvm::seq<int64_t>(
+            0, operand.getType().cast<RankedTensorType>().getRank()),
+        [&](int64_t dim) -> Value {
+          return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+        }));
+    shapes.push_back(builder.create<tensor::FromElementsOp>(
+        getLoc(), builder.getIndexType(), currShape));
+  }
+  return success();
+}
+
 LogicalResult
 OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
     OpBuilder &builder,
     llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
-  SmallVector<Value> operand1Shape, operand2Shape;
   Location loc = getLoc();
-  for (auto i :
-       llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
-    operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
+  shapes.reserve(getNumOperands());
+  for (Value operand : llvm::reverse(getOperands())) {
+    auto currShape = llvm::to_vector<4>(llvm::map_range(
+        llvm::seq<int64_t>(
+            0, operand.getType().cast<RankedTensorType>().getRank()),
+        [&](int64_t dim) -> Value {
+          return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+        }));
+    shapes.emplace_back(std::move(currShape));
   }
-  for (auto i :
-       llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
-    operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
+  return success();
+}
+
+LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
+    OpBuilder &builder, ValueRange operands,
+    llvm::SmallVectorImpl<Value> &shapes) {
+  Location loc = getLoc();
+  shapes.reserve(operands.size());
+  for (Value operand : llvm::reverse(operands)) {
+    auto currShape = llvm::to_vector<4>(llvm::map_range(
+        llvm::seq<int64_t>(
+            0, operand.getType().cast<RankedTensorType>().getRank()),
+        [&](int64_t dim) -> Value {
+          return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+        }));
+    shapes.push_back(builder.create<tensor::FromElementsOp>(
+        getLoc(), builder.getIndexType(), currShape));
+  }
+  return success();
+}
+
+LogicalResult
+OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &builder,
+    llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+  Location loc = getLoc();
+  shapes.reserve(getNumOperands());
+  for (Value operand : llvm::reverse(getOperands())) {
+    auto currShape = llvm::to_vector<4>(llvm::map_range(
+        llvm::seq<int64_t>(
+            0, operand.getType().cast<RankedTensorType>().getRank()),
+        [&](int64_t dim) -> Value {
+          return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+        }));
+    shapes.emplace_back(std::move(currShape));
   }
-  shapes.emplace_back(std::move(operand2Shape));
-  shapes.emplace_back(std::move(operand1Shape));
   return success();
 }
 
index ea39b9c..0f1775f 100644 (file)
@@ -571,9 +571,25 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
   let results = (outs AnyTensor);
 }
 
-def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
+def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
       [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-           ["reifyReturnTypeShapesPerResultDim"]>]> {
+          ["reifyReturnTypeShapes"]>]> {
+  let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+  let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapePerDimInterfaceOp :
+    TEST_Op<"op_with_result_shape_per_dim_interface",
+        [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+             ["reifyReturnTypeShapesPerResultDim"]>]> {
+  let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+  let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapeAndPerDimInterfaceOp :
+    TEST_Op<"op_with_result_shape_and_per_dim_interface",
+        [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+             ["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
   let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
   let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
 }