From f708a549b87e163e144b41fe1887951c5639027d Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 14 Feb 2023 18:20:45 -0800 Subject: [PATCH] [mlir][sparse] Factoring out SparseTensorType class This change adds a new `SparseTensorType` class for making the "dim" vs "lvl" distinction more overt, and for abstracting over the differences between sparse-tensors and dense-tensors. In addition, this change also adds new type aliases `Dimension`, `Level`, and `FieldIndex` to make code more self-documenting. Although the diff is very large, the majority of the changes are mechanical in nature (e.g., changing types to use the new aliases, updating variable names to match, etc). Along the way I also made many variables `const` when they could be; the majority of which required only adding the keyword. A few places had conditional definitions of these variables, requiring actual code changes; however, that was only done when the overall change was extremely local and easy to extract. All these changes are included in the current patch only because it would be too onerous to split them off into a separate patch. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143800 --- .../mlir/Dialect/SparseTensor/IR/SparseTensor.h | 160 +++--- .../SparseTensor/IR/SparseTensorAttrDefs.td | 39 +- .../Dialect/SparseTensor/IR/SparseTensorType.h | 232 ++++++++ mlir/lib/CAPI/Dialect/SparseTensor.cpp | 6 +- .../SparseTensor/IR/SparseTensorDialect.cpp | 390 ++++++++------ .../Dialect/SparseTensor/Transforms/CodegenEnv.cpp | 5 +- .../SparseTensor/Transforms/CodegenUtils.cpp | 86 +-- .../Dialect/SparseTensor/Transforms/CodegenUtils.h | 11 +- .../SparseTensor/Transforms/LoopEmitter.cpp | 40 +- .../Transforms/SparseStorageSpecifierToLLVM.cpp | 4 +- .../Transforms/SparseTensorCodegen.cpp | 345 ++++++------ .../Transforms/SparseTensorConversion.cpp | 586 ++++++++++----------- .../Transforms/SparseTensorRewriting.cpp | 276 +++++----- .../Transforms/SparseTensorStorageLayout.cpp | 105 ++-- .../Transforms/SparseTensorStorageLayout.h | 216 ++++---- .../SparseTensor/Transforms/Sparsification.cpp | 133 +++-- mlir/test/Dialect/SparseTensor/invalid.mlir | 12 +- .../Dialect/SparseTensor/invalid_encoding.mlir | 6 +- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 5 +- 19 files changed, 1503 insertions(+), 1154 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h index c010544..ad5501c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -18,6 +18,49 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +//===----------------------------------------------------------------------===// +// +// Type aliases to help code be more self-documenting. Unfortunately +// these are not type-checked, so they only provide documentation rather +// than doing anything to prevent mixups. +// +// We must include these here (rather than in "SparseTensorType.h") +// because they are used by methods declared in the tablegen files. +// +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace sparse_tensor { + +/// The type of dimension identifiers, and dimension-ranks. We use the +/// same type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how the +/// one-past-the-end iterators are used). +using Dimension = uint64_t; + +/// The type of level identifiers, and level-ranks. We use the same +/// type for both identifiers and ranks because the latter are used +/// mainly for ordering-comparisons against the former (just like how +/// the one-past-the-end iterators are used). +using Level = uint64_t; + +/// The type for individual components of a compile-time shape. We avoid +/// calling this "size" because we use the term "sizes" to indicate the +/// actual run-time sizes, whereas this type also allows the value +/// `ShapedType::kDynamic`. +using DynSize = int64_t; + +/// The type for individual components of a compile-time shape which +/// are known not to be `ShapedType::kDynamic`. +using StaticSize = int64_t; + +} // namespace sparse_tensor +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen-defined classes +//===----------------------------------------------------------------------===// + // We must include Enums.h.inc before AttrDefs.h.inc due to dependency between // StorageSpecifierKindAttr and StorageSpeciferKind Enum. @@ -35,6 +78,10 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Additional convenience methods. +//===----------------------------------------------------------------------===// + namespace mlir { namespace sparse_tensor { @@ -54,14 +101,14 @@ inline MemRefType getMemRefType(T t) { /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); -/// Returns true iff the given type is a type for a COO tensor with the last -/// dimension level type being unique. +/// Returns true iff the given type is a COO type where the last level +/// is unique. bool isUniqueCOOType(TensorType tp); -/// Returns the starting dimension for a trailing COO region that spans across -/// at least two dimensions. If no such COO region is found, returns the rank -/// of the tensor. -unsigned getCOOStart(SparseTensorEncodingAttr enc); +/// Returns the starting level for a trailing COO region that spans +/// at least two levels. If no such COO region is found, then returns +/// the level-rank. +Level getCOOStart(SparseTensorEncodingAttr enc); /// Helpers to setup a COO type. RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src, @@ -70,87 +117,32 @@ RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src, RankedTensorType getCOOFromType(RankedTensorType src, bool ordered); // -// Dimension level types. -// - -// MSVC does not allow this function to be constexpr, because -// `SparseTensorEncodingAttr::operator bool` isn't declared constexpr. -// And therefore all functions calling it cannot be constexpr either. -// TODO: since Clang does allow these to be constexpr, perhaps we should -// define a macro to abstract over `inline` vs `constexpr` annotations. -inline DimLevelType getDimLevelType(SparseTensorEncodingAttr enc, uint64_t d) { - if (enc) { - auto types = enc.getDimLevelType(); - assert(d < types.size() && "Dimension out of bounds"); - return types[d]; - } - return DimLevelType::Dense; // unannotated tensor is dense -} - -inline DimLevelType getDimLevelType(RankedTensorType type, uint64_t d) { - return getDimLevelType(getSparseTensorEncoding(type), d); -} - -/// Convenience function to test for dense dimension (0 <= d < rank). -inline bool isDenseDim(RankedTensorType type, uint64_t d) { - return isDenseDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for compressed dimension (0 <= d < rank). -inline bool isCompressedDim(RankedTensorType type, uint64_t d) { - return isCompressedDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for singleton dimension (0 <= d < rank). -inline bool isSingletonDim(RankedTensorType type, uint64_t d) { - return isSingletonDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for dense dimension (0 <= d < rank). -inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isDenseDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for compressed dimension (0 <= d < rank). -inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isCompressedDLT(getDimLevelType(enc, d)); -} - -/// Convenience function to test for singleton dimension (0 <= d < rank). -inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) { - return isSingletonDLT(getDimLevelType(enc, d)); -} - -// -// Dimension level properties. -// - -/// Convenience function to test for ordered property in the -/// given dimension (0 <= d < rank). -inline bool isOrderedDim(RankedTensorType type, uint64_t d) { - return isOrderedDLT(getDimLevelType(type, d)); -} - -/// Convenience function to test for unique property in the -/// given dimension (0 <= d < rank). -inline bool isUniqueDim(RankedTensorType type, uint64_t d) { - return isUniqueDLT(getDimLevelType(type, d)); -} - -// // Reordering. // -uint64_t toOrigDim(SparseTensorEncodingAttr enc, uint64_t d); -uint64_t toStoredDim(SparseTensorEncodingAttr enc, uint64_t d); - -/// Convenience method to translate the given stored dimension -/// to the original dimension (0 <= d < rank). -uint64_t toOrigDim(RankedTensorType type, uint64_t d); - -/// Convenience method to translate the given original dimension -/// to the stored dimension (0 <= d < rank). -uint64_t toStoredDim(RankedTensorType type, uint64_t d); +// This CPP guard is to disable deprecation warnings for the LLVM +// build-bot, while making it easy to re-enable it for local development. +#if 0 +#define DEPRECATED \ + LLVM_DEPRECATED("The toOrigDim/toStoredDim functions are deprecated " \ + "because they only work for permutations; therefore any " \ + "code using them cannot support non-permutations.", \ + "") +#else +#define DEPRECATED +#endif + +/// [deprecated] Convenience method to translate the given level to the +/// corresponding dimension. Requires: `0 <= l < lvlRank`. +DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l); +DEPRECATED Dimension toOrigDim(RankedTensorType type, Level l); + +/// [deprecated] Convenience method to translate the given dimension to +/// the corresponding level. Requires: `0 <= d < dimRank`. +DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d); +DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d); + +#undef DEPRECATED } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index f6fc8fd..f25cd31 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -260,22 +260,45 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", /// reset to the default/identity. SparseTensorEncodingAttr withoutOrdering() const; - /// Return true if every level is dense in the encoding. + /// Returns true if every level is dense. Also returns true for + /// the null encoding (since dense-tensors are always all-dense). bool isAllDense() const; - /// Return true if the encoding has an identity dimension ordering. + /// Returns true if every level is ordered. Also returns true for + /// the null encoding (since dense-tensors are always all-ordered). + bool isAllOrdered() const; + + /// Returns true if the encoding has an identity dimension ordering. + /// Also returns true for the null encoding (since dense-tensors + /// always have the identity ordering). bool hasIdDimOrdering() const; + /// Returns the number of storage levels. Asserts that the encoding + /// is non-null (since there is no fixed result that's valid for + /// every dense-tensor). + ::mlir::sparse_tensor::Level getLvlRank() const; + + /// Safely looks up the level-type for the requested level. (Returns + /// `DimLevelType::Dense` for the null encoding, since dense-tensors + /// are always all-dense.) + ::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const; + + bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); } + bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); } + bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); } + bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); } + bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); } + bool isSlice() const { return !getDimSlices().empty(); } - std::optional getStaticDimSliceOffset(unsigned dim) const; - std::optional getStaticDimSliceSize(unsigned dim) const; - std::optional getStaticDimSliceStride(unsigned dim) const; - std::optional getStaticLvlSliceOffset(unsigned lvl) const; - std::optional getStaticLvlSliceSize(unsigned lvl) const; - std::optional getStaticLvlSliceStride(unsigned lvl) const; + std::optional getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticDimSliceStride(::mlir::sparse_tensor::Dimension dim) const; + std::optional getStaticLvlSliceOffset(::mlir::sparse_tensor::Level lvl) const; + std::optional getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const; + std::optional getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const; }]; let genVerifyDecl = 1; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h new file mode 100644 index 0000000..ba31fa7 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -0,0 +1,232 @@ +//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines the `SparseTensorType` wrapper class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ +#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ + +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +namespace mlir { +namespace sparse_tensor { + +//===----------------------------------------------------------------------===// +/// A wrapper around `RankedTensorType`, which has three goals: +/// +/// (1) To provide a uniform API for querying aspects of sparse-tensor +/// types; in particular, to make the "dimension" vs "level" distinction +/// overt (i.e., explicit everywhere). Thus, throughout the sparse-compiler +/// this class should be preferred over using `RankedTensorType` or +/// `ShapedType` directly, since the methods of the latter do not make +/// the "dimension" vs "level" distinction overt. +/// +/// (2) To provide a uniform abstraction over both sparse-tensor +/// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`) +/// and dense-tensor types (i.e., `RankedTensorType` without an encoding). +/// That is, we want to manipulate dense-tensor types using the same API +/// that we use for manipulating sparse-tensor types; both to keep the +/// "dimension" vs "level" distinction overt, and to avoid needing to +/// handle certain cases specially in the sparse-compiler. +/// +/// (3) To provide uniform handling of "defaults". In particular +/// this means that dense-tensors should always return the same answers +/// as sparse-tensors with a default encoding. But it additionally means +/// that the answers should be normalized, so that there's no way to +/// distinguish between non-provided data (which is filled in by default) +/// vs explicitly-provided data which equals the defaults. +/// +class SparseTensorType { +public: + // We memoize `lvlRank` and `dim2lvl` to avoid repeating the + // conditionals throughout the rest of the class. + SparseTensorType(RankedTensorType rtp) + : rtp(rtp), enc(getSparseTensorEncoding(rtp)), + lvlRank(enc ? enc.getLvlRank() : getDimRank()), + dim2lvl(enc.hasIdDimOrdering() ? AffineMap() : enc.getDimOrdering()) { + assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch"); + } + + SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc) + : SparseTensorType( + RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {} + + /// Constructs a new `SparseTensorType` with the same dimension-shape + /// and element type, but with the encoding replaced by the given encoding. + SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const { + return SparseTensorType(rtp, newEnc); + } + + /// Constructs a new `SparseTensorType` with the same dimension-shape + /// and element type, but with the encoding replaced by + /// `getEncoding().withoutOrdering()`. + SparseTensorType withoutOrdering() const { + return withEncoding(enc.withoutOrdering()); + } + + /// Allow implicit conversion to `RankedTensorType`, `ShapedType`, + /// and `Type`. These are implicit to help alleviate the impedance + /// mismatch for code that has not been converted to use `SparseTensorType` + /// directly. Once more of the sparse compiler has been converted to + /// using `SparseTensorType`, we may want to make these explicit instead. + /// + /// WARNING: This user-defined-conversion method causes overload + /// ambiguity whenever passing a `SparseTensorType` directly to a + /// function which is overloaded to accept either `Type` or `TypeRange`. + /// In particular, this includes `RewriterBase::replaceOpWithNewOp` + /// and `OpBuilder::create` whenever the `OpTy::build` is overloaded + /// thus. This happens because the `TypeRange(T&&)` ctor is implicit + /// as well, and there's no SFINAE we can add to this method that would + /// block subsequent application of that ctor. The only way to fix the + /// overload ambiguity is to avoid *implicit* conversion at the callsite: + /// e.g., by using `static_cast` to make the conversion explicit, by + /// assigning the `SparseTensorType` to a temporary variable of the + /// desired type, etc. + // + // NOTE: We implement this as a single templated user-defined-conversion + // function to avoid ambiguity problems when the desired result is `Type` + // (since both `RankedTensorType` and `ShapedType` can be implicitly + // converted to `Type`). + template >> + /*implicit*/ operator T() const { + return rtp; + } + + /// Explicitly convert to `RankedTensorType`. This method is + /// a convenience for resolving overload-ambiguity issues with + /// implicit conversion. + RankedTensorType getRankedTensorType() const { return rtp; } + + MLIRContext *getContext() const { return rtp.getContext(); } + + Type getElementType() const { return rtp.getElementType(); } + + /// Returns the encoding (or the null-attribute for dense-tensors). + SparseTensorEncodingAttr getEncoding() const { return enc; } + + /// Returns true for tensors which have an encoding, and false for + /// those which do not. Therefore tensors with an all-dense encoding + /// return true. + bool hasEncoding() const { return static_cast(enc); } + + /// Returns true for tensors where every level is dense. + /// (This is always true for dense-tensors.) + bool isAllDense() const { return enc.isAllDense(); } + + /// Returns true for tensors where every level is ordered. + /// (This is always true for dense-tensors.) + bool isAllOrdered() const { return enc.isAllOrdered(); } + + /// Returns true if the dimToLvl mapping is the identity. + /// (This is always true for dense-tensors.) + bool isIdentity() const { return !dim2lvl; } + + /// Returns the dimToLvl mapping (or the null-map for the identity). + AffineMap getDimToLvlMap() const { return dim2lvl; } + + /// Returns the dimToLvl mapping, where the identity map is expanded out + /// into a full `AffineMap`. This method is provided as a convenience, + /// but for most purposes other methods (`isIdentity`, `getDimToLvlMap`, + /// etc) will be more helpful. + AffineMap getExpandedDimToLvlMap() const { + return dim2lvl + ? dim2lvl + : AffineMap::getMultiDimIdentityMap(getDimRank(), getContext()); + } + + /// Returns the dimension-rank. + Dimension getDimRank() const { return rtp.getRank(); } + + /// Returns the level-rank. + Level getLvlRank() const { return lvlRank; } + + /// Returns the dimension-shape. + ArrayRef getDimShape() const { return rtp.getShape(); } + + /// Safely looks up the requested dimension-DynSize. If you intend + /// to check the result with `ShapedType::isDynamic`, then see the + /// `getStaticDimSize` method instead. + DynSize getDynamicDimSize(Dimension d) const { + assert(d < getDimRank() && "Dimension is out of bounds"); + return getDimShape()[d]; + } + + /// Safely looks up the requested dimension-size, mapping dynamic + /// sizes to `std::nullopt`. + std::optional getStaticDimSize(Dimension d) const { + const DynSize sh = getDynamicDimSize(d); + return ShapedType::isDynamic(sh) ? std::nullopt + : std::optional(sh); + } + + /// Returns true if no dimension has dynamic size. + bool hasStaticDimShape() const { return rtp.hasStaticShape(); } + + /// Returns true if any dimension has dynamic size. + bool hasDynamicDimShape() const { return !hasStaticDimShape(); } + + /// Returns true if the given dimension has dynamic size. If you + /// intend to call `getDynamicDimSize` based on the result, then see + /// the `getStaticDimSize` method instead. + bool isDynamicDim(Dimension d) const { + // We don't use `rtp.isDynamicDim(d)` because we want the + // OOB error message to be consistent with `getDynamicDimSize`. + return ShapedType::isDynamic(getDynamicDimSize(d)); + } + + /// Returns the number of dimensions which have dynamic sizes. + /// The return type is `int64_t` to maintain consistency with + /// `ShapedType::Trait::getNumDynamicDims`. + int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); } + + DimLevelType getLvlType(Level l) const { + // This OOB check is for dense-tensors, since this class knows + // their lvlRank (whereas STEA::getLvlType will/can only check + // OOB for sparse-tensors). + assert(l < lvlRank && "Level out of bounds"); + return enc.getLvlType(l); + } + + // We can't just delegate these, since we want to use this class's + // `getLvlType` method instead of STEA's. + bool isDenseLvl(Level l) const { return isDenseDLT(getLvlType(l)); } + bool isCompressedLvl(Level l) const { return isCompressedDLT(getLvlType(l)); } + bool isSingletonLvl(Level l) const { return isSingletonDLT(getLvlType(l)); } + bool isOrderedLvl(Level l) const { return isOrderedDLT(getLvlType(l)); } + bool isUniqueLvl(Level l) const { return isUniqueDLT(getLvlType(l)); } + + /// Returns the index-overhead bitwidth, defaulting to zero. + unsigned getIndexBitWidth() const { return enc ? enc.getIndexBitWidth() : 0; } + + /// Returns the pointer-overhead bitwidth, defaulting to zero. + unsigned getPointerBitWidth() const { + return enc ? enc.getPointerBitWidth() : 0; + } + +private: + // These two must be const, to ensure coherence of the memoized fields. + const RankedTensorType rtp; + const SparseTensorEncodingAttr enc; + // Memoized to avoid frequent redundant conditionals. + const Level lvlRank; + const AffineMap dim2lvl; +}; + +/// Convenience method to abbreviate wrapping `getRankedTensorType`. +template +inline SparseTensorType getSparseTensorType(T t) { + return SparseTensorType(getRankedTensorType(t)); +} + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORTYPE_H_ diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index b667ad3..831cdd8 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -70,13 +70,13 @@ mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { } intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { - return unwrap(attr).cast().getDimLevelType().size(); + return unwrap(attr).cast().getLvlRank(); } MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getDimLevelType()[pos]); + unwrap(attr).cast().getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 506d933..32d998f 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -9,6 +9,7 @@ #include #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" @@ -33,8 +34,10 @@ using namespace mlir::sparse_tensor; // Additional convenience methods. //===----------------------------------------------------------------------===// +/// Gets the dimension-rank of the type of some `T`. (In particular +/// this is only used for `Value` and `TypedValue`.) template -static inline int64_t getTypeRank(T t) { +static inline Dimension getDimRank(T t) { return getRankedTensorType(t).getRank(); } @@ -132,40 +135,59 @@ SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const { } bool SparseTensorEncodingAttr::isAllDense() const { - return llvm::all_of(getDimLevelType(), isDenseDLT); + return !getImpl() || llvm::all_of(getDimLevelType(), isDenseDLT); +} + +bool SparseTensorEncodingAttr::isAllOrdered() const { + return !getImpl() || llvm::all_of(getDimLevelType(), isOrderedDLT); } bool SparseTensorEncodingAttr::hasIdDimOrdering() const { - return !getDimOrdering() || getDimOrdering().isIdentity(); + return !getImpl() || !getDimOrdering() || getDimOrdering().isIdentity(); +} + +Level SparseTensorEncodingAttr::getLvlRank() const { + assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); + return getDimLevelType().size(); +} + +DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const { + if (!getImpl()) + return DimLevelType::Dense; + assert(l < getLvlRank() && "Level is out of bounds"); + return getDimLevelType()[l]; } std::optional -SparseTensorEncodingAttr::getStaticDimSliceOffset(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { return getDimSlices()[dim].getStaticOffset(); } std::optional -SparseTensorEncodingAttr::getStaticDimSliceSize(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const { return getDimSlices()[dim].getStaticSize(); } std::optional -SparseTensorEncodingAttr::getStaticDimSliceStride(unsigned dim) const { +SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { return getDimSlices()[dim].getStaticStride(); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceOffset(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceOffset(toOrigDim(*this, lvl)); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceSize(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceSize(toOrigDim(*this, lvl)); } std::optional -SparseTensorEncodingAttr::getStaticLvlSliceStride(unsigned lvl) const { +SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { + // FIXME: `toOrigDim` is deprecated. return getStaticDimSliceStride(toOrigDim(*this, lvl)); } @@ -296,11 +318,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { // Print the struct-like storage in dictionary fashion. printer << "<{ dimLevelType = [ "; - for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { - printer << "\"" << toMLIRString(getDimLevelType()[i]) << "\""; - if (i != e - 1) - printer << ", "; - } + llvm::interleaveComma(getDimLevelType(), printer, [&](DimLevelType dlt) { + printer << "\"" << toMLIRString(dlt) << "\""; + }); printer << " ]"; // Print remaining members only for non-default values. if (!hasIdDimOrdering()) @@ -334,11 +354,19 @@ LogicalResult SparseTensorEncodingAttr::verify( return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; if (!acceptBitWidth(indexBitWidth)) return emitError() << "unexpected index bitwidth: " << indexBitWidth; + // Before we can check that the level-rank is consistent/coherent + // across all fields, we need to define it. The source-of-truth for + // the `getLvlRank` method is the length of the level-types array, + // since it must always be provided and have full rank; therefore we + // use that same source-of-truth here. + const Level lvlRank = dimLevelType.size(); + if (lvlRank == 0) + return emitError() << "expected a non-empty array for level types"; if (dimOrdering) { if (!dimOrdering.isPermutation()) return emitError() << "expected a permutation affine map for dimension ordering"; - if (dimOrdering.getNumResults() != dimLevelType.size()) + if (dimOrdering.getNumResults() != lvlRank) return emitError() << "unexpected mismatch in ordering and dimension " "level types size"; } @@ -347,11 +375,11 @@ LogicalResult SparseTensorEncodingAttr::verify( return emitError() << "unexpected higher ordering mapping from " << higherOrdering.getNumDims() << " to " << higherOrdering.getNumResults(); - if (higherOrdering.getNumResults() != dimLevelType.size()) + if (higherOrdering.getNumResults() != lvlRank) return emitError() << "unexpected mismatch in higher ordering and " "dimension level types size"; } - if (!dimSlices.empty() && dimSlices.size() != dimLevelType.size()) { + if (!dimSlices.empty() && dimSlices.size() != lvlRank) { return emitError() << "unexpected mismatch in dimension slices and " "dimension level type size"; } @@ -364,32 +392,28 @@ LogicalResult SparseTensorEncodingAttr::verify( } LogicalResult SparseTensorEncodingAttr::verifyEncoding( - ArrayRef shape, Type elementType, + ArrayRef dimShape, Type elementType, function_ref emitError) const { - // Check structural integrity. + // Check structural integrity. In particular, this ensures that the + // level-rank is coherent across all the fields. RETURN_FAILURE_IF_FAILED(verify( emitError, getDimLevelType(), getDimOrdering(), getHigherOrdering(), getPointerBitWidth(), getIndexBitWidth(), getDimSlices())) - // Check integrity with tensor type specifics. Dimension ordering is optional, - // but we always should have dimension level types for the full rank. - unsigned size = shape.size(); - if (size == 0) + // Check integrity with tensor type specifics. In particular, we + // need only check that the dimension-rank of the tensor agrees with + // the dimension-rank of the encoding. + const Dimension dimRank = dimShape.size(); + if (dimRank == 0) return emitError() << "expected non-scalar sparse tensor"; - if (getHigherOrdering()) { - if (getHigherOrdering().getNumDims() != size) - return emitError() << "expected an affine map of size " << size - << " for higher ordering"; - + if (const auto higherOrdering = getHigherOrdering()) { + if (higherOrdering.getNumDims() != dimRank) + return emitError() << "expected an affine map with " << dimRank + << " dimensions for higher ordering"; // TODO: verification of higher ordering contents - - size = getHigherOrdering().getNumResults(); // higher-order size! - } - if (getDimOrdering() && getDimOrdering().getNumResults() != size) - return emitError() << "expected an affine map of size " << size - << " for dimension ordering"; - if (getDimLevelType().size() != size) - return emitError() << "expected an array of size " << size + } else if (dimRank != getLvlRank()) { + return emitError() << "expected an array of size " << dimRank << " for dimension level types"; + } return success(); } @@ -407,69 +431,72 @@ mlir::sparse_tensor::getSparseTensorEncoding(Type type) { } /// Returns true iff the given sparse tensor encoding attribute has a trailing -/// COO region starting at the given dimension. -static bool isCOOType(SparseTensorEncodingAttr enc, uint64_t s, bool isUnique) { - uint64_t rank = enc.getDimLevelType().size(); - assert(s < rank && "Dimension out of bounds"); - if (!isCompressedDim(enc, s)) +/// COO region starting at the given level. +static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, + bool isUnique) { + if (!enc || !enc.isCompressedLvl(startLvl)) return false; - - for (uint64_t i = s + 1; i < rank; ++i) - if (!isSingletonDim(enc, i)) + const Level lvlRank = enc.getLvlRank(); + for (Level l = startLvl + 1; l < lvlRank; ++l) + if (!enc.isSingletonLvl(l)) return false; - - // If isUnique is true, then make sure that the last dimension level is - // unique, that is, rank == 1 (unique the only compressed) and rank > 1 + // If isUnique is true, then make sure that the last level is unique, + // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1 // (unique on the last singleton). - return !isUnique || isUniqueDLT(getDimLevelType(enc, rank - 1)); + return !isUnique || enc.isUniqueLvl(lvlRank - 1); } bool mlir::sparse_tensor::isUniqueCOOType(TensorType tp) { - SparseTensorEncodingAttr enc = getSparseTensorEncoding(tp); - return enc && isCOOType(enc, 0, /*isUnique=*/true); + return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true); } -unsigned mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { - const unsigned rank = enc.getDimLevelType().size(); - // We only consider COO region with at least two dimensions for the purpose +Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) { + // We only consider COO region with at least two levels for the purpose // of AOS storage optimization. - if (rank > 1) - for (unsigned r = 0; r < rank - 1; r++) - if (isCOOType(enc, r, /*isUnique=*/false)) - return r; - - return rank; + const Level lvlRank = enc.getLvlRank(); + if (lvlRank > 1) + for (Level l = 0; l < lvlRank - 1; l++) + if (isCOOType(enc, l, /*isUnique=*/false)) + return l; + return lvlRank; } // Helpers to setup a COO type. -RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType src, - AffineMap ordering, +RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt, + AffineMap lvlPerm, bool ordered) { - auto *ctx = src.getContext(); - auto rank = src.getRank(); - SmallVector dims; - - // An unordered and non-unique compressed dim at beginning. - // If this is also the last dimension, then it is unique. - dims.push_back(*getDimLevelType(LevelFormat::Compressed, ordered, rank == 1)); - if (rank > 1) { + const SparseTensorType src(rtt); + // The dim-rank of the source `RankedTensorType` is used as the lvl-rank + // of the result `RankedTensorType`. This follows from the fact that the + // result's encoding has the default higher-ordering (hence the result's + // lvl-rank equals its dim-rank). We don't need to assert that `lvlRank` + // agrees with the size of `lvlPerm` because that will be verified by + // `STEA::get`. + const Level lvlRank = src.getDimRank(); + SmallVector lvlTypes; + + // An unordered and non-unique compressed level at beginning. + // If this is also the last level, then it is unique. + lvlTypes.push_back( + *getDimLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); + if (lvlRank > 1) { // TODO: it is actually ordered at the level for ordered input. // Followed by unordered non-unique n-2 singleton levels. - std::fill_n(std::back_inserter(dims), rank - 2, + std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, *getDimLevelType(LevelFormat::Singleton, ordered, false)); - // Ends by a unique singleton level unless the tensor rank is 1. - dims.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true)); + // Ends by a unique singleton level unless the lvlRank is 1. + lvlTypes.push_back(*getDimLevelType(LevelFormat::Singleton, ordered, true)); } - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(src); // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. - unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0; - unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0; - auto enc = SparseTensorEncodingAttr::get(ctx, dims, ordering, AffineMap(), - pointerBitWidth, indexBitWidth); - return RankedTensorType::get(src.getShape(), src.getElementType(), enc); + unsigned pointerBitWidth = src.getPointerBitWidth(); + unsigned indexBitWidth = src.getIndexBitWidth(); + auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm, + AffineMap(), pointerBitWidth, + indexBitWidth); + return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc); } RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, @@ -479,20 +506,24 @@ RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src, ordered); } -uint64_t mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, - uint64_t d) { +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc, + Level l) { if (enc) { auto order = enc.getDimOrdering(); if (order) { assert(order.isPermutation()); - return order.getDimPosition(d); + return order.getDimPosition(l); } } - return d; + return l; } -uint64_t mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, - uint64_t d) { +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, + Dimension d) { if (enc) { auto order = enc.getDimOrdering(); if (order) { @@ -506,13 +537,18 @@ uint64_t mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc, return d; } -uint64_t mlir::sparse_tensor::toOrigDim(RankedTensorType type, uint64_t d) { - assert(d < static_cast(type.getRank())); - return toOrigDim(getSparseTensorEncoding(type), d); +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) { + const auto enc = getSparseTensorEncoding(type); + assert(l < enc.getLvlRank()); + return toOrigDim(enc, l); } -uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) { - assert(d < static_cast(type.getRank())); +// TODO: Remove this definition once all use-sites have been fixed to +// properly handle non-permutations. +Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) { + assert(d < static_cast(type.getRank())); return toStoredDim(getSparseTensorEncoding(type), d); } @@ -554,6 +590,8 @@ IntegerType StorageSpecifierType::getSizesType() const { return IntegerType::get(getContext(), std::max(idxBitWidth, ptrBitWidth)); } +// FIXME: see note [CLARIFY_DIM_LVL] in +// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, std::optional dim) const { if (kind != StorageSpecifierKind::ValMemSize) @@ -565,6 +603,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, return getSizesType(); } +// FIXME: see note [CLARIFY_DIM_LVL] in +// "lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h" Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, std::optional dim) const { return getFieldType(kind, dim ? std::optional(dim.value().getZExtValue()) @@ -575,8 +615,8 @@ Type StorageSpecifierType::getFieldType(StorageSpecifierKind kind, // SparseTensorDialect Operations. //===----------------------------------------------------------------------===// -static LogicalResult isInBounds(uint64_t dim, Value tensor) { - return success(dim < static_cast(getTypeRank(tensor))); +static LogicalResult dimIsInBounds(Dimension dim, Value tensor) { + return success(dim < getDimRank(tensor)); } static LogicalResult isMatchingWidth(Value result, unsigned width) { @@ -585,26 +625,25 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) { } static LogicalResult verifySparsifierGetterSetter( - StorageSpecifierKind mdKind, std::optional dim, + StorageSpecifierKind mdKind, std::optional lvl, TypedValue md, Operation *op) { - if (mdKind == StorageSpecifierKind::ValMemSize && dim) { + if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { return op->emitError( - "redundant dimension argument for querying value memory size"); + "redundant level argument for querying value memory size"); } - auto enc = md.getType().getEncoding(); - ArrayRef dlts = enc.getDimLevelType(); - unsigned rank = dlts.size(); + const auto enc = md.getType().getEncoding(); + const Level lvlRank = enc.getLvlRank(); if (mdKind != StorageSpecifierKind::ValMemSize) { - if (!dim) - return op->emitError("missing dimension argument"); + if (!lvl) + return op->emitError("missing level argument"); - unsigned d = dim.value().getZExtValue(); - if (d >= rank) - return op->emitError("requested dimension out of bound"); + const Level l = lvl.value().getZExtValue(); + if (l >= lvlRank) + return op->emitError("requested level out of bound"); - if (mdKind == StorageSpecifierKind::PtrMemSize && isSingletonDLT(dlts[d])) + if (mdKind == StorageSpecifierKind::PtrMemSize && enc.isSingletonLvl(l)) return op->emitError( "requested pointer memory size on a singleton level"); } @@ -612,7 +651,7 @@ static LogicalResult verifySparsifierGetterSetter( } LogicalResult NewOp::verify() { - if (getExpandSymmetry() && getTypeRank(getResult()) != 2) + if (getExpandSymmetry() && getDimRank(getResult()) != 2) return emitOpError("expand_symmetry can only be used for 2D tensors"); return success(); } @@ -670,7 +709,7 @@ LogicalResult ConvertOp::verify() { // Accept size matches between the source and the destination type // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). - for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) + for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) return emitError("unexpected conversion mismatch in dimension ") << d; return success(); @@ -692,7 +731,8 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { LogicalResult ToPointersOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (failed(isInBounds(getDimension().getZExtValue(), getTensor()))) + // FIXME: there seems to be some dim/lvl confusion here. + if (failed(dimIsInBounds(getDimension().getZExtValue(), getTensor()))) return emitError("requested pointers dimension out of bounds"); if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth()))) return emitError("unexpected type for pointers"); @@ -701,7 +741,8 @@ LogicalResult ToPointersOp::verify() { LogicalResult ToIndicesOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (failed(isInBounds(getDimension().getZExtValue(), getTensor()))) + // FIXME: there seems to be some dim/lvl confusion here. + if (failed(dimIsInBounds(getDimension().getZExtValue(), getTensor()))) return emitError("requested indices dimension out of bounds"); if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth()))) return emitError("unexpected type for indices"); @@ -710,7 +751,7 @@ LogicalResult ToIndicesOp::verify() { LogicalResult ToIndicesBufferOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); - if (getCOOStart(e) >= e.getDimLevelType().size()) + if (getCOOStart(e) >= e.getLvlRank()) return emitError("expected sparse tensor with a COO region"); return success(); } @@ -846,58 +887,56 @@ LogicalResult UnaryOp::verify() { } LogicalResult ConcatenateOp::verify() { - auto dstTp = getRankedTensorType(*this); - uint64_t concatDim = getDimension().getZExtValue(); - unsigned rank = dstTp.getRank(); + const auto dstTp = getSparseTensorType(*this); + const Dimension concatDim = getDimension().getZExtValue(); + const Dimension dimRank = dstTp.getDimRank(); if (getInputs().size() <= 1) return emitError("Need at least two tensors to concatenate."); - for (auto type : getInputs().getTypes()) { - auto shape = type.cast().getShape(); - for (auto dim : shape) { - if (ShapedType::isDynamic(dim)) - return emitError("Only statically-sized input tensors are supported."); - } - } - - if (concatDim >= rank) + if (concatDim >= dimRank) return emitError(llvm::formatv( - "Failed to concatentate tensors with rank={0} on dimension={1}.", rank, - concatDim)); - - for (size_t i = 0, e = getInputs().size(); i < e; i++) { - const auto inputRank = getTypeRank(getInputs()[i]); - if (inputRank != rank) + "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", + concatDim, dimRank)); + + for (const auto &it : llvm::enumerate(getInputs())) { + const auto i = it.index(); + const auto srcTp = getSparseTensorType(it.value()); + if (srcTp.hasDynamicDimShape()) + return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); + const Dimension srcDimRank = srcTp.getDimRank(); + if (srcDimRank != dimRank) return emitError( - llvm::formatv("The input tensor ${0} has a different rank (rank={1}) " + llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " "from the output tensor (rank={2}).", - i, inputRank, rank)); + i, srcDimRank, dimRank)); } - for (unsigned i = 0; i < rank; i++) { - const auto dstDim = dstTp.getShape()[i]; - if (i == concatDim) { - if (!ShapedType::isDynamic(dstDim)) { - // If we reach here, all inputs should have static shapes. - unsigned sumDim = 0; - for (auto src : getInputs()) - sumDim += getRankedTensorType(src).getShape()[i]; + for (Dimension d = 0; d < dimRank; d++) { + const DynSize dstSh = dstTp.getDimShape()[d]; + if (d == concatDim) { + if (!ShapedType::isDynamic(dstSh)) { + // If we reach here, then all inputs have static shapes. So we + // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` + // to avoid redundant assertions in the loop. + StaticSize sumSz = 0; + for (const auto src : getInputs()) + sumSz += getSparseTensorType(src).getDimShape()[d]; // If all dimension are statically known, the sum of all the input // dimensions should be equal to the output dimension. - if (sumDim != dstDim) + if (sumSz != dstSh) return emitError( "The concatenation dimension of the output tensor should be the " "sum of all the concatenation dimensions of the input tensors."); } } else { - int64_t prev = dstDim; - for (auto src : getInputs()) { - const auto d = getRankedTensorType(src).getShape()[i]; - if (!ShapedType::isDynamic(prev) && d != prev) + DynSize prev = dstSh; + for (const auto src : getInputs()) { + const auto sh = getSparseTensorType(src).getDimShape()[d]; + if (!ShapedType::isDynamic(prev) && sh != prev) return emitError("All dimensions (expect for the concatenating one) " "should be equal."); - prev = d; + prev = sh; } } } @@ -906,7 +945,7 @@ LogicalResult ConcatenateOp::verify() { } LogicalResult InsertOp::verify() { - if (getTypeRank(getTensor()) != static_cast(getIndices().size())) + if (getDimRank(getTensor()) != static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -926,7 +965,8 @@ LogicalResult PushBackOp::verify() { } LogicalResult CompressOp::verify() { - if (getTypeRank(getTensor()) != 1 + static_cast(getIndices().size())) + if (getDimRank(getTensor()) != + 1 + static_cast(getIndices().size())) return emitOpError("incorrect number of indices"); return success(); } @@ -947,37 +987,34 @@ void ForeachOp::build( // Builds foreach body. if (!bodyBuilder) return; - auto rtp = getRankedTensorType(tensor); - int64_t rank = rtp.getRank(); + const auto stt = getSparseTensorType(tensor); + const Dimension dimRank = stt.getDimRank(); - SmallVector blockArgTypes; - // Starts with n index. - std::fill_n(std::back_inserter(blockArgTypes), rank, builder.getIndexType()); + // Starts with `dimRank`-many indices. + SmallVector blockArgTypes(dimRank, builder.getIndexType()); // Followed by one value. - blockArgTypes.push_back(rtp.getElementType()); - // Followed by reduction variable. + blockArgTypes.push_back(stt.getElementType()); + // Followed by the reduction variables. blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); - SmallVector blockArgLocs; - std::fill_n(std::back_inserter(blockArgLocs), blockArgTypes.size(), - tensor.getLoc()); + SmallVector blockArgLocs(blockArgTypes.size(), tensor.getLoc()); OpBuilder::InsertionGuard guard(builder); auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); bodyBuilder(builder, result.location, - bodyBlock->getArguments().slice(0, rank), - bodyBlock->getArguments()[rank], - bodyBlock->getArguments().drop_front(rank + 1)); + bodyBlock->getArguments().slice(0, dimRank), + bodyBlock->getArguments()[dimRank], + bodyBlock->getArguments().drop_front(dimRank + 1)); } LogicalResult ForeachOp::verify() { - auto t = getRankedTensorType(getTensor()); - auto args = getBody()->getArguments(); + const auto t = getSparseTensorType(getTensor()); + const Dimension dimRank = t.getDimRank(); + const auto args = getBody()->getArguments(); - if (static_cast(t.getRank()) + 1 + getInitArgs().size() != - args.size()) + if (static_cast(dimRank) + 1 + getInitArgs().size() != args.size()) return emitError("Unmatched number of arguments in the block"); if (getNumResults() != getInitArgs().size()) @@ -986,18 +1023,20 @@ LogicalResult ForeachOp::verify() { if (getResultTypes() != getInitArgs().getTypes()) return emitError("Mismatch in types of init arguments and results"); + // Cannot mark this const, because the getters aren't. auto yield = cast(getBody()->getTerminator()); if (yield.getNumOperands() != getNumResults() || yield.getOperands().getTypes() != getResultTypes()) return emitError("Mismatch in types of yield values and results"); - for (int64_t i = 0, e = t.getRank(); i < e; i++) - if (args[i].getType() != IndexType::get(getContext())) + const auto iTp = IndexType::get(getContext()); + for (Dimension d = 0; d < dimRank; d++) + if (args[d].getType() != iTp) emitError( - llvm::formatv("Expecting Index type for argument at index {0}", i)); + llvm::formatv("Expecting Index type for argument at index {0}", d)); - auto elemTp = t.getElementType(); - auto valueTp = args[t.getRank()].getType(); + const auto elemTp = t.getElementType(); + const auto valueTp = args[dimRank].getType(); if (elemTp != valueTp) emitError(llvm::formatv("Unmatched element type between input tensor and " "block argument, expected:{0}, got: {1}", @@ -1036,13 +1075,13 @@ LogicalResult SortOp::verify() { bool checkEleType = true) -> LogicalResult { for (Value opnd : operands) { auto mtp = getMemRefType(opnd); - int64_t dim = mtp.getShape()[0]; + const DynSize sh = mtp.getShape()[0]; // We can't check the size of dynamic dimension at compile-time, but all // xs and ys should have a dimension not less than n at runtime. - if (n && !ShapedType::isDynamic(dim) && dim < n.value()) + if (n && !ShapedType::isDynamic(sh) && sh < n.value()) return emitError(llvm::formatv("xs and ys need to have a dimension >= n" ": {0} < {1}", - dim, n.value())); + sh, n.value())); if (checkEleType && xtp != mtp.getElementType()) return emitError("mismatch xs element types"); @@ -1072,12 +1111,13 @@ LogicalResult SortCooOp::verify() { ny = nyAttr.getInt(); } - auto checkDim = [&](Value v, uint64_t min, const char *message) { - auto tp = getMemRefType(v); - int64_t dim = tp.getShape()[0]; - if (!ShapedType::isDynamic(dim) && dim < (int64_t)min) { - emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min)); - } + // FIXME: update the types of variables used in expressions bassed as + // the `minSize` argument, to avoid implicit casting at the callsites + // of this lambda. + const auto checkDim = [&](Value v, StaticSize minSize, const char *message) { + const DynSize sh = getMemRefType(v).getShape()[0]; + if (!ShapedType::isDynamic(sh) && sh < minSize) + emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); }; checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp index a8ebad3..62acac2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -10,7 +10,9 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" + #include using namespace mlir; @@ -114,10 +116,9 @@ bool CodegenEnv::isAdmissibleTensorExp(unsigned exp) { OpOperand *lhs = linalgOp.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); - auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random // access n-dim memref. Admissible since insertions cannot occur. - if (!enc || enc.isAllDense()) + if (getSparseTensorType(lhs->get()).isAllDense()) return true; // A tensor expression with a sparse output tensor that changes its values diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index cf2f127..97ec7e2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -236,7 +236,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, void mlir::sparse_tensor::genReshapeDstShape( Location loc, PatternRewriter &rewriter, SmallVectorImpl &dstShape, - ArrayRef srcShape, ArrayRef staticDstShape, + ArrayRef srcShape, ArrayRef staticDstShape, ArrayRef reassociation) { // Collapse shape. if (reassociation.size() < srcShape.size()) { @@ -269,7 +269,7 @@ void mlir::sparse_tensor::genReshapeDstShape( if (staticDstShape[j] == ShapedType::kDynamic) { // The expanded dimension has dynamic size. We compute the dimension // by dividing srcDim by the product of the static dimensions. - int64_t product = 1; + StaticSize product = 1; for (unsigned k = start; k < start + map.size(); k++) { if (staticDstShape[k] != ShapedType::kDynamic) { product *= staticDstShape[k]; @@ -483,9 +483,9 @@ void mlir::sparse_tensor::genDenseTensorOrSparseConstantIterLoop( void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder, SmallVectorImpl &sizes, Location loc, Value src) { - unsigned rank = src.getType().cast().getRank(); - for (unsigned i = 0; i < rank; i++) - sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); + const Dimension dimRank = getSparseTensorType(src).getDimRank(); + for (Dimension d = 0; d < dimRank; d++) + sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d)); } Operation *mlir::sparse_tensor::getTop(Operation *op) { @@ -532,9 +532,20 @@ void sparse_tensor::foreachInSparseConstant( } void sparse_tensor::storeIndices(OpBuilder &builder, Location loc, - unsigned rank, Value ind, ValueRange ivs, + unsigned size, Value ind, ValueRange ivs, unsigned offsetDim, Value offset) { - for (unsigned i = 0; i < rank; i++) { +#ifndef NDEBUG + const auto memTp = ind.getType().cast(); + (void)memTp; + assert(memTp.getRank() == 1); + const DynSize memSh = memTp.getDimSize(0); + (void)memSh; + assert(ShapedType::isDynamic(memSh) || memSh == static_cast(size)); + assert(ivs.size() == static_cast(size)); + assert(offsetDim < size); +#endif // NDEBUG + + for (unsigned i = 0; i < size; i++) { Value idx = ivs[i]; if (offsetDim == i && offset) idx = builder.create(loc, idx, offset); @@ -543,44 +554,47 @@ void sparse_tensor::storeIndices(OpBuilder &builder, Location loc, } } -Value sparse_tensor::reshapeValuesToLevels( - OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc, - const SmallVectorImpl &dimSizes, Value valuesBuffer, - Value idxBuffer) { - // Use the dstIdx to store the level sizes. - unsigned rank = enc.getDimLevelType().size(); +Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc, + SparseTensorEncodingAttr enc, + ValueRange dimSizes, + Value valuesBuffer, + Value idxBuffer) { + // Use the `idxBuffer` to store the level sizes. + const Level lvlRank = enc.getLvlRank(); SmallVector lvlSizes; - for (unsigned i = 0; i < dimSizes.size(); i++) - lvlSizes.push_back(dimSizes[toOrigDim(enc, i)]); - storeIndices(builder, loc, rank, idxBuffer, lvlSizes); + lvlSizes.reserve(lvlRank); + for (Level l = 0; l < lvlRank; l++) + // FIXME: `toOrigDim` is deprecated. + lvlSizes.push_back(dimSizes[toOrigDim(enc, l)]); + storeIndices(builder, loc, lvlRank, idxBuffer, lvlSizes); // The memref ReshapeOp requires the sizes buffer to have a static // shape. - idxBuffer = builder.create( - loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer); - SmallVector shape(rank, ShapedType::kDynamic); - Type elemTp = getMemRefType(valuesBuffer).getElementType(); - return builder.create(loc, MemRefType::get(shape, elemTp), - valuesBuffer, idxBuffer); + const auto iTp = builder.getIndexType(); + const SmallVector idxBufferShape{static_cast(lvlRank)}; + const auto idxBufferTp = MemRefType::get(idxBufferShape, iTp); + idxBuffer = builder.create(loc, idxBufferTp, idxBuffer); + const SmallVector resShape(lvlRank, ShapedType::kDynamic); + const Type elemTp = getMemRefType(valuesBuffer).getElementType(); + const auto resTp = MemRefType::get(resShape, elemTp); + return builder.create(loc, resTp, valuesBuffer, idxBuffer); } Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc, - Value tensor, uint64_t d) { - RankedTensorType srcTp = getRankedTensorType(tensor); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc), - /*withLayout=*/false); - return builder.create(loc, ptrTp, tensor, - builder.getIndexAttr(d)); + Value tensor, Level lvl) { + const auto srcTp = getSparseTensorType(tensor); + const Type ptrTp = getPointerOverheadType(builder, srcTp.getEncoding()); + const Type memTp = get1DMemRefType(ptrTp, /*withLayout=*/false); + return builder.create(loc, memTp, tensor, + builder.getIndexAttr(lvl)); } Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc, - Value tensor, uint64_t d, uint64_t cooStart) { - RankedTensorType srcTp = getRankedTensorType(tensor); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); - Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc), - /*withLayout=*/d >= cooStart); - return builder.create(loc, indTp, tensor, - builder.getIndexAttr(d)); + Value tensor, Level lvl, Level cooStart) { + const auto srcTp = getSparseTensorType(tensor); + const Type idxTp = getIndexOverheadType(builder, srcTp.getEncoding()); + const Type memTp = get1DMemRefType(idxTp, /*withLayout=*/lvl >= cooStart); + return builder.create(loc, memTp, tensor, + builder.getIndexAttr(lvl)); } Value sparse_tensor::genToValues(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 12cfd3b..9e98c9d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -228,8 +228,7 @@ void storeIndices(OpBuilder &builder, Location loc, unsigned rank, Value ind, /// to match the shape of the corresponding dense tensor to support direct /// access of the buffer through indices. Value reshapeValuesToLevels(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - const SmallVectorImpl &dimSizes, + SparseTensorEncodingAttr enc, ValueRange dimSizes, Value valuesBuffer, Value idxBuffer); //===----------------------------------------------------------------------===// @@ -345,13 +344,13 @@ inline bool isZeroRankedTensorOrScalar(Type type) { } /// Infers the result type and generates ToPointersOp. -Value genToPointers(OpBuilder &builder, Location loc, Value tensor, uint64_t d); +Value genToPointers(OpBuilder &builder, Location loc, Value tensor, Level lvl); -/// Infers the result type and generates ToIndicesOp. If the dim is within a COO +/// Infers the result type and generates ToIndicesOp. If the lvl is within a COO /// region, the result type is a memref with unknown stride and offset. /// Otherwise, the result type is a memref without any specified layout. -Value genToIndices(OpBuilder &builder, Location loc, Value tensor, uint64_t d, - uint64_t cooStart); +Value genToIndices(OpBuilder &builder, Location loc, Value tensor, Level lvl, + Level cooStart); /// Infers the result type and generates ToValuesOp. Value genToValues(OpBuilder &builder, Location loc, Value tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp index e9724bf..7f94b3a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -166,37 +166,41 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, // For every tensor, find lower and upper bound on dimensions, set the // same bounds on loop indices, and obtain dense or sparse buffer(s). for (size_t t = 0, e = tensors.size(); t < e; t++) { - auto tensor = tensors[t]; - auto rtp = tensor.getType().dyn_cast(); + const auto tensor = tensors[t]; + const auto rtp = tensor.getType().dyn_cast(); if (!rtp) // Skips only scalar, zero ranked tensor still need to be bufferized and // (probably) filled with zeros by users. continue; - auto rank = rtp.getRank(); - auto shape = rtp.getShape(); - auto enc = getSparseTensorEncoding(rtp); - uint64_t cooStart = enc ? getCOOStart(enc) : rank; - // Scan all dimensions of current tensor. - for (int64_t d = 0; d < rank; d++) { + // FIXME: the definition of `lvlRank` looks more like a dim-rank; + // but the variable is used as a level everywhere below, which + // suggests there may be some dim/lvl confusion going on here. + const Level lvlRank = rtp.getRank(); + const auto shape = rtp.getShape(); + const auto enc = getSparseTensorEncoding(rtp); + const Level cooStart = enc ? getCOOStart(enc) : lvlRank; + // Scan all levels of current tensor. + for (Level l = 0; l < lvlRank; l++) { // This should be called only once at beginning. - assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !highs[t][d]); + assert(!ptrBuffer[t][l] && !idxBuffer[t][l] && !highs[t][l]); + const auto dlt = dimTypes[t][l]; // Handle sparse storage schemes. - if (isCompressedDLT(dimTypes[t][d])) { + if (isCompressedDLT(dlt)) { // Generate sparse primitives to obtains pointer and indices. - ptrBuffer[t][d] = genToPointers(builder, loc, tensor, d); - idxBuffer[t][d] = genToIndices(builder, loc, tensor, d, cooStart); - } else if (isSingletonDLT(dimTypes[t][d])) { + ptrBuffer[t][l] = genToPointers(builder, loc, tensor, l); + idxBuffer[t][l] = genToIndices(builder, loc, tensor, l, cooStart); + } else if (isSingletonDLT(dlt)) { // Singleton dimension, fetch indices. - idxBuffer[t][d] = genToIndices(builder, loc, tensor, d, cooStart); + idxBuffer[t][l] = genToIndices(builder, loc, tensor, l, cooStart); } else { // Dense dimension, nothing to fetch. - assert(isDenseDLT(dimTypes[t][d])); + assert(isDenseDLT(dlt)); } // Find upper bound in current dimension. - unsigned p = toOrigDim(enc, d); - Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p); - highs[t][d] = up; + // FIXME: `toOrigDim` is deprecated + const Dimension d = toOrigDim(enc, l); + highs[t][l] = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d); } // Perform the required bufferization. Dense inputs materialize diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp index 791f7f2..d634a1d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -24,11 +24,11 @@ namespace { static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); - unsigned rank = enc.getDimLevelType().size(); + const Level lvlRank = enc.getLvlRank(); SmallVector result; auto indexType = tp.getSizesType(); - auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, lvlRank); auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 797a318..ceee541 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -24,9 +24,12 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" + #include using namespace mlir; @@ -104,74 +107,77 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper, /// Gets the dimension size for the given sparse tensor at the given /// original dimension 'dim'. static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, unsigned dim) { - RankedTensorType rtp = desc.getTensorType(); + SparseTensorDescriptor desc, Dimension dim) { + const SparseTensorType stt(desc.getRankedTensorType()); // Access into static dimension can query original type directly. // Note that this is typically already done by DimOp's folding. - auto shape = rtp.getShape(); - if (!ShapedType::isDynamic(shape[dim])) - return constantIndex(builder, loc, shape[dim]); + if (auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); // Any other query can consult the dimSizes array at field DimSizesIdx, // accounting for the reordering applied to the sparse storage. - return desc.getDimSize(builder, loc, toStoredDim(rtp, dim)); + // FIXME: `toStoredDim` is deprecated. + const Level lvl = toStoredDim(stt, dim); + // FIXME: this method seems to get *level* sizes, but the name is confusing + return desc.getDimSize(builder, loc, lvl); } // Gets the dimension size at the given stored level 'lvl', either as a // constant for a static size, or otherwise dynamically through memSizes. static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc, - SparseTensorDescriptor desc, unsigned lvl) { + SparseTensorDescriptor desc, Level lvl) { + // FIXME: `toOrigDim` is deprecated. return sizeFromTensorAtDim(builder, loc, desc, - toOrigDim(desc.getTensorType(), lvl)); + toOrigDim(desc.getRankedTensorType(), lvl)); } static void createPushback(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, - SparseTensorFieldKind kind, - std::optional dim, Value value, - Value repeat = Value()) { - Type etp = desc.getMemRefElementType(kind, dim); - Value field = desc.getMemRefField(kind, dim); + SparseTensorFieldKind kind, std::optional lvl, + Value value, Value repeat = Value()) { + Type etp = desc.getMemRefElementType(kind, lvl); + Value field = desc.getMemRefField(kind, lvl); StorageSpecifierKind specFieldKind = toSpecifierKind(kind); auto pushBackOp = builder.create( - loc, desc.getSpecifierField(builder, loc, specFieldKind, dim), field, + loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field, toType(builder, loc, value, etp), repeat); - desc.setMemRefField(kind, dim, pushBackOp.getOutBuffer()); - desc.setSpecifierField(builder, loc, specFieldKind, dim, + desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer()); + desc.setSpecifierField(builder, loc, specFieldKind, lvl, pushBackOp.getNewSize()); } /// Generates code that allocates a sparse storage scheme for given rank. static void allocSchemeForRank(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, unsigned r0) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); + MutSparseTensorDescriptor desc, Level startLvl) { + const SparseTensorType stt(desc.getRankedTensorType()); Value linear = constantIndex(builder, loc, 1); - for (unsigned r = r0; r < rank; r++) { - if (isCompressedDim(rtp, r)) { + const Level lvlRank = stt.getLvlRank(); + for (Level l = startLvl; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); + Type ptrType = stt.getEncoding().getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); - createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l, ptrZero, linear); return; } - if (isSingletonDim(rtp, r)) { + if (isSingletonDLT(dlt)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized // at this level. We will eventually reach a compressed level or // otherwise the values array for the from-here "all-dense" case. - assert(isDenseDim(rtp, r)); - Value size = sizeFromTensorAtLvl(builder, loc, desc, r); + assert(isDenseDLT(dlt)); + Value size = sizeFromTensorAtLvl(builder, loc, desc, l); linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. - Value valZero = constantZero(builder, loc, rtp.getElementType()); + Value valZero = constantZero(builder, loc, stt.getElementType()); createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, std::nullopt, valZero, linear); } @@ -193,43 +199,40 @@ static Value createAllocation(OpBuilder &builder, Location loc, /// for all dynamic memrefs, the memory size is really the capacity of /// the "vector", while the actual size resides in the sizes array. /// -/// TODO: for efficiency, we will need heuristis to make educated guesses +/// TODO: for efficiency, we will need heuristics to make educated guesses /// on the required capacities (see heuristic variable). /// -static void createAllocFields(OpBuilder &builder, Location loc, Type type, - ValueRange dynSizes, bool enableInit, - SmallVectorImpl &fields, Value sizeHint) { - RankedTensorType rtp = type.cast(); - SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); - +static void createAllocFields(OpBuilder &builder, Location loc, + SparseTensorType stt, ValueRange dynSizes, + bool enableInit, SmallVectorImpl &fields, + Value sizeHint) { // Build original sizes. - SmallVector sizes; - auto shape = rtp.getShape(); - unsigned rank = shape.size(); - for (unsigned r = 0, o = 0; r < rank; r++) { - if (ShapedType::isDynamic(shape[r])) - sizes.push_back(dynSizes[o++]); - else - sizes.push_back(constantIndex(builder, loc, shape[r])); - } + assert((dynSizes.size() == static_cast(stt.getNumDynamicDims())) && + "Got wrong number of dynamic sizes"); + const Dimension dimRank = stt.getDimRank(); + SmallVector dimSizes; + dimSizes.reserve(dimRank); + unsigned i = 0; // cumulative index into `dynSizes`. + for (const DynSize sh : stt.getDimShape()) + dimSizes.push_back(ShapedType::isDynamic(sh) + ? dynSizes[i++] + : constantIndex(builder, loc, sh)); // Set up some heuristic sizes. We try to set the initial // size based on available information. Otherwise we just // initialize a few elements to start the reallocation chain. // TODO: refine this Value ptrHeuristic, idxHeuristic, valHeuristic; - if (enc.isAllDense()) { - Value linear = sizes[0]; - for (unsigned r = 1; r < rank; r++) { - linear = builder.create(loc, linear, sizes[r]); - } - valHeuristic = linear; + if (stt.isAllDense()) { + valHeuristic = dimSizes[0]; + for (const Value sz : ArrayRef{dimSizes}.drop_front()) + valHeuristic = builder.create(loc, valHeuristic, sz); } else if (sizeHint) { - if (getCOOStart(enc) == 0) { + if (getCOOStart(stt.getEncoding()) == 0) { ptrHeuristic = constantIndex(builder, loc, 2); idxHeuristic = builder.create( - loc, constantIndex(builder, loc, rank), sizeHint); // AOS - } else if (rank == 2 && isDenseDim(rtp, 0) && isCompressedDim(rtp, 1)) { + loc, constantIndex(builder, loc, dimRank), sizeHint); // AOS + } else if (dimRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) { ptrHeuristic = builder.create( loc, sizeHint, constantIndex(builder, loc, 1)); idxHeuristic = sizeHint; @@ -243,15 +246,15 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, } foreachFieldAndTypeInSparseTensor( - rtp, - [&builder, &fields, rtp, loc, ptrHeuristic, idxHeuristic, valHeuristic, - enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + stt, + [&builder, &fields, stt, loc, ptrHeuristic, idxHeuristic, valHeuristic, + enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, + Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); Value field; switch (fKind) { case SparseTensorFieldKind::StorageSpec: - field = SparseTensorSpecifier::getInitValue(builder, loc, rtp); + field = SparseTensorSpecifier::getInitValue(builder, loc, stt); break; case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: @@ -270,65 +273,66 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type, return true; }); - MutSparseTensorDescriptor desc(rtp, fields); + MutSparseTensorDescriptor desc(stt, fields); // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the // "linear + 1" length property. Value ptrZero = - constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType()); - for (unsigned r = 0; r < rank; r++) { - unsigned ro = toOrigDim(rtp, r); + constantZero(builder, loc, stt.getEncoding().getPointerType()); + for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) { // Fills dim sizes array. - desc.setDimSize(builder, loc, r, sizes[ro]); - + // FIXME: this method seems to set *level* sizes, but the name is confusing + // FIXME: `toOrigDim` is deprecated. + desc.setDimSize(builder, loc, l, dimSizes[toOrigDim(stt, l)]); // Pushes a leading zero to pointers memref. - if (isCompressedDim(rtp, r)) { - createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, r, + if (stt.isCompressedLvl(l)) + createPushback(builder, loc, desc, SparseTensorFieldKind::PtrMemRef, l, ptrZero); - } } allocSchemeForRank(builder, loc, desc, /*rank=*/0); } /// Helper method that generates block specific to compressed case: /// -/// plo = pointers[d][pos[d-1]] -/// phi = pointers[d][pos[d-1]+1] -/// msz = indices[d].size() +/// plo = pointers[l][pos[l-1]] +/// phi = pointers[l][pos[l-1]+1] +/// msz = indices[l].size() /// if (plo < phi) { -/// present = indices[d][phi-1] == i[d] +/// present = indices[l][phi-1] == i[l] /// } else { // first insertion /// present = false -/// pointers[d][pos[d-1]] = msz +/// pointers[l][pos[l-1]] = msz /// } /// if (present) { // index already present /// next = phi-1 /// } else { -/// indices[d].push_back(i[d]) -/// pointers[d][pos[d-1]+1] = msz+1 +/// indices[l].push_back(i[l]) +/// pointers[l][pos[l-1]+1] = msz+1 /// next = msz -/// +/// /// } -/// pos[d] = next +/// pos[l] = next static Value genCompressed(OpBuilder &builder, Location loc, - MutSparseTensorDescriptor desc, - SmallVectorImpl &indices, Value value, - Value pos, unsigned d) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); + MutSparseTensorDescriptor desc, ValueRange indices, + Value value, Value pos, Level lvl) { + const SparseTensorType stt(desc.getRankedTensorType()); + const Level lvlRank = stt.getLvlRank(); + assert(lvl < lvlRank && "Level is out of bounds"); + assert(indices.size() == static_cast(lvlRank) && + "Level-rank mismatch"); SmallVector types; Type indexType = builder.getIndexType(); Type boolType = builder.getIntegerType(1); unsigned idxIndex; unsigned idxStride; - std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(d); + std::tie(idxIndex, idxStride) = desc.getIdxMemRefIndexAndStride(lvl); Value one = constantIndex(builder, loc, 1); Value pp1 = builder.create(loc, pos, one); - Value plo = genLoad(builder, loc, desc.getPtrMemRef(d), pos); - Value phi = genLoad(builder, loc, desc.getPtrMemRef(d), pp1); - Value msz = desc.getIdxMemSize(builder, loc, d); + Value plo = genLoad(builder, loc, desc.getPtrMemRef(lvl), pos); + Value phi = genLoad(builder, loc, desc.getPtrMemRef(lvl), pp1); + Value msz = desc.getIdxMemSize(builder, loc, lvl); Value idxStrideC; if (idxStride > 1) { idxStrideC = constantIndex(builder, loc, idxStride); @@ -349,14 +353,13 @@ static Value genCompressed(OpBuilder &builder, Location loc, : phim1); Value eq = builder.create(loc, arith::CmpIPredicate::eq, toType(builder, loc, crd, indexType), - indices[d]); + indices[lvl]); builder.create(loc, eq); builder.setInsertionPointToStart(&ifOp1.getElseRegion().front()); - if (d > 0) - genStore(builder, loc, msz, desc.getPtrMemRef(d), pos); + if (lvl > 0) + genStore(builder, loc, msz, desc.getPtrMemRef(lvl), pos); builder.create(loc, constantI1(builder, loc, false)); builder.setInsertionPointAfter(ifOp1); - Value p = ifOp1.getResult(0); // If present construct. Note that for a non-unique dimension level, we // simply set the condition to false and rely on CSE/DCE to clean up the IR. // @@ -365,8 +368,8 @@ static Value genCompressed(OpBuilder &builder, Location loc, for (unsigned i = 0, e = desc.getNumFields(); i < e; i++) types.push_back(desc.getField(i).getType()); types.push_back(indexType); - if (!isUniqueDim(rtp, d)) - p = constantI1(builder, loc, false); + const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0) + : constantI1(builder, loc, false); scf::IfOp ifOp2 = builder.create(loc, types, p, /*else*/ true); // If present (fields unaffected, update next to phim1). builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); @@ -380,12 +383,12 @@ static Value genCompressed(OpBuilder &builder, Location loc, // If !present (changes fields, update next). builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); Value mszp1 = builder.create(loc, msz, one); - genStore(builder, loc, mszp1, desc.getPtrMemRef(d), pp1); - createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, - indices[d]); + genStore(builder, loc, mszp1, desc.getPtrMemRef(lvl), pp1); + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, lvl, + indices[lvl]); // Prepare the next dimension "as needed". - if ((d + 1) < rank) - allocSchemeForRank(builder, loc, desc, d + 1); + if ((lvl + 1) < lvlRank) + allocSchemeForRank(builder, loc, desc, lvl + 1); desc.getFields().push_back(msz); builder.create(loc, desc.getFields()); @@ -412,52 +415,52 @@ static Value genCompressed(OpBuilder &builder, Location loc, /// static void genInsertBody(OpBuilder &builder, ModuleOp module, func::FuncOp func, RankedTensorType rtp) { - OpBuilder::InsertionGuard insertionGuard(builder); - Block *entryBlock = func.addEntryBlock(); + const OpBuilder::InsertionGuard insertionGuard(builder); + Block *const entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); - - Location loc = func.getLoc(); - ValueRange args = entryBlock->getArguments(); - unsigned rank = rtp.getShape().size(); + const ValueRange args = entryBlock->getArguments(); + const Location loc = func.getLoc(); + const SparseTensorType stt(rtp); + const Level lvlRank = stt.getLvlRank(); // Construct fields and indices arrays from parameters. - ValueRange tmp = args.drop_back(rank + 1); - SmallVector fields(tmp.begin(), tmp.end()); + SmallVector fields = llvm::to_vector(args.drop_back(lvlRank + 1)); MutSparseTensorDescriptor desc(rtp, fields); - tmp = args.take_back(rank + 1).drop_back(); - SmallVector indices(tmp.begin(), tmp.end()); + const SmallVector indices = + llvm::to_vector(args.take_back(lvlRank + 1).drop_back()); Value value = args.back(); Value pos = constantZero(builder, loc, builder.getIndexType()); - // Generate code for every dimension. - for (unsigned d = 0; d < rank; d++) { - if (isCompressedDim(rtp, d)) { + // Generate code for every level. + for (Level l = 0; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Create: // if (!present) { - // indices[d].push_back(i[d]) - // + // indices[l].push_back(i[l]) + // // } - // pos[d] = indices.size() - 1 - // - pos = genCompressed(builder, loc, desc, indices, value, pos, d); - } else if (isSingletonDim(rtp, d)) { + // pos[l] = indices.size() - 1 + // + pos = genCompressed(builder, loc, desc, indices, value, pos, l); + } else if (isSingletonDLT(dlt)) { // Create: - // indices[d].push_back(i[d]) - // pos[d] = pos[d-1] - // - createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, d, - indices[d]); + // indices[l].push_back(i[l]) + // pos[l] = pos[l-1] + // + createPushback(builder, loc, desc, SparseTensorFieldKind::IdxMemRef, l, + indices[l]); } else { - assert(isDenseDim(rtp, d)); + assert(isDenseDLT(dlt)); // Construct the new position as: - // pos[d] = size * pos[d-1] + i[d] - // - Value size = sizeFromTensorAtLvl(builder, loc, desc, d); + // pos[l] = size * pos[l-1] + i[l] + // + Value size = sizeFromTensorAtLvl(builder, loc, desc, l); Value mult = builder.create(loc, size, pos); - pos = builder.create(loc, mult, indices[d]); + pos = builder.create(loc, mult, indices[l]); } } // Reached the actual value append/insert. - if (!isDenseDim(rtp, rank - 1)) + if (!stt.isDenseLvl(lvlRank - 1)) createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef, std::nullopt, value); else @@ -476,26 +479,24 @@ static void genInsertionCallHelper(OpBuilder &builder, // The mangled name of the function has this format: // ____ // __ - RankedTensorType rtp = desc.getTensorType(); + const SparseTensorType stt(desc.getRankedTensorType()); SmallString<32> nameBuffer; llvm::raw_svector_ostream nameOstream(nameBuffer); nameOstream << namePrefix; - unsigned rank = rtp.getShape().size(); - assert(rank == indices.size()); - for (unsigned d = 0; d < rank; d++) { - nameOstream << toMLIRString(getDimLevelType(rtp, d)) << "_"; - } + assert(static_cast(stt.getLvlRank()) == indices.size()); + const Level lvlRank = stt.getLvlRank(); + for (Level l = 0; l < lvlRank; l++) + nameOstream << toMLIRString(stt.getLvlType(l)) << "_"; // Static dim sizes are used in the generated code while dynamic sizes are // loaded from the dimSizes buffer. This is the reason for adding the shape // to the function name. - for (auto d : rtp.getShape()) - nameOstream << d << "_"; - SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + for (const auto sh : stt.getDimShape()) + nameOstream << sh << "_"; // Permutation information is also used in generating insertion. - if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) - nameOstream << enc.getDimOrdering() << "_"; - nameOstream << rtp.getElementType() << "_"; - nameOstream << enc.getIndexBitWidth() << "_" << enc.getPointerBitWidth(); + if (!stt.isIdentity()) + nameOstream << stt.getDimToLvlMap() << "_"; + nameOstream << stt.getElementType() << "_"; + nameOstream << stt.getIndexBitWidth() << "_" << stt.getPointerBitWidth(); // Look up the function. ModuleOp module = insertPoint->getParentOfType(); @@ -504,8 +505,8 @@ static void genInsertionCallHelper(OpBuilder &builder, auto func = module.lookupSymbol(result.getAttr()); // Construct parameters for fields and indices. - SmallVector operands(desc.getFields().begin(), desc.getFields().end()); - operands.append(indices.begin(), indices.end()); + SmallVector operands = llvm::to_vector(desc.getFields()); + operands.append(indices); operands.push_back(value); Location loc = insertPoint.getLoc(); @@ -519,7 +520,7 @@ static void genInsertionCallHelper(OpBuilder &builder, FunctionType::get(context, ValueRange(operands).getTypes(), ValueRange(desc.getFields()).getTypes())); func.setPrivate(); - createFunc(builder, module, func, rtp); + createFunc(builder, module, func, stt); } // Generate a call to perform the insertion and update `fields` with values @@ -533,20 +534,21 @@ static void genInsertionCallHelper(OpBuilder &builder, /// Generations insertion finalization code. static void genEndInsert(OpBuilder &builder, Location loc, SparseTensorDescriptor desc) { - RankedTensorType rtp = desc.getTensorType(); - unsigned rank = rtp.getShape().size(); - for (unsigned d = 0; d < rank; d++) { - if (isCompressedDim(rtp, d)) { + const SparseTensorType stt(desc.getRankedTensorType()); + const Level lvlRank = stt.getLvlRank(); + for (Level l = 0; l < lvlRank; l++) { + const auto dlt = stt.getLvlType(l); + if (isCompressedDLT(dlt)) { // Compressed dimensions need a pointer cleanup for all entries // that were not visited during the insertion pass. // // TODO: avoid cleanup and keep compressed scheme consistent at all // times? // - if (d > 0) { - Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); - Value ptrMemRef = desc.getPtrMemRef(d); - Value hi = desc.getPtrMemSize(builder, loc, d); + if (l > 0) { + Type ptrType = stt.getEncoding().getPointerType(); + Value ptrMemRef = desc.getPtrMemRef(l); + Value hi = desc.getPtrMemSize(builder, loc, l); Value zero = constantIndex(builder, loc, 0); Value one = constantIndex(builder, loc, 1); // Vector of only one, but needed by createFor's prototype. @@ -570,7 +572,7 @@ static void genEndInsert(OpBuilder &builder, Location loc, builder.setInsertionPointAfter(loop); } } else { - assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d)); + assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); } } } @@ -732,18 +734,25 @@ public: LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); - if (!enc) + const auto resType = getSparseTensorType(op); + if (!resType.hasEncoding()) return failure(); if (op.getCopy()) return rewriter.notifyMatchFailure(op, "tensor copy not implemented"); // Construct allocation for each field. - Location loc = op.getLoc(); - Value sizeHint = op.getSizeHint(); + const Location loc = op.getLoc(); + const Value sizeHint = op.getSizeHint(); + const ValueRange dynSizes = adaptor.getDynamicSizes(); + const size_t found = dynSizes.size(); + const int64_t expected = resType.getNumDynamicDims(); + if (found != static_cast(expected)) + return rewriter.notifyMatchFailure( + op, llvm::formatv( + "Got wrong number of dynamic sizes: Found={0}, Expected={1}", + found, expected)); SmallVector fields; - createAllocFields(rewriter, loc, resType, adaptor.getOperands(), + createAllocFields(rewriter, loc, resType, dynSizes, enableBufferInitialization, fields, sizeHint); // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); @@ -807,7 +816,7 @@ public: return failure(); Location loc = op->getLoc(); auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); - auto srcType = getRankedTensorType(op.getTensor()); + const auto srcType = getSparseTensorType(op.getTensor()); Type eltType = srcType.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); @@ -816,11 +825,12 @@ public: // Determine the size for access expansion (always the innermost stored // dimension size, translated back to original dimension). Note that we // recursively rewrite the new DimOp on the **original** tensor. - unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1); - auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); + // FIXME: `toOrigDim` is deprecated. + const Dimension innerDim = toOrigDim(srcType, srcType.getLvlRank() - 1); + const auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim); // Generate a memref for `sz` elements of type `t`. - auto genAlloc = [&](Type t) { - auto memTp = MemRefType::get({ShapedType::kDynamic}, t); + const auto genAlloc = [&](Type t) { + const auto memTp = MemRefType::get({ShapedType::kDynamic}, t); return rewriter.create(loc, memTp, ValueRange{sz}); }; // Allocate temporary buffers for values/filled-switch and added. @@ -862,14 +872,13 @@ public: Value filled = adaptor.getFilled(); Value added = adaptor.getAdded(); Value count = adaptor.getCount(); - RankedTensorType dstType = desc.getTensorType(); + const SparseTensorType dstType(desc.getRankedTensorType()); Type eltType = dstType.getElementType(); // Prepare indices. SmallVector indices(adaptor.getIndices()); - // If the innermost dimension is ordered, we need to sort the indices + // If the innermost level is ordered, we need to sort the indices // in the "added" array prior to applying the compression. - unsigned rank = dstType.getShape().size(); - if (isOrderedDim(dstType, rank - 1)) + if (dstType.isOrderedLvl(dstType.getLvlRank() - 1)) rewriter.create(loc, count, ValueRange{added}, ValueRange{}, SparseTensorSortKind::HybridQuickSort); // While performing the insertions, we also need to reset the elements @@ -1093,7 +1102,7 @@ struct SparsePackOpConverter : public OpConversionPattern { matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rtp = op.getResult().getType().cast(); + const auto rtp = getRankedTensorType(op.getResult()); assert(isUniqueCOOType(rtp)); SmallVector fields; @@ -1102,8 +1111,8 @@ struct SparsePackOpConverter : public OpConversionPattern { foreachFieldAndTypeInSparseTensor( rtp, [&rewriter, &fields, &op, rtp, - loc](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, + Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); auto enc = getSparseTensorEncoding(rtp); Value field; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 074a8de..cfd7bca 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" @@ -84,21 +85,20 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc, Value tensor, /// Looks up a level-size by returning a statically-computed constant /// (when possible), or by calling `genLvlSizeCall` (when dynamic). static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned lvl) { + SparseTensorType stt, Value tensor, + Level lvl) { // Only sparse tensors have "levels" to query. - assert(enc); - auto dimOrder = enc.getDimOrdering(); + assert(stt.hasEncoding()); // TODO: The following implementation only handles permutations; // we'll need to generalize this to handle arbitrary AffineExpr. // // There's no need to assert `isPermutation` here: because // `getDimPosition` checks that the expr isa `AffineDimExpr`, // which is all we care about (for supporting permutations). - unsigned dim = dimOrder ? dimOrder.getDimPosition(lvl) : lvl; - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); + const Dimension dim = + stt.isIdentity() ? lvl : stt.getDimToLvlMap().getDimPosition(lvl); + if (const auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); // If we cannot statically compute the size from the shape, then we // must dynamically query it. (In principle we could also dynamically // compute it, but since we already did so to construct the `tensor` @@ -111,89 +111,82 @@ static Value createOrFoldLvlCall(OpBuilder &builder, Location loc, /// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes /// of dense tensors). static Value createOrFoldDimCall(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, - Value tensor, unsigned dim) { - auto s = stp.getShape()[dim]; - if (s != ShapedType::kDynamic) - return constantIndex(builder, loc, s); - if (enc) + SparseTensorType stt, Value tensor, + Dimension dim) { + if (const auto sz = stt.getStaticDimSize(dim)) + return constantIndex(builder, loc, *sz); + if (stt.hasEncoding()) return genDimSizeCall(builder, loc, tensor, dim); return linalg::createOrFoldDimOp(builder, loc, tensor, dim); } /// Populates the array with the dimension-sizes of the given tensor. -static void fillDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, ShapedType stp, +static void fillDimSizes(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, SmallVectorImpl &out) { - unsigned dimRank = stp.getRank(); + const Dimension dimRank = stt.getDimRank(); + out.clear(); out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) - out.push_back(createOrFoldDimCall(builder, loc, enc, stp, tensor, d)); + for (Dimension d = 0; d < dimRank; d++) + out.push_back(createOrFoldDimCall(builder, loc, stt, tensor, d)); } /// Returns an array with the dimension-sizes of the given tensor. static SmallVector getDimSizes(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, - ShapedType stp, Value tensor) { + SparseTensorType stt, Value tensor) { SmallVector out; - fillDimSizes(builder, loc, enc, stp, tensor, out); + fillDimSizes(builder, loc, stt, tensor, out); return out; } -/// Populates the array with the dimension-shape of the given `ShapedType`, -/// where dynamic sizes are represented by zero. -static void fillDimShape(OpBuilder &builder, Location loc, ShapedType stp, +/// Populates the array with the dimension-shape of the given +/// `SparseTensorType`, where dynamic sizes are represented by zero. +static void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt, SmallVectorImpl &out) { - auto shape = stp.getShape(); - unsigned dimRank = stp.getRank(); - out.reserve(dimRank); - for (unsigned d = 0; d < dimRank; d++) { - auto s = shape[d] == ShapedType::kDynamic ? 0 : shape[d]; + out.clear(); + out.reserve(stt.getDimRank()); + for (const DynSize sh : stt.getDimShape()) { + const auto s = ShapedType::isDynamic(sh) ? 0 : sh; out.push_back(constantIndex(builder, loc, s)); } } -/// Returns an array with the dimension-shape of the given `ShapedType`, +/// Returns an array with the dimension-shape of the given `SparseTensorType`, /// where dynamic sizes are represented by zero. static SmallVector getDimShape(OpBuilder &builder, Location loc, - ShapedType stp) { + SparseTensorType stt) { SmallVector out; - fillDimShape(builder, loc, stp, out); + fillDimShape(builder, loc, stt, out); return out; } /// Populates the given sizes array for concatenation from type (for static /// sizes) and from an already-converted opaque pointer source (for dynamic /// sizes). -static void concatSizesFromInputs(OpBuilder &builder, - SmallVectorImpl &sizes, Location loc, - ShapedType dstTp, ValueRange srcs, - unsigned dim) { - auto dstShape = dstTp.getShape(); - - auto srcTp = srcs[0].getType().cast(); - auto srcEnc = getSparseTensorEncoding(srcTp); +static void concatDimSizesFromInputs(OpBuilder &builder, Location loc, + SparseTensorType dstTp, ValueRange srcs, + Dimension dim, + SmallVectorImpl &dimSizes) { + assert(dim < dstTp.getDimRank() && "Dimension is out of bounds"); + dimSizes.clear(); + // We first fills the sizes from an input tensor, and then // compute the size of the concatenation dimension if necessary. - if (srcEnc) + const auto srcTp = getSparseTensorType(srcs[0]); + if (srcTp.hasEncoding()) // Reuses sizes from an arbitrary input tensor is fine. - fillDimSizes(builder, loc, srcEnc, srcTp, srcs[0], sizes); + fillDimSizes(builder, loc, srcTp, srcs[0], dimSizes); else - sizesFromSrc(builder, sizes, loc, srcs[0]); + sizesFromSrc(builder, dimSizes, loc, srcs[0]); - // Sum up on the `dim` if the dimension is dynamic. - if (dstShape[dim] != ShapedType::kDynamic) { + if (const auto sz = dstTp.getStaticDimSize(dim)) { // Faithfully take the static size. - sizes[dim] = constantIndex(builder, loc, dstShape[dim]); + dimSizes[dim] = constantIndex(builder, loc, *sz); } else { - // Else, compute the shape dynamically. - for (size_t i = 1, sz = srcs.size(); i < sz; i++) { - auto srcTp = srcs[i].getType().cast(); - auto encSrc = getSparseTensorEncoding(srcTp); - Value srcSz = - createOrFoldDimCall(builder, loc, encSrc, srcTp, srcs[i], dim); - // Sum up all the sizes. - sizes[dim] = builder.create(loc, sizes[dim], srcSz); + // Else, dynamically compute the size. + for (const auto src : srcs.drop_front()) { + const auto srcTp = getSparseTensorType(src); + Value srcSz = createOrFoldDimCall(builder, loc, srcTp, src, dim); + dimSizes[dim] = builder.create(loc, dimSizes[dim], srcSz); } } } @@ -209,11 +202,10 @@ static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc) { + SparseTensorType stt) { SmallVector lvlTypes; - auto dlts = enc.getDimLevelType(); - lvlTypes.reserve(dlts.size()); - for (auto dlt : dlts) + lvlTypes.reserve(stt.getLvlRank()); + for (const auto dlt : stt.getEncoding().getDimLevelType()) lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); return allocaBuffer(builder, loc, lvlTypes); } @@ -235,8 +227,7 @@ public: /// MLIR buffers as needed, and returning `this` for method chaining. /// This method does not set the action and pointer arguments, since /// those are handled by `genNewCall` instead. - NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes, - ShapedType stp); + NewCallParams &genBuffers(SparseTensorType stt, ValueRange dimSizes); /// (Re)sets the C++ template type parameters, and returns `this` /// for method chaining. This is already done as part of `genBuffers`, @@ -246,12 +237,12 @@ public: // // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`; // is there a better way to handle that than this one-off setter method? - NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc, - ShapedType stp) { + NewCallParams &setTemplateTypes(SparseTensorType stt) { + const auto enc = stt.getEncoding(); params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc); params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc); params[kParamValTp] = - constantPrimaryTypeEncoding(builder, loc, stp.getElementType()); + constantPrimaryTypeEncoding(builder, loc, stt.getElementType()); return *this; } @@ -308,15 +299,16 @@ private: // TODO: see the note at `_mlir_ciface_newSparseTensor` about how // the meaning of the various arguments (e.g., "sizes" vs "shapes") // is inconsistent between the different actions. -NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, - ValueRange dimSizes, ShapedType stp) { - const unsigned lvlRank = enc.getDimLevelType().size(); - const unsigned dimRank = stp.getRank(); +NewCallParams &NewCallParams::genBuffers(SparseTensorType stt, + ValueRange dimSizes) { + const Level lvlRank = stt.getLvlRank(); + const Dimension dimRank = stt.getDimRank(); // Sparsity annotations. - params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, enc); + params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt); // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. - assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); + assert(dimSizes.size() == static_cast(dimRank) && + "Dimension-rank mismatch"); params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes); // The level-sizes array must be passed as well, since for arbitrary // dim2lvl mappings it cannot be trivially reconstructed at runtime. @@ -329,29 +321,31 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, // `operator[]` assignment. We preinitialize `lvl2dim` for code symmetry. SmallVector dim2lvl(dimRank); SmallVector lvl2dim(lvlRank); - auto dimOrder = enc.getDimOrdering(); - if (dimOrder) { + if (!stt.isIdentity()) { + const auto dimOrder = stt.getDimToLvlMap(); assert(dimOrder.isPermutation()); - for (unsigned l = 0; l < lvlRank; l++) { + for (Level l = 0; l < lvlRank; l++) { // The `d`th source variable occurs in the `l`th result position. - uint64_t d = dimOrder.getDimPosition(l); + const Dimension d = dimOrder.getDimPosition(l); dim2lvl[d] = constantIndex(builder, loc, l); lvl2dim[l] = constantIndex(builder, loc, d); lvlSizes[l] = dimSizes[d]; } } else { - assert(dimRank == lvlRank && "Rank mismatch"); - for (unsigned i = 0; i < lvlRank; i++) { - dim2lvl[i] = lvl2dim[i] = constantIndex(builder, loc, i); - lvlSizes[i] = dimSizes[i]; + // The `SparseTensorType` ctor already ensures `dimRank == lvlRank` + // when `isIdentity`; so no need to re-assert it here. + for (Level l = 0; l < lvlRank; l++) { + dim2lvl[l] = lvl2dim[l] = constantIndex(builder, loc, l); + lvlSizes[l] = dimSizes[l]; } } params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim); - params[kParamDim2Lvl] = - dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; + params[kParamDim2Lvl] = stt.isIdentity() + ? params[kParamLvl2Dim] + : allocaBuffer(builder, loc, dim2lvl); // Secondary and primary types encoding. - setTemplateTypes(enc, stp); + setTemplateTypes(stt); // Finally, make note that initialization is complete. assert(isInitialized() && "Initialization failed"); // And return `this` for method chaining. @@ -441,8 +435,7 @@ static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc, /// given target `dimTypes`. static bool canUseDirectConversion(ArrayRef dimTypes) { bool alreadyCompressed = false; - for (uint64_t rank = dimTypes.size(), r = 0; r < rank; r++) { - const DimLevelType dlt = dimTypes[r]; + for (const auto dlt : dimTypes) { if (isCompressedDLT(dlt)) { if (alreadyCompressed) return false; // Multiple compressed dimensions not yet supported. @@ -467,13 +460,14 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, TensorType dstTp, TensorType srcTp, Value dstIdx, Value srcIdx, ArrayRef dstShape, ArrayRef srcShape) { - unsigned dstRank = dstTp.getRank(); - unsigned srcRank = srcTp.getRank(); + const Dimension dstRank = dstTp.getRank(); + const Dimension srcRank = srcTp.getRank(); SmallVector srcIndices; - for (unsigned i = 0; i < srcRank; i++) { + srcIndices.reserve(srcRank); + for (Dimension d = 0; d < srcRank; d++) { Value idx = rewriter.create( - loc, srcIdx, constantIndex(rewriter, loc, i)); + loc, srcIdx, constantIndex(rewriter, loc, d)); srcIndices.push_back(idx); } @@ -481,9 +475,9 @@ static void translateIndices(Location loc, ConversionPatternRewriter &rewriter, translateIndicesArray(rewriter, loc, reassociation, srcIndices, srcShape, dstShape, dstIndices); - for (unsigned i = 0; i < dstRank; i++) - rewriter.create(loc, dstIndices[i], dstIdx, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dstRank; d++) + rewriter.create(loc, dstIndices[d], dstIdx, + constantIndex(rewriter, loc, d)); } /// Generate code for a general sparse to sparse reshaping operation. @@ -505,37 +499,34 @@ static LogicalResult genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) { Location loc = op.getLoc(); - auto srcTp = getRankedTensorType(op.getSrc()); - auto dstTp = getRankedTensorType(op.getResult()); - auto encSrc = getSparseTensorEncoding(srcTp); - auto encDst = getSparseTensorEncoding(dstTp); - if (!encDst || !encSrc) + const auto srcTp = getSparseTensorType(op.getSrc()); + const auto dstTp = getSparseTensorType(op.getResult()); + if (!srcTp.hasEncoding() || !dstTp.hasEncoding()) return failure(); Type elemTp = srcTp.getElementType(); assert(elemTp == dstTp.getElementType() && "reshape should not change element type"); // Start an iterator over the source tensor (in original index order). - const auto noPerm = encSrc.withoutOrdering(); SmallVector srcDimSizes = - getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc()); + getDimSizes(rewriter, loc, srcTp, adaptor.getSrc()); NewCallParams params(rewriter, loc); - Value iter = params.genBuffers(noPerm, srcDimSizes, srcTp) + Value iter = params.genBuffers(srcTp.withoutOrdering(), srcDimSizes) .genNewCall(Action::kToIterator, adaptor.getSrc()); // Start a new COO for the destination tensor. SmallVector dstDimSizes; - if (dstTp.hasStaticShape()) + if (dstTp.hasStaticDimShape()) // Static "shapes" are in fact "sizes". fillDimShape(rewriter, loc, dstTp, dstDimSizes); else genReshapeDstShape(loc, rewriter, dstDimSizes, srcDimSizes, - dstTp.getShape(), op.getReassociationIndices()); - Value coo = params.genBuffers(encDst, dstDimSizes, dstTp) - .genNewCall(Action::kEmptyCOO); + dstTp.getDimShape(), op.getReassociationIndices()); + Value coo = + params.genBuffers(dstTp, dstDimSizes).genNewCall(Action::kEmptyCOO); Value dstPerm = params.getDim2LvlMap(); // Construct a while loop over the iterator. Type iTp = rewriter.getIndexType(); - Value srcIdx = genAlloca(rewriter, loc, srcTp.getRank(), iTp); - Value dstIdx = genAlloca(rewriter, loc, dstTp.getRank(), iTp); + Value srcIdx = genAlloca(rewriter, loc, srcTp.getDimRank(), iTp); + Value dstIdx = genAlloca(rewriter, loc, dstTp.getDimRank(), iTp); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); SmallVector noArgs; SmallVector noTypes; @@ -571,23 +562,22 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor, // TODO: rename to `genSparseIterationLoop`? static void genSparseCOOIterationLoop( ConversionPatternRewriter &rewriter, Location loc, Value t, - RankedTensorType tensorTp, + SparseTensorType stt, function_ref bodyBuilder) { - auto enc = getSparseTensorEncoding(tensorTp); - assert(enc && "Generating Sparse Tensor COO Loop on a Dense Tensor!"); - - unsigned rank = tensorTp.getRank(); - Type elemTp = tensorTp.getElementType(); + assert(stt.hasEncoding() && + "Generating Sparse Tensor COO Loop on a Dense Tensor!"); + const Dimension dimRank = stt.getDimRank(); + const Type elemTp = stt.getElementType(); // Start an iterator over the tensor (in original index order). - const auto noPerm = enc.withoutOrdering(); - SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t); + const auto noPerm = stt.withoutOrdering(); + SmallVector dimSizes = getDimSizes(rewriter, loc, noPerm, t); Value iter = NewCallParams(rewriter, loc) - .genBuffers(noPerm, dimSizes, tensorTp) + .genBuffers(noPerm, dimSizes) .genNewCall(Action::kToIterator, t); // Construct a while loop over the iterator. - Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value srcIdx = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); SmallVector noArgs; SmallVector noTypes; @@ -599,8 +589,8 @@ static void genSparseCOOIterationLoop( Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - bool hasDenseDim = llvm::any_of( - enc.getDimLevelType(), [](DimLevelType dlt) { return isDenseDLT(dlt); }); + const bool hasDenseDim = + llvm::any_of(stt.getEncoding().getDimLevelType(), isDenseDLT); if (hasDenseDim) { Value elemV = rewriter.create(loc, elemPtr); Value isZero = genIsNonzero(rewriter, loc, elemV); @@ -633,12 +623,12 @@ static void genSparseCOOIterationLoop( // reduce code repetition! static void genDenseTensorIterationLoop( ConversionPatternRewriter &rewriter, Location loc, Value t, - RankedTensorType tensorTp, + SparseTensorType stt, function_ref bodyBuilder) { - assert(!getSparseTensorEncoding(tensorTp) && + assert(!stt.hasEncoding() && "Generating Dense Tensor Loop on a Sparse Tensor!"); - unsigned rank = tensorTp.getRank(); + const Dimension dimRank = stt.getDimRank(); Value zero = constantIndex(rewriter, loc, 0); Value one = constantIndex(rewriter, loc, 1); @@ -647,9 +637,9 @@ static void genDenseTensorIterationLoop( SmallVector st; // Fill out loop iteration information. - for (unsigned i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, i)); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d)); st.push_back(one); } @@ -686,10 +676,9 @@ public: LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto stp = op.getSource().getType().cast(); + const auto stt = getSparseTensorType(op.getSource()); // Only rewrite sparse DimOp. - auto enc = getSparseTensorEncoding(stp); - if (!enc) + if (!stt.hasEncoding()) return failure(); // Only rewrite DimOp with constant index. std::optional dim = op.getConstantIndex(); @@ -698,7 +687,7 @@ public: // Generate the call. Value src = adaptor.getOperands()[0]; rewriter.replaceOp( - op, createOrFoldDimCall(rewriter, op->getLoc(), enc, stp, src, *dim)); + op, createOrFoldDimCall(rewriter, op->getLoc(), stt, src, *dim)); return success(); } }; @@ -741,21 +730,19 @@ public: matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto stp = op.getType().cast(); - auto enc = getSparseTensorEncoding(stp); - if (!enc) + const auto stt = getSparseTensorType(op); + if (!stt.hasEncoding()) return failure(); - const unsigned dimRank = stp.getRank(); - const unsigned lvlRank = enc.getDimLevelType().size(); + const Dimension dimRank = stt.getDimRank(); + const Level lvlRank = stt.getLvlRank(); // Construct the dimShape. - const auto dimShape = stp.getShape(); - SmallVector dimShapeValues = getDimShape(rewriter, loc, stp); + SmallVector dimShapeValues = getDimShape(rewriter, loc, stt); Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues); // Allocate `SparseTensorReader` and perform all initial setup that // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). Type opaqueTp = getOpaquePointerType(rewriter); Value valTp = - constantPrimaryTypeEncoding(rewriter, loc, stp.getElementType()); + constantPrimaryTypeEncoding(rewriter, loc, stt.getElementType()); Value reader = createFuncCall(rewriter, loc, "createCheckedSparseTensorReader", opaqueTp, @@ -773,7 +760,7 @@ public: // // FIXME: reduce redundancy vs `NewCallParams::genBuffers`. Value dimSizesBuffer; - if (!stp.hasStaticShape()) { + if (stt.hasDynamicDimShape()) { Type indexTp = rewriter.getIndexType(); auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp); dimSizesBuffer = @@ -784,22 +771,23 @@ public: Value lvlSizesBuffer; Value lvl2dimBuffer; Value dim2lvlBuffer; - if (auto dimOrder = enc.getDimOrdering()) { + if (!stt.isIdentity()) { + const auto dimOrder = stt.getDimToLvlMap(); assert(dimOrder.isPermutation() && "Got non-permutation"); // We preinitialize `dim2lvlValues` since we need random-access writing. // And we preinitialize the others for stylistic consistency. SmallVector lvlSizeValues(lvlRank); SmallVector lvl2dimValues(lvlRank); SmallVector dim2lvlValues(dimRank); - for (unsigned l = 0; l < lvlRank; l++) { + for (Level l = 0; l < lvlRank; l++) { // The `d`th source variable occurs in the `l`th result position. - uint64_t d = dimOrder.getDimPosition(l); + Dimension d = dimOrder.getDimPosition(l); Value lvl = constantIndex(rewriter, loc, l); Value dim = constantIndex(rewriter, loc, d); dim2lvlValues[d] = lvl; lvl2dimValues[l] = dim; lvlSizeValues[l] = - (dimShape[d] == ShapedType::kDynamic) + stt.isDynamicDim(d) ? rewriter.create(loc, dimSizesBuffer, dim) : dimShapeValues[d]; } @@ -807,11 +795,12 @@ public: lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues); dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues); } else { - assert(dimRank == lvlRank && "Rank mismatch"); + // The `SparseTensorType` ctor already ensures `dimRank == lvlRank` + // when `isIdentity`; so no need to re-assert it here. SmallVector iotaValues; iotaValues.reserve(lvlRank); - for (unsigned i = 0; i < lvlRank; i++) - iotaValues.push_back(constantIndex(rewriter, loc, i)); + for (Level l = 0; l < lvlRank; l++) + iotaValues.push_back(constantIndex(rewriter, loc, l)); lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer; dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues); } @@ -819,11 +808,11 @@ public: SmallVector params{ reader, lvlSizesBuffer, - genLvlTypesBuffer(rewriter, loc, enc), + genLvlTypesBuffer(rewriter, loc, stt), lvl2dimBuffer, dim2lvlBuffer, - constantPointerTypeEncoding(rewriter, loc, enc), - constantIndexTypeEncoding(rewriter, loc, enc), + constantPointerTypeEncoding(rewriter, loc, stt.getEncoding()), + constantIndexTypeEncoding(rewriter, loc, stt.getEncoding()), valTp}; Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader", opaqueTp, params, EmitCInterface::On) @@ -848,27 +837,25 @@ public: return rewriter.notifyMatchFailure(op, "sparse tensor copy not implemented"); Location loc = op.getLoc(); - RankedTensorType resType = op.getType(); - auto enc = getSparseTensorEncoding(resType); - if (!enc) + const auto stt = getSparseTensorType(op); + if (!stt.hasEncoding()) return failure(); // Gather all dimension sizes as SSA values. - SmallVector sizes; - unsigned int operandCtr = 0; - for (int64_t i = 0; i < resType.getRank(); ++i) { - if (resType.isDynamicDim(i)) { - sizes.push_back(adaptor.getOperands()[operandCtr++]); - } else { - sizes.push_back( - rewriter.create(loc, op.getStaticSize(i))); - } + const Dimension dimRank = stt.getDimRank(); + SmallVector dimSizes; + dimSizes.reserve(dimRank); + unsigned operandCtr = 0; + for (Dimension d = 0; d < dimRank; ++d) { + dimSizes.push_back( + stt.isDynamicDim(d) + ? adaptor.getOperands()[operandCtr++] + : constantIndex(rewriter, loc, op.getStaticSize(d))); } // Generate the call to construct empty tensor. The sizes are // explicitly defined by the arguments to the alloc operator. - rewriter.replaceOp(op, - NewCallParams(rewriter, loc) - .genBuffers(enc, sizes, resType.cast()) - .genNewCall(Action::kEmpty)); + rewriter.replaceOp(op, NewCallParams(rewriter, loc) + .genBuffers(stt, dimSizes) + .genNewCall(Action::kEmpty)); return success(); } }; @@ -887,27 +874,30 @@ public: LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - auto resType = getRankedTensorType(op); - auto srcType = getRankedTensorType(op.getSource()); - auto encDst = getSparseTensorEncoding(resType); - auto encSrc = getSparseTensorEncoding(srcType); - Value src = adaptor.getOperands()[0]; - if (encDst && encSrc) { + const Location loc = op->getLoc(); + const auto srcTp = getSparseTensorType(op.getSource()); + const auto dstTp = getSparseTensorType(op); + if (!srcTp.hasEncoding() && !dstTp.hasEncoding()) + return failure(); + + const Dimension dimRank = srcTp.getDimRank(); + const Type elemTp = srcTp.getElementType(); + const Value src = adaptor.getOperands()[0]; + if (srcTp.hasEncoding() && dstTp.hasEncoding()) { + const auto srcEnc = srcTp.getEncoding(); + const auto dstEnc = dstTp.getEncoding(); // This is a sparse => sparse conversion, which is handled as follows: // t = src->toCOO(); ; src to COO in dst order // dst = newSparseTensor(t) // Using the coordinate scheme as an intermediate does not always // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. - if (encDst == encSrc) { + if (dstEnc == srcEnc) { rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast return success(); } NewCallParams params(rewriter, loc); - ShapedType stp = srcType.cast(); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, stp, src); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); bool useDirectConversion; switch (options.sparseToSparseStrategy) { case SparseToSparseConversionStrategy::kViaCOO: @@ -915,37 +905,39 @@ public: break; case SparseToSparseConversionStrategy::kDirect: useDirectConversion = true; - assert(canUseDirectConversion(encDst.getDimLevelType()) && + assert(canUseDirectConversion(dstEnc.getDimLevelType()) && "Unsupported target for direct sparse-to-sparse conversion"); break; case SparseToSparseConversionStrategy::kAuto: - useDirectConversion = canUseDirectConversion(encDst.getDimLevelType()); + useDirectConversion = canUseDirectConversion(dstEnc.getDimLevelType()); break; } if (useDirectConversion) { - rewriter.replaceOp(op, params.genBuffers(encDst, dimSizes, stp) - .genNewCall(Action::kSparseToSparse, src)); + rewriter.replaceOp( + op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes) + .genNewCall(Action::kSparseToSparse, src)); } else { // use via-COO conversion. // Set up encoding with right mix of src and dst so that the two // method calls can share most parameters, while still providing // the correct sparsity information to either of them. - auto enc = SparseTensorEncodingAttr::get( - op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(), - encDst.getHigherOrdering(), encSrc.getPointerBitWidth(), - encSrc.getIndexBitWidth()); + const auto mixedEnc = SparseTensorEncodingAttr::get( + op->getContext(), dstEnc.getDimLevelType(), dstEnc.getDimOrdering(), + dstEnc.getHigherOrdering(), srcEnc.getPointerBitWidth(), + srcEnc.getIndexBitWidth()); // TODO: This is the only place where `kToCOO` (or `kToIterator`) // is called with a non-identity permutation. Is there any clean // way to push the permutation over to the `kFromCOO` side instead? - Value coo = params.genBuffers(enc, dimSizes, stp) + Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes) .genNewCall(Action::kToCOO, src); - Value dst = params.setTemplateTypes(encDst, stp) + Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc)) .genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, stp.getElementType(), coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); } return success(); } - if (!encDst && encSrc) { + if (srcTp.hasEncoding() && !dstTp.hasEncoding()) { + const auto srcEnc = srcTp.getEncoding(); // This is sparse => dense conversion, which is handled as follows: // dst = new Tensor(0); // iter = new SparseTensorIterator(src); @@ -953,26 +945,24 @@ public: // dst[elem.indices] = elem.value; // } // delete iter; - const unsigned rank = resType.getRank(); - const Type elemTp = resType.getElementType(); + // // Fabricate a no-permutation encoding for NewCallParams // The pointer/index types must be those of `src`. // The dimLevelTypes aren't actually used by Action::kToIterator. - encDst = SparseTensorEncodingAttr::get( + const auto dstEnc = SparseTensorEncodingAttr::get( op->getContext(), - SmallVector(rank, DimLevelType::Dense), AffineMap(), - AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcType, src); + SmallVector(dimRank, DimLevelType::Dense), AffineMap(), + AffineMap(), srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); Value iter = NewCallParams(rewriter, loc) - .genBuffers(encDst, dimSizes, resType) + .genBuffers(dstTp.withEncoding(dstEnc), dimSizes) .genNewCall(Action::kToIterator, src); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + Value ind = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); Block *insertionBlock = rewriter.getInsertionBlock(); // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - Value dst = allocDenseTensor(rewriter, loc, resType, dimSizes); + Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); SmallVector noArgs; SmallVector noTypes; auto whileOp = rewriter.create(loc, noTypes, noArgs); @@ -982,12 +972,13 @@ public: rewriter.create(loc, cond, before->getArguments()); Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); - SmallVector ivs = loadIndices(rewriter, loc, rank, ind); + SmallVector ivs = loadIndices(rewriter, loc, dimRank, ind); insertScalarIntoDenseTensor(rewriter, loc, elemPtr, dst, ivs); rewriter.create(loc); rewriter.setInsertionPointAfter(whileOp); genDelIteratorCall(rewriter, loc, elemTp, iter); - rewriter.replaceOpWithNewOp(op, resType, dst); + rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), dst); // Deallocate the buffer. if (bufferization::allocationDoesNotEscape(op->getOpResult(0))) { rewriter.setInsertionPoint(insertionBlock->getTerminator()); @@ -995,10 +986,7 @@ public: } return success(); } - if (!encDst && !encSrc) { - // dense => dense - return failure(); - } + assert(!srcTp.hasEncoding() && dstTp.hasEncoding()); // This is a dense => sparse conversion or a sparse constant in COO => // sparse conversion, which is handled as follows: // t = newSparseCOO() @@ -1025,30 +1013,27 @@ public: // Also note that the code below only generates the "new" ops and // the loop-nest per se; whereas the entire body of the innermost // loop is generated by genAddElt(). - ShapedType stp = resType.cast(); - unsigned rank = stp.getRank(); - SmallVector sizes; - sizesFromSrc(rewriter, sizes, loc, src); + SmallVector dimSizes; + sizesFromSrc(rewriter, dimSizes, loc, src); NewCallParams params(rewriter, loc); Value coo = - params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO); - Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + params.genBuffers(dstTp, dimSizes).genNewCall(Action::kEmptyCOO); + Value ind = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); Value perm = params.getDim2LvlMap(); - Type eltType = stp.getElementType(); - Value elemPtr = genAllocaScalar(rewriter, loc, eltType); + Value elemPtr = genAllocaScalar(rewriter, loc, elemTp); genDenseTensorOrSparseConstantIterLoop( - rewriter, loc, src, rank, - [&](OpBuilder &builder, Location loc, Value val, ValueRange indices) { - for (unsigned i = 0; i < rank; i++) { - Value idx = constantIndex(builder, loc, i); - builder.create(loc, indices[i], ind, idx); + rewriter, loc, src, dimRank, + [&](OpBuilder &builder, Location loc, Value val, ValueRange ivs) { + for (Dimension d = 0; d < dimRank; d++) { + Value dim = constantIndex(builder, loc, d); + builder.create(loc, ivs[d], ind, dim); } builder.create(loc, val, elemPtr); - genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm); + genAddEltCall(builder, loc, elemTp, coo, elemPtr, ind, perm); }); // Final call to construct sparse tensor storage. Value dst = params.genNewCall(Action::kFromCOO, coo); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.replaceOp(op, dst); return success(); } @@ -1066,8 +1051,7 @@ public: LogicalResult matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto enc = getSparseTensorEncoding(op.getTensor().getType()); - if (!enc) + if (!getSparseTensorType(op.getTensor()).hasEncoding()) return failure(); StringRef name = "delSparseTensor"; createFuncCall(rewriter, op->getLoc(), name, {}, adaptor.getOperands(), @@ -1190,14 +1174,14 @@ public: // index order. All values are passed by reference through stack // allocated memrefs. Location loc = op->getLoc(); - auto tp = getRankedTensorType(op.getTensor()); - auto elemTp = tp.getElementType(); - unsigned rank = tp.getRank(); - auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + const auto stt = getSparseTensorType(op.getTensor()); + const auto elemTp = stt.getElementType(); + const Dimension dimRank = stt.getDimRank(); + auto mref = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); auto vref = genAllocaScalar(rewriter, loc, elemTp); - for (unsigned i = 0; i < rank; i++) - rewriter.create(loc, adaptor.getIndices()[i], mref, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dimRank; d++) + rewriter.create(loc, adaptor.getIndices()[d], mref, + constantIndex(rewriter, loc, d)); rewriter.create(loc, adaptor.getValue(), vref); SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, {adaptor.getTensor(), mref, vref}, @@ -1215,18 +1199,15 @@ public: matchAndRewrite(ExpandOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto srcType = getRankedTensorType(op.getTensor()); - Type eltType = srcType.getElementType(); + const auto srcTp = getSparseTensorType(op.getTensor()); + Type eltType = srcTp.getElementType(); Type boolType = rewriter.getIntegerType(1); Type idxType = rewriter.getIndexType(); // All initialization should be done on entry of the loop nest. rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp()); // Get the cardinality of valid coordinates for the innermost level. - auto srcEnc = getSparseTensorEncoding(srcType); - unsigned lvlRank = - srcEnc ? srcEnc.getDimLevelType().size() : srcType.getRank(); - Value sz = createOrFoldLvlCall(rewriter, loc, srcEnc, srcType, - adaptor.getTensor(), lvlRank - 1); + Value sz = createOrFoldLvlCall(rewriter, loc, srcTp, adaptor.getTensor(), + srcTp.getLvlRank() - 1); // Allocate temporary buffers for values, filled-switch, and indices. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). @@ -1269,13 +1250,13 @@ public: Value added = adaptor.getAdded(); Value count = adaptor.getCount(); Value tensor = adaptor.getTensor(); - auto tp = getRankedTensorType(op.getTensor()); - Type elemTp = tp.getElementType(); - unsigned rank = tp.getRank(); - auto mref = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); - for (unsigned i = 0; i < rank - 1; i++) - rewriter.create(loc, adaptor.getIndices()[i], mref, - constantIndex(rewriter, loc, i)); + const auto stt = getSparseTensorType(op.getTensor()); + const Type elemTp = stt.getElementType(); + const Dimension dimRank = stt.getDimRank(); + auto mref = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); + for (Dimension d = 0; d < dimRank - 1; d++) + rewriter.create(loc, adaptor.getIndices()[d], mref, + constantIndex(rewriter, loc, d)); SmallString<12> name{"expInsert", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, {tensor, mref, values, filled, added, count}, @@ -1323,34 +1304,33 @@ public: // a[ adjustForOffset(elem.indices) ] = elem.value // return a Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op); - auto encDst = getSparseTensorEncoding(dstTp); - Type elemTp = dstTp.getElementType(); - uint64_t concatDim = op.getDimension().getZExtValue(); - unsigned rank = dstTp.getRank(); + const auto dstTp = getSparseTensorType(op); + const auto dstEnc = dstTp.getEncoding(); + const Type elemTp = dstTp.getElementType(); + const Dimension concatDim = op.getDimension().getZExtValue(); + const Dimension dimRank = dstTp.getDimRank(); Value dst; // destination tensor Value dstPerm; // destination tensor permutation (if sparse out) // A pointer to the value being inserted (if dense => sparse) Value elemPtr; - // Memory that holds the COO for destination tensor (if sparse out) - Value dstIdx; + // Memory that holds the dim-indices for destination tensor (if sparse out) + Value dstInd; // The offset applied to the dimenstion to be concated (starting from 0) Value offset = constantIndex(rewriter, loc, 0); - SmallVector sizes; - NewCallParams params(rewriter, loc); - concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), - concatDim); + SmallVector dimSizes; + concatDimSizesFromInputs(rewriter, loc, dstTp, op.getInputs(), concatDim, + dimSizes); - bool allDense = false; + NewCallParams params(rewriter, loc); + const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense(); Value dstTensor; - if (encDst) { - allDense = encDst.isAllDense(); + if (dstTp.hasEncoding()) { // Start a new COO or an initialized annotated all dense sparse tensor. - dst = params.genBuffers(encDst, sizes, dstTp) + dst = params.genBuffers(dstTp, dimSizes) .genNewCall(allDense ? Action::kEmpty : Action::kEmptyCOO); - dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType()); + dstInd = genAlloca(rewriter, loc, dimRank, rewriter.getIndexType()); if (allDense) { dstTensor = dst; // Get the values buffer for the sparse tensor and reshape it to the @@ -1358,8 +1338,9 @@ public: dst = genValuesCall(rewriter, loc, MemRefType::get({ShapedType::kDynamic}, elemTp), {dst}); - // Use the dstIdx to store the level sizes. - dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, dst, dstIdx); + // Use the dstInd to store the level sizes. + dst = + reshapeValuesToLevels(rewriter, loc, dstEnc, dimSizes, dst, dstInd); } else { dstPerm = params.getDim2LvlMap(); elemPtr = genAllocaScalar(rewriter, loc, elemTp); @@ -1367,74 +1348,73 @@ public: } else { // TODO: Dense buffers should be allocated/deallocated via the callback // in BufferizationOptions. - dst = allocDenseTensor(rewriter, loc, dstTp, sizes); + dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes); } - auto dimIdx2LvlIdx = [&](ValueRange dIdx) -> SmallVector { - SmallVector lIdx; - for (unsigned i = 0; i < dIdx.size(); i++) - lIdx.push_back(dIdx[toOrigDim(encDst, i)]); - return lIdx; + const Level lvlRank = dstTp.getLvlRank(); + const auto dimIvs2LvlIvs = [&](ValueRange dimIvs) -> SmallVector { + SmallVector lvlIvs; + lvlIvs.reserve(lvlRank); + for (Level l = 0; l < lvlRank; l++) + // FIXME: `toOrigDim` is deprecated + lvlIvs.push_back(dimIvs[toOrigDim(dstEnc, l)]); + return lvlIvs; }; - for (auto it : llvm::zip(op.getInputs(), adaptor.getInputs())) { + for (const auto &it : llvm::zip(op.getInputs(), adaptor.getInputs())) { Value orignalOp = std::get<0>(it); // Input (with encoding) from Op Value adaptedOp = std::get<1>(it); // Input (type converted) from adaptor - auto srcTp = getRankedTensorType(orignalOp); - auto encSrc = getSparseTensorEncoding(srcTp); - if (encSrc) { + const auto srcTp = getSparseTensorType(orignalOp); + if (srcTp.hasEncoding()) { genSparseCOOIterationLoop( rewriter, loc, adaptedOp, srcTp, [&](OpBuilder &builder, Location loc, Value idx, Value elemPtr) -> void { - SmallVector dimInd = - loadIndices(builder, loc, rank, idx, concatDim, offset); - if (encDst && !allDense) { + SmallVector dimIvs = + loadIndices(builder, loc, dimRank, idx, concatDim, offset); + if (dstTp.hasEncoding() && !allDense) { // Case: sparse => sparse, except for annotated all dense. - storeIndices(builder, loc, rank, dstIdx, dimInd); - genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, + storeIndices(builder, loc, dimRank, dstInd, dimIvs); + genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstInd, dstPerm); } else { // Case: sparse => dense, or annotated all dense. - SmallVector lvlInd; - if (allDense) - lvlInd = dimIdx2LvlIdx(dimInd); - else - lvlInd = dimInd; - insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lvlInd); + const auto lvlIvs = allDense ? dimIvs2LvlIvs(dimIvs) : dimIvs; + insertScalarIntoDenseTensor(builder, loc, elemPtr, dst, lvlIvs); } }); } else { genDenseTensorIterationLoop( rewriter, loc, adaptedOp, srcTp, - [&](OpBuilder &builder, Location loc, ValueRange idx) -> void { - if (encDst && !allDense) { + [&](OpBuilder &builder, Location loc, ValueRange dimIvs) -> void { + if (dstTp.hasEncoding() && !allDense) { // Case: dense => sparse, except for annotated all dense. - storeIndices(builder, loc, rank, dstIdx, idx, concatDim, + storeIndices(builder, loc, dimRank, dstInd, dimIvs, concatDim, offset); - Value val = genValueForDense(builder, loc, adaptedOp, idx); + Value val = genValueForDense(builder, loc, adaptedOp, dimIvs); builder.create(loc, val, elemPtr); - genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstIdx, + genAddEltCall(builder, loc, elemTp, dst, elemPtr, dstInd, dstPerm); } else { // Case: dense => dense, or annotated all dense. - Value val = genValueForDense(builder, loc, adaptedOp, idx); - SmallVector lvlInd(idx); + Value val = genValueForDense(builder, loc, adaptedOp, dimIvs); + // Despite the name, this isn't actually level-ivs until + // after the `dimIvs2LvlIvs` call. + SmallVector lvlIvs(dimIvs); // Apply offset. - lvlInd[concatDim] = builder.create( - loc, lvlInd[concatDim], offset); + lvlIvs[concatDim] = builder.create( + loc, lvlIvs[concatDim], offset); if (allDense) - lvlInd = dimIdx2LvlIdx(lvlInd); - builder.create(loc, val, dst, lvlInd); + lvlIvs = dimIvs2LvlIvs(lvlIvs); + builder.create(loc, val, dst, lvlIvs); } }); } // Accumulate offset. // TODO: avoid calling sparseDimSize multiple times by caching the result! - Value curDim = createOrFoldDimCall(rewriter, loc, encSrc, srcTp, - adaptedOp, concatDim); - + Value curDim = + createOrFoldDimCall(rewriter, loc, srcTp, adaptedOp, concatDim); offset = rewriter.create(loc, offset, curDim); } - if (encDst) { + if (dstTp.hasEncoding()) { if (!allDense) { // In sparse output case, the destination holds the COO. Value coo = dst; @@ -1446,7 +1426,8 @@ public: } rewriter.replaceOp(op, dst); } else { - rewriter.replaceOpWithNewOp(op, dstTp, dst); + rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), dst); } return success(); } @@ -1459,30 +1440,25 @@ public: LogicalResult matchAndRewrite(OutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - ShapedType srcType = op.getTensor().getType().cast(); + const Location loc = op->getLoc(); + const auto srcTp = getSparseTensorType(op.getTensor()); // Convert to default permuted COO. Value src = adaptor.getOperands()[0]; - auto encSrc = getSparseTensorEncoding(srcType); - SmallVector dimSizes = - getDimSizes(rewriter, loc, encSrc, srcType, src); - const auto enc = encSrc.withoutOrdering(); + SmallVector dimSizes = getDimSizes(rewriter, loc, srcTp, src); Value coo = NewCallParams(rewriter, loc) - .genBuffers(enc, dimSizes, srcType) + .genBuffers(srcTp.withoutOrdering(), dimSizes) .genNewCall(Action::kToCOO, src); // Then output the tensor to external file with indices in the externally // visible lexicographic index order. A sort is required if the source was // not in that order yet (note that the sort can be dropped altogether if // external format does not care about the order at all, but here we assume // it does). - Value sort = constantI1(rewriter, loc, - encSrc.getDimOrdering() && - !encSrc.getDimOrdering().isIdentity()); + const Value sort = constantI1(rewriter, loc, !srcTp.isIdentity()); SmallVector outParams{coo, adaptor.getOperands()[1], sort}; - Type eltType = srcType.getElementType(); - SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)}; + const Type elemTp = srcTp.getElementType(); + SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(elemTp)}; createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off); - genDelCOOCall(rewriter, loc, eltType, coo); + genDelCOOCall(rewriter, loc, elemTp, coo); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 76745a5..f6ef4cf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" @@ -42,32 +43,15 @@ static bool isZeroValue(Value val) { // Helper to detect a sparse tensor type operand. static bool isSparseTensor(OpOperand *op) { - if (auto enc = getSparseTensorEncoding(op->get().getType())) { - if (llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed)) - return true; - } - return false; -} - -static bool isAllDimOrdered(RankedTensorType rtp) { - if (auto enc = getSparseTensorEncoding(rtp)) - return llvm::all_of(enc.getDimLevelType(), isOrderedDLT); - - return true; + auto enc = getSparseTensorEncoding(op->get().getType()); + return enc && + llvm::is_contained(enc.getDimLevelType(), DimLevelType::Compressed); } static bool hasSameDimOrdering(RankedTensorType rtp1, RankedTensorType rtp2) { assert(rtp1.getRank() == rtp2.getRank()); - AffineMap idMap = - AffineMap::getMultiDimIdentityMap(rtp1.getRank(), rtp1.getContext()); - - auto enc1 = getSparseTensorEncoding(rtp1); - auto enc2 = getSparseTensorEncoding(rtp2); - - auto order1 = (enc1 && enc1.getDimOrdering()) ? enc1.getDimOrdering() : idMap; - auto order2 = (enc2 && enc2.getDimOrdering()) ? enc2.getDimOrdering() : idMap; - - return order1 == order2; + return SparseTensorType(rtp1).getDimToLvlMap() == + SparseTensorType(rtp2).getDimToLvlMap(); } // Helper method to find zero/uninitialized allocation. @@ -424,9 +408,10 @@ public: ValueRange reduc) { SmallVector srcIndices; SmallVector dstIndices; - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encSrc, i); - srcIndices.push_back(args[dim]); + for (Dimension d = 0, dimRank = srcTp.getRank(); d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level lvl = toStoredDim(encSrc, d); + srcIndices.push_back(args[lvl]); } translateIndicesArray(builder, loc, op.getReassociationIndices(), srcIndices, srcSizes, dstSizes, dstIndices); @@ -486,9 +471,10 @@ struct ConcatenateRewriter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op); - uint64_t conDim = op.getDimension().getZExtValue(); + const Location loc = op.getLoc(); + const auto dstTp = getSparseTensorType(op); + const Dimension dimRank = dstTp.getDimRank(); + const Dimension conDim = op.getDimension().getZExtValue(); SmallVector sizes; concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim); @@ -505,14 +491,16 @@ struct ConcatenateRewriter : public OpRewritePattern { // foreach in %s2 : insert d0, d1 + size(s1), %tmp // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp // %t = convert_to_dest_tensor(%tmp) - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + // + // NOTE: this cannot be `const` because it will be changed when + // `needTmpCOO`, but that's buried in the conditional below and + // thus not easily extracted. + auto encDst = dstTp.getEncoding(); Value dst; // Destination tensor for inserting source tensor values. bool needTmpCOO = true; - bool allDense = false; + const bool allDense = dstTp.hasEncoding() && dstTp.isAllDense(); Value annotatedDenseDst; - int64_t rank = dstTp.getRank(); - if (encDst) { - allDense = encDst.isAllDense(); + if (dstTp.hasEncoding()) { bool allOrdered = false; // When concatenating on dimension 0, and all inputs are sorted and have // an identity dimOrdering, the concatenate will generate coords in @@ -521,16 +509,12 @@ struct ConcatenateRewriter : public OpRewritePattern { // in all input/output buffers, and all input/output buffers have the same // dimOrdering, the tmp COO buffer is still unnecessary (e.g, concatenate // CSC matrices along column). - if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) { + if (!allDense && conDim == 0 && dstTp.isIdentity()) { for (auto i : op.getInputs()) { - auto rtp = getRankedTensorType(i); - auto srcEnc = getSparseTensorEncoding(rtp); - if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) { - allOrdered = true; - continue; - } - allOrdered = false; - break; + const auto stt = getSparseTensorType(i); + allOrdered = stt.isAllOrdered() && stt.isIdentity(); + if (!allOrdered) + break; } } @@ -547,8 +531,9 @@ struct ConcatenateRewriter : public OpRewritePattern { // Create a view of the values buffer to match the unannotated dense // tensor. Value valuesBuffer = genToValues(rewriter, loc, dst); - Value idxBuffer = genAlloca( - rewriter, loc, rank, rewriter.getIndexType(), /*staticShape=*/true); + Value idxBuffer = + genAlloca(rewriter, loc, dimRank, rewriter.getIndexType(), + /*staticShape=*/true); annotatedDenseDst = dst; dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, valuesBuffer, idxBuffer); @@ -571,13 +556,14 @@ struct ConcatenateRewriter : public OpRewritePattern { loc, input, initArgs, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector indices(rank, Value()); - for (int64_t i = 0; i < rank; i++) { - Value idx = args[i]; - if (i == static_cast(conDim)) + SmallVector indices(dstTp.getLvlRank()); + for (Dimension d = 0; d < dimRank; d++) { + Value idx = args[d]; + if (d == conDim) // Transform coordinates for the concatenating dim. idx = builder.create(loc, idx, offset); - indices[toStoredDim(encDst, i)] = idx; + // FIXME: `toStoredDim` is deprecated + indices[toStoredDim(encDst, d)] = idx; } if (encDst && !allDense) { Value cond = genIsNonzero(rewriter, loc, v); @@ -599,31 +585,34 @@ struct ConcatenateRewriter : public OpRewritePattern { // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset // dynamically. - int64_t d = getRankedTensorType(input).getShape()[conDim]; - assert(!ShapedType::isDynamic(d)); - offset = rewriter.create(loc, offset, - constantIndex(rewriter, loc, d)); + const auto sh = getSparseTensorType(input).getStaticDimSize(conDim); + assert(sh.has_value()); + offset = rewriter.create( + loc, offset, constantIndex(rewriter, loc, *sh)); if (encDst && !allDense) { dst = foreachOp.getResult(0); initArgs[0] = dst; } } + // Temp variable to avoid needing to call `getRankedTensorType` + // in the three use-sites below. + const RankedTensorType dstRTT = dstTp; if (encDst) { if (!allDense) { dst = rewriter.create(loc, dst, true); if (needTmpCOO) { Value tmpCoo = dst; - dst = rewriter.create(loc, dstTp, tmpCoo).getResult(); + dst = rewriter.create(loc, dstRTT, tmpCoo).getResult(); rewriter.create(loc, tmpCoo); } } else { - dst = rewriter.create(loc, dstTp, annotatedDenseDst) + dst = rewriter.create(loc, dstRTT, annotatedDenseDst) .getResult(); } rewriter.replaceOp(op, dst); } else { - rewriter.replaceOpWithNewOp(op, dstTp, dst); + rewriter.replaceOpWithNewOp(op, dstRTT, dst); } return success(); } @@ -675,7 +664,7 @@ private: PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value src = op.getSource(); - auto dstTp = getRankedTensorType(op); + const auto dstTp = getSparseTensorType(op); SmallVector sizes; sizesFromSrc(rewriter, sizes, loc, src); SmallVector dynSizes; @@ -688,16 +677,16 @@ private: } } - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); + const auto encDst = dstTp.getEncoding(); // We don't need a temporary COO tensor if the destination has an identity // ordering. Otherwise, we use the destination ordering for the temporary // COO tensor. // TODO: enhance foreachOp to take ordering to remove the need of a // temporary COO tensor here. - RankedTensorType bufferTp = encDst.hasIdDimOrdering() - ? dstTp - : getUnorderedCOOFromTypeWithOrdering( - dstTp, encDst.getDimOrdering()); + const RankedTensorType bufferTp = dstTp.isIdentity() + ? dstTp.getRankedTensorType() + : getUnorderedCOOFromTypeWithOrdering( + dstTp, dstTp.getDimToLvlMap()); auto buffer = rewriter.create(loc, bufferTp, dynSizes).getResult(); auto foreachOp = rewriter.create( @@ -705,10 +694,11 @@ private: [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { Value input = reduc.front(); - uint64_t rank = dstTp.getRank(); - SmallVector indicesArray(rank, Value()); - for (uint64_t i = 0; i < rank; i++) - indicesArray[toStoredDim(encDst, i)] = indices[i]; + const Dimension dimRank = dstTp.getDimRank(); + SmallVector indicesArray(dimRank); + for (Dimension d = 0; d < dimRank; d++) + // FIXME: `toStoredDim` is deprecated + indicesArray[toStoredDim(encDst, d)] = indices[d]; if (fromSparseConst) { input = builder.create(loc, v, input, indicesArray); } else { @@ -729,7 +719,8 @@ private: rewriter.setInsertionPointAfter(op); src = rewriter.create(loc, foreachOp.getResult(0), true); if (bufferTp != dstTp) { - rewriter.replaceOpWithNewOp(op, dstTp, src); + rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), + src); rewriter.create(loc, src); } else { rewriter.replaceOp(op, src); @@ -782,15 +773,22 @@ private: // insert element to dst LogicalResult sparse2SparseRewrite(ConvertOp op, PatternRewriter &rewriter) const { - Location loc = op->getLoc(); + const Location loc = op->getLoc(); + // These two variables cannot be `const` because they're conditionally + // changed below. Ideally we'd use `SparseTensorType` for `srcRTT`; + // however that class's copy-ctor is implicitly deleted. Value src = op.getSource(); - RankedTensorType srcTp = getRankedTensorType(src); - RankedTensorType dstTp = getRankedTensorType(op); - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - int64_t rank = dstTp.getRank(); + auto srcRTT = getRankedTensorType(src); + const auto dstTp = getSparseTensorType(op); + const auto encDst = dstTp.getEncoding(); + const Level dstLvlRank = dstTp.getLvlRank(); + const Dimension dimRank = dstTp.getDimRank(); + // This assertion should be guaranteed by validity of the op, + // but just for paranoia's sake. + assert(srcRTT.getRank() == dimRank); SmallVector srcSizes; - sizesForTensor(rewriter, srcSizes, loc, srcTp, src); + sizesForTensor(rewriter, srcSizes, loc, srcRTT, src); Value tmpCoo = Value(); Value nnz = rewriter.create(loc, src); // We need a tmp COO buffer if and only if @@ -798,28 +796,31 @@ private: // 2. the src tensor is not ordered in the same way as the target // tensor (e.g., src tensor is not ordered or src tensor haves a different // dimOrdering). - if (!isUniqueCOOType(srcTp) && - !(isAllDimOrdered(srcTp) && hasSameDimOrdering(srcTp, dstTp))) { + if (!isUniqueCOOType(srcRTT) && !(SparseTensorType(srcRTT).isAllOrdered() && + hasSameDimOrdering(srcRTT, dstTp))) { // Construct a COO tensor from the src tensor. // TODO: there may be cases for which more efficiently without // going through an intermediate COO, such as cases that only change // the overhead types. SmallVector dynSrcSizes; - getDynamicSizes(srcTp, srcSizes, dynSrcSizes); - srcTp = - getUnorderedCOOFromTypeWithOrdering(srcTp, encDst.getDimOrdering()); + getDynamicSizes(srcRTT, srcSizes, dynSrcSizes); + srcRTT = + getUnorderedCOOFromTypeWithOrdering(srcRTT, dstTp.getDimToLvlMap()); + // Ensure that mutating `srcRTT` didn't invalidate `dimRank`. + assert(srcRTT.getRank() == dimRank); tmpCoo = rewriter - .create(loc, srcTp, dynSrcSizes, Value(), + .create(loc, srcRTT, dynSrcSizes, Value(), /*sizeHint=*/nnz, Attribute()) .getResult(); auto foreachOp = rewriter.create( loc, src, tmpCoo, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - SmallVector dstIndices(srcTp.getRank(), Value()); - for (int64_t i = 0; i < rank; i++) { - uint64_t dim = toStoredDim(encDst, i); - dstIndices[dim] = args[i]; + SmallVector dstIndices(dstLvlRank); + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level l = toStoredDim(encDst, d); + dstIndices[l] = args[d]; } auto t = builder.create(loc, v, reduc.front(), dstIndices); @@ -828,30 +829,36 @@ private: src = rewriter.create(loc, foreachOp.getResult(0), true); } + // Now that the conditional is done, we can use `SparseTensorType`. + const SparseTensorType srcTp(srcRTT); + // Only need to sort if the srcTp is not already sorted (we faithfully take // the guarantee from the sparse tensor encoding). - if (!isAllDimOrdered(srcTp)) { + if (!srcTp.isAllOrdered()) { // Retrieve the values-array. Value y = genToValues(rewriter, loc, src); - SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp); + const auto encSrc = srcTp.getEncoding(); // Sort the COO tensor so that its elements are ordered via increasing // indices for the storage ordering of the dst tensor. Use SortCoo if the // COO tensor has the same dim ordering as the dst tensor. - if (rank > 1 && hasSameDimOrdering(srcTp, dstTp)) { + if (dimRank > 1 && hasSameDimOrdering(srcTp, dstTp)) { MemRefType indTp = get1DMemRefType(getIndexOverheadType(rewriter, encSrc), /*withLayout=*/false); Value xs = rewriter.create(loc, indTp, src); rewriter.create( - loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(rank), + loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank), rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort); } else { // Gather the indices-arrays in the dst tensor storage order. - SmallVector xs(rank, Value()); - for (int64_t i = 0; i < rank; i++) { - uint64_t orgDim = toOrigDim(encSrc, i); - xs[toStoredDim(encDst, orgDim)] = - genToIndices(rewriter, loc, src, i, /*cooStart=*/0); + SmallVector xs(dstLvlRank); + const Level srcLvlRank = srcTp.getLvlRank(); + for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) { + // FIXME: `toOrigDim` is deprecated + Dimension dim = toOrigDim(encSrc, srcLvl); + // FIXME: `toStoredDim` is deprecated + Level dstLvl = toStoredDim(encDst, dim); + xs[dstLvl] = genToIndices(rewriter, loc, src, srcLvl, /*cooStart=*/0); } rewriter.create(loc, nnz, xs, ValueRange{y}, SparseTensorSortKind::HybridQuickSort); @@ -862,17 +869,19 @@ private: SmallVector dynDstSizes; getDynamicSizes(dstTp, srcSizes, dynDstSizes); Value dst = rewriter - .create(loc, dstTp, dynDstSizes, Value(), + .create(loc, dstTp.getRankedTensorType(), + dynDstSizes, Value(), /*sizeHint=*/nnz, Attribute()) .getResult(); - SmallVector indices(srcTp.getRank(), Value()); + SmallVector indices(dstLvlRank); auto foreachOp = rewriter.create( loc, src, dst, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encDst, i); - indices[dim] = args[i]; + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Level l = toStoredDim(encDst, d); + indices[l] = args[d]; } auto t = builder.create(loc, v, reduc.front(), indices); builder.create(loc, t); @@ -889,7 +898,7 @@ private: // codegen. rewriter.setInsertionPointAfter(op); auto t = rewriter.create(loc, foreachOp.getResult(0), true); - rewriter.replaceOpWithNewOp(op, dstTp, t); + rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), t); return success(); } }; @@ -905,8 +914,8 @@ public: auto loc = op.getLoc(); Value input = op.getTensor(); SmallVector reduc = op.getInitArgs(); - auto rtp = getRankedTensorType(input); - int64_t rank = rtp.getRank(); + const auto stt = getSparseTensorType(input); + const Dimension dimRank = stt.getDimRank(); // Special-case: for each over a sparse constant uses its own rewriting // rule. @@ -917,24 +926,24 @@ public: } // Otherwise, use loop emitter to generate loops. - auto enc = getSparseTensorEncoding(rtp); + const auto enc = stt.getEncoding(); // 1. Generates loop for the sparse input. LoopEmitter loopEmitter( ValueRange{input}, StringAttr::get(getContext(), ForeachOp::getOperationName())); loopEmitter.initializeLoopEmit(rewriter, loc); - for (int64_t i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { // TODO: provide utility function for loop sequences that only contains // one for loop? - loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(i)); + loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(d)); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i, reduc); + loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, d, reduc); } SmallVector coords; - coords.reserve(rank); + coords.reserve(dimRank); loopEmitter.getCoordinateArray(coords); Value vals = loopEmitter.getValBuffer()[0]; @@ -949,8 +958,9 @@ public: // Remap coordinates. SmallVector args; - for (int64_t i = 0; i < rank; i++) { - Value actual = coords[toStoredDim(enc, i)]; + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + Value actual = coords[toStoredDim(enc, d)]; args.push_back(actual); } // Remap value. @@ -972,7 +982,7 @@ public: rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args); } - for (int64_t i = 0; i < rank; i++) { + for (Dimension d = 0; d < dimRank; d++) { // Link the reduction chain. Note that loop emitter update the reducValue // in place. loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); @@ -992,9 +1002,9 @@ struct NewRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(NewOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto dstTp = getRankedTensorType(op.getResult()); - SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp); - if (!encDst) + const auto dstTp = getSparseTensorType(op.getResult()); + const auto encDst = dstTp.getEncoding(); + if (!dstTp.hasEncoding()) return failure(); // Create a sparse tensor reader. @@ -1006,17 +1016,17 @@ struct NewRewriter : public OpRewritePattern { // Allocate a temporary buffer for storing dimension sizes and indices. Type indexTp = rewriter.getIndexType(); - uint64_t rank = dstTp.getRank(); - Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); + const Dimension dimRank = dstTp.getDimRank(); + Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); // If the result tensor has dynamic dimensions, get the dynamic sizes from // the sparse tensor reader. SmallVector dynSizesArray; - if (!dstTp.hasStaticShape()) { + if (dstTp.hasDynamicDimShape()) { createFuncCall(rewriter, loc, "copySparseTensorReaderDimSizes", {}, {reader, dimSizes}, EmitCInterface::On) .getResult(0); - ArrayRef dstShape = dstTp.getShape(); + ArrayRef dstShape = dstTp.getRankedTensorType().getShape(); for (auto &d : llvm::enumerate(dstShape)) { if (d.value() == ShapedType::kDynamic) { dynSizesArray.push_back(rewriter.create( @@ -1038,7 +1048,7 @@ struct NewRewriter : public OpRewritePattern { {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); RankedTensorType cooTp = - getUnorderedCOOFromTypeWithOrdering(dstTp, encDst.getDimOrdering()); + getUnorderedCOOFromTypeWithOrdering(dstTp, dstTp.getDimToLvlMap()); Value cooBuffer = rewriter .create(loc, cooTp, dynSizesArray, Value(), @@ -1047,7 +1057,7 @@ struct NewRewriter : public OpRewritePattern { // The verifier ensures only 2D tensors can have the expandSymmetry flag. Value symmetric; - if (rank == 2 && op.getExpandSymmetry()) { + if (dimRank == 2 && op.getExpandSymmetry()) { symmetric = createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) @@ -1066,10 +1076,11 @@ struct NewRewriter : public OpRewritePattern { Value indices = dimSizes; // Reuse the indices memref to store indices. createFuncCall(rewriter, loc, getNextFuncName, {}, {reader, indices, value}, EmitCInterface::On); - SmallVector indicesArray(rank, Value()); - for (uint64_t i = 0; i < rank; i++) { - indicesArray[toStoredDim(encDst, i)] = rewriter.create( - loc, indices, constantIndex(rewriter, loc, i)); + SmallVector indicesArray(dimRank); + for (Dimension d = 0; d < dimRank; d++) { + // FIXME: `toStoredDim` is deprecated + indicesArray[toStoredDim(encDst, d)] = rewriter.create( + loc, indices, constantIndex(rewriter, loc, d)); } Value v = rewriter.create(loc, value); Value t = rewriter.create(loc, v, forOp.getRegionIterArg(0), @@ -1098,7 +1109,8 @@ struct NewRewriter : public OpRewritePattern { createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader}, EmitCInterface::Off); cooBuffer = rewriter.create(loc, cooBuffer, true); - Value newOp = rewriter.replaceOpWithNewOp(op, dstTp, cooBuffer); + Value newOp = rewriter.replaceOpWithNewOp( + op, dstTp.getRankedTensorType(), cooBuffer); // Release the unordered COO tensor buffer. rewriter.setInsertionPointAfterValue(newOp); @@ -1118,18 +1130,18 @@ struct OutRewriter : public OpRewritePattern { Value nnz = rewriter.create(loc, src); // Allocate a temporary buffer for storing dimension sizes and indices. - auto srcTp = getRankedTensorType(src); - uint64_t rank = srcTp.getRank(); + const auto srcTp = getSparseTensorType(src); + const Dimension dimRank = srcTp.getDimRank(); Type indexTp = rewriter.getIndexType(); - Value dimSizes = genAlloca(rewriter, loc, rank, indexTp); + Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp); // Generate code to calculate dimension size values and store the values to // the buffer. SmallVector dims; sizesForTensor(rewriter, dims, loc, srcTp, src); - for (uint64_t i = 0; i < rank; i++) { - rewriter.create(loc, dims[i], dimSizes, - constantIndex(rewriter, loc, i)); + for (Dimension d = 0; d < dimRank; d++) { + rewriter.create(loc, dims[d], dimSizes, + constantIndex(rewriter, loc, d)); } // Create a sparse tensor writer and output meta data. @@ -1138,7 +1150,7 @@ struct OutRewriter : public OpRewritePattern { createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp}, {op.getDest()}, EmitCInterface::Off) .getResult(0); - Value rankValue = constantIndex(rewriter, loc, rank); + Value rankValue = constantIndex(rewriter, loc, dimRank); createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {}, {writer, rankValue, nnz, dimSizes}, EmitCInterface::On); @@ -1153,9 +1165,9 @@ struct OutRewriter : public OpRewritePattern { loc, src, std::nullopt, [&](OpBuilder &builder, Location loc, ValueRange args, Value v, ValueRange reduc) { - for (uint64_t i = 0; i < rank; i++) { - rewriter.create(loc, args[i], indices, - constantIndex(builder, loc, i)); + for (Dimension d = 0; d < dimRank; d++) { + rewriter.create(loc, args[d], indices, + constantIndex(builder, loc, d)); } rewriter.create(loc, v, value); SmallVector operands{writer, rankValue, indices, value}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp index e24a38d..8440630 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -36,16 +36,19 @@ static IntegerAttr fromOptionalInt(MLIRContext *ctx, return IntegerAttr::get(IndexType::get(ctx), dim.value()); } +// This is only ever called from `SparseTensorTypeToBufferConverter`, +// which is why the first argument is `RankedTensorType` rather than +// `SparseTensorType`. static std::optional convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields) { - auto enc = getSparseTensorEncoding(rtp); - if (!enc) + const SparseTensorType stt(rtp); + if (!stt.hasEncoding()) return std::nullopt; foreachFieldAndTypeInSparseTensor( - rtp, - [&fields](Type fieldType, unsigned fieldIdx, - SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/, + stt, + [&fields](Type fieldType, FieldIndex fieldIdx, + SparseTensorFieldKind /*fieldKind*/, Level /*lvl*/, DimLevelType /*dlt*/) -> bool { assert(fieldIdx == fields.size()); fields.push_back(fieldType); @@ -60,9 +63,7 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl &fields) { SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); - addConversion([&](RankedTensorType rtp, SmallVectorImpl &fields) { - return convertSparseTensorType(rtp, fields); - }); + addConversion(convertSparseTensorType); // Required by scf.for 1:N type conversion. addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, @@ -81,9 +82,9 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { //===----------------------------------------------------------------------===// Value SparseTensorSpecifier::getInitValue(OpBuilder &builder, Location loc, - RankedTensorType rtp) { + SparseTensorType stt) { return builder.create( - loc, StorageSpecifierType::get(getSparseTensorEncoding(rtp))); + loc, StorageSpecifierType::get(stt.getEncoding())); } Value SparseTensorSpecifier::getSpecifierField(OpBuilder &builder, Location loc, @@ -110,34 +111,30 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc, //===----------------------------------------------------------------------===// Value sparse_tensor::SparseTensorDescriptor::getIdxMemRefOrView( - OpBuilder &builder, Location loc, unsigned idxDim) const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - unsigned idx = idxDim >= cooStart ? cooStart : idxDim; - Value buffer = getMemRefField(SparseTensorFieldKind::IdxMemRef, idx); - if (idxDim >= cooStart) { - unsigned rank = enc.getDimLevelType().size(); - Value stride = constantIndex(builder, loc, rank - cooStart); - Value size = getIdxMemSize(builder, loc, cooStart); - size = builder.create(loc, size, stride); - buffer = builder.create( - loc, buffer, - /*offset=*/ValueRange{constantIndex(builder, loc, idxDim - cooStart)}, - /*size=*/ValueRange{size}, - /*step=*/ValueRange{stride}); - } - return buffer; + OpBuilder &builder, Location loc, Level idxLvl) const { + const Level cooStart = getCOOStart(rType.getEncoding()); + if (idxLvl < cooStart) + return getMemRefField(SparseTensorFieldKind::IdxMemRef, idxLvl); + + Value stride = constantIndex(builder, loc, rType.getLvlRank() - cooStart); + Value size = getIdxMemSize(builder, loc, cooStart); + size = builder.create(loc, size, stride); + return builder.create( + loc, getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart), + /*offset=*/ValueRange{constantIndex(builder, loc, idxLvl - cooStart)}, + /*size=*/ValueRange{size}, + /*step=*/ValueRange{stride}); } //===----------------------------------------------------------------------===// // Public methods. //===----------------------------------------------------------------------===// -constexpr uint64_t kDataFieldStartingIdx = 0; +constexpr FieldIndex kDataFieldStartingIdx = 0; void sparse_tensor::foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, - llvm::function_ref callback) { assert(enc); @@ -146,23 +143,22 @@ void sparse_tensor::foreachFieldInSparseTensor( if (!(callback(idx, kind, dim, dlt))) \ return; - unsigned rank = enc.getDimLevelType().size(); - unsigned end = getCOOStart(enc); - if (end != rank) - end += 1; - static_assert(kDataFieldStartingIdx == 0); - unsigned fieldIdx = kDataFieldStartingIdx; + const auto lvlTypes = enc.getDimLevelType(); + const Level lvlRank = enc.getLvlRank(); + const Level cooStart = getCOOStart(enc); + const Level end = cooStart == lvlRank ? cooStart : cooStart + 1; + FieldIndex fieldIdx = kDataFieldStartingIdx; // Per-dimension storage. - for (unsigned r = 0; r < end; r++) { + for (Level l = 0; l < end; l++) { // Dimension level types apply in order to the reordered dimension. // As a result, the compound type can be constructed directly in the given // order. - auto dlt = getDimLevelType(enc, r); + const auto dlt = lvlTypes[l]; if (isCompressedDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt); - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, l, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, l, dlt); } else if (isSingletonDLT(dlt)) { - RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt); + RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, l, dlt); } else { assert(isDenseDLT(dlt)); // no fields } @@ -179,16 +175,16 @@ void sparse_tensor::foreachFieldInSparseTensor( } void sparse_tensor::foreachFieldAndTypeInSparseTensor( - RankedTensorType rType, - llvm::function_ref callback) { - auto enc = getSparseTensorEncoding(rType); + const auto enc = stt.getEncoding(); assert(enc); // Construct the basic types. Type idxType = enc.getIndexType(); Type ptrType = enc.getPointerType(); - Type eltType = rType.getElementType(); + Type eltType = stt.getElementType(); Type metaDataType = StorageSpecifierType::get(enc); // memref pointers @@ -201,17 +197,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor( foreachFieldInSparseTensor( enc, [metaDataType, ptrMemType, idxMemType, valMemType, - callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind, - unsigned dim, DimLevelType dlt) -> bool { + callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind, + Level lvl, DimLevelType dlt) -> bool { switch (fieldKind) { case SparseTensorFieldKind::StorageSpec: - return callback(metaDataType, fieldIdx, fieldKind, dim, dlt); + return callback(metaDataType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::PtrMemRef: - return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt); + return callback(ptrMemType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::IdxMemRef: - return callback(idxMemType, fieldIdx, fieldKind, dim, dlt); + return callback(idxMemType, fieldIdx, fieldKind, lvl, dlt); case SparseTensorFieldKind::ValMemRef: - return callback(valMemType, fieldIdx, fieldKind, dim, dlt); + return callback(valMemType, fieldIdx, fieldKind, lvl, dlt); }; llvm_unreachable("unrecognized field kind"); }); @@ -220,8 +216,8 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor( unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; foreachFieldInSparseTensor(enc, - [&numFields](unsigned, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { + [&numFields](FieldIndex, SparseTensorFieldKind, + Level, DimLevelType) -> bool { numFields++; return true; }); @@ -232,8 +228,9 @@ unsigned sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) { unsigned numFields = 0; // one value memref foreachFieldInSparseTensor(enc, - [&numFields](unsigned fidx, SparseTensorFieldKind, - unsigned, DimLevelType) -> bool { + [&numFields](FieldIndex fidx, + SparseTensorFieldKind, Level, + DimLevelType) -> bool { if (fidx >= kDataFieldStartingIdx) numFields++; return true; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h index 193c227..40207d7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -15,6 +15,7 @@ #include "mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" @@ -33,14 +34,14 @@ namespace sparse_tensor { // capacity and the used size resides in the storage_specifier struct. // // struct { -// ; per-dimension d: +// ; per-level l: // ; if dense: // // ; if compresed: -// memref pointers-d ; pointers for sparse dim d -// memref indices-d ; indices for sparse dim d +// memref pointers-l ; pointers for sparse level l +// memref indices-l ; indices for sparse level l // ; if singleton: -// memref indices-d ; indices for singleton dim d +// memref indices-l ; indices for singleton level l // // memref values ; values // @@ -50,13 +51,13 @@ namespace sparse_tensor { // } // }; // -// In addition, for a "trailing COO region", defined as a compressed -// dimension followed by one ore more singleton dimensions, the default -// SOA storage that is inherent to the TACO format is optimized into an -// AOS storage where all indices of a stored element appear consecutively. -// In such cases, a special operation (sparse_tensor.indices_buffer) must -// be used to access the AOS index array. In the code below, the method -// `getCOOStart` is used to find the start of the "trailing COO region". +// In addition, for a "trailing COO region", defined as a compressed level +// followed by one ore more singleton levels, the default SOA storage that +// is inherent to the TACO format is optimized into an AOS storage where +// all indices of a stored element appear consecutively. In such cases, +// a special operation (sparse_tensor.indices_buffer) must be used to +// access the AOS index array. In the code below, the method `getCOOStart` +// is used to find the start of the "trailing COO region". // // Examples. // @@ -64,13 +65,13 @@ namespace sparse_tensor { // memref ; pointers-1 // memref ; indices-1 // memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; dim0, dim1, 3xsizes +// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes // // #COO storage of 2-dim matrix yields // memref, ; pointers-0, essentially [0,sz] // memref ; AOS index storage // memref ; values -// struct<(array<2 x i64>, array<3 x i64>)>) ; dim0, dim1, 3xsizes +// struct<(array<2 x i64>, array<3 x i64>)>) ; lvl0, lvl1, 3xsizes // //===----------------------------------------------------------------------===// @@ -88,6 +89,15 @@ static_assert(static_cast(SparseTensorFieldKind::IdxMemRef) == static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == static_cast(StorageSpecifierKind::ValMemSize)); +/// The type of field indices. This alias is to help code be more +/// self-documenting; unfortunately it is not type-checked, so it only +/// provides documentation rather than doing anything to prevent mixups. +using FieldIndex = unsigned; + +// TODO: Functions/methods marked with [NUMFIELDS] might should use +// `FieldIndex` for their return type, via the same reasoning for why +// `Dimension`/`Level` are used both for identifiers and ranks. + /// For each field that will be allocated for the given sparse tensor encoding, /// calls the callback with the corresponding field index, field kind, dimension /// (for sparse tensor level memrefs) and dimlevelType. @@ -97,25 +107,26 @@ static_assert(static_cast(SparseTensorFieldKind::ValMemRef) == /// tensor fields instead of relying on ad-hoc index computation. void foreachFieldInSparseTensor( SparseTensorEncodingAttr, - llvm::function_ref); + llvm::function_ref); /// Same as above, except that it also builds the Type for the corresponding /// field. void foreachFieldAndTypeInSparseTensor( - RankedTensorType, - llvm::function_ref); /// Gets the total number of fields for the given sparse tensor encoding. +// TODO: See note [NUMFIELDS]. unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc); /// Gets the total number of data fields (index arrays, pointer arrays, and a /// value array) for the given sparse tensor encoding. +// TODO: See note [NUMFIELDS]. unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc); inline StorageSpecifierKind toSpecifierKind(SparseTensorFieldKind kind) { @@ -138,47 +149,48 @@ public: /// Getters: get the field index for required field. /// - unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, - std::optional dim) const { - return getFieldIndexAndStride(kind, dim).first; + FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional lvl) const { + return getFieldIndexAndStride(kind, lvl).first; } - unsigned getMemRefFieldIndex(StorageSpecifierKind kind, - std::optional dim) const { - return getMemRefFieldIndex(toFieldKind(kind), dim); + FieldIndex getMemRefFieldIndex(StorageSpecifierKind kind, + std::optional lvl) const { + return getMemRefFieldIndex(toFieldKind(kind), lvl); } + // TODO: See note [NUMFIELDS]. static unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) { return sparse_tensor::getNumFieldsFromEncoding(enc); } static void foreachFieldInSparseTensor( const SparseTensorEncodingAttr enc, - llvm::function_ref callback) { return sparse_tensor::foreachFieldInSparseTensor(enc, callback); } - std::pair + std::pair getFieldIndexAndStride(SparseTensorFieldKind kind, - std::optional dim) const { - unsigned fieldIdx = -1u; + std::optional lvl) const { + FieldIndex fieldIdx = -1u; unsigned stride = 1; if (kind == SparseTensorFieldKind::IdxMemRef) { - assert(dim.has_value()); - unsigned cooStart = getCOOStart(enc); - unsigned rank = enc.getDimLevelType().size(); - if (dim.value() >= cooStart && dim.value() < rank) { - dim = cooStart; - stride = rank - cooStart; + assert(lvl.has_value()); + const Level cooStart = getCOOStart(enc); + const Level lvlRank = enc.getLvlRank(); + if (lvl.value() >= cooStart && lvl.value() < lvlRank) { + lvl = cooStart; + stride = lvlRank - cooStart; } } foreachFieldInSparseTensor( enc, - [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind, - unsigned fDim, DimLevelType dlt) -> bool { - if ((dim && fDim == dim.value() && kind == fKind) || + [lvl, kind, &fieldIdx](FieldIndex fIdx, SparseTensorFieldKind fKind, + Level fLvl, DimLevelType dlt) -> bool { + if ((lvl && fLvl == lvl.value() && kind == fKind) || (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { fieldIdx = fIdx; // Returns false to break the iteration. @@ -187,13 +199,27 @@ public: return true; }); assert(fieldIdx != -1u); - return std::pair(fieldIdx, stride); + return std::pair(fieldIdx, stride); } private: SparseTensorEncodingAttr enc; }; +// FIXME: Functions/methods marked with [CLARIFY_DIM_LVL] require +// clarification on whether their "dim" argument should actually +// be `Level` or `Dimension`. In particular, it's unclear whether +// `StorageSpecifierKind::DimSize` actually means to refer to dimension-sizes +// vs level-sizes. If it's the latter (which seems unlikely), then all the +// noted functions should use the `Level` type alias. If it's the former, +// then the functions which specifically use `DimSize` should be changed +// to use the `Dimension` type alias; however, the functions which take +// an unknown `StorageSpecifierKind` must be adjusted to ensure that they +// correctly interpret the "dim" argument since the interpretation depends +// on the `StorageSpecifierKind` value. Since wrengr couldn't figure this +// out from context, Peiming or Bixia should review these functions and +// update them as appropriate. + class SparseTensorSpecifier { public: explicit SparseTensorSpecifier(Value specifier) @@ -201,18 +227,21 @@ public: // Undef value for dimension sizes, all zero value for memory sizes. static Value getInitValue(OpBuilder &builder, Location loc, - RankedTensorType rtp); + SparseTensorType stt); /*implicit*/ operator Value() { return specifier; } + // FIXME: see note [CLARIFY_DIM_LVL]. Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim); + // FIXME: see note [CLARIFY_DIM_LVL]. void setSpecifierField(OpBuilder &builder, Location loc, Value v, StorageSpecifierKind kind, std::optional dim); + // FIXME: see note [CLARIFY_DIM_LVL]. Type getFieldType(StorageSpecifierKind kind, std::optional dim) { return specifier.getType().getFieldType(kind, dim); } @@ -229,11 +258,10 @@ private: template class SparseTensorDescriptorImpl { protected: - SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields) - : rType(tp.cast()), fields(fields) { - assert(getSparseTensorEncoding(tp) && - getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) == - fields.size()); + SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields) + : rType(stt), fields(fields) { + assert(stt.hasEncoding() && + getNumFieldsFromEncoding(stt.getEncoding()) == getNumFields()); // We should make sure the class is trivially copyable (and should be small // enough) such that we can pass it by value. static_assert(std::is_trivially_copyable_v< @@ -241,19 +269,21 @@ protected: } public: - unsigned getMemRefFieldIndex(SparseTensorFieldKind kind, - std::optional dim) const { + FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind, + std::optional lvl) const { // Delegates to storage layout. - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getMemRefFieldIndex(kind, dim); + StorageLayout layout(rType.getEncoding()); + return layout.getMemRefFieldIndex(kind, lvl); } + // TODO: See note [NUMFIELDS]. unsigned getNumFields() const { return fields.size(); } /// /// Getters: get the value for required field. /// + // FIXME: see note [CLARIFY_DIM_LVL]. Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim) const { @@ -261,12 +291,13 @@ public: return md.getSpecifierField(builder, loc, kind, dim); } + // FIXME: see note [CLARIFY_DIM_LVL]. Value getDimSize(OpBuilder &builder, Location loc, unsigned dim) const { return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim); } - Value getPtrMemRef(unsigned ptrDim) const { - return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim); + Value getPtrMemRef(Level lvl) const { + return getMemRefField(SparseTensorFieldKind::PtrMemRef, lvl); } Value getValMemRef() const { @@ -274,23 +305,23 @@ public: } Value getMemRefField(SparseTensorFieldKind kind, - std::optional dim) const { - return getField(getMemRefFieldIndex(kind, dim)); + std::optional lvl) const { + return getField(getMemRefFieldIndex(kind, lvl)); } - Value getMemRefField(unsigned fidx) const { + Value getMemRefField(FieldIndex fidx) const { assert(fidx < fields.size() - 1); return getField(fidx); } - Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + Value getPtrMemSize(OpBuilder &builder, Location loc, Level lvl) const { return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, - dim); + lvl); } - Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const { + Value getIdxMemSize(OpBuilder &builder, Location loc, Level lvl) const { return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, - dim); + lvl); } Value getValMemSize(OpBuilder &builder, Location loc) const { @@ -299,54 +330,46 @@ public: } Type getMemRefElementType(SparseTensorFieldKind kind, - std::optional dim) const { - return getMemRefField(kind, dim) - .getType() - .template cast() - .getElementType(); + std::optional lvl) const { + return getMemRefType(getMemRefField(kind, lvl)).getElementType(); } - Value getField(unsigned fidx) const { + Value getField(FieldIndex fidx) const { assert(fidx < fields.size()); return fields[fidx]; } ValueRange getMemRefFields() const { - ValueRange ret = fields; // Drop the last metadata fields. - return ret.slice(0, fields.size() - 1); + return fields.drop_back(); } - std::pair - getIdxMemRefIndexAndStride(unsigned idxDim) const { - StorageLayout layout(getSparseTensorEncoding(rType)); - return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, - idxDim); + std::pair getIdxMemRefIndexAndStride(Level lvl) const { + StorageLayout layout(rType.getEncoding()); + return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef, lvl); } Value getAOSMemRef() const { - auto enc = getSparseTensorEncoding(rType); - unsigned cooStart = getCOOStart(enc); - assert(cooStart < enc.getDimLevelType().size()); + const Level cooStart = getCOOStart(rType.getEncoding()); + assert(cooStart < rType.getLvlRank()); return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart); } - RankedTensorType getTensorType() const { return rType; } + RankedTensorType getRankedTensorType() const { return rType; } ValueArrayRef getFields() const { return fields; } protected: - RankedTensorType rType; + SparseTensorType rType; ValueArrayRef fields; }; /// Uses ValueRange for immutable descriptors. class SparseTensorDescriptor : public SparseTensorDescriptorImpl { public: - SparseTensorDescriptor(Type tp, ValueRange buffers) - : SparseTensorDescriptorImpl(tp, buffers) {} + SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers) + : SparseTensorDescriptorImpl(stt, buffers) {} - Value getIdxMemRefOrView(OpBuilder &builder, Location loc, - unsigned idxDim) const; + Value getIdxMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const; }; /// Uses SmallVectorImpl & for mutable descriptors. @@ -359,8 +382,9 @@ public: class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl &> { public: - MutSparseTensorDescriptor(Type tp, SmallVectorImpl &buffers) - : SparseTensorDescriptorImpl &>(tp, buffers) {} + MutSparseTensorDescriptor(SparseTensorType stt, + SmallVectorImpl &buffers) + : SparseTensorDescriptorImpl &>(stt, buffers) {} // Allow implicit type conversion from mutable descriptors to immutable ones // (but not vice versa). @@ -373,21 +397,22 @@ public: /// required field. /// - void setMemRefField(SparseTensorFieldKind kind, std::optional dim, + void setMemRefField(SparseTensorFieldKind kind, std::optional lvl, Value v) { - fields[getMemRefFieldIndex(kind, dim)] = v; + fields[getMemRefFieldIndex(kind, lvl)] = v; } - void setMemRefField(unsigned fidx, Value v) { + void setMemRefField(FieldIndex fidx, Value v) { assert(fidx < fields.size() - 1); fields[fidx] = v; } - void setField(unsigned fidx, Value v) { + void setField(FieldIndex fidx, Value v) { assert(fidx < fields.size()); fields[fidx] = v; } + // FIXME: see note [CLARIFY_DIM_LVL]. void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim, Value v) { @@ -401,14 +426,15 @@ public: std::nullopt, v); } - void setIdxMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { - setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, dim, v); + void setIdxMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { + setSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize, lvl, v); } - void setPtrMemSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { - setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, dim, v); + void setPtrMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) { + setSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize, lvl, v); } + // FIXME: see note [CLARIFY_DIM_LVL]. void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) { setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v); } @@ -428,19 +454,21 @@ inline Value genTuple(OpBuilder &builder, Location loc, Type tp, inline Value genTuple(OpBuilder &builder, Location loc, SparseTensorDescriptor desc) { - return genTuple(builder, loc, desc.getTensorType(), desc.getFields()); + return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields()); } inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) { auto tuple = getTuple(tensor); - return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs()); + SparseTensorType stt(tuple.getResultTypes()[0].cast()); + return SparseTensorDescriptor(stt, tuple.getInputs()); } inline MutSparseTensorDescriptor getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl &fields) { auto tuple = getTuple(tensor); fields.assign(tuple.getInputs().begin(), tuple.getInputs().end()); - return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields); + SparseTensorType stt(tuple.getResultTypes()[0].cast()); + return MutSparseTensorDescriptor(stt, fields); } } // namespace sparse_tensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index f3cf2ce..3555873 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -225,18 +226,20 @@ static bool findAffine(Merger &merger, unsigned tensor, unsigned dim, static unsigned getNumCompoundAffineOnSparseDims(AffineMap affineMap, Value tensor) { unsigned num = 0; - auto enc = getSparseTensorEncoding(tensor.getType()); + const auto enc = getSparseTensorEncoding(tensor.getType()); if (enc) { - ArrayRef exps = affineMap.getResults(); - for (unsigned rank = 0; rank < exps.size(); rank++) { - auto aidx = toOrigDim(enc, rank); - auto affine = exps[aidx]; - if (!affine.isa()) - if (!isDenseDLT(getDimLevelType(enc, rank))) - num++; + const ArrayRef exps = affineMap.getResults(); + const Level lvlRank = enc.getLvlRank(); + assert(static_cast(exps.size()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + const Dimension d = toOrigDim(enc, l); + // FIXME: there's some dim/lvl confusion here; since `d` isn't + // guaranteed to be in bounds (for non-permutations). + if (!exps[d].isa() && !enc.isDenseLvl(l)) + num++; } } - return num; } @@ -252,10 +255,8 @@ static unsigned getNumCompoundAffineOnSparseDims(linalg::GenericOp op) { static bool hasCompoundAffineOnSparseOut(linalg::GenericOp op) { OpOperand *out = op.getDpsInitOperand(0); - auto enc = getSparseTensorEncoding(out->get().getType()); - if (!enc || enc.isAllDense()) + if (getSparseTensorType(out->get()).isAllDense()) return false; - return getNumCompoundAffineOnSparseDims(op.getMatchingIndexingMap(out), out->get()); } @@ -269,16 +270,18 @@ static bool findSparseAnnotations(CodegenEnv &env) { bool annotated = false; unsigned filterLdx = env.merger().getFilterLoopStartingIdx(); for (OpOperand &t : env.op()->getOpOperands()) { - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == env.op().getRank(&t)); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned tensor = t.getOperandNumber(); - AffineExpr a = map.getResult(toOrigDim(enc, d)); - if (!findAffine(env.merger(), tensor, d, a, getDimLevelType(enc, d), - filterLdx)) + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + assert(env.op().getRank(&t) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + const unsigned tensor = t.getOperandNumber(); + // FIXME: `toOrigDim` is deprecated. + const AffineExpr a = map.getResult(toOrigDim(enc, l)); + if (!findAffine(env.merger(), tensor, l, a, enc.getLvlType(l), filterLdx)) return false; // inadmissible affine expression } } @@ -440,15 +443,15 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask, OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. - unsigned n = env.merger().getNumLoops(); + const unsigned n = env.merger().getNumLoops(); std::vector> adjM(n, std::vector(n, false)); std::vector inDegree(n, 0); // in-degree of each node. - auto iteratorTypes = env.op().getIteratorTypesArray(); + const auto iteratorTypes = env.op().getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. for (OpOperand &t : env.op()->getOpOperands()) { // Get map and encoding. - auto map = env.op().getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); + const auto map = env.op().getMatchingIndexingMap(&t); + const auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) @@ -457,16 +460,18 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask, // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr ta = map.getResult(toOrigDim(enc, d)); + const Level lvlRank = map.getNumResults(); + assert(!enc || lvlRank == enc.getLvlRank()); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr ta = map.getResult(toOrigDim(enc, l)); std::optional tldx = - env.merger().getLoopIdx(t.getOperandNumber(), d); + env.merger().getLoopIdx(t.getOperandNumber(), l); // Filter loops should be constructed after all the dependent loops, // i.e., d0 + d1 < filter_loop(d0 + d1) if (tldx && env.merger().isFilterLoop(*tldx)) { - assert(!ta.isa() && - !isDenseDLT(getDimLevelType(enc, d))); + assert(!ta.isa() && !isDenseDLT(enc.getLvlType(l))); addAffineOrderings(adjM, inDegree, ta, AffineExpr(), std::nullopt, tldx); // Now that the ordering of affine expression is captured by filter @@ -481,10 +486,11 @@ static bool computeIterationGraph(CodegenEnv &env, unsigned mask, if (&t == skip) continue; - if (d > 0) { - AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + if (l > 0) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr fa = map.getResult(toOrigDim(enc, l - 1)); std::optional fldx = - env.merger().getLoopIdx(t.getOperandNumber(), d - 1); + env.merger().getLoopIdx(t.getOperandNumber(), l - 1); // Applying order constraints on every pair of dimExpr between two // compound affine expressions can sometime too strict: @@ -576,8 +582,11 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) { /// Generates index for load/store on sparse tensor. static Value genIndex(CodegenEnv &env, OpOperand *t) { auto map = env.op().getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); - AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); + const auto stt = getSparseTensorType(t->get()); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + // FIXME: `toOrigDim` is deprecated. + AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1)); assert(a.getKind() == AffineExprKind::DimId); unsigned idx = a.cast().getPosition(); return env.getLoopIdxValue(idx); @@ -589,15 +598,16 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, linalg::GenericOp op = env.op(); unsigned tensor = t->getOperandNumber(); auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); - unsigned rank = map.getNumResults(); - if (enc) { + const auto stt = getSparseTensorType(t->get()); + if (stt.hasEncoding()) { Value pidx = env.emitter().getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index } else { - for (unsigned d = 0; d < rank; d++) { - AffineExpr a = map.getResult(d); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + AffineExpr a = map.getResult(l); args.push_back(env.emitter().genAffine(builder, a, op.getLoc())); } } @@ -861,11 +871,14 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp, linalg::GenericOp op = env.op(); OpOperand &t = op->getOpOperand(env.exp(exp).tensor); auto map = op.getMatchingIndexingMap(&t); - auto enc = getSparseTensorEncoding(t.get().getType()); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - AffineExpr a = map.getResult(toOrigDim(enc, d)); + const auto stt = getSparseTensorType(t.get()); + const Level lvlRank = stt.getLvlRank(); + assert(static_cast(map.getNumResults()) == lvlRank); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l)); std::optional sldx = - env.merger().getLoopIdx(t.getOperandNumber(), d); + env.merger().getLoopIdx(t.getOperandNumber(), l); if (sldx && env.merger().isFilterLoop(*sldx)) { if (!env.getLoopIdxValue(*sldx)) // The filter loops has not been constructed. @@ -1002,6 +1015,7 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, OpOperand *t = &op->getOpOperand(tid); auto enc = getSparseTensorEncoding(t->get().getType()); // Retrieves the affine expression for the filter loop. + // FIXME: `toOrigDim` is deprecated. AffineExpr a = op.getMatchingIndexingMap(t).getResult(toOrigDim(enc, dim)); return env.emitter().enterFilterLoopOverTensorAtDim(builder, loc, tid, @@ -1192,20 +1206,22 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, unsigned tid, - unsigned lvl) { + Level lvl) { // TODO: Handle affine expression on output tensor. linalg::GenericOp op = env.op(); assert(tid < op.getNumDpsInputs()); OpOperand *input = op.getDpsInputOperands()[tid]; ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); - auto enc = getSparseTensorEncoding(input->get().getType()); + const auto enc = getSparseTensorEncoding(input->get().getType()); if (enc) { - for (unsigned i = lvl, e = affines.size(); i < e; i++) { - AffineExpr affine = affines[toOrigDim(enc, i)]; - if (isDenseDLT(getDimLevelType(enc, i)) && - affine.isa()) + const Level lvlRank = enc.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = lvl; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr affine = affines[toOrigDim(enc, l)]; + if (enc.isDenseLvl(l) && affine.isa()) env.emitter().genDenseAffineAddressAtCurLevel( - builder, op.getLoc(), input->getOperandNumber(), i, affine); + builder, op.getLoc(), input->getOperandNumber(), l, affine); else return; // break on first non-dense non-constant level } @@ -1262,20 +1278,21 @@ static bool translateBitsToTidDimPairs( // We only handle affine expression on input tensors (for now). return; OpOperand *operand = &op->getOpOperand(tid); - auto enc = getSparseTensorEncoding(operand->get().getType()); + const auto stt = getSparseTensorType(operand->get()); // Non-annotated dense tensors requires no special handling. - if (!enc) + if (!stt.hasEncoding()) return; ArrayRef affines = op.getMatchingIndexingMap(operand).getResults(); - assert(affines.size() == enc.getDimLevelType().size()); - for (unsigned i = 0, e = affines.size(); i < e; i++) { - AffineExpr exp = affines[toOrigDim(enc, i)]; + const Level lvlRank = stt.getLvlRank(); + assert(affines.size() == static_cast(lvlRank)); + for (Level l = 0; l < lvlRank; l++) { + // FIXME: `toOrigDim` is deprecated. + AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)]; // Skip simple affine expression and non dense dimensions (which has // it own filter loop). - if (exp.isa() || - !isDenseDLT(getDimLevelType(enc, i))) + if (exp.isa() || !stt.isDenseLvl(l)) continue; // Constant affine expression are handled in genLoop @@ -1292,7 +1309,7 @@ static bool translateBitsToTidDimPairs( // might be accepting out-of-order access between consecutive // dense levels. affineTids.push_back(tid); - affineDims.push_back(i); + affineDims.push_back(l); exps.push_back(exp); } } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index 69f45ec..c65245e 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -227,7 +227,7 @@ func.func @mismatch_values_types(%arg0: tensor) -> memref< #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{redundant dimension argument for querying value memory size}} + // expected-error@+1 {{redundant level argument for querying value memory size}} %0 = sparse_tensor.storage_specifier.get %arg0 val_mem_sz at 0 : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -238,7 +238,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{missing dimension argument}} + // expected-error@+1 {{missing level argument}} %0 = sparse_tensor.storage_specifier.get %arg0 idx_mem_sz : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -249,7 +249,7 @@ func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { - // expected-error@+1 {{requested dimension out of bound}} + // expected-error@+1 {{requested level out of bound}} %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 1 : !sparse_tensor.storage_specifier<#SparseVector> to i64 return %0 : i64 @@ -654,7 +654,7 @@ func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4x func.func @invalid_concat_dim(%arg0: tensor<2x4xf64, #DC>, %arg1: tensor<3x4xf64, #DC>, %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{Failed to concatentate tensors with rank=2 on dimension=4}} + // expected-error@+1 {{Concat-dimension is out of bounds for dimension-rank (4 >= 2)}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 4 : index} : tensor<2x4xf64, #DC>, tensor<3x4xf64, #DC>, @@ -670,7 +670,7 @@ func.func @invalid_concat_dim(%arg0: tensor<2x4xf64, #DC>, func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>, %arg1: tensor<3x4xf64, #DC>, %arg2: tensor<4x4x4xf64, #DCC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{The input tensor $0 has a different rank (rank=1) from the output tensor (rank=2)}} + // expected-error@+1 {{Input tensor $0 has a different rank (rank=1) from the output tensor (rank=2)}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} : tensor<2xf64, #C>, tensor<3x4xf64, #DC>, @@ -684,7 +684,7 @@ func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>, func.func @invalid_concat_size_mismatch_dyn(%arg0: tensor, %arg1: tensor<5x4xf64, #DC>, %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> { - // expected-error@+1 {{Only statically-sized input tensors are supported.}} + // expected-error@+1 {{Input tensor $0 has dynamic shape}} %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} : tensor, tensor<5x4xf64, #DC>, diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir index cfbf14a..fa64660 100644 --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics +// expected-error@+1 {{expected a non-empty array for level types}} #a = #sparse_tensor.encoding<{dimLevelType = []}> -func.func private @scalar(%arg0: tensor) -> () // expected-error {{expected non-scalar sparse tensor}} +func.func private @scalar(%arg0: tensor) -> () // ----- @@ -35,7 +36,8 @@ func.func private @tensor_highorder_mismatch(%arg0: tensor<8xi32, #a>) -> () // ----- -#a = #sparse_tensor.encoding<{dimOrdering = affine_map<(i,j) -> (i,i)>}> // expected-error {{expected a permutation affine map for dimension ordering}} +// expected-error@+1 {{expected a permutation affine map for dimension ordering}} +#a = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"], dimOrdering = affine_map<(i,j) -> (i,i)>}> func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> () // ----- diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 2a1b8df..b55c247 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2175,7 +2175,10 @@ cc_library( cc_library( name = "SparseTensorDialect", srcs = ["lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp"], - hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], + hdrs = [ + "include/mlir/Dialect/SparseTensor/IR/SparseTensor.h", + "include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h", + ], includes = ["include"], deps = [ ":ArithDialect", -- 2.7.4