[MLIR][Tensor] Use the existing helper function `applyPermutationToVector` (NFC)
authorLorenzo Chelini <l.chelini@icloud.com>
Tue, 22 Nov 2022 09:19:34 +0000 (10:19 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Tue, 22 Nov 2022 10:34:44 +0000 (11:34 +0100)
Avoid duplicate code by using an existing helper function to interchange
a vector based on a permutation. Address comments emerged after landing
D138119.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D138480

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

index 95dbd77..c5d7e42 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BlockAndValueMapping.h"
@@ -3200,19 +3201,6 @@ LogicalResult PackOp::verify() {
   return success();
 }
 
-/// Returns a vector that interchanges `elements` starting at offset `offset`
-/// based on the indexes in `interchangeVector`.
-template <typename T>
-SmallVector<T> interchange(ArrayRef<T> elements,
-                           ArrayRef<int64_t> interchangeVector,
-                           int offset = 0) {
-  SmallVector<T> vec = llvm::to_vector(elements);
-  for (auto en : llvm::enumerate(interchangeVector))
-    vec[en.index() + offset] = elements[en.value() + offset];
-
-  return vec;
-}
-
 /// Get the expected packed type based on source type, tile factors, position of
 /// the inner tiles and permutation of the outer tiled loop.
 ShapedType PackOp::inferPackedType(ShapedType sourceType,
@@ -3231,7 +3219,8 @@ ShapedType PackOp::inferPackedType(ShapedType sourceType,
                                             innerTileSizes[tiledDim.index()]);
   }
 
-  resultShape = interchange<int64_t>(resultShape, outerDimsPerm);
+  if (!outerDimsPerm.empty())
+    applyPermutationToVector(resultShape, outerDimsPerm);
 
   // Append the inner tile dimensions.
   resultShape.append(innerTileSizes.begin(), innerTileSizes.end());