Revert "[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a Spars...
authorStella Stamenova <stilis@microsoft.com>
Tue, 6 Dec 2022 01:20:01 +0000 (17:20 -0800)
committerStella Stamenova <stilis@microsoft.com>
Tue, 6 Dec 2022 01:20:01 +0000 (17:20 -0800)
This reverts commit 8a7e69d145ff72e7e4fc10ce6b81c3aa4794201c.

This broke the windows mlir buildbot: https://lab.llvm.org/buildbot/#/builders/13/builds/29257

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

index e9b63a6..52f9fef 100644 (file)
@@ -75,21 +75,6 @@ inline bool isSingletonDim(RankedTensorType type, uint64_t d) {
   return isSingletonDLT(getDimLevelType(type, d));
 }
 
-/// Convenience function to test for dense dimension (0 <= d < rank).
-inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) {
-  return isDenseDLT(getDimLevelType(enc, d));
-}
-
-/// Convenience function to test for compressed dimension (0 <= d < rank).
-inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) {
-  return isCompressedDLT(getDimLevelType(enc, d));
-}
-
-/// Convenience function to test for singleton dimension (0 <= d < rank).
-inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) {
-  return isSingletonDLT(getDimLevelType(enc, d));
-}
-
 //
 // Dimension level properties.
 //
index cd2fb58..2ac3f3b 100644 (file)
@@ -90,115 +90,6 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
   return val;
 }
 
-void sparse_tensor::foreachFieldInSparseTensor(
-    const SparseTensorEncodingAttr enc,
-    llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
-                            DimLevelType)>
-        callback) {
-  assert(enc);
-
-#define RETURN_ON_FALSE(idx, kind, dim, dlt)                                   \
-  if (!(callback(idx, kind, dim, dlt)))                                        \
-    return;
-
-  RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u,
-                  DimLevelType::Undef);
-  RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u,
-                  DimLevelType::Undef);
-
-  static_assert(dataFieldIdx == memSizesIdx + 1);
-  unsigned fieldIdx = dataFieldIdx;
-  // Per-dimension storage.
-  for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
-    // Dimension level types apply in order to the reordered dimension.
-    // As a result, the compound type can be constructed directly in the given
-    // order.
-    auto dlt = getDimLevelType(enc, r);
-    if (isCompressedDLT(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
-    } else if (isSingletonDLT(dlt)) {
-      RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
-    } else {
-      assert(isDenseDLT(dlt)); // no fields
-    }
-  }
-  // The values array.
-  RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
-                  DimLevelType::Undef);
-
-#undef RETURN_ON_FALSE
-}
-
-void sparse_tensor::foreachFieldAndTypeInSparseTensor(
-    RankedTensorType rType,
-    llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
-                            DimLevelType)>
-        callback) {
-  auto enc = getSparseTensorEncoding(rType);
-  assert(enc);
-  // Construct the basic types.
-  Type indexType = IndexType::get(enc.getContext());
-  Type idxType = enc.getIndexType();
-  Type ptrType = enc.getPointerType();
-  Type eltType = rType.getElementType();
-  unsigned rank = rType.getShape().size();
-  // memref<rank x index> dimSizes
-  Type dimSizeType = MemRefType::get({rank}, indexType);
-  // memref<n x index> memSizes
-  Type memSizeType =
-      MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType);
-  // memref<? x ptr>  pointers
-  Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
-  // memref<? x idx>  indices
-  Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
-  // memref<? x eltType> values
-  Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
-
-  foreachFieldInSparseTensor(
-      enc,
-      [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType,
-       callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
-                 unsigned dim, DimLevelType dlt) -> bool {
-        switch (fieldKind) {
-        case SparseTensorFieldKind::DimSizes:
-          return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt);
-        case SparseTensorFieldKind::MemSizes:
-          return callback(memSizeType, fieldIdx, fieldKind, dim, dlt);
-        case SparseTensorFieldKind::PtrMemRef:
-          return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
-        case SparseTensorFieldKind::IdxMemRef:
-          return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
-        case SparseTensorFieldKind::ValMemRef:
-          return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
-        };
-      });
-}
-
-unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-  unsigned numFields = 0;
-  foreachFieldInSparseTensor(enc,
-                             [&numFields](unsigned, SparseTensorFieldKind,
-                                          unsigned, DimLevelType) -> bool {
-                               numFields++;
-                               return true;
-                             });
-  return numFields;
-}
-
-unsigned
-sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
-  unsigned numFields = 0; // one value memref
-  foreachFieldInSparseTensor(enc,
-                             [&numFields](unsigned fidx, SparseTensorFieldKind,
-                                          unsigned, DimLevelType) -> bool {
-                               if (fidx >= dataFieldIdx)
-                                 numFields++;
-                               return true;
-                             });
-  assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx);
-  return numFields;
-}
 //===----------------------------------------------------------------------===//
 // Sparse tensor loop emitter class implementations
 //===----------------------------------------------------------------------===//
index 4c86135..bafe752 100644 (file)
@@ -311,222 +311,8 @@ inline bool isZeroRankedTensorOrScalar(Type type) {
 }
 
 //===----------------------------------------------------------------------===//
