[mlir][sparse] generalize sparse_tensor.convert on static/dynamic dimension sizes
authorAart Bik <ajcbik@google.com>
Fri, 15 Oct 2021 23:10:30 +0000 (16:10 -0700)
committerAart Bik <ajcbik@google.com>
Mon, 18 Oct 2021 20:54:03 +0000 (13:54 -0700)
This revison lifts the artificial restriction on having exact matches between
source and destination type shapes. A static size may become dynamic. We still
reject changing a dynamic size into a static size to avoid the need for a
runtime "assert" on the conversion. This revision also refactors some of the
conversion code to share same-content buffers.

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D111915

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
mlir/test/Dialect/SparseTensor/conversion.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir [new file with mode: 0644]

index c7e6e0a..d1724b4 100644 (file)
@@ -80,9 +80,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
   string summary = "Converts between different tensor types";
   string description = [{
     Converts one sparse or dense tensor type to another tensor type. The rank
-    and dimensions of the source and destination types must match exactly,
-    only the sparse encoding of these types may be different. The name `convert`
-    was preferred over `cast`, since the operation may incur a non-trivial cost.
+    and dimensions of the source and destination types must match, but the sparse
+    encoding of these types can obviously be different. The name `convert` was
+    preferred over `cast`, since the operation may incur a non-trivial cost.
 
     When converting between two different sparse tensor types, only explicitly
     stored values are moved from one underlying sparse storage format to
@@ -97,9 +97,14 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
     Examples:
 
     ```mlir
-    %0 = sparse_tensor.convert %1 : tensor<32x32xf32> to tensor<32x32xf32, #CSR>
-
-    %2 = sparse_tensor.convert %3 : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR>
+    %0 = sparse_tensor.convert %a : tensor<32x32xf32> to tensor<32x32xf32, #CSR>
+    %1 = sparse_tensor.convert %a : tensor<32x32xf32> to tensor<?x?xf32, #CSR>
+    %2 = sparse_tensor.convert %b : tensor<8x8xi32, #CSC> to tensor<8x8xi32, #CSR>
+    %3 = sparse_tensor.convert %c : tensor<4x8xf64, #CSR> to tensor<4x?xf64, #CSC>
+
+    // The following conversion is not allowed (since it would require a
+    // runtime assertion that the source's dimension size is actually 100).
+    %4 = sparse_tensor.convert %d : tensor<?xf64> to tensor<100xf64, #SV>
     ```
 
   }];
index 8a0e467..bb499be 100644 (file)
@@ -240,8 +240,11 @@ static LogicalResult verify(ConvertOp op) {
       assert(tp1.getRank() == tp2.getRank());
       auto shape1 = tp1.getShape();
       auto shape2 = tp2.getShape();
+      // Accept size matches between the source and the destination type
+      // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
+      // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
-        if (shape1[d] != shape2[d])
+        if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
           return op.emitError("unexpected conversion mismatch in dimension ")
                  << d;
       }
