[mlir] Add a pattern to fold tensor.cast into scf.forall.
authorAlexander Belyaev <pifon@google.com>
Wed, 22 Mar 2023 11:50:24 +0000 (12:50 +0100)
committerAlexander Belyaev <pifon@google.com>
Wed, 22 Mar 2023 12:26:22 +0000 (13:26 +0100)
Differential revision: https://reviews.llvm.org/D146558

mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index e212159..3eda0d6 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
@@ -1594,11 +1595,91 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   }
 };
 
+struct FoldTensorCastOfOutputIntoForallOp
+    : public OpRewritePattern<scf::ForallOp> {
+  using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
+
+  struct TypeCast {
+    Type srcType;
+    Type dstType;
+  };
+
+  LogicalResult matchAndRewrite(scf::ForallOp forallOp,
+                                PatternRewriter &rewriter) const final {
+    llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
+    llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
+    for (auto en : llvm::enumerate(newOutputTensors)) {
+      auto castOp = en.value().getDefiningOp<tensor::CastOp>();
+      if (!castOp)
+        continue;
+
+      // Only casts that that preserve static information, i.e. will make the
+      // loop result type "more" static than before, will be folded.
+      if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
+                                              castOp.getSource().getType())) {
+        continue;
+      }
+
+      tensorCastProducers[en.index()] =
+          TypeCast{castOp.getSource().getType(), castOp.getType()};
+      newOutputTensors[en.index()] = castOp.getSource();
+    }
+
+    if (tensorCastProducers.empty())
+      return failure();
+
+    // Create new loop.
+    Location loc = forallOp.getLoc();
+    auto newForallOp = rewriter.create<ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
+        [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
+          auto castBlockArgs =
+              llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
+          for (auto [index, cast] : tensorCastProducers) {
+            Value &oldTypeBBArg = castBlockArgs[index];
+            oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
+                nestedLoc, cast.dstType, oldTypeBBArg);
+          }
+
+          // Move old body into new parallel loop.
+          SmallVector<Value> ivsBlockArgs =
+              llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
+          ivsBlockArgs.append(castBlockArgs);
+          rewriter.mergeBlocks(forallOp.getBody(),
+                               bbArgs.front().getParentBlock(), ivsBlockArgs);
+        });
+
+    // After `mergeBlocks` happened, the destinations in the terminator were
+    // mapped to the tensor.cast old-typed results of the output bbArgs. The
+    // destination have to be updated to point to the output bbArgs directly.
+    auto terminator = newForallOp.getTerminator();
+    for (auto [yieldingOp, outputBlockArg] :
+         llvm::zip(terminator.getYieldingOps(),
+                   newForallOp.getOutputBlockArguments())) {
+      auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+      insertSliceOp.getDestMutable().assign(outputBlockArg);
+    }
+
+    // Cast results back to the original types.
+    rewriter.setInsertionPointAfter(newForallOp);
+    SmallVector<Value> castResults = newForallOp.getResults();
+    for (auto &item : tensorCastProducers) {
+      Value &oldTypeResult = castResults[item.first];
+      oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
+                                                      oldTypeResult);
+    }
+    rewriter.replaceOp(forallOp, castResults);
+    return success();
+  }
+};
+
 } // namespace
 
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<DimOfForallOp, ForallOpControlOperandsFolder,
+  results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
+              ForallOpControlOperandsFolder,
               ForallOpSingleOrZeroIterationDimsFolder>(context);
 }
 
index f69cf19..ec6e35a 100644 (file)
@@ -1651,3 +1651,52 @@ func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
 // CHECK:       %[[EMPTY:.*]] = tensor.empty
 // CHECK:       return %[[EMPTY]]
 
+// -----
+
+func.func @fold_tensor_cast_into_forall(
+    %in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> {
+  %cst = arith.constant dense<[100500]> : tensor<1xi32>
+
+
+  %out_cast = tensor.cast %out : tensor<2xi32> to tensor<?xi32>
+  %result = scf.forall (%i) = (0) to (2) step (1)
+      shared_outs (%out_ = %out_cast) -> tensor<?xi32> {
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
+        : tensor<1xi32> into tensor<?xi32>
+    }
+  }
+  %result_cast = tensor.cast %result : tensor<?xi32> to tensor<2xi32>
+  func.return %result_cast : tensor<2xi32>
+}
+// CHECK-LABEL: @fold_tensor_cast_into_forall
+// CHECK-NOT:     tensor.cast
+// CHECK:         parallel_insert_slice
+// CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
+// CHECK-NOT:     tensor.cast
+
+// -----
+
+func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
+    %in: tensor<?xi32>, %out: tensor<?xi32>) -> tensor<?xi32> {
+  %cst = arith.constant dense<[100500]> : tensor<1xi32>
+
+
+  %out_cast = tensor.cast %out : tensor<?xi32> to tensor<2xi32>
+  %result = scf.forall (%i) = (0) to (2) step (1)
+      shared_outs (%out_ = %out_cast) -> tensor<2xi32> {
+
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
+        : tensor<1xi32> into tensor<2xi32>
+    }
+  }
+  %result_cast = tensor.cast %result : tensor<2xi32> to tensor<?xi32>
+  func.return %result_cast : tensor<?xi32>
+}
+// CHECK-LABEL: @do_not_fold_tensor_cast_
+// CHECK:         tensor.cast
+// CHECK:         parallel_insert_slice
+// CHECK-SAME:      : tensor<1xi32> into tensor<2xi32>
+// CHECK:         tensor.cast