-// SparseTensorDescriptor and helpers, manage the sparse tensor memory layout
-// scheme.
-//
-// Sparse tensor storage scheme for rank-dimensional tensor is organized
-// as a single compound type with the following fields. Note that every
-// memref with ? size actually behaves as a "vector", i.e. the stored
-// size is the capacity and the used size resides in the memSizes array.
-//
-// struct {
-//   memref<rank x index> dimSizes     ; size in each dimension
-//   memref<n x index> memSizes        ; sizes of ptrs/inds/values
-//   ; per-dimension d:
-//   ;  if dense:
-//        <nothing>
-//   ;  if compresed:
-//        memref<? x ptr>  pointers-d  ; pointers for sparse dim d
-//        memref<? x idx>  indices-d   ; indices for sparse dim d
-//   ;  if singleton:
-//        memref<? x idx>  indices-d   ; indices for singleton dim d
-//   memref<? x eltType> values        ; values
-// };
-//
-//===----------------------------------------------------------------------===//
-enum class SparseTensorFieldKind {
-  DimSizes,
-  MemSizes,
-  PtrMemRef,
-  IdxMemRef,
-  ValMemRef
-};
-
-constexpr uint64_t dimSizesIdx = 0;
-constexpr uint64_t memSizesIdx = dimSizesIdx + 1;
-constexpr uint64_t dataFieldIdx = memSizesIdx + 1;
-
-/// For each field that will be allocated for the given sparse tensor encoding,
-/// calls the callback with the corresponding field index, field kind, dimension
-/// (for sparse tensor level memrefs) and dimlevelType.
-/// The field index always starts with zero and increments by one between two
-/// callback invocations.
-/// Ideally, all other methods should rely on this function to query a sparse
-/// tensor fields instead of relying on ad-hoc index computation.
-void foreachFieldInSparseTensor(
-    SparseTensorEncodingAttr,
-    llvm::function_ref<bool(unsigned /*fieldIdx*/,
-                            SparseTensorFieldKind /*fieldKind*/,
-                            unsigned /*dim (if applicable)*/,
-                            DimLevelType /*DLT (if applicable)*/)>);
-
-/// Same as above, except that it also builds the Type for the corresponding
-/// field.
-void foreachFieldAndTypeInSparseTensor(
-    RankedTensorType,
-    llvm::function_ref<bool(Type /*fieldType*/, unsigned /*fieldIdx*/,
-                            SparseTensorFieldKind /*fieldKind*/,
-                            unsigned /*dim (if applicable)*/,
-                            DimLevelType /*DLT (if applicable)*/)>);
-
-/// Gets the total number of fields for the given sparse tensor encoding.
-unsigned getNumFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Gets the total number of data fields (index arrays, pointer arrays, and a
-/// value array) for the given sparse tensor encoding.
-unsigned getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc);
-
-/// Get the index of the field in memSizes (only valid for data fields).
-inline unsigned getFieldMemSizesIndex(unsigned fid) {
-  assert(fid >= dataFieldIdx);
-  return fid - dataFieldIdx;
-}
-
-/// A helper class around an array of values that corresponding to a sparse
-/// tensor, provides a set of meaningful APIs to query and update a particular
-/// field in a consistent way.
-/// Users should not make assumption on how a sparse tensor is laid out but
-/// instead relies on this class to access the right value for the right field.
-template <bool mut>
-class SparseTensorDescriptorImpl {
-private:
-  template <bool>
-  struct ArrayStorage;
-
-  template <>
-  struct ArrayStorage<false> {
-    using ValueArray = ValueRange;
-  };
-
-  template <>
-  struct ArrayStorage<true> {
-    using ValueArray = SmallVectorImpl<Value> &;
-  };
-
-  // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
-  // for mutable descriptors.
-  // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
-  // buffers to append value for some special cases, though users should be
-  // responsible to restore the buffer to legal states after their use. It is
-  // probably not a clean way, but it is the most efficient way to avoid copying
-  // the fields into another SmallVector. If a more clear way is wanted, we
-  // should change it to MutableArrayRef instead.
-  using Storage = typename ArrayStorage<mut>::ValueArray;
-
-public:
-  SparseTensorDescriptorImpl(Type tp, Storage fields)
-      : rType(tp.cast<RankedTensorType>()), fields(fields) {
-    assert(getSparseTensorEncoding(tp) &&
-           getNumFieldsFromEncoding(getSparseTensorEncoding(tp)) ==
-               fields.size());
-    // We should make sure the class is trivially copyable (and should be small
-    // enough) such that we can pass it by value.
-    static_assert(
-        std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
-  }
-
-  // Implicit (and cheap) type conversion from MutSparseTensorDescriptor to
-  // SparseTensorDescriptor.
-  template <typename T = SparseTensorDescriptorImpl<true>>
-  /*implicit*/ SparseTensorDescriptorImpl(std::enable_if_t<!mut, T> &mDesc)
-      : rType(mDesc.getTensorType()), fields(mDesc.getFields()) {}
-
-  ///
-  /// Getters: get the field index for required field.
-  ///
-
-  unsigned getPtrMemRefIndex(unsigned ptrDim) const {
-    return getFieldIndex(ptrDim, SparseTensorFieldKind::PtrMemRef);
-  }
-
-  unsigned getIdxMemRefIndex(unsigned idxDim) const {
-    return getFieldIndex(idxDim, SparseTensorFieldKind::IdxMemRef);
-  }
-
-  unsigned getValMemRefIndex() const { return fields.size() - 1; }
-
-  unsigned getPtrMemSizesIndex(unsigned dim) const {
-    return getPtrMemRefIndex(dim) - dataFieldIdx;
-  }
-
-  unsigned getIdxMemSizesIndex(unsigned dim) const {
-    return getIdxMemRefIndex(dim) - dataFieldIdx;
-  }
-
-  unsigned getValMemSizesIndex() const {
-    return getValMemRefIndex() - dataFieldIdx;
-  }
-
-  unsigned getNumFields() const { return fields.size(); }
-
-  ///
-  /// Getters: get the value for required field.
-  ///
-
-  Value getDimSizesMemRef() const { return fields[dimSizesIdx]; }
-  Value getMemSizesMemRef() const { return fields[memSizesIdx]; }
-
-  Value getPtrMemRef(unsigned ptrDim) const {
-    return fields[getPtrMemRefIndex(ptrDim)];
-  }
-
-  Value getIdxMemRef(unsigned idxDim) const {
-    return fields[getIdxMemRefIndex(idxDim)];
-  }
-
-  Value getValMemRef() const { return fields[getValMemRefIndex()]; }
-
-  Value getField(unsigned fid) const {
-    assert(fid < fields.size());
-    return fields[fid];
-  }
-
-  ///
-  /// Setters: update the value for required field (only enabled for
-  /// MutSparseTensorDescriptor).
-  ///
-
-  template <typename T = Value>
-  void setField(unsigned fid, std::enable_if_t<mut, T> v) {
-    assert(fid < fields.size());
-    fields[fid] = v;
-  }
-
-  RankedTensorType getTensorType() const { return rType; }
-  Storage getFields() const { return fields; }
-
-  Type getElementType(unsigned fidx) const {
-    return fields[fidx].getType().template cast<MemRefType>().getElementType();
-  }
-
-private:
-  unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const {
-    unsigned fieldIdx = -1u;
-    foreachFieldInSparseTensor(
-        getSparseTensorEncoding(rType),
-        [dim, kind, &fieldIdx](unsigned fIdx, SparseTensorFieldKind fKind,
-                               unsigned fDim, DimLevelType dlt) -> bool {
-          if (fDim == dim && kind == fKind) {
-            fieldIdx = fIdx;
-            // Returns false to break the iteration.
-            return false;
-          }
-          return true;
-        });
-    assert(fieldIdx != -1u);
-    return fieldIdx;
-  }
-
-  RankedTensorType rType;
-  Storage fields;
-};
-
-using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
-
-//===----------------------------------------------------------------------===//
-// SparseTensorLoopEmiter class, manages sparse tensors and helps to
-// generate loop structure to (co)-iterate sparse tensors.
+// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate
+// loop structure to (co)-iterate sparse tensors.
 //
 // An example usage:
 // To generate the following loops over T1<?x?> and T2<?x?>
@@ -559,15 +345,15 @@ public:
   using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
                                            Value memref, Value tensor)>;
 
