[mlir][sparse] Sparse reduction in lex order no longer produces dense output
authorJim Kitchen <jim22k@gmail.com>
Fri, 10 Feb 2023 19:08:49 +0000 (13:08 -0600)
committerJim Kitchen <jim22k@gmail.com>
Fri, 10 Feb 2023 19:09:28 +0000 (13:09 -0600)
Previously, when performing a reduction on a sparse tensor, the result
would be different depending on iteration order. For expanded access pattern,
an empty row would contribute no entry in the output. For lex ordering, the
identity would end up in the output.

This code changes that behavior and keeps track of whether any entries were
actually reduced in lex ordering, making the output consistent between the
two iteration styles.

Differential Revision: https://reviews.llvm.org/D142050

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_affine.mlir
mlir/test/Dialect/SparseTensor/sparse_out.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir

index 4a8fbf7..a8ebad3 100644 (file)
@@ -37,7 +37,7 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
       latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(),
       topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
       expFilled(), expAdded(), expCount(), redVal(), redExp(-1u),
-      redCustom(-1u) {}
+      redCustom(-1u), redValidLexInsert() {}
 
 LogicalResult CodegenEnv::initTensorExp() {
   // Builds the tensor expression for the Linalg operation in SSA form.
@@ -70,16 +70,24 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
     function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
         callback) {
   SmallVector<Value> params;
-  if (isReduc())
+  if (isReduc()) {
     params.push_back(redVal);
+    if (redValidLexInsert)
+      params.push_back(redValidLexInsert);
+  } else {
+    assert(!redValidLexInsert);
+  }
   if (isExpand())
     params.push_back(expCount);
   if (insChain != nullptr)
     params.push_back(insChain);
   auto r = callback(params); // may update parameters
   unsigned i = 0;
-  if (isReduc())
+  if (isReduc()) {
     updateReduc(params[i++]);
+    if (redValidLexInsert)
+      setValidLexInsert(params[i++]);
+  }
   if (isExpand())
     updateExpandCount(params[i++]);
   if (insChain != nullptr)
@@ -225,6 +233,16 @@ Value CodegenEnv::endReduc() {
   return val;
 }
 
+void CodegenEnv::setValidLexInsert(Value val) {
+  assert(isReduc() && val);
+  redValidLexInsert = val;
+}
+
+void CodegenEnv::clearValidLexInsert() {
+  assert(!isReduc());
+  redValidLexInsert = Value();
+}
+
 void CodegenEnv::startCustomReduc(unsigned exp) {
   assert(redCustom == -1u && exp != -1u);
   redCustom = exp;
index 6f16d8f..b210ca8 100644 (file)
@@ -132,6 +132,9 @@ public:
   void updateReduc(Value val);
   Value getReduc() const { return redVal; }
   Value endReduc();
+  void setValidLexInsert(Value val);
+  void clearValidLexInsert();
+  Value getValidLexInsert() const { return redValidLexInsert; }
 
   void startCustomReduc(unsigned exp);
   bool isCustomReduc() const { return redCustom != -1u; }
@@ -172,6 +175,11 @@ private:
   unsigned redExp;
   unsigned redCustom;
 
+  // Bookkeeping for lex insertion during reductions. Holds the runtime boolean
+  // value of whether any reduction occurred. This is only set during a
+  // reduction and cleared once the reduction is finished.
+  Value redValidLexInsert;
+
   // The root tensor expression of the kernel.
   unsigned tensorExp;
 };
index 72063a1..f3cf2ce 100644 (file)
@@ -651,8 +651,31 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
       indices.push_back(env.emitter().getLoopIV(i));
     }
     Value chain = env.getInsertionChain();
