#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
}
};
+/// Pattern to add init operands to ins when all the loops are parallel and
+/// blockArgument corresponding to init is used in the region. This is a fix-up
+/// when unit reduction dimensions are all folded away. In this context, it
+/// becomes a elementwise generic op. E.g., it converts
+///
+/// %0 = tensor.empty() : tensor<1x1xf32>
+/// %1 = linalg.fill
+/// ins(%cst : f32)
+/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
+/// affine_map<(d0) -> (0, d0)>],
+/// iterator_types = ["parallel"]}
+/// ins(%arg0 : tensor<1x?x1x1xf32>)
+/// outs(%1 : tensor<1x1xf32>) {
+/// ^bb0(%in: f32, %out: f32):
+/// %3 = arith.addf %in, %out : f32
+/// linalg.yield %3 : f32
+/// } -> tensor<1x1xf32>
+///
+/// into
+///
+/// %0 = tensor.empty() : tensor<1x1xf32>
+/// %1 = linalg.fill
+/// ins(%cst : f32)
+/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
+/// %2 = tensor.empty() : tensor<1x1xf32>
+/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
+/// affine_map<(d0) -> (0, d0)>,
+/// affine_map<(d0) -> (0, d0)>],
+/// iterator_types = ["parallel"]}
+/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
+/// outs(%2 : tensor<1x1xf32>) {
+/// ^bb0(%in: f32, %in_0: f32, %out: f32):
+/// %4 = arith.addf %in, %in_0 : f32
+/// linalg.yield %4 : f32
+/// } -> tensor<1x1xf32>
+struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOp genericOp,
+ PatternRewriter &rewriter) const override {
+ if (!genericOp.hasTensorSemantics())
+ return failure();
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return failure();
+
+ auto outputOperands = genericOp.getDpsInitOperands();
+ SetVector<OpOperand *> candidates;
+ for (OpOperand *op : outputOperands) {
+ if (genericOp.getMatchingBlockArgument(op).use_empty())
+ continue;
+ candidates.insert(op);
+ }
+
+ if (candidates.empty())
+ return failure();
+
+ // Compute the modified indexing maps.
+ int64_t origNumInput = genericOp.getNumDpsInputs();
+ SmallVector<Value> newInputOperands = genericOp.getDpsInputOperands();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<AffineMap> newIndexingMaps;
+ newIndexingMaps.append(indexingMaps.begin(),
+ std::next(indexingMaps.begin(), origNumInput));
+ for (OpOperand *op : candidates) {
+ newInputOperands.push_back(op->get());
+ newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
+ }
+ newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
+ indexingMaps.end());
+
+ Location loc = genericOp.getLoc();
+ SmallVector<Value> newOutputOperands = outputOperands;
+ for (OpOperand *op : candidates) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfterValue(op->get());
+ auto elemType = op->get().getType().cast<ShapedType>().getElementType();
+ auto empty = rewriter.create<tensor::EmptyOp>(
+ loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
+
+ auto [start, end] = genericOp.getDpsInitsPositionRange();
+ newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
+ }
+
+ auto newOp = rewriter.create<GenericOp>(
+ loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
+ newIndexingMaps, genericOp.getIteratorTypesArray(),
+ /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+
+ Region ®ion = newOp.getRegion();
+ Block *block = new Block();
+ region.push_back(block);
+ BlockAndValueMapping mapper;
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(block);
+ for (auto bbarg : genericOp.getRegionInputArgs())
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+
+ for (OpOperand *op : candidates) {
+ BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+ }
+
+ for (OpOperand *op : outputOperands) {
+ BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
+ if (candidates.count(op))
+ block->addArgument(bbarg.getType(), loc);
+ else
+ mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
+ }
+
+ for (auto &op : genericOp.getBody()->getOperations()) {
+ rewriter.clone(op, mapper);
+ }
+ rewriter.replaceOp(genericOp, newOp.getResults());
+
+ return success();
+ }
+};
+
struct UnitExtentReplacementInfo {
Type type;
AffineMap indexMap;
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
- patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
+ patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
+ RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
+ memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
namespace {
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly)
- patterns.add<FoldUnitDimLoops>(context);
+ patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
else
populateFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));