From: aartbik Date: Wed, 3 Jun 2020 21:13:22 +0000 (-0700) Subject: [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose' X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6391da98f43a995fe3dfb96a5376b2d9c652ed87;p=platform%2Fupstream%2Fllvm.git [mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose' Summary: Progressive lowering of vector.transpose into an operation that is closer to an intrinsic, and thus the hardware ISA. Currently under the common vector transform testing flag, as we prepare deploying this transformation in the LLVM lowering pipeline. Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse Reviewed By: nicolasvasilache, ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits Tags: #llvm, #mlir Differential Revision: https://reviews.llvm.org/D80772 --- diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h index 8c8424e8..def0d24 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -53,9 +53,19 @@ enum class VectorContractLowering { /// Lower to `vector.outerproduct`. OuterProduct = 2, }; +/// Enum to control the lowering of `vector.transpose` operations. +enum class VectorTransposeLowering { + // Lower transpose into element-wise extract and inserts. + EltWise = 0, + /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix + /// intrinsics. + Flat = 1, +}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt) { vectorContractLowering = opt; diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 4065d19..365795f 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1206,6 +1206,7 @@ def Vector_ShapeCastOp : } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasFolder = 1; } def Vector_TypeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 63891d1..21b62ce 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1667,6 +1667,19 @@ static LogicalResult verify(ShapeCastOp op) { return success(); } +OpFoldResult ShapeCastOp::fold(ArrayRef operands) { + // Nop shape cast. + if (source().getType() == result().getType()) + return source(); + + // Canceling shape casts. + if (auto otherOp = source().getDefiningOp()) + if (result().getType() == otherOp.source().getType()) + return otherOp.source(); + + return {}; +} + //===----------------------------------------------------------------------===// // TypeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index 491ad62..82c2738 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1186,6 +1186,11 @@ class TransposeOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); @@ -1197,6 +1202,22 @@ public: for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); + // Handle a true 2-D matrix transpose differently when requested. + if (vectorTransformsOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Flat && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, op.vector()); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } + // Generate fully unrolled extract/insert ops. Value result = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); @@ -1230,6 +1251,9 @@ private: } return result; } + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; }; /// Progressive lowering of OuterProductOp. @@ -1829,9 +1853,9 @@ void mlir::vector::populateVectorContractLoweringPatterns( ConstantMaskOpLowering, OuterProductOpLowering, ShapeCastOp2DDownCastRewritePattern, - ShapeCastOp2DUpCastRewritePattern, - TransposeOpLowering>(context); - patterns.insert(context); + patterns.insert(parameters, context); // clang-format on diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index 1dd2f37..491f18f 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -319,6 +319,26 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { return %0 : vector<3x2xf32> } + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + // Shape up and downcasts for 2-D vectors, for supporting conversion to // llvm.matrix operations // CHECK-LABEL: func @shape_casts diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir new file mode 100644 index 0000000..e7157557 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure + +// Tests for lowering 2-D vector.transpose into vector.flat_transpose. +// +// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and +// ShapeCastOp2DUpCastRewritePattern too early in +// the greedy rewriting patterns misses opportunities +// to fold shape casts! + +// No shape cast folding expected. +// +// CHECK-LABEL: func @transpose44_44( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// Folds preceding shape cast as expected, +// no following shape cast folding expected. +// +// CHECK-LABEL: func @transpose16_44( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// CHECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> +// +func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// No preceding shape cast folding expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose44_16( +// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> +// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Folds preceding shape cast as expected, +// but FAILS to fold following cast. +// +// CHECK-LABEL: func @transpose16_16( +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> +// +func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> + %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32> + return %2 : vector<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 65024db..22585fd 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -47,10 +47,14 @@ struct TestVectorContractionConversion TestVectorContractionConversion(const TestVectorContractionConversion &pass) { } - Option lowerToLLVMMatrixIntrinsics{ + Option lowerToFlatMatrix{ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option lowerToFlatTranspose{ + *this, "vector-flat-transpose", + llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), + llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -67,10 +71,14 @@ struct TestVectorContractionConversion return; } - VectorContractLowering lowering = VectorContractLowering::FMA; - if (lowerToLLVMMatrixIntrinsics) - lowering = VectorContractLowering::Matmul; - VectorTransformsOptions options{lowering}; + VectorContractLowering contractLowering = VectorContractLowering::FMA; + if (lowerToFlatMatrix) + contractLowering = VectorContractLowering::Matmul; + VectorTransposeLowering transposeLowering = + VectorTransposeLowering::EltWise; + if (lowerToFlatTranspose) + transposeLowering = VectorTransposeLowering::Flat; + VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }