[mlir][sparse] added codegen for dimop, pointers, indices, values
authorAart Bik <ajcbik@google.com>
Thu, 1 Sep 2022 19:34:58 +0000 (12:34 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 1 Sep 2022 23:36:10 +0000 (16:36 -0700)
Demonstrates how sparse tensor type -> tuple -> getter
will eventually yield actual code on the memrefs directly

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D133143

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir

index d765a10..8f96280 100644 (file)
@@ -142,6 +142,10 @@ def SparseTensorCodegen : Pass<"sparse-tensor-codegen", "ModuleOp"> {
   }];
   let constructor = "mlir::createSparseTensorCodegenPass()";
   let dependentDialects = [
+    "arith::ArithmeticDialect",
+    "bufferization::BufferizationDialect",
+    "memref::MemRefDialect",
+    "scf::SCFDialect",
     "sparse_tensor::SparseTensorDialect",
   ];
 }
index d82ebea..ac71062 100644 (file)
@@ -33,14 +33,24 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Reorders stored dimension to logical dimension.
-static unsigned reorder(const SparseTensorEncodingAttr &enc, unsigned d) {
+/// Reorders stored dimension to original dimension.
+static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
   auto order = enc.getDimOrdering();
   if (order) {
     assert(order.isPermutation());
-    return order.getDimPosition(d);
+    return order.getDimPosition(i);
   }
-  return d;
+  return i;
+}
+
+/// Reorders original dimension to stored dimension.
+static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
+  auto order = enc.getDimOrdering();
+  if (order) {
+    assert(order.isPermutation());
+    return order.getPermutedPosition(i);
+  }
+  return i;
 }
 
 /// Maps a sparse tensor type to the appropriate compounded buffers.
