using namespace mlir::linalg;
namespace {
-enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice };
-} // namespace
-
-/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
-/// broadcasting. For example,
-///
-/// ```mlir
-/// #accesses = [
-/// affine_map<(d0, d1) -> (0, d1)>,
-/// affine_map<(d0, d1) -> (d0, 0)>,
-/// affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-/// args_in = 2,
-/// args_out = 1,
-/// indexing_maps = #accesses,
-/// iterator_types = ["parallel", "parallel"],
-/// library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
-/// tensor<5xf32> into tensor<1x5xf32>
-/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
-/// tensor<5xf32> into tensor<5x1xf32>
-/// %2 = linalg.generic #trait %0, %1 {
-/// ^bb0(%arg2: f32, %arg3: f32):
-/// %3 = arith.addf %arg2, %arg3 : f32
-/// linalg.yield %3 : f32
-/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
-/// return %2 : tensor<5x5xf32>
-/// }
-///
-/// would canonicalize to
-///
-/// ```mlir
-/// #accesses = [
-/// affine_map<(d0, d1) -> (d1)>,
-/// affine_map<(d0, d1) -> (d0)>,
-/// affine_map<(d0, d1) -> (d0, d1)>
-/// ]
-///
-/// #trait = {
-/// args_in = 2,
-/// args_out = 1,
-/// indexing_maps = #accesses,
-/// iterator_types = ["parallel", "parallel"],
-/// library_call = "some_external_fn"
-/// }
-///
-/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
-/// tensor<5x5xf32>
-/// {
-/// %0 = linalg.generic #trait %arg0, %arg1 {
-/// ^bb0(%arg2: f32, %arg3: f32):
-/// %3 = arith.addf %arg2, %arg3 : f32
-/// linalg.yield %3 : f32
-/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
-/// return %0 : tensor<5x5xf32>
-/// }
-
-/// Given dims of the iteration space of a structured op that are known to be
-/// single trip count (`unitDims`), return the indexing maps to use in the
-/// canonicalized op with these dims removed, given the original `indexingMaps`.
-static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
- ArrayRef<AffineMap> indexingMaps,
- MLIRContext *context) {
- if (indexingMaps.empty())
- return nullptr;
- unsigned numIterationDims = indexingMaps.front().getNumDims();
- unsigned numSymbols = indexingMaps.front().getNumSymbols();
-
- // Compute the replacement for each dim expr.
- SmallVector<AffineExpr, 4> dimReplacements;
- dimReplacements.reserve(numIterationDims);
- unsigned numKeptDims = 0;
- for (unsigned dim : llvm::seq<unsigned>(0, numIterationDims)) {
- if (unitDims.count(dim))
- dimReplacements.push_back(getAffineConstantExpr(0, context));
- else
- dimReplacements.push_back(getAffineDimExpr(numKeptDims++, context));
- }
-
- // Symbols remain the same.
- SmallVector<AffineExpr, 4> symReplacements;
- symReplacements.reserve(numSymbols);
- for (unsigned symbol : llvm::seq<unsigned>(0, numSymbols))
- symReplacements.push_back(getAffineSymbolExpr(symbol, context));
-
- SmallVector<AffineMap, 4> newIndexingMaps;
- newIndexingMaps.reserve(indexingMaps.size());
- for (AffineMap operandMap : indexingMaps) {
- // Expected indexing maps to have no symbols.
- if (operandMap.getNumSymbols())
- return nullptr;
- newIndexingMaps.push_back(simplifyAffineMap(
- operandMap.replaceDimsAndSymbols(dimReplacements, symReplacements,
- numIterationDims - unitDims.size(),
- numSymbols)));
- }
-
- // Check that the new index maps are invertible. If not, something went
- // wrong, so abort.
- if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
- return nullptr;
- return ArrayAttr::get(context,
- llvm::to_vector<4>(llvm::map_range(
- newIndexingMaps, [](AffineMap map) -> Attribute {
- return AffineMapAttr::get(map);
- })));
-}
-
-/// Update the index accesses of linalg operations having index semantics.
-static void replaceUnitDimIndexOps(GenericOp genericOp,
- const DenseSet<unsigned> &unitDims,
- PatternRewriter &rewriter) {
- for (IndexOp indexOp :
- llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(indexOp);
- if (unitDims.count(indexOp.getDim()) != 0) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
- } else {
- // Update the dimension of the index operation if needed.
- unsigned droppedDims = llvm::count_if(
- unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
- if (droppedDims != 0)
- rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
- indexOp.getDim() - droppedDims);
- }
- }
-}
-
-namespace {
-/// Pattern to fold unit-trip count loops in GenericOps.
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMapsArray();
- if (indexingMaps.empty())
- return failure();
-
- // Check if any of the iteration dimensions are unit-trip count. They will
- // end up being unit-trip count if they are used to index into a unit-dim
- // tensor/memref.
- AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
- if (!invertedMap)
- return failure();
- SmallVector<int64_t> dims = genericOp.getStaticShape();
-
- DenseSet<unsigned> unitDims;
- SmallVector<unsigned, 4> unitDimsReductionLoops;
- ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
- for (const auto &expr : enumerate(invertedMap.getResults())) {
- if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
- if (dims[dimExpr.getPosition()] == 1)
- unitDims.insert(expr.index());
- }
-
- if (unitDims.empty())
- return failure();
-
- // Compute the modified indexing maps.
- MLIRContext *context = rewriter.getContext();
- ArrayAttr newIndexingMapAttr =
- replaceUnitDims(unitDims, indexingMaps, context);
- if (!newIndexingMapAttr)
- return genericOp.emitError("unable to compute modified indexing_maps");
-
- // Compute the iterator types of the modified op by dropping the one-trip
- // count loops.
- SmallVector<Attribute, 4> newIteratorTypes;
- for (const auto &attr : llvm::enumerate(iteratorTypes)) {
- if (!unitDims.count(attr.index()))
- newIteratorTypes.push_back(attr.value());
- }
-
- rewriter.startRootUpdate(genericOp);
- genericOp.setIndexingMapsAttr(newIndexingMapAttr);
- genericOp.setIteratorTypesAttr(ArrayAttr::get(context, newIteratorTypes));
- replaceUnitDimIndexOps(genericOp, unitDims, rewriter);
- rewriter.finalizeRootUpdate(genericOp);
- return success();
- }
-};
-
/// Pattern to move init operands to ins when all the loops are parallel and
/// blockArgument corresponding to init is used in the region. This is a fix-up
/// when unit reduction dimensions are all folded away. In this context, it
return success();
}
};
+} // namespace
+
+//===---------------------------------------------------------------------===//
+// Drop loops that are unit-extents within Linalg operations.
+//===---------------------------------------------------------------------===//
+
+/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
+/// broadcasting. For example,
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (0, d1)>,
+/// affine_map<(d0, d1) -> (d0, 0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<1x5xf32>
+/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
+/// tensor<5xf32> into tensor<5x1xf32>
+/// %2 = linalg.generic #trait %0, %1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = arith.addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
+/// return %2 : tensor<5x5xf32>
+/// }
+///
+/// would canonicalize to
+///
+/// ```mlir
+/// #accesses = [
+/// affine_map<(d0, d1) -> (d1)>,
+/// affine_map<(d0, d1) -> (d0)>,
+/// affine_map<(d0, d1) -> (d0, d1)>
+/// ]
+///
+/// #trait = {
+/// args_in = 2,
+/// args_out = 1,
+/// indexing_maps = #accesses,
+/// iterator_types = ["parallel", "parallel"],
+/// library_call = "some_external_fn"
+/// }
+///
+/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
+/// tensor<5x5xf32>
+/// {
+/// %0 = linalg.generic #trait %arg0, %arg1 {
+/// ^bb0(%arg2: f32, %arg3: f32):
+/// %3 = arith.addf %arg2, %arg3 : f32
+/// linalg.yield %3 : f32
+/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
+/// return %0 : tensor<5x5xf32>
+/// }
+/// Update the index accesses of linalg operations having index semantics.
+static void
+replaceUnitDimIndexOps(GenericOp genericOp,
+ const llvm::SmallDenseSet<unsigned> &unitDims,
+ RewriterBase &rewriter) {
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(indexOp);
+ if (unitDims.count(indexOp.getDim()) != 0) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
+ } else {
+ // Update the dimension of the index operation if needed.
+ unsigned droppedDims = llvm::count_if(
+ unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
+ if (droppedDims != 0)
+ rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
+ indexOp.getDim() - droppedDims);
+ }
+ }
+}
+
+/// Expand the given `value` so that the type matches the type of `origDest`.
+/// The `reassociation` is used when `rankReductionStrategy` is set to
+/// `RankReductionStrategy::ReassociativeReshape`.
+static Value
+expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
+ ArrayRef<ReassociationIndices> reassociation,
+ ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+ // There are no results for memref outputs.
+ auto origResultType = cast<RankedTensorType>(origDest.getType());
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ unsigned rank = origResultType.getRank();
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> sizes =
+ tensor::getMixedSizes(rewriter, loc, origDest);
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ return rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, result, origDest, offsets, sizes, strides);
+ }
+
+ assert(rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
+ reassociation);
+}
+
+/// Collapse the given `value` so that the type matches the type of
+/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
+/// set to `RankReductionStrategy::ReassociativeReshape`.
+static Value collapseValue(
+ RewriterBase &rewriter, Location loc, Value operand,
+ ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
+ ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(
+ rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ MemRefLayoutAttrInterface layout;
+ auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
+ layout, memrefType.getMemorySpace());
+ return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
+ }
+ if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
+ if (rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ FailureOr<Value> rankReducingExtract =
+ tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
+ targetShape);
+ assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
+ return *rankReducingExtract;
+ }
+
+ assert(
+ rankReductionStrategy ==
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
+ "unknown rank reduction strategy");
+ auto targetType =
+ RankedTensorType::get(targetShape, tensorType.getElementType());
+ return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
+ reassociation);
+ }
+ llvm_unreachable("unsupported operand type");
+}
+
+/// Compute the modified metadata for an operands of operation
+/// whose unit dims are being dropped. Return the new indexing map
+/// to use, the shape of the operand in the replacement op
+/// and the `reassocation` to use to go from original operand shape
+/// to modified operand shape.
struct UnitExtentReplacementInfo {
AffineMap indexMap;
SmallVector<ReassociationIndices> reassociation;
SmallVector<int64_t> targetShape;
};
-} // namespace
-
-/// Utility function for replacing operands/results to a linalg generic
-/// operation with unit-extent dimensions. These can be replaced with
-/// an operand/result with the unit-extent dimension removed. This is only done
-/// if the indexing map used to access that dimension has a
-/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
-/// Linalg op, and its `indexMap` the utility function returns:
-/// - the new type with dimensions of size 1 removed.
-/// - modified index map that can be used to access the replaced result/operand
-/// - the reassociation that converts from the original tensor type to the
-/// modified tensor type.
-static std::optional<UnitExtentReplacementInfo>
-replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
- MLIRContext *context) {
+static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
+ MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
+ llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
+ ArrayRef<AffineExpr> dimReplacements) {
+ UnitExtentReplacementInfo info;
+ ReassociationIndices reassociationGroup;
+ SmallVector<AffineExpr> newIndexExprs;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
+ ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
- SmallVector<AffineExpr> newIndexExprs;
- SmallVector<int64_t> newShape;
- int64_t origRank = genericOp.getRank(opOperand);
- AffineExpr zeroExpr = getAffineConstantExpr(0, context);
- auto isUnitExtent = [&](int64_t dim) -> bool {
- return shape[dim] == 1 && exprs[dim] == zeroExpr;
+ auto isUnitDim = [&](unsigned dim) {
+ if (auto dimExpr = exprs[dim].dyn_cast<AffineDimExpr>()) {
+ unsigned oldPosition = dimExpr.getPosition();
+ return !oldDimsToNewDimsMap.count(oldPosition);
+ }
+ // Handle the other case where the shape is 1, and is accessed using a
+ // constant 0.
+ if (operandShape[dim] == 1) {
+ auto constAffineExpr = exprs[dim].dyn_cast<AffineConstantExpr>();
+ return constAffineExpr && constAffineExpr.getValue() == 0;
+ }
+ return false;
};
- // Early return for memrefs with affine maps to represent that we will always
- // leave them unchanged.
- Type actualType = opOperand->get().getType();
- if (auto memref = dyn_cast<MemRefType>(actualType)) {
- if (!memref.getLayout().isIdentity())
- return std::nullopt;
- }
-
int64_t dim = 0;
- SmallVector<ReassociationIndices> reassociation;
- ReassociationIndices reassociationGroup;
- // Fold dimensions that are unit-extent at the beginning of the tensor.
- while (dim < origRank && isUnitExtent(dim))
+ while (dim < operandShape.size() && isUnitDim(dim))
reassociationGroup.push_back(dim++);
- while (dim < origRank) {
- assert(!isUnitExtent(dim) && "expected non unit-extent");
+ while (dim < operandShape.size()) {
+ assert(!isUnitDim(dim) && "expected non unit-extent");
reassociationGroup.push_back(dim);
- newIndexExprs.push_back(exprs[dim]);
- newShape.push_back(shape[dim]);
+ AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
+ newIndexExprs.push_back(newExpr);
+ info.targetShape.push_back(operandShape[dim]);
++dim;
// Fold all following dimensions that are unit-extent.
- while (dim < origRank && isUnitExtent(dim))
+ while (dim < operandShape.size() && isUnitDim(dim)) {
reassociationGroup.push_back(dim++);
- reassociation.push_back(reassociationGroup);
+ }
+ info.reassociation.push_back(reassociationGroup);
reassociationGroup.clear();
}
-
- // Return if the rank was not reduced.
- if (origRank == static_cast<int64_t>(newShape.size()))
- return std::nullopt;
-
- UnitExtentReplacementInfo info = {
- /*indexMap=*/AffineMap::get(indexingMap.getNumDims(),
- indexingMap.getNumSymbols(), newIndexExprs,
- context),
- /*reassociation=*/reassociation, /*targetShape=*/newShape};
+ info.indexMap =
+ AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
+ newIndexExprs, context);
return info;
}
-namespace {
+LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+ const ControlDropUnitDims &options) {
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ if (indexingMaps.empty())
+ return failure();
+
+ // 1. Check if any of the iteration dimensions are unit-trip count. They will
+ // end up being unit-trip count if they are used to index into a unit-dim
+ // tensor/memref.
+ AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
+ if (!invertedMap) {
+ return rewriter.notifyMatchFailure(genericOp,
+ "invalid indexing maps for operation");
+ }
+ SmallVector<int64_t> dims = genericOp.getStaticShape();
-/// Pattern to replace tensor/buffer operands/results that are unit extents.
-struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
- ReplaceUnitExtents(MLIRContext *ctx,
- RankReductionStrategy rankReductionStrategy)
- : OpRewritePattern<GenericOp>(ctx),
- rankReductionStrategy(rankReductionStrategy) {}
-
- // Expand the given value.
- Value expandValue(Value result, Value origOutput,
- ArrayRef<ReassociationIndices> reassociation, Location loc,
- PatternRewriter &rewriter) const {
- // There are no results for memref outputs.
- auto origResultType = cast<RankedTensorType>(origOutput.getType());
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- unsigned rank = origResultType.getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, loc, origOutput);
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- return rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, result, origOutput, offsets, sizes, strides);
+ // 1a. Get the allowed list of dimensions to drop from the `options`.
+ SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
+ if (allowedUnitDims.empty()) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "control function returns no allowed unit dims to prune");
+ }
+ llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
+ allowedUnitDims.end());
+ llvm::SmallDenseSet<unsigned> unitDims;
+ ArrayAttr iteratorTypes = genericOp.getIteratorTypes();
+ for (const auto &expr : enumerate(invertedMap.getResults())) {
+ if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>()) {
+ if (dims[dimExpr.getPosition()] == 1 &&
+ unitDimsFilter.count(expr.index()))
+ unitDims.insert(expr.index());
}
+ }
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- return rewriter.create<tensor::ExpandShapeOp>(loc, origResultType, result,
- reassociation);
+ // 2. Compute the iterator types of the modified op by dropping the one-trip
+ // count loops.
+ SmallVector<utils::IteratorType> newIteratorTypes;
+ llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
+ SmallVector<AffineExpr> dimReplacements;
+ unsigned newDims = 0;
+ for (auto [index, attr] :
+ llvm::enumerate(genericOp.getIteratorTypesArray())) {
+ if (unitDims.count(index)) {
+ dimReplacements.push_back(
+ getAffineConstantExpr(0, rewriter.getContext()));
+ } else {
+ newIteratorTypes.push_back(attr);
+ oldDimToNewDimMap[index] = newDims;
+ dimReplacements.push_back(
+ getAffineDimExpr(newDims, rewriter.getContext()));
+ newDims++;
+ }
}
- // Collapse the given value.
- Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
- ArrayRef<ReassociationIndices> reassociation,
- Location loc, PatternRewriter &rewriter) const {
- if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- FailureOr<Value> rankReducingExtract =
- memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
- targetShape);
- assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
- return *rankReducingExtract;
- }
-
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- MemRefLayoutAttrInterface layout;
- auto targetType =
- MemRefType::get(targetShape, memrefType.getElementType(), layout,
- memrefType.getMemorySpace());
- return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ // 3. For each of the operands, find the
+ // - modified affine map to use.
+ // - shape of the operands after the unit-dims are dropped.
+ // - the reassociation indices used to convert from the original
+ // operand type to modified operand (needed only when using reshapes
+ // for rank reduction strategy)
+ // Note that the indexing maps might need changing even if there are no
+ // unit dimensions that are dropped to handle cases where `0` is used to
+ // access a unit-extent tensor. Consider moving this out of this specific
+ // transformation as a stand-alone transformation. Kept here right now due
+ // to legacy.
+ SmallVector<AffineMap> newIndexingMaps;
+ SmallVector<SmallVector<ReassociationIndices>> reassociations;
+ SmallVector<SmallVector<int64_t>> targetShapes;
+ SmallVector<bool> collapsed;
+ auto hasCollapsibleType = [](OpOperand &operand) {
+ Type operandType = operand.get().getType();
+ if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
+ return memrefOperandType.getLayout().isIdentity();
+ } else if (auto tensorOperandType =
+ dyn_cast<RankedTensorType>(operandType)) {
+ return tensorOperandType.getEncoding() == nullptr;
}
- if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
- if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
- FailureOr<Value> rankReducingExtract =
- tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
- targetShape);
- assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
- return *rankReducingExtract;
- }
-
- assert(rankReductionStrategy ==
- RankReductionStrategy::ReassociativeReshape &&
- "unknown rank reduction strategy");
- auto targetType =
- RankedTensorType::get(targetShape, tensorType.getElementType());
- return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
- reassociation);
+ return false;
+ };
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
+ ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
+ if (!hasCollapsibleType(opOperand)) {
+ AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
+ dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
+ newIndexingMaps.push_back(newIndexingMap);
+ targetShapes.push_back(llvm::to_vector(shape));
+ collapsed.push_back(false);
+ reassociations.push_back({});
+ continue;
}
- llvm_unreachable("unsupported operand type");
+ auto replacementInfo = dropUnitExtentFromOperandMetadata(
+ rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
+ dimReplacements);
+ reassociations.push_back(replacementInfo.reassociation);
+ newIndexingMaps.push_back(replacementInfo.indexMap);
+ targetShapes.push_back(replacementInfo.targetShape);
+ collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
+ indexingMap.getNumResults()));
}
- LogicalResult matchAndRewrite(GenericOp genericOp,
- PatternRewriter &rewriter) const override {
- // Skip the pattern if the op has any tensor with special encoding.
- if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
- auto tensorType = dyn_cast<RankedTensorType>(type);
- return tensorType && tensorType.getEncoding() != nullptr;
- }))
- return failure();
- MLIRContext *context = rewriter.getContext();
- Location loc = genericOp.getLoc();
- SmallVector<Value> oldOutputs(genericOp.getOutputs().begin(),
- genericOp.getOutputs().end());
-
- SmallVector<AffineMap> newIndexingMaps;
- SmallVector<SmallVector<ReassociationIndices>> reassociations;
- SmallVector<SmallVector<int64_t>> targetShapes;
- SmallVector<bool> collapsed;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
- if (replacementInfo) {
- reassociations.push_back(replacementInfo->reassociation);
- newIndexingMaps.push_back(replacementInfo->indexMap);
- targetShapes.push_back(replacementInfo->targetShape);
- collapsed.push_back(true);
- } else {
- // If replaceUnitExtents cannot handle this case (or no unit dim was
- // removed), maintain the same type, indexing map, and create a set of
- // mappings representing an identity matrix.
- newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
- reassociations.emplace_back();
- targetShapes.emplace_back();
- collapsed.push_back(false);
- }
+ // Abort if the indexing maps of the result operation are not invertible
+ // (i.e. not legal) or if no dimension was reduced.
+ if (newIndexingMaps == indexingMaps ||
+ !inversePermutation(concatAffineMaps(newIndexingMaps)))
+ return failure();
+
+ Location loc = genericOp.getLoc();
+ // 4. For each of the operands, collapse the operand to convert
+ // from original shape to shape in the modified operation if needed,
+ // either through use of reshapes or rank-reducing slices as
+ // specified in `options`.
+ SmallVector<Value> newOperands;
+ for (OpOperand &opOperand : genericOp->getOpOperands()) {
+ int64_t idx = opOperand.getOperandNumber();
+ if (!collapsed[idx]) {
+ newOperands.push_back(opOperand.get());
+ continue;
}
+ newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
+ targetShapes[idx], reassociations[idx],
+ options.rankReductionStrategy));
+ }
- // Abort if the indexing maps of the result operation are not invertible
- // (i.e. not legal) or if no dimension was reduced.
- if (!llvm::any_of(collapsed, [](bool c) { return c; }) ||
- !inversePermutation(concatAffineMaps(newIndexingMaps)))
- return failure();
-
- // Insert rank reductions.
- SmallVector<Value> newOperands;
- for (OpOperand &opOperand : genericOp->getOpOperands()) {
- int64_t idx = opOperand.getOperandNumber();
- if (!collapsed[idx]) {
- newOperands.push_back(opOperand.get());
- continue;
- }
- newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx],
- reassociations[idx], loc, rewriter));
+ // 5. Create the `linalg.generic` operation with the new operands,
+ // indexing maps, iterator types and result types.
+ ArrayRef<Value> newInputs =
+ ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
+ ArrayRef<Value> newOutputs =
+ ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
+ SmallVector<Type> resultTypes;
+ resultTypes.reserve(genericOp.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
+ resultTypes.push_back(newOutputs[i].getType());
+ GenericOp replacementOp =
+ rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
+ newIndexingMaps, newIteratorTypes);
+ rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
+ replacementOp.getRegion().begin());
+ // 5a. Replace `linalg.index` operations that refer to the dropped unit
+ // dimensions.
+ replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
+
+ // 6. If any result type changes, insert a reshape/slice to convert from the
+ // original
+ // type to the new type.
+ SmallVector<Value> resultReplacements;
+ for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
+ unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
+ Value origDest = genericOp.getDpsInitOperand(index)->get();
+ if (!collapsed[opOperandIndex]) {
+ resultReplacements.push_back(result);
+ continue;
}
+ resultReplacements.push_back(expandValue(rewriter, loc, result, origDest,
+ reassociations[opOperandIndex],
+ options.rankReductionStrategy));
+ }
- // If any result type changes, insert a reshape to convert from the original
- // type to the new type.
- ArrayRef<Value> newInputs =
- ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
- ArrayRef<Value> newOutputs =
- ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
- SmallVector<Type> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newOutputs[i].getType());
- GenericOp replacementOp = rewriter.create<GenericOp>(
- loc, resultTypes, newInputs, newOutputs, newIndexingMaps,
- genericOp.getIteratorTypesArray());
- rewriter.inlineRegionBefore(genericOp.getRegion(),
- replacementOp.getRegion(),
- replacementOp.getRegion().begin());
-
- // If any result tensor has a modified shape, then add reshape to recover
- // the original shape.
- SmallVector<Value> resultReplacements;
- for (const auto &result : llvm::enumerate(replacementOp.getResults())) {
- unsigned index = result.index() + replacementOp.getNumDpsInputs();
- Value origOutput = oldOutputs[result.index()];
- if (!collapsed[result.index() + genericOp.getNumDpsInputs()]) {
- resultReplacements.push_back(result.value());
- continue;
- }
- resultReplacements.push_back(expandValue(
- result.value(), origOutput, reassociations[index], loc, rewriter));
- }
+ rewriter.replaceOp(genericOp, resultReplacements);
+ return success();
+}
- rewriter.replaceOp(genericOp, resultReplacements);
- return success();
+namespace {
+struct DropUnitDims : public OpRewritePattern<GenericOp> {
+ DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(std::move(options)) {}
+
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ return dropUnitDims(rewriter, genericOp, options);
}
private:
- RankReductionStrategy rankReductionStrategy;
+ ControlDropUnitDims options;
};
} // namespace
tensor::CollapseShapeOp reshapedSource;
{
OpBuilder::InsertionGuard g(rewriter);
- // The only difference between InsertSliceOp and ParallelInsertSliceOp is
- // the insertion point is just before the ParallelCombiningOp in the
+ // The only difference between InsertSliceOp and ParallelInsertSliceOp
+ // is the insertion point is just before the ParallelCombiningOp in the
// parallel case.
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
-void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
- RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
+ ControlDropUnitDims &options) {
auto *context = patterns.getContext();
- patterns.add<ReplaceUnitExtents>(context,
- RankReductionStrategy::ReassociativeReshape);
+ patterns.add<DropUnitDims>(context, options);
// TODO: Patterns unrelated to unit dim folding should be factored out.
- patterns.add<FoldUnitDimLoops, RankReducedExtractSliceOp,
+ patterns.add<RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
-void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
- RewritePatternSet &patterns) {
+static void
+populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
+ ControlDropUnitDims &options) {
auto *context = patterns.getContext();
- patterns.add<ReplaceUnitExtents>(context,
- RankReductionStrategy::ExtractInsertSlice);
- patterns.add<FoldUnitDimLoops>(context);
+ options.rankReductionStrategy =
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
+ patterns.add<DropUnitDims>(context, options);
// TODO: Patterns unrelated to unit dim folding should be factored out.
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
+void mlir::linalg::populateFoldUnitExtentDimsPatterns(
+ RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
+ if (options.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
+ populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
+ } else if (options.rankReductionStrategy ==
+ linalg::ControlDropUnitDims::RankReductionStrategy::
+ ReassociativeReshape) {
+ populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
+ }
+}
+
void mlir::linalg::populateMoveInitOperandsToInputPattern(
RewritePatternSet &patterns) {
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
- if (foldOneTripLoopsOnly) {
- patterns.add<FoldUnitDimLoops, MoveInitOperandsToInput>(context);
- } else if (useRankReducingSlices) {
- populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
- populateMoveInitOperandsToInputPattern(patterns);
- } else {
- populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
- populateMoveInitOperandsToInputPattern(patterns);
+ ControlDropUnitDims options;
+ if (useRankReducingSlices) {
+ options.rankReductionStrategy = linalg::ControlDropUnitDims::
+ RankReductionStrategy::ExtractInsertSlice;
}
+ linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
+ populateMoveInitOperandsToInputPattern(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};