[Linalg] Add permutation information to tiling
authorJose Ignacio Gomez <jigomez@ucm.es>
Thu, 5 Dec 2019 23:14:22 +0000 (15:14 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Dec 2019 23:14:59 +0000 (15:14 -0800)
This patch closes issue tensorflow/mlir#271.
It adds an optional permutation map to declarative tiling transformations.
The map is expressed as a list of integers.

Closes tensorflow/mlir#288

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/288 from tetuante:issue271 2df2938d6a1f01b3bc404ded08dea2dd1e10b588
PiperOrigin-RevId: 284064151

12 files changed:
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/IR/AffineMap.cpp
mlir/test/Dialect/Linalg/tile_permute_patterns.mlir [new file with mode: 0644]
mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td [new file with mode: 0644]
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp [new file with mode: 0644]

index 8bc0eaf..f558fa5 100644 (file)
@@ -57,9 +57,13 @@ class TileAndFuseLinalgOp<
 // In the future, tile sizes should be derived from op properties + machine
 // description but we do not need to wait on this to start having useful
 // patterns.
-class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall<
+// `permutation` is an optional parameter to specify the ordering of the
+// tiled loops. If provided, it must be a list of integers with the same number
+// of elements as `sizes`.
+class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall<
   "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
-  StrJoinInt<sizes>.result # "}, \"" # value # "\")))" #
+  StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
+  StrJoinInt<permutation>.result # "})))" #
   "  return matchFailure();">;
 
 //===----------------------------------------------------------------------===//
index 966b8f9..89615e1 100644 (file)
@@ -58,11 +58,20 @@ bool isProducedByOpOfType(Operation *consumerOp, Value *consumedView) {
 // success.
 ////////////////////////////////////////////////////////////////////////////////
 
-// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
-// `linalgMarker`.
+/// Tiles `op` by `sizes` permuting the looops according to `permutation`
+/// and sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
 LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
                                        ArrayRef<int64_t> sizes,
-                                       StringRef linalgMarker);
+                                       StringRef linalgMarker,
+                                       ArrayRef<unsigned> permutation);
 
 // Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
 // the attribute `kLinalgTransformMarker` to `linalgMarker`.
index 91c7082..8dc7845 100644 (file)
@@ -134,23 +134,43 @@ struct TiledLinalgOp {
 };
 
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
-/// Returns a struct containing the tiled loops and the cloned op if successful,
-/// llvm::None otherwise.
+/// and permute the loop nest according to `permutation`
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
+/// Returns a struct containing the tiled loops in the specified order
+/// and the cloned op if successful, llvm::None otherwise.
 /// When non-null, the optional pointer `folder` is used to call into the
 /// `createAndFold` builder method. If `folder` is null, the regular `create`
 /// method is called.
 llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                            ArrayRef<Value *> tileSizes,
+                                           ArrayRef<unsigned> permutation = {},
                                            OperationFolder *folder = nullptr);
 
 /// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
-/// Returns a struct containing the tiled loops and the cloned op if successful,
-/// llvm::None otherwise.
+/// and permute the loop nest according to `permutation`
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
+/// Returns a struct containing the tiled loops in the specified order
+/// and the cloned op if successful, llvm::None otherwise.
 /// When non-null, the optional pointer `folder` is used to call into the
 /// `createAndFold` builder method. If `folder` is null, the regular `create`
 /// method is called.
 llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                            ArrayRef<int64_t> tileSizes,
+                                           ArrayRef<unsigned> permutation = {},
                                            OperationFolder *folder = nullptr);
 
 template <typename... Args>
index 9b30f15..e42173d 100644 (file)
@@ -65,6 +65,15 @@ public:
   static AffineMap getMultiDimIdentityMap(unsigned numDims,
                                           MLIRContext *context);
 
