[mlir][sparse] lower number of entries op to actual code
authorAart Bik <ajcbik@google.com>
Thu, 20 Oct 2022 23:01:37 +0000 (16:01 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 21 Oct 2022 17:48:37 +0000 (10:48 -0700)
works both along runtime path and pure codegen path

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D136389

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dot.mlir

index 1beb1271103b4ac44ca61eb16a44c5a352631da9..bf2f77d95e66537be83337f2108229d72edc6906 100644 (file)
@@ -277,6 +277,12 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
   return forOp;
 }
 
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+  assert(2 <= field);
+  return field - 2;
+}
+
 /// Creates a pushback op for given field and updates the fields array
 /// accordingly.
 static void createPushback(OpBuilder &builder, Location loc,
@@ -286,9 +292,9 @@ static void createPushback(OpBuilder &builder, Location loc,
   Type etp = fields[field].getType().cast<ShapedType>().getElementType();
   if (value.getType() != etp)
     value = builder.create<arith::IndexCastOp>(loc, etp, value);
-  fields[field] =
-      builder.create<PushBackOp>(loc, fields[field].getType(), fields[1],
-                                 fields[field], value, APInt(64, field - 2));
+  fields[field] = builder.create<PushBackOp>(
+      loc, fields[field].getType(), fields[1], fields[field], value,
+      APInt(64, getMemSizesIndex(field)));
 }
 
 /// Generates insertion code.
@@ -739,6 +745,25 @@ public:
   }
 };
 
+/// Sparse codegen rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+    : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Query memSizes for the actually stored values size.
+    auto tuple = getTuple(adaptor.getTensor());
+    auto fields = tuple.getInputs();
+    unsigned lastField = fields.size() - 1;
+    Value field =
+        constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -775,5 +800,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseExpandConverter, SparseCompressConverter,
                SparseInsertConverter, SparseToPointersConverter,
                SparseToIndicesConverter, SparseToValuesConverter,
-               SparseConvertConverter>(typeConverter, patterns.getContext());
+               SparseConvertConverter, SparseNumberOfEntriesConverter>(
+      typeConverter, patterns.getContext());
 }
index 40112078572bb8c73a75fa3ac2f1ce130ab78c88..c7c81767a40412e779541d2f00778aa49d568844 100644 (file)
@@ -205,6 +205,15 @@ static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
   params.push_back(ptr);
 }
 
+/// Generates a call to obtain the values array.
+static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
+                           ValueRange ptr) {
+  SmallString<15> name{"sparseValues",
+                       primaryTypeFunctionSuffix(tp.getElementType())};
+  return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+      .getResult(0);
+}
+
 /// Generates a call to release/delete a `SparseTensorCOO`.
 static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
                           Value coo) {
@@ -903,11 +912,28 @@ public:
   LogicalResult
   matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resType = op.getType();
-    Type eltType = resType.cast<ShapedType>().getElementType();
-    SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltType)};
-    replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(),
-                          EmitCInterface::On);
+    auto resType = op.getType().cast<ShapedType>();
+    rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
+                                         adaptor.getOperands()));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for number of entries operator.
+class SparseNumberOfEntriesConverter
+    : public OpConversionPattern<NumberOfEntriesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // Query values array size for the actually stored values size.
+    Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
+    auto resTp = MemRefType::get({ShapedType::kDynamicSize}, eltType);
+    Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
+    rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
+                                               constantIndex(rewriter, loc, 0));
     return success();
   }
 };
@@ -1250,9 +1276,10 @@ void mlir::populateSparseTensorConversionPatterns(
                SparseTensorConcatConverter, SparseTensorAllocConverter,
                SparseTensorDeallocConverter, SparseTensorToPointersConverter,
                SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
-               SparseTensorLoadConverter, SparseTensorInsertConverter,
-               SparseTensorExpandConverter, SparseTensorCompressConverter,
-               SparseTensorOutConverter>(typeConverter, patterns.getContext());
+               SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
+               SparseTensorInsertConverter, SparseTensorExpandConverter,
+               SparseTensorCompressConverter, SparseTensorOutConverter>(
+      typeConverter, patterns.getContext());
 
   patterns.add<SparseTensorConvertConverter>(typeConverter,
                                              patterns.getContext(), options);
index 6b5c6c4ce3808031c348cbc8fc7ea8af533887cd..71f736d7263de824059fc6ca1c30f65e26b1e0b8 100644 (file)
@@ -239,6 +239,20 @@ func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
   return %0 : memref<?xf64>
 }
 
+// CHECK-LABEL: func @sparse_noe(
+//  CHECK-SAME: %[[A0:.*0]]: memref<1xindex>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
+//  CHECK-SAME: %[[A2:.*2]]: memref<?xi32>,
+//  CHECK-SAME: %[[A3:.*3]]: memref<?xi64>,
+//  CHECK-SAME: %[[A4:.*4]]: memref<?xf64>)
+//       CHECK: %[[C2:.*]] = arith.constant 2 : index
+//       CHECK: %[[NOE:.*]] = memref.load %[[A1]][%[[C2]]] : memref<3xindex>
+//       CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+  %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+  return %0 : index
+}
+
 // CHECK-LABEL: func @sparse_dealloc_csr(
 //  CHECK-SAME: %[[A0:.*0]]: memref<2xindex>,
 //  CHECK-SAME: %[[A1:.*1]]: memref<3xindex>,
index 44fcd4219ec08c3e5944151e2c776bfb249cac3b..33b7d133fe849547b90bf01530da9cb08f56d14d 100644 (file)
@@ -268,6 +268,17 @@ func.func @sparse_valuesi8(%arg0: tensor<128xi8, #SparseVector>) -> memref<?xi8>
   return %0 : memref<?xi8>
 }
 
+// CHECK-LABEL: func @sparse_noe(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
+//   CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[T:.*]] = call @sparseValuesF64(%[[A]]) : (!llvm.ptr<i8>) -> memref<?xf64>
+//       CHECK: %[[NOE:.*]] = memref.dim %[[T]], %[[C]] : memref<?xf64>
+//       CHECK: return %[[NOE]] : index
+func.func @sparse_noe(%arg0: tensor<128xf64, #SparseVector>) -> index {
+  %0 = sparse_tensor.number_of_entries %arg0 : tensor<128xf64, #SparseVector>
+  return %0 : index
+}
+
 // CHECK-LABEL: func @sparse_reconstruct(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>
 //       CHECK: return %[[A]] : !llvm.ptr<i8>
index 6fc68ed700fdbb554a3e7afa526ad30bea7f0834..b3a5bbdb8f54a381c73124a2bef0884d2fc458bf 100644 (file)
@@ -46,6 +46,16 @@ module {
     %1 = tensor.extract %0[] : tensor<f32>
     vector.print %1 : f32
 
+    // Print number of entries in the sparse vectors.
+    //
+    // CHECK: 5
+    // CHECK: 3
+    //
+    %noe1 = sparse_tensor.number_of_entries %s1 : tensor<1024xf32, #SparseVector>
+    %noe2 = sparse_tensor.number_of_entries %s2 : tensor<1024xf32, #SparseVector>
+    vector.print %noe1 : index
+    vector.print %noe2 : index
+
     // Release the resources.
     bufferization.dealloc_tensor %s1 : tensor<1024xf32, #SparseVector>
     bufferization.dealloc_tensor %s2 : tensor<1024xf32, #SparseVector>