-  /// Constructor: take an array of tensors inputs, on which the generated
-  /// loops will iterate on. The index of the tensor in the array is also the
+  /// Constructor: take an array of tensors inputs, on which the generated loops
+  /// will iterate on. The index of the tensor in the array is also the
   /// tensor id (tid) used in related functions.
   /// If isSparseOut is set, loop emitter assume that the sparse output tensor
   /// is empty, and will always generate loops on it based on the dim sizes.
   /// An optional array could be provided (by sparsification) to indicate the
   /// loop id sequence that will be generated. It is used to establish the
-  /// mapping between affineDimExpr to the corresponding loop index in the
-  /// loop stack that are maintained by the loop emitter.
+  /// mapping between affineDimExpr to the corresponding loop index in the loop
+  /// stack that are maintained by the loop emitter.
   explicit SparseTensorLoopEmitter(ValueRange tensors,
                                    StringAttr loopTag = nullptr,
                                    bool hasOutput = false,
@@ -582,8 +368,8 @@ public:
   /// Generates a list of operations to compute the affine expression.
   Value genAffine(OpBuilder &builder, AffineExpr a, Location loc);
 
-  /// Enters a new loop sequence, the loops within the same sequence starts
-  /// from the break points of previous loop instead of starting over from 0.
+  /// Enters a new loop sequence, the loops within the same sequence starts from
+  /// the break points of previous loop instead of starting over from 0.
   /// e.g.,
   /// {
   ///   // loop sequence start.
@@ -738,10 +524,10 @@ private:
   ///     scf.reduce.return %val
   ///   }
   /// }
-  /// NOTE: only one instruction will be moved into reduce block,
-  /// transformation will fail if multiple instructions are used to compute
-  /// the reduction value. Return %ret to user, while %val is provided by
-  /// users (`reduc`).
+  /// NOTE: only one instruction will be moved into reduce block, transformation
+  /// will fail if multiple instructions are used to compute the reduction
+  /// value.
+  /// Return %ret to user, while %val is provided by users (`reduc`).
   void exitForLoop(RewriterBase &rewriter, Location loc,
                    MutableArrayRef<Value> reduc);
 
@@ -749,9 +535,9 @@ private:
   void exitCoIterationLoop(OpBuilder &builder, Location loc,
                            MutableArrayRef<Value> reduc);
 
-  /// A optional string attribute that should be attached to the loop
-  /// generated by loop emitter, it might help following passes to identify
-  /// loops that operates on sparse tensors more easily.
+  /// A optional string attribute that should be attached to the loop generated
+  /// by loop emitter, it might help following passes to identify loops that
+  /// operates on sparse tensors more easily.
   StringAttr loopTag;
   /// Whether the loop emitter needs to treat the last tensor as the output
   /// tensor.
@@ -770,8 +556,7 @@ private:
   std::vector<std::vector<Value>> idxBuffer; // to_indices
   std::vector<Value> valBuffer;              // to_value
 
-  // Loop Stack, stores the information of all the nested loops that are
-  // alive.
+  // Loop Stack, stores the information of all the nested loops that are alive.
   std::vector<LoopLevelInfo> loopStack;
 
   // Loop Sequence Stack, stores the unversial index for the current loop
index e059bd3..113347c 100644 (file)
@@ -36,6 +36,10 @@ using FuncGeneratorType =
 
 static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
 
+static constexpr uint64_t dimSizesIdx = 0;
+static constexpr uint64_t memSizesIdx = 1;
+static constexpr uint64_t fieldsIdx = 2;
+
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -45,18 +49,6 @@ static UnrealizedConversionCastOp getTuple(Value tensor) {
   return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
 }
 
-static SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
-  auto tuple = getTuple(tensor);
-  return SparseTensorDescriptor(tuple.getResultTypes()[0], tuple.getInputs());
-}
-
-static MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
-  auto tuple = getTuple(tensor);
-  fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
-  return MutSparseTensorDescriptor(tuple.getResultTypes()[0], fields);
-}
-
 /// Packs the given values as a "tuple" value.
 static Value genTuple(OpBuilder &builder, Location loc, Type tp,
                       ValueRange values) {
@@ -64,14 +56,6 @@ static Value genTuple(OpBuilder &builder, Location loc, Type tp,
       .getResult(0);
 }
 
-static Value genTuple(OpBuilder &builder, Location loc,
-                      SparseTensorDescriptor desc) {
-  return builder
-      .create<UnrealizedConversionCastOp>(loc, desc.getTensorType(),
-                                          desc.getFields())
-      .getResult(0);
-}
-
 /// Flatten a list of operands that may contain sparse tensors.
 static void flattenOperands(ValueRange operands,
                             SmallVectorImpl<Value> &flattened) {
@@ -117,7 +101,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
 
 /// Creates a straightforward counting for-loop.
 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
-                            MutableArrayRef<Value> fields,
+                            SmallVectorImpl<Value> &fields,
                             Value lower = Value()) {
   Type indexType = builder.getIndexType();
   if (!lower)
@@ -134,46 +118,81 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
 /// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
 /// attached to the given tensor type.
 static Optional<Value> sizeFromTensorAtDim(OpBuilder &builder, Location loc,
-                                           SparseTensorDescriptor desc,
-                                           unsigned dim) {
-  RankedTensorType rtp = desc.getTensorType();
+                                           RankedTensorType tensorTp,
+                                           Value adaptedValue, unsigned dim) {
+  auto enc = getSparseTensorEncoding(tensorTp);
+  if (!enc)
+    return std::nullopt;
+
   // Access into static dimension can query original type directly.
   // Note that this is typically already done by DimOp's folding.
-  auto shape = rtp.getShape();
+  auto shape = tensorTp.getShape();
   if (!ShapedType::isDynamic(shape[dim]))
     return constantIndex(builder, loc, shape[dim]);
 
   // Any other query can consult the dimSizes array at field DimSizesIdx,
   // accounting for the reordering applied to the sparse storage.
-  Value idx = constantIndex(builder, loc, toStoredDim(rtp, dim));
-  return builder.create<memref::LoadOp>(loc, desc.getDimSizesMemRef(), idx)
+  auto tuple = getTuple(adaptedValue);
+  Value idx = constantIndex(builder, loc, toStoredDim(tensorTp, dim));
+  return builder
+      .create<memref::LoadOp>(loc, tuple.getInputs()[dimSizesIdx], idx)
       .getResult();
 }
 
 // Gets the dimension size at the given stored dimension 'd', either as a
 // constant for a static size, or otherwise dynamically through memSizes.
-Value sizeAtStoredDim(OpBuilder &builder, Location loc,
-                      SparseTensorDescriptor desc, unsigned d) {
-  RankedTensorType rtp = desc.getTensorType();
+Value sizeAtStoredDim(OpBuilder &builder, Location loc, RankedTensorType rtp,
+                      SmallVectorImpl<Value> &fields, unsigned d) {
   unsigned dim = toOrigDim(rtp, d);
   auto shape = rtp.getShape();
   if (!ShapedType::isDynamic(shape[dim]))
     return constantIndex(builder, loc, shape[dim]);
-
-  return genLoad(builder, loc, desc.getDimSizesMemRef(),
+  return genLoad(builder, loc, fields[dimSizesIdx],
                  constantIndex(builder, loc, d));
 }
 
+/// Translates field index to memSizes index.
+static unsigned getMemSizesIndex(unsigned field) {
+  assert(fieldsIdx <= field);
+  return field - fieldsIdx;
+}
+
+/// Creates a pushback op for given field and updates the fields array
+/// accordingly. This operation also updates the memSizes contents.
 static void createPushback(OpBuilder &builder, Location loc,
-                           MutSparseTensorDescriptor desc, unsigned fidx,
+                           SmallVectorImpl<Value> &fields, unsigned field,
                            Value value, Value repeat = Value()) {
-  Type etp = desc.getElementType(fidx);
-  Value field = desc.getField(fidx);
-  Value newField = builder.create<PushBackOp>(
-      loc, field.getType(), desc.getMemSizesMemRef(), field,
-      toType(builder, loc, value, etp), APInt(64, getFieldMemSizesIndex(fidx)),
+  assert(fieldsIdx <= field && field < fields.size());
+  Type etp = fields[field].getType().cast<ShapedType>().getElementType();
+  fields[field] = builder.create<PushBackOp>(
+      loc, fields[field].getType(), fields[memSizesIdx], fields[field],
+      toType(builder, loc, value, etp), APInt(64, getMemSizesIndex(field)),
       repeat);
-  desc.setField(fidx, newField);
+}
+
+/// Returns field index of sparse tensor type for pointers/indices, when set.
+static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
+  assert(getSparseTensorEncoding(type));
+  RankedTensorType rType = type.cast<RankedTensorType>();
+  unsigned field = fieldsIdx; // start past header
+  for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
+    if (isCompressedDim(rType, r)) {
+      if (r == ptrDim)
+        return field;
+      field++;
+      if (r == idxDim)
+        return field;
+      field++;
+    } else if (isSingletonDim(rType, r)) {
+      if (r == idxDim)
+        return field;
+      field++;
+    } else {
+      assert(isDenseDim(rType, r)); // no fields
+    }
+  }
+  assert(ptrDim == -1u && idxDim == -1u);
+  return field + 1; // return values field index
 }
 
 /// Maps a sparse tensor type to the appropriate compounded buffers.
@@ -182,24 +201,64 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
   auto enc = getSparseTensorEncoding(type);
   if (!enc)
     return std::nullopt;
-
+  // Construct the basic types.
+  auto *context = type.getContext();
   RankedTensorType rType = type.cast<RankedTensorType>();
-  foreachFieldAndTypeInSparseTensor(
-      rType,
-      [&fields](Type fieldType, unsigned fieldIdx,
-                SparseTensorFieldKind /*fieldKind*/, unsigned /*dim*/,
-                DimLevelType /*dlt*/) -> bool {
-        assert(fieldIdx == fields.size());
-        fields.push_back(fieldType);
-        return true;
-      });
+  Type indexType = IndexType::get(context);
+  Type idxType = enc.getIndexType();
+  Type ptrType = enc.getPointerType();
+  Type eltType = rType.getElementType();
+  //
+  // Sparse tensor storage scheme for rank-dimensional tensor is organized
+  // as a single compound type with the following fields. Note that every
+  // memref with ? size actually behaves as a "vector", i.e. the stored
+  // size is the capacity and the used size resides in the memSizes array.
+  //
+  // struct {
+  //   memref<rank x index> dimSizes     ; size in each dimension
+  //   memref<n x index> memSizes        ; sizes of ptrs/inds/values
+  //   ; per-dimension d:
+  //   ;  if dense:
+  //        <nothing>
+  //   ;  if compresed:
+  //        memref<? x ptr>  pointers-d  ; pointers for sparse dim d
+  //        memref<? x idx>  indices-d   ; indices for sparse dim d
+  //   ;  if singleton:
+  //        memref<? x idx>  indices-d   ; indices for singleton dim d
+  //   memref<? x eltType> values        ; values
+  // };
+  //
+  unsigned rank = rType.getShape().size();
+  unsigned lastField = getFieldIndex(type, -1u, -1u);
+  // The dimSizes array and memSizes array.
+  fields.push_back(MemRefType::get({rank}, indexType));
+  fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
+  // Per-dimension storage.
+  for (unsigned r = 0; r < rank; r++) {
+    // Dimension level types apply in order to the reordered dimension.
+    // As a result, the compound type can be constructed directly in the given
+    // order. Clients of this type know what field is what from the sparse
+    // tensor type.
+    if (isCompressedDim(rType, r)) {
+      fields.push_back(MemRefType::get({ShapedType::kDynamic}, ptrType));
+      fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
+    } else if (isSingletonDim(rType, r)) {
+      fields.push_back(MemRefType::get({ShapedType::kDynamic}, idxType));
+    } else {
+      assert(isDenseDim(rType, r)); // no fields
+    }
+  }
+  // The values array.
+  fields.push_back(MemRefType::get({ShapedType::kDynamic}, eltType));
+  assert(fields.size() == lastField);
   return success();
 }
 
 /// Generates code that allocates a sparse storage scheme for given rank.
 static void allocSchemeForRank(OpBuilder &builder, Location loc,
-                               MutSparseTensorDescriptor desc, unsigned r0) {
-  RankedTensorType rtp = desc.getTensorType();
+                               RankedTensorType rtp,
+                               SmallVectorImpl<Value> &fields, unsigned field,
+                               unsigned r0) {
   unsigned rank = rtp.getShape().size();
   Value linear = constantIndex(builder, loc, 1);
   for (unsigned r = r0; r < rank; r++) {
@@ -209,8 +268,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
       // the desired "linear + 1" length property at all times.
       Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
       Value ptrZero = constantZero(builder, loc, ptrType);
-      createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero,
-                     linear);
+      createPushback(builder, loc, fields, field, ptrZero, linear);
       return;
     }
     if (isSingletonDim(rtp, r)) {
@@ -220,23 +278,23 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
     // at this level. We will eventually reach a compressed level or
     // otherwise the values array for the from-here "all-dense" case.
     assert(isDenseDim(rtp, r));
-    Value size = sizeAtStoredDim(builder, loc, desc, r);
+    Value size = sizeAtStoredDim(builder, loc, rtp, fields, r);
     linear = builder.create<arith::MulIOp>(loc, linear, size);
   }
   // Reached values array so prepare for an insertion.
   Value valZero = constantZero(builder, loc, rtp.getElementType());
-  createPushback(builder, loc, desc, desc.getValMemRefIndex(), valZero, linear);
+  createPushback(builder, loc, fields, field, valZero, linear);
+  assert(fields.size() == ++field);
 }
 
 /// Creates allocation operation.
-static Value createAllocation(OpBuilder &builder, Location loc,
-                              MemRefType memRefType, Value sz,
-                              bool enableInit) {
-  Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
-  Type elemType = memRefType.getElementType();
+static Value createAllocation(OpBuilder &builder, Location loc, Type type,
+                              Value sz, bool enableInit) {
+  auto memType = MemRefType::get({ShapedType::kDynamic}, type);
+  Value buffer = builder.create<memref::AllocOp>(loc, memType, sz);
   if (enableInit) {
-    Value fillValue = builder.create<arith::ConstantOp>(
-        loc, elemType, builder.getZeroAttr(elemType));
+    Value fillValue =
+        builder.create<arith::ConstantOp>(loc, type, builder.getZeroAttr(type));
     builder.create<linalg::FillOp>(loc, fillValue, buffer);
   }
   return buffer;
@@ -252,68 +310,69 @@ static Value createAllocation(OpBuilder &builder, Location loc,
 static void createAllocFields(OpBuilder &builder, Location loc, Type type,
                               ValueRange dynSizes, bool enableInit,
                               SmallVectorImpl<Value> &fields) {
+  auto enc = getSparseTensorEncoding(type);
+  assert(enc);
   RankedTensorType rtp = type.cast<RankedTensorType>();
+  Type indexType = builder.getIndexType();
+  Type idxType = enc.getIndexType();
+  Type ptrType = enc.getPointerType();
+  Type eltType = rtp.getElementType();
+  auto shape = rtp.getShape();
+  unsigned rank = shape.size();
   Value heuristic = constantIndex(builder, loc, 16);
-
-  foreachFieldAndTypeInSparseTensor(
-      rtp,
-      [&builder, &fields, loc, heuristic,
-       enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind,
-                   unsigned /*dim*/, DimLevelType /*dlt*/) -> bool {
-        assert(fields.size() == fIdx);
-        auto memRefTp = fType.cast<MemRefType>();
-        Value field;
-        switch (fKind) {
-        case SparseTensorFieldKind::DimSizes:
-        case SparseTensorFieldKind::MemSizes:
-          field = builder.create<memref::AllocOp>(loc, memRefTp);
-          break;
-        case SparseTensorFieldKind::PtrMemRef:
-        case SparseTensorFieldKind::IdxMemRef:
-        case SparseTensorFieldKind::ValMemRef:
-          field =
-              createAllocation(builder, loc, memRefTp, heuristic, enableInit);
-          break;
-        }
-        assert(field);
-        fields.push_back(field);
-        // Returns true to continue the iteration.
-        return true;
-      });
-
-  MutSparseTensorDescriptor desc(rtp, fields);
-
   // Build original sizes.
   SmallVector<Value> sizes;
-  auto shape = rtp.getShape();
-  unsigned rank = shape.size();
   for (unsigned r = 0, o = 0; r < rank; r++) {
     if (ShapedType::isDynamic(shape[r]))
       sizes.push_back(dynSizes[o++]);
     else
       sizes.push_back(constantIndex(builder, loc, shape[r]));
   }
+  // The dimSizes array and memSizes array.
+  unsigned lastField = getFieldIndex(type, -1u, -1u);
+  Value dimSizes =
+      builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
+  Value memSizes = builder.create<memref::AllocOp>(
+      loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
+  fields.push_back(dimSizes);
+  fields.push_back(memSizes);
+  // Per-dimension storage.
+  for (unsigned r = 0; r < rank; r++) {
+    if (isCompressedDim(rtp, r)) {
+      fields.push_back(
+          createAllocation(builder, loc, ptrType, heuristic, enableInit));
+      fields.push_back(
+          createAllocation(builder, loc, idxType, heuristic, enableInit));
+    } else if (isSingletonDim(rtp, r)) {
+      fields.push_back(
+          createAllocation(builder, loc, idxType, heuristic, enableInit));
+    } else {
+      assert(isDenseDim(rtp, r)); // no fields
+    }
+  }
+  // The values array.
+  fields.push_back(
+      createAllocation(builder, loc, eltType, heuristic, enableInit));
+  assert(fields.size() == lastField);
   // Initialize the storage scheme to an empty tensor. Initialized memSizes
   // to all zeros, sets the dimSizes to known values and gives all pointer
   // fields an initial zero entry, so that it is easier to maintain the
   // "linear + 1" length property.
   builder.create<linalg::FillOp>(
-      loc, constantZero(builder, loc, builder.getIndexType()),
-      desc.getMemSizesMemRef()); // zero memSizes
-
-  Value ptrZero =
-      constantZero(builder, loc, getSparseTensorEncoding(rtp).getPointerType());
-  for (unsigned r = 0; r < rank; r++) {
+      loc, ValueRange{constantZero(builder, loc, indexType)},
+      ValueRange{memSizes}); // zero memSizes
+  Value ptrZero = constantZero(builder, loc, ptrType);
+  for (unsigned r = 0, field = fieldsIdx; r < rank; r++) {
     unsigned ro = toOrigDim(rtp, r);
-    // Fills dim sizes array.
-    genStore(builder, loc, sizes[ro], desc.getDimSizesMemRef(),
-             constantIndex(builder, loc, r));
-
-    // Pushes a leading zero to pointers memref.
-    if (isCompressedDim(rtp, r))
-      createPushback(builder, loc, desc, desc.getPtrMemRefIndex(r), ptrZero);
+    genStore(builder, loc, sizes[ro], dimSizes, constantIndex(builder, loc, r));
+    if (isCompressedDim(rtp, r)) {
+      createPushback(builder, loc, fields, field, ptrZero);
+      field += 2;
+    } else if (isSingletonDim(rtp, r)) {
+      field += 1;
+    }
   }
-  allocSchemeForRank(builder, loc, desc, /*rank=*/0);
+  allocSchemeForRank(builder, loc, rtp, fields, fieldsIdx, /*rank=*/0);
 }
 
 /// Helper method that generates block specific to compressed case:
@@ -337,22 +396,19 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
 ///  }
 ///  pos[d] = next
 static Value genCompressed(OpBuilder &builder, Location loc,
-                           MutSparseTensorDescriptor desc,
+                           RankedTensorType rtp, SmallVectorImpl<Value> &fields,
                            SmallVectorImpl<Value> &indices, Value value,
-                           Value pos, unsigned d) {
-  RankedTensorType rtp = desc.getTensorType();
+                           Value pos, unsigned field, unsigned d) {
   unsigned rank = rtp.getShape().size();
   SmallVector<Type> types;
   Type indexType = builder.getIndexType();
   Type boolType = builder.getIntegerType(1);
-  unsigned idxIndex = desc.getIdxMemRefIndex(d);
-  unsigned ptrIndex = desc.getPtrMemRefIndex(d);
   Value one = constantIndex(builder, loc, 1);
   Value pp1 = builder.create<arith::AddIOp>(loc, pos, one);
-  Value plo = genLoad(builder, loc, desc.getField(ptrIndex), pos);
-  Value phi = genLoad(builder, loc, desc.getField(ptrIndex), pp1);
-  Value psz = constantIndex(builder, loc, getFieldMemSizesIndex(idxIndex));
-  Value msz = genLoad(builder, loc, desc.getMemSizesMemRef(), psz);
+  Value plo = genLoad(builder, loc, fields[field], pos);
+  Value phi = genLoad(builder, loc, fields[field], pp1);
+  Value psz = constantIndex(builder, loc, getMemSizesIndex(field + 1));
+  Value msz = genLoad(builder, loc, fields[memSizesIdx], psz);
   Value phim1 = builder.create<arith::SubIOp>(
       loc, toType(builder, loc, phi, indexType), one);
   // Conditional expression.
@@ -362,55 +418,49 @@ static Value genCompressed(OpBuilder &builder, Location loc,
   scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
   types.pop_back();
   builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
-  Value crd = genLoad(builder, loc, desc.getField(idxIndex), phim1);
+  Value crd = genLoad(builder, loc, fields[field + 1], phim1);
   Value eq = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
                                            toType(builder, loc, crd, indexType),
                                            indices[d]);
   builder.create<scf::YieldOp>(loc, eq);
   builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
   if (d > 0)
-    genStore(builder, loc, msz, desc.getField(ptrIndex), pos);
+    genStore(builder, loc, msz, fields[field], pos);
   builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
   builder.setInsertionPointAfter(ifOp1);
   Value p = ifOp1.getResult(0);
-  // If present construct. Note that for a non-unique dimension level, we
-  // simply set the condition to false and rely on CSE/DCE to clean up the IR.
+  // If present construct. Note that for a non-unique dimension level, we simply
+  // set the condition to false and rely on CSE/DCE to clean up the IR.
   //
   // TODO: generate less temporary IR?
   //
-  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
-    types.push_back(desc.getField(i).getType());
+  for (unsigned i = 0, e = fields.size(); i < e; i++)
+    types.push_back(fields[i].getType());
   types.push_back(indexType);
   if (!isUniqueDim(rtp, d))
     p = constantI1(builder, loc, false);
   scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
   // If present (fields unaffected, update next to phim1).
   builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
-
-  // FIXME: This does not looks like a clean way, but probably the most
-  // efficient way.
-  desc.getFields().push_back(phim1);
-  builder.create<scf::YieldOp>(loc, desc.getFields());
-  desc.getFields().pop_back();
-
+  fields.push_back(phim1);
+  builder.create<scf::YieldOp>(loc, fields);
+  fields.pop_back();
   // If !present (changes fields, update next).
   builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
   Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
-  genStore(builder, loc, mszp1, desc.getField(ptrIndex), pp1);
-  createPushback(builder, loc, desc, idxIndex, indices[d]);
+  genStore(builder, loc, mszp1, fields[field], pp1);
+  createPushback(builder, loc, fields, field + 1, indices[d]);
   // Prepare the next dimension "as needed".
   if ((d + 1) < rank)
-    allocSchemeForRank(builder, loc, desc, d + 1);
-
-  desc.getFields().push_back(msz);
-  builder.create<scf::YieldOp>(loc, desc.getFields());
-  desc.getFields().pop_back();
-
+    allocSchemeForRank(builder, loc, rtp, fields, field + 2, d + 1);
+  fields.push_back(msz);
+  builder.create<scf::YieldOp>(loc, fields);
+  fields.pop_back();
   // Update fields and return next pos.
   builder.setInsertionPointAfter(ifOp2);
   unsigned o = 0;
-  for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
-    desc.setField(i, ifOp2.getResult(o++));
+  for (unsigned i = 0, e = fields.size(); i < e; i++)
+    fields[i] = ifOp2.getResult(o++);
   return ifOp2.getResult(o);
 }
 
@@ -438,10 +488,11 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
   // Construct fields and indices arrays from parameters.
   ValueRange tmp = args.drop_back(rank + 1);
   SmallVector<Value> fields(tmp.begin(), tmp.end());
-  MutSparseTensorDescriptor desc(rtp, fields);
   tmp = args.take_back(rank + 1).drop_back();
   SmallVector<Value> indices(tmp.begin(), tmp.end());
   Value value = args.back();
+
+  unsigned field = fieldsIdx; // Start past header.
   Value pos = constantZero(builder, loc, builder.getIndexType());
   // Generate code for every dimension.
   for (unsigned d = 0; d < rank; d++) {
@@ -453,35 +504,39 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
       //   }
       //   pos[d] = indices.size() - 1
       //   <insert @ pos[d] at next dimension d + 1>
-      pos = genCompressed(builder, loc, desc, indices, value, pos, d);
+      pos = genCompressed(builder, loc, rtp, fields, indices, value, pos, field,
+                          d);
+      field += 2;
     } else if (isSingletonDim(rtp, d)) {
       // Create:
       //   indices[d].push_back(i[d])
       //   pos[d] = pos[d-1]
       //   <insert @ pos[d] at next dimension d + 1>
-      createPushback(builder, loc, desc, desc.getIdxMemRefIndex(d), indices[d]);
+      createPushback(builder, loc, fields, field, indices[d]);
+      field += 1;
     } else {
       assert(isDenseDim(rtp, d));
       // Construct the new position as:
       //   pos[d] = size * pos[d-1] + i[d]
       //   <insert @ pos[d] at next dimension d + 1>
-      Value size = sizeAtStoredDim(builder, loc, desc, d);
+      Value size = sizeAtStoredDim(builder, loc, rtp, fields, d);
       Value mult = builder.create<arith::MulIOp>(loc, size, pos);
       pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
     }
   }
   // Reached the actual value append/insert.
   if (!isDenseDim(rtp, rank - 1))
-    createPushback(builder, loc, desc, desc.getValMemRefIndex(), value);
+    createPushback(builder, loc, fields, field++, value);
   else
-    genStore(builder, loc, value, desc.getValMemRef(), pos);
+    genStore(builder, loc, value, fields[field++], pos);
+  assert(fields.size() == field);
   builder.create<func::ReturnOp>(loc, fields);
 }
 
 /// Generates a call to a function to perform an insertion operation. If the
 /// function doesn't exist yet, call `createFunc` to generate the function.
-static void genInsertionCallHelper(OpBuilder &builder,
-                                   MutSparseTensorDescriptor desc,
+static void genInsertionCallHelper(OpBuilder &builder, RankedTensorType rtp,
+                                   SmallVectorImpl<Value> &fields,
                                    SmallVectorImpl<Value> &indices, Value value,
                                    func::FuncOp insertPoint,
                                    StringRef namePrefix,
@@ -489,7 +544,6 @@ static void genInsertionCallHelper(OpBuilder &builder,
   // The mangled name of the function has this format:
   //   <namePrefix>_[C|S|D]_<shape>_<ordering>_<eltType>
   //     _<indexBitWidth>_<pointerBitWidth>
-  RankedTensorType rtp = desc.getTensorType();
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
   nameOstream << namePrefix;
@@ -523,7 +577,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
   auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
 
   // Construct parameters for fields and indices.
-  SmallVector<Value> operands(desc.getFields().begin(), desc.getFields().end());
+  SmallVector<Value> operands(fields.begin(), fields.end());
   operands.append(indices.begin(), indices.end());
   operands.push_back(value);
   Location loc = insertPoint.getLoc();
@@ -536,7 +590,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
     func = builder.create<func::FuncOp>(
         loc, nameOstream.str(),
         FunctionType::get(context, ValueRange(operands).getTypes(),
-                          ValueRange(desc.getFields()).getTypes()));
+                          ValueRange(fields).getTypes()));
     func.setPrivate();
     createFunc(builder, module, func, rtp);
   }
@@ -544,44 +598,42 @@ static void genInsertionCallHelper(OpBuilder &builder,
   // Generate a call to perform the insertion and update `fields` with values
   // returned from the call.
   func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
-  for (size_t i = 0, e = desc.getNumFields(); i < e; i++) {
-    desc.getFields()[i] = call.getResult(i);
+  for (size_t i = 0; i < fields.size(); i++) {
+    fields[i] = call.getResult(i);
   }
 }
 
 /// Generations insertion finalization code.
-static void genEndInsert(OpBuilder &builder, Location loc,
-                         MutSparseTensorDescriptor desc) {
-  RankedTensorType rtp = desc.getTensorType();
+static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
+                         SmallVectorImpl<Value> &fields) {
   unsigned rank = rtp.getShape().size();
+  unsigned field = fieldsIdx; // start past header
   for (unsigned d = 0; d < rank; d++) {
     if (isCompressedDim(rtp, d)) {
       // Compressed dimensions need a pointer cleanup for all entries
       // that were not visited during the insertion pass.
       //
-      // TODO: avoid cleanup and keep compressed scheme consistent at all
-      // times?
+      // TODO: avoid cleanup and keep compressed scheme consistent at all times?
       //
       if (d > 0) {
         Type ptrType = getSparseTensorEncoding(rtp).getPointerType();
-        Value ptrMemRef = desc.getPtrMemRef(d);
-        Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d));
-        Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz);
+        Value mz = constantIndex(builder, loc, getMemSizesIndex(field));
+        Value hi = genLoad(builder, loc, fields[memSizesIdx], mz);
         Value zero = constantIndex(builder, loc, 0);
         Value one = constantIndex(builder, loc, 1);
         // Vector of only one, but needed by createFor's prototype.
-        SmallVector<Value, 1> inits{genLoad(builder, loc, ptrMemRef, zero)};
+        SmallVector<Value, 1> inits{genLoad(builder, loc, fields[field], zero)};
         scf::ForOp loop = createFor(builder, loc, hi, inits, one);
         Value i = loop.getInductionVar();
         Value oldv = loop.getRegionIterArg(0);
-        Value newv = genLoad(builder, loc, ptrMemRef, i);
+        Value newv = genLoad(builder, loc, fields[field], i);
         Value ptrZero = constantZero(builder, loc, ptrType);
         Value cond = builder.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::eq, newv, ptrZero);
         scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(ptrType),
                                                    cond, /*else*/ true);
         builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-        genStore(builder, loc, oldv, ptrMemRef, i);
+        genStore(builder, loc, oldv, fields[field], i);
         builder.create<scf::YieldOp>(loc, oldv);
         builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
         builder.create<scf::YieldOp>(loc, newv);
@@ -589,10 +641,14 @@ static void genEndInsert(OpBuilder &builder, Location loc,
         builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
         builder.setInsertionPointAfter(loop);
       }
+      field += 2;
+    } else if (isSingletonDim(rtp, d)) {
+      field++;
     } else {
-      assert(isDenseDim(rtp, d) || isSingletonDim(rtp, d));
+      assert(isDenseDim(rtp, d));
     }
   }
+  assert(fields.size() == ++field);
 }
 
 //===----------------------------------------------------------------------===//
