// In the future, tile sizes should be derived from op properties + machine
// description but we do not need to wait on this to start having useful
// patterns.
-class TileLinalgOp<list<int> sizes, string value> : NativeCodeCall<
+// `permutation` is an optional parameter to specify the ordering of the
+// tiled loops. If provided, it must be a list of integers with the same number
+// of elements as `sizes`.
+class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall<
"if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
- StrJoinInt<sizes>.result # "}, \"" # value # "\")))" #
+ StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
+ StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// success.
////////////////////////////////////////////////////////////////////////////////
-// Tiles `op` by `sizes` and sets the attribute `kLinalgTransformMarker` to
-// `linalgMarker`.
+/// Tiles `op` by `sizes` permuting the looops according to `permutation`
+/// and sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
LogicalResult tileLinalgOpAndSetMarker(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> sizes,
- StringRef linalgMarker);
+ StringRef linalgMarker,
+ ArrayRef<unsigned> permutation);
// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and sets
// the attribute `kLinalgTransformMarker` to `linalgMarker`.
};
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
-/// Returns a struct containing the tiled loops and the cloned op if successful,
-/// llvm::None otherwise.
+/// and permute the loop nest according to `permutation`
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
+/// Returns a struct containing the tiled loops in the specified order
+/// and the cloned op if successful, llvm::None otherwise.
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<Value *> tileSizes,
+ ArrayRef<unsigned> permutation = {},
OperationFolder *folder = nullptr);
/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`.
-/// Returns a struct containing the tiled loops and the cloned op if successful,
-/// llvm::None otherwise.
+/// and permute the loop nest according to `permutation`
+/// The permutation is expressed as a list of integers that specify
+/// the new ordering of the loop nest. The length of `permutation`
+/// must be equal to the length of `tileSizes`.
+/// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+/// `permutation = [1,2,0]`. All values in `permutation` must be
+/// integers, in the range 0..`tileSizes.size()` without duplications
+/// (i.e. `[1,1,2]` is an invalid permutation). An empty list
+/// states for the identity permutation.
+/// Returns a struct containing the tiled loops in the specified order
+/// and the cloned op if successful, llvm::None otherwise.
/// When non-null, the optional pointer `folder` is used to call into the
/// `createAndFold` builder method. If `folder` is null, the regular `create`
/// method is called.
llvm::Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> permutation = {},
OperationFolder *folder = nullptr);
template <typename... Args>
static AffineMap getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context);
+ /// Returns an AffineMap representing a permutation.
+ /// The permutation is expressed as a non-empty vector of integers.
+ /// E.g. the permutation `(i,j,k) -> (j,k,i)` will be expressed with
+ /// `permutation = [1,2,0]`. All values in `permutation` must be
+ /// integers, in the range 0..`permutation.size()-1` without duplications
+ /// (i.e. `[1,1,2]` is an invalid permutation).
+ static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
+ MLIRContext *context);
+
MLIRContext *getContext() const;
explicit operator bool() { return map != nullptr; }
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
"__internal_linalg_transform__";
-LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(PatternRewriter &rewriter,
- Operation *op,
- ArrayRef<int64_t> sizes,
- StringRef linalgMarker) {
- auto tileRes = tileLinalgOperation(rewriter, op, sizes);
+LogicalResult mlir::linalg::tileLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
+ StringRef linalgMarker, ArrayRef<unsigned> permutation) {
+ assert(permutation.empty() || permutation.size() == sizes.size());
+ auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation);
if (!tileRes)
return failure();
tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker,
return res;
}
-llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
- ArrayRef<Value *> tileSizes,
- OperationFolder *folder) {
+void applyPermutationToLoopRanges(SmallVector<SubViewOp::Range, 4> &loopRanges,
+ ArrayRef<unsigned> permutation) {
+ SmallVector<SubViewOp::Range, 4> auxVec(loopRanges.size());
+ for (unsigned i = 0; i < permutation.size(); ++i)
+ auxVec[i] = loopRanges[permutation[i]];
+ loopRanges = auxVec;
+}
+
+llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
+ OpBuilder &b, LinalgOp op, ArrayRef<Value *> tileSizes,
+ ArrayRef<unsigned> permutation, OperationFolder *folder) {
// 1. Enforce the convention that "tiling by zero" skips tiling a particular
// dimension. This convention is significantly simpler to handle instead of
// adjusting affine maps to account for missing dimensions.
op.getNumWindowLoops() ==
tileSizes.size() &&
"expected matching number of tile sizes and loops");
+
+ // If permutation is empty, use the identity. Build the permutation map
+ // otherwise.
+ auto invPermutationMap = AffineMap::getMultiDimIdentityMap(
+ tileSizes.size(), ScopedContext::getContext());
+ if (!permutation.empty())
+ invPermutationMap = inversePermutation(
+ AffineMap::getPermutationMap(permutation, ScopedContext::getContext()));
+
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
ScopedContext scope(b, op.getLoc());
auto loopRanges =
makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap,
viewSizes, tileSizes, folder);
+ if (!permutation.empty())
+ applyPermutationToLoopRanges(loopRanges, permutation);
// 3. Create the tiled loops.
LinalgOp res = op;
auto b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
SmallVector<Value *, 4> ivValues(ivs.begin(), ivs.end());
+
+ // If we have to apply a permutation to the tiled loop nest, we have to
+ // reorder the induction variables This permutation is the right one
+ // assuming that loopRanges have previously been permuted by
+ // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of
+ // that one: (d0,d1,d2)->(d2,d0,d1)
+ if (!permutation.empty())
+ ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder);
+
auto views =
makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder);
auto operands = getAssumedNonViewOperands(op);
return TiledLinalgOp{res, loops};
}
-llvm::Optional<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
- ArrayRef<int64_t> tileSizes,
- OperationFolder *folder) {
+llvm::Optional<TiledLinalgOp> mlir::linalg::tileLinalgOp(
+ OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> permutation, OperationFolder *folder) {
if (tileSizes.empty())
return llvm::None;
tileSizeValues.push_back(constant_index(folder, 0));
}
- return tileLinalgOp(b, op, tileSizeValues, folder);
+ return tileLinalgOp(b, op, tileSizeValues, permutation, folder);
}
static void tileLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
OpBuilder b(f);
OperationFolder folder(f.getContext());
f.walk([tileSizes, &b, &folder](LinalgOp op) {
- auto opLoopsPair = tileLinalgOp(b, op, tileSizes, &folder);
+ auto opLoopsPair =
+ tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder);
// If tiling occurred successfully, erase old op.
if (opLoopsPair)
op.erase();
{getAffineConstantExpr(val, context)});
}
+/// Returns an AffineMap representing a permutation.
+AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
+ MLIRContext *context) {
+ assert(!permutation.empty() &&
+ "Cannot create permutation map from empty permutation vector");
+ SmallVector<AffineExpr, 4> affExprs;
+ for (auto index : permutation)
+ affExprs.push_back(getAffineDimExpr(index, context));
+ auto m = std::max_element(permutation.begin(), permutation.end());
+ auto permutationMap = AffineMap::get(*m + 1, 0, affExprs);
+ assert(permutationMap.isPermutation() && "Invalid permutation vector");
+ return permutationMap;
+}
+
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
MLIRContext *context) {
SmallVector<AffineExpr, 4> dimExprs;
--- /dev/null
+// RUN: mlir-opt %s -test-linalg-tile-and-permute-patterns | FileCheck %s
+
+// CHECK-DAG: #[[STRIDED_1D:.*]] = (d0)[s0] -> (d0 + s0)
+// CHECK-DAG: #[[STRIDED_2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
+
+func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
+ %y: memref<?xf32, offset: ?, strides: [1]>,
+ %v: memref<f32>) {
+ linalg.dot(%x, %y, %v) : memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<f32>
+ return
+}
+// CHECK-LABEL: func @dot
+// CHECK-DAG : %[[c0:.*]] = constant 0 : index
+// CHECK-DAG : %[[c8:.*]] = constant 8 : index
+// CHECK-DAG : %[[c8000:.*]] = constant 8000 : index
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c8]] {
+// CHECK : linalg.dot({{.*}}, {{.*}}, {{.*}}) : memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<f32>
+
+func @matvec(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %x: memref<?xf32, offset: ?, strides: [1]>,
+ %y: memref<?xf32, offset: ?, strides: [1]>) {
+ linalg.matvec(%A, %x, %y) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?xf32, offset: ?, strides: [1]>,
+ memref<?xf32, offset: ?, strides: [1]>
+ return
+}
+// CHECK-LABEL: func @matvec
+// CHECK-DAG : %[[c0:.*]] = constant 0 : index
+// CHECK-DAG : %[[c5:.*]] = constant 5 : index
+// CHECK-DAG : %[[c6:.*]] = constant 6 : index
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
+// CHECK : linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?xf32, #[[STRIDED_1D]]>, memref<?xf32, #[[STRIDED_1D]]>
+
+func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+ linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ memref<?x?xf32, offset: ?, strides: [?, 1]>
+ return
+}
+// CHECK-LABEL: func @matmul
+// CHECK-DAG : %[[c0:.*]] = constant 0 : index
+// CHECK-DAG : %[[c2:.*]] = constant 2 : index
+// CHECK-DAG : %[[c3:.*]] = constant 3 : index
+// CHECK-DAG : %[[c4:.*]] = constant 4 : index
+// CHECK-DAG : %[[c20:.*]] = constant 20 : index
+// CHECK-DAG : %[[c30:.*]] = constant 30 : index
+// CHECK-DAG : %[[c40:.*]] = constant 40 : index
+// CHECK-DAG : %[[c200:.*]] = constant 200 : index
+// CHECK-DAG : %[[c300:.*]] = constant 300 : index
+// CHECK-DAG : %[[c400:.*]] = constant 400 : index
+// CHECK-DAG : %[[c2000:.*]] = constant 2000 : index
+// CHECK-DAG : %[[c3000:.*]] = constant 3000 : index
+// CHECK-DAG : %[[c4000:.*]] = constant 4000 : index
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c300]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c200]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c400]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
+// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
set(LLVM_TARGET_DEFINITIONS TestLinalgTransformPatterns.td)
mlir_tablegen(TestLinalgTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS TestLinalgTilePermutePatterns.td)
+mlir_tablegen(TestLinalgTilePermutePatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRTestLinalgTilePermutePatternsIncGen)
--- /dev/null
+//===- TestLinalgTilePermutePatterns.td - Test patterns --*- tablegen ----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the pattern definition file for declarative Linalg transformations
+// tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TEST_LINALG_TILEPERMUTE_PATTERNS
+#define TEST_LINALG_TILEPERMUTE_PATTERNS
+
+include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
+
+//===----------------------------------------------------------------------===//
+// Linalg tiling and permutation patterns.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $A, $B, $C),
+ (TileLinalgOp<[2000, 3000, 4000], "L2", [1,2,0]> $op),
+ [(Constraint<Or<[HasNoLinalgTransformMarker,
+ HasLinalgTransformMarker<"MEM">]>> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+ (TileLinalgOp<[200, 300, 400], "L1", [1,0,2]> $op),
+ [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
+def : Pat<(MatmulOp:$op $A, $B, $C),
+ (TileLinalgOp<[20, 30, 40], "REG"> $op),
+ [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+
+def : Pattern<(MatvecOp:$op $A, $b, $c),
+ [(TileLinalgOp<[5, 6], "L1", [1,0]> $op)],
+ [(Constraint<HasNoLinalgTransformMarker> $op)]>;
+
+def : Pattern<(DotOp:$op $a, $b, $c),
+ [(TileLinalgOp<[8000], "L1"> $op)],
+ [(Constraint<Or<[HasNoLinalgTransformMarker,
+ HasLinalgTransformMarker<"MEM">,
+ HasLinalgTransformMarker<"L3">,
+ HasLinalgTransformMarker<"L2">]>> $op)]>;
+def : Pattern<(DotOp:$op $a, $b, $c),
+ [(TileLinalgOp<[8], "REG"> $op)],
+ [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+
+#endif // TEST_LINALG_TILEPERMUTE_PATTERNS
TestLoopFusion.cpp
TestInlining.cpp
TestLinalgTransforms.cpp
+ TestLinalgTilePermuteTransforms.cpp
TestLoopMapping.cpp
TestLoopParametricTiling.cpp
TestOpaqueLoc.cpp
include_directories(${CMAKE_CURRENT_BINARY_DIR}/../DeclarativeTransforms)
add_dependencies(MLIRTestTransforms MLIRStandardOpsIncGen)
add_dependencies(MLIRTestTransforms MLIRTestLinalgTransformPatternsIncGen)
+add_dependencies(MLIRTestTransforms MLIRTestLinalgTilePermutePatternsIncGen)
target_link_libraries(MLIRTestTransforms
MLIRAffineOps
MLIRAnalysis
--- /dev/null
+//===- TestLinalgTilePermuteTransforms.cpp - Test Linalg tile + permute ---===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements logic for testing Linalg transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace mlir {
+namespace linalg {
+namespace {
+#include "TestLinalgTilePermutePatterns.h.inc"
+} // end namespace
+} // end namespace linalg
+} // end namespace mlir
+
+namespace {
+struct TestLinalgTilePermuteTransforms
+ : public FunctionPass<TestLinalgTilePermuteTransforms> {
+ void runOnFunction() override;
+};
+} // end anonymous namespace
+
+/// Apply transformations specified as patterns.
+void TestLinalgTilePermuteTransforms::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto funcOp = getFunction();
+
+ // Add the generated patterns to the list.
+ linalg::populateWithGenerated(&getContext(), &patterns);
+ applyPatternsGreedily(funcOp, patterns);
+
+ // Drop the marker.
+ funcOp.walk([](LinalgOp op) {
+ op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+ });
+}
+
+static PassRegistration<TestLinalgTilePermuteTransforms>
+ pass("test-linalg-tile-and-permute-patterns",
+ "Test Linalg transformation with permutation patterns by applying "
+ "them greedily.");