[mlir][Vector] Add patterns to reorder elementwise ops and broadcast/transpose ops.
authorHanhan Wang <hanchung@google.com>
Mon, 7 Mar 2022 20:52:03 +0000 (12:52 -0800)
committerHanhan Wang <hanchung@google.com>
Mon, 7 Mar 2022 20:52:12 +0000 (12:52 -0800)
In quantized comutation, there are casting ops around computation ops.
Reorder the ops to make reduce-to-contract actually work.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D120760

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

index 4421b55..ee6429d 100644 (file)
@@ -1009,6 +1009,84 @@ struct CombineContractBroadcast
   }
 };
 
+/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+///   %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
+///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
+///   %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnBroadcast
+    : public OpInterfaceRewritePattern<CastOpInterface> {
+  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return failure();
+    auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    Type castResTy = getElementTypeOrSelf(op->getResult(0));
+    if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
+      castResTy = VectorType::get(vecTy.getShape(), castResTy);
+    OperationState state(op->getLoc(), op->getName(), bcastOp.source(),
+                         castResTy, op->getAttrs());
+    auto castOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        op, op->getResult(0).getType(), castOp->getResult(0));
+    return success();
+  }
+};
+
+/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
+/// contraction ops closer, which kicks in CombineContractTranspose pattern when
+/// casting ops are around these operations.
+/// Ex:
+/// ```
+///   %0 = vector.transpose %arg0, [2, 0, 1]
+///     : vector<32x16x8xi8> to vector<8x32x16xi8>
+///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// ```
+/// Gets converted to:
+/// ```
+///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
+///   %1 = vector.transpose %arg0, [2, 0, 1]
+///     : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// ```
+struct ReorderCastOpsOnTranspose
+    : public OpInterfaceRewritePattern<CastOpInterface> {
+
+  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(CastOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return failure();
+    auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
+    if (!transpOp)
+      return failure();
+
+    auto castResTy = transpOp.getVectorType();
+    castResTy = VectorType::get(castResTy.getShape(),
+                                getElementTypeOrSelf(op->getResult(0)));
+    OperationState state(op->getLoc(), op->getName(), transpOp.vector(),
+                         castResTy, op->getAttrs());
+    auto castOp = rewriter.createOperation(state);
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(
+        op, op->getResult(0).getType(), castOp->getResult(0),
+        transpOp.getTransp());
+    return success();
+  }
+};
+
 } // namespace
 
 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@@ -2585,7 +2663,8 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractTranspose>(patterns.getContext());
+               CombineContractTranspose, ReorderCastOpsOnBroadcast,
+               ReorderCastOpsOnTranspose>(patterns.getContext());
 }
 
 void mlir::vector::
index 85389b2..1167f1e 100644 (file)
@@ -85,3 +85,38 @@ func @contract_broadcast(
     kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
   return %1 : vector<8x32xf32>
 }
+
+//===----------------------------------------------------------------------===//
+// Reorder casting ops and vector ops. The casting ops have almost identical
+// pattern, so only arith.extsi op is tested.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
+  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
+  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
+  %b = vector.broadcast %a : i8 to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}
+
+// -----
+
+func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
+  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
+  // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
+  %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
+  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
+  return %r : vector<2x4xi32>
+}