#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
}
};
+struct FoldTensorCastOfOutputIntoForallOp
+ : public OpRewritePattern<scf::ForallOp> {
+ using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
+
+ struct TypeCast {
+ Type srcType;
+ Type dstType;
+ };
+
+ LogicalResult matchAndRewrite(scf::ForallOp forallOp,
+ PatternRewriter &rewriter) const final {
+ llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
+ llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
+ for (auto en : llvm::enumerate(newOutputTensors)) {
+ auto castOp = en.value().getDefiningOp<tensor::CastOp>();
+ if (!castOp)
+ continue;
+
+ // Only casts that that preserve static information, i.e. will make the
+ // loop result type "more" static than before, will be folded.
+ if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
+ castOp.getSource().getType())) {
+ continue;
+ }
+
+ tensorCastProducers[en.index()] =
+ TypeCast{castOp.getSource().getType(), castOp.getType()};
+ newOutputTensors[en.index()] = castOp.getSource();
+ }
+
+ if (tensorCastProducers.empty())
+ return failure();
+
+ // Create new loop.
+ Location loc = forallOp.getLoc();
+ auto newForallOp = rewriter.create<ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
+ [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
+ auto castBlockArgs =
+ llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
+ for (auto [index, cast] : tensorCastProducers) {
+ Value &oldTypeBBArg = castBlockArgs[index];
+ oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
+ nestedLoc, cast.dstType, oldTypeBBArg);
+ }
+
+ // Move old body into new parallel loop.
+ SmallVector<Value> ivsBlockArgs =
+ llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
+ ivsBlockArgs.append(castBlockArgs);
+ rewriter.mergeBlocks(forallOp.getBody(),
+ bbArgs.front().getParentBlock(), ivsBlockArgs);
+ });
+
+ // After `mergeBlocks` happened, the destinations in the terminator were
+ // mapped to the tensor.cast old-typed results of the output bbArgs. The
+ // destination have to be updated to point to the output bbArgs directly.
+ auto terminator = newForallOp.getTerminator();
+ for (auto [yieldingOp, outputBlockArg] :
+ llvm::zip(terminator.getYieldingOps(),
+ newForallOp.getOutputBlockArguments())) {
+ auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ insertSliceOp.getDestMutable().assign(outputBlockArg);
+ }
+
+ // Cast results back to the original types.
+ rewriter.setInsertionPointAfter(newForallOp);
+ SmallVector<Value> castResults = newForallOp.getResults();
+ for (auto &item : tensorCastProducers) {
+ Value &oldTypeResult = castResults[item.first];
+ oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
+ oldTypeResult);
+ }
+ rewriter.replaceOp(forallOp, castResults);
+ return success();
+ }
+};
+
} // namespace
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp, ForallOpControlOperandsFolder,
+ results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
+ ForallOpControlOperandsFolder,
ForallOpSingleOrZeroIterationDimsFolder>(context);
}
// CHECK: %[[EMPTY:.*]] = tensor.empty
// CHECK: return %[[EMPTY]]
+// -----
+
+func.func @fold_tensor_cast_into_forall(
+ %in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> {
+ %cst = arith.constant dense<[100500]> : tensor<1xi32>
+
+
+ %out_cast = tensor.cast %out : tensor<2xi32> to tensor<?xi32>
+ %result = scf.forall (%i) = (0) to (2) step (1)
+ shared_outs (%out_ = %out_cast) -> tensor<?xi32> {
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
+ : tensor<1xi32> into tensor<?xi32>
+ }
+ }
+ %result_cast = tensor.cast %result : tensor<?xi32> to tensor<2xi32>
+ func.return %result_cast : tensor<2xi32>
+}
+// CHECK-LABEL: @fold_tensor_cast_into_forall
+// CHECK-NOT: tensor.cast
+// CHECK: parallel_insert_slice
+// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
+// CHECK-NOT: tensor.cast
+
+// -----
+
+func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
+ %in: tensor<?xi32>, %out: tensor<?xi32>) -> tensor<?xi32> {
+ %cst = arith.constant dense<[100500]> : tensor<1xi32>
+
+
+ %out_cast = tensor.cast %out : tensor<?xi32> to tensor<2xi32>
+ %result = scf.forall (%i) = (0) to (2) step (1)
+ shared_outs (%out_ = %out_cast) -> tensor<2xi32> {
+
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
+ : tensor<1xi32> into tensor<2xi32>
+ }
+ }
+ %result_cast = tensor.cast %result : tensor<2xi32> to tensor<?xi32>
+ func.return %result_cast : tensor<?xi32>
+}
+// CHECK-LABEL: @do_not_fold_tensor_cast_
+// CHECK: tensor.cast
+// CHECK: parallel_insert_slice
+// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
+// CHECK: tensor.cast