From 27224fe7272a791bcc9f28c997ce322f7d3856cd Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Mon, 26 Sep 2022 10:39:46 +0200 Subject: [PATCH] [MLIR] Expose `getAsValues` in `StaticValueUtils.h` (NFC) The utility function should live in `StaticValueUtils.h` as it provides a convenient way to convert a vector of OpFoldResults into a vector of Values. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134451 --- mlir/include/mlir/Dialect/Utils/StaticValueUtils.h | 7 +++++++ .../Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp | 12 +----------- mlir/lib/Dialect/Utils/CMakeLists.txt | 1 + mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 13 +++++++++++++ 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index f290b1e..f09cf88 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -80,6 +80,13 @@ bool isConstantIntValue(OpFoldResult ofr, int64_t value); /// no IndexAttr and that IndexType have no bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result +/// if it casts to a `Value` or create an index-type constant if it casts to +/// `IntegerAttr`. No other attribute types are supported. +SmallVector getAsValues(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 3a6be9f..df65eee 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -8,8 +8,8 @@ #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/InferTypeOpInterface.h" using namespace mlir; @@ -134,16 +134,6 @@ getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, builder, loc, src, dstStaticShape, reassocation); } -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. -static SmallVector getAsValues(OpBuilder &b, Location loc, - ArrayRef valueOrAttrVec) { - return llvm::to_vector<4>( - llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { - return getValueOrCreateConstantIndexOp(b, loc, value); - })); -} - template struct ReifyExpandOrCollapseShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel< diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt index f329afa..b93a30d 100644 --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -5,5 +5,6 @@ add_mlir_library(MLIRDialectUtils StaticValueUtils.cpp LINK_LIBS PUBLIC + MLIRArithmeticUtils MLIRIR ) diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 6212df9..2392b1d 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" @@ -124,4 +125,16 @@ bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v1 == v2; } + +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result +/// if it casts to a `Value` or create an index-type constant if it casts to +/// `IntegerAttr`. No other attribute types are supported. +SmallVector getAsValues(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec) { + return llvm::to_vector<4>( + llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { + return getValueOrCreateConstantIndexOp(b, loc, value); + })); +} } // namespace mlir -- 2.7.4