[mlir][tensor] Add dim(expand_shape/collapse_shape) folding
authorMatthias Springer <springerm@google.com>
Tue, 22 Nov 2022 16:26:19 +0000 (17:26 +0100)
committerMatthias Springer <springerm@google.com>
Tue, 22 Nov 2022 16:34:49 +0000 (17:34 +0100)
Differential Revision: https://reviews.llvm.org/D138487

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir

index 7af19a7..1406007 100644 (file)
@@ -1051,7 +1051,10 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
     }]>
   ];
 
-  let extraClassDeclaration = commonExtraClassDeclaration;
+  let extraClassDeclaration = commonExtraClassDeclaration # [{
+    int64_t getCorrespondingSourceDim(int64_t resultDim);
+  }];
+
   let hasVerifier = 1;
 }
 
index bf54d46..e53879b 100644 (file)
@@ -908,9 +908,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
 }
 
 Optional<int64_t> DimOp::getConstantIndex() {
-  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
-    return constantOp.getValue().cast<IntegerAttr>().getInt();
-  return {};
+  return getConstantIntValue(getIndex());
 }
 
 Speculation::Speculatability DimOp::getSpeculatability() {
index c5d7e42..826c69e 100644 (file)
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
@@ -379,9 +380,7 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
 }
 
 Optional<int64_t> DimOp::getConstantIndex() {
-  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
-    return constantOp.getValue().cast<IntegerAttr>().getInt();
-  return {};
+  return getConstantIntValue(getIndex());
 }
 
 Speculation::Speculatability DimOp::getSpeculatability() {
@@ -1302,6 +1301,15 @@ void ExpandShapeOp::getAsmResultNames(
   setNameFn(getResult(), "expanded");
 }
 
+int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
+  assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
+         "invalid resultDim");
+  for (const auto &it : llvm::enumerate(getReassociationIndices()))
+    if (llvm::find(it.value(), resultDim) != it.value().end())
+      return it.index();
+  llvm_unreachable("could not find reassociation group");
+}
+
 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
@@ -1470,6 +1478,87 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
   }
 };
 
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    // Only constant dimension values are supported.
+    Optional<int64_t> dim = dimOp.getConstantIndex();
+    if (!dim.has_value())
+      return failure();
+
+    // Skip static dims. These are folded to constant ops.
+    TensorType resultType = expandShapeOp.getResultType();
+    if (!resultType.isDynamicDim(*dim))
+      return failure();
+
+    // Find reassociation group that contains this result dimension.
+    int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
+
+    // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+    // ExpandShapeOp would be ambiguous.)
+    int64_t product = 1;
+    ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+    for (int64_t d : grp) {
+      if (d != dim) {
+        assert(!resultType.isDynamicDim(d) && "expected static dim");
+        product *= resultType.getDimSize(d);
+      }
+    }
+
+    // result dim size = src dim size / (product(other dims in reassoc group))
+    Value srcDimSz =
+        rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+    AffineExpr expr;
+    bindSymbols(dimOp.getContext(), expr);
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, expr.floorDiv(product),
+                                               srcDimSz);
+    return success();
+  }
+};
+
+struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return failure();
+
+    // Only constant dimension values are supported.
+    Optional<int64_t> dim = dimOp.getConstantIndex();
+    if (!dim.has_value())
+      return failure();
+
+    // Skip static dims. These are folded to constant ops.
+    TensorType resultType = collapseShapeOp.getResultType();
+    if (!resultType.isDynamicDim(*dim))
+      return failure();
+
+    // Get reassociation group of the result dimension.
+    ReassociationIndices group =
+        collapseShapeOp.getReassociationIndices()[*dim];
+
+    // result dim size = product(dims in reassoc group)
+    SmallVector<Value> srcDimSizes;
+    SmallVector<AffineExpr> syms;
+    AffineExpr product;
+    for (const auto &it : llvm::enumerate(group)) {
+      srcDimSizes.push_back(rewriter.create<DimOp>(
+          dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
+      syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
+      product = product ? product * syms.back() : syms.back();
+    }
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, product, srcDimSizes);
+    return success();
+  }
+};
 } // namespace
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1477,7 +1566,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
               ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
               FoldReshapeWithConstant<ExpandShapeOp>,
-              FoldReshapeWithFromElements<ExpandShapeOp>>(context);
+              FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+              FoldDimOfCollapseShape>(context);
 }
 
 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
index 99e31c7..c9e662f 100644 (file)
@@ -1628,3 +1628,41 @@ func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tens
   %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
   return %r: tensor<2xf32>
 }
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
+// CHECK-LABEL: func @dim_of_expand_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
+//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
+//       CHECK:   return %[[apply]]
+func.func @dim_of_expand_shape(%t: tensor<?x?xf32>) -> index {
+  %c2 = arith.constant 2 : index
+  %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]]
+      : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
+  %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
+  return %1 : index
+}
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
+// CHECK-LABEL: func @dim_of_collapse_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x7x?xf32>
+//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
+//   CHECK-DAG:   %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
+//   CHECK-DAG:   %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
+//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
+//       CHECK:   return %[[apply]]
+func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
+  %c1 = arith.constant 1 : index
+  %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
+      : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+  %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+  return %1 : index
+}