From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 30 May 2023 21:16:17 +0000 (-0700) Subject: [mlir][sparse] Adding new STEA::{with,without}DimSlices factories X-Git-Tag: upstream/17.0.6~6743 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=af2bec7c4a967c9e2e009cdbc4470eb5ba8332f6;p=platform%2Fupstream%2Fllvm.git [mlir][sparse] Adding new STEA::{with,without}DimSlices factories (These factories are used in downstream code, despite not being used within the MLIR codebase.) Depends On D151513 Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D151518 --- diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index f0a502e..9fe425a 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -304,6 +304,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", /// reset to the default, and all other fields inherited from `this`. SparseTensorEncodingAttr withoutBitWidths() const; + /// Constructs a new encoding with the given dimSlices, and all + /// other fields inherited from `this`. + SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const; + + /// Constructs a new encoding with the dimSlices reset to the default, + /// and all other fields inherited from `this`. + SparseTensorEncodingAttr withoutDimSlices() const; + // // Rank methods. // diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 6cae09d..cfc3374 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -111,6 +111,15 @@ public: return withEncoding(enc.withoutBitWidths()); } + SparseTensorType + withDimSlices(ArrayRef dimSlices) const { + return withEncoding(enc.withDimSlices(dimSlices)); + } + + SparseTensorType withoutDimSlices() const { + return withEncoding(enc.withoutDimSlices()); + } + // // Other methods. // diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 962e0ac21..a1eda89 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -291,6 +291,17 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { return withBitWidths(0, 0); } +SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( + ArrayRef dimSlices) const { + return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), + getDimToLvl(), getPosWidth(), + getCrdWidth(), dimSlices); +} + +SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { + return withDimSlices(ArrayRef{}); +} + bool SparseTensorEncodingAttr::isAllDense() const { return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index f84009c..a7f37e81 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1138,10 +1138,7 @@ public: // TODO: We should check these in ExtractSliceOp::verify. if (!srcEnc || !dstEnc || !dstEnc.isSlice()) return failure(); - assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes()); - assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl()); - assert(srcEnc.getPosWidth() == dstEnc.getPosWidth()); - assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth()); + assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); SmallVector fields; auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);