[mlir][sparse] Move some member functions from SparseTensorDescriptorImpl to MutSpars...
authorbixia1 <bixia@google.com>
Wed, 4 Jan 2023 19:35:45 +0000 (11:35 -0800)
committerbixia1 <bixia@google.com>
Wed, 4 Jan 2023 21:05:43 +0000 (13:05 -0800)
This is to prepare for implementing AOS optimization.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D141002

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h

index 7228a23..ff53640 100644 (file)
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CodegenUtils.h"
+#include "SparseTensorStorageLayout.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -551,4 +552,11 @@ Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
   Type valTp = get1DMemRefType(srcTp.getElementType(),
                                /*withLayout=*/false);
   return builder.create<ToValuesOp>(loc, valTp, tensor);
+}
+
+Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
+                                   Value tensor) {
+  SmallVector<Value> fields;
+  auto desc = getMutDescriptorFromTensorTuple(tensor, fields);
+  return desc.getValMemSize(builder, loc);
 }
\ No newline at end of file
index 4ec9c25..1c8cad5 100644 (file)
@@ -336,6 +336,9 @@ Value genToIndices(OpBuilder &builder, Location loc, Value tensor, uint64_t d,
 /// Infers the result type and generates ToValuesOp.
 Value genToValues(OpBuilder &builder, Location loc, Value tensor);
 
+/// Generates code to retrieve the values size for the sparse tensor.
+Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
+
 } // namespace sparse_tensor
 } // namespace mlir
 
index d138b6d..bb5128b 100644 (file)
@@ -967,9 +967,9 @@ public:
   LogicalResult
   matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Query memSizes for the actually stored values size.
-    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    rewriter.replaceOp(op, desc.getValMemSize(rewriter, op.getLoc()));
+    // Query memSizes for the actually stored values.
+    rewriter.replaceOp(
+        op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
     return success();
   }
 };
index 45aac6d..6124445 100644 (file)
@@ -154,7 +154,7 @@ private:
 /// instead relies on this class to access the right value for the right field.
 template <bool mut>
 class SparseTensorDescriptorImpl {
-private:
+protected:
   // Uses ValueRange for immuatable descriptors; uses SmallVectorImpl<Value> &
   // for mutable descriptors.
   // Using SmallVector for mutable descriptor allows users to reuse it as a tmp
@@ -220,21 +220,6 @@ public:
     return getSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim);
   }
 
-  Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
-                             dim);
-  }
-
-  Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
-                             dim);
-  }
-
-  Value getValMemSize(OpBuilder &builder, Location loc) const {
-    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
-                             std::nullopt);
-  }
-
   Value getPtrMemRef(unsigned ptrDim) const {
     return getMemRefField(SparseTensorFieldKind::PtrMemRef, ptrDim);
   }
@@ -262,25 +247,70 @@ public:
     return fields[fidx];
   }
 
+  ValueRange getMemRefFields() const {
+    ValueRange ret = fields;
+    // Drop the last metadata fields.
+    return ret.slice(0, fields.size() - 1);
+  }
+
+  Type getMemRefElementType(SparseTensorFieldKind kind,
+                            Optional<unsigned> dim) const {
+    return getMemRefField(kind, dim)
+        .getType()
+        .template cast<MemRefType>()
+        .getElementType();
+  }
+
+  RankedTensorType getTensorType() const { return rType; }
+  ValueArrayRef getFields() const { return fields; }
+
+protected:
+  RankedTensorType rType;
+  ValueArrayRef fields;
+};
+
+class MutSparseTensorDescriptor : public SparseTensorDescriptorImpl<true> {
+public:
+  MutSparseTensorDescriptor(Type tp, ValueArrayRef buffers)
+      : SparseTensorDescriptorImpl<true>(tp, buffers) {}
+
+  ///
+  /// Getters: get the value for required field.
+  ///
+
+  Value getPtrMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::PtrMemSize,
+                             dim);
+  }
+
+  Value getIdxMemSize(OpBuilder &builder, Location loc, unsigned dim) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::IdxMemSize,
+                             dim);
+  }
+
+  Value getValMemSize(OpBuilder &builder, Location loc) const {
+    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
+                             std::nullopt);
+  }
+
   ///
   /// Setters: update the value for required field (only enabled for
   /// MutSparseTensorDescriptor).
   ///
 
   template <typename T = Value>
-  void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim,
-                      std::enable_if_t<mut, T> v) {
+  void setMemRefField(SparseTensorFieldKind kind, Optional<unsigned> dim, T v) {
     fields[getMemRefFieldIndex(kind, dim)] = v;
   }
 
   template <typename T = Value>
-  void setMemRefField(unsigned fidx, std::enable_if_t<mut, T> v) {
+  void setMemRefField(unsigned fidx, T v) {
     assert(fidx < fields.size() - 1);
     fields[fidx] = v;
   }
 
   template <typename T = Value>
-  void setField(unsigned fidx, std::enable_if_t<mut, T> v) {
+  void setField(unsigned fidx, T v) {
     assert(fidx < fields.size());
     fields[fidx] = v;
   }
@@ -288,42 +318,19 @@ public:
   template <typename T = Value>
   void setSpecifierField(OpBuilder &builder, Location loc,
                          StorageSpecifierKind kind, Optional<unsigned> dim,
-                         std::enable_if_t<mut, T> v) {
+                         T v) {
     SparseTensorSpecifier md(fields.back());
     md.setSpecifierField(builder, loc, v, kind, dim);
     fields.back() = md;
   }
 
   template <typename T = Value>
-  void setDimSize(OpBuilder &builder, Location loc, unsigned dim,
-                  std::enable_if_t<mut, T> v) {
+  void setDimSize(OpBuilder &builder, Location loc, unsigned dim, T v) {
     setSpecifierField(builder, loc, StorageSpecifierKind::DimSize, dim, v);
   }
-
-  ValueRange getMemRefFields() const {
-    ValueRange ret = fields;
-    // drop the last metadata fields
-    return ret.slice(0, fields.size() - 1);
-  }
-
-  Type getMemRefElementType(SparseTensorFieldKind kind,
-                            Optional<unsigned> dim) const {
-    return getMemRefField(kind, dim)
-        .getType()
-        .template cast<MemRefType>()
-        .getElementType();
-  }
-
-  RankedTensorType getTensorType() const { return rType; }
-  ValueArrayRef getFields() const { return fields; }
-
-private:
-  RankedTensorType rType;
-  ValueArrayRef fields;
 };
 
 using SparseTensorDescriptor = SparseTensorDescriptorImpl<false>;
-using MutSparseTensorDescriptor = SparseTensorDescriptorImpl<true>;
 
 /// Returns the "tuple" value of the adapted tensor.
 inline UnrealizedConversionCastOp getTuple(Value tensor) {