/// Lower to `vector.outerproduct`.
OuterProduct = 2,
};
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+ // Lower transpose into element-wise extract and inserts.
+ EltWise = 0,
+ /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+ /// intrinsics.
+ Flat = 1,
+};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+ VectorTransposeLowering vectorTransposeLowering =
+ VectorTransposeLowering::EltWise;
VectorTransformsOptions &
setVectorTransformsOptions(VectorContractLowering opt) {
vectorContractLowering = opt;
}
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+ let hasFolder = 1;
}
def Vector_TypeCastOp :
return success();
}
+OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+ // Nop shape cast.
+ if (source().getType() == result().getType())
+ return source();
+
+ // Canceling shape casts.
+ if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+ if (result().getType() == otherOp.source().getType())
+ return otherOp.source();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// TypeCastOp
//===----------------------------------------------------------------------===//
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+ MLIRContext *context)
+ : OpRewritePattern<vector::TransposeOp>(context),
+ vectorTransformsOptions(vectorTransformsOptions) {}
+
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
for (auto attr : op.transp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
+ // Handle a true 2-D matrix transpose differently when requested.
+ if (vectorTransformsOptions.vectorTransposeLowering ==
+ vector::VectorTransposeLowering::Flat &&
+ resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+ Type flattenedType =
+ VectorType::get(resType.getNumElements(), resType.getElementType());
+ auto matrix =
+ rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+ auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+ auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+ Value trans = rewriter.create<vector::FlatTransposeOp>(
+ loc, flattenedType, matrix, rows, columns);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+ return success();
+ }
+
// Generate fully unrolled extract/insert ops.
Value result = rewriter.create<ConstantOp>(loc, resType,
rewriter.getZeroAttr(resType));
}
return result;
}
+
+ /// Options to control the vector patterns.
+ vector::VectorTransformsOptions vectorTransformsOptions;
};
/// Progressive lowering of OuterProductOp.
ConstantMaskOpLowering,
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern,
- TransposeOpLowering>(context);
- patterns.insert<ContractionOpLowering,
+ ShapeCastOp2DUpCastRewritePattern>(context);
+ patterns.insert<TransposeOpLowering,
+ ContractionOpLowering,
ContractionOpToMatmulOpLowering,
ContractionOpToOuterProductOpLowering>(parameters, context);
// clang-format on
return %0 : vector<3x2xf32>
}
+
+// CHECK-LABEL: func @nop_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
+
+func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @cancel_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: return %[[A]] : vector<16xf32>
+
+func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
// Shape up and downcasts for 2-D vectors, for supporting conversion to
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
--- /dev/null
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure
+
+// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
+//
+// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and
+// ShapeCastOp2DUpCastRewritePattern too early in
+// the greedy rewriting patterns misses opportunities
+// to fold shape casts!
+
+// No shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose44_44(
+// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// Folds preceding shape cast as expected,
+// no following shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose16_44(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ return %1 : vector<4x4xf32>
+}
+
+// No preceding shape cast folding expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose44_16(
+// CHECK-SAME: %[[A:.*]]: vector<4x4xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// Folds preceding shape cast as expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose16_16(
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+//
+func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+ %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
+ return %2 : vector<16xf32>
+}
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
}
- Option<bool> lowerToLLVMMatrixIntrinsics{
+ Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
+ Option<bool> lowerToFlatTranspose{
+ *this, "vector-flat-transpose",
+ llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+ llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
return;
}
- VectorContractLowering lowering = VectorContractLowering::FMA;
- if (lowerToLLVMMatrixIntrinsics)
- lowering = VectorContractLowering::Matmul;
- VectorTransformsOptions options{lowering};
+ VectorContractLowering contractLowering = VectorContractLowering::FMA;
+ if (lowerToFlatMatrix)
+ contractLowering = VectorContractLowering::Matmul;
+ VectorTransposeLowering transposeLowering =
+ VectorTransposeLowering::EltWise;
+ if (lowerToFlatTranspose)
+ transposeLowering = VectorTransposeLowering::Flat;
+ VectorTransformsOptions options{contractLowering, transposeLowering};
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}