index e98b1fa..ffd852f 100644 (file)
@@ -99,7 +99,7 @@ inline static Value constantZero(ConversionPatternRewriter &rewriter,
 
 /// Generates a constant of `index` type.
 inline static Value constantIndex(ConversionPatternRewriter &rewriter,
-                                  Location loc, unsigned i) {
+                                  Location loc, int64_t i) {
   return rewriter.create<arith::ConstantIndexOp>(loc, i);
 }
 
@@ -144,6 +144,70 @@ static FlatSymbolRefAttr getFunc(Operation *op, StringRef name,
   return result;
 }
 
+/// Generates dimension size call.
+static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op,
+                            SparseTensorEncodingAttr &enc, Value src,
+                            int64_t idx) {
+  // Permute the index according to an optional dimension ordering.
+  if (AffineMap p = enc.getDimOrdering())
+    idx = p.getPermutedPosition(idx);
+  // Generate the call.
+  Location loc = op->getLoc();
+  StringRef name = "sparseDimSize";
+  SmallVector<Value, 2> params;
+  params.push_back(src);
+  params.push_back(constantIndex(rewriter, loc, idx));
+  Type iTp = rewriter.getIndexType();
+  auto fn = getFunc(op, name, iTp, params);
+  return rewriter.create<CallOp>(loc, iTp, fn, params).getResult(0);
+}
+
+/// Generates a call into the "swiss army knife" method of the sparse runtime
+/// support library for materializing sparse tensors into the computation.
+static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
+                        ArrayRef<Value> params) {
+  Location loc = op->getLoc();
+  StringRef name = "newSparseTensor";
+  Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
+  auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
+  auto call = rewriter.create<CallOp>(loc, pTp, fn, params);
+  return call.getResult(0);
+}
+
+/// Populates given sizes array from type.
+static void sizesFromType(ConversionPatternRewriter &rewriter,
+                          SmallVector<Value, 4> &sizes, Location loc,
+                          ShapedType stp) {
+  auto shape = stp.getShape();
+  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) {
+    uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
+    sizes.push_back(constantIndex(rewriter, loc, s));
+  }
+}
+
+/// Populates given sizes array from source.
+static void sizesFromSrc(ConversionPatternRewriter &rewriter,
+                         SmallVector<Value, 4> &sizes, Location loc,
+                         Value src) {
+  ShapedType stp = src.getType().cast<ShapedType>();
+  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+    sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
+}
+
+/// Populates given sizes array from type (for static sizes) and from
+/// an already converted into opague pointer source (for dynamic sizes).
+static void sizesFromPtr(ConversionPatternRewriter &rewriter,
+                         SmallVector<Value, 4> &sizes, Operation *op,
+                         SparseTensorEncodingAttr &enc, ShapedType stp,
+                         Value src) {
+  auto shape = stp.getShape();
+  for (unsigned i = 0, rank = stp.getRank(); i < rank; i++)
+    if (shape[i] == ShapedType::kDynamicSize)
+      sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i));
+    else
+      sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i]));
+}
+
 /// Generates a temporary buffer of the given size and type.
 static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
                        unsigned sz, Type tp) {
@@ -152,7 +216,7 @@ static Value genAlloca(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{a});
 }
 
-/// Fills a temporary buffer of the given type with arguments.
+/// Generates a temporary buffer of the given type and given contents.
 static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
                        ArrayRef<Value> values) {
   unsigned sz = values.size();
@@ -165,36 +229,28 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
   return buffer;
 }
 
-/// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation. The
-/// method returns the call value and assigns the permutation to 'perm'.
-static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
-                        SparseTensorEncodingAttr &enc, uint32_t action,
-                        Value &perm, ValueRange szs, Value ptr = Value()) {
+/// Populates parameters required to call the "swiss army knife" method of the
+/// sparse runtime support library for materializing sparse tensors into the
+/// computation.
+static void newParams(ConversionPatternRewriter &rewriter,
+                      SmallVector<Value, 8> &params, Operation *op,
+                      SparseTensorEncodingAttr &enc, uint32_t action,
+                      ValueRange szs, Value ptr = Value()) {
   Location loc = op->getLoc();
-  ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
-  SmallVector<Value, 8> params;
-  // Sparsity annotations in tensor constant form.
-  SmallVector<Value, 4> attrs;
   ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
   unsigned sz = dlt.size();
+  // Sparsity annotations.
+  SmallVector<Value, 4> attrs;
   for (unsigned i = 0; i < sz; i++)
     attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
   params.push_back(genBuffer(rewriter, loc, attrs));
-  // Dimension sizes array of the enveloping *dense* tensor. Useful for either
+  // Dimension sizes array of the enveloping tensor. Useful for either
   // verification of external data, or for construction of internal data.
-  auto shape = resType.getShape();
+  // The index type is casted to I64 for API consistency.
+  Type iTp = rewriter.getI64Type();
   SmallVector<Value, 4> sizes;
-  if (szs.size() > 0) {
-    for (Value s : szs)
-      sizes.push_back(
-          rewriter.create<arith::IndexCastOp>(loc, s, rewriter.getI64Type()));
-  } else {
-    for (unsigned i = 0; i < sz; i++) {
-      uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i];
-      sizes.push_back(constantI64(rewriter, loc, s));
-    }
-  }
+  for (Value s : szs)
+    sizes.push_back(rewriter.create<arith::IndexCastOp>(loc, s, iTp));
   params.push_back(genBuffer(rewriter, loc, sizes));
   // Dimension order permutation array. This is the "identity" permutation by
   // default, or otherwise the "reverse" permutation of a given ordering, so
@@ -207,9 +263,9 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
     for (unsigned i = 0; i < sz; i++)
       rev[i] = constantI64(rewriter, loc, i);
   }
