[mlir][Linalg] Enable fusion of std.constant (producer) with
authorMaheshRavishankar <ravishankarm@google.com>
Mon, 27 Jul 2020 16:48:36 +0000 (09:48 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Mon, 27 Jul 2020 16:51:20 +0000 (09:51 -0700)
linalg.indexed_generic (consumer) with tensor arguments.

The implementation of fusing std.constant producer with a
linalg.indexed_generic consumer was already in place. It is exposed
with this change. Also cleaning up some of the patterns that implement
the fusion to not be templated, thereby avoiding lot of conditional
checks for calling the right instantiation.

Differential Revision: https://reviews.llvm.org/D84566

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir

index 82dfa75..9080a20 100644 (file)
@@ -440,6 +440,10 @@ namespace {
 struct FuseGenericOpsOnTensors {
   static bool isFusible(LinalgOp producer, LinalgOp consumer,
                         unsigned consumerIdx) {
+    // Producer and consumer must have tensor semantics.
+    if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
+      return false;
+
     // Verify that
     // - the producer has all "parallel" iterator type.
     if (producer.getNumParallelLoops() != producer.getNumLoops())
@@ -457,9 +461,9 @@ struct FuseGenericOpsOnTensors {
     return producerResultIndexMap.isPermutation();
   }
 
-  static Operation *fuse(LinalgOp producer, LinalgOp consumer,
-                         unsigned consumerIdx, PatternRewriter &rewriter,
-                         OperationFolder *folder = nullptr) {
+  static LinalgOp fuse(LinalgOp producer, LinalgOp consumer,
+                       unsigned consumerIdx, PatternRewriter &rewriter,
+                       OperationFolder *folder = nullptr) {
     if (!isFusible(producer, consumer, consumerIdx))
       return nullptr;
 
@@ -736,24 +740,45 @@ static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
   return useIndexMap.isIdentity();
 }
 
+/// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
+/// is a linalg.generic operation, the create a `linalg.generic` operation with
+/// the given `args`. Expects `op` to be `linalg.generic` or
+/// `linalg.indexed_generic`.
+template <typename... Args>
+static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
+                                         Args... args) {
+  if (isa<GenericOp>(op.getOperation()))
+    return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation());
+  if (isa<IndexedGenericOp>(op.getOperation()))
+    return cast<LinalgOp>(
+        rewriter.create<IndexedGenericOp>(args...).getOperation());
+  llvm_unreachable(
+      "expected only linalg.generic or linalg.indexed_generic ops");
+  return nullptr;
+}
+
 namespace {
+
 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp.
-template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
-  static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer,
+struct FuseTensorReshapeOpAsProducer {
+  static bool isFusible(TensorReshapeOp producer, LinalgOp consumer,
                         unsigned consumerIdx) {
-    return isTensorReshapeOpFusible(
-        producer, consumer.getInputIndexingMap(consumerIdx), true);
+    return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+           consumer.hasTensorSemantics() &&
+           isTensorReshapeOpFusible(producer,
+                                    consumer.getInputIndexingMap(consumerIdx),
+                                    /*asProducer=*/true);
   }
 
-  static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer,
-                         unsigned consumerIdx, PatternRewriter &rewriter,
-                         OperationFolder *folder = nullptr) {
+  static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer,
+                       unsigned consumerIdx, PatternRewriter &rewriter,
+                       OperationFolder *folder = nullptr) {
     if (!isFusible(producer, consumer, consumerIdx))
       return nullptr;
 
     // Compute the fused operands list,
-    SmallVector<Value, 2> fusedOperands(consumer.operand_begin(),
-                                        consumer.operand_end());
+    Operation *consumerOp = consumer.getOperation();
+    SmallVector<Value, 2> fusedOperands(consumerOp->getOperands());
     fusedOperands[consumerIdx] = producer.src();
 
     // Compute indexing_maps for the fused operation. The indexing_maps for the
@@ -783,32 +808,35 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer {
         llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute {
           return AffineMapAttr::get(map);
         }));
-    auto fusedOp = rewriter.create<LinalgOpTy>(
-        rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+    LinalgOp fusedOp = createLinalgOpOfSameType(
+        consumer, rewriter, rewriter.getUnknownLoc(),
+        consumerOp->getResultTypes(), fusedOperands,
         rewriter.getI64IntegerAttr(fusedOperands.size()),
-        rewriter.getI64IntegerAttr(consumer.getNumResults()),
+        rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
         rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
         /*doc=*/nullptr,
         /*library_call=*/nullptr,
         /*symbol_source=*/nullptr);
-    auto &fusedRegion = fusedOp.region();
-    rewriter.cloneRegionBefore(consumer.region(), fusedRegion,
+    auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+    rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion,
                                fusedRegion.begin());
     return fusedOp;
   }
 };
 
 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp.
-template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
-  static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer,
+struct FuseTensorReshapeOpAsConsumer {
+  static bool isFusible(LinalgOp producer, TensorReshapeOp consumer,
                         unsigned consumerIdx) {
-    return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
-                                    false);
+    return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) &&
+           producer.hasTensorSemantics() &&
+           isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0),
+                                    /*asProducer=*/false);
   }
 
