From 9534192c3bfd861f8082843c57dfee0a7881d266 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 29 May 2020 13:09:55 -0400 Subject: [PATCH] [mlir][Linalg] Make contraction vectorization use vector transfers This revision replaces the load + vector.type_cast by appropriate vector transfer operations. These play more nicely with other vector abstractions and canonicalization patterns and lower to load/store with or without masks when appropriate. Differential Revision: https://reviews.llvm.org/D80809 --- .../Dialect/Linalg/Transforms/Vectorization.cpp | 26 +++++++++++++++++----- mlir/test/Dialect/Linalg/transform-patterns.mlir | 11 ++++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 8fa0aa3..7639613 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -120,14 +120,30 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) { // Vectorize other ops as vector contraction (currently only matmul). LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); + auto extractVectorTypeFromScalarView = [](Value v) { + MemRefType mt = v.getType().cast(); + return VectorType::get(mt.getShape(), mt.getElementType()); + }; auto linalgOp = cast(op); - Value a = std_load(vector_type_cast(linalgOp.getInput(0))); - Value b = std_load(vector_type_cast(linalgOp.getInput(1))); - Value memref = vector_type_cast(linalgOp.getOutputBuffer(0)); - Value c = std_load(memref); + Value viewA = linalgOp.getInput(0); + Value viewB = linalgOp.getInput(1); + Value viewC = linalgOp.getOutputBuffer(0); + Value zero = std_constant_index(0); + SmallVector indicesA(linalgOp.getInputShapedType(0).getRank(), + zero); + SmallVector indicesB(linalgOp.getInputShapedType(1).getRank(), + zero); + SmallVector indicesC(linalgOp.getOutputShapedType(0).getRank(), + zero); + Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA, + indicesA); + Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB, + indicesB); + Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC, + indicesC); Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), linalgOp.iterator_types()); - std_store(res, memref); + vector_transfer_write(res, viewC, indicesC); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 4c46c74..41fa3fd 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -106,14 +106,11 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, return } // CHECK-LABEL: func @vectorization_test -// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref> -// CHECK: load %{{.*}}[] : memref> -// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref> -// CHECK: load %{{.*}}[] : memref> -// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref> -// CHECK: load %{{.*}}[] : memref> +// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> // CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> -// CHECK: store %{{.*}}, %{{.*}}[] : memref> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { -- 2.7.4