Instead of interchanging loops during the loop lowering this pass performs the interchange by permuting the indexing maps. It also updates the iterator types and the index accesses in the body of the operation.
Differential Revision: https://reviews.llvm.org/D100627
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions);
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
-/// This is an in-place transformation controlled by `interchangeVector`.
-/// An empty vector is interpreted as the identity permutation and the
-/// transformation returns early.
+/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// the index accesses of `op`. This is an in-place transformation controlled by
+/// `interchangeVector`. An empty vector is interpreted as the identity
+/// permutation and the transformation returns early.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
-LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
+void interchange(PatternRewriter &rewriter, LinalgOp op,
+ ArrayRef<unsigned> interchangeVector);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
Operation *op, ArrayRef<unsigned> interchangeVector) {
- if (interchangeVector.empty())
- return failure();
// Transformation applies to generic ops only.
if (!isa<GenericOp, IndexedGenericOp>(op))
return failure();
- LinalgOp linOp = cast<LinalgOp>(op);
- // Transformation applies to buffers only.
- if (!linOp.hasBufferSemantics())
- return failure();
- // Permutation must be applicable.
- if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ // Interchange vector must be non-empty and match the number of loops.
+ if (interchangeVector.empty() ||
+ linalgOp.getNumLoops() != interchangeVector.size())
return failure();
// Permutation map must be invertible.
if (!inversePermutation(
return success();
}
-LinalgOp mlir::linalg::interchange(LinalgOp op,
- ArrayRef<unsigned> interchangeVector) {
- if (interchangeVector.empty())
- return op;
-
+void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
+ ArrayRef<unsigned> interchangeVector) {
+ // 1. Compute the inverse permutation map.
MLIRContext *context = op.getContext();
- auto permutationMap = inversePermutation(
+ AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
+ assert(interchangeVector.size() == op.getNumLoops() &&
+ "expected interchange vector to have entry for every loop");
+
+ // 2. Compute the interchanged indexing maps.
SmallVector<Attribute, 4> newIndexingMaps;
- auto indexingMaps = op.indexing_maps().getValue();
+ ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
}
- auto itTypes = op.iterator_types().getValue();
- SmallVector<Attribute, 4> itTypesVector;
- for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
- itTypesVector.push_back(itTypes[i]);
- applyPermutationToVector(itTypesVector, interchangeVector);
-
op->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
+
+ // 3. Compute the interchanged iterator types.
+ ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
+ SmallVector<Attribute, 4> itTypesVector;
+ llvm::append_range(itTypesVector, itTypes);
+ applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
- return op;
+ // 4. Transform the index operations by applying the permutation map.
+ if (op.hasIndexSemantics()) {
+ // TODO: Remove the assertion and add a getBody() method to LinalgOp
+ // interface once every LinalgOp has a body.
+ assert(op->getNumRegions() == 1 &&
+ op->getRegion(0).getBlocks().size() == 1 &&
+ "expected generic operation to have one block.");
+ Block &block = op->getRegion(0).front();
+ OpBuilder::InsertionGuard guard(rewriter);
+ for (IndexOp indexOp :
+ llvm::make_early_inc_range(block.getOps<IndexOp>())) {
+ rewriter.setInsertionPoint(indexOp);
+ SmallVector<Value> allIndices;
+ allIndices.reserve(op.getNumLoops());
+ llvm::transform(llvm::seq<int64_t>(0, op.getNumLoops()),
+ std::back_inserter(allIndices), [&](int64_t dim) {
+ return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
+ });
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(
+ indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
+ }
+ }
}
LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- // TODO: remove hasIndexSemantics check once index ops are supported.
- if (!linalgOp || linalgOp.hasIndexSemantics())
+ if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
// TODO: figure out how this interplays with named ops. In particular this
// should break the named op property.
rewriter.updateRootInPlace(op, [&]() {
- interchange(linalgOp, interchangeVector);
+ interchange(rewriter, linalgOp, interchangeVector);
// New filter if specified.
filter.replaceLinalgTransformationFilter(rewriter, op);
});
--- /dev/null
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 -test-linalg-transform-patterns=test-interchange-pattern=1,3,4,2,0 | FileCheck --check-prefix=CANCEL-OUT %s
+
+#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+
+func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) {
+ linalg.generic {
+ indexing_maps = [#map0, #map1],
+ iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]}
+ ins(%arg0 : memref<1x2x3x4x5xindex>)
+ outs(%arg1 : memref<1x2x4xindex>) {
+ ^bb0(%arg2 : index, %arg3 : index) :
+ %0 = linalg.index 0 : index
+ %1 = linalg.index 1 : index
+ %2 = linalg.index 4 : index
+ %3 = subi %0, %1 : index
+ %4 = addi %3, %2 : index
+ %5 = addi %4, %arg2 : index
+ linalg.yield %5 : index
+ }
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)>
+// CHECK: func @interchange_generic_op
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-DAG: %[[IDX0:.+]] = linalg.index 1 : index
+// CHECK-DAG: %[[IDX1:.+]] = linalg.index 3 : index
+// CHECK-DAG: %[[IDX4:.+]] = linalg.index 0 : index
+// CHECK: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+// CHECK: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+// CHECK: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+// CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+// CANCEL-OUT: func @interchange_generic_op
+// CANCEL-OUT: linalg.generic
+// CANCEL-OUT-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CANCEL-OUT-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]
+// CANCEL-OUT-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+// CANCEL-OUT-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+// CANCEL-OUT-DAG: %[[IDX4:.+]] = linalg.index 4 : index
+// CANCEL-OUT: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+// CANCEL-OUT: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+// CANCEL-OUT: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+
*this, "tile-sizes-for-padding",
llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated};
+ ListOption<unsigned> testInterchangePattern{
+ *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
+ llvm::cl::desc("Test the interchange pattern.")};
};
} // end anonymous namespace
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
+static void applyInterchangePattern(FuncOp funcOp,
+ ArrayRef<unsigned> interchangeVector) {
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet interchangePattern(context);
+ interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
+ context, interchangeVector,
+ LinalgTransformationFilter(ArrayRef<Identifier>{},
+ Identifier::get("interchange", context)));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
auto lambda = [&](void *) {
(void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
});
}
+ if (testInterchangePattern.hasValue())
+ return applyInterchangePattern(getFunction(), testInterchangePattern);
}
namespace mlir {