-  static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer,
-                         unsigned consumerIdx, PatternRewriter &rewriter,
-                         OperationFolder *folder = nullptr) {
+  static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer,
+                       unsigned consumerIdx, PatternRewriter &rewriter,
+                       OperationFolder *folder = nullptr) {
     if (!isFusible(producer, consumer, consumerIdx))
       return nullptr;
 
@@ -839,33 +867,36 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
           return AffineMapAttr::get(map);
         }));
 
-    auto fusedOp = rewriter.create<LinalgOpTy>(
-        rewriter.getUnknownLoc(), consumer.getResultType(),
-        producer.getOperands(),
-        rewriter.getI64IntegerAttr(producer.getNumOperands()),
+    Operation *producerOp = producer.getOperation();
+    LinalgOp fusedOp = createLinalgOpOfSameType(
+        producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
+        producerOp->getOperands(),
+        rewriter.getI64IntegerAttr(producerOp->getNumOperands()),
         rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
         producer.iterator_types(),
         /*doc=*/nullptr,
         /*library_call=*/nullptr,
         /*symbol_source=*/nullptr);
-    auto &fusedRegion = fusedOp.region();
-    rewriter.cloneRegionBefore(producer.region(), fusedRegion,
+    auto &fusedRegion = fusedOp.getOperation()->getRegion(0);
+    rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion,
                                fusedRegion.begin());
     return fusedOp;
   }
 };
 
 /// Implementation of fusion on tensor ops when producer is a splat constant.
-template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
-  static bool isFusible(ConstantOp producer, LinalgOpTy consumer,
+struct FuseConstantOpAsProducer {
+  static bool isFusible(ConstantOp producer, LinalgOp consumer,
                         unsigned consumerIdx) {
-    return producer.getResult().getType().isa<RankedTensorType>() &&
-           producer.value().template cast<DenseElementsAttr>().isSplat();
+    return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) &&
+           consumer.hasTensorSemantics() &&
+           producer.getResult().getType().isa<RankedTensorType>() &&
+           producer.value().cast<DenseElementsAttr>().isSplat();
   }
 
-  static Operation *fuse(ConstantOp producer, LinalgOpTy consumer,
-                         unsigned consumerIdx, PatternRewriter &rewriter,
-                         OperationFolder *folder = nullptr) {
+  static LinalgOp fuse(ConstantOp producer, LinalgOp consumer,
+                       unsigned consumerIdx, PatternRewriter &rewriter,
+                       OperationFolder *folder = nullptr) {
     if (!isFusible(producer, consumer, consumerIdx))
       return nullptr;
 
@@ -881,19 +912,20 @@ template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
 
     // The operands list is same as the consumer with the argument for constant
     // index dropped.
-    SmallVector<Value, 4> fusedOperands(consumer.operand_begin(),
-                                        consumer.operand_end());
+    Operation *consumerOp = consumer.getOperation();
+    SmallVector<Value, 4> fusedOperands(consumerOp->getOperands());
     fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
 
     // Create a constant scalar value from the splat constant.
     Value scalarConstant = rewriter.create<ConstantOp>(
         producer.getLoc(),
-        producer.value().template cast<DenseElementsAttr>().getSplatValue());
+        producer.value().cast<DenseElementsAttr>().getSplatValue());
 
-    auto fusedOp = rewriter.create<LinalgOpTy>(
-        rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
-        rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1),
-        rewriter.getI64IntegerAttr(consumer.getNumResults()),
+    LinalgOp fusedOp = createLinalgOpOfSameType(
+        consumer, rewriter, rewriter.getUnknownLoc(),
+        consumerOp->getResultTypes(), fusedOperands,
+        rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1),
+        rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
         rewriter.getAffineMapArrayAttr(fusedIndexMaps),
         consumer.iterator_types(),
         /*doc=*/nullptr,
@@ -902,19 +934,18 @@ template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
 
     // Map the block argument corresponding to the replaced argument with the
     // scalar constant.
-    Region &consumerRegion = consumer.region();
+    Region &consumerRegion = consumerOp->getRegion(0);
     Block &entryBlock = *consumerRegion.begin();
-    unsigned argIndex =
-        entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx;
+    unsigned argIndex = entryBlock.getNumArguments() -
+                        consumerOp->getNumOperands() + consumerIdx;
     BlockAndValueMapping mapping;
     mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
-    Region &fusedRegion = fusedOp.region();
+    Region &fusedRegion = fusedOp.getOperation()->getRegion(0);
     rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
                                mapping);
     return fusedOp;
   }
 };
-
 } // namespace
 
 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
