From edc8b602f996a1fc68c28054c81e8fb671bb874b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 7 Apr 2023 11:05:14 +0900 Subject: [PATCH] [mlir][linalg] ValueBoundsOpInterface: Add LinalgOps Also add a few more complex test cases. Differential Revision: https://reviews.llvm.org/D145806 --- .../Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h | 20 ++++++ mlir/include/mlir/InitAllDialects.h | 2 + mlir/lib/Dialect/Linalg/IR/CMakeLists.txt | 2 + .../Linalg/IR/ValueBoundsOpInterfaceImpl.cpp | 43 ++++++++++++ .../Dialect/Affine/value-bounds-reification.mlir | 79 ++++++++++++++++++++++ .../Linalg/value-bounds-op-interface-impl.mlir | 13 ++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 7 files changed, 160 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h create mode 100644 mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp create mode 100644 mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 index 0000000..005ed4b --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 36d5017..4df4593 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -41,6 +41,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" @@ -141,6 +142,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + linalg::registerValueBoundsOpInterfaceExternalModels(registry); memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index 85412db..7e4c8f5 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgDialect LinalgInterfaces.cpp LinalgOps.cpp LinalgDialect.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg @@ -30,5 +31,6 @@ add_mlir_dialect_library(MLIRLinalgDialect MLIRMemRefDialect MLIRTensorDialect MLIRTilingInterface + MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 index 0000000..389cac4 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,43 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +namespace mlir { +namespace linalg { +namespace { + +/// Helper structure that iterates over all LinalgOps in `OpTys` and registers +/// the `ValueBoundsOpInterface` with each of them. +template struct LinalgValueBoundsOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>( + *ctx), + ...); + } +}; + +} // namespace +} // namespace linalg +} // namespace mlir + +void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + // Register all Linalg structured ops. + LinalgValueBoundsOpInterfaceHelper< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::registerOpInterface(ctx); + }); +} diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir index c376af1..e5ee497 100644 --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -20,3 +20,82 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index return %4, %5, %6 : index, index, index } + +// ----- + +// CHECK-LABEL: func @reify_slice_bound( +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: "test.some_use"(%[[c5]]) +func.func @reify_slice_bound(%t: tensor, %idx: index, %ub: index, %f: f32) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + scf.for %iv = %c0 to %ub step %c4 { + %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] + %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor to tensor<1x?xi32> + %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%bound) : (index) -> () + } + return +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 - s1 + 1)> +// CHECK-LABEL: func @scf_for( +// CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index +// CHECK: %[[bound:.*]] = affine.apply #[[$map]]()[%[[ub]], %[[lb]]] +// CHECK: "test.some_use"(%[[bound]]) +func.func @scf_for(%lb: index, %ub: index, %step: index) { + scf.for %iv = %lb to %ub step %step { + %0 = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv)[%ub] + %bound = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index) + "test.some_use"(%bound) : (index) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @reify_slice_bound2( +func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index, + %ub2: index, %t1: tensor<1x?xi8>, + %t2: tensor, %t3: tensor<1x?xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %iv0 = %lb0 to %ub0 step %step0 { + // CHECK: %[[c129:.*]] = arith.constant 129 : index + // CHECK: "test.some_use"(%[[c129]]) + %ub1 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%iv0)[%ub0] + %ub1_ub = "test.reify_bound"(%ub1) {type = "UB"} : (index) -> (index) + "test.some_use"(%ub1_ub) : (index) -> () + + // CHECK: %[[c129:.*]] = arith.constant 129 : index + // CHECK: "test.some_use"(%[[c129]]) + %lb1 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 32)>()[%ub1] + %lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index) + "test.some_use"(%lb1_ub) : (index) -> () + + scf.for %iv1 = %lb1 to %ub1 step %c32 { + // CHECK: %[[c32:.*]] = arith.constant 32 : index + // CHECK: "test.some_use"(%[[c32]]) + %sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv1)[%ub1] + %sz_ub = "test.reify_bound"(%sz) {type = "UB"} : (index) -> (index) + "test.some_use"(%sz_ub) : (index) -> () + + scf.for %iv2 = %c0 to %ub2 step %c1 { + %slice1 = tensor.extract_slice %t1[0, %iv2] [1, 1] [1, 1] : tensor<1x?xi8> to tensor<1x1xi8> + %slice2 = tensor.extract_slice %t2[%iv2, 0] [1, %sz] [1, 1] : tensor to tensor<1x?xi8> + %slice3 = tensor.extract_slice %t3[0, 0] [1, %sz] [1, 1] : tensor<1x?xi32> to tensor<1x?xi32> + %matmul = linalg.matmul ins(%slice1, %slice2 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%slice3 : tensor<1x?xi32>) -> tensor<1x?xi32> + + // CHECK: %[[c32:.*]] = arith.constant 32 : index + // CHECK: "test.some_use"(%[[c32]]) + %matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%matmul_ub) : (index) -> () + } + } + } + return +} diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir new file mode 100644 index 0000000..537bc98 --- /dev/null +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: func @linalg_fill( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: return %[[dim]] +func.func @linalg_fill(%t: tensor, %f: f32) -> index { + %0 = linalg.fill ins(%f : f32) outs(%t : tensor) -> tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 8a8997f..5124602 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8654,6 +8654,7 @@ cc_library( ":Support", ":TensorDialect", ":TilingInterface", + ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", ], -- 2.7.4