[mlir][sparse] avoid using mutable descriptor when unnecessary (NFC)
authorPeiming Liu <peiming@google.com>
Tue, 17 Jan 2023 19:25:40 +0000 (19:25 +0000)
committerPeiming Liu <peiming@google.com>
Tue, 17 Jan 2023 20:54:27 +0000 (20:54 +0000)
Use SparseTensorDescriptor whenever not calling setters, to avoid needing to create a temporal buffer for simple query purposes.

Reviewed By: bixia, wrengr

Differential Revision: https://reviews.llvm.org/D141953

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h

index 9466f6f..f47d304 100644 (file)
@@ -593,7 +593,5 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
                                    Value tensor) {
-  SmallVector<Value> fields;
-  auto desc = getMutDescriptorFromTensorTuple(tensor, fields);
-  return desc.getValMemSize(builder, loc);
-}
\ No newline at end of file
+  return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
+}
index 975403e..4a1a0c9 100644 (file)
@@ -102,11 +102,9 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
 }
 
 /// Gets the dimension size for the given sparse tensor at the given
-/// original dimension 'dim'. Returns std::nullopt if no sparse encoding is
-/// attached to the given tensor type.
-static std::optional<Value>
-sizeFromTensorAtDim(OpBuilder &builder, Location loc,
-                    const SparseTensorDescriptor &desc, unsigned dim) {
+/// original dimension 'dim'.
+static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc,
+                                 SparseTensorDescriptor desc, unsigned dim) {
   RankedTensorType rtp = desc.getTensorType();
   // Access into static dimension can query original type directly.
   // Note that this is typically already done by DimOp's folding.
@@ -119,17 +117,12 @@ sizeFromTensorAtDim(OpBuilder &builder, Location loc,
   return desc.getDimSize(builder, loc, toStoredDim(rtp, dim));
 }
 
-// Gets the dimension size at the given stored dimension 'd', either as a
+// Gets the dimension size at the given stored level 'lvl', either as a
 // constant for a static size, or otherwise dynamically through memSizes.
-Value sizeAtStoredDim(OpBuilder &builder, Location loc,
-                      MutSparseTensorDescriptor desc, unsigned d) {
-  RankedTensorType rtp = desc.getTensorType();
-  unsigned dim = toOrigDim(rtp, d);
-  auto shape = rtp.getShape();
-  if (!ShapedType::isDynamic(shape[dim]))
-    return constantIndex(builder, loc, shape[dim]);
-
-  return desc.getDimSize(builder, loc, d);
+static Value sizeFromTensorAtLvl(OpBuilder &builder, Location loc,
+                                 SparseTensorDescriptor desc, unsigned lvl) {
+  return sizeFromTensorAtDim(builder, loc, desc,
+                             toOrigDim(desc.getTensorType(), lvl));
 }
 
 static void createPushback(OpBuilder &builder, Location loc,
@@ -174,7 +167,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
     // at this level. We will eventually reach a compressed level or
     // otherwise the values array for the from-here "all-dense" case.
     assert(isDenseDim(rtp, r));
-    Value size = sizeAtStoredDim(builder, loc, desc, r);
+    Value size = sizeFromTensorAtLvl(builder, loc, desc, r);
     linear = builder.create<arith::MulIOp>(loc, linear, size);
   }
   // Reached values array so prepare for an insertion.
@@ -436,7 +429,7 @@ static void genInsertBody(OpBuilder &builder, ModuleOp module,
       // Construct the new position as:
       //   pos[d] = size * pos[d-1] + i[d]
       //   <insert @ pos[d] at next dimension d + 1>
-      Value size = sizeAtStoredDim(builder, loc, desc, d);
+      Value size = sizeFromTensorAtLvl(builder, loc, desc, d);
       Value mult = builder.create<arith::MulIOp>(loc, size, pos);
       pos = builder.create<arith::AddIOp>(loc, mult, indices[d]);
     }
@@ -517,7 +510,7 @@ static void genInsertionCallHelper(OpBuilder &builder,
 
 /// Generations insertion finalization code.
 static void genEndInsert(OpBuilder &builder, Location loc,
-                         MutSparseTensorDescriptor desc) {
+                         SparseTensorDescriptor desc) {
   RankedTensorType rtp = desc.getTensorType();
   unsigned rank = rtp.getShape().size();
   for (unsigned d = 0; d < rank; d++) {
@@ -654,10 +647,7 @@ public:
     auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
     auto sz = sizeFromTensorAtDim(rewriter, op.getLoc(), desc, *index);
 
-    if (!sz)
-      return failure();
-
-    rewriter.replaceOp(op, *sz);
+    rewriter.replaceOp(op, sz);
     return success();
   }
 };
@@ -727,8 +717,7 @@ public:
 
     // Replace the sparse tensor deallocation with field deallocations.
     Location loc = op.getLoc();
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     for (auto input : desc.getMemRefFields())
       // Deallocate every buffer used to store the sparse tensor handler.
       rewriter.create<memref::DeallocOp>(loc, input);
@@ -746,8 +735,7 @@ public:
   matchAndRewrite(LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Prepare descriptor.
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     // Generate optional insertion finalization code.
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
@@ -780,11 +768,10 @@ public:
     // recursively rewrite the new DimOp on the **original** tensor.
     unsigned innerDim = toOrigDim(srcType, srcType.getRank() - 1);
     auto sz = sizeFromTensorAtDim(rewriter, loc, desc, innerDim);
-    assert(sz); // This for sure is a sparse tensor
     // Generate a memref for `sz` elements of type `t`.
     auto genAlloc = [&](Type t) {
       auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
-      return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{*sz});
+      return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
     };
     // Allocate temporary buffers for values/filled-switch and added.
     // We do not use stack buffers for this, since the expanded size may
@@ -957,8 +944,7 @@ public:
     // Replace the requested pointer access with corresponding field.
     // The cast_op is inserted by type converter to intermix 1:N type
     // conversion.
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
     rewriter.replaceOp(op, desc.getAOSMemRef());
 
     return success();
index 9ca7149..7ffb2c8 100644 (file)
@@ -202,20 +202,9 @@ private:
 /// field in a consistent way.
 /// Users should not make assumption on how a sparse tensor is laid out but
 /// instead relies on this class to access the right value for the right field.
-template <bool mut>
+template <typename ValueArrayRef>
 class SparseTensorDescriptorImpl {
 protected:
-  // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
-  // for mutable descriptors.
-  // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
-  // buffers to append value for some special cases, though users should be
-  // responsible to restore the buffer to legal states after their use. It is
-  // probably not a clean way, but it is the most efficient way to avoid copying
-  // the fields into another SmallVector. If a more clear way is wanted, we
-  // should change it to MutableArrayRef instead.
-  using ValueArrayRef = typename std::conditional<mut, SmallVectorImpl<Value> &,
-                                                  ValueRange>::type;
-
   SparseTensorDescriptorImpl(Type tp, ValueArrayRef fields)
       : rType(tp.cast<RankedTensorType>()), fields(fields) {
     assert(getSparseTensorEncoding(tp) &&
@@ -223,8 +212,8 @@ protected:
                fields.size());
     // We should make sure the class is trivially copyable (and should be small
     // enough) such that we can pass it by value.
-    static_assert(
-        std::is_trivially_copyable_v<SparseTensorDescriptorImpl<mut>>);
+    static_assert(std::is_trivially_copyable_v<
+                  SparseTensorDescriptorImpl<ValueArrayRef>>);
   }
 
 public:
@@ -262,12 +251,12 @@ public:
 
   Value getMemRefField(SparseTensorFieldKind kind,
                        std::optional<unsigned> dim) const {
-    return fields[getMemRefFieldIndex(kind, dim)];
+    return getField(getMemRefFieldIndex(kind, dim));
   }
 
   Value getMemRefField(unsigned fidx) const {
     assert(fidx < fields.size() - 1);
-    return fields[fidx];
+    return getField(fidx);
   }
 
   Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
@@ -293,6 +282,31 @@ public:
         .getElementType();
   }
 
+  Value getField(unsigned fidx) const {
+    assert(fidx < fields.size());
+    return fields[fidx];
+  }
+
+  ValueRange getMemRefFields() const {
+    ValueRange ret = fields;
+    // Drop the last metadata fields.
+    return ret.slice(0, fields.size() - 1);
+  }
+
+  std::pair<unsigned, unsigned>
+  getIdxMemRefIndexAndStride(unsigned idxDim) const {
+    StorageLayout layout(getSparseTensorEncoding(rType));
+    return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
+                                         idxDim);
+  }
+
+  Value getAOSMemRef() const {
+    auto enc = getSparseTensorEncoding(rType);
+    unsigned cooStart = getCOOStart(enc);
+    assert(cooStart < enc.getDimLevelType().size());
+    return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
+  }
+
   RankedTensorType getTensorType() const { return rType; }
   ValueArrayRef getFields() const { return fields; }
 
@@ -301,25 +315,38 @@ protected:
   ValueArrayRef fields;
 };
 
-class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
+/// Uses ValueRange for immuatable descriptors;
+class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
 public:
-  MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
-      : SparseTensorDescriptorImpl<true>(tp, buffers) {}
+  SparseTensorDescriptor(Type tp, ValueRange buffers)
+      : SparseTensorDescriptorImpl<ValueRange>(tp, buffers) {}
 
-  Value getField(unsigned fidx) const {
-    assert(fidx < fields.size());
-    return fields[fidx];
-  }
+  Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
+                           unsigned idxDim) const;
+};
 
