From b1d3afc93e0e6bdfbc1105b48bc4caed0c880be2 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Thu, 1 Dec 2022 17:04:09 -0800 Subject: [PATCH] [mlir] Factor more common utils to IndexingUtils Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D139159 --- mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 4 --- mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 6 ++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 16 ++------- .../lib/Dialect/Linalg/Transforms/HoistPadding.cpp | 2 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 3 +- .../Dialect/SCF/Transforms/TileUsingInterface.cpp | 41 ++++------------------ mlir/lib/Dialect/Utils/IndexingUtils.cpp | 19 ++++++++++ 7 files changed, 36 insertions(+), 55 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index e231bdd..28c75fc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -70,10 +70,6 @@ AffineMap extractOrIdentityMap(Optional maybeMap, unsigned rank, SmallVector concat(ArrayRef a, ArrayRef b); -/// Check if `permutation` is a permutation of the range -/// `[0, permutation.size())`. -bool isPermutation(ArrayRef permutation); - } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index ee1d455..bc58c12 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -75,6 +75,12 @@ void applyPermutationToVector(SmallVector &inVec, inVec = auxVec; } +/// Helper method to apply to inverse a permutation. +SmallVector invertPermutationVector(ArrayRef permutation); + +/// Method to check if an interchange vector is a permutation. +bool isPermutationVector(ArrayRef interchange); + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 1f6a8fd..a6f42c9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" @@ -1392,7 +1393,7 @@ void TransposeOp::print(OpAsmPrinter &p) { LogicalResult TransposeOp::verify() { ArrayRef permutationRef = getPermutation(); - if (!isPermutation(permutationRef)) + if (!isPermutationVector(permutationRef)) return emitOpError("permutation is not valid"); auto inputType = getInput().getType(); @@ -1683,19 +1684,6 @@ SmallVector mlir::linalg::concat(ArrayRef a, return llvm::to_vector<4>(concatRanges); } -bool mlir::linalg::isPermutation(ArrayRef permutation) { - // Count the number of appearances for all indices. - SmallVector indexCounts(permutation.size(), 0); - for (auto index : permutation) { - // Exit if the index is out-of-range. - if (index < 0 || index >= static_cast(permutation.size())) - return false; - ++indexCounts[index]; - } - // Return true if all indices appear once. - return count(indexCounts, 1) == static_cast(permutation.size()); -} - static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { if (auto memref = t.dyn_cast()) { ss << "view"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index bff2d54..a83305d0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -151,7 +151,7 @@ computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector) { if (transposeVector.empty()) return rankedTensorType; - if (!isPermutation(transposeVector) || + if (!isPermutationVector(transposeVector) || transposeVector.size() != static_cast(rankedTensorType.getRank())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 712d4c2..3267b2a 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" @@ -409,7 +410,7 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor, auto resultTensorType = outputTensor.getType().cast(); Type elementType = resultTensorType.getElementType(); - assert(isPermutation(transposeVector) && + assert(isPermutationVector(transposeVector) && "expect transpose vector to be a permutation"); assert(transposeVector.size() == static_cast(resultTensorType.getRank()) && diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index aea0f07..092f853 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" @@ -59,34 +60,6 @@ fillInterchangeVector(ArrayRef interchangeVector, return filledVector; } -/// Helper method to apply permutation to a vector -template -static SmallVector applyPermutationToVector(const SmallVector &vector, - ArrayRef interchange) { - assert(interchange.size() == vector.size()); - return llvm::to_vector( - llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); -} -/// Helper method to apply to invert a permutation. -static SmallVector -invertPermutationVector(ArrayRef interchange) { - SmallVector inversion(interchange.size()); - for (const auto &pos : llvm::enumerate(interchange)) { - inversion[pos.value()] = pos.index(); - } - return inversion; -} -/// Method to check if an interchange vector is a permutation. -static bool isPermutation(ArrayRef interchange) { - llvm::SmallDenseSet seenVals; - for (auto val : interchange) { - if (seenVals.count(val)) - return false; - seenVals.insert(val); - } - return seenVals.size() == interchange.size(); -} - //===----------------------------------------------------------------------===// // tileUsingSCFForOp implementation. //===----------------------------------------------------------------------===// @@ -321,16 +294,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, iterationDomain.size()); } if (!interchangeVector.empty()) { - if (!isPermutation(interchangeVector)) { + if (!isPermutationVector(interchangeVector)) { return rewriter.notifyMatchFailure( op, "invalid intechange vector, not a permutation of the entire " "iteration space"); } - iterationDomain = - applyPermutationToVector(iterationDomain, interchangeVector); - tileSizeVector = - applyPermutationToVector(tileSizeVector, interchangeVector); + applyPermutationToVector(iterationDomain, interchangeVector); + applyPermutationToVector(tileSizeVector, interchangeVector); } // 3. Materialize an empty loop nest that iterates over the tiles. These @@ -341,8 +312,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, if (!interchangeVector.empty()) { auto inversePermutation = invertPermutationVector(interchangeVector); - offsets = applyPermutationToVector(offsets, inversePermutation); - sizes = applyPermutationToVector(sizes, inversePermutation); + applyPermutationToVector(offsets, inversePermutation); + applyPermutationToVector(sizes, inversePermutation); } } diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index 1c7a89c..ed92ade 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -86,6 +86,25 @@ int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { std::multiplies()); } +llvm::SmallVector +mlir::invertPermutationVector(ArrayRef permutation) { + SmallVector inversion(permutation.size()); + for (const auto &pos : llvm::enumerate(permutation)) { + inversion[pos.value()] = pos.index(); + } + return inversion; +} + +bool mlir::isPermutationVector(ArrayRef interchange) { + llvm::SmallDenseSet seenVals; + for (auto val : interchange) { + if (seenVals.count(val)) + return false; + seenVals.insert(val); + } + return seenVals.size() == interchange.size(); +} + llvm::SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront, unsigned dropBack) { -- 2.7.4