[mlir][sparse] codegen for trivial tensor cast
authorAart Bik <ajcbik@google.com>
Fri, 2 Sep 2022 01:44:48 +0000 (18:44 -0700)
committerAart Bik <ajcbik@google.com>
Fri, 2 Sep 2022 04:55:18 +0000 (21:55 -0700)
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133176

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir

index ac71062..65c1027 100644 (file)
@@ -33,16 +33,6 @@ namespace {
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Reorders stored dimension to original dimension.
-static unsigned toOrig(const SparseTensorEncodingAttr &enc, unsigned i) {
-  auto order = enc.getDimOrdering();
-  if (order) {
-    assert(order.isPermutation());
-    return order.getDimPosition(i);
-  }
-  return i;
-}
-
 /// Reorders original dimension to stored dimension.
 static unsigned toStored(const SparseTensorEncodingAttr &enc, unsigned i) {
   auto order = enc.getDimOrdering();
@@ -67,7 +57,6 @@ static Optional<Type> convertSparseTensorType(Type type) {
   Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType;
   Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType;
   Type eltType = rType.getElementType();
-  ArrayRef<int64_t> shape = rType.getShape();
   //
   // Sparse tensor storage for rank-dimensional tensor is organized as a
   // single compound type with the following fields:
@@ -85,27 +74,18 @@ static Optional<Type> convertSparseTensorType(Type type) {
   //   memref<? x eltType> values        ; values
   // };
   //
-  int64_t linear = 1;
-  bool allDense = true;
   unsigned rank = rType.getShape().size();
   SmallVector<Type, 8> fields;
   // The dimSizes array.
   fields.push_back(MemRefType::get({rank}, indexType));
   // Per-dimension storage.
   for (unsigned r = 0; r < rank; r++) {
-    // Get the original dimension (ro) for the current stored dimension (r).
-    unsigned ro = toOrig(enc, r);
     // Dimension level types apply in order to the reordered dimension.
     // As a result, the compound type can be constructed directly in the given
     // order. Clients of this type know what field is what from the sparse
     // tensor type.
     switch (enc.getDimLevelType()[r]) {
     case SparseTensorEncodingAttr::DimLevelType::Dense:
-      // Linearize the size of consecutive dense dimensions.
-      if (ShapedType::isDynamic(shape[ro]) || ShapedType::isDynamic(linear))
-        linear = ShapedType::kDynamicSize;
-      else
-        linear *= shape[ro];
       break;
     case SparseTensorEncodingAttr::DimLevelType::Compressed:
     case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
@@ -113,23 +93,17 @@ static Optional<Type> convertSparseTensorType(Type type) {
     case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
-      allDense = false;
-      linear = 1;
       break;
     case SparseTensorEncodingAttr::DimLevelType::Singleton:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
     case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
       fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
-      allDense = false;
-      linear = 1;
       break;
     }
   }
   // The values array.
-  int64_t nnz =
-      (rType.hasStaticShape() && allDense) ? linear : ShapedType::kDynamicSize;
-  fields.push_back(MemRefType::get({nnz}, eltType));
+  fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, eltType));
   // Sparse tensor storage (temporarily) lives in a tuple. This allows a
   // simple 1:1 type conversion during codegen. A subsequent pass uses
   // a 1:N type conversion to expand the tuple into its fields.
@@ -241,6 +215,23 @@ public:
   }
 };
 
+/// Sparse codegen rule for trivial tensor casts.
+class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only rewrite identically annotated source/dest.
+    auto encDst = getSparseTensorEncoding(op.getType());
+    auto encSrc = getSparseTensorEncoding(op.getSource().getType());
+    if (!encDst || encDst != encSrc)
+      return failure();
+    rewriter.replaceOp(op, adaptor.getOperands());
+    return success();
+  }
+};
+
 /// Sparse conversion rule for pointer accesses.
 class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
 public:
@@ -314,7 +305,7 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 /// the sparsification of linear algebra operations.
 void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                                                RewritePatternSet &patterns) {
-  patterns.add<SparseReturnConverter, SparseDimOpConverter,
+  patterns.add<SparseReturnConverter, SparseDimOpConverter, SparseCastConverter,
                SparseToPointersConverter, SparseToIndicesConverter,
                SparseToValuesConverter>(typeConverter, patterns.getContext());
 }
index 6c93092..905278c 100644 (file)
 }>
 
 // CHECK-LABEL: func @sparse_nop(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>) -> tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>)
 //       CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf64>>
 func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
   return %arg0 : tensor<?xf64, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_nop_cast(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>)
+//       CHECK: return %[[A]] : tuple<memref<1xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>>
+func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+  %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  return %0 : tensor<?xf32, #SparseVector>
+}
+
+// CHECK-LABEL: func @sparse_nop_cast_3d(
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf32>>)
+//       CHECK: return %[[A]] : tuple<memref<3xindex>, memref<?xf32>>
+func.func @sparse_nop_cast_3d(%arg0: tensor<10x20x30xf32, #Dense3D>) -> tensor<?x?x?xf32, #Dense3D> {
+  %0 = tensor.cast %arg0 : tensor<10x20x30xf32, #Dense3D> to tensor<?x?x?xf32, #Dense3D>
+  return %0 : tensor<?x?x?xf32, #Dense3D>
+}
+
 // CHECK-LABEL: func @sparse_dense_2d(
 //  CHECK-SAME: %[[A:.*]]: tuple<memref<2xindex>, memref<?xf64>>)
 func.func @sparse_dense_2d(%arg0: tensor<?x?xf64, #Dense2D>) {
@@ -71,7 +87,7 @@ func.func @sparse_dcsr(%arg0: tensor<?x?xf64, #DCSR>) {
 // fold using the original static dimension sizes.
 //
 // CHECK-LABEL: func @sparse_dense_3d(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<6000xf64>>) -> index {
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
 //       CHECK: %[[C:.*]] = arith.constant 20 : index
 //       CHECK: return %[[C]] : index
 func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
@@ -86,7 +102,7 @@ func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index {
 // since the latter honors the dimOrdering.
 //
 // CHECK-LABEL: func @sparse_dense_3d_dyn(
-//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>) -> index {
+//  CHECK-SAME: %[[A:.*]]: tuple<memref<3xindex>, memref<?xf64>>)
 //       CHECK: %[[C:.*]] = arith.constant 2 : index
 //       CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple<memref<3xindex>, memref<?xf64>> to memref<3xindex>
 //       CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>