[mlir][sparse] support foreach on dense tensor.
authorPeiming Liu <peiming@google.com>
Thu, 20 Oct 2022 22:05:28 +0000 (22:05 +0000)
committerPeiming Liu <peiming@google.com>
Fri, 21 Oct 2022 00:12:37 +0000 (00:12 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136384

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir

index edd19d2..e8e5a3e 100644 (file)
@@ -834,16 +834,16 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator]>,
 
 def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
     [SingleBlockImplicitTerminator<"YieldOp">]>,
-    Arguments<(ins AnySparseTensor:$tensor)>{
-  let summary = "Iterates over non-zero elements in a sparse tensor";
+    Arguments<(ins AnyTensor:$tensor)>{
+  let summary = "Iterates over elements in a tensor";
   let description = [{
-     Iterates over every non-zero element in the given sparse tensor and executes
-     the block.
+     Iterates over stored elements in a tensor (which are typically, but not always,
+     non-zero for sparse tensors) and executes the block.
 
-     For a input sparse tensor with rank n, the block must take n + 1 arguments. The
+     For an input tensor with rank n, the block must take n + 1 arguments. The
      first n arguments must be Index type, together indicating the current coordinates
      of the element being visited. The last argument must have the same type as the
-     sparse tensor's element type, representing the actual value loaded from the input
+     tensor's element type, representing the actual value loaded from the input
      tensor at the given coordinates.
 
      Example:
index 4707115..fc5fb76 100644 (file)
@@ -480,14 +480,17 @@ public:
     for (int64_t i = 0; i < rank; i++)
       loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
 
-    Value vals = loopEmitter.getTensorValueBuffer(0);
-    Value idx = loopEmitter.getLastLevelTensorPointerIndex(0);
-    Value val = rewriter.create<memref::LoadOp>(op.getLoc(), vals, idx);
-
     SmallVector<Value, 4> coords;
     coords.reserve(rank);
     loopEmitter.getCoordinateArray(coords);
 
+    Value vals = loopEmitter.getTensorValueBuffer(0);
+    Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0);
+    // Loads the value from sparse tensor using pointer index;
+    // loads the value from dense tensor using coordinate array.
+    Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pidx)
+                    : rewriter.create<memref::LoadOp>(loc, vals, coords);
+
     for (int64_t i = 0; i < rank; i++)
       loopEmitter.exitCurrentLoop();
 
index 685b34b..aeb63a0 100644 (file)
@@ -78,6 +78,16 @@ module {
      return
   }
 
+  func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) {
+    sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do {
+    ^bb0(%1: index, %2: index, %v: f64) :
+      vector.print %1: index
+      vector.print %2: index
+      vector.print %v: f64
+   }
+   return
+  }
+  
   //
   // Main driver.
   //
@@ -109,6 +119,19 @@ module {
     // CHECK-NEXT: 5
     // CHECK-NEXT: 1
     // CHECK-NEXT: 1
+    // CHECK-NEXT: 6    
+    call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> ()
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 2
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 0
+    // CHECK-NEXT: 5
+    // CHECK-NEXT: 1
+    // CHECK-NEXT: 1
     // CHECK-NEXT: 6
     call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> ()
     // CHECK-NEXT: 0