SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
-/// Check if `permutation` is a permutation of the range
-/// `[0, permutation.size())`.
-bool isPermutation(ArrayRef<int64_t> permutation);
-
} // namespace linalg
} // namespace mlir
inVec = auxVec;
}
+/// Helper method to apply to inverse a permutation.
+SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
+
+/// Method to check if an interchange vector is a permutation.
+bool isPermutationVector(ArrayRef<int64_t> interchange);
+
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
LogicalResult TransposeOp::verify() {
ArrayRef<int64_t> permutationRef = getPermutation();
- if (!isPermutation(permutationRef))
+ if (!isPermutationVector(permutationRef))
return emitOpError("permutation is not valid");
auto inputType = getInput().getType();
return llvm::to_vector<4>(concatRanges);
}
-bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
- // Count the number of appearances for all indices.
- SmallVector<int64_t> indexCounts(permutation.size(), 0);
- for (auto index : permutation) {
- // Exit if the index is out-of-range.
- if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
- return false;
- ++indexCounts[index];
- }
- // Return true if all indices appear once.
- return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
-}
-
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";
ArrayRef<int64_t> transposeVector) {
if (transposeVector.empty())
return rankedTensorType;
- if (!isPermutation(transposeVector) ||
+ if (!isPermutationVector(transposeVector) ||
transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
return failure();
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
Type elementType = resultTensorType.getElementType();
- assert(isPermutation(transposeVector) &&
+ assert(isPermutationVector(transposeVector) &&
"expect transpose vector to be a permutation");
assert(transposeVector.size() ==
static_cast<size_t>(resultTensorType.getRank()) &&
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
return filledVector;
}
-/// Helper method to apply permutation to a vector
-template <typename T>
-static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
- ArrayRef<int64_t> interchange) {
- assert(interchange.size() == vector.size());
- return llvm::to_vector(
- llvm::map_range(interchange, [&](int64_t val) { return vector[val]; }));
-}
-/// Helper method to apply to invert a permutation.
-static SmallVector<int64_t>
-invertPermutationVector(ArrayRef<int64_t> interchange) {
- SmallVector<int64_t> inversion(interchange.size());
- for (const auto &pos : llvm::enumerate(interchange)) {
- inversion[pos.value()] = pos.index();
- }
- return inversion;
-}
-/// Method to check if an interchange vector is a permutation.
-static bool isPermutation(ArrayRef<int64_t> interchange) {
- llvm::SmallDenseSet<int64_t, 4> seenVals;
- for (auto val : interchange) {
- if (seenVals.count(val))
- return false;
- seenVals.insert(val);
- }
- return seenVals.size() == interchange.size();
-}
-
//===----------------------------------------------------------------------===//
// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
iterationDomain.size());
}
if (!interchangeVector.empty()) {
- if (!isPermutation(interchangeVector)) {
+ if (!isPermutationVector(interchangeVector)) {
return rewriter.notifyMatchFailure(
op, "invalid intechange vector, not a permutation of the entire "
"iteration space");
}
- iterationDomain =
- applyPermutationToVector(iterationDomain, interchangeVector);
- tileSizeVector =
- applyPermutationToVector(tileSizeVector, interchangeVector);
+ applyPermutationToVector(iterationDomain, interchangeVector);
+ applyPermutationToVector(tileSizeVector, interchangeVector);
}
// 3. Materialize an empty loop nest that iterates over the tiles. These
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
- offsets = applyPermutationToVector(offsets, inversePermutation);
- sizes = applyPermutationToVector(sizes, inversePermutation);
+ applyPermutationToVector(offsets, inversePermutation);
+ applyPermutationToVector(sizes, inversePermutation);
}
}
std::multiplies<int64_t>());
}
+llvm::SmallVector<int64_t>
+mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
+ SmallVector<int64_t> inversion(permutation.size());
+ for (const auto &pos : llvm::enumerate(permutation)) {
+ inversion[pos.value()] = pos.index();
+ }
+ return inversion;
+}
+
+bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
+ llvm::SmallDenseSet<int64_t, 4> seenVals;
+ for (auto val : interchange) {
+ if (seenVals.count(val))
+ return false;
+ seenVals.insert(val);
+ }
+ return seenVals.size() == interchange.size();
+}
+
llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {