[mlir][Linalg] Make contraction vectorization use vector transfers
authorNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 17:09:55 +0000 (13:09 -0400)
committerNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 19:03:46 +0000 (15:03 -0400)
This revision replaces the load + vector.type_cast by appropriate vector transfer
operations. These play more nicely with other vector abstractions and canonicalization
patterns and lower to load/store with or without masks when appropriate.

Differential Revision: https://reviews.llvm.org/D80809

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir

index 8fa0aa3..7639613 100644 (file)
@@ -120,14 +120,30 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   // Vectorize other ops as vector contraction (currently only matmul).
   LLVM_DEBUG(dbgs() << dbgPref
                     << "Rewrite linalg op as vector.contract: " << *op);
+  auto extractVectorTypeFromScalarView = [](Value v) {
+    MemRefType mt = v.getType().cast<MemRefType>();
+    return VectorType::get(mt.getShape(), mt.getElementType());
+  };
   auto linalgOp = cast<linalg::LinalgOp>(op);
-  Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
-  Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
-  Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
-  Value c = std_load(memref);
+  Value viewA = linalgOp.getInput(0);
+  Value viewB = linalgOp.getInput(1);
+  Value viewC = linalgOp.getOutputBuffer(0);
+  Value zero = std_constant_index(0);
+  SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+                                 zero);
+  SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+                                 zero);
+  Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+                                 indicesA);
+  Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+                                 indicesB);
+  Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+                                 indicesC);
   Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
                               linalgOp.iterator_types());
-  std_store(res, memref);
+  vector_transfer_write(res, viewC, indicesC);
 }
 
 /// Check whether there is any interleaved use of any `values` between `firstOp`
index 4c46c74..41fa3fd 100644 (file)
@@ -106,14 +106,11 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   return
 }
 // CHECK-LABEL: func @vectorization_test
-//       CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-//       CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-//       CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+//       CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
 //       CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-//       CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+//       CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
 
 func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
                          %C: memref<8x32xf32>) {