[mlir][linalg] Add vectorization transform for CopyOp
authorThomas Raoux <thomasraoux@google.com>
Wed, 22 Jul 2020 19:16:29 +0000 (12:16 -0700)
committerThomas Raoux <thomasraoux@google.com>
Wed, 22 Jul 2020 19:40:42 +0000 (12:40 -0700)
CopyOp get vectorized to vector.transfer_read followed by vector.transfer_write

Differential Revision: https://reviews.llvm.org/D83739

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp

index 8e5da6a..23d89c2 100644 (file)
@@ -96,7 +96,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
 
-  if (isa<linalg::FillOp>(op))
+  if (isa<linalg::FillOp, linalg::CopyOp>(op))
     return success();
 
   return isContraction(op);
@@ -119,12 +119,6 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     return;
   }
 
-  assert(succeeded(isContraction(op)) && "Expected contraction");
-
-  // Vectorize other ops as vector contraction.
-  // TODO: interface.
-  LLVM_DEBUG(dbgs() << dbgPref
-                    << "Rewrite linalg op as vector.contract: " << *op);
   // In the case of 0-D memrefs, return null and special case to scalar load or
   // store later.
   auto extractVectorTypeFromScalarView = [](Value v) {
@@ -133,6 +127,49 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                ? VectorType()
                : VectorType::get(mt.getShape(), mt.getElementType());
   };
+
+  if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+    // Vectorize copy as a vector.transfer_read+vector.transfer_write.
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg.copy as vector.transfer_read + "
+                         "vector.transfer_write: "
+                      << *op);
+    Value zero = std_constant_index(0);
+    Value viewInput = copyOp.input();
+    Value viewOutput = copyOp.output();
+    Value vector;
+    if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
+      SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
+      if (copyOp.inputPermutation())
+        vector = vector_transfer_read(
+            extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
+            copyOp.inputPermutation().getValue());
+      else
+        vector =
+            vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
+                                 viewInput, indicesInput);
+    } else {
+      vector = std_load(viewInput).value;
+    }
+    if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
+      SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
+      if (copyOp.outputPermutation())
+        vector_transfer_write(vector, viewOutput, indicesOutput,
+                              copyOp.outputPermutation().getValue());
+      else
+        vector_transfer_write(vector, viewOutput, indicesOutput);
+    } else {
+      std_store(vector, viewOutput);
+    }
+    return;
+  }
+
+  assert(succeeded(isContraction(op)) && "Expected contraction");
+
+  // Vectorize other ops as vector contraction.
+  // TODO: interface.
+  LLVM_DEBUG(dbgs() << dbgPref
+                    << "Rewrite linalg op as vector.contract: " << *op);
   auto linalgOp = cast<linalg::LinalgOp>(op);
   Value viewA = linalgOp.getInput(0);
   Value viewB = linalgOp.getInput(1);
index 819b3b7..3f7d164 100644 (file)
@@ -152,6 +152,23 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
 // CHECK-LABEL: func @test_vectorize_fill
 //       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
 
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, memref<8x16xf32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+//       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<f32>, memref<f32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy_scalar
+//       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+//       CHECK: store %[[V]], {{.*}} : memref<f32>
+
+
 #matmul_accesses = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
index 4fb378c..e356eb7 100644 (file)
@@ -144,6 +144,7 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgVectorizationPattern<MatmulOp>,
                   LinalgVectorizationPattern<FillOp>,
+                  LinalgVectorizationPattern<CopyOp>,
                   LinalgVectorizationPattern<GenericOp>>(
       ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));