[mlir][sparse] Fix getUnorderedCOOFromType for rank 1 tensor.
authorbixia1 <bixia@google.com>
Thu, 20 Oct 2022 22:43:03 +0000 (15:43 -0700)
committerbixia1 <bixia@google.com>
Thu, 20 Oct 2022 23:33:01 +0000 (16:33 -0700)
Previously, it used DimLevelType::SingletonNo to represent an unorder COO
tensor of rank 1 while it should use DimLevelType::CompressedNuNo.

Reviewed By: Peiming, wrengr

Differential Revision: https://reviews.llvm.org/D136387

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

index a852a15..4707115 100644 (file)
@@ -132,18 +132,19 @@ static RankedTensorType getUnorderedCOOFromType(RankedTensorType src) {
   auto rank = src.getRank();
   SmallVector<DimLevelType, 4> dims;
 
-  // An unordered and non-unique compressed dim at beginning unless the tensor
-  // is a 1D tensor.
-  if (rank > 1)
-    dims.push_back(DimLevelType::CompressedNuNo);
-
-  // TODO: it is actually ordered at the level for ordered input.
-  // Followed by unordered non-unique n-2 singleton levels.
-  std::fill_n(std::back_inserter(dims), rank - 2, DimLevelType::SingletonNuNo);
-  // TODO: only if all the inputs (for concatentate) are unique at the last
-  // level should the COO has a unique level at the end. Ends by a unordered
-  // unique singleton level.
-  dims.push_back(DimLevelType::SingletonNo);
+  // An unordered and non-unique compressed dim at beginning.
+  dims.push_back(DimLevelType::CompressedNuNo);
+
+  if (rank > 1) {
+    // TODO: it is actually ordered at the level for ordered input.
+    // Followed by unordered non-unique n-2 singleton levels.
+    std::fill_n(std::back_inserter(dims), rank - 2,
+                DimLevelType::SingletonNuNo);
+    // TODO: only if all the inputs (for concatentate) are unique at the last
+    // level should the COO has a unique level at the end. Ends by a unordered
+    // unique singleton level unless the tensor rank is 1.
+    dims.push_back(DimLevelType::SingletonNo);
+  }
   SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src);
   // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
   // largest one among them) in the original operation instead of using the