//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
}
Optional<int64_t> DimOp::getConstantIndex() {
- if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
- return constantOp.getValue().cast<IntegerAttr>().getInt();
- return {};
+ return getConstantIntValue(getIndex());
}
Speculation::Speculatability DimOp::getSpeculatability() {
setNameFn(getResult(), "expanded");
}
+int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
+ assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
+ "invalid resultDim");
+ for (const auto &it : llvm::enumerate(getReassociationIndices()))
+ if (llvm::find(it.value(), resultDim) != it.value().end())
+ return it.index();
+ llvm_unreachable("could not find reassociation group");
+}
+
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
}
};
+struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ Optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ TensorType resultType = expandShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Find reassociation group that contains this result dimension.
+ int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
+
+ // `dim` is the only dynamic dimension in `group`. (Otherwise, the
+ // ExpandShapeOp would be ambiguous.)
+ int64_t product = 1;
+ ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
+ for (int64_t d : grp) {
+ if (d != dim) {
+ assert(!resultType.isDynamicDim(d) && "expected static dim");
+ product *= resultType.getDimSize(d);
+ }
+ }
+
+ // result dim size = src dim size / (product(other dims in reassoc group))
+ Value srcDimSz =
+ rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
+ AffineExpr expr;
+ bindSymbols(dimOp.getContext(), expr);
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, expr.floorDiv(product),
+ srcDimSz);
+ return success();
+ }
+};
+
+struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dimOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return failure();
+
+ // Only constant dimension values are supported.
+ Optional<int64_t> dim = dimOp.getConstantIndex();
+ if (!dim.has_value())
+ return failure();
+
+ // Skip static dims. These are folded to constant ops.
+ TensorType resultType = collapseShapeOp.getResultType();
+ if (!resultType.isDynamicDim(*dim))
+ return failure();
+
+ // Get reassociation group of the result dimension.
+ ReassociationIndices group =
+ collapseShapeOp.getReassociationIndices()[*dim];
+
+ // result dim size = product(dims in reassoc group)
+ SmallVector<Value> srcDimSizes;
+ SmallVector<AffineExpr> syms;
+ AffineExpr product;
+ for (const auto &it : llvm::enumerate(group)) {
+ srcDimSizes.push_back(rewriter.create<DimOp>(
+ dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
+ syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
+ product = product ? product * syms.back() : syms.back();
+ }
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(dimOp, product, srcDimSizes);
+ return success();
+ }
+};
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ComposeReassociativeReshapeOps<ExpandShapeOp>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
FoldReshapeWithConstant<ExpandShapeOp>,
- FoldReshapeWithFromElements<ExpandShapeOp>>(context);
+ FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
+ FoldDimOfCollapseShape>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
%r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
return %r: tensor<2xf32>
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
+// CHECK-LABEL: func @dim_of_expand_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
+// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
+// CHECK: return %[[apply]]
+func.func @dim_of_expand_shape(%t: tensor<?x?xf32>) -> index {
+ %c2 = arith.constant 2 : index
+ %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]]
+ : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
+ %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
+// CHECK-LABEL: func @dim_of_collapse_shape(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
+// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
+// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
+// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
+// CHECK: return %[[apply]]
+func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
+ %c1 = arith.constant 1 : index
+ %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
+ : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
+ %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %1 : index
+}