-    env.updateInsertionChain(
-        builder.create<InsertOp>(loc, rhs, chain, indices));
+    if (!env.getValidLexInsert()) {
+      env.updateInsertionChain(
+          builder.create<InsertOp>(loc, rhs, chain, indices));
+    } else {
+      // Generates runtime check for a valid lex during reduction,
+      // to avoid inserting the identity value for empty reductions.
+      //   if (validLexInsert) then
+      //     insert(rhs) into chain
+      //     return updated chain
+      //   else
+      //     return unmodified chain
+      scf::IfOp ifValidLexInsert = builder.create<scf::IfOp>(
+          loc, chain.getType(), env.getValidLexInsert(),
+          /*else=*/true);
+      // True branch.
+      builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
+      Value res = builder.create<InsertOp>(loc, rhs, chain, indices);
+      builder.create<scf::YieldOp>(loc, res);
+      // False branch.
+      builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
+      builder.create<scf::YieldOp>(loc, chain);
+      // Value assignment.
+      builder.setInsertionPointAfter(ifValidLexInsert);
+      env.updateInsertionChain(ifValidLexInsert.getResult(0));
+    }
     return;
   }
   // Generates insertion code along expanded access pattern.
@@ -857,13 +880,16 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp,
       return;
     OpOperand *lhs = op.getDpsInitOperand(0);
     if (lhs == &t) {
-      // Start or end a scalarized reduction
+      // Start or end a scalarized reduction.
       if (atStart) {
         Value load = env.isCustomReduc() ? env.getCustomRedId()
                                          : genTensorLoad(env, builder, exp);
         env.startReduc(exp, load);
+        if (env.hasSparseOutput())
+          env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
       } else {
         genTensorStore(env, builder, exp, env.endReduc());
+        env.clearValidLexInsert();
       }
     } else {
       // Start or end loop invariant hoisting of a tensor load.
@@ -1031,6 +1057,10 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx,
       if (env.isReduc()) {
         yields.push_back(env.getReduc());
         env.updateReduc(ifOp.getResult(y++));
+        if (env.getValidLexInsert()) {
+          yields.push_back(env.getValidLexInsert());
+          env.setValidLexInsert(ifOp.getResult(y++));
+        }
       }
       if (env.isExpand()) {
         yields.push_back(env.getExpandCount());
@@ -1073,8 +1103,11 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx,
     }
     cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
   }
-  if (env.isReduc())
+  if (env.isReduc()) {
     types.push_back(env.getReduc().getType());
+    if (env.getValidLexInsert())
+      types.push_back(env.getValidLexInsert().getType());
+  }
   if (env.isExpand())
     types.push_back(builder.getIndexType());
   if (env.getInsertionChain())
@@ -1092,6 +1125,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
   if (env.isReduc()) {
     operands.push_back(env.getReduc());
     env.updateReduc(redInput);
+    if (env.getValidLexInsert())
+      // Any overlapping indices during a reduction creates a valid lex insert.
+      operands.push_back(constantI1(builder, env.op().getLoc(), true));
   }
   if (env.isExpand()) {
     operands.push_back(env.getExpandCount());
@@ -1318,6 +1354,10 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
   // End a while-loop.
   if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
     finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp);
+  } else if (auto forOp = dyn_cast<scf::ForOp>(loop)) {
+    // Any iteration of a reduction for-loop creates a valid lex insert.
+    if (env.isReduc() && env.getValidLexInsert())
+      env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
   } else {
     needsUniv = false;
   }