+  /// Returns an AffineMap representing a permutation.
+  /// The permutation is expressed as a non-empty vector of integers.
+  /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+  /// `permutation = [1,2,0]`. All values in `permutation` must be
+  /// integers, in the range 0..`permutation.size()-1` without duplications
+  /// (i.e. `[1,1,2]` is an invalid permutation).
+  static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
+                                     MLIRContext *context);
+
   MLIRContext *getContext() const;
 
   explicit operator bool() { return map != nullptr; }
index 0e4aaa7..1b4509f 100644 (file)
@@ -33,11 +33,11 @@ using namespace mlir::linalg;
 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
     "__internal_linalg_transform__";
 
-LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(PatternRewriter &rewriter,
-                                                     Operation *op,
-                                                     ArrayRef<int64_t> sizes,
-                                                     StringRef linalgMarker) {
-  auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+    StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+  assert(permutation.empty() || permutation.size() == sizes.size());
+  auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation);
   if (!tileRes)
     return failure();
   tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
index 09a1ba6..2c84eee 100644 (file)
@@ -215,10 +215,17 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
   return res;
 }
 
-llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
-                           ArrayRef<Value *> tileSizes,
-                           OperationFolder *folder) {
+void applyPermutationToLoopRanges(SmallVector<SubViewOp::Range, 4> &loopRanges,
+                                  ArrayRef<unsigned> permutation) {
+  SmallVector<SubViewOp::Range, 4> auxVec(loopRanges.size());
+  for (unsigned i = 0; i < permutation.size(); ++i)
+    auxVec[i] = loopRanges[permutation[i]];
+  loopRanges = auxVec;
+}
+
+llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
+    OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
+    ArrayRef<unsigned> permutation, OperationFolder *folder) {
   // 1. Enforce the convention that "tiling by zero" skips tiling a particular
   // dimension. This convention is significantly simpler to handle instead of
   // adjusting affine maps to account for missing dimensions.
@@ -226,6 +233,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
                  op.getNumWindowLoops() ==
              tileSizes.size() &&
          "expected matching number of tile sizes and loops");
+
+  // If permutation is empty, use the identity. Build the permutation map
+  // otherwise.
+  auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
+      tileSizes.size(), ScopedContext::getContext());
+  if (!permutation.empty())
+    invPermutationMap = inversePermutation(
+        AffineMap::getPermutationMap(permutation, ScopedContext::getContext()));
+
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
   ScopedContext scope(b, op.getLoc());
@@ -239,6 +255,8 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
   auto loopRanges =
       makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
                           viewSizes, tileSizes, folder);
+  if (!permutation.empty())
+    applyPermutationToLoopRanges(loopRanges, permutation);
 
   // 3. Create the tiled loops.
   LinalgOp res = op;
@@ -248,6 +266,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
     auto b = ScopedContext::getBuilder();
     auto loc = ScopedContext::getLocation();
     SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
+
+    // If we have to apply a permutation to the tiled loop nest, we have to
+    // reorder the induction variables This permutation is the right one
+    // assuming that loopRanges have previously been permuted by
+    // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
+    // that one: (d0,d1,d2)->(d2,d0,d1)
+    if (!permutation.empty())
+      ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
+
     auto views =
         makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
     auto operands = getAssumedNonViewOperands(op);
@@ -264,10 +291,9 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
   return TiledLinalgOp{res, loops};
 }
 
-llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
-                           ArrayRef<int64_t> tileSizes,
-                           OperationFolder *folder) {
+llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
+    OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
+    ArrayRef<unsigned> permutation, OperationFolder *folder) {
   if (tileSizes.empty())
     return llvm::None;
 
@@ -297,14 +323,15 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
       tileSizeValues.push_back(constant_index(folder, 0));
   }
 
