[mlir][vector] Add pattern to distribute splat constant
authorThomas Raoux <thomasraoux@google.com>
Mon, 11 Jul 2022 07:01:13 +0000 (07:01 +0000)
committerThomas Raoux <thomasraoux@google.com>
Mon, 11 Jul 2022 15:50:26 +0000 (15:50 +0000)
Distribute splat constant out of WarpExecuteOnLane0Op region.

Differential Revision: https://reviews.llvm.org/D129467

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir

index 57fa863..1fb7a21 100644 (file)
@@ -524,6 +524,44 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Sink out splat constant op feeding into a warp op yield.
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
+///   ...
+///   %cst = arith.constant dense<2.0> : vector<32xf32>
+///   vector.yield %cst : vector<32xf32>
+/// }
+/// ```
+/// To
+/// ```
+/// vector.warp_execute_on_lane_0(%arg0 {
+///   ...
+/// }
+/// %0 = arith.constant dense<2.0> : vector<1xf32>
+struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
+    if (!yieldOperand)
+      return failure();
+    auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
+    auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+    if (!dense)
+      return failure();
+    unsigned operandIndex = yieldOperand->getOperandNumber();
+    Attribute scalarAttr = dense.getSplatValue<Attribute>();
+    Attribute newAttr = DenseElementsAttr::get(
+        warpOp.getResult(operandIndex).getType(), scalarAttr);
+    Location loc = warpOp.getLoc();
+    rewriter.setInsertionPointAfter(warpOp);
+    Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
+    warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
+    return success();
+  }
+};
+
 /// Sink out transfer_read op feeding into a warp op yield.
 /// ```
 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
@@ -868,8 +906,8 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
-      patterns.getContext());
+               WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp,
+               WarpOpConstant>(patterns.getContext());
 }
 
 void mlir::vector::populateDistributeReduction(
index 4a04f98..55a8490 100644 (file)
@@ -562,3 +562,16 @@ func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>
   }
   return %r#0, %r#1 : vector<1xf32>, vector<1xf32>
 }
+
+// -----
+
+// CHECK-PROP-LABEL: func @warp_constant(
+//       CHECK-PROP:   %[[C:.*]] = arith.constant dense<2.000000e+00> : vector<1xf32>
+//       CHECK-PROP:   return %[[C]] : vector<1xf32>
+func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+    %cst = arith.constant dense<2.0> : vector<32xf32>
+    vector.yield %cst : vector<32xf32>
+  }
+  return %r : vector<1xf32>
+}