From 0d9761d50e738163c87d84a4328bc0a827ac8f34 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 22 Nov 2022 11:20:41 +0100 Subject: [PATCH] [mlir][SCF] Add tensor.dim(scf.foreach_thread) folding Dim sizes of `scf.foreach_thread` op results match the dim sizes of their respective tied shared_outs operands. Differential Revision: https://reviews.llvm.org/D138484 --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 10 ++++++++++ mlir/lib/Dialect/SCF/IR/SCF.cpp | 25 +++++++++++++++++++++++++ mlir/test/Dialect/SCF/canonicalize.mlir | 22 ++++++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 2d880ac..af4db68 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -487,6 +487,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -510,11 +511,20 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [ opOperand->getOperandNumber() - getRank()); } + /// Return the num_threads operand that is tied to the given thread id + /// block argument. OpOperand *getTiedOpOperand(BlockArgument bbArg) { assert(bbArg.getArgNumber() >= getRank() && "invalid bbArg"); return &getOperation()->getOpOperand(bbArg.getArgNumber()); } + /// Return the shared_outs operand that is tied to the given OpResult. + OpOperand *getTiedOpOperand(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return &getOperation()->getOpOperand( + opResult.getResultNumber() + getRank()); + } + BlockArgument getTiedBlockArgument(OpOperand *opOperand) { assert(opOperand->getOperandNumber() >= getRank() && "invalid operand"); return getBody()->getArgument(opOperand->getOperandNumber()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 118452a..6924107 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1323,6 +1323,31 @@ ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { return dyn_cast(containingOp); } +namespace { +/// Fold tensor.dim(foreach_thread shared_outs(... = %t)) to tensor.dim(%t). +struct DimOfForeachThreadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const final { + auto foreachThreadOp = dimOp.getSource().getDefiningOp(); + if (!foreachThreadOp) + return failure(); + Value sharedOut = + foreachThreadOp.getTiedOpOperand(dimOp.getSource().cast()) + ->get(); + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); + return success(); + } +}; +} // namespace + +void ForeachThreadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // PerformConcurrentlyOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index b6ac362..e5e2afc 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1478,3 +1478,25 @@ func.func @func_execute_region_elim_multi_yield() { // CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: "test.bar"(%[[z]]) // CHECK: return + +// ----- + +// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices( +// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor +func.func @canonicalize_parallel_insert_slice_indices( + %arg0 : tensor<1x5xf32>, %arg1: tensor, %num_threads : index) -> index +{ + // CHECK: %[[c1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + %2 = scf.foreach_thread (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %arg0 into %o[%tidx, 0] [1, 5] [1, 1] : tensor<1x5xf32> into tensor + } + } + + // CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c1]] + %dim = tensor.dim %2, %c1 : tensor + // CHECK: return %[[dim]] + return %dim : index +} -- 2.7.4