Add basic support for declarative Linalg transformations
authorNicolas Vasilache <ntv@google.com>
Wed, 30 Oct 2019 14:12:07 +0000 (07:12 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 30 Oct 2019 14:12:33 +0000 (07:12 -0700)
Linalg ops provide a good anchor for pattern matching/rewriting transformations.
This CL adds a simple example of how multi-level tiling may be specified by attaching a simple StringAttr to ops as they are transformed so we can easily specify partial lowering to control transformation application.

This is a first stab at taking advantage of higher-level information contained in Linalg ops and will evolve in the future.

PiperOrigin-RevId: 277497958

mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td [new file with mode: 0644]
mlir/lib/Dialect/Linalg/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/transform-patterns.mlir [new file with mode: 0644]

index 1e6384c..a4a76cc 100644 (file)
 #else
 #define LINALG_LIBRARY_OPS
 
+#ifdef AFFINE_OPS_BASE
+#else
 include "mlir/Dialect/AffineOps/AffineOpsBase.td"
+#endif // AFFINE_OPS_BASE
+
+#ifdef LINALG_BASE
+#else
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
 
 class LinalgParametricNativeOpTrait<string prop, string parameters> :
   NativeOpTrait<"linalg::" # prop # parameters>
index 38e0cb6..b865b5f 100644 (file)
 #else
 #define LINALG_OPS
 
+#ifdef AFFINE_OPS_BASE
+#else
 include "mlir/Dialect/AffineOps/AffineOpsBase.td"
+#endif // AFFINE_OPS_BASE
+
+#ifdef LINALG_BASE
+#else
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
 
 // Base class for Linalg dialect ops that do not correspond to library calls.
 class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
index 1efb7d3..f4eb797 100644 (file)
@@ -41,6 +41,8 @@ std::unique_ptr<OpPassBase<FuncOp>> createLinalgPromotionPass();
 std::unique_ptr<OpPassBase<FuncOp>> createLowerLinalgToLoopsPass();
 
 std::unique_ptr<OpPassBase<ModuleOp>> createLowerLinalgToLLVMPass();
+
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgTransformsPass();
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt
new file mode 100644 (file)
index 0000000..f87938c
--- /dev/null
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS LinalgTransformPatterns.td)
+mlir_tablegen(LinalgTransformPatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRLinalgTransformPatternsIncGen)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
new file mode 100644 (file)
index 0000000..fdb4677
--- /dev/null
@@ -0,0 +1,76 @@
+//===- LinalgPatterns.td - Linalg transformation patterns --*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the pattern definition file for declarative Linalg transformation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_TRANSFORMS
+#else
+#define LINALG_TRANSFORMS
+
+include "mlir/Dialect/Linalg/IR/LinalgOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgLibraryOps.td"
+
+def HasNoLinalgTransformMarker : CPred<[{
+  !$0.getAttrOfType<StringAttr>(kLinalgTransformMarker)
+}]>;
+
+class HasLinalgTransformMarker<string value> : CPred<[{
+  $0.getAttrOfType<StringAttr>(kLinalgTransformMarker).getValue() == "}] #
+  value # [{"}]>;
+
+//===----------------------------------------------------------------------===//
+// Linalg transformation patterns.
+//===----------------------------------------------------------------------===//
+class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall<
+  "auto res = tileLinalgOperation($_builder, $0, ArrayRef<int64_t>{" #
+    StrJoinInt<sizes>.result # "});" # [{
+    if (!res)
+      return matchFailure();
+    res->op.setAttr(kLinalgTransformMarker, StringAttr::get("}] # value #
+    [{", $0.getContext()));}]>;
+
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
+          [(Constraint<Or<[HasNoLinalgTransformMarker,
+                           HasLinalgTransformMarker<"MEM">]>> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[200, 300, 400], "L2"> $op),
+          [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[20, 30, 40], "L1"> $op),
+          [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[2, 3, 4], "REG"> $op),
+          [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+def : Pattern<(MatvecOp:$op $A, $b, $c),
+              [(TileLinalgOp<[5, 6], "L1"> $op)],
+              [(Constraint<HasNoLinalgTransformMarker> $op)]>;
+
+def : Pattern<(DotOp:$op $a, $b, $c),
+              [(TileLinalgOp<[8000], "L1"> $op)],
+              [(Constraint<Or<[HasNoLinalgTransformMarker,
+                               HasLinalgTransformMarker<"MEM">,
+                               HasLinalgTransformMarker<"L3">,
+                               HasLinalgTransformMarker<"L2">]>> $op)]>;
+def : Pattern<(DotOp:$op $a, $b, $c),
+              [(TileLinalgOp<[8], "REG"> $op)],
+              [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+#endif // LINALG_TRANSFORMS
index d647c17..884e4d2 100644 (file)
@@ -4,6 +4,7 @@ add_llvm_library(MLIRLinalg
   IR/LinalgOps.cpp
   IR/LinalgTypes.cpp
   Transforms/Fusion.cpp
+  Transforms/LinalgTransforms.cpp
   Transforms/LowerToLLVMDialect.cpp
   Transforms/LowerToLoops.cpp
   Transforms/Promotion.cpp
@@ -22,6 +23,7 @@ add_dependencies(MLIRLinalg
   MLIRAnalysis
   MLIRLinalgOpsIncGen
   MLIRLinalgLibraryOpsIncGen
+  MLIRLinalgTransformPatternsIncGen
   MLIRStandardOps
   MLIRStandardToLLVM
   )
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
new file mode 100644 (file)
index 0000000..118018b
--- /dev/null
@@ -0,0 +1,67 @@
+//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements logic for transforming Linalg operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using mlir::linalg::LinalgOp;
+
+// Marker used as attribute name in generated Linalg rewriting transformations.
+static constexpr auto kLinalgTransformMarker = "__internal_linalg_transform__";
+
+namespace mlir {
+namespace linalg {
+namespace {
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.h.inc"
+} // end namespace
+} // end namespace linalg
+} // end namespace mlir
+
+namespace {
+struct LinalgTransforms : public FunctionPass<LinalgTransforms> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+/// Apply transformations specified as patterns.
+void LinalgTransforms::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto funcOp = getFunction();
+
+  // Add the generated patterns to the list.
+  linalg::populateWithGenerated(&getContext(), &patterns);
+  applyPatternsGreedily(funcOp, patterns);
+
+  // Drop the marker.
+  funcOp.walk([](LinalgOp op) { op.removeAttr(kLinalgTransformMarker); });
+}
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgTransformsPass() {
+  return std::make_unique<LinalgTransforms>();
+}
+
+static PassRegistration<LinalgTransforms>
+    pass("test-linalg-transform-patterns",
+         "Test Linalg transformation patterns by applying them greedily.");
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
new file mode 100644 (file)
index 0000000..b7c0924
--- /dev/null
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns | FileCheck %s
+
+// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
+// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
+
+func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
+          %y: memref<?xf32, offset: ?, strides: [1]>,
+          %v: memref<f32>) {
+  linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
+                           memref<?xf32, offset: ?, strides: [1]>,
+                           memref<f32>
+  return
+}
+// CHECK-LABEL: func @dot
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c8:.*]] = constant 8 : index
+// CHECK-DAG  :   %[[c8000:.*]] = constant 8000 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
+// CHECK      :       linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
+
+func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %x: memref<?xf32, offset: ?, strides: [1]>,
+             %y: memref<?xf32, offset: ?, strides: [1]>) {
+  linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?xf32, offset: ?, strides: [1]>,
+                              memref<?xf32, offset: ?, strides: [1]>
+  return
+}
+// CHECK-LABEL: func @matvec
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c5:.*]] = constant 5 : index
+// CHECK-DAG  :   %[[c6:.*]] = constant 6 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
+// CHECK      :       linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
+
+func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?x?xf32, offset: ?, strides: [?, 1]>
+  return
+}
+// CHECK-LABEL: func @matmul
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c2:.*]] = constant 2 : index
+// CHECK-DAG  :   %[[c3:.*]] = constant 3 : index
+// CHECK-DAG  :   %[[c4:.*]] = constant 4 : index
+// CHECK-DAG  :   %[[c20:.*]] = constant 20 : index
+// CHECK-DAG  :   %[[c30:.*]] = constant 30 : index
+// CHECK-DAG  :   %[[c40:.*]] = constant 40 : index
+// CHECK-DAG  :   %[[c200:.*]] = constant 200 : index
+// CHECK-DAG  :   %[[c300:.*]] = constant 300 : index
+// CHECK-DAG  :   %[[c400:.*]] = constant 400 : index
+// CHECK-DAG  :   %[[c2000:.*]] = constant 2000 : index
+// CHECK-DAG  :   %[[c3000:.*]] = constant 3000 : index
+// CHECK-DAG  :   %[[c4000:.*]] = constant 4000 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
+// CHECK      :       loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
+// CHECK      :         loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
+// CHECK      :           loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
+// CHECK      :             loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
+// CHECK      :               loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
+// CHECK      :                 loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
+// CHECK      :                   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
+// CHECK      :                     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
+// CHECK      :                       loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
+// CHECK      :                         loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
+// CHECK      :                           linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>