#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
}
};
-/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and
-/// contraction ops closer, which kicks in CombineContractTranspose pattern when
-/// casting ops are around these operations.
-/// Ex:
+/// Reorders elementwise(transpose) to transpose(elementwise). This makes
+/// transpose ops and contraction ops closer, which kicks in
+/// CombineContractTranspose pattern when elementwise ops are between these
+/// operations. Ex:
/// ```
-/// %0 = vector.transpose %arg0, [2, 0, 1]
-/// : vector<32x16x8xi8> to vector<8x32x16xi8>
-/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
+/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+/// %r = arith.addf %at, %bt : vector<2x4xf32>
/// ```
/// Gets converted to:
/// ```
-/// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32>
-/// %1 = vector.transpose %arg0, [2, 0, 1]
-/// : vector<32x16x8xi32> to vector<8x32x16xi32>
+/// %0 = arith.addf %a, %b : vector<4x2xf32>
+/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
/// ```
-struct ReorderCastOpsOnTranspose
- : public OpInterfaceRewritePattern<CastOpInterface> {
-
- using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(CastOpInterface op,
+struct ReorderElementwiseOpsOnTranspose final
+ : public OpTraitRewritePattern<OpTrait::Elementwise> {
+ using OpTraitRewritePattern::OpTraitRewritePattern;
+ LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (op->getNumOperands() != 1)
+ if (op->getNumResults() != 1 || op->getNumRegions() != 0)
return failure();
- auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>();
- if (!transpOp)
+
+ // Make sure all operands are transpose/constant ops and collect their
+ // transposition maps.
+ SmallVector<ArrayAttr, 4> transposeMaps;
+ transposeMaps.reserve(op->getNumOperands());
+ // Record the initial type before transposition. We'll use its shape later.
+ // Any type will do here as we will check all transpose maps are the same.
+ VectorType srcType;
+ for (Value operand : op->getOperands()) {
+ auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+ if (transposeOp) {
+ transposeMaps.push_back(transposeOp.getTransp());
+ srcType = transposeOp.getVectorType();
+ } else if (!matchPattern(operand, m_Constant())) {
+ return failure();
+ }
+ }
+ if (transposeMaps.empty())
return failure();
+ // This is an elementwise op, so all transposed operands should have the
+ // same type. We need to additionally check that all transposes uses the
+ // same map.
+ if (!llvm::is_splat(transposeMaps))
+ return rewriter.notifyMatchFailure(op, "different transpose map");
+
+ SmallVector<Value, 4> srcValues;
+ srcValues.reserve(op->getNumOperands());
+
+ // If there are constant operands, we need to insert inverse transposes for
+ // them. Calculate the inverse order first.
+ auto order = extractVector<unsigned>(transposeMaps.front());
+ SmallVector<int64_t> invOrder(order.size());
+ for (int i = 0, e = order.size(); i < e; ++i)
+ invOrder[order[i]] = i;
+
+ for (Value operand : op->getOperands()) {
+ auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
+ if (transposeOp) {
+ srcValues.push_back(transposeOp.getVector());
+ } else {
+ // This is a constant. Create a reverse transpose op for it.
+ auto vectorType = VectorType::get(
+ srcType.getShape(),
+ operand.getType().cast<VectorType>().getElementType());
+ srcValues.push_back(rewriter.create<vector::TransposeOp>(
+ operand.getLoc(), vectorType, operand,
+ rewriter.getI64ArrayAttr(invOrder)));
+ }
+ }
- auto castResTy = transpOp.getVectorType();
- castResTy = VectorType::get(castResTy.getShape(),
- getElementTypeOrSelf(op->getResult(0)));
- auto *castOp =
- rewriter.create(op->getLoc(), op->getName().getIdentifier(),
- transpOp.getVector(), castResTy, op->getAttrs());
+ auto vectorType = VectorType::get(
+ srcType.getShape(),
+ op->getResultTypes()[0].cast<VectorType>().getElementType());
+ Operation *elementwiseOp =
+ rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
+ vectorType, op->getAttrs());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(
- op, op->getResult(0).getType(), castOp->getResult(0),
- transpOp.getTransp());
+ op, op->getResultTypes()[0], elementwiseOp->getResult(0),
+ transposeMaps.front());
return success();
}
};
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose, ReorderCastOpsOnBroadcast,
- ReorderCastOpsOnTranspose>(patterns.getContext());
+ ReorderElementwiseOpsOnTranspose>(patterns.getContext());
}
void mlir::vector::
%r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
return %r : vector<2x4xi32>
}
+
+//===----------------------------------------------------------------------===//
+// Reorder elementwise ops and vector ops.
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_same_type
+// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+// CHECK: return %[[T]]
+
+func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+ %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %r = arith.addf %at, %bt : vector<2x4xf32>
+ return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
+// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: return %[[T]]
+func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
+ %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
+ %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
+ return %r : vector<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
+// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
+// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
+// CHECK: return %[[T]]
+func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
+ %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
+ return %r : vector<2x4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_splat_constant
+// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
+// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
+
+func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
+ %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+ %r = arith.addf %at, %b : vector<6x4x2x3xf32>
+ return %r : vector<6x4x2x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_map
+// CHECK: vector.transpose
+// CHECK: vector.transpose
+// CHECK: arith.addf
+func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
+ %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
+ %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
+ return %r : vector<6x4x2x3xf32>
+}