//===----------------------------------------------------------------------===//
#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
namespace {
-/// Implementation of fusion of generic ops.
+/// Implementation of fusion of generic ops and indexed_generic ops.
struct FuseGenericOpsOnTensors {
- static bool isFusible(GenericOp producer, GenericOp consumer,
+ static bool isFusible(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
// Verify that
// - the producer has all "parallel" iterator type.
return producerResultIndexMap.isPermutation();
}
- static Operation *fuse(GenericOp producer, GenericOp consumer,
+ static Operation *fuse(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx, PatternRewriter &rewriter,
OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
// indexing_map of the operand at consumerIdx in the consumer.
SmallVector<Attribute, 4> fusedIndexMaps;
auto consumerIndexMaps = consumer.indexing_maps();
- fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults());
+ fusedIndexMaps.reserve(fusedOperands.size() +
+ consumer.getOperation()->getNumResults());
fusedIndexMaps.assign(consumerIndexMaps.begin(),
std::next(consumerIndexMaps.begin(), consumerIdx));
// Compute indexing maps for the producer args in the fused operation.
consumerIndexMaps.end());
// Generate the fused op.
- auto fusedOp = rewriter.create<GenericOp>(
- rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(consumer.getNumResults()),
- rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
- generateFusedRegion(rewriter, fusedOp.region(), producer.region(),
- consumer.region(), consumerIdx);
+ LinalgOp fusedOp;
+ if (isa<GenericOp>(producer.getOperation()) &&
+ isa<GenericOp>(consumer.getOperation())) {
+ fusedOp =
+ rewriter
+ .create<GenericOp>(
+ rewriter.getUnknownLoc(),
+ consumer.getOperation()->getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(fusedOperands.size()),
+ rewriter.getI64IntegerAttr(
+ consumer.getOperation()->getNumResults()),
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr)
+ .getOperation();
+ } else {
+ fusedOp =
+ rewriter
+ .create<IndexedGenericOp>(
+ rewriter.getUnknownLoc(),
+ consumer.getOperation()->getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(fusedOperands.size()),
+ rewriter.getI64IntegerAttr(
+ consumer.getOperation()->getNumResults()),
+ rewriter.getArrayAttr(fusedIndexMaps),
+ consumer.iterator_types(),
+ /*doc=*/nullptr,
+ /*library_call=*/nullptr)
+ .getOperation();
+ }
+
+ // Construct an AffineMap from consumer loops to producer loops.
+ // consumer loop -> tensor index
+ AffineMap consumerResultIndexMap =
+ consumer.getInputIndexingMap(consumerIdx);
+ // producer loop -> tensor index
+ AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
+ // tensor index -> producer loop
+ AffineMap invProducerResultIndexMap =
+ inversePermutation(producerResultIndexMap);
+ assert(invProducerResultIndexMap &&
+ "expected producer result indexig map to be invertible");
+ // consumer loop -> producer loop
+ AffineMap consumerToProducerLoopsMap =
+ invProducerResultIndexMap.compose(consumerResultIndexMap);
+
+ generateFusedRegion(rewriter, fusedOp, producer, consumer,
+ consumerToProducerLoopsMap, consumerIdx,
+ consumer.getNumLoops());
return fusedOp;
}
/// the `producer` to use in the fused operation given the indexing map of the
/// result of the producer in the consumer.
static void computeProducerOperandIndex(
- GenericOp producer, AffineMap fusedConsumerArgIndexMap,
+ LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
// from consumer loop -> consumer arg tensor index/producer result tensor
/// Generate the region of the fused operation. The region of the fused op
/// must be empty.
- static void generateFusedRegion(PatternRewriter &rewriter,
- Region &fusedRegion, Region &producerRegion,
- Region &consumerRegion,
- unsigned consumerIdx) {
+ static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp,
+ LinalgOp producer, LinalgOp consumer,
+ AffineMap consumerToProducerLoopsMap,
+ unsigned consumerIdx, unsigned nloops) {
// Build the region of the fused op.
- Block &producerBlock = producerRegion.front();
- Block &consumerBlock = consumerRegion.front();
+ Block &producerBlock = producer.getOperation()->getRegion(0).front();
+ Block &consumerBlock = consumer.getOperation()->getRegion(0).front();
Block *fusedBlock = new Block();
- fusedRegion.push_back(fusedBlock);
+ fusedOp->getRegion(0).push_back(fusedBlock);
BlockAndValueMapping mapper;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(fusedBlock);
+
+ // The block arguments are
+ // [index_0, index_1, ... ,
+ // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
+ // producer_operand_0, ... , producer_operand_(n-1)],
+ // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
+ // , where n is the number of producer's operand and m is the number
+ // consumer's operand.
+ // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
+ // generic op. In this case, there are no indices in block arguments.
+ unsigned numProducerIndices =
+ isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0;
+ unsigned numConsumerIndices =
+ isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0;
+ // Firstly, add all the indices to the block arguments.
+ for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices);
+ i < e; ++i)
+ fusedBlock->addArgument(rewriter.getIndexType());
// Map the arguments for the unmodified args from the consumer.
for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
- if (consumerArg.index() == consumerIdx) {
+ if (consumerArg.index() == consumerIdx + numConsumerIndices) {
// Map the arguments for the args from the producer.
- for (auto producerArg : producerBlock.getArguments())
- mapper.map(producerArg,
- fusedBlock->addArgument(producerArg.getType()));
+ for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
+ // If producer is an indexed_generic op, map the indices from consumer
+ // loop to producer loop (because the fusedOp is built based on
+ // consumer's perspective).
+ if (producerArg.index() < numProducerIndices) {
+ auto newIndex = rewriter.create<mlir::AffineApplyOp>(
+ producer.getLoc(),
+ consumerToProducerLoopsMap.getSubMap(producerArg.index()),
+ fusedBlock->getArguments().take_front(nloops));
+ mapper.map(producerArg.value(), newIndex);
+ } else {
+ mapper.map(producerArg.value(),
+ fusedBlock->addArgument(producerArg.value().getType()));
+ }
+ }
continue;
}
- mapper.map(consumerArg.value(),
- fusedBlock->addArgument(consumerArg.value().getType()));
+
+ // If consumer is an indexed_generic op, map the indices to the block
+ // arguments directly. Otherwise, add the same type of arugment and map to
+ // it.
+ if (consumerArg.index() < numConsumerIndices) {
+ mapper.map(consumerArg.value(),
+ fusedBlock->getArgument(consumerArg.index()));
+ } else {
+ mapper.map(consumerArg.value(),
+ fusedBlock->addArgument(consumerArg.value().getType()));
+ }
}
// Add operations from producer (except the yield operation) to the fused
// Lookup the value the yield operation is mapped to.
Value yieldVal = yieldOp.getOperand(0);
if (Value clonedVal = mapper.lookupOrNull(yieldVal))
- mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal);
+ mapper.map(
+ consumerBlock.getArgument(consumerIdx + numConsumerIndices),
+ clonedVal);
continue;
}
rewriter.clone(op, mapper);
if (!producer || producer->getNumResults() != 1)
return nullptr;
- // Fuse when consumer is GenericOp.
- if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) {
- if (!genericOp.hasTensorSemantics())
+ // Fuse when consumer is GenericOp or IndexedGenericOp.
+ if (isa<GenericOp>(consumer) || isa<IndexedGenericOp>(consumer)) {
+ auto linalgOpConsumer = cast<LinalgOp>(consumer);
+ if (!linalgOpConsumer.hasTensorSemantics())
return nullptr;
- if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
- if (genericOpProducer.hasTensorSemantics())
- return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp,
+ if (isa<GenericOp>(producer) || isa<IndexedGenericOp>(producer)) {
+ auto linalgOpProducer = cast<LinalgOp>(producer);
+ if (linalgOpProducer.hasTensorSemantics())
+ return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
consumerIdx, rewriter, folder);
} else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
- return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
- reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
+ if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
+ return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
+ reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
+ folder);
+ }
} else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
- return FuseConstantOpAsProducer<GenericOp>::fuse(
- constantOpProducer, genericOp, consumerIdx, rewriter, folder);
+ if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
+ return FuseConstantOpAsProducer<GenericOp>::fuse(
+ constantOpProducer, genericOpConsumer, consumerIdx, rewriter,
+ folder);
+ }
}
return nullptr;
}
}
return nullptr;
}
+
return nullptr;
}
void mlir::populateLinalgTensorOpsFusionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>(
- context);
+ patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
+ FuseTensorOps<TensorReshapeOp>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() {
// CHECK-SAME: args_out = 1 : i64
// CHECK: ^{{.*}}(%[[ARG1:.*]]: f32)
// CHECK: mulf %[[CST]], %[[ARG1]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) {
+ %0 = linalg.generic {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0, %arg1 {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ARG3]] : i32
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL2:.+]] = addi %[[VAL1]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = subi %[[VAL2]], %[[SUB_OPERAND]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
+ %arg1: tensor<?x?xi32>) {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.generic {
+ args_in = 2 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %0, %arg1 {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %10 = addi %arg2, %arg3 : i32
+ linalg.yield %10 : i32
+ } : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 2
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[ADD_OPERAND:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG3]] : i32
+// CHECK: linalg.yield %[[VAL3]] : i32
+// CHECK-NOT: linalg.generic
+
+// -----
+
+// The indices of the first indexed_generic op are swapped after fusion.
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map1, #map1],
+ iterator_types = ["parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = index_cast %arg3 : index to i32
+ %4 = addi %arg4, %2 : i32
+ %5 = subi %4, %3 : i32
+ linalg.yield %5 : i32
+ }: tensor<?x?xi32> -> tensor<?x?xi32>
+ return
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func @indexed_generic_op_fusion
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]]
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: i32
+// CHECK: %[[ADD_OPERAND1:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[SUB_OPERAND1:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[VAL1:.+]] = addi %[[ARG2]], %[[ADD_OPERAND1]] : i32
+// CHECK: %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND1]] : i32
+// CHECK: %[[ADD_OPERAND2:.+]] = index_cast %[[ARG0]] : index to i32
+// CHECK: %[[SUB_OPERAND2:.+]] = index_cast %[[ARG1]] : index to i32
+// CHECK: %[[VAL3:.+]] = addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
+// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
+// CHECK: linalg.yield %[[VAL4]] : i32
+// CHECK-NOT: linalg.indexed_generic