[mlir][vector] Support vector.extractelement distribution of 1D vectors
authorMatthias Springer <springerm@google.com>
Thu, 10 Nov 2022 14:04:23 +0000 (15:04 +0100)
committerMatthias Springer <springerm@google.com>
Thu, 10 Nov 2022 14:07:56 +0000 (15:07 +0100)
Ops such as `%1 = vector.extractelement %0[%pos : index] : vector<96xf32>`.

In case of an extract from a 1D vector, the source vector is distributed. The lane into which the requested position falls, extracts the element and shuffles it to all other lanes.

Differential Revision: https://reviews.llvm.org/D137336

mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 49e3427..a76a58e 100644 (file)
@@ -67,11 +67,17 @@ void populateDistributeTransferWriteOpPatterns(
 /// region.
 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
 
+/// Lambda signature to compute a warp shuffle of a given value of a given lane
+/// within a given warp size.
+using WarpShuffleFromIdxFn =
+    std::function<Value(Location, OpBuilder &b, Value, Value, int64_t)>;
+
 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used
 /// to decide how a value should be distributed when this cannot be inferred
 /// from its uses.
 void populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
+    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
     PatternBenefit benefit = 1);
 
 /// Lambda signature to compute a reduction of a distributed value for the given
index a2916a5..c56af3a 100644 (file)
@@ -915,7 +915,10 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
 /// need to be distributed and can just be propagated outside of the region.
 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
+                       PatternBenefit b = 1)
+      : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
+        warpShuffleFromIdxFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
@@ -925,19 +928,60 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    if (extractOp.getVectorType().getRank() != 0)
-      return failure();
+    VectorType extractSrcType = extractOp.getVectorType();
+    bool is0dExtract = extractSrcType.getRank() == 0;
+    Type elType = extractSrcType.getElementType();
+    VectorType distributedVecType;
+    if (!is0dExtract) {
+      assert(extractSrcType.getRank() == 1 &&
+             "expected that extractelement src rank is 0 or 1");
+      int64_t elementsPerLane =
+          extractSrcType.getShape()[0] / warpOp.getWarpSize();
+      distributedVecType = VectorType::get({elementsPerLane}, elType);
+    } else {
+      distributedVecType = extractSrcType;
+    }
+
+    // Yield source vector from warp op.
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+        rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
-    Value newExtract = rewriter.create<vector::ExtractElementOp>(
-        loc, newWarpOp->getResult(newRetIndices[0]));
-    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+
+    // 0d extract: The new warp op broadcasts the source vector to all lanes.
+    // All lanes extract the scalar.
+    if (is0dExtract) {
+      Value newExtract =
+          rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+      return success();
+    }
+
+    // 1d extract: Distribute the source vector. One lane extracts and shuffles
+    // the value to all other lanes.
+    int64_t elementsPerLane = distributedVecType.getShape()[0];
+    AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
+    // tid of extracting thread: pos / elementsPerLane
+    Value broadcastFromTid = rewriter.create<AffineApplyOp>(
+        loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
+    // Extract at position: pos % elementsPerLane
+    Value pos = rewriter.create<AffineApplyOp>(loc, sym0 % elementsPerLane,
+                                               extractOp.getPosition());
+    Value extracted =
+        rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
+
+    // Shuffle the extracted value to all lanes.
+    Value shuffled = warpShuffleFromIdxFn(
+        loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled);
     return success();
   }
+
+private:
+  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
@@ -1194,11 +1238,12 @@ void mlir::vector::populateDistributeTransferWriteOpPatterns(
 
 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
-    PatternBenefit benefit) {
+    const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
-               WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement,
-               WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(),
-                                                     benefit);
+               WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
+               WarpOpConstant>(patterns.getContext(), benefit);
+  patterns.add<WarpOpExtractElement>(patterns.getContext(),
+                                     warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
                                benefit);
 }
index daebccd..b698745 100644 (file)
@@ -666,14 +666,14 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
 
 // -----
 
-// CHECK-PROP-LABEL: func.func @vector_extractelement_simple(
+// CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
 //       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>
 //       CHECK-PROP:     vector.yield %[[V]] : vector<f32>
 //       CHECK-PROP:   }
 //       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][] : vector<f32>
 //       CHECK-PROP:   return %[[E]] : f32
-func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
+func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
     %0 = "some_def"() : () -> (vector<f32>)
     %1 = vector.extractelement %0[] : vector<f32>
@@ -684,6 +684,32 @@ func.func @vector_extractelement_simple(%laneid: index) -> (f32) {
 
 // -----
 
+//       CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+//       CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
+//   CHECK-PROP-DAG:   %[[C32:.*]] = arith.constant 32 : i32
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[V]] : vector<96xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]]
+//       CHECK-PROP:   %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]]
+//       CHECK-PROP:   %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32>
+//       CHECK-PROP:   %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32
+//       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32
+//       CHECK-PROP:   return %[[SHUFFLED]]
+func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<96xf32>)
+    %1 = vector.extractelement %0[%pos : index] : vector<96xf32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
 // CHECK-PROP:   func @lane_dependent_warp_propagate_read
 //  CHECK-PROP-SAME:   %[[ID:.*]]: index
 func.func @lane_dependent_warp_propagate_read(
index 4f44d43..6b9afe3 100644 (file)
@@ -759,6 +759,21 @@ struct TestVectorDistribution
         return AffineMap::get(val.getContext());
       return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
     };
+    auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
+                        Value srcIdx, int64_t warpSz) {
+      assert((val.getType().isF32() || val.getType().isInteger(32)) &&
+             "unsupported shuffle type");
+      Type i32Type = builder.getIntegerType(32);
+      Value srcIdxI32 =
+          builder.create<arith::IndexCastOp>(loc, i32Type, srcIdx);
+      Value warpSzI32 = builder.create<arith::ConstantOp>(
+          loc, builder.getIntegerAttr(i32Type, warpSz));
+      Value result = builder
+                         .create<gpu::ShuffleOp>(loc, val, srcIdxI32, warpSzI32,
+                                                 gpu::ShuffleMode::IDX)
+                         .getResult(0);
+      return result;
+    };
     if (distributeTransferWriteOps) {
       RewritePatternSet patterns(ctx);
       populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
@@ -766,8 +781,8 @@ struct TestVectorDistribution
     }
     if (propagateDistribution) {
       RewritePatternSet patterns(ctx);
-      vector::populatePropagateWarpVectorDistributionPatterns(patterns,
-                                                              distributionFn);
+      vector::populatePropagateWarpVectorDistributionPatterns(
+          patterns, distributionFn, shuffleFn);
       vector::populateDistributeReduction(patterns, warpReduction);
       (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
     }