[mlir][sparse] fix a bug in UnpackOp converter.
authorPeiming Liu <peiming@google.com>
Wed, 15 Feb 2023 02:18:54 +0000 (02:18 +0000)
committerPeiming Liu <peiming@google.com>
Wed, 15 Feb 2023 02:36:00 +0000 (02:36 +0000)
UnpackOp Converter used to create reallocOp unconditionally, but it might cause issue when the requested memory size is smaller than the actually storage.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144065

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/sparse_pack.mlir

index 17319b7..797a318 100644 (file)
@@ -575,6 +575,34 @@ static void genEndInsert(OpBuilder &builder, Location loc,
   }
 }
 
+/// Returns a memref that fits the requested length (reallocates if requested
+/// length is larger, or creates a subview if it is smaller).
+static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len,
+                              Value buffer) {
+  MemRefType memTp = getMemRefType(buffer);
+  auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType());
+
+  Value targetLen = constantIndex(builder, loc, len);
+  Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0);
+  Value reallocP = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                                 targetLen, bufferLen);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, retTp, reallocP, true);
+  // If targetLen > bufferLen, reallocate to get enough sparse to return.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  Value reallocBuf = builder.create<memref::ReallocOp>(loc, retTp, buffer);
+  builder.create<scf::YieldOp>(loc, reallocBuf);
+  // Else, return a subview to fit the size.
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  Value subViewBuf = builder.create<memref::SubViewOp>(
+      loc, retTp, buffer, /*offset=*/ArrayRef<int64_t>{0},
+      /*size=*/ArrayRef<int64_t>{len},
+      /*stride=*/ArrayRef<int64_t>{1});
+  builder.create<scf::YieldOp>(loc, subViewBuf);
+  // Resets insertion point.
+  builder.setInsertionPointAfter(ifOp);
+  return ifOp.getResult(0);
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -1174,16 +1202,13 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
     // to ensure that we meet their need.
     TensorType dataTp = op.getData().getType();
     if (dataTp.hasStaticShape()) {
-      dataBuf = rewriter.create<memref::ReallocOp>(
-          loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()),
-          dataBuf);
+      dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf);
     }
 
     TensorType indicesTp = op.getIndices().getType();
     if (indicesTp.hasStaticShape()) {
       auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1];
-      flatBuf = rewriter.create<memref::ReallocOp>(
-          loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf);
+      flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf);
     }
 
     Value idxBuf = rewriter.create<memref::ExpandShapeOp>(
index eeb41fe..057153a 100644 (file)
@@ -43,14 +43,33 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x2xi32>)
 // CHECK-SAME:      %[[VAL_1:.*]]: memref<?xi32>,
 // CHECK-SAME:      %[[VAL_2:.*]]: memref<?xf64>,
 // CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
-// CHECK:           %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
-// CHECK:           %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
-// CHECK:           %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
-// CHECK:           %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64>
-// CHECK:           %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32>
-// CHECK:           %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]  val_mem_sz
-// CHECK:           %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index
-// CHECK:           return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 6 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf64>
+// CHECK:           %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index
+// CHECK:           %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) {
+// CHECK:             %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref<?xf64> to memref<6xf64>
+// CHECK:             scf.yield %[[VAL_9]] : memref<6xf64>
+// CHECK:           } else {
+// CHECK:             %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref<?xf64> to memref<6xf64>
+// CHECK:             scf.yield %[[VAL_10]] : memref<6xf64>
+// CHECK:           }
+// CHECK:           %[[VAL_11:.*]] = arith.constant 12 : index
+// CHECK:           %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
+// CHECK:           %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) {
+// CHECK:             %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref<?xi32> to memref<12xi32>
+// CHECK:             scf.yield %[[VAL_15]] : memref<12xi32>
+// CHECK:           } else {
+// CHECK:             %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref<?xi32> to memref<12xi32>
+// CHECK:             scf.yield %[[VAL_16]] : memref<12xi32>
+// CHECK:           }
+// CHECK:           %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32>
+// CHECK:           %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64>
+// CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32>
+// CHECK:           %[[VAL_22:.*]] = sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index
+// CHECK:           return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index
 // CHECK:         }
 func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) {
   %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>