return success();
}
+namespace {
+
+// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
+class TransposeFolder final : public OpRewritePattern<TransposeOp> {
+public:
+ using OpRewritePattern<TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ // Wrapper around TransposeOp::getTransp() for cleaner code.
+ auto getPermutation = [](TransposeOp transpose) {
+ SmallVector<int64_t, 4> permutation;
+ transpose.getTransp(permutation);
+ return permutation;
+ };
+
+ // Composes two permutations: result[i] = permutation1[permutation2[i]].
+ auto composePermutations = [](ArrayRef<int64_t> permutation1,
+ ArrayRef<int64_t> permutation2) {
+ SmallVector<int64_t, 4> result;
+ for (auto index : permutation2)
+ result.push_back(permutation1[index]);
+ return result;
+ };
+
+ // Return if the input of 'transposeOp' is not defined by another transpose.
+ TransposeOp parentTransposeOp =
+ dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
+ if (!parentTransposeOp)
+ return failure();
+
+ SmallVector<int64_t, 4> permutation = composePermutations(
+ getPermutation(parentTransposeOp), getPermutation(transposeOp));
+ // Replace 'transposeOp' with a new transpose operation.
+ rewriter.replaceOpWithNewOp<TransposeOp>(
+ transposeOp, transposeOp.getResult().getType(),
+ parentTransposeOp.vector(),
+ vector::getVectorSubscriptAttr(rewriter, permutation));
+ return success();
+ }
+};
+
+} // end anonymous namespace
+
+void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<TransposeFolder>(context);
+}
+
void TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(transp(), results);
}
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
+ patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder,
+ TransposeFolder>(context);
}
namespace mlir {
// CHECK-LABEL: transpose_2D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
-func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<3x4xf32> {
+func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
// CHECK-NOT: transpose
- %0 = vector.transpose %arg, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
- // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [1, 0]
- %1 = vector.transpose %0, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
- // CHECK-NOT: transpose
- %2 = vector.transpose %1, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
- // CHECK: [[ADD:%.*]] = addf [[T1]], [[T1]]
- %4 = addf %1, %2 : vector<3x4xf32>
+ %0 = vector.transpose %arg, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
+ %1 = vector.transpose %0, [0, 1] : vector<3x4xf32> to vector<3x4xf32>
+ %2 = vector.transpose %1, [1, 0] : vector<3x4xf32> to vector<4x3xf32>
+ %3 = vector.transpose %2, [0, 1] : vector<4x3xf32> to vector<4x3xf32>
+ // CHECK: [[ADD:%.*]] = addf [[ARG]], [[ARG]]
+ %4 = addf %2, %3 : vector<4x3xf32>
// CHECK-NEXT: return [[ADD]]
- return %4 : vector<3x4xf32>
+ return %4 : vector<4x3xf32>
}
// -----
// CHECK-LABEL: transpose_3D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
-func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<2x3x4xf32> {
- // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [1, 2, 0]
+func @transpose_3D_sequence(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
+ // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0]
%0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32>
+ %1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
// CHECK-NOT: transpose
- %1 = vector.transpose %0, [0, 1, 2] : vector<3x2x4xf32> to vector<3x2x4xf32>
- // CHECK: [[T2:%.*]] = vector.transpose [[T0]], [1, 0, 2]
- %2 = vector.transpose %1, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
- // CHECK: [[ADD:%.*]] = addf [[T2]], [[T2]]
- %3 = addf %2, %2 : vector<2x3x4xf32>
+ %2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ %3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+ // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]]
+ %4 = mulf %1, %3 : vector<2x3x4xf32>
+ // CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0]
+ %5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
// CHECK-NOT: transpose
- %4 = vector.transpose %3, [0, 1, 2] : vector<2x3x4xf32> to vector<2x3x4xf32>
+ %6 = vector.transpose %3, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ // CHECK: [[ADD:%.*]] = addf [[T5]], [[ARG]]
+ %7 = addf %5, %6 : vector<4x3x2xf32>
// CHECK-NEXT: return [[ADD]]
- return %4 : vector<2x3x4xf32>
+ return %7 : vector<4x3x2xf32>
}