[mlir][sparse] Adding new STEA::{with,without}DimSlices factories
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 30 May 2023 21:16:17 +0000 (14:16 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 30 May 2023 22:53:30 +0000 (15:53 -0700)
(These factories are used in downstream code, despite not being used within the MLIR codebase.)

Depends On D151513

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D151518

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

index f0a502e..9fe425a 100644 (file)
@@ -304,6 +304,14 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     /// reset to the default, and all other fields inherited from `this`.
     SparseTensorEncodingAttr withoutBitWidths() const;
 
+    /// Constructs a new encoding with the given dimSlices, and all
+    /// other fields inherited from `this`.
+    SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
+
+    /// Constructs a new encoding with the dimSlices reset to the default,
+    /// and all other fields inherited from `this`.
+    SparseTensorEncodingAttr withoutDimSlices() const;
+
     //
     // Rank methods.
     //
index 6cae09d..cfc3374 100644 (file)
@@ -111,6 +111,15 @@ public:
     return withEncoding(enc.withoutBitWidths());
   }
 
+  SparseTensorType
+  withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+    return withEncoding(enc.withDimSlices(dimSlices));
+  }
+
+  SparseTensorType withoutDimSlices() const {
+    return withEncoding(enc.withoutDimSlices());
+  }
+
   //
   // Other methods.
   //
index 962e0ac..a1eda89 100644 (file)
@@ -291,6 +291,17 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
   return withBitWidths(0, 0);
 }
 
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
+    ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
+  return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
+                                       getDimToLvl(), getPosWidth(),
+                                       getCrdWidth(), dimSlices);
+}
+
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
+  return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
+}
+
 bool SparseTensorEncodingAttr::isAllDense() const {
   return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
 }
index f84009c..a7f37e8 100644 (file)
@@ -1138,10 +1138,7 @@ public:
     // TODO: We should check these in ExtractSliceOp::verify.
     if (!srcEnc || !dstEnc || !dstEnc.isSlice())
       return failure();
-    assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes());
-    assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl());
-    assert(srcEnc.getPosWidth() == dstEnc.getPosWidth());
-    assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth());
+    assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
 
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);