static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context) {
+ // Handle the corner case of the result being a rank 0 shaped type. Return an
+ // emtpy ArrayAttr.
+ if (mapsConsumer.empty() && !mapsProducer.empty())
+ return ArrayAttr::get(ArrayRef<Attribute>(), context);
if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
ShapedType intermediateType,
ShapedType smallerType) -> bool {
return largerType.getRank() > intermediateType.getRank() &&
- intermediateType.getRank() > smallerType.getRank() &&
- smallerType.getRank() > 0;
+ intermediateType.getRank() > smallerType.getRank();
};
// Check if producer and consumer are both expanding dims.
if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include <set>
+
#define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir;
context);
}
+/// Modify the region of indexed generic op to drop arguments corresponding to
+/// loops that are unit trip count.
+template <typename OpTy>
+static LogicalResult
+replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet<unsigned> &unitDims,
+ PatternRewriter &rewriterp) {
+ return success();
+}
+
+template <>
+LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
+ IndexedGenericOp op, const DenseSet<unsigned> &unitDims,
+ PatternRewriter &rewriter) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *entryBlock = &op.getOperation()->getRegion(0).front();
+ rewriter.setInsertionPointToStart(entryBlock);
+ Value zero = rewriter.create<ConstantIndexOp>(op.getLoc(), 0);
+ for (unsigned unitDimLoop : unitDims) {
+ entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
+ }
+ std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end());
+ for (unsigned i : llvm::reverse(orderedUnitDims))
+ entryBlock->eraseArgument(i);
+ return success();
+}
+
namespace {
/// Pattern to fold unit-trip count loops in GenericOps.
// TODO: Generalize this to indexed-generic as well by modifying the region args
// as well.
-struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
+template <typename GenericOpTy>
+struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
- SmallVector<AffineMap, 4> indexingMaps = genericOp.getIndexingMaps();
+ SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
if (indexingMaps.empty())
return failure();
if (!invertedMap)
return failure();
SmallVector<int64_t, 4> dims;
- for (ShapedType shapedType : genericOp.getInputOutputShapedTypes())
+ for (ShapedType shapedType : op.getInputOutputShapedTypes())
dims.append(shapedType.getShape().begin(), shapedType.getShape().end());
DenseSet<unsigned> unitDims;
- ArrayAttr iteratorTypes = genericOp.iterator_types();
+ ArrayAttr iteratorTypes = op.iterator_types();
for (auto expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = expr.value().dyn_cast<AffineDimExpr>())
if (dims[dimExpr.getPosition()] == 1 &&
ArrayAttr newIndexingMapAttr =
replaceUnitDims(unitDims, indexingMaps, context);
if (!newIndexingMapAttr)
- return genericOp.emitError("unable to compute modified indexing_maps");
+ return op.emitError("unable to compute modified indexing_maps");
// Compute the iterator types of the modified op by dropping the one-trip
// count loops.
newIteratorTypes.push_back(attr.value());
}
- rewriter.startRootUpdate(genericOp);
- genericOp.indexing_mapsAttr(newIndexingMapAttr);
- genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
- rewriter.finalizeRootUpdate(genericOp);
+ rewriter.startRootUpdate(op);
+ op.indexing_mapsAttr(newIndexingMapAttr);
+ op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
+ replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
+ rewriter.finalizeRootUpdate(op);
return success();
}
};
namespace {
/// Pattern to replace tensors operands/results that are unit extents.
-struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
- using OpRewritePattern<GenericOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(GenericOp genericOp,
+template <typename GenericOpTy>
+struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
+ using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
// TODO: support init_tensors and reductions.
- if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
+ if (!op.hasTensorSemantics() || !op.init_tensors().empty())
return failure();
MLIRContext *context = rewriter.getContext();
- Location loc = genericOp.getLoc();
+ Location loc = op.getLoc();
SmallVector<AffineMap, 4> newIndexingMaps;
SmallVector<ArrayAttr, 4> reassociationMaps;
SmallVector<ShapedType, 4> newInputOutputTypes;
bool doCanonicalization = false;
- for (auto it : llvm::zip(genericOp.getIndexingMaps(),
- genericOp.getInputOutputShapedTypes())) {
+ for (auto it :
+ llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) {
auto replacementInfo = replaceUnitExtents(
- std::get<0>(it), std::get<1>(it).cast<RankedTensorType>(), context);
+ std::get<0>(it), std::get<1>(it).template cast<RankedTensorType>(),
+ context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
return res;
};
- SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
+ SmallVector<Value, 4> newInputs = insertReshapes(op.inputs());
SmallVector<Value, 4> newOutputBuffers =
- insertReshapes(genericOp.output_buffers());
- SmallVector<Value, 4> newInitTensors =
- insertReshapes(genericOp.init_tensors());
+ insertReshapes(op.output_buffers());
+ SmallVector<Value, 4> newInitTensors = insertReshapes(op.init_tensors());
// If any result type change, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
- resultTypes.reserve(genericOp.getNumResults());
- for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
- resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
- GenericOp replacementOp = rewriter.create<GenericOp>(
+ resultTypes.reserve(op.getNumResults());
+ for (unsigned i : llvm::seq<unsigned>(0, op.getNumResults()))
+ resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]);
+ GenericOpTy replacementOp = rewriter.create<GenericOpTy>(
loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
newIndexingMaps,
llvm::to_vector<4>(
- genericOp.iterator_types().getAsValueRange<StringAttr>()));
- rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
+ op.iterator_types().template getAsValueRange<StringAttr>()));
+ rewriter.inlineRegionBefore(op.region(), replacementOp.region(),
replacementOp.region().begin());
// If any result tensor has a modified shape, then add reshape to recover
SmallVector<Value, 4> resultReplacements;
for (auto result : llvm::enumerate(replacementOp.getResults())) {
unsigned index = result.index() + replacementOp.getNumOperands();
- RankedTensorType origResultType = genericOp.getResult(result.index())
+ RankedTensorType origResultType = op.getResult(result.index())
.getType()
- .cast<RankedTensorType>();
+ .template cast<RankedTensorType>();
if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, origResultType, result.value(), reassociationMaps[index]));
else
resultReplacements.push_back(result.value());
}
- rewriter.replaceOp(genericOp, resultReplacements);
+ rewriter.replaceOp(op, resultReplacements);
return success();
}
};
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
+ patterns
+ .insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
+ ReplaceUnitExtentTensors<GenericOp>,
+ ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
}
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
if (foldOneTripLoopsOnly)
- patterns.insert<FoldUnitDimLoops>(context);
+ patterns.insert<FoldUnitDimLoops<GenericOp>,
+ FoldUnitDimLoops<IndexedGenericOp>>(context);
else
populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
applyPatternsAndFoldGreedily(funcOp.getBody(), patterns);
// consumer's operand.
// If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
// generic op. In this case, there are no indices in block arguments.
- unsigned numProducerIndices =
- isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
- unsigned numConsumerIndices =
- isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+ unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
+ ? producer.getNumLoops()
+ : 0;
+ unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
+ ? consumer.getNumLoops()
+ : 0;
+ unsigned numFusedOpIndices =
+ (isa<IndexedGenericOp>(producer.getOperation()) ||
+ isa<IndexedGenericOp>(consumer.getOperation()))
+ ? std::max(producer.getNumLoops(), consumer.getNumLoops())
+ : 0;
// Firstly, add all the indices to the block arguments.
- for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
- i < e; ++i)
+ for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
fusedBlock->addArgument(rewriter.getIndexType());
// Map the arguments for the unmodified args from the consumer.
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
auto newIndex = rewriter.create<mlir::AffineApplyOp>(
producer.getLoc(),
consumerToProducerLoopsMap.getSubMap(producerArg.index()),
- fusedBlock->getArguments().take_front(nloops));
+ fusedBlock->getArguments().take_front(numFusedOpIndices));
mapper.map(producerArg.value(), newIndex);
} else {
mapper.map(producerArg.value(),
// -----
+// -----
+
+func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
+ -> tensor<f32> {
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+ tensor<1x1x1xf32> into tensor<1xf32>
+ %1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor<f32>
+ return %1 : tensor<f32>
+}
+// CHECK-LABEL: collapsing_tensor_reshapes_to_zero
+// CHECK: linalg.tensor_reshape %{{.*}} []
+// CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
+
+// -----
+
+func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
+ -> memref<f32> {
+ %0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+ memref<1x1x1xf32> into memref<1xf32>
+ %1 = linalg.reshape %0 [] : memref<1xf32> into memref<f32>
+ return %1 : memref<f32>
+}
+// CHECK-LABEL: collapsing_memref_reshapes_to_zero
+// CHECK: linalg.reshape %{{.*}} []
+// CHECK-SAME: memref<1x1x1xf32> into memref<f32>
+
+// -----
+
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
// -----
+func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor<f32>)
+ -> tensor<1x1x1xf32> {
+ %0 = linalg.tensor_reshape %arg0 [] : tensor<f32> into tensor<1xf32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+ tensor<1xf32> into tensor<1x1x1xf32>
+ return %1 : tensor<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_tensor_reshapes_to_zero
+// CHECK: linalg.tensor_reshape %{{.*}} []
+// CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
+
+// -----
+
+func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref<f32>)
+ -> memref<1x1x1xf32> {
+ %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1xf32>
+ %1 = linalg.reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
+ memref<1xf32> into memref<1x1x1xf32>
+ return %1 : memref<1x1x1xf32>
+}
+// CHECK-LABEL: expanding_memref_reshapes_to_zero
+// CHECK: linalg.reshape %{{.*}} []
+// CHECK-SAME: memref<f32> into memref<1x1x1xf32>
+
+// -----
+
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
{
%0 = linalg.tensor_reshape %arg0
// -----
+#accesses = [
+ affine_map<(i, j, k, l, m) -> (i, k, m)>,
+ affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+ indexing_maps = #accesses,
+ library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops_indexed_generic
+ (%arg0 : tensor<?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
+{
+ %0 = linalg.indexed_generic #trait
+ ins(%arg0 : tensor<?x1x?xi32>) {
+ ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
+ %arg5 : index, %arg6 : i32) :
+ %1 = addi %arg1, %arg2 : index
+ %2 = addi %1, %arg3 : index
+ %3 = addi %2, %arg4 : index
+ %4 = addi %3, %arg5 : index
+ %5 = index_cast %4 : index to i32
+ %6 = addi %5, %arg6 : i32
+ linalg.yield %6 : i32
+ } -> tensor<?x1x?x1x?xi32>
+ return %0 : tensor<?x1x?x1x?xi32>
+}
+// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic
+// CHECK: linalg.indexed_generic
+// CHECK: ^{{.+}}(
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32)
+// CHECK: %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]]
+// CHECK: %[[T4:.+]] = addi %[[T3]], %[[ARG3]]
+// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32
+// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
+// CHECK: linalg.yield %[[T6]] : i32
+
+// -----
+
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
// -----
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+ iterator_types = ["parallel", "parallel"],
+ indexing_maps = #access,
+ library_call = "some_external_func"
+}
+
+func @drop_all_loops_indexed_generic
+ (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>
+{
+ %0 = linalg.indexed_generic #trait
+ ins(%arg0 : tensor<1x1xi32>) {
+ ^bb0(%arg1 : index, %arg2 : index, %arg3: i32) :
+ %1 = addi %arg1, %arg2 : index
+ %2 = index_cast %1 : index to i32
+ %3 = addi %2, %arg3 : i32
+ linalg.yield %3 : i32
+ } -> tensor<1x1xi32>
+ return %0 : tensor<1x1xi32>
+}
+
+// CHECK-LABEL: func @drop_all_loops_indexed_generic
+// CHECK: linalg.indexed_generic
+// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32)
+// CHECK: linalg.yield %[[ARG1]] : i32
+
+// -----
+
#accesses = [
affine_map<(d0) -> (0, d0)>,
affine_map<(d0) -> (d0)>
// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
// CHECK: linalg.yield %[[VAL4]] : i32
// CHECK-NOT: linalg.indexed_generic
+
+// -----
+
+func @scalar_indexed_generic_fusion
+ (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
+{
+ %c0 = constant 0 : index
+ %cst = constant dense<1.000000e+00> : tensor<10xf32>
+ %0 = linalg.indexed_generic
+ {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ ins(%arg1 : tensor<i32>) {
+ ^bb0(%arg2: i32): // no predecessors
+ %3 = index_cast %arg2 : i32 to index
+ %4 = extract_element %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
+ linalg.yield %4 : f32
+ } -> tensor<f32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%0, %cst : tensor<f32>, tensor<10xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ %3 = mulf %arg2, %arg3 : f32
+ linalg.yield %3 : f32
+ } -> tensor<10xf32>
+ return %1 : tensor<10xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: func @scalar_indexed_generic_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
+// CHECK: %[[T0:.+]] = linalg.indexed_generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel"]
+// CHECK-SAME: ins(%[[ARG1]] : tensor<i32>)
+// CHECK: extract_element %[[ARG0]]
+// CHECK: linalg.yield
+// CHECK return %[[T0]]
\ No newline at end of file