-  ValueRange getMemRefFields() const {
-    ValueRange ret = fields;
-    // Drop the last metadata fields.
-    return ret.slice(0, fields.size() - 1);
+/// Uses SmallVectorImpl<Value> & for mutable descriptors.
+/// Using SmallVector for mutable descriptor allows users to reuse it as a
+/// tmp buffers to append value for some special cases, though users should
+/// be responsible to restore the buffer to legal states after their use. It
+/// is probably not a clean way, but it is the most efficient way to avoid
+/// copying the fields into another SmallVector. If a more clear way is
+/// wanted, we should change it to MutableArrayRef instead.
+class MutSparseTensorDescriptor
+    : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
+public:
+  MutSparseTensorDescriptor(Type tp, SmallVectorImpl<Value> &buffers)
+      : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(tp, buffers) {}
+
+  // Allow implicit type conversion from mutable descriptors to immutable ones
+  // (but not vice versa).
+  /*implicit*/ operator SparseTensorDescriptor() const {
+    return SparseTensorDescriptor(rType, fields);
   }
 
   ///
-  /// Setters: update the value for required field (only enabled for
-  /// MutSparseTensorDescriptor).
+  /// Adds additional setters for mutable descriptor, update the value for
+  /// required field.
   ///
 
   void setMemRefField(SparseTensorFieldKind kind, std::optional<unsigned> dim,
@@ -348,29 +375,6 @@ public:
   void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value v) {
     setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
   }
-
-  std::pair<unsigned, unsigned>
-  getIdxMemRefIndexAndStride(unsigned idxDim) const {
-    StorageLayout layout(getSparseTensorEncoding(rType));
-    return layout.getFieldIndexAndStride(SparseTensorFieldKind::IdxMemRef,
-                                         idxDim);
-  }
-
-  Value getAOSMemRef() const {
-    auto enc = getSparseTensorEncoding(rType);
-    unsigned cooStart = getCOOStart(enc);
-    assert(cooStart < enc.getDimLevelType().size());
-    return getMemRefField(SparseTensorFieldKind::IdxMemRef, cooStart);
-  }
-};
-
-class SparseTensorDescriptor : public SparseTensorDescriptorImpl<false> {
-public:
-  SparseTensorDescriptor(Type tp, ValueArrayRef buffers)
-      : SparseTensorDescriptorImpl<false>(tp, buffers) {}
-
-  Value getIdxMemRefOrView(OpBuilder &builder, Location loc,
-                           unsigned idxDim) const;
 };
 
 /// Returns the "tuple" value of the adapted tensor.
@@ -386,7 +390,7 @@ inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
 }
 
 inline Value genTuple(OpBuilder &builder, Location loc,
-                      MutSparseTensorDescriptor desc) {
+                      SparseTensorDescriptor desc) {
   return genTuple(builder, loc, desc.getTensorType(), desc.getFields());
 }