[mlir][linalg] adding pass to run the interchange pattern.
authorTobias Gysi <gysit@google.com>
Mon, 19 Apr 2021 11:53:18 +0000 (11:53 +0000)
committerTobias Gysi <gysit@google.com>
Mon, 19 Apr 2021 12:19:15 +0000 (12:19 +0000)
Instead of interchanging loops during the loop lowering this pass performs the interchange by permuting the indexing maps. It also updates the iterator types and the index accesses in the body of the operation.

Differential Revision: https://reviews.llvm.org/D100627

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/interchange.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/TestLinalgTransforms.cpp

index 4d45642..48b1eb8 100644 (file)
@@ -194,16 +194,17 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                      const LinalgDependenceGraph &dependenceGraph,
                      const LinalgTilingOptions &tilingOptions);
 
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
-/// This is an in-place transformation controlled by `interchangeVector`.
-/// An empty vector is interpreted as the identity permutation and the
-/// transformation returns early.
+/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// the index accesses of `op`. This is an in-place transformation controlled by
+/// `interchangeVector`. An empty vector is interpreted as the identity
+/// permutation and the transformation returns early.
 ///
 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
 /// integers, in the range 0..`op.rank` without duplications
 /// (i.e. `[1,1,2]` is an invalid permutation).
-LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
+void interchange(PatternRewriter &rewriter, LinalgOp op,
+                 ArrayRef<unsigned> interchangeVector);
 
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
index b893f2b..29acd62 100644 (file)
@@ -34,17 +34,13 @@ using namespace mlir::linalg;
 
 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
     Operation *op, ArrayRef<unsigned> interchangeVector) {
-  if (interchangeVector.empty())
-    return failure();
   // Transformation applies to generic ops only.
   if (!isa<GenericOp, IndexedGenericOp>(op))
     return failure();
-  LinalgOp linOp = cast<LinalgOp>(op);
-  // Transformation applies to buffers only.
-  if (!linOp.hasBufferSemantics())
-    return failure();
-  // Permutation must be applicable.
-  if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
+  LinalgOp linalgOp = cast<LinalgOp>(op);
+  // Interchange vector must be non-empty and match the number of loops.
+  if (interchangeVector.empty() ||
+      linalgOp.getNumLoops() != interchangeVector.size())
     return failure();
   // Permutation map must be invertible.
   if (!inversePermutation(
@@ -53,33 +49,56 @@ LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
   return success();
 }
 
-LinalgOp mlir::linalg::interchange(LinalgOp op,
-                                   ArrayRef<unsigned> interchangeVector) {
-  if (interchangeVector.empty())
-    return op;
-
+void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
+                               ArrayRef<unsigned> interchangeVector) {
+  // 1. Compute the inverse permutation map.
   MLIRContext *context = op.getContext();
-  auto permutationMap = inversePermutation(
+  AffineMap permutationMap = inversePermutation(
       AffineMap::getPermutationMap(interchangeVector, context));
   assert(permutationMap && "expected permutation to be invertible");
+  assert(interchangeVector.size() == op.getNumLoops() &&
+         "expected interchange vector to have entry for every loop");
+
+  // 2. Compute the interchanged indexing maps.
   SmallVector<Attribute, 4> newIndexingMaps;
-  auto indexingMaps = op.indexing_maps().getValue();
+  ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
   for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
     AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
     if (!permutationMap.isEmpty())
       m = m.compose(permutationMap);
     newIndexingMaps.push_back(AffineMapAttr::get(m));
   }
-  auto itTypes = op.iterator_types().getValue();
-  SmallVector<Attribute, 4> itTypesVector;
-  for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
-    itTypesVector.push_back(itTypes[i]);
-  applyPermutationToVector(itTypesVector, interchangeVector);
-
   op->setAttr(getIndexingMapsAttrName(),
               ArrayAttr::get(context, newIndexingMaps));
+
+  // 3. Compute the interchanged iterator types.
+  ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
+  SmallVector<Attribute, 4> itTypesVector;
+  llvm::append_range(itTypesVector, itTypes);
+  applyPermutationToVector(itTypesVector, interchangeVector);
   op->setAttr(getIteratorTypesAttrName(),
               ArrayAttr::get(context, itTypesVector));
 
-  return op;
+  // 4. Transform the index operations by applying the permutation map.
+  if (op.hasIndexSemantics()) {
+    // TODO: Remove the assertion and add a getBody() method to LinalgOp
+    // interface once every LinalgOp has a body.
+    assert(op->getNumRegions() == 1 &&
+           op->getRegion(0).getBlocks().size() == 1 &&
+           "expected generic operation to have one block.");
+    Block &block = op->getRegion(0).front();
+    OpBuilder::InsertionGuard guard(rewriter);
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(block.getOps<IndexOp>())) {
+      rewriter.setInsertionPoint(indexOp);
+      SmallVector<Value> allIndices;
+      allIndices.reserve(op.getNumLoops());
+      llvm::transform(llvm::seq<int64_t>(0, op.getNumLoops()),
+                      std::back_inserter(allIndices), [&](int64_t dim) {
+                        return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
+                      });
+      rewriter.replaceOpWithNewOp<AffineApplyOp>(
+          indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
+    }
+  }
 }
index c51c929..55402a7 100644 (file)
@@ -404,8 +404,7 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  // TODO: remove hasIndexSemantics check once index ops are supported.
-  if (!linalgOp || linalgOp.hasIndexSemantics())
+  if (!linalgOp)
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
@@ -415,7 +414,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
   // TODO: figure out how this interplays with named ops. In particular this
   // should break the named op property.
   rewriter.updateRootInPlace(op, [&]() {
-    interchange(linalgOp, interchangeVector);
+    interchange(rewriter, linalgOp, interchangeVector);
     // New filter if specified.
     filter.replaceLinalgTransformationFilter(rewriter, op);
   });
diff --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir
new file mode 100644 (file)
index 0000000..bc1d10b
--- /dev/null
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 -test-linalg-transform-patterns=test-interchange-pattern=1,3,4,2,0 | FileCheck --check-prefix=CANCEL-OUT %s
+
+#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+
+func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) {
+  linalg.generic {
+    indexing_maps = [#map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]}
+  ins(%arg0 : memref<1x2x3x4x5xindex>)
+  outs(%arg1 : memref<1x2x4xindex>) {
+      ^bb0(%arg2 : index, %arg3 : index) :
+        %0 = linalg.index 0 : index
+        %1 = linalg.index 1 : index
+        %2 = linalg.index 4 : index
+        %3 = subi %0, %1 : index
+        %4 = addi %3, %2 : index
+        %5 = addi %4, %arg2 : index
+        linalg.yield %5 : index
+      }
+  return
+}
+
+//    CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)>
+//    CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)>
+//        CHECK: func @interchange_generic_op
+//        CHECK:   linalg.generic
+//   CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//   CHECK-SAME:     iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"]
+//    CHECK-DAG:     %[[IDX0:.+]] = linalg.index 1 : index
+//    CHECK-DAG:     %[[IDX1:.+]] = linalg.index 3 : index
+//    CHECK-DAG:     %[[IDX4:.+]] = linalg.index 0 : index
+//        CHECK:     %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+//        CHECK:     %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+//        CHECK:     %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+//  CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+//      CANCEL-OUT: func @interchange_generic_op
+//      CANCEL-OUT:   linalg.generic
+// CANCEL-OUT-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CANCEL-OUT-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]
+//  CANCEL-OUT-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
+//  CANCEL-OUT-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
+//  CANCEL-OUT-DAG:     %[[IDX4:.+]] = linalg.index 4 : index
+//      CANCEL-OUT:     %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+//      CANCEL-OUT:     %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+//      CANCEL-OUT:     %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+
index a6fe895..178de38 100644 (file)
@@ -91,6 +91,9 @@ struct TestLinalgTransforms
       *this, "tile-sizes-for-padding",
       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
       llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<unsigned> testInterchangePattern{
+      *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Test the interchange pattern.")};
 };
 } // end anonymous namespace
 
@@ -540,6 +543,17 @@ static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
 }
 
+static void applyInterchangePattern(FuncOp funcOp,
+                                    ArrayRef<unsigned> interchangeVector) {
+  MLIRContext *context = funcOp.getContext();
+  RewritePatternSet interchangePattern(context);
+  interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
+      context, interchangeVector,
+      LinalgTransformationFilter(ArrayRef<Identifier>{},
+                                 Identifier::get("interchange", context)));
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   auto lambda = [&](void *) {
@@ -580,6 +594,8 @@ void TestLinalgTransforms::runOnFunction() {
       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
     });
   }
+  if (testInterchangePattern.hasValue())
+    return applyInterchangePattern(getFunction(), testInterchangePattern);
 }
 
 namespace mlir {