[mlir][sparse] fixed bug with unary op, dense output
authorAart Bik <ajcbik@google.com>
Fri, 2 Jun 2023 23:41:49 +0000 (16:41 -0700)
committerAart Bik <ajcbik@google.com>
Sat, 3 Jun 2023 01:15:33 +0000 (18:15 -0700)
Note that by sparse compiler convention, dense output
is zerod out when not set, so complement results in
zeros where elements were present.

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D152046

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir

index 7ebb602..d9f363a 100644 (file)
@@ -1049,50 +1049,52 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
 /// Generates a store on a dense or sparse tensor.
 static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
                            Value rhs) {
-  linalg::GenericOp op = env.op();
-  Location loc = op.getLoc();
+  // Only unary and binary are allowed to return uninitialized rhs
+  // to indicate missing output.
+  if (!rhs) {
+    assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
+           env.exp(exp).kind == TensorExp::Kind::kBinary);
+    return;
+  }
   // Test if this is a scalarized reduction.
   if (env.isReduc()) {
     env.updateReduc(rhs);
     return;
   }
-  // Store during insertion.
+  // Regular store.
+  linalg::GenericOp op = env.op();
+  Location loc = op.getLoc();
   OpOperand *t = op.getDpsInitOperand(0);
-  if (env.isSparseOutput(t)) {
-    if (!rhs) {
-      // Only unary and binary are allowed to return uninitialized rhs
-      // to indicate missing output.
-      assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
-             env.exp(exp).kind == TensorExp::Kind::kBinary);
-    } else if (env.exp(exp).kind == TensorExp::Kind::kSelect) {
-      // Select operation insertion.
-      Value chain = env.getInsertionChain();
-      scf::IfOp ifOp =
-          builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
-      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      // Existing value was preserved to be used here.
-      assert(env.exp(exp).val);
-      Value v0 = env.exp(exp).val;
-      genInsertionStore(env, builder, t, v0);
-      env.merger().clearExprValue(exp);
-      // Yield modified insertion chain along true branch.
-      Value mchain = env.getInsertionChain();
-      builder.create<scf::YieldOp>(op.getLoc(), mchain);
-      // Yield original insertion chain along false branch.
-      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-      builder.create<scf::YieldOp>(loc, chain);
-      // Done with if statement.
-      env.updateInsertionChain(ifOp->getResult(0));
-      builder.setInsertionPointAfter(ifOp);
-    } else {
-      genInsertionStore(env, builder, t, rhs);
-    }
+  if (!env.isSparseOutput(t)) {
+    SmallVector<Value> args;
+    Value ptr = genSubscript(env, builder, t, args);
+    builder.create<memref::StoreOp>(loc, rhs, ptr, args);
     return;
   }
-  // Actual store.
-  SmallVector<Value> args;
-  Value ptr = genSubscript(env, builder, t, args);
-  builder.create<memref::StoreOp>(loc, rhs, ptr, args);
+  // Store during sparse insertion.
+  if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
+    genInsertionStore(env, builder, t, rhs);
+    return;
+  }
+  // Select operation insertion.
+  Value chain = env.getInsertionChain();
+  scf::IfOp ifOp =
+      builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  // Existing value was preserved to be used here.
+  assert(env.exp(exp).val);
+  Value v0 = env.exp(exp).val;
+  genInsertionStore(env, builder, t, v0);
+  env.merger().clearExprValue(exp);
+  // Yield modified insertion chain along true branch.
+  Value mchain = env.getInsertionChain();
+  builder.create<scf::YieldOp>(op.getLoc(), mchain);
+  // Yield original insertion chain along false branch.
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, chain);
+  // Done with if statement.
+  env.updateInsertionChain(ifOp->getResult(0));
+  builder.setInsertionPointAfter(ifOp);
 }
 
 /// Generates an invariant value.
index 63c6d0e..462addf 100644 (file)
 //
 // Traits for tensor operations.
 //
-#trait_vec_scale = {
+#trait_vec = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a (in)
     affine_map<(i) -> (i)>   // x (out)
   ],
   iterator_types = ["parallel"]
 }
