#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
return failure();
// User pointer.
params.push_back(operands[0]);
- // Sparsity annotations.
+ // Sparsity annotations in tensor constant form. Note that we cast
+ // the static shape into a dynamic shape to ensure that the method
+ // signature remains uniform accross different tensor dimensions.
SmallVector<bool, 4> attrs;
unsigned sz = enc.getDimLevelType().size();
for (unsigned i = 0; i < sz; i++)
attrs.push_back(enc.getDimLevelType()[i] ==
SparseTensorEncodingAttr::DimLevelType::Compressed);
- auto elts = DenseElementsAttr::get(
- RankedTensorType::get({sz}, rewriter.getIntegerType(1)), attrs);
- params.push_back(rewriter.create<ConstantOp>(loc, elts));
+ Type etp = rewriter.getIntegerType(1);
+ RankedTensorType tt1 = RankedTensorType::get({sz}, etp);
+ RankedTensorType tt2 =
+ RankedTensorType::get({ShapedType::kDynamicSize}, etp);
+ auto elts =
+ rewriter.create<ConstantOp>(loc, DenseElementsAttr::get(tt1, attrs));
+ params.push_back(rewriter.create<tensor::CastOp>(loc, tt2, elts));
// Seconary and primary types encoding.
unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
indexBitWidth = 32
}>
+#SparseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = ["dense", "compressed"]
+}>
+
// CHECK-LABEL: func @sparse_dim(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index
return %0 : index
}
-// CHECK-LABEL: func @sparse_new(
+// CHECK-LABEL: func @sparse_new1d(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
-// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]]
+// CHECK: %[[D:.*]] = constant dense<true> : tensor<1xi1>
+// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi1> to tensor<?xi1>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
// CHECK: return %[[T]] : !llvm.ptr<i8>
-func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
+func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
return %0 : tensor<128xf64, #SparseVector>
}
+// CHECK-LABEL: func @sparse_new2d(
+// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8>
+// CHECK: %[[D:.*]] = constant dense<[false, true]> : tensor<2xi1>
+// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi1> to tensor<?xi1>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr<i8>, tensor<?xi1>, i64, i64, i64) -> !llvm.ptr<i8>
+// CHECK: return %[[T]] : !llvm.ptr<i8>
+func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #SparseMatrix> {
+ %0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #SparseMatrix>
+ return %0 : tensor<?x?xf32, #SparseMatrix>
+}
+
// CHECK-LABEL: func @sparse_pointers(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[C:.*]] = constant 0 : index