@@ -929,48 +960,27 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
 
   // Fuse when consumer is GenericOp or IndexedGenericOp.
   if (isa<GenericOp, IndexedGenericOp>(consumer)) {
-    auto linalgOpConsumer = cast<LinalgOp>(consumer);
-    if (!linalgOpConsumer.hasTensorSemantics())
-      return nullptr;
-    if (isa<GenericOp, IndexedGenericOp>(producer)) {
-      auto linalgOpProducer = cast<LinalgOp>(producer);
-      if (linalgOpProducer.hasTensorSemantics())
-        return FuseGenericOpsOnTensors::fuse(linalgOpProducer, linalgOpConsumer,
-                                             consumerIdx, rewriter, folder);
-    } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
-      if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
-        return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
-            reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
-            folder);
-      } else if (auto indexedGenericOpConsumer =
-                     dyn_cast<IndexedGenericOp>(consumer)) {
-        return FuseTensorReshapeOpAsProducer<IndexedGenericOp>::fuse(
-            reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter,
-            folder);
-      }
-    } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
-      if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
-        return FuseConstantOpAsProducer<GenericOp>::fuse(
-            constantOpProducer, genericOpConsumer, consumerIdx, rewriter,
-            folder);
-      }
-    }
+    if (isa<GenericOp, IndexedGenericOp>(producer))
+      return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer),
+                                           cast<LinalgOp>(consumer),
+                                           consumerIdx, rewriter, folder);
+    if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer))
+      return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer,
+                                                 cast<LinalgOp>(consumer),
+                                                 consumerIdx, rewriter, folder);
+    if (auto constantOpProducer = dyn_cast<ConstantOp>(producer))
+      return FuseConstantOpAsProducer::fuse(constantOpProducer,
+                                            cast<LinalgOp>(consumer),
+                                            consumerIdx, rewriter, folder);
     return nullptr;
   }
 
-  // Fuse when consumer is a TensorReshapeOp.
-  if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
-    if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) {
-      if (genericOpProducer.hasTensorSemantics())
-        return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
-            genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
-    } else if (auto indexedGenericOpProducer =
-                   dyn_cast<IndexedGenericOp>(producer)) {
-      if (indexedGenericOpProducer.hasTensorSemantics())
-        return FuseTensorReshapeOpAsConsumer<IndexedGenericOp>::fuse(
-            indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+  if (isa<GenericOp, IndexedGenericOp>(producer)) {
+    // Fuse when consumer is a TensorReshapeOp.
+    if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) {
+      return FuseTensorReshapeOpAsConsumer::fuse(
+          cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder);
     }
-    return nullptr;
   }
 
   return nullptr;
index 5c49a59..4e7f1f6 100644 (file)
@@ -249,6 +249,38 @@ func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
 
 // -----
 
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @indexed_generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>)
+                                         -> tensor<5x?x?xf32>
+{
+  %0 = constant dense<42.0> : tensor<5xf32>
+  %1 = linalg.indexed_generic
+       {args_in = 2 : i64, args_out = 1 : i64,
+         indexing_maps = [#map0, #map1, #map1],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+       %0, %arg0 {
+       ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: f32, %arg5 : f32):
+         %2 = mulf %arg4, %arg5 : f32
+         linalg.yield %2 : f32
+       }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+  return %1 : tensor<5x?x?xf32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @indexed_generic_op_constant_fusion
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.indexed_generic
+//  CHECK-SAME:     args_in = 1 : i64
+//  CHECK-SAME:     args_out = 1 : i64
+//       CHECK:   ^{{[a-zA-Z0-9_]*}}
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG4:.*]]: f32)
+//       CHECK:     mulf %[[CST]], %[[ARG4]]
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2) -> ()>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
@@ -277,6 +309,38 @@ func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
 
 // -----
 
+#map0 = affine_map<(d0, d1, d2) -> ()>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @indexed_generic_op_zero_dim_constant_fusion
+  (%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+{
+  %0 = constant dense<42.0> : tensor<f32>
+  %1 = linalg.indexed_generic
+       {args_in = 2 : i64, args_out = 1 : i64,
+         indexing_maps = [#map0, #map1, #map1],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+       %0, %arg0 {
+       ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4: f32, %arg5: f32):
+         %2 = mulf %arg4, %arg5 : f32
+         linalg.yield %2 : f32
+       }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+  return %1 : tensor<5x?x?xf32>
+}
+//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @indexed_generic_op_zero_dim_constant_fusion
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.indexed_generic
+//  CHECK-SAME:     args_in = 1 : i64
+//  CHECK-SAME:     args_out = 1 : i64
+//       CHECK:   ^{{[a-zA-Z0-9_]*}}
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]*]]: index
+//  CHECK-SAME:     %[[ARG4:.*]]: f32)
+//       CHECK:     mulf %[[CST]], %[[ARG4]]
+
+// -----
+
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
                                            %arg1: tensor<?x?xi32>) {