[mlir][sparse] extend pack operation to support packing a batched COO type
authorPeiming Liu <peiming@google.com>
Tue, 18 Apr 2023 22:33:25 +0000 (22:33 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 20 Apr 2023 01:35:30 +0000 (01:35 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D148670

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index d61aaed..6aa4f34 100644 (file)
@@ -101,6 +101,10 @@ inline MemRefType getMemRefType(T t) {
 /// Returns null-attribute for any type without an encoding.
 SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
 
+/// Returns true iff the given sparse tensor encoding attribute has a trailing
+/// COO region starting at the given level.
+bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
+
 /// Returns true iff the given type is a COO type where the last level
 /// is unique.
 bool isUniqueCOOType(Type tp);
index 336f196..9cc3967 100644 (file)
@@ -54,8 +54,9 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 }
 
 def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
-    Arguments<(ins 1DTensorOf<[AnyType]>:$values,
-                   2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates)>,
+    Arguments<(ins TensorOf<[AnyType]>:$values,
+                   TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
+                   OptionalAttr<IndexAttr>:$batched_lvls)>,
     Results<(outs AnySparseTensor: $result)> {
   let summary = "Returns a sparse tensor from the given (values, coordinates) pair";
 
@@ -77,6 +78,8 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
       supplies the level-coords for each element in `values`.
     - `values : tensor<NSE x V>`
       supplies the corresponding values for each entry in `coordinates`.
+    - `batched_lvls : optional<index>`
+      supplies the number of leading levels that are batched.
 
     This operation can be used to materialize a sparse tensor from external
     sources; e.g., when passing two numpy arrays from Python.
@@ -92,10 +95,29 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
     //     of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
     //                   |0.0, 0.0, 0.0, 0.0|
     ```
+
+    If `batched_lvls` is provided, the operation materializes a batched sparse tensor.
+    Example:
+
+    ```mlir
+    %values      = arith.constant dense<[[ 1.1,   2.2,   3.3 ],
+                                         [ 1.2,   2.3,   0.0 ]]> : tensor<2x3xf64>
+    %coordinates = arith.constant dense<[[ [0],   [1],   [2] ],
+                                         [ [1],   [2],   [3] ]> : tensor<2x3x1xindex>
+    %st = sparse_tensor.pack %values, %coordinates batched_lvls=1
+        : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
+    // yields BCOO format |1.1, 2.2, 3.3, 0.0|
+    //      of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
+    ```
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns the number of leading levels that are batched.
+    unsigned getNumBatchedLvls();
   }];
 
   let assemblyFormat =
-    "$values `,` $coordinates attr-dict"
+    "$values `,` $coordinates (`batched_lvls` `=` $batched_lvls^)? attr-dict"
     "`:` type($values) `,` type($coordinates) `to` type($result)";
 
   let hasVerifier = 1;
index 1a93b14..b235301 100644 (file)
@@ -451,9 +451,10 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
 
 /// Returns true iff the given sparse tensor encoding attribute has a trailing
 /// COO region starting at the given level.
-static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl,
-                      bool isUnique) {
-  if (!enc || !enc.isCompressedLvl(startLvl))
+bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
+                                    Level startLvl, bool isUnique) {
+  if (!enc ||
+      !(enc.isCompressedLvl(startLvl) || enc.isCompressedWithHiLvl(startLvl)))
     return false;
   const Level lvlRank = enc.getLvlRank();
   for (Level l = startLvl + 1; l < lvlRank; ++l)
@@ -647,43 +648,55 @@ static LogicalResult verifySparsifierGetterSetter(
 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
                                       SparseTensorType tensorTp,
                                       RankedTensorType valuesTp,
-                                      RankedTensorType coordinatesTp) {
+                                      RankedTensorType coordinatesTp,
+                                      IntegerAttr batchedLvls) {
+  unsigned nBatched = batchedLvls ? batchedLvls.getValue().getZExtValue() : 0;
   if (requiresStaticShape && !tensorTp.hasStaticDimShape())
     return op->emitError("the sparse-tensor must have static shape");
   if (!tensorTp.hasEncoding())
     return op->emitError("the sparse-tensor must have an encoding attribute");
   if (!tensorTp.isIdentity())
     return op->emitError("the sparse-tensor must have the identity mapping");
-  if (!isUniqueCOOType(tensorTp))
+  if (!isCOOType(tensorTp.getEncoding(), nBatched, true))
     return op->emitError("the sparse-tensor must have a COO type");
 
-  if (coordinatesTp.getRank() != 2)
-    return op->emitError("coordinates must have rank 2");
+  if (coordinatesTp.getRank() != 2 + nBatched)
+    return op->emitError("coordinates must have rank 2 + batched_lvls");
   if (requiresStaticShape && !coordinatesTp.hasStaticShape())
     return op->emitError("coordinates must have static shape");
   if (coordinatesTp.getElementType() != tensorTp.getCrdType())
     return op->emitError("input/output coordinate-types don't match");
 
-  if (valuesTp.getRank() != 1)
-    return op->emitError("values must have rank 1");
+  if (valuesTp.getRank() != 1 + nBatched)
+    return op->emitError("values must have rank 1 + batched_lvls");
   if (requiresStaticShape && !valuesTp.hasStaticShape())
     return op->emitError("values must have static shape");
   if (valuesTp.getElementType() != tensorTp.getElementType())
     return op->emitError("input/output element-types don't match");
 
-  const auto valuesNSE = valuesTp.getShape()[0];
-  const auto coordsNSE = coordinatesTp.getShape()[0];
+  for (unsigned i = 0; i < nBatched; i++) {
+    const auto valBatch = valuesTp.getShape()[i];
+    const auto crdBatch = coordinatesTp.getShape()[i];
+    if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) ||
+        crdBatch != valBatch) {
+      return op->emitError(
+          "values/coordinates batched level sizes don't match statically");
+    }
+  }
+
+  const auto valuesNSE = valuesTp.getShape()[nBatched];
+  const auto coordsNSE = coordinatesTp.getShape()[nBatched];
   if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) &&
       valuesNSE != coordsNSE)
     return op->emitError("values/coordinates number-of-elements don't match");
 
   // NOTE: We use `getLvlRank` because the `coordinatesTp` is for
   // level-coordinates (cf., the op documentation).
-  const DynSize coordsRank = coordinatesTp.getShape()[1];
+  const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched];
   const Level tensorRank = tensorTp.getLvlRank();
   // FIXME: replace the `operator!=` with our backported `safelyNE`.
   if (!ShapedType::isDynamic(coordsRank) &&
-      coordsRank != static_cast<DynSize>(tensorRank))
+      coordsRank != static_cast<DynSize>(tensorRank) - nBatched)
     return op->emitError("input/output level-ranks don't match");
 
   return success();
@@ -693,14 +706,20 @@ LogicalResult PackOp::verify() {
   const auto valuesTp = getRankedTensorType(getValues());
   const auto coordinatesTp = getRankedTensorType(getCoordinates());
   const auto resTp = getSparseTensorType(getResult());
-  return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp);
+  return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp,
+                          getBatchedLvlsAttr());
+}
+
+unsigned PackOp::getNumBatchedLvls() {
+  return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
 }
 
 LogicalResult UnpackOp::verify() {
   const auto valuesTp = getRankedTensorType(getValues());
   const auto coordinatesTp = getRankedTensorType(getCoordinates());
   const auto srcTp = getSparseTensorType(getTensor());
-  return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp);
+  return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
+                          nullptr);
 }
 
 LogicalResult ConvertOp::verify() {
index 6fbdc73..8a8b2ed 100644 (file)
@@ -138,7 +138,6 @@ struct PackOpInterface
   AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
     assert(op->getNumResults() == 1);
-    assert(isUniqueCOOType(op->getResultTypes()[0].cast<RankedTensorType>()));
     // PackOp reuses the input tensors as values/coordinates instead of
     // creating new ones when packing into a COO format.
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
index 38f7021..55f4419 100644 (file)
@@ -1231,23 +1231,110 @@ public:
   }
 };
 
+static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc,
+                                             ArrayRef<unsigned> batchDimSzs,
+                                             Value posMemRef, unsigned nse,
+                                             PackOp op) {
+  SmallVector<Value> lbs, ubs, steps;
+  Value c0 = constantIndex(builder, loc, 0);
+  Value c1 = constantIndex(builder, loc, 1);
+  Value c2 = constantIndex(builder, loc, 2);
+  for (unsigned dimSz : batchDimSzs) {
+    lbs.push_back(c0);
+    ubs.push_back(constantIndex(builder, loc, dimSz));
+    steps.push_back(c1);
+  }
+  auto tensorType = op.getValues().getType();
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+  Value batV = builder.create<bufferization::ToMemrefOp>(loc, memrefType,
+                                                         op.getValues());
+  scf::buildLoopNest(
+      builder, loc, lbs, ubs, steps,
+      [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc,
+                                               ValueRange ivs) {
+        // Linearize index variables
+        Value crd = constantIndex(builder, loc, 0);
+        for (unsigned i = 0, e = ivs.size(); i < e; i++) {
+          crd = builder.create<arith::AddIOp>(loc, crd, ivs[i]);
+          if (i != ivs.size() - 1)
+            crd = builder.create<arith::MulIOp>(loc, crd, ubs[i + 1]);
+        }
+        Value len = constantIndex(builder, loc, nse);
+        Value pLo = builder.create<arith::MulIOp>(loc, crd, len);
+        SmallVector<Value> indices(ivs.begin(), ivs.end());
+        auto whileOp = builder.create<scf::WhileOp>(
+            loc, TypeRange{builder.getIndexType()}, ValueRange{len},
+            [&indices, c0, c1, batV](OpBuilder &builder, Location loc,
+                                     ValueRange vs) {
+              Value curLen = vs.front();
+              Value pred = builder.create<arith::CmpIOp>(
+                  loc, arith::CmpIPredicate::eq, curLen, c0);
+              auto ifOp = builder.create<scf::IfOp>(
+                  loc, TypeRange{builder.getI1Type()}, pred, true);
+              {
+                OpBuilder::InsertionGuard guard(builder);
+                // if len == 0.
+                builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+                builder.create<scf::YieldOp>(loc,
+                                             constantI1(builder, loc, false));
+                // Else branch.
+                builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+                indices.push_back(
+                    builder.create<arith::SubIOp>(loc, curLen, c1));
+                Value val = builder.create<memref::LoadOp>(loc, batV, indices);
+                indices.pop_back();
+                Value cont = builder.create<arith::CmpFOp>(
+                    loc, arith::CmpFPredicate::OEQ, val,
+                    constantZero(builder, loc, val.getType()));
+                builder.create<scf::YieldOp>(loc, cont);
+              }
+              builder.create<scf::ConditionOp>(loc, ifOp.getResults()[0], vs);
+            },
+            [c1](OpBuilder &builder, Location loc, ValueRange vs) {
+              // len --;
+              Value nxLen = builder.create<arith::SubIOp>(loc, vs.front(), c1);
+              builder.create<scf::YieldOp>(loc, nxLen);
+            });
+        len = whileOp.getResults()[0];
+        Value pHi = builder.create<arith::AddIOp>(loc, pLo, len);
+        // Stores position lower bound.
+        Value idx = builder.create<arith::MulIOp>(loc, crd, c2);
+        genStore(builder, loc, pLo, posMemRef, idx);
+        // Stores position upper bound.
+        idx = builder.create<arith::AddIOp>(loc, idx, c1);
+        genStore(builder, loc, pHi, posMemRef, idx);
+      });
+}
+
 struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
   matchAndRewrite(PackOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
+    const unsigned batchedLvls = op.getNumBatchedLvls();
+    unsigned nse = op.getValues().getType().getDimSize(batchedLvls);
     const auto stt = getSparseTensorType(op.getResult());
-    assert(isUniqueCOOType(stt));
+    assert(isCOOType(stt.getEncoding(), batchedLvls, true));
+
+    unsigned batchedCount = 1;
+    SmallVector<unsigned> batchDimSzs;
+    batchDimSzs.reserve(batchedLvls);
+    for (unsigned i = 0; i < batchedLvls; i++) {
+      // Should already be guaranteed by verifier.
+      assert(!ShapedType::isDynamic(stt.getDimShape()[i]));
+      batchedCount *= stt.getDimShape()[i];
+      batchDimSzs.push_back(stt.getDimShape()[i]);
+    }
 
     SmallVector<Value> fields;
     Location loc = op.getLoc();
 
     foreachFieldAndTypeInSparseTensor(
         stt,
-        [&rewriter, &fields, &op, stt,
+        [&rewriter, &fields, &op, &batchDimSzs, nse, batchedCount, stt,
          loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
-              Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
+              Level /*lvl*/, DimLevelType dlt) -> bool {
           assert(fields.size() == fIdx);
           Value field;
           switch (fKind) {
@@ -1259,34 +1346,38 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
             // By creating a constant value for it, we avoid the complexity of
             // memory management.
             const auto posTp = stt.getPosType();
-            auto tensorType = RankedTensorType::get({2}, posTp);
-            auto memrefType = MemRefType::get(tensorType.getShape(),
-                                              tensorType.getElementType());
-            auto cstPtr = rewriter.create<arith::ConstantOp>(
-                loc, tensorType,
-                DenseElementsAttr::get(
-                    tensorType,
-                    ArrayRef<Attribute>{
-                        IntegerAttr::get(posTp, 0),
-                        IntegerAttr::get(
-                            posTp, op.getValues().getType().getShape()[0])}));
-            field = rewriter.create<bufferization::ToMemrefOp>(loc, memrefType,
-                                                               cstPtr);
+            if (isCompressedDLT(dlt)) {
+              RankedTensorType tensorType;
+              SmallVector<Attribute> posAttr;
+              tensorType = RankedTensorType::get({batchedCount + 1}, posTp);
+              posAttr.push_back(IntegerAttr::get(posTp, 0));
+              for (unsigned i = 0; i < batchedCount; i++) {
+                // The postion memref will have values as
+                // [0, nse, 2 * nse, ..., batchedCount * nse]
+                posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1)));
+              }
+              MemRefType memrefType = MemRefType::get(
+                  tensorType.getShape(), tensorType.getElementType());
+              auto cstPtr = rewriter.create<arith::ConstantOp>(
+                  loc, tensorType, DenseElementsAttr::get(tensorType, posAttr));
+              field = rewriter.create<bufferization::ToMemrefOp>(
+                  loc, memrefType, cstPtr);
+            } else {
+              assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty());
+              MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp);
+              field = rewriter.create<memref::AllocaOp>(loc, posMemTp);
+              populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs,
+                                               field, nse, op);
+            }
             break;
           }
           case SparseTensorFieldKind::CrdMemRef: {
             auto tensorType = op.getCoordinates().getType();
             auto memrefType = MemRefType::get(tensorType.getShape(),
                                               tensorType.getElementType());
-            auto crdMemRef = rewriter.create<bufferization::ToMemrefOp>(
+            field = rewriter.create<bufferization::ToMemrefOp>(
                 op->getLoc(), memrefType, op.getCoordinates());
-            ReassociationIndices reassociation;
-            for (int i = 0, e = tensorType.getRank(); i < e; i++)
-              reassociation.push_back(i);
 
-            // Flattened the indices buffer to rank 1.
-            field = rewriter.create<memref::CollapseShapeOp>(
-                loc, crdMemRef, ArrayRef<ReassociationIndices>(reassociation));
             break;
           }
           case SparseTensorFieldKind::ValMemRef: {
@@ -1300,6 +1391,17 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
           }
 
           assert(field);
+          if (auto memrefTp = field.getType().dyn_cast<MemRefType>();
+              memrefTp && memrefTp.getRank() > 1) {
+            ReassociationIndices reassociation;
+            for (int i = 0, e = memrefTp.getRank(); i < e; i++)
+              reassociation.push_back(i);
+            // Flattens the buffer to rank 1. The value buffer might need be
+            // collapsed as well due to batching.
+            field = rewriter.create<memref::CollapseShapeOp>(
+                loc, field, ArrayRef<ReassociationIndices>(reassociation));
+          }
+
           if (fType != field.getType())
             field = rewriter.create<memref::CastOp>(loc, fType, field);
           fields.push_back(field);
index b9f45fb..b6f43ad 100644 (file)
@@ -36,7 +36,7 @@ func.func @invalid_pack_dense(%values: tensor<6xf64>, %coordinates: tensor<6x1xi
 
 func.func @invalid_pack_data(%values: tensor<6x1xf64>, %coordinates: tensor<6x1xi32>)
                             -> tensor<100xf64, #SparseVector> {
-  // expected-error@+1 {{'sparse_tensor.pack' op operand #0 must be 1D tensor of any type values}}
+  // expected-error@+1 {{values must have rank 1 + batched_lvls}}
   %0 = sparse_tensor.pack %values, %coordinates
      : tensor<6x1xf64>, tensor<6x1xi32> to tensor<100xf64, #SparseVector>
   return %0 : tensor<100xf64, #SparseVector>
@@ -80,6 +80,18 @@ func.func @invalid_pack_type(%values: tensor<6xf64>, %coordinates: tensor<6x2xi3
 
 // -----
 
+#BCOO = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+
+func.func @invalid_pack_batched(%values: tensor<2x6xf64>, %coordinates: tensor<3x6x1xi32>)
+                              -> tensor<2x100xf64, #BCOO> {
+  // expected-error@+1 {{values/coordinates batched level sizes don't match statically}}
+  %0 = sparse_tensor.pack %values, %coordinates batched_lvls=1
+     : tensor<2x6xf64>, tensor<3x6x1xi32> to tensor<2x100xf64, #BCOO>
+  return %0 : tensor<2x100xf64, #BCOO>
+}
+
+// -----
+
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"], crdWidth=32}>
 
 func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>)
index ff622a4..e3e548c 100644 (file)
@@ -29,6 +29,21 @@ func.func @sparse_pack(%data: tensor<6xf64>, %index: tensor<6x1xi32>)
 
 // -----
 
+#BCOO = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed-hi"], crdWidth=32}>
+// CHECK-LABEL: func @sparse_pack_batched(
+// CHECK-SAME: %[[D:.*]]: tensor<2x6xf64>,
+// CHECK-SAME: %[[I:.*]]: tensor<2x6x1xi32>)
+//       CHECK: %[[R:.*]] = sparse_tensor.pack %[[D]], %[[I]] batched_lvls = 1
+//       CHECK: return %[[R]] : tensor<2x100xf64, #{{.*}}>
+func.func @sparse_pack_batched(%values: tensor<2x6xf64>, %coordinates: tensor<2x6x1xi32>)
+                            -> tensor<2x100xf64, #BCOO> {
+  %0 = sparse_tensor.pack %values, %coordinates batched_lvls=1
+     : tensor<2x6xf64>, tensor<2x6x1xi32> to tensor<2x100xf64, #BCOO>
+  return %0 : tensor<2x100xf64, #BCOO>
+}
+
+// -----
+
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"], crdWidth=32}>
 
 // CHECK-LABEL: func @sparse_unpack(