[mlir][vector] Distribute vector.insertelement op
authorMatthias Springer <springerm@google.com>
Mon, 9 Jan 2023 15:40:32 +0000 (16:40 +0100)
committerMatthias Springer <springerm@google.com>
Mon, 9 Jan 2023 15:41:08 +0000 (16:41 +0100)
In case of a distribution, only one lane inserts the scalar value. In case of a broadcast, every lane inserts the scalar.

Differential Revision: https://reviews.llvm.org/D137929

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir

index 60ca036..df7b240 100644 (file)
@@ -1033,8 +1033,13 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     Value broadcastFromTid = rewriter.create<AffineApplyOp>(
         loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
     // Extract at position: pos % elementsPerLane
-    Value pos = rewriter.create<AffineApplyOp>(loc, sym0 % elementsPerLane,
-                                               extractOp.getPosition());
+    Value pos =
+        elementsPerLane == 1
+            ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
+            : rewriter
+                  .create<AffineApplyOp>(loc, sym0 % elementsPerLane,
+                                         extractOp.getPosition())
+                  .getResult();
     Value extracted =
         rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
 
@@ -1049,6 +1054,85 @@ private:
   WarpShuffleFromIdxFn warpShuffleFromIdxFn;
 };
 
+struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
+    if (!operand)
+      return failure();
+    unsigned int operandNumber = operand->getOperandNumber();
+    auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
+    VectorType vecType = insertOp.getDestVectorType();
+    VectorType distrType =
+        warpOp.getResult(operandNumber).getType().cast<VectorType>();
+    bool hasPos = static_cast<bool>(insertOp.getPosition());
+
+    // Yield destination vector, source scalar and position from warp op.
+    SmallVector<Value> additionalResults{insertOp.getDest(),
+                                         insertOp.getSource()};
+    SmallVector<Type> additionalResultTypes{distrType,
+                                            insertOp.getSource().getType()};
+    if (hasPos) {
+      additionalResults.push_back(insertOp.getPosition());
+      additionalResultTypes.push_back(insertOp.getPosition().getType());
+    }
+    Location loc = insertOp.getLoc();
+    SmallVector<size_t> newRetIndices;
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, additionalResults, additionalResultTypes,
+        newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
+    Value newSource = newWarpOp->getResult(newRetIndices[1]);
+    Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    if (vecType == distrType) {
+      // Broadcast: Simply move the vector.inserelement op out.
+      Value newInsert = rewriter.create<vector::InsertElementOp>(
+          loc, newSource, distributedVec, newPos);
+      newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert);
+      return success();
+    }
+
+    // This is a distribution. Only one lane should insert.
+    int64_t elementsPerLane = distrType.getShape()[0];
+    AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
+    // tid of extracting thread: pos / elementsPerLane
+    Value insertingLane = rewriter.create<AffineApplyOp>(
+        loc, sym0.ceilDiv(elementsPerLane), newPos);
+    // Insert position: pos % elementsPerLane
+    Value pos =
+        elementsPerLane == 1
+            ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
+            : rewriter
+                  .create<AffineApplyOp>(loc, sym0 % elementsPerLane, newPos)
+                  .getResult();
+    Value isInsertingLane = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
+    Value newResult =
+        rewriter
+            .create<scf::IfOp>(
+                loc, distrType, isInsertingLane,
+                /*thenBuilder=*/
+                [&](OpBuilder &builder, Location loc) {
+                  Value newInsert = builder.create<vector::InsertElementOp>(
+                      loc, newSource, distributedVec, pos);
+                  builder.create<scf::YieldOp>(loc, newInsert);
+                },
+                /*elseBuilder=*/
+                [&](OpBuilder &builder, Location loc) {
+                  builder.create<scf::YieldOp>(loc, distributedVec);
+                })
+            .getResult(0);
+    newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult);
+    return success();
+  }
+};
+
 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
 /// the scf.ForOp is the last operation in the region so that it doesn't change
 /// the order of execution. This creates a new scf.for region after the
@@ -1303,7 +1387,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
     const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
                WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
-               WarpOpConstant>(patterns.getContext(), benefit);
+               WarpOpConstant, WarpOpInsertElement>(patterns.getContext(),
+                                                    benefit);
   patterns.add<WarpOpExtractElement>(patterns.getContext(),
                                      warpShuffleFromIdxFn, benefit);
   patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
index 5a238c5..b19c3cd 100644 (file)
@@ -930,3 +930,67 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
   //     CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
   return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
 }
+
+// -----
+
+//       CHECK-PROP:   #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
+//       CHECK-PROP:   #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
+// CHECK-PROP-LABEL: func @vector_insertelement_1d(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32)
+//       CHECK-PROP:   %[[INSERTING_LANE:.*]] = affine.apply #[[$MAP]]()[%[[POS]]]
+//       CHECK-PROP:   %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]]
+//       CHECK-PROP:   %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index
+//       CHECK-PROP:   %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) {
+//       CHECK-PROP:     %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index]
+//       CHECK-PROP:     scf.yield %[[INSERT]]
+//       CHECK-PROP:   } else {
+//       CHECK-PROP:     scf.yield %[[W]]#0
+//       CHECK-PROP:   }
+//       CHECK-PROP:   return %[[R]]
+func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) {
+    %0 = "some_def"() : () -> (vector<96xf32>)
+    %f = "another_def"() : () -> (f32)
+    %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+    vector.yield %1 : vector<96xf32>
+  }
+  return %r : vector<3xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast(
+//  CHECK-PROP-SAME:     %[[LANEID:.*]]: index, %[[POS:.*]]: index
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, f32)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
+//       CHECK-PROP:   vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32>
+func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) {
+    %0 = "some_def"() : () -> (vector<96xf32>)
+    %f = "another_def"() : () -> (f32)
+    %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32>
+    vector.yield %1 : vector<96xf32>
+  }
+  return %r : vector<96xf32>
+}
+
+// -----
+
+// CHECK-PROP-LABEL: func @vector_insertelement_0d(
+//       CHECK-PROP:   %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<f32>, f32)
+//       CHECK-PROP:     %[[VEC:.*]] = "some_def"
+//       CHECK-PROP:     %[[VAL:.*]] = "another_def"
+//       CHECK-PROP:     vector.yield %[[VEC]], %[[VAL]]
+//       CHECK-PROP:   vector.insertelement %[[W]]#1, %[[W]]#0[] : vector<f32>
+func.func @vector_insertelement_0d(%laneid: index) -> (vector<f32>) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<f32>) {
+    %0 = "some_def"() : () -> (vector<f32>)
+    %f = "another_def"() : () -> (f32)
+    %1 = vector.insertelement %f, %0[] : vector<f32>
+    vector.yield %1 : vector<f32>
+  }
+  return %r : vector<f32>
+}