[mlir] Add primitive transform pattern to rewrite linalg.fill into vector.broadcast...
authorAhmed Taei <ataei@google.com>
Sat, 25 Jan 2020 00:23:16 +0000 (16:23 -0800)
committerAhmed Taei <ataei@google.com>
Tue, 28 Jan 2020 19:21:56 +0000 (11:21 -0800)
Summary:
This diff adds a transformation patter to rewrite linalg.fill as broadcasting a scaler into a vector.
It uses the same preconditioning as matmul (memory is contiguous).

Reviewers: nicolasvasilache

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73391

mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

index aed392c..618fcfe 100644 (file)
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/VectorOps/VectorOps.h"
 #include "mlir/EDSC/Helpers.h"
+#include "mlir/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>
@@ -156,8 +158,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
          genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
 }
 
-// TODO(ntv): This is in fact much more general than just vectorization for
-// matmul ops.
+// TODO(ntv, ataei): This is in fact much more general than just vectorization
+// for matmul and fill ops.
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -167,7 +169,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
-  if (isa<linalg::MatmulOp>(op))
+  if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
     return success();
 
   auto genericOp = dyn_cast<linalg::GenericOp>(op);
@@ -189,28 +191,41 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
 
 SmallVector<Value, 0> mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter,
                                                       Operation *op) {
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-                       "]: Rewrite linalg op as vector.contract: "
-                    << *op << ":\n");
+  using edsc::intrinsics::std_load;
+  using edsc::intrinsics::std_store;
+  using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
+  using vector_broadcast = edsc::intrinsics::ValueBuilder<vector::BroadcastOp>;
+  using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
 
   assert(succeeded(vectorizeLinalgOpPrecondition(op)) &&
          "DRR failure case must be a precondition");
-
   auto linalgOp = cast<linalg::LinalgOp>(op);
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   edsc::ScopedContext scope(rewriter, op->getLoc());
-  using edsc::intrinsics::std_load;
-  using edsc::intrinsics::std_store;
-  using vector_contract = edsc::intrinsics::ValueBuilder<vector::ContractionOp>;
-  using vector_type_cast = edsc::intrinsics::ValueBuilder<vector::TypeCastOp>;
-  auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
-  auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
-  auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
-  auto vC = std_load(vectorMemRefC);
-  auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
-                              linalgOp.iterator_types());
-  std_store(vRes, vectorMemRefC);
+
+  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+    // Vectorize fill as a vector.broadcast.
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                         "]: Rewrite linalg.fill as vector.broadcast: "
+                      << *op << ":\n");
+    auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0));
+    auto dstVec = std_load(dstMemrefVec);
+    auto resVec = vector_broadcast(dstVec, fillOp.value());
+    std_store(resVec, dstMemrefVec);
+  } else {
+    // Vectorize other ops as vector contraction (currently only matmul).
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                         "]: Rewrite linalg op as vector.contract: "
+                      << *op << ":\n");
+    auto vA = std_load(vector_type_cast(linalgOp.getInput(0)));
+    auto vB = std_load(vector_type_cast(linalgOp.getInput(1)));
+    auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0));
+    auto vC = std_load(vectorMemRefC);
+    auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(),
+                                linalgOp.iterator_types());
+    std_store(vRes, vectorMemRefC);
+  }
   return {};
 }
 
index 5440729..520be05 100644 (file)
@@ -205,6 +205,13 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
 //       CHECK: vector.contract {{.*}} :
 //                vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
 
+func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
+  linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, f32
+  return
+}
+// CHECK-LABEL: func @test_vectorize_fill
+//       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+
 func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
           %d = mulf %a, %b: f32
           %e = addf %c, %d: f32
index 0d43fdf..795247a 100644 (file)
@@ -105,6 +105,12 @@ def : Pattern<(MatmulOp:$op $_, $_, $_),
                 HasLinalgTransformMarker<"VECTORIZE">,
                 PreconditionVectorizeLinalgOp
                ]>>)]>;
+def : Pattern<(FillOp:$op $_, $_),
+              [(VectorizeLinalgOp)],
+              [(Constraint<And<[
+                HasLinalgTransformMarker<"VECTORIZE">,
+                PreconditionVectorizeLinalgOp
+               ]>>)]>;
 def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
               [(VectorizeLinalgOp)],
               [(Constraint<And<[
@@ -112,6 +118,7 @@ def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
                 PreconditionVectorizeLinalgOp
                ]>>)]>;
 
+
 //===----------------------------------------------------------------------===//
 // Linalg generic permutation patterns.
 //===----------------------------------------------------------------------===//