[mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling.
authorMaheshRavishankar <ravishankarm@google.com>
Fri, 23 Oct 2020 19:56:12 +0000 (12:56 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Fri, 23 Oct 2020 20:52:26 +0000 (13:52 -0700)
The current pattern for vector unrolling takes the native shape to
unroll to at pattern instantiation time, but the native shape might
defer based on the types of the operand. Introduce a
UnrollVectorOptions struct which allows for using a function that will
return the native shape based on the operation. Move other options of
unrolling like `filterConstraints` into this struct.

Differential Revision: https://reviews.llvm.org/D89744

mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/test/Dialect/Vector/vector-unroll-options.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index 157084a..a1cf90c 100644 (file)
@@ -85,21 +85,51 @@ SmallVector<Value, 1> unrollSingleResultVectorOp(OpBuilder &builder,
 LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
                                     ArrayRef<int64_t> targetShape);
 
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+  /// Callback function that indicates whether vector unrolling should be
+  /// attempted on the operation.
+  FilterConstraintFnType filterConstraint = nullptr;
+  UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
+    filterConstraint = constraint;
+    return *this;
+  }
+
+  using NativeShapeFnType =
+      std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+  /// Function that returns the shape of the vector to unroll to for a given
+  /// operation. The unrolling is aborted if the function returns `llvm::None`.
+  NativeShapeFnType nativeShape = nullptr;
+  UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+    nativeShape = fn;
+    return *this;
+  }
+
+  /// Set the native shape to use for unrolling.
+  UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+    SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+    nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+      return tsShape;
+    };
+    return *this;
+  }
+};
 /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
 /// declaratively.
 template <typename OpTy>
 struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
-  UnrollVectorPattern(
-      ArrayRef<int64_t> targetShape, MLIRContext *context,
-      FilterConstraintType constraint = [](OpTy op) { return success(); })
-      : OpRewritePattern<OpTy>(context),
-        targetShape(targetShape.begin(), targetShape.end()),
-        filter(constraint) {}
+  UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
+      : OpRewritePattern<OpTy>(context), options(options) {}
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (failed(filter(op)))
+    if (options.filterConstraint && failed(options.filterConstraint(op)))
       return failure();
+    if (!options.nativeShape) {
+      return op.emitError("vector unrolling expects the native shape or native"
+                          "shape call back function to be set");
+    }
     auto unrollableVectorOp =
         dyn_cast<VectorUnrollOpInterface>(op.getOperation());
     if (!unrollableVectorOp)
@@ -107,19 +137,22 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
     auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
     if (!maybeUnrollShape)
       return failure();
-    auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+    Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+    if (!targetShape)
+      return op.emitError("failed to get target shape for vector unroll");
+    auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
     if (!maybeShapeRatio ||
         llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
       return failure();
     if (std::is_same<OpTy, TransferWriteOp>::value) {
-      if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+      if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
         return failure();
       rewriter.eraseOp(op);
       return success();
     }
     if (op.getOperation()->getNumResults() != 1)
       return failure();
-    auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+    auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape);
     if (resultVector.size() != 1)
       return failure();
     rewriter.replaceOp(op, resultVector.front());
@@ -127,8 +160,7 @@ struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
   }
 
 private:
-  SmallVector<int64_t, 4> targetShape;
-  FilterConstraintType filter;
+  UnrollVectorOptions options;
 };
 
 /// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
new file mode 100644 (file)
index 0000000..705d4ab
--- /dev/null
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+
+func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+                          %init : vector<8x8xf32>) -> vector<8x8xf32> {
+  %0 = vector.contract
+         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+                           affine_map<(i, j, k) -> (j, k)>,
+                           affine_map<(i, j, k) -> (i, j)>],
+          iterator_types = ["parallel", "parallel", "reduction"]}
+       %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+  return %0 : vector<8x8xf32>
+}
+// CHECK-LABEL: func @vector_contract_f32
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+//       CHECK:   return
+
+func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
+                          %init : vector<8x8xf16>) -> vector<8x8xf16> {
+  %0 = vector.contract
+         {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+                           affine_map<(i, j, k) -> (j, k)>,
+                           affine_map<(i, j, k) -> (i, j)>],
+          iterator_types = ["parallel", "parallel", "reduction"]}
+       %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
+  return %0 : vector<8x8xf16>
+}
+// CHECK-LABEL: func @vector_contract_f16
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   vector.contract {
+//  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+//       CHECK:   return
index 52d0f7b..5369ab5 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
@@ -26,9 +27,10 @@ struct TestVectorToVectorConversion
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto *ctx = &getContext();
-    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<AddFOp>>(
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-        ArrayRef<int64_t>{2, 2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
@@ -113,16 +115,44 @@ struct TestVectorContractionConversion
 
 struct TestVectorUnrollingPatterns
     : public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+  TestVectorUnrollingPatterns() = default;
+  TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
   void runOnFunction() override {
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
-    patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
-    patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
-        ArrayRef<int64_t>{2, 2, 2}, ctx);
+    patterns.insert<UnrollVectorPattern<AddFOp>>(
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+
+    if (unrollBasedOnType) {
+      UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
+          [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+        vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
+        SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
+        if (auto floatType = contractOp.getLhsType()
+                                 .getElementType()
+                                 .dyn_cast<FloatType>()) {
+          if (floatType.getWidth() == 16) {
+            nativeShape[2] = 4;
+          }
+        }
+        return nativeShape;
+      };
+      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+          ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
+    } else {
+      patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+          ctx,
+          UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+    }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
   }
+
+  Option<bool> unrollBasedOnType{
+      *this, "unroll-based-on-type",
+      llvm::cl::desc("Set the unroll factor based on type of the operation"),
+      llvm::cl::init(false)};
 };
 
 struct TestVectorDistributePatterns
@@ -165,9 +195,9 @@ struct TestVectorTransferUnrollingPatterns
     MLIRContext *ctx = &getContext();
     OwningRewritePatternList patterns;
     patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
-        ArrayRef<int64_t>{2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
-        ArrayRef<int64_t>{2, 2}, ctx);
+        ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), patterns);