[mlir][sparse] fix incorrect coordinates ordering computed by the foreach operation.
authorPeiming Liu <peiming@google.com>
Sat, 12 Nov 2022 01:00:44 +0000 (01:00 +0000)
committerPeiming Liu <peiming@google.com>
Sat, 12 Nov 2022 04:08:50 +0000 (04:08 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137877

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

index 6da7b8a..095d236 100644 (file)
@@ -791,7 +791,7 @@ public:
     SmallVector<Value, 4> args;
     // Remap coordinates.
     for (int64_t i = 0; i < rank; i++) {
-      Value actual = coords[toOrigDim(enc, i)];
+      Value actual = coords[toStoredDim(enc, i)];
       args.push_back(actual);
     }
     // Remap value.
index 8bab24d..706a92f 100644 (file)
   dimOrdering = affine_map<(i,j) -> (j,i)>
 }>
 
+#CCCPerm = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed", "compressed"],
+  dimOrdering = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+}>
+
 module {
   /// uses foreach operator to print coords and values.
   func.func @foreach_print_const() {
@@ -97,6 +102,18 @@ module {
      return
   }
 
+  func.func @foreach_print_3d(%arg0: tensor<7x8x9xf64, #CCCPerm>) {
+    sparse_tensor.foreach in %arg0 : tensor<7x8x9xf64, #CCCPerm> do {
+      ^bb0(%1: index, %2: index, %3: index, %v: f64) :
+        vector.print %1: index
+        vector.print %2: index
+        vector.print %3: index
+        vector.print %v: f64
+     }
+     return
+  }
+
+
   func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) {
     sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do {
     ^bb0(%1: index, %2: index, %v: f64) :
@@ -118,7 +135,11 @@ module {
        [[  1.0,  2.0],
         [  5.0,  6.0]]
     > : tensor<2x2xf64>
-
+    
+    %src3d = arith.constant sparse<
+       [[1, 2, 3], [4, 5, 6]], [1.0, 2.0] 
+    > : tensor<7x8x9xf64>
+    
     //
     // Convert dense tensor directly to various sparse tensors.
     //
@@ -127,6 +148,7 @@ module {
     %s3 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #DCSC>
     %s4 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #SortedCOO>
     %s5 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #SortedCOOPerm>
+    %s6 = sparse_tensor.convert %src3d : tensor<7x8x9xf64>  to tensor<7x8x9xf64, #CCCPerm>
     // CHECK: 0
     // CHECK-NEXT: 0
     // CHECK-NEXT: 1
@@ -212,6 +234,16 @@ module {
     // CHECK-NEXT: 1
     // CHECK-NEXT: 6
     call @foreach_print_5(%s5) : (tensor<2x2xf64, #SortedCOOPerm>) -> ()
+
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 3
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 4
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 6
+    // CHECK-NEXT: 2
+    call @foreach_print_3d(%s6): (tensor<7x8x9xf64, #CCCPerm>) -> ()
     
     bufferization.dealloc_tensor %s1 : tensor<2x2xf64, #Row>
     bufferization.dealloc_tensor %s2 : tensor<2x2xf64, #CSR>