From 16e82d855a1bbedee8035a55680b9c9744479071 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Fri, 24 Jan 2020 16:23:16 -0800 Subject: [PATCH] [mlir] Add primitive transform pattern to rewrite linalg.fill into vector.broadcast form. Summary: This diff adds a transformation patter to rewrite linalg.fill as broadcasting a scaler into a vector. It uses the same preconditioning as matmul (memory is contiguous). Reviewers: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D73391 --- .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 51 ++++++++++++++-------- mlir/test/Dialect/Linalg/transform-patterns.mlir | 7 +++ .../TestLinalgTransformPatterns.td | 7 +++ 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index aed392c..618fcfe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -16,10 +16,12 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -156,8 +158,8 @@ static bool isMatmul(linalg::GenericOp genericOp) { genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -// TODO(ntv): This is in fact much more general than just vectorization for -// matmul ops. +// TODO(ntv, ataei): This is in fact much more general than just vectorization +// for matmul and fill ops. LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. @@ -167,7 +169,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op)) + if (isa(op) || isa(op)) return success(); auto genericOp = dyn_cast(op); @@ -189,28 +191,41 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); + using edsc::intrinsics::std_load; + using edsc::intrinsics::std_store; + using vector_contract = edsc::intrinsics::ValueBuilder; + using vector_broadcast = edsc::intrinsics::ValueBuilder; + using vector_type_cast = edsc::intrinsics::ValueBuilder; assert(succeeded(vectorizeLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); - auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); - using edsc::intrinsics::std_load; - using edsc::intrinsics::std_store; - using vector_contract = edsc::intrinsics::ValueBuilder; - using vector_type_cast = edsc::intrinsics::ValueBuilder; - auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); - auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); - auto vC = std_load(vectorMemRefC); - auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), - linalgOp.iterator_types()); - std_store(vRes, vectorMemRefC); + + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.fill as vector.broadcast: " + << *op << ":\n"); + auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); + auto dstVec = std_load(dstMemrefVec); + auto resVec = vector_broadcast(dstVec, fillOp.value()); + std_store(resVec, dstMemrefVec); + } else { + // Vectorize other ops as vector contraction (currently only matmul). + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + auto vA = std_load(vector_type_cast(linalgOp.getInput(0))); + auto vB = std_load(vector_type_cast(linalgOp.getInput(1))); + auto vectorMemRefC = vector_type_cast(linalgOp.getOutputBuffer(0)); + auto vC = std_load(vectorMemRefC); + auto vRes = vector_contract(vA, vB, vC, linalgOp.indexing_maps(), + linalgOp.iterator_types()); + std_store(vRes, vectorMemRefC); + } return {}; } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 5440729..520be05 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -205,6 +205,13 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, // CHECK: vector.contract {{.*}} : // vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> +func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { + linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, f32 + return +} +// CHECK-LABEL: func @test_vectorize_fill +// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 0d43fdf5..795247a 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -105,6 +105,12 @@ def : Pattern<(MatmulOp:$op $_, $_, $_), HasLinalgTransformMarker<"VECTORIZE">, PreconditionVectorizeLinalgOp ]>>)]>; +def : Pattern<(FillOp:$op $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp + ]>>)]>; def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), [(VectorizeLinalgOp)], [(Constraint>)]>; + //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// -- 2.7.4