From b6204b995eaa2ec771f947a2109bd2ef338e688c Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Fri, 23 Oct 2020 12:56:12 -0700 Subject: [PATCH] [mlir][Vector] Introduce UnrollVectorOptions to control vector unrolling. The current pattern for vector unrolling takes the native shape to unroll to at pattern instantiation time, but the native shape might defer based on the types of the operand. Introduce a UnrollVectorOptions struct which allows for using a function that will return the native shape based on the operation. Move other options of unrolling like `filterConstraints` into this struct. Differential Revision: https://reviews.llvm.org/D89744 --- .../include/mlir/Dialect/Vector/VectorTransforms.h | 56 ++++++++++++---- .../test/Dialect/Vector/vector-unroll-options.mlir | 75 ++++++++++++++++++++++ mlir/test/lib/Transforms/TestVectorTransforms.cpp | 44 +++++++++++-- 3 files changed, 156 insertions(+), 19 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-unroll-options.mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h index 157084a..a1cf90c 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -85,21 +85,51 @@ SmallVector unrollSingleResultVectorOp(OpBuilder &builder, LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, ArrayRef targetShape); +/// Options that control the vector unrolling. +struct UnrollVectorOptions { + using FilterConstraintFnType = std::function; + /// Callback function that indicates whether vector unrolling should be + /// attempted on the operation. + FilterConstraintFnType filterConstraint = nullptr; + UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) { + filterConstraint = constraint; + return *this; + } + + using NativeShapeFnType = + std::function>(Operation *op)>; + /// Function that returns the shape of the vector to unroll to for a given + /// operation. The unrolling is aborted if the function returns `llvm::None`. + NativeShapeFnType nativeShape = nullptr; + UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { + nativeShape = fn; + return *this; + } + + /// Set the native shape to use for unrolling. + UnrollVectorOptions &setNativeShape(ArrayRef shape) { + SmallVector tsShape(shape.begin(), shape.end()); + nativeShape = [=](Operation *) -> Optional> { + return tsShape; + }; + return *this; + } +}; /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` /// declaratively. template struct UnrollVectorPattern : public OpRewritePattern { using FilterConstraintType = std::function; - UnrollVectorPattern( - ArrayRef targetShape, MLIRContext *context, - FilterConstraintType constraint = [](OpTy op) { return success(); }) - : OpRewritePattern(context), - targetShape(targetShape.begin(), targetShape.end()), - filter(constraint) {} + UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) + : OpRewritePattern(context), options(options) {} LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - if (failed(filter(op))) + if (options.filterConstraint && failed(options.filterConstraint(op))) return failure(); + if (!options.nativeShape) { + return op.emitError("vector unrolling expects the native shape or native" + "shape call back function to be set"); + } auto unrollableVectorOp = dyn_cast(op.getOperation()); if (!unrollableVectorOp) @@ -107,19 +137,22 @@ struct UnrollVectorPattern : public OpRewritePattern { auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) return failure(); - auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape); + Optional> targetShape = options.nativeShape(op); + if (!targetShape) + return op.emitError("failed to get target shape for vector unroll"); + auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); if (std::is_same::value) { - if (failed(unrollTransferWriteOp(rewriter, op, targetShape))) + if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) return failure(); rewriter.eraseOp(op); return success(); } if (op.getOperation()->getNumResults() != 1) return failure(); - auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape); + auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape); if (resultVector.size() != 1) return failure(); rewriter.replaceOp(op, resultVector.front()); @@ -127,8 +160,7 @@ struct UnrollVectorPattern : public OpRewritePattern { } private: - SmallVector targetShape; - FilterConstraintType filter; + UnrollVectorOptions options; }; /// Split a vector.transfer operation into an unmasked fastpath and a slowpath. diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir new file mode 100644 index 0000000..705d4ab --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s + +func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>, + %init : vector<8x8xf32>) -> vector<8x8xf32> { + %0 = vector.contract + {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, k)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"]} + %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + return %0 : vector<8x8xf32> +} +// CHECK-LABEL: func @vector_contract_f32 +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32> +// CHECK: return + +func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>, + %init : vector<8x8xf16>) -> vector<8x8xf16> { + %0 = vector.contract + {indexing_maps = [affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (j, k)>, + affine_map<(i, j, k) -> (i, j)>], + iterator_types = ["parallel", "parallel", "reduction"]} + %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16> + return %0 : vector<8x8xf16> +} +// CHECK-LABEL: func @vector_contract_f16 +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: vector.contract { +// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16> +// CHECK: return diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp index 52d0f7b..5369ab5 100644 --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -26,9 +27,10 @@ struct TestVectorToVectorConversion void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert>(ArrayRef{2, 2}, ctx); + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); patterns.insert>( - ArrayRef{2, 2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); @@ -113,16 +115,44 @@ struct TestVectorContractionConversion struct TestVectorUnrollingPatterns : public PassWrapper { + TestVectorUnrollingPatterns() = default; + TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {} void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>(ArrayRef{2, 2}, ctx); - patterns.insert>( - ArrayRef{2, 2, 2}, ctx); + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + + if (unrollBasedOnType) { + UnrollVectorOptions::NativeShapeFnType nativeShapeFn = + [](Operation *op) -> Optional> { + vector::ContractionOp contractOp = cast(op); + SmallVector nativeShape = {4, 4, 2}; + if (auto floatType = contractOp.getLhsType() + .getElementType() + .dyn_cast()) { + if (floatType.getWidth() == 16) { + nativeShape[2] = 4; + } + } + return nativeShape; + }; + patterns.insert>( + ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn)); + } else { + patterns.insert>( + ctx, + UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); } + + Option unrollBasedOnType{ + *this, "unroll-based-on-type", + llvm::cl::desc("Set the unroll factor based on type of the operation"), + llvm::cl::init(false)}; }; struct TestVectorDistributePatterns @@ -165,9 +195,9 @@ struct TestVectorTransferUnrollingPatterns MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; patterns.insert>( - ArrayRef{2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); patterns.insert>( - ArrayRef{2, 2}, ctx); + ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); -- 2.7.4