let extraClassDeclaration = [{
/// Special value for dynamic offset/size/stride.
static constexpr int64_t kDynamic = -1;
-
- static bool isDynamic(int64_t v) {
- return v == kDynamic;
- }
-
- std::optional<uint64_t> getStaticOffset() const {
- if (isDynamic(getOffset()))
- return std::nullopt;
- return static_cast<uint64_t>(getOffset());
- };
-
- std::optional<uint64_t> getStaticStride() const {
- if (isDynamic(getStride()))
- return std::nullopt;
- return static_cast<uint64_t>(getStride());
- }
-
- std::optional<uint64_t> getStaticSize() const {
- if (isDynamic(getSize()))
- return std::nullopt;
- return static_cast<uint64_t>(getSize());
- }
-
- bool isCompletelyDynamic() const {
- return isDynamic(getOffset()) && isDynamic(getStride()) && isDynamic(getSize());
- };
+ static constexpr bool isDynamic(int64_t v) { return v == kDynamic; }
+ static std::optional<uint64_t> getStatic(int64_t v);
+ static std::string getStaticString(int64_t v);
+
+ std::optional<uint64_t> getStaticOffset() const;
+ std::optional<uint64_t> getStaticStride() const;
+ std::optional<uint64_t> getStaticSize() const;
+ bool isCompletelyDynamic() const;
}];
let genVerifyDecl = 1;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
+// Additional convenience methods.
+//===----------------------------------------------------------------------===//
+
+static constexpr bool acceptBitWidth(unsigned bitWidth) {
+ switch (bitWidth) {
+ case 0:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+//===----------------------------------------------------------------------===//
// StorageLayout
//===----------------------------------------------------------------------===//
// TensorDialect Attribute Methods.
//===----------------------------------------------------------------------===//
-static bool acceptBitWidth(unsigned bitWidth) {
- switch (bitWidth) {
- case 0:
- case 8:
- case 16:
- case 32:
- case 64:
- return true;
- default:
- return false;
- }
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
+ return isDynamic(v) ? std::nullopt
+ : std::make_optional(static_cast<uint64_t>(v));
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
+ return getStatic(getOffset());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
+ return getStatic(getStride());
+}
+
+std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
+ return getStatic(getSize());
+}
+
+bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
+ return isDynamic(getOffset()) && isDynamic(getStride()) &&
+ isDynamic(getSize());
+}
+
+std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
+ return isDynamic(v) ? "?" : std::to_string(v);
}
void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
printer << "(";
- printer << (getStaticOffset() ? std::to_string(*getStaticOffset()) : "?");
+ printer << getStaticString(getOffset());
printer << ", ";
- printer << (getStaticSize() ? std::to_string(*getStaticSize()) : "?");
+ printer << getStaticString(getSize());
printer << ", ";
- printer << (getStaticStride() ? std::to_string(*getStaticStride()) : "?");
+ printer << getStaticString(getStride());
printer << ")";
}
}
Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
- int64_t offset = -1, size = -1, stride = -1;
+ int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
if (failed(parser.parseLParen()) ||
failed(parseOptionalStaticSlice(offset, parser)) ||
LogicalResult
SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, int64_t size, int64_t stride) {
- if ((offset == SparseTensorDimSliceAttr::kDynamic || offset >= 0) &&
- (size == SparseTensorDimSliceAttr::kDynamic || size > 0) &&
- (stride == SparseTensorDimSliceAttr::kDynamic || stride > 0)) {
- return success();
- }
- return emitError()
- << "expect positive value or ? for slice offset/size/stride";
+ if (!isDynamic(offset) && offset < 0)
+ return emitError() << "expect non-negative value or ? for slice offset";
+ if (!isDynamic(size) && size <= 0)
+ return emitError() << "expect positive value or ? for slice size";
+ if (!isDynamic(stride) && stride <= 0)
+ return emitError() << "expect positive value or ? for slice stride";
+ return success();
}
Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,