[mlir] [VectorOps] Use 'vector.flat_transpose' for 2-D 'vector.tranpose'
authoraartbik <ajcbik@google.com>
Wed, 3 Jun 2020 21:13:22 +0000 (14:13 -0700)
committeraartbik <ajcbik@google.com>
Wed, 3 Jun 2020 21:55:50 +0000 (14:55 -0700)
Summary:
Progressive lowering of vector.transpose into an operation that
is closer to an intrinsic, and thus the hardware ISA. Currently
under the common vector transform testing flag, as we prepare
deploying this transformation in the LLVM lowering pipeline.

Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse

Reviewed By: nicolasvasilache, ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits

Tags: #llvm, #mlir

Differential Revision: https://reviews.llvm.org/D80772

mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
mlir/test/Dialect/Vector/vector-flat-transforms.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index 8c8424e..def0d24 100644 (file)
@@ -53,9 +53,19 @@ enum class VectorContractLowering {
   /// Lower to `vector.outerproduct`.
   OuterProduct = 2,
 };
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+  // Lower transpose into element-wise extract and inserts.
+  EltWise = 0,
+  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+  /// intrinsics.
+  Flat = 1,
+};
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
   VectorContractLowering vectorContractLowering = VectorContractLowering::FMA;
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
   VectorTransformsOptions &
   setVectorTransformsOptions(VectorContractLowering opt) {
     vectorContractLowering = opt;
index 4065d19..365795f 100644 (file)
@@ -1206,6 +1206,7 @@ def Vector_ShapeCastOp :
     }
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
+  let hasFolder = 1;
 }
 
 def Vector_TypeCastOp :
index 63891d1..21b62ce 100644 (file)
@@ -1667,6 +1667,19 @@ static LogicalResult verify(ShapeCastOp op) {
   return success();
 }
 
+OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
+  // Nop shape cast.
+  if (source().getType() == result().getType())
+    return source();
+
+  // Canceling shape casts.
+  if (auto otherOp = source().getDefiningOp<ShapeCastOp>())
+    if (result().getType() == otherOp.source().getType())
+      return otherOp.source();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TypeCastOp
 //===----------------------------------------------------------------------===//
index 491ad62..82c2738 100644 (file)
@@ -1186,6 +1186,11 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
+  TransposeOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
+                      MLIRContext *context)
+      : OpRewritePattern<vector::TransposeOp>(context),
+        vectorTransformsOptions(vectorTransformsOptions) {}
+
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
@@ -1197,6 +1202,22 @@ public:
     for (auto attr : op.transp())
       transp.push_back(attr.cast<IntegerAttr>().getInt());
 
+    // Handle a true 2-D matrix transpose differently when requested.
+    if (vectorTransformsOptions.vectorTransposeLowering ==
+            vector::VectorTransposeLowering::Flat &&
+        resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) {
+      Type flattenedType =
+          VectorType::get(resType.getNumElements(), resType.getElementType());
+      auto matrix =
+          rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
+      auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
+      auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
+      Value trans = rewriter.create<vector::FlatTransposeOp>(
+          loc, flattenedType, matrix, rows, columns);
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
+      return success();
+    }
+
     // Generate fully unrolled extract/insert ops.
     Value result = rewriter.create<ConstantOp>(loc, resType,
                                                rewriter.getZeroAttr(resType));
@@ -1230,6 +1251,9 @@ private:
     }
     return result;
   }
+
+  /// Options to control the vector patterns.
+  vector::VectorTransformsOptions vectorTransformsOptions;
 };
 
 /// Progressive lowering of OuterProductOp.
@@ -1829,9 +1853,9 @@ void mlir::vector::populateVectorContractLoweringPatterns(
                   ConstantMaskOpLowering,
                   OuterProductOpLowering,
                   ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern,
-                  TransposeOpLowering>(context);
-  patterns.insert<ContractionOpLowering,
+                  ShapeCastOp2DUpCastRewritePattern>(context);
+  patterns.insert<TransposeOpLowering,
+                  ContractionOpLowering,
                   ContractionOpToMatmulOpLowering,
                   ContractionOpToOuterProductOpLowering>(parameters, context);
   // clang-format on
index 1dd2f37..491f18f 100644 (file)
@@ -319,6 +319,26 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
   return %0 : vector<3x2xf32>
 }
 
+
+// CHECK-LABEL: func @nop_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK:      return %[[A]] : vector<16xf32>
+
+func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: func @cancel_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<16xf32>
+// CHECK:      return %[[A]] : vector<16xf32>
+
+func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
 // Shape up and downcasts for 2-D vectors, for supporting conversion to
 // llvm.matrix operations
 // CHECK-LABEL: func @shape_casts
diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir
new file mode 100644 (file)
index 0000000..e715755
--- /dev/null
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s --dump-input-on-failure
+
+// Tests for lowering 2-D vector.transpose into vector.flat_transpose.
+//
+// TODO(ajcbik,ntv): having ShapeCastOp2DDownCastRewritePattern and
+//                   ShapeCastOp2DUpCastRewritePattern too early in
+//                   the greedy rewriting patterns misses opportunities
+//                   to fold shape casts!
+
+// No shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose44_44(
+// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
+// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK:       %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// Folds preceding shape cast as expected,
+// no following shape cast folding expected.
+//
+// CHECK-LABEL: func @transpose16_44(
+// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
+// CHECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+// CHECK:       %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32>
+//
+func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  return %1 : vector<4x4xf32>
+}
+
+// No preceding shape cast folding expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose44_16(
+// CHECK-SAME:  %[[A:.*]]: vector<4x4xf32>
+// CHECK:       %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32>
+// CHECK:       %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// Folds preceding shape cast as expected,
+// but FAILS to fold following cast.
+//
+// CHECK-LABEL: func @transpose16_16(
+// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
+// CHECK:       %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
+//
+func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
+  %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32>
+  %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32>
+  return %2 : vector<16xf32>
+}
index 65024db..22585fd 100644 (file)
@@ -47,10 +47,14 @@ struct TestVectorContractionConversion
   TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
   }
 
-  Option<bool> lowerToLLVMMatrixIntrinsics{
+  Option<bool> lowerToFlatMatrix{
       *this, "vector-lower-matrix-intrinsics",
       llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
       llvm::cl::init(false)};
+  Option<bool> lowerToFlatTranspose{
+      *this, "vector-flat-transpose",
+      llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
+      llvm::cl::init(false)};
   Option<bool> lowerToOuterProduct{
       *this, "vector-outerproduct",
       llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
@@ -67,10 +71,14 @@ struct TestVectorContractionConversion
       return;
     }
 
-    VectorContractLowering lowering = VectorContractLowering::FMA;
-    if (lowerToLLVMMatrixIntrinsics)
-      lowering = VectorContractLowering::Matmul;
-    VectorTransformsOptions options{lowering};
+    VectorContractLowering contractLowering = VectorContractLowering::FMA;
+    if (lowerToFlatMatrix)
+      contractLowering = VectorContractLowering::Matmul;
+    VectorTransposeLowering transposeLowering =
+        VectorTransposeLowering::EltWise;
+    if (lowerToFlatTranspose)
+      transposeLowering = VectorTransposeLowering::Flat;
+    VectorTransformsOptions options{contractLowering, transposeLowering};
     populateVectorContractLoweringPatterns(patterns, &getContext(), options);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }