From f3ece29b4658d60a1e7656bb9e67853376d094b7 Mon Sep 17 00:00:00 2001 From: Laszlo Kindrat Date: Mon, 15 May 2023 15:04:19 -0400 Subject: [PATCH] [ADT] Allow specifying the size of resulting `SmallVector` in `map_to_vector` This patch adds an overload for the `map_to_vector` helper template, exposing a parameter to control the size of the resulting `SmallVector`. A few call sites in mlir are updated to illustrate and test the change. Differential Revision: https://reviews.llvm.org/D150601 --- llvm/include/llvm/ADT/SmallVectorExtras.h | 5 ++++ mlir/include/mlir/IR/AffineMap.h | 17 ++++++------- mlir/lib/IR/Builders.cpp | 40 +++++++++++++++---------------- mlir/lib/IR/TypeUtilities.cpp | 13 +++++----- 4 files changed, 40 insertions(+), 35 deletions(-) diff --git a/llvm/include/llvm/ADT/SmallVectorExtras.h b/llvm/include/llvm/ADT/SmallVectorExtras.h index 8d52280..d5159aa 100644 --- a/llvm/include/llvm/ADT/SmallVectorExtras.h +++ b/llvm/include/llvm/ADT/SmallVectorExtras.h @@ -20,6 +20,11 @@ namespace llvm { /// Map a range to a SmallVector with element types deduced from the mapping. +template +auto map_to_vector(ContainerTy &&C, FuncTy &&F) { + return to_vector( + map_range(std::forward(C), std::forward(F))); +} template auto map_to_vector(ContainerTy &&C, FuncTy &&F) { return to_vector( diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index e21dc9c..01cd718 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -19,6 +19,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include namespace llvm { @@ -226,11 +227,11 @@ public: AffineMap shiftDims(unsigned shift, unsigned offset = 0) const { assert(offset <= getNumDims()); return AffineMap::get(getNumDims() + shift, getNumSymbols(), - llvm::to_vector<4>(llvm::map_range( + llvm::map_to_vector<4>( getResults(), [&](AffineExpr e) { return e.shiftDims(getNumDims(), shift, offset); - })), + }), getContext()); } @@ -238,12 +239,12 @@ public: /// by symbols[offset + shift ... shift + numSymbols). AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const { return AffineMap::get(getNumDims(), getNumSymbols() + shift, - llvm::to_vector<4>(llvm::map_range( - getResults(), - [&](AffineExpr e) { - return e.shiftSymbols(getNumSymbols(), shift, - offset); - })), + llvm::map_to_vector<4>(getResults(), + [&](AffineExpr e) { + return e.shiftSymbols( + getNumSymbols(), shift, + offset); + }), getContext()); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 6cbba06..c4fad9c 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -261,57 +262,56 @@ ArrayAttr Builder::getArrayAttr(ArrayRef value) { } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](bool v) -> Attribute { return getBoolAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](bool v) -> Attribute { return getBoolAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getI64ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getIndexArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>( - llvm::map_range(values, [this](int64_t v) -> Attribute { - return getIntegerAttr(IndexType::get(getContext()), v); - })); + auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute { + return getIntegerAttr(IndexType::get(getContext()), v); + }); return getArrayAttr(attrs); } ArrayAttr Builder::getF32ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](float v) -> Attribute { return getF32FloatAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](float v) -> Attribute { return getF32FloatAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getF64ArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](double v) -> Attribute { return getF64FloatAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](double v) -> Attribute { return getF64FloatAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getStrArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [this](StringRef v) -> Attribute { return getStringAttr(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [this](StringRef v) -> Attribute { return getStringAttr(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getTypeArrayAttr(TypeRange values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [](Type v) -> Attribute { return TypeAttr::get(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [](Type v) -> Attribute { return TypeAttr::get(v); }); return getArrayAttr(attrs); } ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef values) { - auto attrs = llvm::to_vector<8>(llvm::map_range( - values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); })); + auto attrs = llvm::map_to_vector<8>( + values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }); return getArrayAttr(attrs); } diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 7aa37cb..6926beb 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -11,13 +11,12 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/TypeUtilities.h" - -#include - #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include using namespace mlir; @@ -119,8 +118,8 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef dims) { /// have compatible dimensions. Dimensions are compatible if all non-dynamic /// dims are equal. The element type does not matter. LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { - auto shapedTypes = llvm::to_vector<8>(llvm::map_range( - types, [](auto type) { return llvm::dyn_cast(type); })); + auto shapedTypes = llvm::map_to_vector<8>( + types, [](auto type) { return llvm::dyn_cast(type); }); // Return failure if some, but not all are not shaped. Return early if none // are shaped also. if (llvm::none_of(shapedTypes, [](auto t) { return t; })) @@ -155,10 +154,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) { for (unsigned i = 0; i < firstRank; ++i) { // Retrieve all ranked dimensions - auto dims = llvm::to_vector<8>(llvm::map_range( + auto dims = llvm::map_to_vector<8>( llvm::make_filter_range( shapes, [&](auto shape) { return shape.getRank() >= i; }), - [&](auto shape) { return shape.getDimSize(i); })); + [&](auto shape) { return shape.getDimSize(i); }); if (verifyCompatibleDims(dims).failed()) return failure(); } -- 2.7.4