[mlir][memref] Add ValueBoundsOpInterface impls
authorMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 01:35:12 +0000 (10:35 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 01:35:52 +0000 (10:35 +0900)
Differential Revision: https://reviews.llvm.org/D145695

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h [new file with mode: 0644]
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/MemRef/IR/CMakeLists.txt
mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp [new file with mode: 0644]
mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index b37054b..82f5ed9 100644 (file)
@@ -100,6 +100,20 @@ class AllocLikeOp<string mnemonic,
     static StringRef getAlignmentAttrStrName() { return "alignment"; }
 
     MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
+
+    SmallVector<OpFoldResult> getMixedSizes() {
+      SmallVector<OpFoldResult> result;
+      unsigned ctr = 0;
+      OpBuilder b(getContext());
+      for (int64_t i = 0, e = getType().getRank(); i < e; ++i) {
+        if (getType().isDynamicDim(i)) {
+          result.push_back(getDynamicSizes()[ctr++]);
+        } else {
+          result.push_back(b.getIndexAttr(getType().getShape()[i]));
+        }
+      }
+      return result;
+    }
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644 (file)
index 0000000..eec43b7
--- /dev/null
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H
index 1a4e0f9..f947655 100644 (file)
@@ -46,6 +46,7 @@
 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
@@ -139,6 +140,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerTilingInterfaceExternalModels(registry);
   memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+  memref::registerValueBoundsOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);
   sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
index f922838..3aedd37 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMemRefDialect
   MemRefDialect.cpp
   MemRefOps.cpp
