From 6052b17aabec2db8ad255eca5632cb128363c604 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 22 Nov 2022 17:26:19 +0100 Subject: [PATCH] [mlir][tensor] Add dim(expand_shape/collapse_shape) folding Differential Revision: https://reviews.llvm.org/D138487 --- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 5 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 4 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 98 +++++++++++++++++++++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 38 +++++++++ 4 files changed, 137 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 7af19a7..1406007 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1051,7 +1051,10 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { }]> ]; - let extraClassDeclaration = commonExtraClassDeclaration; + let extraClassDeclaration = commonExtraClassDeclaration # [{ + int64_t getCorrespondingSourceDim(int64_t resultDim); + }]; + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index bf54d46..e53879b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -908,9 +908,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source, } Optional DimOp::getConstantIndex() { - if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); - return {}; + return getConstantIntValue(getIndex()); } Speculation::Speculatability DimOp::getSpeculatability() { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index c5d7e42..826c69e 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -379,9 +380,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source, } Optional DimOp::getConstantIndex() { - if (auto constantOp = getIndex().getDefiningOp()) - return constantOp.getValue().cast().getInt(); - return {}; + return getConstantIntValue(getIndex()); } Speculation::Speculatability DimOp::getSpeculatability() { @@ -1302,6 +1301,15 @@ void ExpandShapeOp::getAsmResultNames( setNameFn(getResult(), "expanded"); } +int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) { + assert(resultDim >= 0 && resultDim < getResultType().getRank() && + "invalid resultDim"); + for (const auto &it : llvm::enumerate(getReassociationIndices())) + if (llvm::find(it.value(), resultDim) != it.value().end()) + return it.index(); + llvm_unreachable("could not find reassociation group"); +} + SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } @@ -1470,6 +1478,87 @@ struct FoldCollapseOfCastOp : public OpRewritePattern { } }; +struct FoldDimOfExpandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto expandShapeOp = dimOp.getSource().getDefiningOp(); + if (!expandShapeOp) + return failure(); + + // Only constant dimension values are supported. + Optional dim = dimOp.getConstantIndex(); + if (!dim.has_value()) + return failure(); + + // Skip static dims. These are folded to constant ops. + TensorType resultType = expandShapeOp.getResultType(); + if (!resultType.isDynamicDim(*dim)) + return failure(); + + // Find reassociation group that contains this result dimension. + int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim); + + // `dim` is the only dynamic dimension in `group`. (Otherwise, the + // ExpandShapeOp would be ambiguous.) + int64_t product = 1; + ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim]; + for (int64_t d : grp) { + if (d != dim) { + assert(!resultType.isDynamicDim(d) && "expected static dim"); + product *= resultType.getDimSize(d); + } + } + + // result dim size = src dim size / (product(other dims in reassoc group)) + Value srcDimSz = + rewriter.create(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim); + AffineExpr expr; + bindSymbols(dimOp.getContext(), expr); + rewriter.replaceOpWithNewOp(dimOp, expr.floorDiv(product), + srcDimSz); + return success(); + } +}; + +struct FoldDimOfCollapseShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = dimOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + // Only constant dimension values are supported. + Optional dim = dimOp.getConstantIndex(); + if (!dim.has_value()) + return failure(); + + // Skip static dims. These are folded to constant ops. + TensorType resultType = collapseShapeOp.getResultType(); + if (!resultType.isDynamicDim(*dim)) + return failure(); + + // Get reassociation group of the result dimension. + ReassociationIndices group = + collapseShapeOp.getReassociationIndices()[*dim]; + + // result dim size = product(dims in reassoc group) + SmallVector srcDimSizes; + SmallVector syms; + AffineExpr product; + for (const auto &it : llvm::enumerate(group)) { + srcDimSizes.push_back(rewriter.create( + dimOp.getLoc(), collapseShapeOp.getSrc(), it.value())); + syms.push_back(rewriter.getAffineSymbolExpr(it.index())); + product = product ? product * syms.back() : syms.back(); + } + rewriter.replaceOpWithNewOp(dimOp, product, srcDimSizes); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1477,7 +1566,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add, ComposeExpandOfCollapseOp, FoldReshapeWithConstant, - FoldReshapeWithFromElements>(context); + FoldReshapeWithFromElements, FoldDimOfExpandShape, + FoldDimOfCollapseShape>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 99e31c7..c9e662f 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1628,3 +1628,41 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> return %r: tensor<2xf32> } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> +// CHECK-LABEL: func @dim_of_expand_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]] +// CHECK: return %[[apply]] +func.func @dim_of_expand_shape(%t: tensor) -> index { + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] + : tensor into tensor + %1 = tensor.dim %0, %c2 : tensor + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)> +// CHECK-LABEL: func @dim_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]] +// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]] +// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]] +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]] +// CHECK: return %[[apply]] +func.func @dim_of_collapse_shape(%t: tensor) -> index { + %c1 = arith.constant 1 : index + %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]] + : tensor into tensor + %1 = tensor.dim %0, %c1 : tensor + return %1 : index +} -- 2.7.4