From 5ce68f4284c694392238f1c8c5308d08d9a56251 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 16 Nov 2022 17:23:43 +0100 Subject: [PATCH] [mlir] Introduce `replaceUsesOfWith` to `RewriterBase` Finding uses of a value and replacing them with a new one is a common method. I have not seen an safe and easy shortcut that does that. This revision attempts to address that by intoroducing `replaceUsesOfWith` to `RewriterBase`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D138110 --- mlir/include/mlir/IR/PatternMatch.h | 5 +++++ .../Dialect/GPU/TransformOps/GPUTransformOps.cpp | 24 ++++++---------------- mlir/lib/IR/PatternMatch.cpp | 8 ++++++++ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e257b67..7b05a2d 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -502,6 +502,11 @@ public: finalizeRootUpdate(root); } + /// Find uses of `from` and replace it with `to`. It also marks every modified + /// uses and notifies the rewriter that an in-place operation modification is + /// about to happen. + void replaceAllUsesWith(Value from, Value to); + /// Used to notify the rewriter that the IR failed to be rewritten because of /// a match failure, and provide a callback to populate a diagnostic with the /// reason why the failure occurred. This method allows for derived rewriters diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index ccac412..ec493bf 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -246,15 +246,9 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl( sourceBlock.getOperations()); // Step 5. RAUW thread indices to thread ops. - for (Value blockIdx : foreachThreadOp.getThreadIndices()) { - Value val = bvm.lookup(blockIdx); - SmallVector uses; - for (OpOperand &use : blockIdx.getUses()) - uses.push_back(&use); - for (OpOperand *operand : uses) { - Operation *op = operand->getOwner(); - rewriter.updateRootInPlace(op, [&]() { operand->set(val); }); - } + for (Value loopIndex : foreachThreadOp.getThreadIndices()) { + Value blockIdx = bvm.lookup(loopIndex); + rewriter.replaceAllUsesWith(loopIndex, blockIdx); } // Step 6. Erase old op. @@ -492,15 +486,9 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( sourceBlock.getOperations()); // Step 6. RAUW thread indices to thread ops. - for (Value threadIdx : foreachThreadOp.getThreadIndices()) { - Value val = bvm.lookup(threadIdx); - SmallVector uses; - for (OpOperand &use : threadIdx.getUses()) - uses.push_back(&use); - for (OpOperand *operand : uses) { - Operation *op = operand->getOwner(); - rewriter.updateRootInPlace(op, [&]() { operand->set(val); }); - } + for (Value loopIndex : foreachThreadOp.getThreadIndices()) { + Value threadIdx = bvm.lookup(loopIndex); + rewriter.replaceAllUsesWith(loopIndex, threadIdx); } // Step 7. syncthreads. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index d2de65e..d3072b5 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -309,6 +309,14 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest, source->erase(); } +/// Find uses of `from` and replace it with `to` +void RewriterBase::replaceAllUsesWith(Value from, Value to) { + for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { + Operation *op = operand.getOwner(); + updateRootInPlace(op, [&]() { operand.set(to); }); + } +} + // Merge the operations of block 'source' before the operation 'op'. Source // block should not have existing predecessors or successors. void RewriterBase::mergeBlockBefore(Block *source, Operation *op, -- 2.7.4