[mlir][vector] Add support for unrolling vector.fma
authorLei Zhang <antiagainst@google.com>
Tue, 16 Feb 2021 14:56:01 +0000 (09:56 -0500)
committerLei Zhang <antiagainst@google.com>
Tue, 16 Feb 2021 14:56:25 +0000 (09:56 -0500)
Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96781

mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/vector-unroll-options.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index a7c1223..cf18cd8 100644 (file)
@@ -583,8 +583,10 @@ def Vector_ExtractMapOp :
 }
 
 def Vector_FMAOp :
-  Op<Vector_Dialect, "fma", [NoSideEffect,
-                             AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
+  Op<Vector_Dialect, "fma", [
+       NoSideEffect, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
+       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
+     ]>,
     Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
     Results<(outs AnyVector:$result)> {
   let summary = "vector fused multiply-add";
index 671cd86..af884f9 100644 (file)
@@ -1259,6 +1259,14 @@ AffineMap calculateImplicitMap(MapOp op) {
 AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); }
 
 //===----------------------------------------------------------------------===//
+// FmaOp
+//===----------------------------------------------------------------------===//
+
+Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getVectorType().getShape());
+}
+
+//===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
@@ -2456,8 +2464,7 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
 }
 
 Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
-  auto s = getVectorType().getShape();
-  return SmallVector<int64_t, 4>{s.begin(), s.end()};
+  return llvm::to_vector<4>(getVectorType().getShape());
 }
 
 void TransferReadOp::getEffects(
index 705d4ab..581039c 100644 (file)
@@ -73,3 +73,10 @@ func @vector_contract_f16(%lhs : vector<8x8xf16>, %rhs : vector<8x8xf16>,
 //       CHECK:   vector.contract {
 //  CHECK-SAME:     vector<4x4xf16>, vector<4x4xf16> into vector<4x4xf16>
 //       CHECK:   return
+
+func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf32>) -> vector<4x4xf32> {
+  %0 = vector.fma %a, %b, %c: vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+//   CHECK-LABEL: func @vector_fma
+// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
index 61b1717..8ec970f 100644 (file)
@@ -151,8 +151,9 @@ struct TestVectorUnrollingPatterns
     patterns.insert<UnrollVectorPattern>(
         ctx, UnrollVectorOptions()
                  .setNativeShape(ArrayRef<int64_t>{2, 2})
-                 .setFilterConstraint(
-                     [](Operation *op) { return success(isa<AddFOp>(op)); }));
+                 .setFilterConstraint([](Operation *op) {
+                   return success(isa<AddFOp, vector::FMAOp>(op));
+                 }));
 
     if (unrollBasedOnType) {
       UnrollVectorOptions::NativeShapeFnType nativeShapeFn =