[mlir][vector] Reorder elementwise(transpose)
authorLei Zhang <antiagainst@google.com>
Fri, 15 Apr 2022 12:57:24 +0000 (08:57 -0400)
committerLei Zhang <antiagainst@google.com>
Fri, 15 Apr 2022 13:05:35 +0000 (09:05 -0400)
Similar to the existing pattern for reodering cast(transpose),
this makes transpose following transpose and increases the chance
of embedding the transposition inside contraction op. Actually
cast ops are just special instances of elementwise ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123596

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

index 3312095..50b89c1 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -1048,43 +1049,86 @@ struct ReorderCastOpsOnBroadcast
   }
 };
 
-/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
-/// contraction ops closer, which kicks in CombineContractTranspose pattern when
-/// casting ops are around these operations.
-/// Ex:
+/// Reorders elementwise(transpose) to transpose(elementwise). This makes
+/// transpose ops and contraction ops closer, which kicks in
+/// CombineContractTranspose pattern when elementwise ops are between these
+/// operations. Ex:
 /// ```
-///   %0 = vector.transpose %arg0, [2, 0, 1]
-///     : vector<32x16x8xi8> to vector<8x32x16xi8>
-///   %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %r = arith.addf %at, %bt : vector<2x4xf32>
 /// ```
 /// Gets converted to:
 /// ```
-///   %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
-///   %1 = vector.transpose %arg0, [2, 0, 1]
-///     : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// %0 = arith.addf %a, %b : vector<4x2xf32>
+/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
 /// ```
-struct ReorderCastOpsOnTranspose
-    : public OpInterfaceRewritePattern<CastOpInterface> {
-
-  using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(CastOpInterface op,
+struct ReorderElementwiseOpsOnTranspose final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (op->getNumOperands() != 1)
+    if (op->getNumResults() != 1 || op->getNumRegions() != 0)
       return failure();
-    auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
-    if (!transpOp)
+
+    // Make sure all operands are transpose/constant ops and collect their
+    // transposition maps.
+    SmallVector<ArrayAttr, 4> transposeMaps;
+    transposeMaps.reserve(op->getNumOperands());
+    // Record the initial type before transposition. We'll use its shape later.
+    // Any type will do here as we will check all transpose maps are the same.
+    VectorType srcType;
+    for (Value operand : op->getOperands()) {
+      auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+      if (transposeOp) {
+        transposeMaps.push_back(transposeOp.getTransp());
+        srcType = transposeOp.getVectorType();
+      } else if (!matchPattern(operand, m_Constant())) {
+        return failure();
+      }
+    }
+    if (transposeMaps.empty())
       return failure();
+    // This is an elementwise op, so all transposed operands should have the
+    // same type. We need to additionally check that all transposes uses the
+    // same map.
+    if (!llvm::is_splat(transposeMaps))
+      return rewriter.notifyMatchFailure(op, "different transpose map");
+
+    SmallVector<Value, 4> srcValues;
+    srcValues.reserve(op->getNumOperands());
+
+    // If there are constant operands, we need to insert inverse transposes for
+    // them. Calculate the inverse order first.
+    auto order = extractVector<unsigned>(transposeMaps.front());
+    SmallVector<int64_t> invOrder(order.size());
+    for (int i = 0, e = order.size(); i < e; ++i)
+      invOrder[order[i]] = i;
+
+    for (Value operand : op->getOperands()) {
+      auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+      if (transposeOp) {
+        srcValues.push_back(transposeOp.getVector());
+      } else {
+        // This is a constant. Create a reverse transpose op for it.
+        auto vectorType = VectorType::get(
+            srcType.getShape(),
+            operand.getType().cast<VectorType>().getElementType());
+        srcValues.push_back(rewriter.create<vector::TransposeOp>(
+            operand.getLoc(), vectorType, operand,
+            rewriter.getI64ArrayAttr(invOrder)));
+      }
+    }
 
-    auto castResTy = transpOp.getVectorType();
-    castResTy = VectorType::get(castResTy.getShape(),
-                                getElementTypeOrSelf(op->getResult(0)));
-    auto *castOp =
-        rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                        transpOp.getVector(), castResTy, op->getAttrs());
+    auto vectorType = VectorType::get(
+        srcType.getShape(),
+        op->getResultTypes()[0].cast<VectorType>().getElementType());
+    Operation *elementwiseOp =
+        rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
+                        vectorType, op->getAttrs());
     rewriter.replaceOpWithNewOp<vector::TransposeOp>(
-        op, op->getResult(0).getType(), castOp->getResult(0),
-        transpOp.getTransp());
+        op, op->getResultTypes()[0], elementwiseOp->getResult(0),
+        transposeMaps.front());
     return success();
   }
 };
@@ -2647,7 +2691,7 @@ void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
                CombineContractTranspose, ReorderCastOpsOnBroadcast,
-               ReorderCastOpsOnTranspose>(patterns.getContext());
+               ReorderElementwiseOpsOnTranspose>(patterns.getContext());
 }
 
 void mlir::vector::
index 1167f1e..b17771a 100644 (file)
@@ -120,3 +120,80 @@ func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
   %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
   return %r : vector<2x4xi32>
 }
+
+//===----------------------------------------------------------------------===//
+// Reorder elementwise ops and vector ops.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_same_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+//       CHECK:   return %[[T]]
+
+func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.addf %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
+//  CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+//       CHECK:   return %[[T]]
+func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+  %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
+  return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
+//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+//       CHECK:   %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
+//       CHECK:   return %[[T]]
+func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
+  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+  %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
+  return %r : vector<2x4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_splat_constant
+//  CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
+//       CHECK:   %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
+//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
+//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+//       CHECK:   return %[[T:.+]] : vector<6x4x2x3xf32>
+
+func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
+  %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %b : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_map
+//       CHECK:   vector.transpose
+//       CHECK:   vector.transpose
+//       CHECK:   arith.addf
+func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
+  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+  %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
+  %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
+  return %r : vector<6x4x2x3xf32>
+}