From 1757164eed244b221c6c078baa7c836e4809e133 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Fri, 14 Oct 2022 23:04:24 +0000 Subject: [PATCH] [mlir][vector] Add distribution for extract from 0d vector Differential Revision: https://reviews.llvm.org/D135994 --- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 33 ++++++++++++++++++++-- .../Dialect/Vector/vector-warp-distribute.mlir | 18 ++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 3c4f20f..f730044 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -895,6 +895,34 @@ struct WarpOpExtract : public OpRewritePattern { } }; +/// Pattern to move out vector.extractelement of 0-D tensors. Those don't +/// need to be distributed and can just be propagated outside of the region. +struct WarpOpExtractElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op); + }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto extractOp = operand->get().getDefiningOp(); + if (extractOp.getVectorType().getRank() != 0) + return failure(); + Location loc = extractOp.getLoc(); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value newExtract = rewriter.create( + loc, newWarpOp->getResult(newRetIndices[0])); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the @@ -1093,8 +1121,9 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns( void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); + WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement, + WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>( + patterns.getContext(), benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 3978d94..49c36fe 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -632,6 +632,24 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) { // ----- +// CHECK-PROP-LABEL: func.func @vector_extractelement_simple( +// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { +// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector +// CHECK-PROP: vector.yield %[[V]] : vector +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector +// CHECK-PROP: return %[[E]] : f32 +func.func @vector_extractelement_simple(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector) + %1 = vector.extractelement %0[] : vector + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP: func @lane_dependent_warp_propagate_read // CHECK-PROP-SAME: %[[ID:.*]]: index func.func @lane_dependent_warp_propagate_read( -- 2.7.4