From 875ee0ed1c5af58cb4909f239093e25a35d7a21a Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Fri, 1 Jul 2022 12:41:01 -0700 Subject: [PATCH] [mlir][sparse] Reducing computational complexity This is a followup to D128847. The `AffineMap::getPermutedPosition` method performs a linear scan of the map, thus the previous implementation had asymptotic complexity of `O(|topSort| * |m|)`. This change reduces that to `O(|topSort| + |m|)`. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D129011 --- mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 614700a..1671550 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -115,9 +115,16 @@ struct CodeGen { static AffineMap permute(MLIRContext *context, AffineMap m, std::vector &topSort) { unsigned sz = topSort.size(); + assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch"); + // Construct the inverse of `m`; to avoid the asymptotic complexity + // of calling `m.getPermutedPosition` repeatedly. + SmallVector inv(sz); + for (unsigned i = 0; i < sz; i++) + inv[i] = m.getDimPosition(i); + // Construct the permutation. SmallVector perm(sz); for (unsigned i = 0; i < sz; i++) - perm[i] = m.getPermutedPosition(topSort[i]); + perm[i] = inv[topSort[i]]; return AffineMap::getPermutationMap(perm, context); } -- 2.7.4