-#trait_mat_scale = {
+#trait_mat = {
   indexing_maps = [
     affine_map<(i,j) -> (i,j)>,  // A (in)
     affine_map<(i,j) -> (i,j)>   // X (out)
 
 module {
   // Invert the structure of a sparse vector. Present values become missing.
-  // Missing values are filled with 1 (i32).
-  func.func @vector_complement(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
+  // Missing values are filled with 1 (i32). Output is sparse.
+  func.func @vector_complement_sparse(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
     %c = arith.constant 0 : index
     %ci1 = arith.constant 1 : i32
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
     %xv = bufferization.alloc_tensor(%d) : tensor<?xi32, #SparseVector>
-    %0 = linalg.generic #trait_vec_scale
+    %0 = linalg.generic #trait_vec
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xi32, #SparseVector>) {
         ^bb(%a: f64, %x: i32):
@@ -69,13 +69,35 @@ module {
     return %0 : tensor<?xi32, #SparseVector>
   }
 
+  // Invert the structure of a sparse vector, where missing values are
+  // filled with 1. For a dense output, the sparse compiler initializes
+  // the buffer to all zero at all other places.
+  func.func @vector_complement_dense(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32> {
+    %c = arith.constant 0 : index
+    %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d) : tensor<?xi32>
+    %0 = linalg.generic #trait_vec
+       ins(%arga: tensor<?xf64, #SparseVector>)
+        outs(%xv: tensor<?xi32>) {
+        ^bb(%a: f64, %x: i32):
+          %1 = sparse_tensor.unary %a : f64 to i32
+            present={}
+            absent={
+              %ci1 = arith.constant 1 : i32
+              sparse_tensor.yield %ci1 : i32
+            }
+          linalg.yield %1 : i32
+    } -> tensor<?xi32>
+    return %0 : tensor<?xi32>
+  }
+
   // Negate existing values. Fill missing ones with +1.
   func.func @vector_negation(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
     %c = arith.constant 0 : index
     %cf1 = arith.constant 1.0 : f64
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
     %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
-    %0 = linalg.generic #trait_vec_scale
+    %0 = linalg.generic #trait_vec
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
         ^bb(%a: f64, %x: f64):
@@ -98,7 +120,7 @@ module {
     %c = arith.constant 0 : index
     %d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
     %xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
-    %0 = linalg.generic #trait_vec_scale
+    %0 = linalg.generic #trait_vec
        ins(%arga: tensor<?xf64, #SparseVector>)
         outs(%xv: tensor<?xf64, #SparseVector>) {
         ^bb(%a: f64, %x: f64):
@@ -126,7 +148,7 @@ module {
     %d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
     %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
-    %0 = linalg.generic #trait_mat_scale
+    %0 = linalg.generic #trait_mat
        ins(%argx: tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {
         ^bb(%a: f64, %x: f64):
@@ -153,7 +175,7 @@ module {
     %d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
     %d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
     %xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
-    %0 = linalg.generic #trait_mat_scale
+    %0 = linalg.generic #trait_mat
        ins(%argx: tensor<?x?xf64, #DCSR>)
         outs(%xv: tensor<?x?xf64, #DCSR>) {
         ^bb(%a: f64, %x: f64):
@@ -223,6 +245,7 @@ module {
 
   // Driver method to call and verify vector kernels.
   func.func @entry() {
+    %cmu = arith.constant -99 : i32
     %c0 = arith.constant 0 : index
 
     // Setup sparse vectors.
@@ -240,7 +263,7 @@ module {
     %sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
 
     // Call sparse vector kernels.
-    %0 = call @vector_complement(%sv1)
+    %0 = call @vector_complement_sparse(%sv1)
        : (tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector>
     %1 = call @vector_negation(%sv1)
        : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
@@ -253,6 +276,9 @@ module {
     %4 = call @matrix_slice(%sm1)
       : (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
 
+    // Call kernel with dense output.
+    %5 = call @vector_complement_dense(%sv1) : (tensor<?xf64, #SparseVector>) -> tensor<?xi32>
+
     //
     // Verify the results.
     //
@@ -268,6 +294,7 @@ module {
     // CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) )
     // CHECK-NEXT: ( 99, 99, 99, 99, 5, 6, 99, 99, 99, 0, 0, 0, 0, 0, 0, 0 )
     // CHECK-NEXT: ( ( 99, 99, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 99 ), ( 0, 0, 99, 0, 5, 0, 0, 6 ), ( 99, 0, 99, 99, 0, 0, 0, 0 ) )
+    // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 )
     //
     call @dump_vec_f64(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
     call @dump_vec_i32(%0) : (tensor<?xi32, #SparseVector>) -> ()
@@ -275,6 +302,8 @@ module {
     call @dump_vec_f64(%2) : (tensor<?xf64, #SparseVector>) -> ()
     call @dump_mat(%3) : (tensor<?x?xf64, #DCSR>) -> ()
     call @dump_mat(%4) : (tensor<?x?xf64, #DCSR>) -> ()
+    %v = vector.transfer_read %5[%c0], %cmu: tensor<?xi32>, vector<32xi32>
+    vector.print %v : vector<32xi32>
 
     // Release the resources.
     bufferization.dealloc_tensor %sv1 : tensor<?xf64, #SparseVector>
@@ -284,6 +313,7 @@ module {
     bufferization.dealloc_tensor %2 : tensor<?xf64, #SparseVector>
     bufferization.dealloc_tensor %3 : tensor<?x?xf64, #DCSR>
     bufferization.dealloc_tensor %4 : tensor<?x?xf64, #DCSR>
+    bufferization.dealloc_tensor %5 : tensor<?xi32>
     return
   }
 }