From 495e1d7e8a68e4343756b58b0dd7b4bd047bd847 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 19 Apr 2021 11:53:18 +0000 Subject: [PATCH] [mlir][linalg] adding pass to run the interchange pattern. Instead of interchanging loops during the loop lowering this pass performs the interchange by permuting the indexing maps. It also updates the iterator types and the index accesses in the body of the operation. Differential Revision: https://reviews.llvm.org/D100627 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 11 ++-- mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp | 63 ++++++++++++++-------- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 5 +- mlir/test/Dialect/Linalg/interchange.mlir | 51 ++++++++++++++++++ mlir/test/lib/Transforms/TestLinalgTransforms.cpp | 16 ++++++ 5 files changed, 116 insertions(+), 30 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/interchange.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 4d45642..48b1eb8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -194,16 +194,17 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions); -/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. -/// This is an in-place transformation controlled by `interchangeVector`. -/// An empty vector is interpreted as the identity permutation and the -/// transformation returns early. +/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts +/// the index accesses of `op`. This is an in-place transformation controlled by +/// `interchangeVector`. An empty vector is interpreted as the identity +/// permutation and the transformation returns early. /// /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). -LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); +void interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector); /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index b893f2b..29acd62 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -34,17 +34,13 @@ using namespace mlir::linalg; LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( Operation *op, ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return failure(); // Transformation applies to generic ops only. if (!isa(op)) return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - // Permutation must be applicable. - if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) + LinalgOp linalgOp = cast(op); + // Interchange vector must be non-empty and match the number of loops. + if (interchangeVector.empty() || + linalgOp.getNumLoops() != interchangeVector.size()) return failure(); // Permutation map must be invertible. if (!inversePermutation( @@ -53,33 +49,56 @@ LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( return success(); } -LinalgOp mlir::linalg::interchange(LinalgOp op, - ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return op; - +void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector) { + // 1. Compute the inverse permutation map. MLIRContext *context = op.getContext(); - auto permutationMap = inversePermutation( + AffineMap permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); assert(permutationMap && "expected permutation to be invertible"); + assert(interchangeVector.size() == op.getNumLoops() && + "expected interchange vector to have entry for every loop"); + + // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - auto indexingMaps = op.indexing_maps().getValue(); + ArrayRef indexingMaps = op.indexing_maps().getValue(); for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue(); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(AffineMapAttr::get(m)); } - auto itTypes = op.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, interchangeVector); - op->setAttr(getIndexingMapsAttrName(), ArrayAttr::get(context, newIndexingMaps)); + + // 3. Compute the interchanged iterator types. + ArrayRef itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + llvm::append_range(itTypesVector, itTypes); + applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIteratorTypesAttrName(), ArrayAttr::get(context, itTypesVector)); - return op; + // 4. Transform the index operations by applying the permutation map. + if (op.hasIndexSemantics()) { + // TODO: Remove the assertion and add a getBody() method to LinalgOp + // interface once every LinalgOp has a body. + assert(op->getNumRegions() == 1 && + op->getRegion(0).getBlocks().size() == 1 && + "expected generic operation to have one block."); + Block &block = op->getRegion(0).front(); + OpBuilder::InsertionGuard guard(rewriter); + for (IndexOp indexOp : + llvm::make_early_inc_range(block.getOps())) { + rewriter.setInsertionPoint(indexOp); + SmallVector allIndices; + allIndices.reserve(op.getNumLoops()); + llvm::transform(llvm::seq(0, op.getNumLoops()), + std::back_inserter(allIndices), [&](int64_t dim) { + return rewriter.create(indexOp->getLoc(), dim); + }); + rewriter.replaceOpWithNewOp( + indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); + } + } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index c51c929..55402a7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -404,8 +404,7 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -415,7 +414,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( // TODO: figure out how this interplays with named ops. In particular this // should break the named op property. rewriter.updateRootInPlace(op, [&]() { - interchange(linalgOp, interchangeVector); + interchange(rewriter, linalgOp, interchangeVector); // New filter if specified. filter.replaceLinalgTransformationFilter(rewriter, op); }); diff --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir new file mode 100644 index 0000000..bc1d10b --- /dev/null +++ b/mlir/test/Dialect/Linalg/interchange.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 -test-linalg-transform-patterns=test-interchange-pattern=1,3,4,2,0 | FileCheck --check-prefix=CANCEL-OUT %s + +#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> + +func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) { + linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]} + ins(%arg0 : memref<1x2x3x4x5xindex>) + outs(%arg1 : memref<1x2x4xindex>) { + ^bb0(%arg2 : index, %arg3 : index) : + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 4 : index + %3 = subi %0, %1 : index + %4 = addi %3, %2 : index + %5 = addi %4, %arg2 : index + linalg.yield %5 : index + } + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)> +// CHECK: func @interchange_generic_op +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"] +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[IDX4:.+]] = linalg.index 0 : index +// CHECK: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CHECK: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CHECK: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + +// CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CANCEL-OUT: func @interchange_generic_op +// CANCEL-OUT: linalg.generic +// CANCEL-OUT-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CANCEL-OUT-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"] +// CANCEL-OUT-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CANCEL-OUT-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CANCEL-OUT-DAG: %[[IDX4:.+]] = linalg.index 4 : index +// CANCEL-OUT: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CANCEL-OUT: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CANCEL-OUT: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + + diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp index a6fe895..178de38 100644 --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -91,6 +91,9 @@ struct TestLinalgTransforms *this, "tile-sizes-for-padding", llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption testInterchangePattern{ + *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Test the interchange pattern.")}; }; } // end anonymous namespace @@ -540,6 +543,17 @@ static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef tileSizes) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } +static void applyInterchangePattern(FuncOp funcOp, + ArrayRef interchangeVector) { + MLIRContext *context = funcOp.getContext(); + RewritePatternSet interchangePattern(context); + interchangePattern.add>( + context, interchangeVector, + LinalgTransformationFilter(ArrayRef{}, + Identifier::get("interchange", context))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -580,6 +594,8 @@ void TestLinalgTransforms::runOnFunction() { (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); }); } + if (testInterchangePattern.hasValue()) + return applyInterchangePattern(getFunction(), testInterchangePattern); } namespace mlir { -- 2.7.4