[mlir][arith] Add ValueBoundsOpInterface impls
authorMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 01:46:48 +0000 (10:46 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 01:53:08 +0000 (10:53 +0900)
These ops are useful for unit testing. (They do not fold/canonicalize with affine.apply etc.)

Differential Revision: https://reviews.llvm.org/D145696

mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h [new file with mode: 0644]
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/Arith/IR/CMakeLists.txt
mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp [new file with mode: 0644]
mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644 (file)
index 0000000..d922786
--- /dev/null
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H
index f947655..956b1e4 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
@@ -134,6 +135,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   // Register all external models.
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
+  arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
index ffbe801..2b753ca 100644 (file)
@@ -1,3 +1,10 @@
+set(LLVM_OPTIONAL_SOURCES
+  ArithOps.cpp
+  ArithDialect.cpp
+  InferIntRangeInterfaceImpls.cpp
+  ValueBoundsOpInterfaceImpl.cpp
+  )
+
 set(LLVM_TARGET_DEFINITIONS ArithCanonicalization.td)
 mlir_tablegen(ArithCanonicalization.inc -gen-rewriters)
 add_public_tablegen_target(MLIRArithCanonicalizationIncGen)
@@ -6,6 +13,7 @@ add_mlir_dialect_library(MLIRArithDialect
   ArithOps.cpp
   ArithDialect.cpp
   InferIntRangeInterfaceImpls.cpp
+  ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith
@@ -20,4 +28,17 @@ add_mlir_dialect_library(MLIRArithDialect
   MLIRInferIntRangeInterface
   MLIRInferTypeOpInterface
   MLIRIR
+  MLIRValueBoundsOpInterface
+  )
+
+add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl
+  ValueBoundsOpInterfaceImpl.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRIR
+  MLIRValueBoundsOpInterface
   )
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644 (file)
index 0000000..d24f187
--- /dev/null
@@ -0,0 +1,81 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct AddIOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto addIOp = cast<AddIOp>(op);
+    assert(value == addIOp.getResult() && "invalid value");
+
+    cstr.bound(value) ==
+        cstr.getExpr(addIOp.getLhs()) + cstr.getExpr(addIOp.getRhs());
+  }
+};
+
+struct ConstantOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
+                                                   ConstantOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto constantOp = cast<ConstantOp>(op);
+    assert(value == constantOp.getResult() && "invalid value");
+
+    if (auto attr = constantOp.getValue().dyn_cast<IntegerAttr>())
+      cstr.bound(value) == attr.getInt();
+  }
+};
+
+struct SubIOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto subIOp = cast<SubIOp>(op);
+    assert(value == subIOp.getResult() && "invalid value");
+
+    cstr.bound(value) ==
+        cstr.getExpr(subIOp.getLhs()) - cstr.getExpr(subIOp.getRhs());
+  }
+};
+
+struct MulIOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto mulIOp = cast<MulIOp>(op);
+    assert(value == mulIOp.getResult() && "invalid value");
+
+    cstr.bound(value) ==
+        cstr.getExpr(mulIOp.getLhs()) * cstr.getExpr(mulIOp.getRhs());
+  }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+    arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
+    arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
+    arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
+    arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir
new file mode 100644 (file)
index 0000000..ea44966
--- /dev/null
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN:     -split-input-file | FileCheck %s
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 5)>
+// CHECK-LABEL: func @arith_addi(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]]
+//       CHECK:   return %[[apply]]
+func.func @arith_addi(%a: index) -> index {
+  %0 = arith.constant 5 : index
+  %1 = arith.addi %0, %a : index
+  %2 = "test.reify_bound"(%1) : (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (-s0 + 5)>
+// CHECK-LABEL: func @arith_subi(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]]
+//       CHECK:   return %[[apply]]
+func.func @arith_subi(%a: index) -> index {
+  %0 = arith.constant 5 : index
+  %1 = arith.subi %0, %a : index
+  %2 = "test.reify_bound"(%1) : (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * 5)>
+// CHECK-LABEL: func @arith_muli(
+//  CHECK-SAME:     %[[a:.*]]: index
+//       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]]
+//       CHECK:   return %[[apply]]
+func.func @arith_muli(%a: index) -> index {
+  %0 = arith.constant 5 : index
+  %1 = arith.muli %0, %a : index
+  %2 = "test.reify_bound"(%1) : (index) -> (index)
+  return %2 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @arith_const()
+//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
+//       CHECK:   %[[c5:.*]] = arith.constant 5 : index
+//       CHECK:   return %[[c5]]
+func.func @arith_const() -> index {
+  %c5 = arith.constant 5 : index
+  %0 = "test.reify_bound"(%c5) : (index) -> (index)
+  return %0 : index
+}
index eccf551..5828315 100644 (file)
@@ -31,6 +31,10 @@ struct TestReifyValueBounds
   TestReifyValueBounds() = default;
   TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect>();
+  }
+
   void runOnOperation() override;
 
 private:
index 19e98a9..c2feffd 100644 (file)
@@ -7196,6 +7196,7 @@ cc_library(
         ":ArithToLLVM",
         ":ArithToSPIRV",
         ":ArithTransforms",
+        ":ArithValueBoundsOpInterfaceImpl",
         ":ArmNeonDialect",
         ":ArmSVEDialect",
         ":ArmSVETransforms",
@@ -8863,6 +8864,18 @@ cc_library(
 )
 
 cc_library(
+    name = "ArithValueBoundsOpInterfaceImpl",
+    srcs = ["lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp"],
+    hdrs = ["include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"],
+    includes = ["include"],
+    deps = [
+        ":ArithDialect",
+        ":IR",
+        ":ValueBoundsOpInterface",
+    ],
+)
+
+cc_library(
     name = "TilingInterface",
     srcs = ["lib/Interfaces/TilingInterface.cpp"],
     hdrs = ["include/mlir/Interfaces/TilingInterface.h"],
@@ -9960,12 +9973,11 @@ gentbl_cc_library(
 
 cc_library(
     name = "ArithDialect",
-    srcs = glob(
-        [
-            "lib/Dialect/Arith/IR/*.cpp",
-            "lib/Dialect/Arith/IR/*.h",
-        ],
-    ),
+    srcs = [
+        "lib/Dialect/Arith/IR/ArithDialect.cpp",
+        "lib/Dialect/Arith/IR/ArithOps.cpp",
+        "lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp",
+    ],
     hdrs = [
         "include/mlir/Dialect/Arith/IR/Arith.h",
         "include/mlir/Transforms/InliningUtils.h",