Unify SparseTensorImpl::size_ and TensorImpl::sizes_
authorVitaly Fedyunin <vitalyf@fb.com>
Thu, 13 Dec 2018 16:53:16 +0000 (08:53 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 13 Dec 2018 16:55:35 +0000 (08:55 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15130

Differential Revision: D13434981

Pulled By: VitalyFedyunin

fbshipit-source-id: 98bd4d66834a3c3d2ea577adb0c8413852da095d

aten/src/ATen/SparseTensorImpl.cpp
aten/src/ATen/SparseTensorImpl.h

index 47665c4..0cc5456 100644 (file)
@@ -32,14 +32,13 @@ namespace {
 // values tensor for such an empty tensor.
 SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type)
     : TensorImpl(type_id, data_type, nullptr, false)
-    , size_{0}
     , sparse_dim_(1)
     , dense_dim_(0)
     , indices_(at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)))
     , values_(at::empty({0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(data_type))) {}
 
 IntList SparseTensorImpl::sizes() const {
-  return size_;
+  return sizes_;
 }
 IntList SparseTensorImpl::strides() const {
   AT_ERROR("sparse tensors do not have strides");
@@ -47,10 +46,6 @@ IntList SparseTensorImpl::strides() const {
 bool SparseTensorImpl::is_contiguous() const {
   AT_ERROR("sparse tensors do not have is_contiguous");
 }
-int64_t SparseTensorImpl::size(int64_t d) const {
-  d = at::maybe_wrap_dim(d, dim(), false);
-  return size_[d];
-}
 int64_t SparseTensorImpl::stride(int64_t d) const {
   AT_ERROR("sparse tensors do not have strides");
 }
index 7d8baf5..3ec436a 100644 (file)
@@ -14,11 +14,6 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl {
   // _indices.shape: dimensionality: 2,  shape: (sparse_dim, nnz)
   // _values.shape:  dimensionality: 1 + dense_dim.  shape: (nnz, shape[sparse_dim:])
 
-  // The true size of the sparse tensor (e.g., if you called to_dense()
-  // on it).  When THTensor merges into TensorImpl, this field
-  // should move to the parent class.
-  std::vector<int64_t> size_;
-
   int64_t sparse_dim_ = 0; // number of sparse dimensions
   int64_t dense_dim_ = 0; // number of dense dimensions
 
@@ -48,7 +43,6 @@ public:
   IntList sizes() const override;
   IntList strides() const override;
   bool is_contiguous() const override;
-  int64_t size(int64_t d) const override;
   int64_t stride(int64_t d) const override;
   void resize_dim(int64_t ndim) override;
   void set_size(int64_t dim, int64_t new_size) override;
@@ -63,7 +57,7 @@ public:
   // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
   // respect to indices and values
   void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
-    size_ = size.vec();
+    sizes_ = size.vec();
     sparse_dim_ = sparse_dim;
     dense_dim_ = dense_dim;
     refresh_numel();
@@ -132,7 +126,7 @@ public:
         "shrinking the size of dense dimensions (from ", dense_size_original, " to ", dense_size_new, ") on a non-empty sparse tensor is not supported.\n", alt_options_msg);
     }
 
-    if ((!size.equals(size_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) {
+    if ((!size.equals(sizes_)) || (sparse_dim != sparse_dim_) || (dense_dim != dense_dim_)) {
       auto nnz = values().size(0);
       std::vector<int64_t> values_size = {nnz};
       auto dense_size = size.slice(sparse_dim);
@@ -141,7 +135,7 @@ public:
       indices_.resize_({sparse_dim, nnz});
     }
 
-    size_ = size.vec();
+    sizes_ = size.vec();
     sparse_dim_ = sparse_dim;
     dense_dim_ = dense_dim;
     refresh_numel();
@@ -151,7 +145,7 @@ public:
   void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
     AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
 
-    size_ = size.vec();
+    sizes_ = size.vec();
     sparse_dim_ = sparse_dim;
     dense_dim_ = dense_dim;