From: Nicolas Vasilache Date: Fri, 1 Nov 2019 15:29:42 +0000 (-0700) Subject: Add Linalg pattern for producer-consumer fusion X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bd94a10c02a641e59c5ccfec143f728e13b516c2;p=platform%2Fupstream%2Fllvm.git Add Linalg pattern for producer-consumer fusion This CL adds a simple pattern for specifying producer-consumer fusion on Linalg operations. Implementing such an extension reveals some interesting properties. Since Linalg operates on a buffer abstraction, the output buffers are specified as in/out parameters to the ops. As a consequence, there are no SSA use-def chains and one cannot specify complex dag input patterns with the current infrastructure. Instead this CL uses constraints based on the existing linalg dependence analysis to focus the pattern and refine patterns based on the type of op that last wrote in a buffer. This is a very local property and is less powerful than the generic dag specification based on SSA use-def chains. This will be generalized in the future. PiperOrigin-RevId: 277931503 --- diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index 65cb2e6..354c94c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -22,6 +22,8 @@ #include "mlir/IR/OpDefinition.h" namespace mlir { +class FuncOp; + namespace linalg { class LinalgOp; @@ -71,6 +73,8 @@ public: enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes }; + // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. + static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index f4eb797..fb68c0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -42,7 +42,6 @@ std::unique_ptr> createLowerLinalgToLoopsPass(); std::unique_ptr> createLowerLinalgToLLVMPass(); -std::unique_ptr> createLinalgTransformsPass(); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index aef7e76..9cc4ea3 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -33,16 +33,44 @@ class HasLinalgTransformMarker : CPred<[{ $0.getAttrOfType(kLinalgTransformMarker).getValue() == "}] # value # [{"}]>; +class IsProducedByOpOfType : + CPred<"isProducedByOpOfType<" # value # ">($0, $1)">; + //===----------------------------------------------------------------------===// -// Linalg transformation patterns. +// Linalg fusion patterns. //===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// model but we do not need to wait on this to start having useful patterns. +class TileAndFuseLinalgOp sizes, string value> : NativeCodeCall< + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\")))" # + " return matchFailure();">; + +def : Pat<(MatmulOp:$consumer $A, $B, $C), + (TileAndFuseLinalgOp<[100, 150], "L1"> $consumer), + [ + (Constraint $consumer), + (Constraint> $consumer, $A), + ], + // In the buffer world there is no use-def chains or dags so benefits + // cannot be computed automatically from the length of the matched + // pattern. Instead we specify the benefit ourselves for now. + // This is not expected to be a big challenge long-term because + // pattern benefits are akin to feature engineering: features should + // be learned. + (addBenefit 1)>; + +//===----------------------------------------------------------------------===// +// Linalg tiling patterns. +//===----------------------------------------------------------------------===// +// +// In the future, tile sizes should be derived from op properties + machine +// model but we do not need to wait on this to start having useful patterns. class TileLinalgOp sizes, string value> : NativeCodeCall< - "auto res = tileLinalgOperation($_builder, $0, ArrayRef{" # - StrJoinInt.result # "});" # [{ - if (!res) - return matchFailure(); - res->op.setAttr(kLinalgTransformMarker, StringAttr::get("}] # value # - [{", $0.getContext()));}]>; + "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + StrJoinInt.result # "}, \"" # value # "\")))" # + " return matchFailure();">; def : Pat<(MatmulOp:$op $A, $B, $C), (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 0bfcdea..0401d69 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -81,8 +81,21 @@ struct FusionInfo { LinalgOp fusedProducer; }; -// Fuses producer into consumer if the producer is structurally feasible and the -// fusion would not violate dependencies. +/// Checks whether the specific `producer` is the last write to exactly the +/// whole `consumedView`. This checks structural dominance, that the dependence +/// is a RAW without any interleaved write to any piece of `consumedView`. +bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer); + +/// Checks whether fusing the specific `producer` of the `consumedView` is +/// feasible. This checks `producer` is the last write of `consumedView` and +/// that no interleaved dependence would be violated (RAW, WAR or WAW). +bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, + Value *consumedView, LinalgOp producer); + +/// Fuses producer into consumer if the producer is structurally feasible and +/// the fusion would not violate dependencies. /// When non-null, the optional pointer `folder` is used to call into the /// `createAndFold` builder method. If `folder` is null, the regular `create` /// method is called. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 3a90e61..9e57b7b 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -86,6 +86,13 @@ Value *Aliases::find(Value *v) { } } +LinalgDependenceGraph +LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { + SmallVector linalgOps; + f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + return LinalgDependenceGraph(aliases, linalgOps); +} + LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, ArrayRef ops) : aliases(aliases), linalgOps(ops.begin(), ops.end()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 8e7370a..8269954 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -201,19 +201,13 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, // Encode structural fusion safety preconditions. // Some of these will be lifted in the future with better analysis. -static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, +static bool isStructurallyFusableProducer(LinalgOp producer, + Value *consumedView, LinalgOp consumer) { if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; } - // Must be a subview or a slice to guarantee there are loops we can fuse into. - auto subView = dyn_cast_or_null(readView->getDefiningOp()); - auto slice = dyn_cast_or_null(readView->getDefiningOp()); - if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)"); - return false; - } // Only fuse when the producer block dominates. DominanceInfo dom(producer.getOperation()); if (!dom.dominates(producer.getOperation()->getBlock(), @@ -226,6 +220,41 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } +bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, + LinalgOp consumer, + Value *consumedView, + LinalgOp producer) { + // Make some simple structural checks that alleviate the need for more + // complex analyses. + if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { + LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" + << *producer.getOperation()); + return false; + } + // Check for any interleaved write to consumedView. + if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" + << *producer.getOperation()); + return false; + } + return true; +} + +bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, + LinalgOp consumer, Value *consumedView, + LinalgOp producer) { + if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) + return false; + // Check for any fusion-preventing dependence to any view read/written that + // would violate dependences. + if (!graph.findCoveringDependences(producer, consumer).empty()) { + LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" + << *producer.getOperation()); + return false; + } + return true; +} + // Only consider RAW atm. Optional mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, @@ -239,8 +268,8 @@ Optional mlir::linalg::fuseProducerOf( auto producer = cast(dependence.dependentOpView.op); // Check that the dependence is indeed on the input `consumerIdx` view. - auto *readView = dependence.indexingView; - if (consumer.getInput(consumerIdx) != readView) + auto *consumedView = dependence.indexingView; + if (consumer.getInput(consumerIdx) != consumedView) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks @@ -252,16 +281,17 @@ Optional mlir::linalg::fuseProducerOf( << " view: " << *producedView << " output index: " << producerIdx); - // Make some simple structural checks that alleviate the need for more - // complex analyses. - if (!isStructurallyFusableProducer(producer, readView, consumer)) { - LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation()); + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = dyn_cast_or_null(consumedView->getDefiningOp()); + auto slice = dyn_cast_or_null(consumedView->getDefiningOp()); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); continue; } - // Check for fusion-preventing write that would violate dependences. - // `view` is a producer write that cannot bypass any other write or read. - if (!graph.findCoveringDependences(producer, consumer).empty()) + // Simple fusability checks. + if (!isFusableInto(graph, consumer, consumedView, producer)) continue; // Fuse `producer` just before `consumer`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 118018b..aaa7d9d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -26,11 +27,81 @@ #include "mlir/Pass/Pass.h" using namespace mlir; -using mlir::linalg::LinalgOp; +using namespace mlir::linalg; // Marker used as attribute name in generated Linalg rewriting transformations. static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__"; +static LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + tileRes->op.getParentOfType().dump(); + return success(); +} + +static LogicalResult tileAndFuseLinalgOpAndSetMarker(PatternRewriter &rewriter, + Operation *op, + ArrayRef sizes, + StringRef linalgMarker) { + auto tileRes = tileLinalgOperation(rewriter, op, sizes); + if (!tileRes) + return failure(); + tileRes->op.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, op->getParentOfType()); + auto fusionRes = fuseProducerOf(rewriter, tileRes->op, 0, G); + if (!fusionRes) { + // Linalg fusion requires tiled loops to even determine whether it is + // possible to fuse. As a consequence, the pattern may fail even though a + // tiled version of op has already been introduced. + // So we need to remove the tiled version ourselves in case of failure. + // Another possibility is to ensure the constraints on the pattern guarantee + // that fusion will occur and just assert here. + // As we develop more complex patterns we can choose what is best. + rewriter.eraseOp(tileRes->loops[0]); + return failure(); + } + fusionRes->fusedProducer.setAttr(kLinalgTransformMarker, + rewriter.getStringAttr(linalgMarker)); + // The originalProducer can now be safely erased. This is similar to SSA-value + // use-def but in the world of buffer + structured ops. + rewriter.eraseOp(fusionRes->originalProducer); + fusionRes->fusedProducer.getParentOfType().dump(); + return success(); +} + +template +bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) { + LinalgOp consumer = dyn_cast(consumerOp); + if (!consumer) + return false; + + auto maybeConsumerIndex = consumer.getIndexOfInput(consumedView); + if (!maybeConsumerIndex) + return false; + + Aliases aliases; + auto G = LinalgDependenceGraph::buildDependenceGraph( + aliases, consumer.getParentOfType()); + for (auto dependence : G.getDependencesInto( + consumer, LinalgDependenceGraph::DependenceType::RAW)) { + auto producer = cast(dependence.dependentOpView.op); + if (!isProducerLastWriteOfView(G, consumer, consumedView, producer)) + continue; + if (isa(dependence.dependentOpView.op)) + return true; + } + return false; +} + namespace mlir { namespace linalg { namespace { @@ -58,10 +129,6 @@ void LinalgTransforms::runOnFunction() { funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); }); } -std::unique_ptr> mlir::linalg::createLinalgTransformsPass() { - return std::make_unique(); -} - static PassRegistration pass("test-linalg-transform-patterns", "Test Linalg transformation patterns by applying them greedily."); diff --git a/mlir/test/Dialect/Linalg/foo.mlir b/mlir/test/Dialect/Linalg/foo.mlir new file mode 100644 index 0000000..1466217 --- /dev/null +++ b/mlir/test/Dialect/Linalg/foo.mlir @@ -0,0 +1,281 @@ + +```mlir {.mlir} +func @matmul(%A: memref, + %B: memref, + %C: memref) { + "s.matmul"(%A, %B, %C) : memref, + memref, + memref + return +} +``` + +```mlir {.mlir} +func @matmul(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>) { + ... + loop.for %i = %c0 to %6 step %c2000 { + loop.for %k = %c0_0 to %7 step %c3000 { + loop.for %k = %c0_1 to %8 step %c4000 { + %9 = affine.apply (d0) -> (d0 + 2000)(%i) + ... + %16 = "s.subview" %A[%12, %13, %c1, %14, %15, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %21 = "s.subview" %B[%17, %18, %c1, %19, %20, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %26 = "s.subview" %C[%22, %23, %c1, %24, %25, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%16, %21, %26) {__internal_linalg_transform__ = "L3"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } +} +``` + +```mlir {.mlir} +func @matmul(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>) { + ... + loop.for %i = %c0 to %0 step %c2000 { + loop.for %j = %c0 to %2 step %c3000 { + loop.for %k = %c0 to %1 step %c4000 { + %3 = affine.apply (d0) -> (d0)(%i) + ... + %7 = "s.subview" %A[%3, %4, %c1, %5, %6, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %12 = "s.subview" %B[%8, %9, %c1, %10, %11, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %17 = "s.subview" %C[%13, %14, %c1, %15, %16, %c1] : memref (d0 * s1 + s0 + d1)> + ... + loop.for %ii = %c0_0 to %24 step %c200 { + loop.for %jj = %c0_1 to %25 step %c300 { + loop.for %kk = %c0_2 to %26 step %c400 { + %27 = affine.apply (d0) -> (d0 + 200)(%ii) + ... + %34 = "s.subview" %7[%30, %31, %c1, %32, %33, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %39 = "s.subview" %12[%35, %36, %c1, %37, %38, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %44 = "s.subview" %17[%40, %41, %c1, %42, %43, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%34, %39, %44) {__internal_linalg_transform__ = "L2"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } + } + } + } + return +} +``` + +```mlir {.mlir} +func @matmul(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>) { + ... + loop.for %i = %c0 to %0 step %c2000 { + loop.for %j = %c0 to %2 step %c3000 { + loop.for %k = %c0 to %1 step %c4000 { + %3 = affine.apply (d0) -> (d0 + 2000)(%i) + ... + %5 = "s.subview" %A[%D, %3, %c1, %arg5, %4, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %8 = "s.subview" %B[%arg5, %6, %c1, %E, %7, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %11 = "s.subview" %C[%D, %9, %c1, %E, %10, %c1] : memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg6 = %c0 to %12 step %c200 { + loop.for %arg7 = %c0 to %14 step %c300 { + loop.for %arg8 = %c0 to %13 step %c400 { + %15 = affine.apply (d0) -> (d0)(%arg6) + ... + %19 = "s.subview" %5[%15, %16, %c1, %17, %18, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %24 = "s.subview" %8[%20, %21, %c1, %22, %23, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %29 = "s.subview" %11[%25, %26, %c1, %27, %28, %c1] : memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg9 = %c0_0 to %36 step %c20 { + loop.for %B0 = %c0_1 to %37 step %c30 { + loop.for %B1 = %c0_2 to %38 step %c40 { + %39 = affine.apply (d0) -> (d0 + 20)(%arg9) + ... + %46 = "s.subview" %19[%42, %43, %c1, %44, %45, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %51 = "s.subview" %24[%47, %48, %c1, %49, %50, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %56 = "s.subview" %29[%52, %53, %c1, %54, %55, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%46, %51, %56) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } + } + } + } + } + } + } + return +} +``` + +```mlir {.mlir} +func @fusion_test(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref) { + // This will not be fused as it would violate dependencies. It will get + // tiled for all levels of the memory hierarchy. + "s.matmul"(%A, %A, %C) : memref, + memref, + memref + + // This will be fused. + "s.matmul"(%A, %B, %C) : memref, + memref, + memref + + // This will not be fused or transformed at all since there are no patterns + // on it. However it will be reordered because there are no dependencies. + "s.generic" #some_generic_trait %A, %D { + ^bb(%a: f32, %b: f32) : + "s.yield" %a : f32 + } : memref, + memref + + "s.matmul"(%C, %D, %E) : memref, + memref, + memref + + return +} +``` + +```mlir {.mlir} +func @fusion_test(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>, %D: memref (d0 * s1 + s0 + d1)>, %E: memref (d0 * s1 + s0 + d1)>) { + ... + "s.matmul"(%A, %A, %C) : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + "s.generic" #some_generic_trait %A, %D { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + "s.yield" %arg5 : f32 + }: memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg5 = %c0_0 to %6 step %c100 { + loop.for %arg6 = %c0_1 to %7 step %c150 { + %9 = affine.apply (d0) -> (d0 + 100)(%arg5) + ... + %14 = "s.subview" %C[%11, %12, %c1, %c0_3, %13, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %18 = "s.subview" %D[%c0_5, %15, %c1, %16, %17, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %23 = "s.subview" %E[%19, %20, %c1, %21, %22, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %25 = "s.subview" %A[%11, %12, %c1, %c0_10, %24, %c11] : memref (d0 * s1 + s0 + d1)> + %26 = "s.subview" %B[%c0_10, %24, %c11, %c0_3, %13, %c1] : memref (d0 * s1 + s0 + d1)> + %27 = "s.subview" %C[%11, %12, %c1, %c0_3, %13, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%25, %26, %27) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + "s.matmul"(%14, %18, %23) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + return +} +``` + + +```mlir {.mlir} +func @fusion_test(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>, %D: memref (d0 * s1 + s0 + d1)>, %E: memref (d0 * s1 + s0 + d1)>) { + ... + loop.for %arg5 = %c0_0 to %6 step %c2000 { + loop.for %arg6 = %c0_1 to %7 step %c3000 { + loop.for %arg7 = %c0_2 to %8 step %c4000 { + %11 = affine.apply (d0) -> (d0 + 2000)(%arg5) + ... + %18 = "s.subview" %A[%14, %15, %c1, %16, %17, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %23 = "s.subview" %A[%19, %20, %c1, %21, %22, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %28 = "s.subview" %C[%24, %25, %c1, %26, %27, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%18, %23, %28) {__internal_linalg_transform__ = "L3"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } + ... + "s.generic" #some_generic_trait %A, %D { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + "s.yield" %arg5 : f32 + }: memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg5 = %c0 to %9 step %c100 { + loop.for %arg6 = %c0 to %10 step %c150 { + %11 = affine.apply (d0) -> (d0)(%arg5) + ... + %14 = "s.subview" %C[%11, %12, %c1, %c0, %13, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %18 = "s.subview" %D[%c0, %15, %c1, %16, %17, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %23 = "s.subview" %E[%19, %20, %c1, %21, %22, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %25 = "s.subview" %A[%11, %12, %c1, %c0, %24, %c1] : memref (d0 * s1 + s0 + d1)> + %26 = "s.subview" %B[%c0, %24, %c1, %c0, %13, %c1] : memref (d0 * s1 + s0 + d1)> + %27 = "s.subview" %C[%11, %12, %c1, %c0, %13, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%25, %26, %27) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + "s.matmul"(%14, %18, %23) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + return +} +``` + + +```mlir {.mlir} +func @fusion_test(%A: memref (d0 * s1 + s0 + d1)>, %B: memref (d0 * s1 + s0 + d1)>, %C: memref (d0 * s1 + s0 + d1)>, %D: memref (d0 * s1 + s0 + d1)>, %E: memref (d0 * s1 + s0 + d1)>) { + ... + loop.for %arg5 = %c0 to %0 step %c2000 { + loop.for %arg6 = %c0 to %2 step %c3000 { + loop.for %arg7 = %c0 to %1 step %c4000 { + %5 = affine.apply (d0) -> (d0)(%arg5) + ... + %9 = "s.subview" %A[%5, %6, %c1, %7, %8, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %14 = "s.subview" %A[%10, %11, %c1, %12, %13, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %19 = "s.subview" %C[%15, %16, %c1, %17, %18, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%9, %14, %19) {__internal_linalg_transform__ = "L3"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } + ... + "s.generic" #some_generic_trait %A, %D { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + "s.yield" %arg5 : f32 + }: memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg5 = %c0 to %3 step %c100 { + loop.for %arg6 = %c0 to %4 step %c150 { + %5 = affine.apply (d0) -> (d0)(%arg5) + ... + %8 = "s.subview" %C[%5, %6, %c1, %c0, %7, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %12 = "s.subview" %D[%c0, %9, %c1, %10, %11, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %17 = "s.subview" %E[%13, %14, %c1, %15, %16, %c1] : memref (d0 * s1 + s0 + d1)> + %19 = "s.subview" %A[%5, %6, %c1, %c0, %18, %c1] : memref (d0 * s1 + s0 + d1)> + %20 = "s.subview" %B[%c0, %18, %c1, %c0, %7, %c1] : memref (d0 * s1 + s0 + d1)> + %21 = "s.subview" %C[%5, %6, %c1, %c0, %7, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%19, %20, %21) {__internal_linalg_transform__ = "L1"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + ... + loop.for %arg7 = %c0_0 to %28 step %c2 { + loop.for %arg8 = %c0_1 to %29 step %c3 { + loop.for %arg9 = %c0_2 to %30 step %c4 { + %31 = affine.apply (d0) -> (d0 + 2)(%arg7) + ... + %38 = "s.subview" %8[%34, %35, %c1, %36, %37, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %43 = "s.subview" %12[%39, %40, %c1, %41, %42, %c1] : memref (d0 * s1 + s0 + d1)> + ... + %48 = "s.subview" %17[%44, %45, %c1, %46, %47, %c1] : memref (d0 * s1 + s0 + d1)> + "s.matmul"(%38, %43, %48) {__internal_linalg_transform__ = "REG"} : memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)>, memref (d0 * s1 + s0 + d1)> + } + } + } + } + } + return +} +``` diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index b7c0924..2561cea 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -70,3 +70,85 @@ func @matmul(%A: memref, // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { // CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref + +#some_generic_trait = { + indexing_maps = [ + (i, j) -> (i, j), + (i, j) -> (i, j) + ], + n_views = [1, 1], + n_loop_types = [2, 0, 0] +} +func @fusion_test(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref) { + // This should not be fused as it would violate dependencies. It will get + // tiled for all levels of the memory hierarchy. + linalg.matmul(%A, %A, %C) : memref, + memref, + memref + + // This should be fused. + linalg.matmul(%A, %B, %C) : memref, + memref, + memref + + // This should not be fused or transformed at all since there are no patterns + // on it. However it will be reordered because there are no dependencies. + linalg.generic #some_generic_trait %A, %D { + ^bb(%a: f32, %b: f32) : + linalg.yield %a : f32 + } : memref, + memref + + linalg.matmul(%C, %D, %E) : memref, + memref, + memref + + return +} +// CHECK-LABEL: func @fusion_test +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c2:.*]] = constant 2 : index +// CHECK-DAG : %[[c3:.*]] = constant 3 : index +// CHECK-DAG : %[[c4:.*]] = constant 4 : index +// CHECK-DAG : %[[c20:.*]] = constant 20 : index +// CHECK-DAG : %[[c30:.*]] = constant 30 : index +// CHECK-DAG : %[[c40:.*]] = constant 40 : index +// CHECK-DAG : %[[c100:.*]] = constant 100 : index +// CHECK-DAG : %[[c150:.*]] = constant 150 : index +// CHECK-DAG : %[[c200:.*]] = constant 200 : index +// CHECK-DAG : %[[c300:.*]] = constant 300 : index +// CHECK-DAG : %[[c400:.*]] = constant 400 : index +// CHECK-DAG : %[[c2000:.*]] = constant 2000 : index +// CHECK-DAG : %[[c3000:.*]] = constant 3000 : index +// CHECK-DAG : %[[c4000:.*]] = constant 4000 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { +// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref +// +// CHECK : linalg.generic +// +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c100]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c150]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { +// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c2]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c3]] { +// CHECK : loop.for %{{.*}} = %[[c0]] to %{{.*}} step %[[c4]] { +// CHECK : linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +