[mlir][sparse] keep runtime support library signature consistent
authorAart Bik <ajcbik@google.com>
Tue, 11 May 2021 23:14:00 +0000 (16:14 -0700)
committerAart Bik <ajcbik@google.com>
Wed, 12 May 2021 16:59:46 +0000 (09:59 -0700)
Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D102285

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir

index 336e834..68adb6f 100644 (file)
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   MLIRSCF
   MLIRStandard
   MLIRSparseTensor
+  MLIRTensor
   MLIRTransforms
   MLIRVector
 )
index 71515fe..a2c7b85 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
@@ -103,15 +104,21 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
       return failure();
     // User pointer.
     params.push_back(operands[0]);
-    // Sparsity annotations.
+    // Sparsity annotations in tensor constant form. Note that we cast
+    // the static shape into a dynamic shape to ensure that the method
+    // signature remains uniform accross different tensor dimensions.
     SmallVector<bool, 4> attrs;
     unsigned sz = enc.getDimLevelType().size();
     for (unsigned i = 0; i < sz; i++)
       attrs.push_back(enc.getDimLevelType()[i] ==
                       SparseTensorEncodingAttr::DimLevelType::Compressed);
-    auto elts = DenseElementsAttr::get(
-        RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
-    params.push_back(rewriter.create<ConstantOp>(loc, elts));
+    Type etp = rewriter.getIntegerType(1);
+    RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
+    RankedTensorType tt2 =
+        RankedTensorType::get({ShapedType::kDynamicSize}, etp);
+    auto elts =
+        rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
+    params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
     // Seconary and primary types encoding.
     unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
     unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
index 641ba4a..05fd753 100644 (file)
@@ -120,6 +120,7 @@ struct SparseTensorConversionPass
     target.addDynamicallyLegalOp<ReturnOp>(
         [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); });
     target.addLegalOp<ConstantOp>();
+    target.addLegalOp<tensor::CastOp>();
     populateFuncOpTypeConversionPattern(patterns, converter);
     populateCallOpTypeConversionPattern(patterns, converter);
     populateSparseTensorConversionPatterns(converter, patterns);
index 54bfa74..c749665 100644 (file)
   indexBitWidth = 32
 }>
 
+#SparseMatrix = #sparse_tensor.encoding<{
+  dimLevelType = ["dense", "compressed"]
+}>
+
 // CHECK-LABEL: func @sparse_dim(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index
@@ -27,15 +31,28 @@ func @sparse_dim(%arg0: tensor<?xf64, #SparseVector>) -> index {
   return %0 : index
 }
 
-// CHECK-LABEL: func @sparse_new(
+// CHECK-LABEL: func @sparse_new1d(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]]
+//       CHECK: %[[D:.*]] = constant dense<true> : tensor<1xi1>
+//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi1> to tensor<?xi1>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
   %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
   return %0 : tensor<128xf64, #SparseVector>
 }
 
+// CHECK-LABEL: func @sparse_new2d(
+//  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+//       CHECK: %[[D:.*]] = constant dense<[false, true]> : tensor<2xi1>
+//       CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi1> to tensor<?xi1>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
+//       CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
+  %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
+  return %0 : tensor<?x?xf32, #SparseMatrix>
+}
+
 // CHECK-LABEL: func @sparse_pointers(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
 //       CHECK: %[[C:.*]] = constant 0 : index