@@ -63,14 +73,13 @@ static Optional<Type> convertSparseTensorType(Type type) {
   // single compound type with the following fields:
   //
   // struct {
-  //   ; if dynamic shape:
-  //     memref<rank x index> dimSize    ; size in each dimension
+  //   memref<rank x index> dimSizes     ; size in each dimension
   //   ; per-dimension d:
   //   ;  if dense:
   //        <nothing>
   //   ;  if compresed:
-  //        memref<? x idx>  indices-d   ; indices for sparse dim d
   //        memref<? x ptr>  pointers-d  ; pointers for sparse dim d
+  //        memref<? x idx>  indices-d   ; indices for sparse dim d
   //   ;  if singleton:
   //        memref<? x idx>  indices-d   ; indices for singleton dim d
   //   memref<? x eltType> values        ; values
@@ -81,12 +90,11 @@ static Optional<Type> convertSparseTensorType(Type type) {
   unsigned rank = rType.getShape().size();
   SmallVector<Type, 8> fields;
   // The dimSizes array.
-  if (!rType.hasStaticShape())
-    fields.push_back(MemRefType::get({rank}, indexType));
+  fields.push_back(MemRefType::get({rank}, indexType));
   // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
     // Get the original dimension (ro) for the current stored dimension (r).
-    unsigned ro = reorder(enc, r);
+    unsigned ro = toOrig(enc, r);
     // Dimension level types apply in order to the reordered dimension.
     // As a result, the compound type can be constructed directly in the given
     // order. Clients of this type know what field is what from the sparse
@@ -103,8 +111,8 @@ static Optional<Type> convertSparseTensorType(Type type) {
     case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
     case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
     case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
-      fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
+      fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
       allDense = false;
       linear = 1;
       break;
@@ -128,6 +136,63 @@ static Optional<Type> convertSparseTensorType(Type type) {
   return TupleType::get(context, fields);
 }
 
+// Returns field index for pointers (d), indices (d) for set field.
+static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
+  auto enc = getSparseTensorEncoding(type);
+  assert(enc);
+  RankedTensorType rType = type.cast<RankedTensorType>();
+  unsigned field = 1; // start at DimSizes;
+  unsigned ptr = 0;
+  unsigned idx = 0;
+  for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
+    switch (enc.getDimLevelType()[r]) {
+    case SparseTensorEncodingAttr::DimLevelType::Dense:
+      break; // no fields
+    case SparseTensorEncodingAttr::DimLevelType::Compressed:
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
+    case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
+      if (ptr++ == ptrDim)
+        return field;
+      field++;
+      if (idx++ == idxDim)
+        return field;
+      field++;
+      break;
+    case SparseTensorEncodingAttr::DimLevelType::Singleton:
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
+    case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
+      if (idx++ == idxDim)
+        return field;
+      field++;
+      break;
+    }
+  }
+  llvm_unreachable("failed to find ptr/idx field index");
+  return -1;
+}
+
+/// Returns field type in tuple at given index.
+static Type getFieldType(Value tuple, unsigned field) {
+  return tuple.getType().cast<TupleType>().getType(field);
+}
+
+/// Creates tuple get operation at given index.
+static Value createTupleGet(OpBuilder &builder, Location loc, Value tuple,
+                            unsigned field) {
+  Type indexType = builder.getIndexType();
+  return builder.create<StorageGetOp>(loc, getFieldType(tuple, field), tuple,
+                                      builder.getIntegerAttr(indexType, field));
+}
+
+/// Returns integral constant, if defined.
+static Optional<int64_t> getConstantInt(Value val) {
+  if (auto constantOp = val.getDefiningOp<arith::ConstantOp>())
+    return constantOp.getValue().cast<IntegerAttr>().getInt();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -151,26 +216,82 @@ public:
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op->getLoc();
-    Type type = op.getSource().getType();
     // Only rewrite annotated DimOp with constant index.
-    auto enc = getSparseTensorEncoding(type);
+    auto enc = getSparseTensorEncoding(op.getSource().getType());
     if (!enc)
       return failure();
-    Optional<int64_t> index = op.getConstantIndex();
+    Optional<int64_t> index = getConstantInt(adaptor.getIndex());
     if (!index)
       return failure();
-    // Access into static shape can query original type directly.
+    // Access into static dimension can query original type directly.
     // Note that this is typically already done by DimOp's folding.
-    RankedTensorType rType = type.cast<RankedTensorType>();
-    if (rType.hasStaticShape()) {
-      rewriter.replaceOp(
-          op, constantIndex(rewriter, loc, rType.getShape()[*index]));
+    Location loc = op->getLoc();
+    auto shape = op.getSource().getType().cast<RankedTensorType>().getShape();
+    if (!ShapedType::isDynamic(shape[*index])) {
+      rewriter.replaceOp(op, constantIndex(rewriter, loc, shape[*index]));
       return success();
     }
-    // Any other query can consult the dimSize array.
-    // TODO: this needs tuple access
-    return failure();
+    // Any other query can consult the dimSizes array at field 0 using,
+    // accounting for the reordering applied to the sparse storage.
+    Value tuple = adaptor.getSource();
+    Value dimSizes = createTupleGet(rewriter, loc, tuple, 0);
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(
+        op, dimSizes, constantIndex(rewriter, loc, toStored(enc, *index)));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for pointer accesses.
+class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
+    if (!index)
+      return failure();
+    // Replace the requested pointer access with corresponding field.
+    Location loc = op->getLoc();
+    Value tuple = adaptor.getTensor();
+    unsigned i = getFieldIndex(op.getTensor().getType(), /*ptrDim=*/*index, -1);
+    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for index accesses.
+class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Optional<int64_t> index = getConstantInt(adaptor.getOperands()[1]);
+    if (!index)
+      return failure();
+    // Replace the requested indices access with corresponding field.
+    Location loc = op->getLoc();
+    Value tuple = adaptor.getTensor();
+    unsigned i = getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/*index);
+    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+    return success();
+  }
+};
+
+/// Sparse conversion rule for value accesses.
+class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Replace the requested values access with corresponding field.
+    Location loc = op->getLoc();
+    Value tuple = adaptor.getTensor();
+    unsigned i = tuple.getType().cast<TupleType>().size() - 1;  // last
+    rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i));
+    return success();
   }
 };
 
@@ -193,6 +314,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
-  patterns.add<SparseReturnConverter, SparseDimOpConverter>(
-      typeConverter, patterns.getContext());
+  patterns.add<SparseReturnConverter, SparseDimOpConverter,
+               SparseToPointersConverter, SparseToIndicesConverter,
+               SparseToValuesConverter>(typeConverter, patterns.getContext());
 }
index d5e2b96..c1a6b7a 100644 (file)
@@ -157,8 +157,9 @@ struct SparseTensorCodegenPass
     RewritePatternSet patterns(ctx);
     SparseTensorTypeToBufferConverter converter;
     ConversionTarget target(*ctx);
-    // Everything in the sparse dialect must go!
+    // Almost everything in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
+    target.addLegalOp<StorageGetOp, StorageSetOp>();
     // All dynamic rules below accept new function, call, return.
     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
       return converter.isSignatureLegal(op.getFunctionType());
index 6662616..6c93092 100644 (file)
 
 #Dense3D = #sparse_tensor.encoding<{
   dimLevelType = [ "dense", "dense", "dense" ],
-  indexBitWidth = 64,
-  pointerBitWidth = 32,
-  dimOrdering = affine_map<(i,j,k) -> (k, i,j)>
+  dimOrdering = affine_map<(i, j, k) -> (k, i, j)>
 }>
 
 // CHECK-LABEL: func @sparse_nop(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
-//       CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
+//       CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
@@ -51,28 +49,29 @@ func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
 }
 
 // CHECK-LABEL: func @sparse_row(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 func.func @sparse_row(%arg0: tensor<?x?xf64, #Row>) {
   return
 }
 
 // CHECK-LABEL: func @sparse_csr(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 func.func @sparse_csr(%arg0: tensor<?x?xf64, #CSR>) {
   return
 }
 
 // CHECK-LABEL: func @sparse_dcsr(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xf64>>)
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
   return
 }
 
 //
-// Just a linearized array in the end. Dim op is statically known.
+// Querying for dimension 1 in the tensor type can immediately
+// fold using the original static dimension sizes.
 //
 // CHECK-LABEL: func @sparse_dense_3d(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<6000xf64>>) -> index
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<6000xf64>>) -> index {
 //       CHECK: %[[C:.*]] = arith.constant 20 : index
 //       CHECK: return %[[C]] : index
 func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -80,3 +79,49 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
   %0 = tensor.dim %arg0, %c : tensor<10x20x30xf64, #Dense3D>
   return %0 : index
 }
+
+//
+// Querying for dimension 1 in the tensor type needs to be permuted
+// into querying for dimension 2 in the stored sparse tensor scheme,
+// since the latter honors the dimOrdering.
+//
+// CHECK-LABEL: func @sparse_dense_3d_dyn(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>) -> index {
+//       CHECK: %[[C:.*]] = arith.constant 2 : index
+//       CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
+//       CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>
+//       CHECK: return %[[L]] : index
+func.func @sparse_dense_3d_dyn(%arg0: tensor<?x?x?xf64, #Dense3D>) -> index {
+  %c = arith.constant 1 : index
+  %0 = tensor.dim %arg0, %c : tensor<?x?x?xf64, #Dense3D>
+  return %0 : index
+}
+
+// CHECK-LABEL: func @sparse_pointers_dcsr(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+//       CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi32>
+//       CHECK: return %[[F]] : memref<?xi32>
+func.func @sparse_pointers_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi32> {
+  %c = arith.constant 1 : index
+  %0 = sparse_tensor.pointers %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi32>
+  return %0 : memref<?xi32>
+}
+
+// CHECK-LABEL: func @sparse_indices_dcsr(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+//       CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][4] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xi64>
+//       CHECK: return %[[F]] : memref<?xi64>
+func.func @sparse_indices_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xi64> {
+  %c = arith.constant 1 : index
+  %0 = sparse_tensor.indices %arg0, %c : tensor<?x?xf64, #DCSR> to memref<?xi64>
+  return %0 : memref<?xi64>
+}
+
+// CHECK-LABEL: func @sparse_values_dcsr(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
+//       CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][5] : tuple<memref<2xindex>, memref<?xi32>, memref<?xi64>, memref<?xi32>, memref<?xi64>, memref<?xf64>> to memref<?xf64>
+//       CHECK: return %[[F]] : memref<?xf64>
+func.func @sparse_values_dcsr(%arg0: tensor<?x?xf64, #DCSR>) -> memref<?xf64> {
+  %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+  return %0 : memref<?xf64>
+}