From 05a5a4141648218db2440b4e3a355398ef822111 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 30 Oct 2019 07:12:07 -0700 Subject: [PATCH] Add basic support for declarative Linalg transformations Linalg ops provide a good anchor for pattern matching/rewriting transformations. This CL adds a simple example of how multi-level tiling may be specified by attaching a simple StringAttr to ops as they are transformed so we can easily specify partial lowering to control transformation application. This is a first stab at taking advantage of higher-level information contained in Linalg ops and will evolve in the future. PiperOrigin-RevId: 277497958 --- mlir/include/mlir/Dialect/Linalg/CMakeLists.txt | 1 + .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 7 ++ mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 7 ++ mlir/include/mlir/Dialect/Linalg/Passes.h | 2 + .../mlir/Dialect/Linalg/Transforms/CMakeLists.txt | 3 + .../Linalg/Transforms/LinalgTransformPatterns.td | 76 ++++++++++++++++++++++ mlir/lib/Dialect/Linalg/CMakeLists.txt | 2 + .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 67 +++++++++++++++++++ mlir/test/Dialect/Linalg/transform-patterns.mlir | 72 ++++++++++++++++++++ 9 files changed, 237 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td create mode 100644 mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp create mode 100644 mlir/test/Dialect/Linalg/transform-patterns.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt index f33061b..9f57627 100644 --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 1e6384c..a4a76cc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -24,8 +24,15 @@ #else #define LINALG_LIBRARY_OPS +#ifdef AFFINE_OPS_BASE +#else include "mlir/Dialect/AffineOps/AffineOpsBase.td" +#endif // AFFINE_OPS_BASE + +#ifdef LINALG_BASE +#else include "mlir/Dialect/Linalg/IR/LinalgBase.td" +#endif // LINALG_BASE class LinalgParametricNativeOpTrait : NativeOpTrait<"linalg::" # prop # parameters> diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 38e0cb6..b865b5f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -23,8 +23,15 @@ #else #define LINALG_OPS +#ifdef AFFINE_OPS_BASE +#else include "mlir/Dialect/AffineOps/AffineOpsBase.td" +#endif // AFFINE_OPS_BASE + +#ifdef LINALG_BASE +#else include "mlir/Dialect/Linalg/IR/LinalgBase.td" +#endif // LINALG_BASE // Base class for Linalg dialect ops that do not correspond to library calls. class Linalg_Op traits = []> : diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 1efb7d34..f4eb797 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -41,6 +41,8 @@ std::unique_ptr> createLinalgPromotionPass(); std::unique_ptr> createLowerLinalgToLoopsPass(); std::unique_ptr> createLowerLinalgToLLVMPass(); + +std::unique_ptr> createLinalgTransformsPass(); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt new file mode 100644 index 0000000..f87938c --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td) +mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters) +add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td new file mode 100644 index 0000000..fdb4677 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -0,0 +1,76 @@ +//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the pattern definition file for declarative Linalg transformation. +// +//===----------------------------------------------------------------------===// + +#ifdef LINALG_TRANSFORMS +#else +#define LINALG_TRANSFORMS + +include "mlir/Dialect/Linalg/IR/LinalgOps.td" +include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.td" + +def HasNoLinalgTransformMarker : CPred<[{ + !$0.getAttrOfType(kLinalgTransformMarker) +}]>; + +class HasLinalgTransformMarker : CPred<[{ + $0.getAttrOfType(kLinalgTransformMarker).getValue() == "}] # + value # [{"}]>; + +//===----------------------------------------------------------------------===// +// Linalg transformation patterns. +//===----------------------------------------------------------------------===// +class TileLinalgOp sizes, string value> : NativeCodeCall< + "auto res = tileLinalgOperation($_builder, $0, ArrayRef{" # + StrJoinInt.result # "});" # [{ + if (!res) + return matchFailure(); + res->op.setAttr(kLinalgTransformMarker, StringAttr::get("}] # value # + [{", $0.getContext()));}]>; + +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), + [(Constraint]>> $op)]>; +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[200, 300, 400], "L2"> $op), + [(Constraint> $op)]>; +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[20, 30, 40], "L1"> $op), + [(Constraint> $op)]>; +def : Pat<(MatmulOp:$op $A, $B, $C), + (TileLinalgOp<[2, 3, 4], "REG"> $op), + [(Constraint> $op)]>; + +def : Pattern<(MatvecOp:$op $A, $b, $c), + [(TileLinalgOp<[5, 6], "L1"> $op)], + [(Constraint $op)]>; + +def : Pattern<(DotOp:$op $a, $b, $c), + [(TileLinalgOp<[8000], "L1"> $op)], + [(Constraint, + HasLinalgTransformMarker<"L3">, + HasLinalgTransformMarker<"L2">]>> $op)]>; +def : Pattern<(DotOp:$op $a, $b, $c), + [(TileLinalgOp<[8], "REG"> $op)], + [(Constraint> $op)]>; + +#endif // LINALG_TRANSFORMS diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt index d647c17..884e4d2 100644 --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -4,6 +4,7 @@ add_llvm_library(MLIRLinalg IR/LinalgOps.cpp IR/LinalgTypes.cpp Transforms/Fusion.cpp + Transforms/LinalgTransforms.cpp Transforms/LowerToLLVMDialect.cpp Transforms/LowerToLoops.cpp Transforms/Promotion.cpp @@ -22,6 +23,7 @@ add_dependencies(MLIRLinalg MLIRAnalysis MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen + MLIRLinalgTransformPatternsIncGen MLIRStandardOps MLIRStandardToLLVM ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp new file mode 100644 index 0000000..118018b --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -0,0 +1,67 @@ +//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements logic for transforming Linalg operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using mlir::linalg::LinalgOp; + +// Marker used as attribute name in generated Linalg rewriting transformations. +static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__"; + +namespace mlir { +namespace linalg { +namespace { +#include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.h.inc" +} // end namespace +} // end namespace linalg +} // end namespace mlir + +namespace { +struct LinalgTransforms : public FunctionPass { + void runOnFunction() override; +}; +} // end anonymous namespace + +/// Apply transformations specified as patterns. +void LinalgTransforms::runOnFunction() { + OwningRewritePatternList patterns; + auto funcOp = getFunction(); + + // Add the generated patterns to the list. + linalg::populateWithGenerated(&getContext(), &patterns); + applyPatternsGreedily(funcOp, patterns); + + // Drop the marker. + funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); }); +} + +std::unique_ptr> mlir::linalg::createLinalgTransformsPass() { + return std::make_unique(); +} + +static PassRegistration + pass("test-linalg-transform-patterns", + "Test Linalg transformation patterns by applying them greedily."); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir new file mode 100644 index 0000000..b7c0924 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s + +// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + +func @dot(%x: memref, + %y: memref, + %v: memref) { + linalg.dot(%x, %y, %v) : memref, + memref, + memref + return +} +// CHECK-LABEL: func @dot +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c8:.*]] = constant 8 : index +// CHECK-DAG : %[[c8000:.*]] = constant 8000 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] { +// CHECK : linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref + +func @matvec(%A: memref, + %x: memref, + %y: memref) { + linalg.matvec(%A, %x, %y) : memref, + memref, + memref + return +} +// CHECK-LABEL: func @matvec +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c5:.*]] = constant 5 : index +// CHECK-DAG : %[[c6:.*]] = constant 6 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] +// CHECK : linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref + +func @matmul(%A: memref, + %B: memref, + %C: memref) { + linalg.matmul(%A, %B, %C) : memref, + memref, + memref + return +} +// CHECK-LABEL: func @matmul +// CHECK-DAG : %[[c0:.*]] = constant 0 : index +// CHECK-DAG : %[[c2:.*]] = constant 2 : index +// CHECK-DAG : %[[c3:.*]] = constant 3 : index +// CHECK-DAG : %[[c4:.*]] = constant 4 : index +// CHECK-DAG : %[[c20:.*]] = constant 20 : index +// CHECK-DAG : %[[c30:.*]] = constant 30 : index +// CHECK-DAG : %[[c40:.*]] = constant 40 : index +// CHECK-DAG : %[[c200:.*]] = constant 200 : index +// CHECK-DAG : %[[c300:.*]] = constant 300 : index +// CHECK-DAG : %[[c400:.*]] = constant 400 : index +// CHECK-DAG : %[[c2000:.*]] = constant 2000 : index +// CHECK-DAG : %[[c3000:.*]] = constant 3000 : index +// CHECK-DAG : %[[c4000:.*]] = constant 4000 : index +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { +// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref -- 2.7.4