From: aartbik Date: Wed, 3 Jun 2020 21:13:22 +0000 (-0700) Subject: [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose' X-Git-Tag: llvmorg-12-init~4222 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6391da98f43a995fe3dfb96a5376b2d9c652ed87;p=platform%2Fupstream%2Fllvm.git [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose' Summary: Progressive lowering of vector.transpose into an operation that is closer to an intrinsic, and thus the hardware ISA. Currently under the common vector transform testing flag, as we prepare deploying this transformation in the LLVM lowering pipeline. Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse Reviewed By: nicolasvasilache, ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits Tags: #llvm, #mlir Differential Revision: https://reviews.llvm.org/D80772 --- diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index 8c8424e8ef8f..def0d24adcf5 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,9 +53,19 @@ enum class VectorContractLowering { /// Lower to `vector.outerproduct`. OuterProduct = 2, }; +/// Enum to control the lowering of `vector.transpose` operations. +enum class VectorTransposeLowering { + // Lower transpose into element-wise extract and inserts. + EltWise = 0, + /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix + /// intrinsics. + Flat = 1, +}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt) { vectorContractLowering = opt; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 4065d19b6c8a..365795fb9cab 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1206,6 +1206,7 @@ def Vector_ShapeCastOp : } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; } def Vector_TypeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 63891d1004d4..21b62ceaa689 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1667,6 +1667,19 @@ static LogicalResult verify(ShapeCastOp op) { return success(); } +OpFoldResult ShapeCastOp::fold(ArrayRef operands) { + // Nop shape cast. + if (source().getType() == result().getType()) + return source(); + + // Canceling shape casts. + if (auto otherOp = source().getDefiningOp()) + if (result().getType() == otherOp.source().getType()) + return otherOp.source(); + + return {}; +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 491ad62affcb..82c27387bd6e 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1186,6 +1186,11 @@ class TransposeOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -1197,6 +1202,22 @@ public: for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); + // Handle a true 2-D matrix transpose differently when requested. + if (vectorTransformsOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Flat && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, op.vector()); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } + // Generate fully unrolled extract/insert ops. Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); @@ -1230,6 +1251,9 @@ private: } return result; } + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; }; /// Progressive lowering of OuterProductOp. @@ -1829,9 +1853,9 @@ void mlir::vector::populateVectorContractLoweringPatterns( ConstantMaskOpLowering, OuterProductOpLowering, ShapeCastOp2DDownCastRewritePattern, - ShapeCastOp2DUpCastRewritePattern, - TransposeOpLowering>(context); - patterns.insert(context); + patterns.insert(parameters, context); // clang-format on diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index 1dd2f377a29c..491f18fdf5c9 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -319,6 +319,26 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { return %0 : vector<3x2xf32> } + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir new file mode 100644 index 000000000000..e715755738de --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure + +// Tests for lowering 2-D vector.transpose into vector.flat_transpose. +// +// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and +// ShapeCastOp2DUpCastRewritePattern too early in +// the greedy rewriting patterns misses opportunities +// to fold shape casts! + +// No shape cast folding expected. +// +// CHECK-LABEL: func @transpose44_44( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// Folds preceding shape cast as expected, +// no following shape cast folding expected. +// +// CHECK-LABEL: func @transpose16_44( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// No preceding shape cast folding expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose44_16( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Folds preceding shape cast as expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose16_16( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// +func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32> + return %2 : vector<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 65024dbe3acd..22585fde4ff7 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -47,10 +47,14 @@ struct TestVectorContractionConversion TestVectorContractionConversion(const TestVectorContractionConversion &pass) { } - Option lowerToLLVMMatrixIntrinsics{ + Option lowerToFlatMatrix{ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option lowerToFlatTranspose{ + *this, "vector-flat-transpose", + llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), + llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -67,10 +71,14 @@ struct TestVectorContractionConversion return; } - VectorContractLowering lowering = VectorContractLowering::FMA; - if (lowerToLLVMMatrixIntrinsics) - lowering = VectorContractLowering::Matmul; - VectorTransformsOptions options{lowering}; + VectorContractLowering contractLowering = VectorContractLowering::FMA; + if (lowerToFlatMatrix) + contractLowering = VectorContractLowering::Matmul; + VectorTransposeLowering transposeLowering = + VectorTransposeLowering::EltWise; + if (lowerToFlatTranspose) + transposeLowering = VectorTransposeLowering::Flat; + VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }