[mlir] add permutation utility
authorAart Bik <ajcbik@google.com>
Tue, 24 Aug 2021 02:00:38 +0000 (19:00 -0700)
committerAart Bik <ajcbik@google.com>
Tue, 24 Aug 2021 15:07:40 +0000 (08:07 -0700)
I found myself typing this code several times at different places
by now, so time to make this a general utility instead. Given
a permutation, it returns the permuted position of the input,
for example (i,j,k) -> (k,i,j) yields position 1 for input 0.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D108347

mlir/include/mlir/IR/AffineMap.h
mlir/lib/IR/AffineMap.cpp

index f687253..906c53d 100644 (file)
@@ -162,6 +162,10 @@ public:
   /// when the caller knows it is safe to do so.
   unsigned getDimPosition(unsigned idx) const;
 
+  /// Extracts the permuted position where given input index resides.
+  /// Fails when called on a non-permutation.
+  unsigned getPermutedPosition(unsigned input) const;
+
   /// Return true if any affine expression involves AffineDimExpr `position`.
   bool isFunctionOfDim(unsigned position) const {
     return llvm::any_of(getResults(), [&](AffineExpr e) {
index 2257617..9c6f25d 100644 (file)
@@ -336,6 +336,14 @@ unsigned AffineMap::getDimPosition(unsigned idx) const {
   return getResult(idx).cast<AffineDimExpr>().getPosition();
 }
 
+unsigned AffineMap::getPermutedPosition(unsigned input) const {
+  assert(isPermutation() && "invalid permutation request");
+  for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
+    if (getDimPosition(i) == input)
+      return i;
+  llvm_unreachable("incorrect permutation request");
+}
+
 /// Folds the results of the application of an affine map on the provided
 /// operands to a constant if possible. Returns false if the folding happens,
 /// true otherwise.