[mlir][vector] Support vector.extract distribution of >1D vectors
authorMatthias Springer <springerm@google.com>
Mon, 9 Jan 2023 15:35:29 +0000 (16:35 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 9 Jan 2023 15:39:50 +0000 (16:39 +0100)
Ops such as `%1 = vector.extract %0[2] : vector<5x96xf32>`.

Distribute the source vector, then extract. In case of a 1d extract, rewrite to vector.extractelement.

Differential Revision: https://reviews.llvm.org/D137646

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir

index 08841e3..60ca036 100644 (file)
@@ -897,16 +897,81 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
-    if (extractOp.getVectorType().getNumElements() != 1)
-      return failure();
+    VectorType extractSrcType = extractOp.getVectorType();
     Location loc = extractOp.getLoc();
+
+    // "vector.extract %v[] : vector<f32>" is an invalid op.
+    assert(extractSrcType.getRank() > 0 &&
+           "vector.extract does not support rank 0 sources");
+
+    // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v.
+    if (extractOp.getPosition().empty())
+      return failure();
+
+    // Rewrite vector.extract with 1d source to vector.extractelement.
+    if (extractSrcType.getRank() == 1) {
+      assert(extractOp.getPosition().size() == 1 && "expected 1 index");
+      int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
+      rewriter.setInsertionPoint(extractOp);
+      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
+          extractOp, extractOp.getVector(),
+          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      return success();
+    }
+
+    // All following cases are 2d or higher dimensional source vectors.
+
+    if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
+      // There is no distribution, this is a broadcast. Simply move the extract
+      // out of the warp op.
+      // TODO: This could be optimized. E.g., in case of a scalar result, let
+      // one lane extract and shuffle the result to all other lanes (same as
+      // the 1d case).
+      SmallVector<size_t> newRetIndices;
+      WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+          rewriter, warpOp, {extractOp.getVector()},
+          {extractOp.getVectorType()}, newRetIndices);
+      rewriter.setInsertionPointAfter(newWarpOp);
+      Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+      // Extract from distributed vector.
+      Value newExtract = rewriter.create<vector::ExtractOp>(
+          loc, distributedVec, extractOp.getPosition());
+      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
+      return success();
+    }
+
+    // Find the distributed dimension. There should be exactly one.
+    auto distributedType =
+        warpOp.getResult(operandNumber).getType().cast<VectorType>();
+    auto yieldedType = operand->get().getType().cast<VectorType>();
+    int64_t distributedDim = -1;
+    for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
+      if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
+        // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
+        // support distributing multiple dimensions in the future.
+        assert(distributedDim == -1 && "found multiple distributed dims");
+        distributedDim = i;
+      }
+    }
+    assert(distributedDim != -1 && "could not find distributed dimension");
+
+    // Yield source vector from warp op.
+    SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
+                                             extractSrcType.getShape().end());
+    for (int i = 0; i < distributedType.getRank(); ++i)
+      newDistributedShape[i + extractOp.getPosition().size()] =
+          distributedType.getDimSize(i);
+    auto newDistributedType =
+        VectorType::get(newDistributedShape, distributedType.getElementType());
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
+        rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
         newRetIndices);
     rewriter.setInsertionPointAfter(newWarpOp);
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+    // Extract from distributed vector.
     Value newExtract = rewriter.create<vector::ExtractOp>(
-        loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
+        loc, distributedVec, extractOp.getPosition());
     newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
     return success();
   }
index 3489054..5a238c5 100644 (file)
@@ -648,17 +648,58 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
 
 // -----
 
-// CHECK-PROP-LABEL: func.func @vector_extract_simple(
-//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
-//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
-//       CHECK-PROP:     vector.yield %[[V]] : vector<1xf32>
+// TODO: We could use warp shuffles instead of broadcasting the entire vector.
+
+// CHECK-PROP-LABEL: func.func @vector_extract_1d(
+//   CHECK-PROP-DAG:   %[[C5_I32:.*]] = arith.constant 5 : i32
+//   CHECK-PROP-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<64xf32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<64xf32>
 //       CHECK-PROP:   }
-//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32>
-//       CHECK-PROP:   return %[[E]] : f32
-func.func @vector_extract_simple(%laneid: index) -> (f32) {
+//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C1]] : index] : vector<2xf32>
+//       CHECK-PROP:   %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle  idx %[[E]], %[[C5_I32]]
+//       CHECK-PROP:   return %[[SHUFFLED]] : f32
+func.func @vector_extract_1d(%laneid: index) -> (f32) {
   %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
-    %0 = "some_def"() : () -> (vector<1xf32>)
-    %1 = vector.extract %0[0] : vector<1xf32>
+    %0 = "some_def"() : () -> (vector<64xf32>)
+    %1 = vector.extract %0[9] : vector<64xf32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_2d(
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x3xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[V]] : vector<5x96xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[W]][2] : vector<5x3xf32>
+//       CHECK-PROP:   return %[[E]]
+func.func @vector_extract_2d(%laneid: index) -> (vector<3xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<5x96xf32>)
+    %1 = vector.extract %0[2] : vector<5x96xf32>
+    vector.yield %1 : vector<96xf32>
+  }
+  return %r : vector<3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast_scalar(
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[V]] : vector<5x96xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[W]][1, 2] : vector<5x96xf32>
+//       CHECK-PROP:   return %[[E]]
+func.func @vector_extract_2d_broadcast_scalar(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<5x96xf32>)
+    %1 = vector.extract %0[1, 2] : vector<5x96xf32>
     vector.yield %1 : f32
   }
   return %r : f32
@@ -666,6 +707,42 @@ func.func @vector_extract_simple(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func.func @vector_extract_2d_broadcast(
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<5x96xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[V]] : vector<5x96xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[W]][2] : vector<5x96xf32>
+//       CHECK-PROP:   return %[[E]]
+func.func @vector_extract_2d_broadcast(%laneid: index) -> (vector<96xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
+    %0 = "some_def"() : () -> (vector<5x96xf32>)
+    %1 = vector.extract %0[2] : vector<5x96xf32>
+    vector.yield %1 : vector<96xf32>
+  }
+  return %r : vector<96xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func.func @vector_extract_3d(
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8x4x96xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"
+//       CHECK-PROP:     vector.yield %[[V]] : vector<8x128x96xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extract %[[W]][2] : vector<8x4x96xf32>
+//       CHECK-PROP:   return %[[E]]
+func.func @vector_extract_3d(%laneid: index) -> (vector<4x96xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) {
+    %0 = "some_def"() : () -> (vector<8x128x96xf32>)
+    %1 = vector.extract %0[2] : vector<8x128x96xf32>
+    vector.yield %1 : vector<128x96xf32>
+  }
+  return %r : vector<4x96xf32>
+}
+
+// -----
+
 // CHECK-PROP-LABEL: func.func @vector_extractelement_0d(
 //       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<f32>) {
 //       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<f32>