}
}
+/// Returns a memref that fits the requested length (reallocates if requested
+/// length is larger, or creates a subview if it is smaller).
+static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
+ Value buffer) {
+ MemRefType memTp = getMemRefType(buffer);
+ auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
+
+ Value targetLen = constantIndex(builder, loc, len);
+ Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
+ Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ targetLen, bufferLen);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
+ // If targetLen > bufferLen, reallocate to get enough sparse to return.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
+ builder.create<scf::YieldOp>(loc, reallocBuf);
+ // Else, return a subview to fit the size.
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value subViewBuf = builder.create<memref::SubViewOp>(
+ loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
+ /*size=*/ArrayRef<int64_t>{len},
+ /*stride=*/ArrayRef<int64_t>{1});
+ builder.create<scf::YieldOp>(loc, subViewBuf);
+ // Resets insertion point.
+ builder.setInsertionPointAfter(ifOp);
+ return ifOp.getResult(0);
+}
+
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
// to ensure that we meet their need.
TensorType dataTp = op.getData().getType();
if (dataTp.hasStaticShape()) {
- dataBuf = rewriter.create<memref::ReallocOp>(
- loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
- dataBuf);
+ dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
}
TensorType indicesTp = op.getIndices().getType();
if (indicesTp.hasStaticShape()) {
auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
- flatBuf = rewriter.create<memref::ReallocOp>(
- loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
+ flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
}
Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
// CHECK-SAME: %[[VAL_1:.*]]: memref<?xi32>,
// CHECK-SAME: %[[VAL_2:.*]]: memref<?xf64>,
// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
-// CHECK: %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
-// CHECK: %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
-// CHECK: %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
-// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64>
-// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32>
-// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz
-// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index
-// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
+// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
+// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
+// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
+// CHECK: scf.yield %[[VAL_9]] : memref<6xf64>
+// CHECK: } else {
+// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref<?xf64> to memref<6xf64>
+// CHECK: scf.yield %[[VAL_10]] : memref<6xf64>
+// CHECK: }
+// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index
+// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
+// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
+// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
+// CHECK: scf.yield %[[VAL_15]] : memref<12xi32>
+// CHECK: } else {
+// CHECK: %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref<?xi32> to memref<12xi32>
+// CHECK: scf.yield %[[VAL_16]] : memref<12xi32>
+// CHECK: }
+// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
+// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
+// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
+// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier
+// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index
+// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index
// CHECK: }
func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
%d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>