-  perm = genBuffer(rewriter, loc, rev);
-  params.push_back(perm);
+  params.push_back(genBuffer(rewriter, loc, rev));
   // Secondary and primary types encoding.
+  ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
   unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
   unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
   unsigned primary = getPrimaryTypeEncoding(resType.getElementType());
@@ -223,12 +279,6 @@ static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op,
     ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
   params.push_back(constantI32(rewriter, loc, action));
   params.push_back(ptr);
-  // Generate the call to create new tensor.
-  StringRef name = "newSparseTensor";
-  auto call = rewriter.create<CallOp>(
-      loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
-      params);
-  return call.getResult(0);
 }
 
 /// Generates the comparison `v != 0` where `v` is of numeric type `t`.
@@ -299,9 +349,8 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
   params.push_back(ind);
   params.push_back(perm);
   Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
-  rewriter.create<CallOp>(
-      loc, pTp, getFunc(op, name, pTp, params, /*emitCInterface=*/true),
-      params);
+  auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true);
+  rewriter.create<CallOp>(loc, pTp, fn, params);
 }
 
 /// If the tensor is a sparse constant, generates and returns the pair of
@@ -362,24 +411,17 @@ public:
   LogicalResult
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resType = op.getType();
+    // Only rewrite annotated DimOp with constant index.
     auto enc = getSparseTensorEncoding(op.source().getType());
     if (!enc)
       return failure();
-    // Permute the dim index.
     Optional<int64_t> index = op.getConstantIndex();
     if (!index.hasValue())
       return failure();
-    int64_t idx = index.getValue();
-    if (AffineMap p = enc.getDimOrdering())
-      idx = p.getPermutedPosition(idx);
     // Generate the call.
-    StringRef name = "sparseDimSize";
-    SmallVector<Value, 2> params;
-    params.push_back(adaptor.getOperands()[0]);
-    params.push_back(constantIndex(rewriter, op.getLoc(), idx));
-    rewriter.replaceOpWithNewOp<CallOp>(
-        op, resType, getFunc(op, name, resType, params), params);
+    Value src = adaptor.getOperands()[0];
+    int64_t idx = index.getValue();
+    rewriter.replaceOp(op, genDimSizeCall(rewriter, op, enc, src, idx));
     return success();
   }
 };
@@ -394,9 +436,14 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
-    Value perm;
-    rewriter.replaceOp(op, genNewCall(rewriter, op, enc, kFromFile, perm, {},
-                                      adaptor.getOperands()[0]));
+    // Generate the call to construct tensor from ptr. The sizes are
+    // inferred from the result type of the new operator.
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 8> params;
+    sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
+    Value ptr = adaptor.getOperands()[0];
+    newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
+    rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
 };
@@ -411,9 +458,11 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
     auto enc = getSparseTensorEncoding(resType);
     if (!enc)
       return failure();
-    Value perm;
-    rewriter.replaceOp(
-        op, genNewCall(rewriter, op, enc, kEmpty, perm, adaptor.getOperands()));
+    // Generate the call to construct empty tensor. The sizes are
+    // explicitly defined by the arguments to the init operator.
+    SmallVector<Value, 8> params;
+    newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
+    rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
 };
@@ -424,10 +473,12 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
   LogicalResult
   matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
     Type resType = op.getType();
+    Type srcType = op.source().getType();
     auto encDst = getSparseTensorEncoding(resType);
