From 289cfe9ccdcb04604580ae866533d7b17654ab93 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 19 Apr 2023 15:49:20 +0900 Subject: [PATCH] [mlir][linalg] ValueBoundsOpInterface: Add support for linalg.index Differential Revision: https://reviews.llvm.org/D148598 --- .../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp | 32 ++++++++++++++ .../Linalg/value-bounds-op-interface-impl.mlir | 51 ++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp index 389cac4..55d09c4 100644 --- a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -17,6 +17,36 @@ namespace mlir { namespace linalg { namespace { +struct IndexOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto indexOp = cast(op); + auto linalgOp = indexOp->getParentOfType(); + assert(value == indexOp.getResult() && "invalid value"); + + // index >= 0 + cstr.bound(value) >= 0; + + // index < dim size + int64_t flatDimPos = linalgOp.getShapesToLoopsMap() + .getResult(indexOp.getDim()) + .cast() + .getPosition(); + // Find the `flatDimPos`-th operand dimension. + int64_t flatDimCtr = 0; + for (Value operand : linalgOp->getOperands()) { + assert(flatDimPos >= flatDimCtr && "invalid pos"); + auto shapedType = operand.getType().cast(); + if (flatDimPos < flatDimCtr + shapedType.getRank()) { + cstr.bound(value) < cstr.getExpr(operand, flatDimPos - flatDimCtr); + break; + } + flatDimCtr += shapedType.getRank(); + } + } +}; + /// Helper structure that iterates over all LinalgOps in `OpTys` and registers /// the `ValueBoundsOpInterface` with each of them. template struct LinalgValueBoundsOpInterfaceHelper { @@ -34,6 +64,8 @@ template struct LinalgValueBoundsOpInterfaceHelper { void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + IndexOp::attachInterface(*ctx); + // Register all Linalg structured ops. LinalgValueBoundsOpInterfaceHelper< #define GET_OP_LIST diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir index 537bc98..189c8e64 100644 --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -11,3 +11,54 @@ func.func @linalg_fill(%t: tensor, %f: f32) -> index { %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) return %1 : index } + +// ----- + +#accesses = [ + affine_map<(i, j, k) -> (j, i)>, + affine_map<(i, j, k) -> (i, k, i + j)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "parallel"] +} + +// CHECK-LABEL: func @linalg_index( +// CHECK-SAME: %[[arg0:.*]]: memref +func.func @linalg_index(%arg0: memref, + %arg1: memref) { + linalg.generic #trait + ins(%arg0 : memref) + outs(%arg1 : memref) + { + ^bb(%a: f32, %b: f32): + // CHECK: %[[c1:.*]] = arith.constant 1 : index + // CHECK: %[[ub_0:.*]] = memref.dim %[[arg0]], %[[c1]] + // CHECK: "test.some_use"(%[[ub_0]]) + %0 = linalg.index 0 : index + %ub_0 = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index) + "test.some_use"(%ub_0) : (index) -> () + + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: "test.some_use"(%[[c0]]) + %lb_0 = "test.reify_bound"(%0) {type = "LB"} : (index) -> (index) + "test.some_use"(%lb_0) : (index) -> () + + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[ub_1:.*]] = memref.dim %[[arg0]], %[[c0]] + // CHECK: "test.some_use"(%[[ub_1]]) + %1 = linalg.index 1 : index + %ub_1 = "test.reify_bound"(%1) {type = "UB"} : (index) -> (index) + "test.some_use"(%ub_1) : (index) -> () + + // CHECK: %[[ub_2:.*]] = arith.constant 5 : index + // CHECK: "test.some_use"(%[[ub_2]]) + %2 = linalg.index 2 : index + %ub_2 = "test.reify_bound"(%2) {type = "UB"} : (index) -> (index) + "test.some_use"(%ub_2) : (index) -> () + + linalg.yield %b : f32 + } + return +} -- 2.7.4