}
};
+struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(PackOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto rtp = op.getResult().getType().cast<RankedTensorType>();
+ assert(isUniqueCOOType(rtp));
+
+ SmallVector<Value> fields;
+ Location loc = op.getLoc();
+
+ foreachFieldAndTypeInSparseTensor(
+ rtp,
+ [&rewriter, &fields, &op, rtp,
+ loc](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
+ unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
+ assert(fields.size() == fIdx);
+ auto enc = getSparseTensorEncoding(rtp);
+ Value field;
+ switch (fKind) {
+ case SparseTensorFieldKind::StorageSpec:
+ field = SparseTensorSpecifier::getInitValue(rewriter, loc, rtp);
+ break;
+ case SparseTensorFieldKind::PtrMemRef: {
+ // TACO-style COO starts with a PtrBuffer
+ // By creating a constant value for it, we avoid the complexity of
+ // memory management.
+ auto tensorType = RankedTensorType::get({2}, enc.getPointerType());
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ auto cstPtr = rewriter.create<arith::ConstantOp>(
+ loc, tensorType,
+ DenseElementsAttr::get(
+ tensorType,
+ {APInt(64, 0),
+ APInt(64, op.getData().getType().getShape()[0])}));
+ field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
+ cstPtr);
+ break;
+ }
+ case SparseTensorFieldKind::IdxMemRef: {
+ auto tensorType = op.getIndices().getType();
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ auto idxMemRef = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, op.getIndices());
+ ReassociationIndices reassociation;
+ for (int i = 0, e = tensorType.getRank(); i < e; i++)
+ reassociation.push_back(i);
+
+ // Flattened the indices buffer to rank 1.
+ field = rewriter.create<memref::CollapseShapeOp>(
+ loc, idxMemRef, ArrayRef<ReassociationIndices>(reassociation));
+ break;
+ }
+ case SparseTensorFieldKind::ValMemRef: {
+ auto tensorType = op.getData().getType();
+ auto memrefType = MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType());
+ field = rewriter.create<bufferization::ToMemrefOp>(
+ op->getLoc(), memrefType, op.getData());
+ break;
+ }
+ }
+
+ assert(field);
+ if (fType != field.getType())
+ field = rewriter.create<memref::CastOp>(loc, fType, field);
+ fields.push_back(field);
+ // Returns true to continue the iteration.
+ return true;
+ });
+
+ MutSparseTensorDescriptor desc(rtp, fields);
+ auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getData(), 0);
+ for (unsigned i = 0, e = rtp.getRank(); i < e; i++) {
+ int dim = rtp.getShape()[i];
+ assert(!ShapedType::isDynamic(dim));
+ desc.setDimSize(rewriter, loc, i, constantIndex(rewriter, loc, dim));
+ if (i == 0)
+ desc.setPtrMemSize(rewriter, loc, i, constantIndex(rewriter, loc, 2));
+
+ desc.setIdxMemSize(rewriter, loc, i, noe);
+ }
+ desc.setValMemSize(rewriter, loc, noe);
+
+ rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
void mlir::populateSparseTensorCodegenPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns,
bool enableBufferInitialization) {
- patterns.add<SparseReturnConverter, SparseCallConverter, SparseDimOpConverter,
- SparseCastConverter, SparseTensorDeallocConverter,
- SparseTensorLoadConverter, SparseExpandConverter,
- SparseCompressConverter, SparseInsertConverter,
- SparseToPointersConverter, SparseToIndicesConverter,
- SparseToIndicesBufferConverter, SparseToValuesConverter,
- SparseConvertConverter, SparseNumberOfEntriesConverter>(
- typeConverter, patterns.getContext());
+ patterns.add<SparsePackOpConverter, SparseReturnConverter,
+ SparseCallConverter, SparseDimOpConverter, SparseCastConverter,
+ SparseTensorDeallocConverter, SparseTensorLoadConverter,
+ SparseExpandConverter, SparseCompressConverter,
+ SparseInsertConverter, SparseToPointersConverter,
+ SparseToIndicesConverter, SparseToIndicesBufferConverter,
+ SparseToValuesConverter, SparseConvertConverter,
+ SparseNumberOfEntriesConverter>(typeConverter,
+ patterns.getContext());
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
}
--- /dev/null
+// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse | FileCheck %s
+
+#COO = #sparse_tensor.encoding<{
+ dimLevelType = ["compressed-nu", "singleton"],
+ indexBitWidth=32
+}>
+
+// CHECK-LABEL: func.func @sparse_pack(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex>
+// CHECK: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex>
+// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32>
+// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref<?xi32>
+// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64>
+// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref<?xf64>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init :
+// CHECK: %[[VAL_11:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_12:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i32
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_sz at 0 with %[[VAL_13]] : i32,
+// CHECK: %[[VAL_15:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[VAL_15]] : index to i32
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] ptr_mem_sz at 0 with %[[VAL_16]] : i32,
+// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_11]] : index to i32
+// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] idx_mem_sz at 0 with %[[VAL_18]] : i32,
+// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] dim_sz at 1 with %[[VAL_13]] : i32,
+// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] idx_mem_sz at 1 with %[[VAL_18]] : i32,
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_18]] : i32,
+// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xi32>, memref<?xf64>,
+// CHECK: }
+func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
+ -> tensor<100x100xf64, #COO> {
+ %0 = sparse_tensor.pack %data, %index : tensor<6xf64>, tensor<6x2xi32>
+ to tensor<100x100xf64, #COO>
+ return %0 : tensor<100x100xf64, #COO>
+}