-    auto encSrc = getSparseTensorEncoding(op.source().getType());
-    auto src = adaptor.getOperands()[0];
+    auto encSrc = getSparseTensorEncoding(srcType);
+    Value src = adaptor.getOperands()[0];
     if (encDst && encSrc) {
       // This is a sparse => sparse conversion, which is handled as follows:
       //   t = src->toCOO();         ; src to COO in dst order
@@ -435,10 +486,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       // Using the coordinate scheme as an intermediate does not always
       // yield the fastest conversion but avoids the need for a full
       // O(N^2) conversion matrix.
-      Value perm;
-      Value coo = genNewCall(rewriter, op, encDst, kToCOO, perm, {}, src);
-      rewriter.replaceOp(
-          op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, coo));
+      SmallVector<Value, 4> sizes;
+      SmallVector<Value, 8> params;
+      sizesFromPtr(rewriter, sizes, op, encSrc, srcType.cast<ShapedType>(),
+                   src);
+      newParams(rewriter, params, op, encDst, kToCOO, sizes, src);
+      Value coo = genNewCall(rewriter, op, params);
+      params[6] = constantI32(rewriter, loc, kFromCOO);
+      params[7] = coo;
+      rewriter.replaceOp(op, genNewCall(rewriter, op, params));
       return success();
     }
     if (!encDst || encSrc) {
@@ -471,12 +527,15 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     // Also note that the code below only generates the "new" ops and
     // the loop-nest per se; whereas the entire body of the innermost
     // loop is generated by genAddElt().
-    Location loc = op->getLoc();
-    ShapedType shape = resType.cast<ShapedType>();
-    Value perm;
-    Value ptr = genNewCall(rewriter, op, encDst, kEmptyCOO, perm, {});
-    Value ind =
-        genAlloca(rewriter, loc, shape.getRank(), rewriter.getIndexType());
+    ShapedType stp = resType.cast<ShapedType>();
+    unsigned rank = stp.getRank();
+    SmallVector<Value, 4> sizes;
+    SmallVector<Value, 8> params;
+    sizesFromSrc(rewriter, sizes, loc, src);
+    newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
+    Value ptr = genNewCall(rewriter, op, params);
+    Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
+    Value perm = params[2];
     SmallVector<Value> lo;
     SmallVector<Value> hi;
     SmallVector<Value> st;
@@ -493,14 +552,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
       hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0));
       st.push_back(one);
     } else {
-      for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) {
+      for (unsigned i = 0; i < rank; i++) {
         lo.push_back(zero);
         hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i));
         st.push_back(one);
       }
     }
-    Type eltType = shape.getElementType();
-    unsigned rank = shape.getRank();
+    Type eltType = stp.getElementType();
     scf::buildLoopNest(
         rewriter, op.getLoc(), lo, hi, st, {},
         [&](OpBuilder &builder, Location loc, ValueRange ivs,
@@ -514,8 +572,10 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
           genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm);
           return {};
         });
-    rewriter.replaceOp(
-        op, genNewCall(rewriter, op, encDst, kFromCOO, perm, {}, ptr));
+    // Final call to construct sparse tensor storage.
+    params[6] = constantI32(rewriter, loc, kFromCOO);
+    params[7] = ptr;
+    rewriter.replaceOp(op, genNewCall(rewriter, op, params));
     return success();
   }
 };
@@ -529,9 +589,8 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     StringRef name = "delSparseTensor";
     TypeRange none;
-    rewriter.create<CallOp>(op.getLoc(), none,
-                            getFunc(op, name, none, adaptor.getOperands()),
-                            adaptor.getOperands());
+    auto fn = getFunc(op, name, none, adaptor.getOperands());
+    rewriter.create<CallOp>(op.getLoc(), none, fn, adaptor.getOperands());
     rewriter.eraseOp(op);
     return success();
   }
@@ -560,11 +619,9 @@ public:
       name = "sparsePointers8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
-                                        getFunc(op, name, resType,
-                                                adaptor.getOperands(),
-                                                /*emitCInterface=*/true),
-                                        adaptor.getOperands());
+    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+                      /*emitCInterface=*/true);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
     return success();
   }
 };
@@ -591,11 +648,9 @@ public:
       name = "sparseIndices8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
-                                        getFunc(op, name, resType,
-                                                adaptor.getOperands(),
-                                                /*emitCInterface=*/true),
-                                        adaptor.getOperands());
+    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+                      /*emitCInterface=*/true);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
     return success();
   }
 };
@@ -624,11 +679,9 @@ public:
       name = "sparseValuesI8";
     else
       return failure();
-    rewriter.replaceOpWithNewOp<CallOp>(op, resType,
-                                        getFunc(op, name, resType,
-                                                adaptor.getOperands(),
-                                                /*emitCInterface=*/true),
-                                        adaptor.getOperands());
+    auto fn = getFunc(op, name, resType, adaptor.getOperands(),
+                      /*emitCInterface=*/true);
+    rewriter.replaceOpWithNewOp<CallOp>(op, resType, fn, adaptor.getOperands());
     return success();
   }
 };
index d6e4307..577b79c 100644 (file)
@@ -127,8 +127,8 @@ func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: %[[JJ:.*]] = arith.index_cast %[[J]] : index to i64
 //   CHECK-DAG: memref.store %[[II]], %[[Q]][%[[C0]]] : memref<2xi64>
 //   CHECK-DAG: memref.store %[[JJ]], %[[Q]][%[[C1]]] : memref<2xi64>