index 8af3144..3d2f706 100644 (file)
@@ -306,12 +306,14 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK-LABEL:   func.func @mul_affine_sparse2d(
 // CHECK-SAME:                                   %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>,
 // CHECK-SAME:                                   %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 32 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
-// CHECK:           %[[VAL_7:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 32 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
+// CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
 // CHECK:           %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
 // CHECK:           %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
@@ -330,22 +332,27 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>,
 // CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]]] : memref<?xindex>
 // CHECK:               %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_4]] : index
 // CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xindex>
-// CHECK:               %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:               %[[VAL_30:.*]]:3 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:                 %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref<?xindex>
 // CHECK:                 %[[VAL_35:.*]] = arith.addi %[[VAL_25]], %[[VAL_7]] : index
 // CHECK:                 %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_35]] : index
-// CHECK:                 %[[VAL_37:.*]]:2 = scf.if %[[VAL_36]] -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:                 %[[VAL_37:.*]]:3 = scf.if %[[VAL_36]] -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
 // CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref<?xf64>
 // CHECK:                   %[[VAL_39:.*]] = arith.mulf %[[VAL_26]], %[[VAL_38]] : f64
 // CHECK:                   %[[VAL_40:.*]] = arith.addf %[[VAL_32]], %[[VAL_39]] : f64
-// CHECK:                   scf.yield %[[VAL_40]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                   scf.yield %[[VAL_40]], %[[VAL_TRUE]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:                 } else {
-// CHECK:                   scf.yield %[[VAL_32]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                   scf.yield %[[VAL_32]], %[[VAL_200]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:                 }
-// CHECK:                 scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1 : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                 scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1, %[[VAL_41]]#2 : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               }
+// CHECK:               %[[VAL_201:.*]] = scf.if %[[VAL_30]]#1 -> (tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
+// CHECK:                 %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_30]]#0 into %[[VAL_30]]#2{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:                 scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               } else {
+// CHECK:                 scf.yield %[[VAL_30]]#2 : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:               }
-// CHECK:               %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_43:.*]]#0 into %[[VAL_43]]#1{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
-// CHECK:               scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:               scf.yield %[[VAL_201]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:             }
 // CHECK:             scf.yield %[[VAL_44:.*]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>
 // CHECK:           }
index 99b9dff..e6d6c23 100644 (file)
@@ -153,6 +153,8 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[VAL_FALSE:.*]] = arith.constant false
+// CHECK-DAG:       %[[VAL_TRUE:.*]] = arith.constant true
 // CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>>
 // CHECK:           %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
@@ -216,13 +218,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref<?xindex>
 // CHECK:                   %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index
 // CHECK:                   %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref<?xindex>
-// CHECK:                   %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                   %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
 // CHECK:                     %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index
 // CHECK:                     %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index
 // CHECK:                     %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1
-// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_78]] : index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                   } do {
-// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>):
+// CHECK:                   ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>):
 // CHECK:                     %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref<?xindex>
 // CHECK:                     %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref<?xindex>
 // CHECK:                     %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index
@@ -230,14 +232,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1
-// CHECK:                     %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                     %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
 // CHECK:                       %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref<?xi32>
 // CHECK:                       %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref<?xi32>
 // CHECK:                       %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32
 // CHECK:                       %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32
-// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_85]] : i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                       scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                     } else {
-// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_85]] : i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                       scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                     }
 // CHECK:                     %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index
@@ -245,10 +247,15 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x
 // CHECK:                     %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index
 // CHECK:                     %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index
 // CHECK:                     %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index
-// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                   }
-// CHECK:                   %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_106:.*]]#2 into %[[VAL_106]]#3{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
-// CHECK:                   scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) {
+// CHECK:                     %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                     scf.yield %[[VAL_105]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   } else {
+// CHECK:                     scf.yield %[[VAL_74]]#4 : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
+// CHECK:                   }
+// CHECK:                   scf.yield %[[VAL_202]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                 } else {
 // CHECK:                   scf.yield %[[VAL_59]] : tensor<?x?xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
 // CHECK:                 }
index f743954..697233b 100644 (file)
@@ -172,9 +172,8 @@ module {
     // CHECK-NEXT: ( ( 6, 0, 0, 0, 0 ), ( 0, 0, 0, 5, 0 ), ( 4, 0, 0, 3, 0 ), ( 0, 2, 0, 0, 0 ), ( 0, 11, 0, 0, 0 ) )
     // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 )
     // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
-    // TODO: Update once identity values are no longer inserted for non-overlapping dot product
-    // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 )
-    // CHECK-NEXT: ( ( 7, inf, inf, 7, 0 ), ( 9, inf, inf, inf, 0 ), ( 8, 7, inf, 7, 0 ), ( 12, 11, inf, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
+    // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 )
+    // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) )
     //
     call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
     call @dump_mat(%sm2r) : (tensor<?x?xf64, #CSR>) -> ()