[mlir][sparse] Replace the folding of nop convert with a codegen rule.
authorbixia1 <bixia@google.com>
Wed, 19 Oct 2022 00:22:13 +0000 (17:22 -0700)
committerbixia1 <bixia@google.com>
Wed, 19 Oct 2022 17:20:47 +0000 (10:20 -0700)
This is to allow the use of a nop convert to express that the sparse tensor
allocated through bufferization::AllocTensorOp will be expanded to sparse
tensor storage by sparse tensor codegen.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D136214

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/fold.mlir

index 16a0565..df02c5e 100644 (file)
@@ -86,7 +86,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
 
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
index 7822b40..e9168f7 100644 (file)
@@ -333,12 +333,6 @@ LogicalResult ConvertOp::verify() {
   return emitError("unexpected type in convert");
 }
 
-OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
-  if (getType() == getSource().getType())
-    return getSource();
-  return {};
-}
-
 LogicalResult ToPointersOp::verify() {
   auto e = getSparseTensorEncoding(getTensor().getType());
   if (failed(isInBounds(getDimension().getZExtValue(), getTensor())))
index 77bfef7..707c6c9 100644 (file)
@@ -720,6 +720,22 @@ public:
   }
 };
 
+/// Sparse codegen rule for the convert operator.
+class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getType() != op.getSource().getType()) {
+      // This should be handled by rewriting before codegen.
+      return failure();
+    }
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -744,6 +760,6 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter,
                SparseTensorDeallocConverter, SparseTensorLoadConverter,
                SparseExpandConverter, SparseCompressConverter,
                SparseInsertConverter, SparseToPointersConverter,
-               SparseToIndicesConverter, SparseToValuesConverter>(
-      typeConverter, patterns.getContext());
+               SparseToIndicesConverter, SparseToValuesConverter,
+               SparseConvertConverter>(typeConverter, patterns.getContext());
 }
index b469e66..6a32c72 100644 (file)
@@ -518,3 +518,15 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
   %1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
   return %1 : tensor<128xf64, #SparseVector>
 }
+
+// CHECK-LABEL:   func.func @sparse_nop_convert(
+//  CHECK-SAME:   %[[A0:.*]]: memref<1xindex>,
+//  CHECK-SAME:   %[[A1:.*]]: memref<3xindex>,
+//  CHECK-SAME:   %[[A2:.*]]: memref<?xi32>,
+//  CHECK-SAME:   %[[A3:.*]]: memref<?xi64>,
+//  CHECK-SAME:   %[[A4:.*]]: memref<?xf32>)
+//       CHECK:   return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref<?xi32>, memref<?xi64>, memref<?xf32>
+func.func @sparse_nop_convert(%arg0: tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
+  return %0 : tensor<?xf32, #SparseVector>
+}
index 58900ad..fba2e8e 100644 (file)
@@ -2,15 +2,6 @@
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
-// CHECK-LABEL: func @sparse_nop_convert(
-//  CHECK-SAME: %[[A:.*]]: tensor<64xf32, #sparse_tensor.encoding<{{{.*}}}>>)
-//   CHECK-NOT: sparse_tensor.convert
-//       CHECK: return %[[A]] : tensor<64xf32, #sparse_tensor.encoding<{{{.*}}}>>
-func.func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32, #SparseVector> {
-  %0 = sparse_tensor.convert %arg0 : tensor<64xf32, #SparseVector> to tensor<64xf32, #SparseVector>
-  return %0 : tensor<64xf32, #SparseVector>
-}
-
 // CHECK-LABEL: func @sparse_dce_convert(
 //  CHECK-SAME: %[[A:.*]]: tensor<64xf32>)
 //   CHECK-NOT: sparse_tensor.convert