From 791935037b0b3b211bee54fae694aeb5b7b75125 Mon Sep 17 00:00:00 2001 From: Jim Kitchen Date: Mon, 3 Oct 2022 14:34:53 -0500 Subject: [PATCH] [mlir][sparse] Implement sparse_tensor.select The region within sparse_tensor.select is used as the runtime criteria for whether to keep the existing value in the sparse tensor. While the sparse element is provided to the comparison, indices may also be used to decide on whether to keep the original value. This allows, for example, to only keep the upper triangle of a matrix. Reviewed by: aartbik Differential Revision: https://reviews.llvm.org/D134761 --- .../mlir/Dialect/SparseTensor/Utils/Merger.h | 5 +- .../SparseTensor/Transforms/Sparsification.cpp | 17 ++- mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp | 20 ++- .../Dialect/SparseTensor/CPU/sparse_select.mlir | 148 +++++++++++++++++++++ mlir/unittests/Dialect/SparseTensor/MergerTest.cpp | 1 + 5 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index f9be476..a376d9a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -76,6 +76,7 @@ enum Kind { kBitCast, kBinaryBranch, // semiring unary branch created from a binary op kUnary, // semiring unary op + kSelect, // custom selection criteria // Binary operations. kMulF, kMulC, @@ -129,8 +130,8 @@ struct TensorExp { /// this field may be used to cache "hoisted" loop invariant tensor loads. Value val; - /// Code blocks used by semirings. For the case of kUnary, kBinary, and - /// kReduce, this holds the original operation with all regions. For + /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce, + /// and kSelect, this holds the original operation with all regions. For /// kBinaryBranch, this holds the YieldOp for the left or right half /// to be merged into a nested scf loop. Operation *op; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 88bd885..7cd9f7f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -878,6 +878,15 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, // Only unary and binary are allowed to return uninitialized rhs // to indicate missing output. assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary); + } else if (merger.exp(exp).kind == kSelect) { + scf::IfOp ifOp = builder.create(loc, rhs); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Existing value was preserved to be used here. + assert(merger.exp(exp).val); + Value v0 = merger.exp(exp).val; + genInsertionStore(codegen, builder, op, t, v0); + merger.exp(exp).val = Value(); + builder.setInsertionPointAfter(ifOp); } else { genInsertionStore(codegen, builder, op, t, rhs); } @@ -1037,9 +1046,15 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, if (ee && (merger.exp(exp).kind == Kind::kUnary || merger.exp(exp).kind == Kind::kBinary || merger.exp(exp).kind == Kind::kBinaryBranch || - merger.exp(exp).kind == Kind::kReduce)) + merger.exp(exp).kind == Kind::kReduce || + merger.exp(exp).kind == Kind::kSelect)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); + if (merger.exp(exp).kind == kSelect) { + assert(!merger.exp(exp).val); + merger.exp(exp).val = v0; // Preserve value for later use. + } + if (merger.exp(exp).kind == Kind::kReduce) { assert(codegen.redCustom != -1u); codegen.redCustom = -1u; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 7f132f9..3791971 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -78,6 +78,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) children.e1 = y; break; case kBinaryBranch: + case kSelect: assert(x != -1u && y == -1u && !v && o); children.e0 = x; children.e1 = y; @@ -212,7 +213,7 @@ unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, } unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { - assert(kAbsF <= kind && kind <= kUnary); + assert(kAbsF <= kind && kind <= kSelect); unsigned s = addSet(); for (unsigned p : latSets[s0]) { unsigned e = addExp(kind, latPoints[p].exp, v, op); @@ -265,9 +266,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { BitVector simple = latPoints[p0].bits; bool reset = isSingleton && hasAnySparse(simple); for (unsigned b = 0, be = simple.size(); b < be; b++) { - if (simple[b] && - (!isDimLevelType(b, DimLvlType::kCompressed) && - !isDimLevelType(b, DimLvlType::kSingleton))) { + if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) && + !isDimLevelType(b, DimLvlType::kSingleton))) { if (reset) simple.reset(b); reset = true; @@ -338,6 +338,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const { return isSingleCondition(t, tensorExps[e].children.e0); case kBinaryBranch: case kUnary: + case kSelect: return false; // Binary operations. case kDivF: // note: x / c only @@ -449,6 +450,8 @@ static const char *kindToOpSymbol(Kind kind) { return "binary_branch"; case kUnary: return "unary"; + case kSelect: + return "select"; // Binary operations. case kMulF: case kMulC: @@ -537,6 +540,7 @@ void Merger::dumpExp(unsigned e) const { case kBitCast: case kBinaryBranch: case kUnary: + case kSelect: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; dumpExp(tensorExps[e].children.e0); break; @@ -684,6 +688,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) { return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), tensorExps[e].val); case kBinaryBranch: + case kSelect: // The left or right half of a binary operation which has already // been split into separate operations for each region. return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), @@ -978,6 +983,10 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { isAdmissableBranch(unop, unop.getAbsentRegion())) return addExp(kUnary, e, Value(), def); } + if (auto selop = dyn_cast(def)) { + if (isAdmissableBranch(selop, selop.getRegion())) + return addExp(kSelect, e, Value(), def); + } } } // Construct binary operations if subexpressions can be built. @@ -1228,6 +1237,9 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, *tensorExps[e].op->getBlock()->getParent(), {v0}); case kUnary: return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); + case kSelect: + return insertYieldOp(rewriter, loc, + cast(tensorExps[e].op).getRegion(), {v0}); case kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); case kReduce: { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir new file mode 100644 index 0000000..bbe94e9 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Traits for tensor operations. +// +#trait_vec_select = { + indexing_maps = [ + affine_map<(i) -> (i)>, // A + affine_map<(i) -> (i)> // C (out) + ], + iterator_types = ["parallel"] +} + +#trait_mat_select = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"] +} + +module { + func.func @vecSelect(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cf1 = arith.constant 1.0 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %xv = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #trait_vec_select + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %1 = sparse_tensor.select %a : f64 { + ^bb0(%x: f64): + %keep = arith.cmpf "oge", %x, %cf1 : f64 + sparse_tensor.yield %keep : i1 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + func.func @matUpperTriangle(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xv = bufferization.alloc_tensor(%d0, %d1): tensor + %0 = linalg.generic #trait_mat_select + ins(%arga: tensor) + outs(%xv: tensor) { + ^bb(%a: f64, %b: f64): + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %1 = sparse_tensor.select %a : f64 { + ^bb0(%x: f64): + %keep = arith.cmpi "ugt", %col, %row : index + sparse_tensor.yield %keep : i1 + } + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func.func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<8xf64> + vector.print %1 : vector<8xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dv[%c0], %d0: tensor, vector<16xf64> + vector.print %2 : vector<16xf64> + return + } + + // Dump a sparse matrix. + func.func @dump_mat(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor, vector<5x5xf64> + vector.print %2 : vector<5x5xf64> + return + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse matrices. + %v1 = arith.constant sparse< + [ [1], [3], [5], [7], [9] ], + [ 1.0, 2.0, -4.0, 0.0, 5.0 ] + > : tensor<10xf64> + %m1 = arith.constant sparse< + [ [0, 3], [1, 4], [2, 1], [2, 3], [3, 3], [3, 4], [4, 2] ], + [ 1., 2., 3., 4., 5., 6., 7.] + > : tensor<5x5xf64> + %sv1 = sparse_tensor.convert %v1 : tensor<10xf64> to tensor + %sm1 = sparse_tensor.convert %m1 : tensor<5x5xf64> to tensor + + // Call sparse matrix kernels. + %1 = call @vecSelect(%sv1) : (tensor) -> tensor + %2 = call @matUpperTriangle(%sm1) : (tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, -4, 0, 5, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 1, 0, 2, 0, -4, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 3, 0, 4, 0 ), ( 0, 0, 0, 5, 6 ), ( 0, 0, 7, 0, 0 ) ) + // CHECK-NEXT: ( 1, 2, 5, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 0, 1, 0, 2, 0, 0, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 0, 0, 4, 0 ), ( 0, 0, 0, 0, 6 ), ( 0, 0, 0, 0, 0 ) ) + // + call @dump_vec(%sv1) : (tensor) -> () + call @dump_mat(%sm1) : (tensor) -> () + call @dump_vec(%1) : (tensor) -> () + call @dump_mat(%2) : (tensor) -> () + + // Release the resources. + bufferization.dealloc_tensor %sv1 : tensor + bufferization.dealloc_tensor %sm1 : tensor + bufferization.dealloc_tensor %1 : tensor + bufferization.dealloc_tensor %2 : tensor + return + } +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index 8d41558..c0e75dc 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -259,6 +259,7 @@ protected: case kCIm: case kCRe: case kBitCast: + case kSelect: case kBinaryBranch: case kUnary: return compareExpression(tensorExp.children.e0, pattern->e0); -- 2.7.4