[mlir][sparse] support sparse tensor element type conversion in codegen path
authorPeiming Liu <peiming@google.com>
Wed, 22 Feb 2023 19:04:02 +0000 (19:04 +0000)
committerPeiming Liu <peiming@google.com>
Thu, 23 Feb 2023 17:49:50 +0000 (17:49 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144578

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
mlir/test/Dialect/SparseTensor/convert_sparse2sparse_element.mlir [new file with mode: 0644]
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_element.mlir [new file with mode: 0644]

index b272d9a..44a2b44 100644 (file)
@@ -260,6 +260,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// reset to the default/identity.
     SparseTensorEncodingAttr withoutOrdering() const;
 
+    /// Constructs a new encoding with the pointer and index bitwidth
+    /// reset to the default.
+    SparseTensorEncodingAttr withoutBitWidths() const;
+
     /// Returns true if every level is dense.  Also returns true for
     /// the null encoding (since dense-tensors are always all-dense).
     bool isAllDense() const;
index 4d06cb0..9d0b124 100644 (file)
@@ -135,7 +135,7 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>,
 }
 
 def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
-  [Pure, SameOperandsAndResultElementType]>,
+  [Pure]>,
     Arguments<(ins AnyTensor:$source)>,
     Results<(outs AnyTensor:$dest)> {
   string summary = "Converts between different tensor types";
index 9279ec7..740023a 100644 (file)
@@ -135,6 +135,12 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
       getPointerBitWidth(), getIndexBitWidth());
 }
 
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
+  return SparseTensorEncodingAttr::get(getContext(), getDimLevelType(),
+                                       getDimOrdering(), getHigherOrdering(), 0,
+                                       0);
+}
+
 bool SparseTensorEncodingAttr::isAllDense() const {
   return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT);
 }
index fc1ea38..f3f6784 100644 (file)
@@ -1030,11 +1030,73 @@ public:
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
     SparseTensorEncodingAttr encSrc =
         getSparseTensorEncoding(op.getSource().getType());
-    if (encDst != encSrc) {
-      // This should be handled by rewriting before codegen.
+    // Different encoding (except for different bitwidth) should be handled by
+    // rewriting.
+    if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) {
       return failure();
     }
-    rewriter.replaceOp(op, adaptor.getSource());
+
+    Type retElemTp = op.getResult().getType().getElementType();
+    Type srcElemTp = op.getSource().getType().getElementType();
+    // Fold the trivial cases.
+    if (retElemTp == srcElemTp && encDst == encSrc) {
+      rewriter.replaceOp(op, adaptor.getSource());
+      return success();
+    }
+    //
+    // Do element-wise type conversion without using InsertOp.
+    //
+    // for each memref in srcTensor:
+    //   dst = memref.alloc
+    //   if srcMemRefType != dstMemRefType:
+    //     for every dst[i] = cast(src[i])
+    //   else:
+    //     dst = memref.copy(src)
+    Location loc = op.getLoc();
+    auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
+    SmallVector<Value> fields;
+    foreachFieldAndTypeInSparseTensor(
+        SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
+        [&rewriter, &fields, srcDesc,
+         loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
+              DimLevelType /*dlt*/) -> bool {
+          // Simply reuses the storage specifier as it is an SSA value.
+          if (fKind == SparseTensorFieldKind::StorageSpec) {
+            fields.push_back(srcDesc.getSpecifier());
+          } else {
+            // Allocates new memrefs
+            Value srcMem = srcDesc.getMemRefField(fIdx);
+            // TODO: We can instead use the actual memSize in specifier, that
+            // would require a subViewOp to avoid overflow when copying
+            // values.
+            Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
+            auto dstMem = rewriter.create<memref::AllocOp>(
+                loc, fTp.cast<MemRefType>(), sz);
+            if (fTp != srcMem.getType()) {
+              // Converts elements type.
+              scf::buildLoopNest(
+                  rewriter, loc, constantIndex(rewriter, loc, 0), sz,
+                  constantIndex(rewriter, loc, 1),
+                  [srcMem, &dstMem](OpBuilder &builder, Location loc,
+                                    ValueRange ivs) {
+                    Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
+                    Value casted = genCast(builder, loc, v,
+                                           dstMem.getType().getElementType());
+                    builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
+                  });
+            } else {
+              // TODO: We can even reuse the same memref for the new tensor,
+              // but that requires a `ref-counting` based memory management
+              // for shared memrefs between multiple sparse tensors.
+              rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
+            }
+            fields.push_back(dstMem);
+          }
+          return true;
+        });
+
+    rewriter.replaceOp(
+        op, genTuple(rewriter, loc, op.getResult().getType(), fields));
     return success();
   }
 };
