From: Jim Kitchen Date: Fri, 10 Feb 2023 19:08:49 +0000 (-0600) Subject: [mlir][sparse] Sparse reduction in lex order no longer produces dense output X-Git-Tag: upstream/17.0.6~17895 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=81d0d2b2a068eae9692b9317bceaaea252c1bbf8;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Sparse reduction in lex order no longer produces dense output Previously, when performing a reduction on a sparse tensor, the result would be different depending on iteration order. For expanded access pattern, an empty row would contribute no entry in the output. For lex ordering, the identity would end up in the output. This code changes that behavior and keeps track of whether any entries were actually reduced in lex ordering, making the output consistent between the two iteration styles. Differential Revision: https://reviews.llvm.org/D142050 --- diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp index 4a8fbf7..a8ebad3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -37,7 +37,7 @@ CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, latticeMerger(numTensors, numLoops, numFilterLoops), loopEmitter(), topSort(), sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(), expFilled(), expAdded(), expCount(), redVal(), redExp(-1u), - redCustom(-1u) {} + redCustom(-1u), redValidLexInsert() {} LogicalResult CodegenEnv::initTensorExp() { // Builds the tensor expression for the Linalg operation in SSA form. @@ -70,16 +70,24 @@ std::optional CodegenEnv::genLoopBoundary( function_ref(MutableArrayRef parameters)> callback) { SmallVector params; - if (isReduc()) + if (isReduc()) { params.push_back(redVal); + if (redValidLexInsert) + params.push_back(redValidLexInsert); + } else { + assert(!redValidLexInsert); + } if (isExpand()) params.push_back(expCount); if (insChain != nullptr) params.push_back(insChain); auto r = callback(params); // may update parameters unsigned i = 0; - if (isReduc()) + if (isReduc()) { updateReduc(params[i++]); + if (redValidLexInsert) + setValidLexInsert(params[i++]); + } if (isExpand()) updateExpandCount(params[i++]); if (insChain != nullptr) @@ -225,6 +233,16 @@ Value CodegenEnv::endReduc() { return val; } +void CodegenEnv::setValidLexInsert(Value val) { + assert(isReduc() && val); + redValidLexInsert = val; +} + +void CodegenEnv::clearValidLexInsert() { + assert(!isReduc()); + redValidLexInsert = Value(); +} + void CodegenEnv::startCustomReduc(unsigned exp) { assert(redCustom == -1u && exp != -1u); redCustom = exp; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h index 6f16d8f..b210ca8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -132,6 +132,9 @@ public: void updateReduc(Value val); Value getReduc() const { return redVal; } Value endReduc(); + void setValidLexInsert(Value val); + void clearValidLexInsert(); + Value getValidLexInsert() const { return redValidLexInsert; } void startCustomReduc(unsigned exp); bool isCustomReduc() const { return redCustom != -1u; } @@ -172,6 +175,11 @@ private: unsigned redExp; unsigned redCustom; + // Bookkeeping for lex insertion during reductions. Holds the runtime boolean + // value of whether any reduction occurred. This is only set during a + // reduction and cleared once the reduction is finished. + Value redValidLexInsert; + // The root tensor expression of the kernel. unsigned tensorExp; }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 72063a1..f3cf2ce 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -651,8 +651,31 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, indices.push_back(env.emitter().getLoopIV(i)); } Value chain = env.getInsertionChain(); - env.updateInsertionChain( - builder.create(loc, rhs, chain, indices)); + if (!env.getValidLexInsert()) { + env.updateInsertionChain( + builder.create(loc, rhs, chain, indices)); + } else { + // Generates runtime check for a valid lex during reduction, + // to avoid inserting the identity value for empty reductions. + // if (validLexInsert) then + // insert(rhs) into chain + // return updated chain + // else + // return unmodified chain + scf::IfOp ifValidLexInsert = builder.create( + loc, chain.getType(), env.getValidLexInsert(), + /*else=*/true); + // True branch. + builder.setInsertionPointToStart(ifValidLexInsert.thenBlock()); + Value res = builder.create(loc, rhs, chain, indices); + builder.create(loc, res); + // False branch. + builder.setInsertionPointToStart(ifValidLexInsert.elseBlock()); + builder.create(loc, chain); + // Value assignment. + builder.setInsertionPointAfter(ifValidLexInsert); + env.updateInsertionChain(ifValidLexInsert.getResult(0)); + } return; } // Generates insertion code along expanded access pattern. @@ -857,13 +880,16 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp, return; OpOperand *lhs = op.getDpsInitOperand(0); if (lhs == &t) { - // Start or end a scalarized reduction + // Start or end a scalarized reduction. if (atStart) { Value load = env.isCustomReduc() ? env.getCustomRedId() : genTensorLoad(env, builder, exp); env.startReduc(exp, load); + if (env.hasSparseOutput()) + env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false)); } else { genTensorStore(env, builder, exp, env.endReduc()); + env.clearValidLexInsert(); } } else { // Start or end loop invariant hoisting of a tensor load. @@ -1031,6 +1057,10 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx, if (env.isReduc()) { yields.push_back(env.getReduc()); env.updateReduc(ifOp.getResult(y++)); + if (env.getValidLexInsert()) { + yields.push_back(env.getValidLexInsert()); + env.setValidLexInsert(ifOp.getResult(y++)); + } } if (env.isExpand()) { yields.push_back(env.getExpandCount()); @@ -1073,8 +1103,11 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx, } cond = cond ? builder.create(loc, cond, clause) : clause; } - if (env.isReduc()) + if (env.isReduc()) { types.push_back(env.getReduc().getType()); + if (env.getValidLexInsert()) + types.push_back(env.getValidLexInsert().getType()); + } if (env.isExpand()) types.push_back(builder.getIndexType()); if (env.getInsertionChain()) @@ -1092,6 +1125,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, if (env.isReduc()) { operands.push_back(env.getReduc()); env.updateReduc(redInput); + if (env.getValidLexInsert()) + // Any overlapping indices during a reduction creates a valid lex insert. + operands.push_back(constantI1(builder, env.op().getLoc(), true)); } if (env.isExpand()) { operands.push_back(env.getExpandCount()); @@ -1318,6 +1354,10 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, // End a while-loop. if (auto whileOp = dyn_cast(loop)) { finalizeWhileOp(env, rewriter, idx, needsUniv, env.lat(li).bits, whileOp); + } else if (auto forOp = dyn_cast(loop)) { + // Any iteration of a reduction for-loop creates a valid lex insert. + if (env.isReduc() && env.getValidLexInsert()) + env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); } else { needsUniv = false; } diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir index 8af3144..3d2f706 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -306,12 +306,14 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>, // CHECK-LABEL: func.func @mul_affine_sparse2d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<34x19xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK: %[[VAL_2:.*]] = arith.constant 32 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[VAL_TRUE:.*]] = arith.constant true +// CHECK-DAG: %[[VAL_FALSE:.*]] = arith.constant false // CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref @@ -330,22 +332,27 @@ func.func @mul_affine_dense2d(%arga: tensor<32x16xf64, #CSR>, // CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]]] : memref // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_4]] : index // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref -// CHECK: %[[VAL_30:.*]]:2 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_30:.*]]:3 = scf.for %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_4]] iter_args(%[[VAL_32:.*]] = %[[VAL_6]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_33:.*]] = %[[VAL_24]]) -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { // CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_25]], %[[VAL_7]] : index // CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_35]] : index -// CHECK: %[[VAL_37:.*]]:2 = scf.if %[[VAL_36]] -> (f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_37:.*]]:3 = scf.if %[[VAL_36]] -> (f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { // CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_31]]] : memref // CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_26]], %[[VAL_38]] : f64 // CHECK: %[[VAL_40:.*]] = arith.addf %[[VAL_32]], %[[VAL_39]] : f64 -// CHECK: scf.yield %[[VAL_40]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_40]], %[[VAL_TRUE]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } else { -// CHECK: scf.yield %[[VAL_32]], %[[VAL_33]] : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_32]], %[[VAL_200]], %[[VAL_33]] : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } -// CHECK: scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1 : f64, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_41:.*]]#0, %[[VAL_41]]#1, %[[VAL_41]]#2 : f64, i1, tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } +// CHECK: %[[VAL_201:.*]] = scf.if %[[VAL_30]]#1 -> (tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>>) { +// CHECK: %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_30]]#0 into %[[VAL_30]]#2{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_30]]#2 : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } -// CHECK: %[[VAL_42:.*]] = sparse_tensor.insert %[[VAL_43:.*]]#0 into %[[VAL_43]]#1{{\[}}%[[VAL_16]], %[[VAL_25]]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: scf.yield %[[VAL_42]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: scf.yield %[[VAL_201]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } // CHECK: scf.yield %[[VAL_44:.*]] : tensor<32x16xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir index 99b9dff..e6d6c23 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -153,6 +153,8 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_FALSE:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_TRUE:.*]] = arith.constant true // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> // CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> // CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor(%[[VAL_5]], %[[VAL_6]]) : tensor> @@ -216,13 +218,13 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref // CHECK: %[[VAL_72:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index // CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_72]]] : memref -// CHECK: %[[VAL_74:.*]]:4 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, tensor>) -> (index, index, i32, tensor>) { +// CHECK: %[[VAL_74:.*]]:5 = scf.while (%[[VAL_75:.*]] = %[[VAL_68]], %[[VAL_76:.*]] = %[[VAL_71]], %[[VAL_77:.*]] = %[[VAL_4]], %[[VAL_200:.*]] = %[[VAL_FALSE]], %[[VAL_78:.*]] = %[[VAL_59]]) : (index, index, i32, i1, tensor>) -> (index, index, i32, i1, tensor>) { // CHECK: %[[VAL_79:.*]] = arith.cmpi ult, %[[VAL_75]], %[[VAL_70]] : index // CHECK: %[[VAL_80:.*]] = arith.cmpi ult, %[[VAL_76]], %[[VAL_73]] : index // CHECK: %[[VAL_81:.*]] = arith.andi %[[VAL_79]], %[[VAL_80]] : i1 -// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_78]] : index, index, i32, tensor> +// CHECK: scf.condition(%[[VAL_81]]) %[[VAL_75]], %[[VAL_76]], %[[VAL_77]], %[[VAL_200]], %[[VAL_78]] : index, index, i32, i1, tensor> // CHECK: } do { -// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_85:.*]]: tensor>): +// CHECK: ^bb0(%[[VAL_82:.*]]: index, %[[VAL_83:.*]]: index, %[[VAL_84:.*]]: i32, %[[VAL_201:.*]]: i1, %[[VAL_85:.*]]: tensor>): // CHECK: %[[VAL_86:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_82]]] : memref // CHECK: %[[VAL_87:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_83]]] : memref // CHECK: %[[VAL_88:.*]] = arith.cmpi ult, %[[VAL_87]], %[[VAL_86]] : index @@ -230,14 +232,14 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_90:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_91:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_92:.*]] = arith.andi %[[VAL_90]], %[[VAL_91]] : i1 -// CHECK: %[[VAL_93:.*]]:2 = scf.if %[[VAL_92]] -> (i32, tensor>) { +// CHECK: %[[VAL_93:.*]]:3 = scf.if %[[VAL_92]] -> (i32, i1, tensor>) { // CHECK: %[[VAL_94:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_82]]] : memref // CHECK: %[[VAL_95:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_83]]] : memref // CHECK: %[[VAL_96:.*]] = arith.muli %[[VAL_94]], %[[VAL_95]] : i32 // CHECK: %[[VAL_97:.*]] = arith.addi %[[VAL_84]], %[[VAL_96]] : i32 -// CHECK: scf.yield %[[VAL_97]], %[[VAL_85]] : i32, tensor> +// CHECK: scf.yield %[[VAL_97]], %[[VAL_TRUE]], %[[VAL_85]] : i32, i1, tensor> // CHECK: } else { -// CHECK: scf.yield %[[VAL_84]], %[[VAL_85]] : i32, tensor> +// CHECK: scf.yield %[[VAL_84]], %[[VAL_201]], %[[VAL_85]] : i32, i1, tensor> // CHECK: } // CHECK: %[[VAL_98:.*]] = arith.cmpi eq, %[[VAL_86]], %[[VAL_89]] : index // CHECK: %[[VAL_99:.*]] = arith.addi %[[VAL_82]], %[[VAL_3]] : index @@ -245,10 +247,15 @@ func.func @sparse_truly_dynamic(%arga: tensor<10x20xf32, #CSR>) -> tensor<10x20x // CHECK: %[[VAL_101:.*]] = arith.cmpi eq, %[[VAL_87]], %[[VAL_89]] : index // CHECK: %[[VAL_102:.*]] = arith.addi %[[VAL_83]], %[[VAL_3]] : index // CHECK: %[[VAL_103:.*]] = arith.select %[[VAL_101]], %[[VAL_102]], %[[VAL_83]] : index -// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1 : index, index, i32, tensor> +// CHECK: scf.yield %[[VAL_100]], %[[VAL_103]], %[[VAL_104:.*]]#0, %[[VAL_104]]#1, %[[VAL_104]]#2 : index, index, i32, i1, tensor> // CHECK: } -// CHECK: %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_106:.*]]#2 into %[[VAL_106]]#3{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor> -// CHECK: scf.yield %[[VAL_105]] : tensor> +// CHECK: %[[VAL_202:.*]] = scf.if %[[VAL_74]]#3 -> (tensor>) { +// CHECK: %[[VAL_105:.*]] = sparse_tensor.insert %[[VAL_74]]#2 into %[[VAL_74]]#4{{\[}}%[[VAL_39]], %[[VAL_63]]] : tensor> +// CHECK: scf.yield %[[VAL_105]] : tensor> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_74]]#4 : tensor> +// CHECK: } +// CHECK: scf.yield %[[VAL_202]] : tensor> // CHECK: } else { // CHECK: scf.yield %[[VAL_59]] : tensor> // CHECK: } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir index f743954..697233b 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir @@ -172,9 +172,8 @@ module { // CHECK-NEXT: ( ( 6, 0, 0, 0, 0 ), ( 0, 0, 0, 5, 0 ), ( 4, 0, 0, 3, 0 ), ( 0, 2, 0, 0, 0 ), ( 0, 11, 0, 0, 0 ) ) // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 ) // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) - // TODO: Update once identity values are no longer inserted for non-overlapping dot product - // CHECK-NEXT: ( 7, inf, inf, 7, 9, inf, inf, inf, 8, 7, inf, 7, 12, 11, inf, 11 ) - // CHECK-NEXT: ( ( 7, inf, inf, 7, 0 ), ( 9, inf, inf, inf, 0 ), ( 8, 7, inf, 7, 0 ), ( 12, 11, inf, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) + // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( ( 7, 0, 0, 7, 0 ), ( 9, 0, 0, 0, 0 ), ( 8, 7, 0, 7, 0 ), ( 12, 11, 0, 11, 0 ), ( 0, 0, 0, 0, 0 ) ) // call @dump_mat(%sm1) : (tensor) -> () call @dump_mat(%sm2r) : (tensor) -> ()