[mlir][sparse] Improving SparseTensorDimSliceAttr methods
authorwren romano <2998727+wrengr@users.noreply.github.com>
Tue, 30 May 2023 20:31:49 +0000 (13:31 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 31 May 2023 00:30:55 +0000 (17:30 -0700)
This patch makes the following changes to `SparseTensorDimSliceAttr` methods:
* Mark `isDynamic` constexpr.
* Add new helpers `getStatic` and `getStaticString` to avoid repetition.
* Moved the definitions for `getStatic{Offset,Stride,Size}` and `isCompletelyDynamic` out of the class declaration; because there's no benefit to inlining them.
* Changed `parse` to use `kDynamic` rather than literals.
* Changed `verify` to use the `isDynamic` helper.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150919

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

index 9fe425a..d6c971b 100644 (file)
@@ -76,32 +76,14 @@ def SparseTensorDimSliceAttr : SparseTensor_Attr<"SparseTensorDimSlice", []> {
   let extraClassDeclaration = [{
     /// Special value for dynamic offset/size/stride.
     static constexpr int64_t kDynamic = -1;
-
-    static bool isDynamic(int64_t v) {
-      return v == kDynamic;
-    }
-
-    std::optional<uint64_t> getStaticOffset() const {
-      if (isDynamic(getOffset()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getOffset());
-    };
-
-    std::optional<uint64_t> getStaticStride() const {
-      if (isDynamic(getStride()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getStride());
-    }
-
-    std::optional<uint64_t> getStaticSize() const {
-      if (isDynamic(getSize()))
-        return std::nullopt;
-      return static_cast<uint64_t>(getSize());
-    }
-
-    bool isCompletelyDynamic() const {
-      return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
-    };
+    static constexpr bool isDynamic(int64_t v) { return v == kDynamic; }
+    static std::optional<uint64_t> getStatic(int64_t v);
+    static std::string getStaticString(int64_t v);
+
+    std::optional<uint64_t> getStaticOffset() const;
+    std::optional<uint64_t> getStaticStride() const;
+    std::optional<uint64_t> getStaticSize() const;
+    bool isCompletelyDynamic() const;
   }];
 
   let genVerifyDecl = 1;
index 7f8dcba..490e35d 100644 (file)
@@ -32,6 +32,23 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 
 //===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+static constexpr bool acceptBitWidth(unsigned bitWidth) {
+  switch (bitWidth) {
+  case 0:
+  case 8:
+  case 16:
+  case 32:
+  case 64:
+    return true;
+  default:
+    return false;
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // StorageLayout
 //===----------------------------------------------------------------------===//
 
@@ -166,26 +183,39 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
 // TensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//
 
-static bool acceptBitWidth(unsigned bitWidth) {
-  switch (bitWidth) {
-  case 0:
-  case 8:
-  case 16:
-  case 32:
-  case 64:
-    return true;
-  default:
-    return false;
-  }
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
+  return isDynamic(v) ? std::nullopt
+                      : std::make_optional(static_cast<uint64_t>(v));
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
+  return getStatic(getOffset());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
+  return getStatic(getStride());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
+  return getStatic(getSize());
+}
+
+bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
+  return isDynamic(getOffset()) && isDynamic(getStride()) &&
+         isDynamic(getSize());
+}
+
+std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
+  return isDynamic(v) ? "?" : std::to_string(v);
 }
 
 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
   printer << "(";
-  printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+  printer << getStaticString(getOffset());
   printer << ", ";
-  printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+  printer << getStaticString(getSize());
   printer << ", ";
-  printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+  printer << getStaticString(getStride());
   printer << ")";
 }
 
@@ -208,7 +238,7 @@ static ParseResult parseOptionalStaticSlice(int64_t &result,
 }
 
 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
-  int64_t offset = -1, size = -1, stride = -1;
+  int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
 
   if (failed(parser.parseLParen()) ||
       failed(parseOptionalStaticSlice(offset, parser)) ||
@@ -226,13 +256,13 @@ Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
 LogicalResult
 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                  int64_t offset, int64_t size, int64_t stride) {
-  if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
-      (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
-      (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
-    return success();
-  }
-  return emitError()
-         << "expect positive value or ? for slice offset/size/stride";
+  if (!isDynamic(offset) && offset < 0)
+    return emitError() << "expect non-negative value or ? for slice offset";
+  if (!isDynamic(size) && size <= 0)
+    return emitError() << "expect positive value or ? for slice size";
+  if (!isDynamic(stride) && stride <= 0)
+    return emitError() << "expect positive value or ? for slice stride";
+  return success();
 }
 
 Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,