@@ -683,12 +739,12 @@ public:
   matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Optional<int64_t> index = op.getConstantIndex();
-    if (!index || !getSparseTensorEncoding(adaptor.getSource().getType()))
+    if (!index)
       return failure();
-
-    auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
-    auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
-
+    auto sz =
+        sizeFromTensorAtDim(rewriter, op.getLoc(),
+                            op.getSource().getType().cast<RankedTensorType>(),
+                            adaptor.getSource(), *index);
     if (!sz)
       return failure();
 
@@ -778,14 +834,16 @@ public:
   LogicalResult
   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Prepare descriptor.
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    RankedTensorType srcType =
+        op.getTensor().getType().cast<RankedTensorType>();
+    auto tuple = getTuple(adaptor.getTensor());
+    // Prepare fields.
+    SmallVector<Value> fields(tuple.getInputs());
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
-      genEndInsert(rewriter, op.getLoc(), desc);
+      genEndInsert(rewriter, op.getLoc(), srcType, fields);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), srcType, fields));
     return success();
   }
 };
@@ -797,10 +855,7 @@ public:
   LogicalResult
   matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!getSparseTensorEncoding(op.getTensor().getType()))
-      return failure();
     Location loc = op->getLoc();
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     RankedTensorType srcType =
         op.getTensor().getType().cast<RankedTensorType>();
     Type eltType = srcType.getElementType();