-  return tileLinalgOp(b, op, tileSizeValues, folder);
+  return tileLinalgOp(b, op, tileSizeValues, permutation, folder);
 }
 
 static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
   OpBuilder b(f);
   OperationFolder folder(f.getContext());
   f.walk([tileSizes, &b, &folder](LinalgOp op) {
-    auto opLoopsPair = tileLinalgOp(b, op, tileSizes, &folder);
+    auto opLoopsPair =
+        tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder);
     // If tiling occurred successfully, erase old op.
     if (opLoopsPair)
       op.erase();
index e56d0e8..98357b1 100644 (file)
@@ -106,6 +106,20 @@ AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
              {getAffineConstantExpr(val, context)});
 }
 
+/// Returns an AffineMap representing a permutation.
+AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
+                                       MLIRContext *context) {
+  assert(!permutation.empty() &&
+         "Cannot create permutation map from empty permutation vector");
+  SmallVector<AffineExpr, 4> affExprs;
+  for (auto index : permutation)
+    affExprs.push_back(getAffineDimExpr(index, context));
+  auto m = std::max_element(permutation.begin(), permutation.end());
+  auto permutationMap = AffineMap::get(*m + 1, 0, affExprs);
+  assert(permutationMap.isPermutation() && "Invalid permutation vector");
+  return permutationMap;
+}
+
 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
                                             MLIRContext *context) {
   SmallVector<AffineExpr, 4> dimExprs;
diff --git a/mlir/test/Dialect/Linalg/tile_permute_patterns.mlir b/mlir/test/Dialect/Linalg/tile_permute_patterns.mlir
new file mode 100644 (file)
index 0000000..4844f20
--- /dev/null
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -test-linalg-tile-and-permute-patterns | FileCheck %s
+
+// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
+// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
+
+func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
+          %y: memref<?xf32, offset: ?, strides: [1]>,
+          %v: memref<f32>) {
+  linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
+                           memref<?xf32, offset: ?, strides: [1]>,
+                           memref<f32>
+  return
+}
+// CHECK-LABEL: func @dot
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c8:.*]] = constant 8 : index
+// CHECK-DAG  :   %[[c8000:.*]] = constant 8000 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
+// CHECK      :       linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
+
+func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %x: memref<?xf32, offset: ?, strides: [1]>,
+             %y: memref<?xf32, offset: ?, strides: [1]>) {
+  linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?xf32, offset: ?, strides: [1]>,
+                              memref<?xf32, offset: ?, strides: [1]>
+  return
+}
+// CHECK-LABEL: func @matvec
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c5:.*]] = constant 5 : index
+// CHECK-DAG  :   %[[c6:.*]] = constant 6 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
+// CHECK      :       linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
+
+func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+             %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                              memref<?x?xf32, offset: ?, strides: [?, 1]>
+  return
+}
+// CHECK-LABEL: func @matmul
+// CHECK-DAG  :   %[[c0:.*]] = constant 0 : index
+// CHECK-DAG  :   %[[c2:.*]] = constant 2 : index
+// CHECK-DAG  :   %[[c3:.*]] = constant 3 : index
+// CHECK-DAG  :   %[[c4:.*]] = constant 4 : index
+// CHECK-DAG  :   %[[c20:.*]] = constant 20 : index
+// CHECK-DAG  :   %[[c30:.*]] = constant 30 : index
+// CHECK-DAG  :   %[[c40:.*]] = constant 40 : index
+// CHECK-DAG  :   %[[c200:.*]] = constant 200 : index
+// CHECK-DAG  :   %[[c300:.*]] = constant 300 : index
+// CHECK-DAG  :   %[[c400:.*]] = constant 400 : index
+// CHECK-DAG  :   %[[c2000:.*]] = constant 2000 : index
+// CHECK-DAG  :   %[[c3000:.*]] = constant 3000 : index
+// CHECK-DAG  :   %[[c4000:.*]] = constant 4000 : index
+// CHECK      :   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
+// CHECK      :     loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
+// CHECK      :       loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
+// CHECK      :         loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
+// CHECK      :           loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
+// CHECK      :             loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
+// CHECK      :               loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
+// CHECK      :                 loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
+// CHECK      :                   loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
+// CHECK      :                           linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
index 06e81a0..1ee62d8 100644 (file)
@@ -1,3 +1,7 @@
 set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td)
 mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters)
 add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
+mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTilePermutePatterns.td
new file mode 100644 (file)
index 0000000..6d7bfff
--- /dev/null
@@ -0,0 +1,57 @@
+//===- TestLinalgTilePermutePatterns.td - Test patterns --*- tablegen ----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the pattern definition file for declarative Linalg transformations
+// tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_LINALG_TILEPERMUTE_PATTERNS
+#define TEST_LINALG_TILEPERMUTE_PATTERNS
+
+include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
+
+//===----------------------------------------------------------------------===//
+// Linalg tiling and permutation patterns.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[2000, 3000, 4000], "L2", [1,2,0]> $op),
+          [(Constraint<Or<[HasNoLinalgTransformMarker,
+                           HasLinalgTransformMarker<"MEM">]>> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[200, 300, 400], "L1", [1,0,2]> $op),
+          [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+          (TileLinalgOp<[20, 30, 40], "REG"> $op),
+          [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+
+def : Pattern<(MatvecOp:$op $A, $b, $c),
+              [(TileLinalgOp<[5, 6], "L1", [1,0]> $op)],
+              [(Constraint<HasNoLinalgTransformMarker> $op)]>;
+
+def : Pattern<(DotOp:$op $a, $b, $c),
+              [(TileLinalgOp<[8000], "L1"> $op)],
+              [(Constraint<Or<[HasNoLinalgTransformMarker,
+                               HasLinalgTransformMarker<"MEM">,
+                               HasLinalgTransformMarker<"L3">,
+                               HasLinalgTransformMarker<"L2">]>> $op)]>;
+def : Pattern<(DotOp:$op $a, $b, $c),
+              [(TileLinalgOp<[8], "REG"> $op)],
+              [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+#endif // TEST_LINALG_TILEPERMUTE_PATTERNS
index 8bc9c73..8a79334 100644 (file)
@@ -4,6 +4,7 @@ add_llvm_library(MLIRTestTransforms
   TestLoopFusion.cpp
   TestInlining.cpp
   TestLinalgTransforms.cpp
+  TestLinalgTilePermuteTransforms.cpp
   TestLoopMapping.cpp
   TestLoopParametricTiling.cpp
   TestOpaqueLoc.cpp
@@ -21,6 +22,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../DeclarativeTransforms)
 include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms)
 add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
 add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
+add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
 target_link_libraries(MLIRTestTransforms
   MLIRAffineOps
   MLIRAnalysis
diff --git a/mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTilePermuteTransforms.cpp
new file mode 100644 (file)
index 0000000..ec7fa4e
--- /dev/null
@@ -0,0 +1,64 @@
+//===- TestLinalgTilePermuteTransforms.cpp - Test Linalg tile + permute ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements logic for testing Linalg transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace mlir {
+namespace linalg {
+namespace {
+#include "TestLinalgTilePermutePatterns.h.inc"
+} // end namespace
+} // end namespace linalg
+} // end namespace mlir
+
+namespace {
+struct TestLinalgTilePermuteTransforms
+    : public FunctionPass<TestLinalgTilePermuteTransforms> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+/// Apply transformations specified as patterns.
+void TestLinalgTilePermuteTransforms::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto funcOp = getFunction();
+
+  // Add the generated patterns to the list.
+  linalg::populateWithGenerated(&getContext(), &patterns);
+  applyPatternsGreedily(funcOp, patterns);
+
+  // Drop the marker.
+  funcOp.walk([](LinalgOp op) {
+    op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+  });
+}
+
+static PassRegistration<TestLinalgTilePermuteTransforms>
+    pass("test-linalg-tile-and-permute-patterns",
+         "Test Linalg transformation with permutation patterns by applying "
+         "them greedily.");