From 0af268059636647798b00bd85dc4faecf537ce52 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 11 Jul 2022 07:01:13 +0000 Subject: [PATCH] [mlir][vector] Add pattern to distribute splat constant Distribute splat constant out of WarpExecuteOnLane0Op region. Differential Revision: https://reviews.llvm.org/D129467 --- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 42 ++++++++++++++++++++-- .../Dialect/Vector/vector-warp-distribute.mlir | 13 +++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 57fa863..1fb7a21 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -524,6 +524,44 @@ struct WarpOpElementwise : public OpRewritePattern { } }; +/// Sink out splat constant op feeding into a warp op yield. +/// ``` +/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { +/// ... +/// %cst = arith.constant dense<2.0> : vector<32xf32> +/// vector.yield %cst : vector<32xf32> +/// } +/// ``` +/// To +/// ``` +/// vector.warp_execute_on_lane_0(%arg0 { +/// ... +/// } +/// %0 = arith.constant dense<2.0> : vector<1xf32> +struct WarpOpConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!yieldOperand) + return failure(); + auto constantOp = yieldOperand->get().getDefiningOp(); + auto dense = constantOp.getValue().dyn_cast(); + if (!dense) + return failure(); + unsigned operandIndex = yieldOperand->getOperandNumber(); + Attribute scalarAttr = dense.getSplatValue(); + Attribute newAttr = DenseElementsAttr::get( + warpOp.getResult(operandIndex).getType(), scalarAttr); + Location loc = warpOp.getLoc(); + rewriter.setInsertionPointAfter(warpOp); + Value distConstant = rewriter.create(loc, newAttr); + warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant); + return success(); + } +}; + /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -868,8 +906,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns( void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); + WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp, + WarpOpConstant>(patterns.getContext()); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 4a04f98..55a8490 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -562,3 +562,16 @@ func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32> } return %r#0, %r#1 : vector<1xf32>, vector<1xf32> } + +// ----- + +// CHECK-PROP-LABEL: func @warp_constant( +// CHECK-PROP: %[[C:.*]] = arith.constant dense<2.000000e+00> : vector<1xf32> +// CHECK-PROP: return %[[C]] : vector<1xf32> +func.func @warp_constant(%laneid: index) -> (vector<1xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { + %cst = arith.constant dense<2.0> : vector<32xf32> + vector.yield %cst : vector<32xf32> + } + return %r : vector<1xf32> +} -- 2.7.4