@@ -812,7 +867,8 @@ public:
     // dimension size, translated back to original dimension). Note that we
     // recursively rewrite the new DimOp on the **original** tensor.
     unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
-    auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
+    auto sz = sizeFromTensorAtDim(rewriter, loc, srcType, adaptor.getTensor(),
+                                  innerDim);
     assert(sz); // This for sure is a sparse tensor
     // Generate a memref for `sz` elements of type `t`.
     auto genAlloc = [&](Type t) {
@@ -852,15 +908,16 @@ public:
   matchAndRewrite(CompressOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    RankedTensorType dstType =
+        op.getTensor().getType().cast<RankedTensorType>();
+    Type eltType = dstType.getElementType();
+    auto tuple = getTuple(adaptor.getTensor());
     Value values = adaptor.getValues();
     Value filled = adaptor.getFilled();
     Value added = adaptor.getAdded();
     Value count = adaptor.getCount();
-    RankedTensorType dstType = desc.getTensorType();
-    Type eltType = dstType.getElementType();
-    // Prepare indices.
+    // Prepare fields and indices.
+    SmallVector<Value> fields(tuple.getInputs());
     SmallVector<Value> indices(adaptor.getIndices());
     // If the innermost dimension is ordered, we need to sort the indices
     // in the "added" array prior to applying the compression.
@@ -882,19 +939,19 @@ public:
     //      filled[index] = false;
     //      yield new_memrefs
     //    }
-    scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
+    scf::ForOp loop = createFor(rewriter, loc, count, fields);
     Value i = loop.getInductionVar();
     Value index = genLoad(rewriter, loc, added, i);
     Value value = genLoad(rewriter, loc, values, index);
     indices.push_back(index);
     // TODO: faster for subsequent insertions?
     auto insertPoint = op->template getParentOfType<func::FuncOp>();
-    genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
-                           kInsertFuncNamePrefix, genInsertBody);
+    genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+                           insertPoint, kInsertFuncNamePrefix, genInsertBody);
     genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values,
              index);
     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, index);
