LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op,
ArrayRef<int64_t> targetShape);
+/// Options that control the vector unrolling.
+struct UnrollVectorOptions {
+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+ /// Callback function that indicates whether vector unrolling should be
+ /// attempted on the operation.
+ FilterConstraintFnType filterConstraint = nullptr;
+ UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) {
+ filterConstraint = constraint;
+ return *this;
+ }
+
+ using NativeShapeFnType =
+ std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
+ /// Function that returns the shape of the vector to unroll to for a given
+ /// operation. The unrolling is aborted if the function returns `llvm::None`.
+ NativeShapeFnType nativeShape = nullptr;
+ UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
+ nativeShape = fn;
+ return *this;
+ }
+
+ /// Set the native shape to use for unrolling.
+ UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
+ SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
+ nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
+ return tsShape;
+ };
+ return *this;
+ }
+};
/// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape`
/// declaratively.
template <typename OpTy>
struct UnrollVectorPattern : public OpRewritePattern<OpTy> {
using FilterConstraintType = std::function<LogicalResult(OpTy op)>;
- UnrollVectorPattern(
- ArrayRef<int64_t> targetShape, MLIRContext *context,
- FilterConstraintType constraint = [](OpTy op) { return success(); })
- : OpRewritePattern<OpTy>(context),
- targetShape(targetShape.begin(), targetShape.end()),
- filter(constraint) {}
+ UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options)
+ : OpRewritePattern<OpTy>(context), options(options) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (failed(filter(op)))
+ if (options.filterConstraint && failed(options.filterConstraint(op)))
return failure();
+ if (!options.nativeShape) {
+ return op.emitError("vector unrolling expects the native shape or native"
+ "shape call back function to be set");
+ }
auto unrollableVectorOp =
dyn_cast<VectorUnrollOpInterface>(op.getOperation());
if (!unrollableVectorOp)
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape)
return failure();
- auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, targetShape);
+ Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
+ if (!targetShape)
+ return op.emitError("failed to get target shape for vector unroll");
+ auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return failure();
if (std::is_same<OpTy, TransferWriteOp>::value) {
- if (failed(unrollTransferWriteOp(rewriter, op, targetShape)))
+ if (failed(unrollTransferWriteOp(rewriter, op, *targetShape)))
return failure();
rewriter.eraseOp(op);
return success();
}
if (op.getOperation()->getNumResults() != 1)
return failure();
- auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape);
+ auto resultVector = unrollSingleResultVectorOp(rewriter, op, *targetShape);
if (resultVector.size() != 1)
return failure();
rewriter.replaceOp(op, resultVector.front());
}
private:
- SmallVector<int64_t, 4> targetShape;
- FilterConstraintType filter;
+ UnrollVectorOptions options;
};
/// Split a vector.transfer operation into an unmasked fastpath and a slowpath.
--- /dev/null
+// RUN: mlir-opt %s -test-vector-unrolling-patterns=unroll-based-on-type | FileCheck %s
+
+func @vector_contract_f32(%lhs : vector<8x8xf32>, %rhs : vector<8x8xf32>,
+ %init : vector<8x8xf32>) -> vector<8x8xf32> {
+ %0 = vector.contract
+ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %lhs, %rhs, %init : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+// CHECK-LABEL: func @vector_contract_f32
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x2xf32>, vector<4x2xf32> into vector<4x4xf32>
+// CHECK: return
+
+func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
+ %init : vector<8x8xf16>) -> vector<8x8xf16> {
+ %0 = vector.contract
+ {indexing_maps = [affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (j, k)>,
+ affine_map<(i, j, k) -> (i, j)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ %lhs, %rhs, %init : vector<8x8xf16>, vector<8x8xf16> into vector<8x8xf16>
+ return %0 : vector<8x8xf16>
+}
+// CHECK-LABEL: func @vector_contract_f16
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: vector.contract {
+// CHECK-SAME: vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
+// CHECK: return
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *ctx = &getContext();
- patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<AddFOp>>(
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
- ArrayRef<int64_t>{2, 2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
+ TestVectorUnrollingPatterns() = default;
+ TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
- patterns.insert<UnrollVectorPattern<AddFOp>>(ArrayRef<int64_t>{2, 2}, ctx);
- patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
- ArrayRef<int64_t>{2, 2, 2}, ctx);
+ patterns.insert<UnrollVectorPattern<AddFOp>>(
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
+
+ if (unrollBasedOnType) {
+ UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
+ [](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
+ vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
+ SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
+ if (auto floatType = contractOp.getLhsType()
+ .getElementType()
+ .dyn_cast<FloatType>()) {
+ if (floatType.getWidth() == 16) {
+ nativeShape[2] = 4;
+ }
+ }
+ return nativeShape;
+ };
+ patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+ ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn));
+ } else {
+ patterns.insert<UnrollVectorPattern<vector::ContractionOp>>(
+ ctx,
+ UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2, 2}));
+ }
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
+
+ Option<bool> unrollBasedOnType{
+ *this, "unroll-based-on-type",
+ llvm::cl::desc("Set the unroll factor based on type of the operation"),
+ llvm::cl::init(false)};
};
struct TestVectorDistributePatterns
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
patterns.insert<UnrollVectorPattern<vector::TransferReadOp>>(
- ArrayRef<int64_t>{2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
patterns.insert<UnrollVectorPattern<vector::TransferWriteOp>>(
- ArrayRef<int64_t>{2, 2}, ctx);
+ ctx, UnrollVectorOptions().setNativeShape(ArrayRef<int64_t>{2, 2}));
populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
populateVectorToVectorTransformationPatterns(patterns, ctx);
applyPatternsAndFoldGreedily(getFunction(), patterns);