index ef78977..2071360 100644 (file)
@@ -616,12 +616,17 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
                                 PatternRewriter &rewriter) const override {
     auto encDst = getSparseTensorEncoding(op.getType());
     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
-    if (encDst && encSrc) {
-      // Trivial tensor conversion is handled in codegen.
-      if (encSrc == encDst)
-        return failure();
-      return sparse2SparseRewrite(op, rewriter);
+    if (encDst && encSrc &&
+        encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
+      // Trivial tensor conversion and simple element type conversion is handled
+      // in codegen.
+      return failure();
     }
+    // TODO: Add a cast before generating InsertOp.
+    assert(op.getSource().getType().getElementType() ==
+           op.getDest().getType().getElementType());
+    if (encSrc && encDst)
+      return sparse2SparseRewrite(op, rewriter);
     if (encSrc && !encDst)
       return sparse2DenseRewrite(op, rewriter);
     if (!encSrc && encDst)
index 1e05f88..e618ee8 100644 (file)
@@ -100,19 +100,6 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
 //       CHECK-AUTO: %[[T:.*]] = call @newSparseTensor(%[[DimSizesP]], %[[LvlSizesP]], %[[LvlTypesP]], %[[IotaP]], %[[IotaP]], %{{.*}}, %{{.*}}, %{{.*}}, %[[SparseToSparse]], %[[A]])
 //       CHECK-AUTO: return %[[T]] : !llvm.ptr<i8>
 
-// CHECK-RWT-LABEL: func.func @sparse_convert(
-//  CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
-//  CHECK-RWT-DAG:  %[[C0:.*]] = arith.constant 0 : index
-//      CHECK-RWT:  %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
-//      CHECK-RWT:  %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
-//      CHECK-RWT:  %[[RET:.*]] = sparse_tensor.foreach in %[[A]] init(%[[DST]])
-//      CHECK-RWT:  ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32, %[[T:.*]]: tensor<?xf32,
-//      CHECK-RWT:    %[[I:.*]] = sparse_tensor.insert %[[FV2]] into %[[T]]{{\[}}%[[FI2]]]
-//      CHECK-RWT:    sparse_tensor.yield %[[I]]
-//      CHECK-RWT:  }
-//      CHECK-RWT:  %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
-//      CHECK-RWT:  %[[R:.*]] = sparse_tensor.convert %[[T]]
-//      CHECK-RWT:  return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
 func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
   return %0 : tensor<?xf32, #SparseVector32>
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse_element.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse_element.mlir
new file mode 100644 (file)
index 0000000..e45d794
--- /dev/null
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s
+
+#SparseVector64 = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  pointerBitWidth = 64,
+  indexBitWidth = 64
+}>
+
+#SparseVector32 = #sparse_tensor.encoding<{
+  dimLevelType = ["compressed"],
+  pointerBitWidth = 32,
+  indexBitWidth = 32
+}>
+
+
+// CHECK-LABEL:   func.func @sparse_convert(
+// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xi64>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xi64>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf32>,
+// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : memref<?xi64>
+// CHECK:           %[[VAL_7:.*]] = memref.alloc(%[[VAL_6]]) : memref<?xi32>
+// CHECK:           scf.for %[[VAL_8:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi64>
+// CHECK:             %[[VAL_10:.*]] = arith.trunci %[[VAL_9]] : i64 to i32
+// CHECK:             memref.store %[[VAL_10]], %[[VAL_7]]{{\[}}%[[VAL_8]]] : memref<?xi32>
+// CHECK:           }
+// CHECK:           %[[VAL_11:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi64>
+// CHECK:           %[[VAL_12:.*]] = memref.alloc(%[[VAL_11]]) : memref<?xi32>
+// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<?xi64>
+// CHECK:             %[[VAL_15:.*]] = arith.trunci %[[VAL_14]] : i64 to i32
+// CHECK:             memref.store %[[VAL_15]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<?xi32>
+// CHECK:           }
+// CHECK:           %[[VAL_16:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf32>
+// CHECK:           %[[VAL_17:.*]] = memref.alloc(%[[VAL_16]]) : memref<?xf32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_17]] : memref<?xf32> to memref<?xf32>
+// CHECK:           return %[[VAL_7]], %[[VAL_12]], %[[VAL_17]], %[[VAL_3]] : memref<?xi32>, memref<?xi32>, memref<?xf32>, !sparse_tensor.storage_specifier
+// CHECK:         }
+func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
+  return %0 : tensor<?xf32, #SparseVector32>
+}
+
+// CHECK-LABEL:   func.func @sparse_convert_value(
+// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xi32>,
+// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xi32>,
+// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf32>,
+// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : memref<?xi32>
+// CHECK:           %[[VAL_7:.*]] = memref.alloc(%[[VAL_6]]) : memref<?xi32>
+// CHECK:           memref.copy %[[VAL_0]], %[[VAL_7]] : memref<?xi32> to memref<?xi32>
+// CHECK:           %[[VAL_8:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref<?xi32>
+// CHECK:           %[[VAL_9:.*]] = memref.alloc(%[[VAL_8]]) : memref<?xi32>
+// CHECK:           memref.copy %[[VAL_1]], %[[VAL_9]] : memref<?xi32> to memref<?xi32>
+// CHECK:           %[[VAL_10:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref<?xf32>
+// CHECK:           %[[VAL_11:.*]] = memref.alloc(%[[VAL_10]]) : memref<?xf64>
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK:             %[[VAL_14:.*]] = arith.extf %[[VAL_13]] : f32 to f64
+// CHECK:             memref.store %[[VAL_14]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<?xf64>
+// CHECK:           }
+// CHECK:           return %[[VAL_7]], %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : memref<?xi32>, memref<?xi32>, memref<?xf64>, !sparse_tensor.storage_specifier
+// CHECK:         }
+func.func @sparse_convert_value(%arg0: tensor<?xf32, #SparseVector32>) -> tensor<?xf64, #SparseVector32> {
+  %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector32> to tensor<?xf64, #SparseVector32>
+  return %0 : tensor<?xf64, #SparseVector32>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_element.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_element.mlir
new file mode 100644 (file)
index 0000000..3ffa323
--- /dev/null
@@ -0,0 +1,107 @@
+// DEFINE: %{option} = "enable-runtime-library=false s2s-strategy=2"
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
+// DEFINE: %{run} = mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false s2s-strategy=2 vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{compile} | %{run}
+//
+// Do the same run, but now with direct IR generation and, if available, VLA
+// vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=4 enable-arm-sve=%ENABLE_VLA"
+// REDEFINE: %{run} = %lli \
+// REDEFINE:   --entry-function=entry_lli \
+// REDEFINE:   --extra-module=%S/Inputs/main_for_lli.ll \
+// REDEFINE:   %VLA_ARCH_ATTR_OPTIONS \
+// REDEFINE:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// REDEFINE: FileCheck %s
+// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
+
+#Tensor1 = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ]
+}>
+
+#Tensor2 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed", "dense" ]
+}>
+
+#Tensor3 = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j,k) -> (i,k,j)>
+}>
+
+module {
+  //
+  // Utility for output.
+  //
+  func.func @dump(%arg0: tensor<2x3x4xf32>) {
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f32
+    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %d0: tensor<2x3x4xf32>, vector<2x3x4xf32>
+    vector.print %0 : vector<2x3x4xf32>
+    return
+  }
+
+  //
+  // The first test suite (for non-singleton DimLevelTypes).
+  //
+  func.func @entry() {
+    //
+    // Initialize a 3-dim dense tensor.
+    //
+    %src = arith.constant dense<[
+       [  [  1.0,  2.0,  3.0,  4.0 ],
+          [  5.0,  6.0,  7.0,  8.0 ],
+          [  9.0, 10.0, 11.0, 12.0 ] ],
+       [  [ 13.0, 14.0, 15.0, 16.0 ],
+          [ 17.0, 18.0, 19.0, 20.0 ],
+          [ 21.0, 22.0, 23.0, 24.0 ] ]
+    ]> : tensor<2x3x4xf64>
+
+    //
+    // Convert dense tensor directly to various sparse tensors.
+    //
+    %s1 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor1>
+    %s2 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor2>
+    %s3 = sparse_tensor.convert %src : tensor<2x3x4xf64> to tensor<2x3x4xf64, #Tensor3>
+
+    //
+    // Convert sparse tensor directly to another sparse format.
+    //
+    %t1 = sparse_tensor.convert %s1 : tensor<2x3x4xf64, #Tensor1> to tensor<2x3x4xf32, #Tensor1>
+    %t2 = sparse_tensor.convert %s2 : tensor<2x3x4xf64, #Tensor2> to tensor<2x3x4xf32, #Tensor2>
+    %t3 = sparse_tensor.convert %s3 : tensor<2x3x4xf64, #Tensor3> to tensor<2x3x4xf32, #Tensor3>
+
+    //
+    // Convert sparse tensor back to dense.
+    //
+    %d1 = sparse_tensor.convert %t1 : tensor<2x3x4xf32, #Tensor1> to tensor<2x3x4xf32>
+    %d2 = sparse_tensor.convert %t2 : tensor<2x3x4xf32, #Tensor2> to tensor<2x3x4xf32>
+    %d3 = sparse_tensor.convert %t3 : tensor<2x3x4xf32, #Tensor3> to tensor<2x3x4xf32>
+
+    //
+    // Check round-trip equality.  And release dense tensors.
+    //
+    // CHECK-COUNT-3: ( ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ), ( ( 13, 14, 15, 16 ), ( 17, 18, 19, 20 ), ( 21, 22, 23, 24 ) ) )
+    call @dump(%d1) : (tensor<2x3x4xf32>) -> ()
+    call @dump(%d2) : (tensor<2x3x4xf32>) -> ()
+    call @dump(%d3) : (tensor<2x3x4xf32>) -> ()
+
+    //
+    // Release sparse tensors.
+    //
+    bufferization.dealloc_tensor %t1 : tensor<2x3x4xf32, #Tensor1>
+    bufferization.dealloc_tensor %t2 : tensor<2x3x4xf32, #Tensor2>
+    bufferization.dealloc_tensor %t3 : tensor<2x3x4xf32, #Tensor3>
+    bufferization.dealloc_tensor %s1 : tensor<2x3x4xf64, #Tensor1>
+    bufferization.dealloc_tensor %s2 : tensor<2x3x4xf64, #Tensor2>
+    bufferization.dealloc_tensor %s3 : tensor<2x3x4xf64, #Tensor3>
+
+    return
+  }
+}