+  ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect
@@ -21,5 +22,6 @@ add_mlir_dialect_library(MLIRMemRefDialect
   MLIRIR
   MLIRShapedOpInterfaces
   MLIRSideEffectInterfaces
+  MLIRValueBoundsOpInterface
   MLIRViewLikeInterface
 )
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644 (file)
index 0000000..ca63fb3
--- /dev/null
@@ -0,0 +1,129 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+template <typename OpTy>
+struct AllocOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<AllocOpInterface<OpTy>,
+                                                   OpTy> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto allocOp = cast<OpTy>(op);
+    assert(value == allocOp.getResult() && "invalid value");
+
+    cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim];
+  }
+};
+
+struct CastOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto castOp = cast<CastOp>(op);
+    assert(value == castOp.getResult() && "invalid value");
+
+    if (castOp.getResult().getType().isa<MemRefType>() &&
+        castOp.getSource().getType().isa<MemRefType>()) {
+      cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
+    }
+  }
+};
+
+struct DimOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto dimOp = cast<DimOp>(op);
+    assert(value == dimOp.getResult() && "invalid value");
+
+    auto constIndex = dimOp.getConstantIndex();
+    if (!constIndex.has_value())
+      return;
+    cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
+  }
+};
+
+struct GetGlobalOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
+                                                   GetGlobalOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto getGlobalOp = cast<GetGlobalOp>(op);
+    assert(value == getGlobalOp.getResult() && "invalid value");
+
+    auto type = getGlobalOp.getType();
+    assert(!type.isDynamicDim(dim) && "expected static dim");
+    cstr.bound(value)[dim] == type.getDimSize(dim);
+  }
+};
+
+struct RankOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto rankOp = cast<RankOp>(op);
+    assert(value == rankOp.getResult() && "invalid value");
+
+    auto memrefType = rankOp.getMemref().getType().dyn_cast<MemRefType>();
+    if (!memrefType)
+      return;
+    cstr.bound(value) == memrefType.getRank();
+  }
+};
+
+struct SubViewOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
+                                                   SubViewOp> {
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    auto subViewOp = cast<SubViewOp>(op);
+    assert(value == subViewOp.getResult() && "invalid value");
+
+    llvm::SmallBitVector dropped = subViewOp.getDroppedDims();
+    int64_t ctr = -1;
+    for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) {
+      // Skip over rank-reduced dimensions.
+      if (!dropped.test(i))
+        ++ctr;
+      if (ctr == dim) {
+        cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i];
+        return;
+      }
+    }
+    llvm_unreachable("could not find non-rank-reduced dim");
+  }
+};
+
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    memref::AllocOp::attachInterface<memref::AllocOpInterface<memref::AllocOp>>(
+        *ctx);
+    memref::AllocaOp::attachInterface<
+        memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
+    memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
+    memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+    memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
+    memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
+    memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
new file mode 100644 (file)
index 0000000..0e0f216
--- /dev/null
@@ -0,0 +1,86 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
+// RUN:     -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @memref_alloc(
+//  CHECK-SAME:     %[[sz:.*]]: index
+//       CHECK:   %[[c6:.*]] = arith.constant 6 : index
+//       CHECK:   return %[[c6]], %[[sz]]
+func.func @memref_alloc(%sz: index) -> (index, index) {
+  %0 = memref.alloc(%sz) : memref<6x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_alloca(
+//  CHECK-SAME:     %[[sz:.*]]: index
+//       CHECK:   %[[c6:.*]] = arith.constant 6 : index
+//       CHECK:   return %[[c6]], %[[sz]]
+func.func @memref_alloca(%sz: index) -> (index, index) {
+  %0 = memref.alloca(%sz) : memref<6x?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index)
+  %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index)
+  return %1, %2 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_cast(
+//       CHECK:   %[[c10:.*]] = arith.constant 10 : index
+//       CHECK:   return %[[c10]]
+func.func @memref_cast(%m: memref<10xf32>) -> index {
+  %0 = memref.cast %m : memref<10xf32> to memref<?xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_dim(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32>
+//       CHECK:   %[[dim:.*]] = memref.dim %[[m]]
+//       CHECK:   %[[dim:.*]] = memref.dim %[[m]]
+//       CHECK:   return %[[dim]]
+func.func @memref_dim(%m: memref<?xf32>) -> index {
+  %c0 = arith.constant 0 : index
+  %0 = memref.dim %m, %c0 : memref<?xf32>
+  %1 = "test.reify_bound"(%0) : (index) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_get_global(
+//       CHECK:   %[[c4:.*]] = arith.constant 4 : index
+//       CHECK:   return %[[c4]]
+memref.global "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]>
+func.func @memref_get_global() -> index {
+  %0 = memref.get_global @gv0 : memref<4xf32>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4xf32>) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_rank(
+//  CHECK-SAME:     %[[t:.*]]: memref<5xf32>
+//       CHECK:   %[[c1:.*]] = arith.constant 1 : index
+//       CHECK:   return %[[c1]]
+func.func @memref_rank(%m: memref<5xf32>) -> index {
+  %0 = memref.rank %m : memref<5xf32>
+  %1 = "test.reify_bound"(%0) : (index) -> (index)
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_subview(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32>, %[[sz:.*]]: index
+//       CHECK:   return %[[sz]]
+func.func @memref_subview(%m: memref<?xf32>, %sz: index) -> index {
+  %0 = memref.subview %m[2][%sz][1] : memref<?xf32> to memref<?xf32, strided<[1], offset: 2>>
+  %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?xf32, strided<[1], offset: 2>>) -> (index)
+  return %1 : index
+}
index 2ee2ecd..19e98a9 100644 (file)
@@ -10257,6 +10257,7 @@ cc_library(
     ),
     hdrs = [
         "include/mlir/Dialect/MemRef/IR/MemRef.h",
+        "include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h",
         "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h",
     ],
     includes = ["include"],
@@ -10271,6 +10272,7 @@ cc_library(
         ":MemRefBaseIncGen",
         ":MemRefOpsIncGen",
         ":ShapedOpInterfaces",
+        ":ValueBoundsOpInterface",
         ":ViewLikeInterface",
         "//llvm:Support",
         "//llvm:TargetParser",