-//       CHECK: %[[A:.*]] = llvm.mlir.null : !llvm.ptr<i8>
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
   %0 = sparse_tensor.init [%arg0, %arg1] : tensor<?x?xf64, #SparseMatrix>
@@ -156,22 +156,23 @@ func @sparse_nop_convert(%arg0: tensor<64xf32, #SparseVector>) -> tensor<64xf32,
 //  CHECK-SAME: %[[A:.*]]: tensor<?xi32>) -> !llvm.ptr<i8>
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<1xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref<?xindex>
-//       CHECK: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?xi32>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] {
 //       CHECK:   %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<?xi32>
 //       CHECK:   memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex>
 //       CHECK:   call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector>
@@ -180,8 +181,14 @@ func @sparse_convert_1d(%arg0: tensor<?xi32>) -> tensor<?xi32, #SparseVector> {
 
 // CHECK-LABEL: func @sparse_convert_1d_ss(
 //  CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%{{.}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<1xi8>
+//   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<1xi64> to memref<?xi64>
+//   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<1xi64> to memref<?xi64>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
   %0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>
@@ -198,7 +205,8 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %{{.*}} step %[[C1]] {
@@ -209,7 +217,7 @@ func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf3
 //       CHECK:     call @addEltF64(%[[C]], %[[E]], %[[T]], %[[Z]])
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> {
   %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix>
@@ -226,7 +234,8 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<2xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<2xi64> to memref<?xi64>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<2xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<2xindex> to memref<?xindex>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -235,7 +244,7 @@ func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix
 //       CHECK:   %[[V:.*]] = tensor.extract %{{.*}}[%[[I]]] : tensor<2xf32>
 //       CHECK:   call @addEltF32(%{{.*}}, %[[V]], %[[N]], %{{.*}})
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
   // Initialize a tensor.
@@ -250,18 +259,19 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
+//   CHECK-DAG: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
+//   CHECK-DAG: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
 //   CHECK-DAG: %[[P:.*]] = memref.alloca() : memref<3xi8>
 //   CHECK-DAG: %[[Q:.*]] = memref.alloca() : memref<3xi64>
 //   CHECK-DAG: %[[R:.*]] = memref.alloca() : memref<3xi64>
 //   CHECK-DAG: %[[X:.*]] = memref.cast %[[P]] : memref<3xi8> to memref<?xi8>
 //   CHECK-DAG: %[[Y:.*]] = memref.cast %[[Q]] : memref<3xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Z:.*]] = memref.cast %[[R]] : memref<3xi64> to memref<?xi64>
-//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.}})
+//       CHECK: %[[NP:.*]] = llvm.mlir.null : !llvm.ptr<i8>
+//       CHECK: %[[C:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[NP]])
 //       CHECK: %[[M:.*]] = memref.alloca() : memref<3xindex>
 //       CHECK: %[[N:.*]] = memref.cast %[[M]] : memref<3xindex> to memref<?xindex>
-//       CHECK: %[[U1:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf64>
-//       CHECK: %[[U2:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf64>
-//       CHECK: %[[U3:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf64>
 //       CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U1]] step %[[C1]] {
 //       CHECK:   scf.for %[[J:.*]] = %[[C0]] to %[[U2]] step %[[C1]] {
 //       CHECK:     scf.for %[[K:.*]] = %[[C0]] to %[[U3]] step %[[C1]] {
@@ -273,7 +283,7 @@ func @sparse_constant() -> tensor<8x7xf32, #SparseMatrix>{
 //       CHECK:     }
 //       CHECK:   }
 //       CHECK: }
-//       CHECK: %[[T:.*]] = call @newSparseTensor(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
+//       CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]])
 //       CHECK: return %[[T]] : !llvm.ptr<i8>
 func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> {
   %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor>
index 03abcb5..8955359 100644 (file)
@@ -162,8 +162,8 @@ func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
 
 #CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
 
-func @sparse_convert_mismatch(%arg0: tensor<10x10xf32>) -> tensor<10x?xf32, #CSR> {
+func @sparse_convert_mismatch(%arg0: tensor<10x?xf32>) -> tensor<10x10xf32, #CSR> {
   // expected-error@+1 {{unexpected conversion mismatch in dimension 1}}
-  %0 = sparse_tensor.convert %arg0 : tensor<10x10xf32> to tensor<10x?xf32, #CSR>
-  return %0 : tensor<10x?xf32, #CSR>
+  %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR>
+  return %0 : tensor<10x10xf32, #CSR>
 }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_convert.mlir
new file mode 100644 (file)
index 0000000..3bfd0df
--- /dev/null
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s \
+// RUN:   --sparsification --sparse-tensor-conversion \
+// RUN:   --linalg-bufferize --convert-linalg-to-loops \
+// RUN:   --convert-vector-to-scf --convert-scf-to-std \
+// RUN:   --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
+// RUN:   --std-bufferize --finalizing-bufferize --lower-affine \
+// RUN:   --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \
+// RUN:   --convert-std-to-llvm --reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#DCSR  = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ]
+}>
+
+#DCSC  = #sparse_tensor.encoding<{
+  dimLevelType = [ "compressed", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Integration test that tests conversions between sparse tensors,
+// where the dynamic sizes of the shape of the enveloping tensor
+// may change (the actual underlying sizes obviously never change).
+//
+module {
+
+  //
+  // Helper method to print values array. The transfer actually
+  // reads more than required to verify size of buffer as well.
+  //
+  func @dump(%arg0: memref<?xf64>) {
+    %c = arith.constant 0 : index
+    %d = arith.constant -1.0 : f64
+    %0 = vector.transfer_read %arg0[%c], %d: memref<?xf64>, vector<8xf64>
+    vector.print %0 : vector<8xf64>
+    return
+  }
+
+  func @entry() {
+    %t1 = arith.constant sparse<
+      [ [0,0], [0,1], [0,63], [1,0], [1,1], [31,0], [31,63] ],
+        [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 ]> : tensor<32x64xf64>
+    %t2 = tensor.cast %t1 : tensor<32x64xf64> to tensor<?x?xf64>
+
+    // Four dense to sparse conversions.
+    %1 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<?x?xf64, #DCSR>
+    %2 = sparse_tensor.convert %t1 : tensor<32x64xf64> to tensor<?x?xf64, #DCSC>
+    %3 = sparse_tensor.convert %t2 : tensor<?x?xf64> to tensor<?x?xf64, #DCSR>
+    %4 = sparse_tensor.convert %t2 : tensor<?x?xf64> to tensor<?x?xf64, #DCSC>
+
+    // Two cross conversions.
+    %5 = sparse_tensor.convert %3 : tensor<?x?xf64, #DCSR> to tensor<?x?xf64, #DCSC>
+    %6 = sparse_tensor.convert %4 : tensor<?x?xf64, #DCSC> to tensor<?x?xf64, #DCSR>
+
+    //
+    // All proper row-/column-wise?
+    //
+    // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+    // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+    // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+    // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+    // CHECK: ( 1, 4, 6, 2, 5, 3, 7, -1 )
+    // CHECK: ( 1, 2, 3, 4, 5, 6, 7, -1 )
+    //
+    %m1 = sparse_tensor.values %1 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+    %m2 = sparse_tensor.values %2 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+    %m3 = sparse_tensor.values %3 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+    %m4 = sparse_tensor.values %4 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+    %m5 = sparse_tensor.values %5 : tensor<?x?xf64, #DCSC> to memref<?xf64>
+    %m6 = sparse_tensor.values %6 : tensor<?x?xf64, #DCSR> to memref<?xf64>
+    call @dump(%m1) : (memref<?xf64>) -> ()
+    call @dump(%m2) : (memref<?xf64>) -> ()
+    call @dump(%m3) : (memref<?xf64>) -> ()
+    call @dump(%m4) : (memref<?xf64>) -> ()
+    call @dump(%m5) : (memref<?xf64>) -> ()
+    call @dump(%m6) : (memref<?xf64>) -> ()
+
+    // Release the resources.
+    sparse_tensor.release %1 : tensor<?x?xf64, #DCSR>
+    sparse_tensor.release %2 : tensor<?x?xf64, #DCSC>
+    sparse_tensor.release %3 : tensor<?x?xf64, #DCSR>
+    sparse_tensor.release %4 : tensor<?x?xf64, #DCSC>
+    sparse_tensor.release %5 : tensor<?x?xf64, #DCSC>
+    sparse_tensor.release %6 : tensor<?x?xf64, #DCSR>
+
+    return
+  }
+}