[mlir][linalg] Allow some fusion on mixed generics
authorIvan Butygin <ivan.butygin@gmail.com>
Thu, 24 Nov 2022 02:45:04 +0000 (03:45 +0100)
committerIvan Butygin <ivan.butygin@gmail.com>
Tue, 29 Nov 2022 14:35:02 +0000 (15:35 +0100)
Relax linalg elementwise fusion check to allow mixed consumers. Producer is still required to be fully tensor to avoid potential memref aliasing.

Differential Revision: https://reviews.llvm.org/D138759

mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

index d0efba5..a127bd2 100644 (file)
@@ -55,7 +55,8 @@ public:
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    if (genericOp.hasBufferSemantics())
+    // Mixed and buffer sematics aren't supported.
+    if (!genericOp.hasTensorSemantics())
       return failure();
 
     // Only support ops generating one output for now.
index 26636bc..f391b2c 100644 (file)
@@ -79,8 +79,11 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (!producer || !consumer)
     return false;
 
-  // Producer and consumer must have tensor semantics.
-  if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+  // Consumer can have mixed semantics, just check operand itself has tensor
+  // type. Producer must have full tensor semantics to avoid potential
+  // aliasing between producer and consumer memrefs.
+  if (!producer.hasTensorSemantics() ||
+      !fusedOperand->get().getType().isa<RankedTensorType>())
     return false;
 
   // Verify that
@@ -348,7 +351,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   for (OpOperand *opOperand : consumer.getDpsInitOperands()) {
     fusedOutputOperands.push_back(opOperand->get());
     fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
-    fusedResultTypes.push_back(opOperand->get().getType());
+    Type resultType = opOperand->get().getType();
+    if (!resultType.isa<MemRefType>())
+      fusedResultTypes.push_back(resultType);
   }
 
   // Generate the fused op.
index b334eed..aff6a8f 100644 (file)
@@ -54,15 +54,6 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
            << ") to be equal to the number of output tensors ("
            << outputTensorOperands.size() << ")";
 
-  // Simplifying assumption: either full tensor or full buffer mode.
-  // This allows simpler verification of output operands vs result types
-  // without premature tracking of which operand is what in mixed-mode.
-  // TODO: relax when mixed-mode needs to pass verification.
-  if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
-    return op->emitOpError(
-        "expected output operands to all have tensor type or "
-        "all have buffer type");
-
   for (OpOperand *opOperand : outputTensorOperands) {
     OpResult result = dstStyleOp.getTiedOpResult(opOperand);
     if (result.getType() != opOperand->get().getType())
index 2950b27..5bcae36 100644 (file)
@@ -1110,3 +1110,43 @@ module {
 //   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
 //       CHECK:     linalg.yield %[[T3]] : f32
 //       CHECK:   return %[[GENERIC]]
+
+// -----
+
+// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+
+// CHECK-LABEL: @mixed_fusion
+func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
+  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %4 = arith.addf %arg3, %arg4 : f32
+      linalg.yield %4 : f32
+  } -> tensor<?x?xf32>
+  // CHECK: linalg.generic {
+  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
+  linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
+      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg8 : memref<?x?xf32>) {
+    // CHECK: ^{{[a-zA-Z0-9_]*}}
+    // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
+    // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
+    // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
+    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
+      // CHECK-NOT: linalg.yield
+      // CHECK: arith.mulf [[T1]], [[ARG2]]
+      // CHECK: linalg.yield
+      %5 = arith.mulf %arg5, %arg6 : f32
+      linalg.yield %5 : f32
+    }
+  return
+}