-    rewriter.create<scf::YieldOp>(loc, desc.getFields());
+    rewriter.create<scf::YieldOp>(loc, fields);
     rewriter.setInsertionPointAfter(loop);
     Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
@@ -916,18 +973,20 @@ public:
   LogicalResult
   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
-    // Prepare and indices.
+    RankedTensorType dstType =
+        op.getTensor().getType().cast<RankedTensorType>();
+    auto tuple = getTuple(adaptor.getTensor());
+    // Prepare fields and indices.
+    SmallVector<Value> fields(tuple.getInputs());
     SmallVector<Value> indices(adaptor.getIndices());
     // Generate insertion.
     Value value = adaptor.getValue();
     auto insertPoint = op->template getParentOfType<func::FuncOp>();
-    genInsertionCallHelper(rewriter, desc, indices, value, insertPoint,
-                           kInsertFuncNamePrefix, genInsertBody);
+    genInsertionCallHelper(rewriter, dstType, fields, indices, value,
+                           insertPoint, kInsertFuncNamePrefix, genInsertBody);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), dstType, fields));
     return success();
   }
 };
@@ -944,9 +1003,11 @@ public:
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    Value field = Base::getFieldForOp(desc, op);
-    rewriter.replaceOp(op, field);
+    auto tuple = getTuple(adaptor.getTensor());
+    unsigned idx = Base::getIndexForOp(tuple, op);
+    auto fields = tuple.getInputs();
+    assert(idx < fields.size());
+    rewriter.replaceOp(op, fields[idx]);
     return success();
   }
 };
