struct FuseGenericOpsOnTensors {
static bool isFusible(LinalgOp producer, LinalgOp consumer,
unsigned consumerIdx) {
+ // Producer and consumer must have tensor semantics.
+ if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+ return false;
+
// Verify that
// - the producer has all "parallel" iterator type.
if (producer.getNumParallelLoops() != producer.getNumLoops())
return producerResultIndexMap.isPermutation();
}
- static Operation *fuse(LinalgOp producer, LinalgOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
return nullptr;
return useIndexMap.isIdentity();
}
+/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
+/// is a linalg.generic operation, the create a `linalg.generic` operation with
+/// the given `args`. Expects `op` to be `linalg.generic` or
+/// `linalg.indexed_generic`.
+template <typename... Args>
+static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
+ Args... args) {
+ if (isa<GenericOp>(op.getOperation()))
+ return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
+ if (isa<IndexedGenericOp>(op.getOperation()))
+ return cast<LinalgOp>(
+ rewriter.create<IndexedGenericOp>(args...).getOperation());
+ llvm_unreachable(
+ "expected only linalg.generic or linalg.indexed_generic ops");
+ return nullptr;
+}
+
namespace {
+
/// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
-template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
- static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
+struct FuseTensorReshapeOpAsProducer {
+ static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
unsigned consumerIdx) {
- return isTensorReshapeOpFusible(
- producer, consumer.getInputIndexingMap(consumerIdx), true);
+ return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+ consumer.hasTensorSemantics() &&
+ isTensorReshapeOpFusible(producer,
+ consumer.getInputIndexingMap(consumerIdx),
+ /*asProducer=*/true);
}
- static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
return nullptr;
// Compute the fused operands list,
- SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
- consumer.operand_end());
+ Operation *consumerOp = consumer.getOperation();
+ SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
fusedOperands[consumerIdx] = producer.src();
// Compute indexing_maps for the fused operation. The indexing_maps for the
llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map);
}));
- auto fusedOp = rewriter.create<LinalgOpTy>(
- rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ consumer, rewriter, rewriter.getUnknownLoc(),
+ consumerOp->getResultTypes(), fusedOperands,
rewriter.getI64IntegerAttr(fusedOperands.size()),
- rewriter.getI64IntegerAttr(consumer.getNumResults()),
+ rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr);
- auto &fusedRegion = fusedOp.region();
- rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
+ auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+ rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
fusedRegion.begin());
return fusedOp;
}
};
/// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
-template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
- static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
+struct FuseTensorReshapeOpAsConsumer {
+ static bool isFusible(LinalgOp producer, TensorReshapeOp consumer,
unsigned consumerIdx) {
- return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
- false);
+ return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
+ producer.hasTensorSemantics() &&
+ isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
+ /*asProducer=*/false);
}
- static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
return nullptr;
return AffineMapAttr::get(map);
}));
- auto fusedOp = rewriter.create<LinalgOpTy>(
- rewriter.getUnknownLoc(), consumer.getResultType(),
- producer.getOperands(),
- rewriter.getI64IntegerAttr(producer.getNumOperands()),
+ Operation *producerOp = producer.getOperation();
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
+ producerOp->getOperands(),
+ rewriter.getI64IntegerAttr(producerOp->getNumOperands()),
rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr);
- auto &fusedRegion = fusedOp.region();
- rewriter.cloneRegionBefore(producer.region(), fusedRegion,
+ auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+ rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
fusedRegion.begin());
return fusedOp;
}
};
/// Implementation of fusion on tensor ops when producer is a splat constant.
-template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
- static bool isFusible(ConstantOp producer, LinalgOpTy consumer,
+struct FuseConstantOpAsProducer {
+ static bool isFusible(ConstantOp producer, LinalgOp consumer,
unsigned consumerIdx) {
- return producer.getResult().getType().isa<RankedTensorType>() &&
- producer.value().template cast<DenseElementsAttr>().isSplat();
+ return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+ consumer.hasTensorSemantics() &&
+ producer.getResult().getType().isa<RankedTensorType>() &&
+ producer.value().cast<DenseElementsAttr>().isSplat();
}
- static Operation *fuse(ConstantOp producer, LinalgOpTy consumer,
- unsigned consumerIdx, PatternRewriter &rewriter,
- OperationFolder *folder = nullptr) {
+ static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
+ unsigned consumerIdx, PatternRewriter &rewriter,
+ OperationFolder *folder = nullptr) {
if (!isFusible(producer, consumer, consumerIdx))
return nullptr;
// The operands list is same as the consumer with the argument for constant
// index dropped.
- SmallVector<Value, 4> fusedOperands(consumer.operand_begin(),
- consumer.operand_end());
+ Operation *consumerOp = consumer.getOperation();
+ SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<ConstantOp>(
producer.getLoc(),
- producer.value().template cast<DenseElementsAttr>().getSplatValue());
+ producer.value().cast<DenseElementsAttr>().getSplatValue());
- auto fusedOp = rewriter.create<LinalgOpTy>(
- rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
- rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1),
- rewriter.getI64IntegerAttr(consumer.getNumResults()),
+ LinalgOp fusedOp = createLinalgOpOfSameType(
+ consumer, rewriter, rewriter.getUnknownLoc(),
+ consumerOp->getResultTypes(), fusedOperands,
+ rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1),
+ rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
// Map the block argument corresponding to the replaced argument with the
// scalar constant.
- Region &consumerRegion = consumer.region();
+ Region &consumerRegion = consumerOp->getRegion(0);
Block &entryBlock = *consumerRegion.begin();
- unsigned argIndex =
- entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx;
+ unsigned argIndex = entryBlock.getNumArguments() -
+ consumerOp->getNumOperands() + consumerIdx;
BlockAndValueMapping mapping;
mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
- Region &fusedRegion = fusedOp.region();
+ Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
mapping);
return fusedOp;
}
};
-
} // namespace
Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
// Fuse when consumer is GenericOp or IndexedGenericOp.
if (isa<GenericOp, IndexedGenericOp>(consumer)) {
- auto linalgOpConsumer = cast<LinalgOp>(consumer);
- if (!linalgOpConsumer.hasTensorSemantics())
- return nullptr;
- if (isa<GenericOp, IndexedGenericOp>(producer)) {
- auto linalgOpProducer = cast<LinalgOp>(producer);
- if (linalgOpProducer.hasTensorSemantics())
- return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
- consumerIdx, rewriter, folder);
- } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
- if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
- return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
- reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
- folder);
- } else if (auto indexedGenericOpConsumer =
- dyn_cast<IndexedGenericOp>(consumer)) {
- return FuseTensorReshapeOpAsProducer<IndexedGenericOp>::fuse(
- reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter,
- folder);
- }
- } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
- if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
- return FuseConstantOpAsProducer<GenericOp>::fuse(
- constantOpProducer, genericOpConsumer, consumerIdx, rewriter,
- folder);
- }
- }
+ if (isa<GenericOp, IndexedGenericOp>(producer))
+ return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
+ if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
+ return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
+ if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
+ return FuseConstantOpAsProducer::fuse(constantOpProducer,
+ cast<LinalgOp>(consumer),
+ consumerIdx, rewriter, folder);
return nullptr;
}
- // Fuse when consumer is a TensorReshapeOp.
- if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
- if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
- if (genericOpProducer.hasTensorSemantics())
- return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
- genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
- } else if (auto indexedGenericOpProducer =
- dyn_cast<IndexedGenericOp>(producer)) {
- if (indexedGenericOpProducer.hasTensorSemantics())
- return FuseTensorReshapeOpAsConsumer<IndexedGenericOp>::fuse(
- indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+ if (isa<GenericOp, IndexedGenericOp>(producer)) {
+ // Fuse when consumer is a TensorReshapeOp.
+ if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
+ return FuseTensorReshapeOpAsConsumer::fuse(
+ cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
}
- return nullptr;
}
return nullptr;
// -----
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
+ -> tensor<5x?x?xf32>
+{
+ %0 = constant dense<42.0> : tensor<5xf32>
+ %1 = linalg.indexed_generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ %0, %arg0 {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32):
+ %2 = mulf %arg4, %arg5 : f32
+ linalg.yield %2 : f32
+ }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ return %1 : tensor<5x?x?xf32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @indexed_generic_op_constant_fusion
+// CHECK: %[[CST:.*]] = constant {{.*}} : f32
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1 : i64
+// CHECK-SAME: args_out = 1 : i64
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG4:.*]]: f32)
+// CHECK: mulf %[[CST]], %[[ARG4]]
+
+// -----
+
#map0 = affine_map<(d0, d1, d2) -> ()>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
// -----
+#map0 = affine_map<(d0, d1, d2) -> ()>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @indexed_generic_op_zero_dim_constant_fusion
+ (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+{
+ %0 = constant dense<42.0> : tensor<f32>
+ %1 = linalg.indexed_generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ %0, %arg0 {
+ ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32):
+ %2 = mulf %arg4, %arg5 : f32
+ linalg.yield %2 : f32
+ }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+ return %1 : tensor<5x?x?xf32>
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
+// CHECK: %[[CST:.*]] = constant {{.*}} : f32
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1 : i64
+// CHECK-SAME: args_out = 1 : i64
+// CHECK: ^{{[a-zA-Z0-9_]*}}
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: index
+// CHECK-SAME: %[[ARG4:.*]]: f32)
+// CHECK: mulf %[[CST]], %[[ARG4]]
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
%arg1: tensor<?x?xi32>) {