#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
/// Returns the value that expresses the shape of the output in terms of
/// shape of the input operands where possible
- LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
- SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
+ LogicalResult reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes);
//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
[NoSideEffect,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "operation to define a tensor of particular value";
let description = [{
}
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
- [AttrSizedOperandSegments,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>,
- NoSideEffect]> {
+ [AttrSizedOperandSegments, NoSideEffect,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let summary = "tensor pad operation";
let description = [{
`linalg.pad_tensor` is an operation that pads the `source` tensor
class Linalg_TensorReshapeOp<string mnemonic> : Linalg_ReshapeLikeOp<
mnemonic,
- [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>]>,
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>,
Arguments<(ins AnyTensor:$src,
IndexListArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {
// depending on the specific Linalg op.
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat(props, [
- LinalgStructuredInterface, InferShapedTypeOpInterface])> {
+ LinalgStructuredInterface, ReifyRankedShapedTypeOpInterface])> {
code structuredOpsBaseDecls = [{
// Return whether the op accesses the iteration indices.
bool hasIndexSemantics() {
return !op->getRegion(0).front().getOps<IndexOp>().empty();
}
- LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
- SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
- return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
+ LogicalResult reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
reifiedReturnShapes);
}
}];
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
+/// defined by operations that implement the
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+void populateResolveRankedShapeTypeResultDimsPatterns(
+ RewritePatternSet &patterns);
+
+/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
-/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
+/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input
+/// operands.
+std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();
+
+/// Creates an operation pass to resolve `memref.dim` operations with values
+/// that are defined by operations that implement the
+/// `InferShapedTypeOpInterface` or the `ReifyRankedShapeTypeShapeOpInterface`,
+/// in terms of shapes of its input operands.
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
//===----------------------------------------------------------------------===//
];
}
+def ResolveRankedShapeTypeResultDims :
+ Pass<"resolve-ranked-shaped-type-result-dims"> {
+ let summary = "Resolve memref.dim of result values of ranked shape type";
+ let description = [{
+ The pass resolves memref.dim of result of operations that
+ implement the `ReifyRankedShapedTypeOpInterface` in terms of
+ shapes of its operands.
+ }];
+ let constructor =
+ "mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
+ let dependentDialects = [
+ "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
+
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values";
let description = [{
The pass resolves memref.dim of result of operations that
- implement the `InferShapedTypeOpInterface` in terms of shapes of
- its operands.
+ implement the `InferShapedTypeOpInterface` or
+ `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+ operands.
}];
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
let dependentDialects = [
def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
Tensor_Dialect, "insert_slice",
[NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"expected result type to match dest type",
"dest", "result", "$_self">]> {
let summary = "insert_slice operation";
namespace mlir {
+using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
+
/// ShapedTypeComponents that represents the components of a ShapedType.
/// The components consist of
/// - A ranked or unranked shape with the dimension specification match those
/*desc=*/[{Reify the shape computation for the operation.
Insert operations using the given OpBuilder that computes the
- result shape. Only one of this method or
- `reifyReturnTypeShapesPerResultDim` needs to be overriden by the
- operation. This interface is supposed to be workable during dialect
+ result shape. This interface is supposed to be workable during dialect
conversion (e.g. convert from tensor world to buffer world),
where `getOperand` may be invalid. For example, some ops (e.g.
dynamic_reshape(input, target_shape)) may depend on their operands
"::mlir::SmallVectorImpl<::mlir::Value> &":$reifiedReturnShapes),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{ return ::mlir::failure(); }]
- >,
- InterfaceMethod<
- /*desc=*/[{Reify the shape computation for the operation.
-
- Insert operations using the given OpBuilder that computes the
- result shape. The `reifiedReturnShapes` is expected to be
- populated with as many vectors as the number of results of the
- op (empty if the shape of a result value cannot be computed). If
- the returned shape for a result is not empty, its size must
- match the rank of the shaped type returned. Consequently, this
- interface can only be overridden if the return types are ranked.
-
- If both this method and `reifyReturnTypeShapes` are overridden
- by the operation, `reifyReturnTypeShapes` takes precedence. This
- method is intended to be used when the shape of each result, dim
- pair can be computed independently. Using this method avoids
- adding additional instructions to aggregate individual dimension
- of a result shape into an single `Value` (and consequently
- avoids the need to extract the value from the shape on the
- client side).
- }],
- /*retTy=*/"::mlir::LogicalResult",
- /*methodName=*/"reifyReturnTypeShapesPerResultDim",
- /*args=*/(ins "::mlir::OpBuilder&":$builder,
- "::mlir::SmallVectorImpl<::mlir::SmallVector<::mlir::Value>>&"
- :$reifiedReturnShapes),
- /*methodBody=*/[{}],
- /*defaultImplementation=*/[{ return ::mlir::failure(); }]
>
];
}
defvar InferTensorTypeWithReify = InferTensorType<[
"inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
+
+def ReifyRankedShapedTypeOpInterface :
+ OpInterface<"ReifyRankedShapedTypeOpInterface"> {
+ let description = [{
+ Interface to compute the shape of the result of an operation when
+ the result is a ranked shape type, i.e. `RankedTensorType` or
+ `MemRefType`.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Reify the shape of the result of an operation (typically in
+ terms of shape of its operands)
+
+ Insert operations using the given `OpBuilder` that computes
+ the result shape. The `reifiedReturnShapes` is expected to be
+ populated with as many vectors as the number of results of the
+ op. Each of these vectors is expected to be of size equal to
+ rank of the corresponding result. If the shape of a particular
+ result cannot be computed it must be empty.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"reifyResultShapes",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+ >
+ ];
+}
+
#endif // MLIR_INFERTYPEOPINTERFACE
llvm::SmallSet<unsigned, 4> positions;
};
-LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult
+LinalgOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
// An example that helps understand the logic below.
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
return failure();
Location loc = reshapeOp.getLoc();
- SmallVector<SmallVector<Value>, 4> resultShapes;
- if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
- resultShapes)) ||
+ ReifiedRankedShapedTypeDims resultShapes;
+ if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) ||
!llvm::hasSingleElement(resultShapes))
return failure();
Value initTensor = rewriter.create<InitTensorOp>(
ReplaceStaticShapeDims>(context);
}
-LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &builder,
- SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult InitTensorOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto shapes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
if (isDynamicSize(dim))
builder);
}
-LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult PadTensorOp::reifyResultShapes(
+ OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
Location loc = getLoc();
auto lowPad = getMixedLowPad();
auto highPad = getMixedHighPad();
FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
}
-LogicalResult TensorExpandShapeOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult TensorExpandShapeOp::reifyResultShapes(
+ OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
return success();
}
-LogicalResult TensorCollapseShapeOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult TensorCollapseShapeOp::reifyResultShapes(
+ OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
-//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
-//-------===//
+//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
using namespace mlir;
-/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
-/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
-/// TODO(ravishankarm): This is better put as a interface utility method
-/// somewhere, but that would imply the interface will depend on the `tensor`
-/// dialect. Ideally maybe a utility method in the `tensor` dialect.
-static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
- int64_t dimIndex) {
- unsigned resultNumber = result.getResultNumber();
- auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
- Location loc = result.getOwner()->getLoc();
- if (!shapedTypeOp)
- return nullptr;
-
- // The interface exposes two methods, one that returns the shape of all the
- // results as `Value` and other that returns the shape as a list of
- // `SmallVector<Value>`. The former takes precedence over the latter. So first
- // check if the op implements the first interface method or the second, and
- // get the value to use appropriately.
- SmallVector<Value> reifiedResultShapes;
- if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
- builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
- if (reifiedResultShapes.size() <= resultNumber)
- return nullptr;
- Value resultShape = reifiedResultShapes[resultNumber];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
- return nullptr;
- return builder.create<tensor::ExtractOp>(
- loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
- }
-
- SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
- if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
- builder, reifiedResultShapesPerDim)))
- return nullptr;
- if (reifiedResultShapesPerDim.size() <= resultNumber ||
- reifiedResultShapesPerDim[resultNumber].size() !=
- static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
- return nullptr;
- OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
- if (auto attr = valueOrAttr.dyn_cast<Attribute>())
- return builder.createOrFold<ConstantIndexOp>(
- loc, attr.cast<IntegerAttr>().getInt());
- return valueOrAttr.get<Value>();
-}
-
namespace {
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
- Value replacement =
- getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
- if (!replacement)
+
+ SmallVector<Value> reifiedResultShapes;
+ if (failed(shapedTypeOp.reifyReturnTypeShapes(
+ rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
+ return failure();
+
+ if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
return failure();
- rewriter.replaceOp(dimOp, replacement);
+
+ Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
+ auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+ if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ return failure();
+
+ Location loc = dimOp->getLoc();
+ rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+ dimOp, resultShape,
+ rewriter.createOrFold<ConstantIndexOp>(loc, *dimIndex));
+ return success();
+ }
+};
+
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+template <typename OpTy>
+struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy dimOp,
+ PatternRewriter &rewriter) const override {
+ OpResult dimValue = dimOp.source().template dyn_cast<OpResult>();
+ if (!dimValue)
+ return failure();
+ auto rankedShapeTypeOp =
+ dyn_cast<ReifyRankedShapedTypeOpInterface>(dimValue.getOwner());
+ if (!rankedShapeTypeOp)
+ return failure();
+
+ Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+ if (!dimIndex)
+ return failure();
+
+ SmallVector<SmallVector<Value>> reifiedResultShapes;
+ if (failed(
+ rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes)))
+ return failure();
+
+ if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults())
+ return failure();
+
+ unsigned resultNumber = dimValue.getResultNumber();
+ auto sourceType = dimValue.getType().dyn_cast<RankedTensorType>();
+ if (reifiedResultShapes[resultNumber].size() !=
+ static_cast<size_t>(sourceType.getRank()))
+ return failure();
+
+ rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]);
return success();
}
};
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+struct ResolveRankedShapeTypeResultDimsPass final
+ : public ResolveRankedShapeTypeResultDimsBase<
+ ResolveRankedShapeTypeResultDimsPass> {
+ void runOnOperation() override;
+};
+
struct ResolveShapedTypeResultDimsPass final
: public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
void runOnOperation() override;
};
+
} // namespace
+void memref::populateResolveRankedShapeTypeResultDimsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
+ DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>>(
+ patterns.getContext());
+}
+
void memref::populateResolveShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
// TODO: Move tensor::DimOp pattern to the Tensor dialect.
patterns.getContext());
}
+void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+ std::move(patterns))))
+ return signalPassFailure();
+}
+
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns))))
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
return std::make_unique<ResolveShapedTypeResultDimsPass>();
}
+
+std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
+ return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
+}
return OpFoldResult();
}
-LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim(
- OpBuilder &builder,
- SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+LogicalResult InsertSliceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
reifiedReturnShapes[0][dim] =
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
-
-// -----
-
-func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
- -> (index, index, index, index, index) {
- %c0 = constant 0 : index
- %c1 = constant 1 : index
- %c2 = constant 2 : index
- %0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
- : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
- %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
- %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
- %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
- %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
- %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
- return %1, %2, %3, %4, %5 : index, index, index, index, index
-}
-// CHECK-LABEL: func @result_shape_and_per_dim(
-// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
-// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[C2:.+]] = constant 2 : index
-// CHECK-DAG: %[[C3:.+]] = constant 3 : index
-// CHECK-DAG: %[[C5:.+]] = constant 5 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
-// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
-// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
-// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
-// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
-// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
return success();
}
-LogicalResult
-OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
- OpBuilder &builder,
- llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
- Location loc = getLoc();
- shapes.reserve(getNumOperands());
- for (Value operand : llvm::reverse(getOperands())) {
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(
- 0, operand.getType().cast<RankedTensorType>().getRank()),
- [&](int64_t dim) -> Value {
- return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
- }));
- shapes.emplace_back(std::move(currShape));
- }
- return success();
-}
-
-LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
- OpBuilder &builder, ValueRange operands,
- llvm::SmallVectorImpl<Value> &shapes) {
- Location loc = getLoc();
- shapes.reserve(operands.size());
- for (Value operand : llvm::reverse(operands)) {
- auto currShape = llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(
- 0, operand.getType().cast<RankedTensorType>().getRank()),
- [&](int64_t dim) -> Value {
- return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
- }));
- shapes.push_back(builder.create<tensor::FromElementsOp>(
- getLoc(), builder.getIndexType(), currShape));
- }
- return success();
-}
-
-LogicalResult
-OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
- OpBuilder &builder,
- llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
def OpWithResultShapePerDimInterfaceOp :
TEST_Op<"op_with_result_shape_per_dim_interface",
- [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapesPerResultDim"]>]> {
- let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
- let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
-}
-
-def OpWithResultShapeAndPerDimInterfaceOp :
- TEST_Op<"op_with_result_shape_and_per_dim_interface",
- [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
":Affine",
":DialectUtils",
":IR",
+ ":InferTypeOpInterface",
":LinalgInterfacesIncGen",
":LinalgStructuredOpsIncGen",
":MemRefDialect",