#include "mlir/Pass/Pass.h"
namespace mlir {
+
+class AffineDialect;
+namespace tensor {
+class TensorDialect;
+} // namespace tensor
+namespace vector {
+class VectorDialect;
+} // namespace vector
+
namespace memref {
//===----------------------------------------------------------------------===//
/// into `patterns`.
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the `InferShapedTypeOpInterface`, in
+/// terms of shapes of its input operands.
+void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldSubViewOpsPass();
+/// Creates an operation pass to resolve `memref.dim` operations with values
+/// that are defined by operations that implement the
+/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
+std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
];
}
+def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
+ let summary = "Resolve memref.dim of result values";
+ let description = [{
+ The pass resolves memref.dim of result of operations that
+ implement the `InferShapedTypeOpInterface` in terms of shapes of
+ its operands.
+ }];
+ let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
return success();
}
};
-
-/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
-/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
-/// TODO(ravishankarm): This is better put as a interface utility method
-/// somewhere, but that would imply the interface will depend on the `tensor`
-/// dialect. Ideally maybe a utility method in the `tensor` dialect.
-static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
- int64_t dimIndex) {
- unsigned resultNumber = result.getResultNumber();
- auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
- Location loc = result.getOwner()->getLoc();
- if (!shapedTypeOp)
- return nullptr;
-
- // The interface exposes two methods, one that returns the shape of all the
- // results as `Value` and other that returns the shape as a list of
- // `SmallVector<Value>`. The former takes precedence over the latter. So first
- // check if the op implements the first interface method or the second, and
- // get the value to use appropriately.
- SmallVector<Value> reifiedResultShapes;
- if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
- builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
- if (reifiedResultShapes.size() <= resultNumber)
- return nullptr;
- Value resultShape = reifiedResultShapes[resultNumber];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
- return nullptr;
- return builder.create<tensor::ExtractOp>(
- loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
- }
-
- SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
- if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
- builder, reifiedResultShapesPerDim)))
- return nullptr;
- if (reifiedResultShapesPerDim.size() <= resultNumber ||
- reifiedResultShapesPerDim[resultNumber].size() !=
- static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
- return nullptr;
- OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
- if (auto attr = valueOrAttr.dyn_cast<Attribute>())
- return builder.createOrFold<ConstantIndexOp>(
- loc, attr.cast<IntegerAttr>().getInt());
- return valueOrAttr.get<Value>();
-}
-
-/// Fold dim of an operation that implements the InferShapedTypeOpInterface
-struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dimOp,
- PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
- if (!dimValue)
- return failure();
- auto shapedTypeOp =
- dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
- if (!shapedTypeOp)
- return failure();
-
- Optional<int64_t> dimIndex = dimOp.getConstantIndex();
- if (!dimIndex)
- return failure();
- Value replacement =
- getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
- if (!replacement)
- return failure();
- rewriter.replaceOp(dimOp, replacement);
- return success();
- }
-};
} // end anonymous namespace.
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
- DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
+ DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------
add_mlir_dialect_library(MLIRMemRefTransforms
FoldSubViewOps.cpp
+ ResolveShapedTypeResultDims.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
LINK_LIBS PUBLIC
MLIRAffine
+ MLIRInferTypeOpInterface
MLIRMemRef
MLIRPass
MLIRStandard
+ MLIRTensor
MLIRTransforms
MLIRVector
)
--- /dev/null
+//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass resolves `memref.dim` operations of result values in terms of
+// shapes of their operands using the `InferShapedTypeOpInterface`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
+/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
+/// TODO(ravishankarm): This is better put as a interface utility method
+/// somewhere, but that would imply the interface will depend on the `tensor`
+/// dialect. Ideally maybe a utility method in the `tensor` dialect.
+static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
+ int64_t dimIndex) {
+ unsigned resultNumber = result.getResultNumber();
+ auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
+ Location loc = result.getOwner()->getLoc();
+ if (!shapedTypeOp)
+ return nullptr;
+
+ // The interface exposes two methods, one that returns the shape of all the
+ // results as `Value` and other that returns the shape as a list of
+ // `SmallVector<Value>`. The former takes precedence over the latter. So first
+ // check if the op implements the first interface method or the second, and
+ // get the value to use appropriately.
+ SmallVector<Value> reifiedResultShapes;
+ if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
+ builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
+ if (reifiedResultShapes.size() <= resultNumber)
+ return nullptr;
+ Value resultShape = reifiedResultShapes[resultNumber];
+ auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+ if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ return nullptr;
+ return builder.create<tensor::ExtractOp>(
+ loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
+ }
+
+ SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
+ if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
+ builder, reifiedResultShapesPerDim)))
+ return nullptr;
+ if (reifiedResultShapesPerDim.size() <= resultNumber ||
+ reifiedResultShapesPerDim[resultNumber].size() !=
+ static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
+ return nullptr;
+ OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+ return builder.createOrFold<ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ return valueOrAttr.get<Value>();
+}
+
+namespace {
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
+ using OpRewritePattern<memref::DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
+ if (!dimValue)
+ return failure();
+ auto shapedTypeOp =
+ dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
+ if (!shapedTypeOp)
+ return failure();
+
+ Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+ Value replacement =
+ getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
+ if (!replacement)
+ return failure();
+ rewriter.replaceOp(dimOp, replacement);
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct ResolveShapedTypeResultDimsPass final
+ : public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void memref::populateResolveShapedTypeResultDimsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
+}
+
+void ResolveShapedTypeResultDimsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ memref::populateResolveShapedTypeResultDimsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+ std::move(patterns))))
+ return signalPassFailure();
+}
+
+std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
+ return std::make_unique<ResolveShapedTypeResultDimsPass>();
+}
// -----
-func @init_tensor_static_dim() -> (index, index) {
- %c0 = constant 0 : index
- %c2 = constant 2 : index
- %c6 = constant 6 : index
- %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
- %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
- %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: func @init_tensor_static_dim
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C6:.+]] = constant 6 : index
-// CHECK: return %[[C6]], %[[C4]]
-
-// -----
-
-func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
- %c2 = constant 2 : index
- %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
- %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
- return %1 : index
-}
-// CHECK: func @init_tensor_dynamic_dim
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG0]]
-
-// -----
-
-func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: func @init_tensor_dynamic_dim2
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG0]], %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses
- (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : tensor<?x?xf32>) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d2, d1)>,
- affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
- iterator_types = ["parallel", "parallel", "reduction"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %1 = mulf %arg3, %arg4 : f32
- %2 = addf %1, %arg5 : f32
- linalg.yield %2 : f32
- } -> tensor<?x?xf32>
- %3 = memref.dim %0, %c0 : tensor<?x?xf32>
- %4 = memref.dim %0, %c1 : tensor<?x?xf32>
- return %3, %4 : index, index
-}
-// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
-// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
-// CHECK: func @remove_dim_result_uses
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
-// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
-// CHECK: return %[[T2]], %[[T5]]
-
-// -----
-
-func @remove_dim_result_uses_outs
- (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
- %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg2: f32, %arg3: f32) :
- linalg.yield %arg2 : f32
- } -> tensor<?x?xf32>
- %2 = memref.dim %1, %c1 : tensor<?x?xf32>
- return %2 : index
-}
-// CHECK: func @remove_dim_result_uses_outs
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK: return %[[ARG1]]
-
-// -----
-
-func @remove_dim_result_uses_sequence
- (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?xf32>
- %3 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
- affine_map<(d0, d1, d2) -> (d0, d2)>,
- affine_map<(d0, d1, d2) -> (d0, d2)>],
- iterator_types = ["parallel", "reduction", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
- %4 = mulf %arg3, %arg4 : f32
- %5 = addf %4, %arg5 : f32
- linalg.yield %5 : f32
- } -> tensor<?x?xf32>
- %6 = memref.dim %3, %c0 : tensor<?x?xf32>
- %7 = memref.dim %3, %c1 : tensor<?x?xf32>
- return %1, %2, %6, %7 : index, index, index, index
-}
-// CHECK-LABEL: func @remove_dim_result_uses_sequence
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
-// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
-
-// -----
-
-func @keep_result_dim_uses_sequence2
- (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
- %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
- %1 = linalg.generic
- {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
- ^bb0(%arg2: f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- } -> tensor<?x?xf32>
- %2 = memref.dim %1, %c0 : tensor<?x?xf32>
- %3 = memref.dim %1, %c1 : tensor<?x?xf32>
- return %2, %3 : index, index
-}
-// CHECK: func @keep_result_dim_uses_sequence2
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
-// CHECK: return %[[T0]], %[[ARG1]]
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-
-func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
- %arg_1: tensor<?xf32>) -> (index, index) {
- %0, %1 = linalg.generic {
- indexing_maps = [#map, #map, #map],
- iterator_types = ["parallel"]
- } ins(%arg_0 : tensor<?xf32>)
- outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
- ^bb0(%in: f32, %out_0: f32, %out_1: f32):
- linalg.yield %in, %in : f32, f32
- } -> (tensor<?xf32>, tensor<?xf32>)
-
- %c0 = constant 0 : index
- %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
-
- %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
- return %num_elem_0, %num_elem_1 : index, index
-}
-// CHECK: func @init_tensor_dim_of_linalg_result(
-// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
-// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
-// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
-// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
-// CHECK: return %[[R0]], %[[R1]]
-
-// -----
-
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
-// CHECK: return %[[T1]]
+// CHECK: %[[C2:.+]] = constant 2
+// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
+// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
+// CHECK: return %[[INIT2]]
// -----
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
-// CHECK: return %[[T1]]
+// CHECK: %[[C4:.+]] = constant 4
+// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
+// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]]
+// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
+// CHECK: return %[[INIT2]]
// -----
} : tensor<5x6xf32> to tensor<5x6xf32>
return %0 : tensor<5x6xf32>
}
-
-// -----
-
-func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
-{
- %c1 = constant 1 : index
- %c3 = constant 3 : index
- %c4 = constant 4 : index
- %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
- : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
- %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
- %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
- %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
- return %1, %2, %3 : index, index, index
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
-// CHECK: func @dim_reshape_expansion
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
-
-// -----
-
-func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
-{
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
- : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
- %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
- %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
- return %1, %2 : index, index
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
-// CHECK: func @dim_reshape_collapse
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
-// CHECK-DAG: %[[C4:.+]] = constant 4 : index
-// CHECK-DAG: %[[C5:.+]] = constant 5 : index
-// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C5]], %[[D1]]
-
-// -----
-
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
// -----
-func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
- %arg3: f32) -> (index, index, index)
-{
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %c3 = constant 3 : index
- %c4 = constant 4 : index
- %c5 = constant 5 : index
- %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
- ^bb0(%arg4: index, %arg5: index, %arg6: index):
- linalg.yield %arg3 : f32
- } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
- %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
- %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
- %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
- return %1, %2, %3 : index, index, index
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
-// CHECK: func @dim_of_pad_op
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
-// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
-// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
-// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C12:.+]] = constant 12 : index
-// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
-// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
-// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
-
-// -----
-
#map = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
-// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
module {
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
-// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
+// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
--- /dev/null
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func @init_tensor_static_dim() -> (index, index) {
+ %c0 = constant 0 : index
+ %c2 = constant 2 : index
+ %c6 = constant 6 : index
+ %0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
+ %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+ %2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: func @init_tensor_static_dim
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C6:.+]] = constant 6 : index
+// CHECK: return %[[C6]], %[[C4]]
+
+// -----
+
+func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
+ %c2 = constant 2 : index
+ %0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
+ %1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
+ return %1 : index
+}
+// CHECK: func @init_tensor_dynamic_dim
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG0]]
+
+// -----
+
+func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: func @init_tensor_dynamic_dim2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>,
+ affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %1 = mulf %arg3, %arg4 : f32
+ %2 = addf %1, %arg5 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ %3 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %4 = memref.dim %0, %c1 : tensor<?x?xf32>
+ return %3, %4 : index, index
+}
+// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)>
+// CHECK: func @remove_dim_result_uses
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
+// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
+// CHECK: return %[[T2]], %[[T5]]
+
+// -----
+
+func @remove_dim_result_uses_outs
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32) :
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = memref.dim %1, %c1 : tensor<?x?xf32>
+ return %2 : index
+}
+// CHECK: func @remove_dim_result_uses_outs
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: return %[[ARG1]]
+
+// -----
+
+func @remove_dim_result_uses_sequence
+ (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>],
+ iterator_types = ["parallel", "reduction", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %4 = mulf %arg3, %arg4 : f32
+ %5 = addf %4, %arg5 : f32
+ linalg.yield %5 : f32
+ } -> tensor<?x?xf32>
+ %6 = memref.dim %3, %c0 : tensor<?x?xf32>
+ %7 = memref.dim %3, %c1 : tensor<?x?xf32>
+ return %1, %2, %6, %7 : index, index, index, index
+}
+// CHECK-LABEL: func @remove_dim_result_uses_sequence
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
+
+// -----
+
+func @keep_result_dim_uses_sequence2
+ (%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = memref.dim %arg0, %c0 : tensor<?xf32>
+ %0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32, %arg3 : f32):
+ linalg.yield %arg2 : f32
+ } -> tensor<?x?xf32>
+ %2 = memref.dim %1, %c0 : tensor<?x?xf32>
+ %3 = memref.dim %1, %c1 : tensor<?x?xf32>
+ return %2, %3 : index, index
+}
+// CHECK: func @keep_result_dim_uses_sequence2
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: return %[[T0]], %[[ARG1]]
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
+ %arg_1: tensor<?xf32>) -> (index, index) {
+ %0, %1 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel"]
+ } ins(%arg_0 : tensor<?xf32>)
+ outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%in: f32, %out_0: f32, %out_1: f32):
+ linalg.yield %in, %in : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+
+ %c0 = constant 0 : index
+ %num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
+
+ %num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
+ return %num_elem_0, %num_elem_1 : index, index
+}
+// CHECK: func @init_tensor_dim_of_linalg_result(
+// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
+// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
+// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
+// CHECK: return %[[R0]], %[[R1]]
+
+// -----
+
+func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+{
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ %1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
+ %2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
+ %3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+// CHECK: func @dim_reshape_expansion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+
+// -----
+
+func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
+{
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+ %1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
+ %2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
+ return %1, %2 : index, index
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+// CHECK: func @dim_reshape_collapse
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
+// CHECK-DAG: %[[C4:.+]] = constant 4 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
+// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+// CHECK: return %[[C5]], %[[D1]]
+
+// -----
+
+func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
+ %arg3: f32) -> (index, index, index)
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %c5 = constant 5 : index
+ %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
+ ^bb0(%arg4: index, %arg5: index, %arg6: index):
+ linalg.yield %arg3 : f32
+ } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
+ %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
+ %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
+ %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
+ return %1, %2, %3 : index, index, index
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)>
+// CHECK: func @dim_of_pad_op
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C12:.+]] = constant 12 : index
+// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
+// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
+// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
+// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
// CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)>
-// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
-// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
+// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
-// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
+// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)>
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)>
-// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)>
// CHECK: func @conv_tensors_dynamic
// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
-// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]]
+// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]]
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
-// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]]
+// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
--- /dev/null
+// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+
+func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
+
+// -----
+
+func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_per_dim(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @result_shape_and_per_dim(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
+// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
+// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
+// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
return %0 : i32
}
-// CHECK-LABEL: func @result_shape_per_dim
-// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
-func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
- -> (index, index, index, index, index) {
- // CHECK-DAG: %[[C0:.+]] = constant 0 : index
- // CHECK-DAG: %[[C2:.+]] = constant 2 : index
- // CHECK-DAG: %[[C3:.+]] = constant 3 : index
- // CHECK-DAG: %[[C5:.+]] = constant 5 : index
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
- : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
- %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
- %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
- %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
- %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
- %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
- // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
- // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
- // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
- return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
-
// CHECK-LABEL: test_dialect_canonicalizer
func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)
MLIRReduce
MLIRStandard
MLIRStandardOpsTransforms
+ MLIRTensor
MLIRTransformUtils
MLIRTransforms
)
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
return success();
}
+LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), builder.getIndexType(), currShape));
+ }
+ return success();
+}
+
LogicalResult
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
- SmallVector<Value> operand1Shape, operand2Shape;
Location loc = getLoc();
- for (auto i :
- llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
- operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.emplace_back(std::move(currShape));
}
- for (auto i :
- llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
- operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
+ return success();
+}
+
+LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, ValueRange operands,
+ llvm::SmallVectorImpl<Value> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(operands.size());
+ for (Value operand : llvm::reverse(operands)) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.push_back(builder.create<tensor::FromElementsOp>(
+ getLoc(), builder.getIndexType(), currShape));
+ }
+ return success();
+}
+
+LogicalResult
+OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
+ OpBuilder &builder,
+ llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+ Location loc = getLoc();
+ shapes.reserve(getNumOperands());
+ for (Value operand : llvm::reverse(getOperands())) {
+ auto currShape = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(
+ 0, operand.getType().cast<RankedTensorType>().getRank()),
+ [&](int64_t dim) -> Value {
+ return builder.createOrFold<memref::DimOp>(loc, operand, dim);
+ }));
+ shapes.emplace_back(std::move(currShape));
}
- shapes.emplace_back(std::move(operand2Shape));
- shapes.emplace_back(std::move(operand1Shape));
return success();
}
let results = (outs AnyTensor);
}
-def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
+def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>]> {
+ ["reifyReturnTypeShapes"]>]> {
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapePerDimInterfaceOp :
+ TEST_Op<"op_with_result_shape_per_dim_interface",
+ [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["reifyReturnTypeShapesPerResultDim"]>]> {
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def OpWithResultShapeAndPerDimInterfaceOp :
+ TEST_Op<"op_with_result_shape_and_per_dim_interface",
+ [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}