(These factories are used in downstream code, despite not being used within the MLIR codebase.)
Depends On D151513
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D151518
/// reset to the default, and all other fields inherited from `this`.
SparseTensorEncodingAttr withoutBitWidths() const;
+ /// Constructs a new encoding with the given dimSlices, and all
+ /// other fields inherited from `this`.
+ SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
+
+ /// Constructs a new encoding with the dimSlices reset to the default,
+ /// and all other fields inherited from `this`.
+ SparseTensorEncodingAttr withoutDimSlices() const;
+
//
// Rank methods.
//
return withEncoding(enc.withoutBitWidths());
}
+ SparseTensorType
+ withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+ return withEncoding(enc.withDimSlices(dimSlices));
+ }
+
+ SparseTensorType withoutDimSlices() const {
+ return withEncoding(enc.withoutDimSlices());
+ }
+
//
// Other methods.
//
return withBitWidths(0, 0);
}
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
+ ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+ return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
+ getDimToLvl(), getPosWidth(),
+ getCrdWidth(), dimSlices);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
+ return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
+}
+
bool SparseTensorEncodingAttr::isAllDense() const {
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
}
// TODO: We should check these in ExtractSliceOp::verify.
if (!srcEnc || !dstEnc || !dstEnc.isSlice())
return failure();
- assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
- assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
- assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
- assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
+ assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);