}
def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
- Arguments<(ins 1DTensorOf<[AnyType]>:$values,
- 2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates)>,
+ Arguments<(ins TensorOf<[AnyType]>:$values,
+ TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
+ OptionalAttr<IndexAttr>:$batched_lvls)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor from the given (values, coordinates) pair";
supplies the level-coords for each element in `values`.
- `values : tensor<NSE x V>`
supplies the corresponding values for each entry in `coordinates`.
+ - `batched_lvls : optional<index>`
+ supplies the number of leading levels that are batched.
This operation can be used to materialize a sparse tensor from external
sources; e.g., when passing two numpy arrays from Python.
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
// |0.0, 0.0, 0.0, 0.0|
```
+
+ If `batched_lvls` is provided, the operation materializes a batched sparse tensor.
+ Example:
+
+ ```mlir
+ %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ],
+ [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64>
+ %coordinates = arith.constant dense<[[ [0], [1], [2] ],
+ [ [1], [2], [3] ]> : tensor<2x3x1xindex>
+ %st = sparse_tensor.pack %values, %coordinates batched_lvls=1
+ : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
+ // yields BCOO format |1.1, 2.2, 3.3, 0.0|
+ // of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns the number of leading levels that are batched.
+ unsigned getNumBatchedLvls();
}];
let assemblyFormat =
- "$values `,` $coordinates attr-dict"
+ "$values `,` $coordinates (`batched_lvls` `=` $batched_lvls^)? attr-dict"
"`:` type($values) `,` type($coordinates) `to` type($result)";
let hasVerifier = 1;
/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
-static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl,
- bool isUnique) {
- if (!enc || !enc.isCompressedLvl(startLvl))
+bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
+ Level startLvl, bool isUnique) {
+ if (!enc ||
+ !(enc.isCompressedLvl(startLvl) || enc.isCompressedWithHiLvl(startLvl)))
return false;
const Level lvlRank = enc.getLvlRank();
for (Level l = startLvl + 1; l < lvlRank; ++l)
static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
SparseTensorType tensorTp,
RankedTensorType valuesTp,
- RankedTensorType coordinatesTp) {
+ RankedTensorType coordinatesTp,
+ IntegerAttr batchedLvls) {
+ unsigned nBatched = batchedLvls ? batchedLvls.getValue().getZExtValue() : 0;
if (requiresStaticShape && !tensorTp.hasStaticDimShape())
return op->emitError("the sparse-tensor must have static shape");
if (!tensorTp.hasEncoding())
return op->emitError("the sparse-tensor must have an encoding attribute");
if (!tensorTp.isIdentity())
return op->emitError("the sparse-tensor must have the identity mapping");
- if (!isUniqueCOOType(tensorTp))
+ if (!isCOOType(tensorTp.getEncoding(), nBatched, true))
return op->emitError("the sparse-tensor must have a COO type");
- if (coordinatesTp.getRank() != 2)
- return op->emitError("coordinates must have rank 2");
+ if (coordinatesTp.getRank() != 2 + nBatched)
+ return op->emitError("coordinates must have rank 2 + batched_lvls");
if (requiresStaticShape && !coordinatesTp.hasStaticShape())
return op->emitError("coordinates must have static shape");
if (coordinatesTp.getElementType() != tensorTp.getCrdType())
return op->emitError("input/output coordinate-types don't match");
- if (valuesTp.getRank() != 1)
- return op->emitError("values must have rank 1");
+ if (valuesTp.getRank() != 1 + nBatched)
+ return op->emitError("values must have rank 1 + batched_lvls");
if (requiresStaticShape && !valuesTp.hasStaticShape())
return op->emitError("values must have static shape");
if (valuesTp.getElementType() != tensorTp.getElementType())
return op->emitError("input/output element-types don't match");
- const auto valuesNSE = valuesTp.getShape()[0];
- const auto coordsNSE = coordinatesTp.getShape()[0];
+ for (unsigned i = 0; i < nBatched; i++) {
+ const auto valBatch = valuesTp.getShape()[i];
+ const auto crdBatch = coordinatesTp.getShape()[i];
+ if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) ||
+ crdBatch != valBatch) {
+ return op->emitError(
+ "values/coordinates batched level sizes don't match statically");
+ }
+ }
+
+ const auto valuesNSE = valuesTp.getShape()[nBatched];
+ const auto coordsNSE = coordinatesTp.getShape()[nBatched];
if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) &&
valuesNSE != coordsNSE)
return op->emitError("values/coordinates number-of-elements don't match");
// NOTE: We use `getLvlRank` because the `coordinatesTp` is for
// level-coordinates (cf., the op documentation).
- const DynSize coordsRank = coordinatesTp.getShape()[1];
+ const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched];
const Level tensorRank = tensorTp.getLvlRank();
// FIXME: replace the `operator!=` with our backported `safelyNE`.
if (!ShapedType::isDynamic(coordsRank) &&
- coordsRank != static_cast<DynSize>(tensorRank))
+ coordsRank != static_cast<DynSize>(tensorRank) - nBatched)
return op->emitError("input/output level-ranks don't match");
return success();
const auto valuesTp = getRankedTensorType(getValues());
const auto coordinatesTp = getRankedTensorType(getCoordinates());
const auto resTp = getSparseTensorType(getResult());
- return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp);
+ return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp,
+ getBatchedLvlsAttr());
+}
+
+unsigned PackOp::getNumBatchedLvls() {
+ return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
}
LogicalResult UnpackOp::verify() {
const auto valuesTp = getRankedTensorType(getValues());
const auto coordinatesTp = getRankedTensorType(getCoordinates());
const auto srcTp = getSparseTensorType(getTensor());
- return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp);
+ return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
+ nullptr);
}
LogicalResult ConvertOp::verify() {
}
};
+static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
+ ArrayRef<unsigned> batchDimSzs,
+ Value posMemRef, unsigned nse,
+ PackOp op) {
+ SmallVector<Value> lbs, ubs, steps;
+ Value c0 = constantIndex(builder, loc, 0);
+ Value c1 = constantIndex(builder, loc, 1);
+ Value c2 = constantIndex(builder, loc, 2);
+ for (unsigned dimSz : batchDimSzs) {
+ lbs.push_back(c0);
+ ubs.push_back(constantIndex(builder, loc, dimSz));
+ steps.push_back(c1);
+ }
+ auto tensorType = op.getValues().getType();
+ auto memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ Value batV = builder.create<bufferization::ToMemrefOp>(loc, memrefType,
+ op.getValues());
+ scf::buildLoopNest(
+ builder, loc, lbs, ubs, steps,
+ [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
+ ValueRange ivs) {
+ // Linearize index variables
+ Value crd = constantIndex(builder, loc, 0);
+ for (unsigned i = 0, e = ivs.size(); i < e; i++) {
+ crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
+ if (i != ivs.size() - 1)
+ crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
+ }
+ Value len = constantIndex(builder, loc, nse);
+ Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
+ SmallVector<Value> indices(ivs.begin(), ivs.end());
+ auto whileOp = builder.create<scf::WhileOp>(
+ loc, TypeRange{builder.getIndexType()}, ValueRange{len},
+ [&indices, c0, c1, batV](OpBuilder &builder, Location loc,
+ ValueRange vs) {
+ Value curLen = vs.front();
+ Value pred = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, curLen, c0);
+ auto ifOp = builder.create<scf::IfOp>(
+ loc, TypeRange{builder.getI1Type()}, pred, true);
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ // if len == 0.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc,
+ constantI1(builder, loc, false));
+ // Else branch.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ indices.push_back(
+ builder.create<arith::SubIOp>(loc, curLen, c1));
+ Value val = builder.create<memref::LoadOp>(loc, batV, indices);
+ indices.pop_back();
+ Value cont = builder.create<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OEQ, val,
+ constantZero(builder, loc, val.getType()));
+ builder.create<scf::YieldOp>(loc, cont);
+ }
+ builder.create<scf::ConditionOp>(loc, ifOp.getResults()[0], vs);
+ },
+ [c1](OpBuilder &builder, Location loc, ValueRange vs) {
+ // len --;
+ Value nxLen = builder.create<arith::SubIOp>(loc, vs.front(), c1);
+ builder.create<scf::YieldOp>(loc, nxLen);
+ });
+ len = whileOp.getResults()[0];
+ Value pHi = builder.create<arith::AddIOp>(loc, pLo, len);
+ // Stores position lower bound.
+ Value idx = builder.create<arith::MulIOp>(loc, crd, c2);
+ genStore(builder, loc, pLo, posMemRef, idx);
+ // Stores position upper bound.
+ idx = builder.create<arith::AddIOp>(loc, idx, c1);
+ genStore(builder, loc, pHi, posMemRef, idx);
+ });
+}
+
struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PackOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
+ const unsigned batchedLvls = op.getNumBatchedLvls();
+ unsigned nse = op.getValues().getType().getDimSize(batchedLvls);
const auto stt = getSparseTensorType(op.getResult());
- assert(isUniqueCOOType(stt));
+ assert(isCOOType(stt.getEncoding(), batchedLvls, true));
+
+ unsigned batchedCount = 1;
+ SmallVector<unsigned> batchDimSzs;
+ batchDimSzs.reserve(batchedLvls);
+ for (unsigned i = 0; i < batchedLvls; i++) {
+ // Should already be guaranteed by verifier.
+ assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
+ batchedCount *= stt.getDimShape()[i];
+ batchDimSzs.push_back(stt.getDimShape()[i]);
+ }
SmallVector<Value> fields;
Location loc = op.getLoc();
foreachFieldAndTypeInSparseTensor(
stt,
- [&rewriter, &fields, &op, stt,
+ [&rewriter, &fields, &op, &batchDimSzs, nse, batchedCount, stt,
loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
- Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
+ Level /*lvl*/, DimLevelType dlt) -> bool {
assert(fields.size() == fIdx);
Value field;
switch (fKind) {
// By creating a constant value for it, we avoid the complexity of
// memory management.
const auto posTp = stt.getPosType();
- auto tensorType = RankedTensorType::get({2}, posTp);
- auto memrefType = MemRefType::get(tensorType.getShape(),
- tensorType.getElementType());
- auto cstPtr = rewriter.create<arith::ConstantOp>(
- loc, tensorType,
- DenseElementsAttr::get(
- tensorType,
- ArrayRef<Attribute>{
- IntegerAttr::get(posTp, 0),
- IntegerAttr::get(
- posTp, op.getValues().getType().getShape()[0])}));
- field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
- cstPtr);
+ if (isCompressedDLT(dlt)) {
+ RankedTensorType tensorType;
+ SmallVector<Attribute> posAttr;
+ tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
+ posAttr.push_back(IntegerAttr::get(posTp, 0));
+ for (unsigned i = 0; i < batchedCount; i++) {
+ // The postion memref will have values as
+ // [0, nse, 2 * nse, ..., batchedCount * nse]
+ posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
+ }
+ MemRefType memrefType = MemRefType::get(
+ tensorType.getShape(), tensorType.getElementType());
+ auto cstPtr = rewriter.create<arith::ConstantOp>(
+ loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
+ field = rewriter.create<bufferization::ToMemrefOp>(
+ loc, memrefType, cstPtr);
+ } else {
+ assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
+ MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
+ field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
+ populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
+ field, nse, op);
+ }
break;
}
case SparseTensorFieldKind::CrdMemRef: {
auto tensorType = op.getCoordinates().getType();
auto memrefType = MemRefType::get(tensorType.getShape(),
tensorType.getElementType());
- auto crdMemRef = rewriter.create<bufferization::ToMemrefOp>(
+ field = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, op.getCoordinates());
- ReassociationIndices reassociation;
- for (int i = 0, e = tensorType.getRank(); i < e; i++)
- reassociation.push_back(i);
- // Flattened the indices buffer to rank 1.
- field = rewriter.create<memref::CollapseShapeOp>(
- loc, crdMemRef, ArrayRef<ReassociationIndices>(reassociation));
break;
}
case SparseTensorFieldKind::ValMemRef: {
}
assert(field);
+ if (auto memrefTp = field.getType().dyn_cast<MemRefType>();
+ memrefTp && memrefTp.getRank() > 1) {
+ ReassociationIndices reassociation;
+ for (int i = 0, e = memrefTp.getRank(); i < e; i++)
+ reassociation.push_back(i);
+ // Flattens the buffer to rank 1. The value buffer might need be
+ // collapsed as well due to batching.
+ field = rewriter.create<memref::CollapseShapeOp>(
+ loc, field, ArrayRef<ReassociationIndices>(reassociation));
+ }
+
if (fType != field.getType())
field = rewriter.create<memref::CastOp>(loc, fType, field);
fields.push_back(field);