@@ -957,10 +1018,10 @@ class SparseToPointersConverter
 public:
   using SparseGetterOpConverter::SparseGetterOpConverter;
   // Callback for SparseGetterOpConverter.
-  static Value getFieldForOp(const SparseTensorDescriptor &desc,
-                             ToPointersOp op) {
+  static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+                                ToPointersOp op) {
     uint64_t dim = op.getDimension().getZExtValue();
-    return desc.getPtrMemRef(dim);
+    return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
   }
 };
 
@@ -970,10 +1031,10 @@ class SparseToIndicesConverter
 public:
   using SparseGetterOpConverter::SparseGetterOpConverter;
   // Callback for SparseGetterOpConverter.
-  static Value getFieldForOp(const SparseTensorDescriptor &desc,
-                             ToIndicesOp op) {
+  static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
+                                ToIndicesOp op) {
     uint64_t dim = op.getDimension().getZExtValue();
-    return desc.getIdxMemRef(dim);
+    return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
   }
 };
 
@@ -983,9 +1044,10 @@ class SparseToValuesConverter
 public:
   using SparseGetterOpConverter::SparseGetterOpConverter;
   // Callback for SparseGetterOpConverter.
-  static Value getFieldForOp(const SparseTensorDescriptor &desc,
-                             ToValuesOp /*op*/) {
-    return desc.getValMemRef();
+  static unsigned getIndexForOp(UnrealizedConversionCastOp tuple,
+                                ToValuesOp /*op*/) {
+    // The last field holds the value buffer.
+    return tuple.getInputs().size() - 1;
   }
 };
 
@@ -1017,11 +1079,12 @@ public:
   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Query memSizes for the actually stored values size.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    auto tuple = getTuple(adaptor.getTensor());
+    auto fields = tuple.getInputs();
+    unsigned lastField = fields.size() - 1;
     Value field =
-        constantIndex(rewriter, op.getLoc(), desc.getValMemSizesIndex());
-    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, desc.getMemSizesMemRef(),
-                                                field);
+        constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[memSizesIdx], field);
     return success();
   }
 };