[mlir][vector] Add distribution for extract from 0d vector
authorThomas Raoux <thomasraoux@google.com>
Fri, 14 Oct 2022 23:04:24 +0000 (23:04 +0000)
committerThomas Raoux <thomasraoux@google.com>
Fri, 14 Oct 2022 23:06:42 +0000 (23:06 +0000)
Differential Revision: https://reviews.llvm.org/D135994

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir

index 3c4f20f..f730044 100644 (file)
@@ -895,6 +895,34 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
   }
 };
 
+/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
+/// need to be distributed and can just be propagated outside of the region.
+struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
+      return isa<vector::ExtractElementOp>(op);
+    });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
+    if (extractOp.getVectorType().getRank() != 0)
+      return failure();
+    Location loc = extractOp.getLoc();
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value newExtract = rewriter.create<vector::ExtractElementOp>(
+        loc, newWarpOp->getResult(newRetIndices[0]));
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't change
 /// the order of execution. This creates a new scf.for region after the
@@ -1093,8 +1121,9 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpScfForOp, WarpOpConstant>(patterns.getContext(), benefit);
+               WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
+               WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateDistributeReduction(
index 3978d94..49c36fe 100644 (file)
@@ -632,6 +632,24 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<f32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
+//       CHECK-PROP:   return %[[E]] : f32
+func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<f32>)
+    %1 = vector.extractelement %0[] : vector<f32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
 // CHECK-PROP:   func @lane_dependent_warp_propagate_read
 //  CHECK-PROP-SAME:   %[[ID:.*]]: index
 func.func @lane_dependent_warp_propagate_read(