Remove Type::elementSizeInBytes
authorRoy Li <royboy@fb.com>
Fri, 15 Mar 2019 19:52:57 +0000 (12:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Mar 2019 19:56:02 +0000 (12:56 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17785

Reviewed By: ezyang

Differential Revision: D14379074

fbshipit-source-id: 60727f187d61eb571b144bd6eed4dd4908da0b51

30 files changed:
aten/src/ATen/DLConvertor.cpp
aten/src/ATen/UndefinedType.cpp
aten/src/ATen/UndefinedType.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/cuda/Copy.cu
aten/src/ATen/native/cuda/SpectralOps.cu
aten/src/ATen/templates/SparseTypeDerived.cpp
aten/src/ATen/templates/Type.h
aten/src/ATen/templates/TypeDerived.cpp
aten/src/ATen/templates/TypeDerived.h
aten/src/ATen/templates/TypeExtension.cpp
aten/src/ATen/templates/TypeExtension.h
test/cpp_extensions/complex_registration_extension.cpp
tools/autograd/templates/VariableType.h
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/jit/export.cpp
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/register_special_ops.cpp
torch/csrc/utils/tensor_apply.cpp
torch/csrc/utils/tensor_flatten.cpp
torch/csrc/utils/tensor_new.cpp
torch/csrc/utils/tensor_numpy.cpp
torch/lib/THD/base/data_channels/DataChannelGloo.cpp
torch/lib/THD/base/data_channels/DataChannelTCP.cpp
torch/lib/THD/base/data_channels/DataChannelUtils.hpp
torch/lib/THD/base/data_channels/GlooCache.hpp
torch/lib/c10d/ProcessGroupGloo.cpp

index 428ae9e..ce9da56 100644 (file)
@@ -11,7 +11,7 @@ namespace at {
 static DLDataType getDLDataType(const Tensor& t) {
   DLDataType dtype;
   dtype.lanes = 1;
-  dtype.bits = t.dtype().itemsize() * 8;
+  dtype.bits = t.element_size() * 8;
   switch (t.scalar_type()) {
     case ScalarType::Byte:
       dtype.code = DLDataTypeCode::kDLUInt;
index 6c2b452..c608c43 100644 (file)
@@ -47,10 +47,6 @@ TypeID UndefinedType::ID() const {
   return TypeID::Undefined;
 }
 
-size_t UndefinedType::elementSizeInBytes() const {
-  AT_ERROR("elementSizeInBytes not defined for UndefinedType");
-}
-
 Type & UndefinedType::toBackend(Backend b) const {
   if (b == Backend::Undefined) {
     return TypeDefault::toBackend(b);
index 44c3af4..095b6fe 100644 (file)
@@ -22,7 +22,6 @@ struct UndefinedType final : public TypeDefault {
   virtual Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
   virtual std::unique_ptr<Generator> generator() const override;
   virtual const char * toString() const override;
-  virtual size_t elementSizeInBytes() const override;
   virtual Type & toBackend(Backend b) const override;
   virtual Type & toScalarType(ScalarType s) const override;
   virtual TypeID ID() const override;
index b66c631..e58d8e9 100644 (file)
@@ -127,7 +127,6 @@ struct CAFFE2_API Type {
   virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
   virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const = 0;
   virtual const char * toString() const = 0;
-  virtual size_t elementSizeInBytes() const = 0;
   virtual Type & toBackend(Backend b) const = 0;
   virtual Type & toScalarType(ScalarType s) const = 0;
   Type & toSparse() const {
index 6ea80ba..c7d478c 100644 (file)
@@ -327,7 +327,7 @@ static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t
 
 AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
 {
-  int64_t element_size_bytes = src.type().elementSizeInBytes();
+  int64_t element_size_bytes = src.element_size();
   int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
   IntArrayRef replacement_shape;
   for (size_t dim = 0; dim < indices_list.size(); dim++) {
index 579b8e2..c588614 100644 (file)
@@ -183,7 +183,7 @@ void TensorIterator::allocate_outputs() {
     auto& op = operands_[i];
     if (!op.tensor.defined()) {
       AT_ASSERTM(op.type, "no type for operand", i);
-      int element_size = op.type->elementSizeInBytes();
+      int element_size = op.type->typeMeta().itemsize();
       op.stride_bytes = compatible_stride(element_size);
 
       auto tensor_shape = invert_perm(shape_);
@@ -548,7 +548,7 @@ static DimVector compute_stride(const Tensor& tensor, IntArrayRef shape) {
   int ndim = shape.size();
   auto original_shape = tensor.sizes();
   auto original_stride = tensor.strides();
-  auto element_size_in_bytes = tensor.type().elementSizeInBytes();
+  auto element_size_in_bytes = tensor.element_size();
 
   auto stride = DimVector(ndim, 0);
   auto offset = ndim - original_shape.size();
index 274b5e9..c4a1a0b 100644 (file)
@@ -153,7 +153,7 @@ struct CAFFE2_API TensorIterator {
   }
   ScalarType dtype(int arg=0) const { return type(arg).scalarType(); }
   DeviceType device_type(int arg=0) const { return type(arg).device_type(); }
-  int64_t element_size(int arg) const { return type(arg).elementSizeInBytes(); }
+  int64_t element_size(int arg) const { return type(arg).typeMeta().itemsize(); }
   bool is_scalar(int arg) const;
   bool is_cpu_scalar(int arg) const;
 
index 01c0782..35dfb9b 100644 (file)
@@ -165,7 +165,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) {
   AT_CUDA_CHECK(cudaMemcpyAsync(
       dst_contig.data_ptr(),
       src_contig.data_ptr(),
-      src.numel() * src.dtype().itemsize(),
+      src.numel() * src.element_size(),
       cudaMemcpyHostToDevice,
       stream));
   AT_CUDA_CHECK(cudaStreamSynchronize(stream));
@@ -184,7 +184,7 @@ void copy_to_cpu(Tensor& dst, const Tensor& src) {
   AT_CUDA_CHECK(cudaMemcpyAsync(
       dst_contig.data_ptr(),
       src_contig.data_ptr(),
-      src.numel() * src.dtype().itemsize(),
+      src.numel() * src.element_size(),
       cudaMemcpyDeviceToHost,
       stream));
   AT_CUDA_CHECK(cudaStreamSynchronize(stream));
index 19c28ff..c62048f 100644 (file)
@@ -312,7 +312,7 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
   // (see kRoundSmall and kRoundLarge in THCCachingAllocator.cpp), but we do
   // need to check input tensor to make sure that it is not unaligned, e.g.,
   // from a slicing.
-  auto complex_size_bytes = 2 * input.type().elementSizeInBytes();
+  auto complex_size_bytes = 2 * input.element_size();
   if (reinterpret_cast<std::uintptr_t>(input.data_ptr()) % complex_size_bytes != 0) {
     input = input.clone();
     input_was_cloned = true;
index 1aaffe8..43d9b86 100644 (file)
@@ -47,10 +47,6 @@ TypeID ${Type}::ID() const {
   return ${TypeID};
 }
 
-size_t ${Type}::elementSizeInBytes() const {
-  return sizeof(${ScalarType});
-}
-
 ${type_derived_method_definitions}
 
 }
index dfccbbe..52f9eb2 100644 (file)
@@ -76,7 +76,6 @@ struct CAFFE2_API Type {
   virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const = 0;
   virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const = 0;
   virtual const char * toString() const = 0;
-  virtual size_t elementSizeInBytes() const = 0;
   virtual Type & toBackend(Backend b) const = 0;
   virtual Type & toScalarType(ScalarType s) const = 0;
   Type & toSparse() const {
index 42cc995..cd30ba1 100644 (file)
@@ -53,10 +53,6 @@ TypeID ${Type}::ID() const {
   return ${TypeID};
 }
 
-size_t ${Type}::elementSizeInBytes() const {
-  return sizeof(${ScalarType});
-}
-
 /* example
 Tensor * ${Type}::add(Tensor & a, Tensor & b) {
   std::cout << "add Tensor with backend ${Backend}\n";
index 50b3541..f0b8ed6 100644 (file)
@@ -22,7 +22,6 @@ struct ${Type} final : public ${DenseBackend}TypeDefault {
   virtual caffe2::TypeMeta typeMeta() const override;
   virtual Backend backend() const override;
   virtual const char * toString() const override;
-  virtual size_t elementSizeInBytes() const override;
   virtual TypeID ID() const override;
 
   // example
index a50fb2f..78c6ad6 100644 (file)
@@ -26,10 +26,6 @@ Backend ${Type}::backend() const {
   return Backend::${Backend};
 }
 
-size_t ${Type}::elementSizeInBytes() const {
-  AT_ERROR("elementSizeInBytes is not implemented for ${Type}");
-}
-
 ${type_method_definitions}
 
 } // namespace at
index a622efb..324d92e 100644 (file)
@@ -36,7 +36,6 @@ struct CAFFE2_API ${Type} : public TypeDefault {
   Device getDeviceFromPtr(void * data) const override;
   std::unique_ptr<Generator> generator() const override;
   virtual Backend backend() const override;
-  virtual size_t elementSizeInBytes() const override;
 
   ${type_method_declarations}
 };
index 423f89c..e4bc446 100644 (file)
@@ -36,7 +36,6 @@ struct CPUComplexFloatType : public at::CPUTypeDefault {
   caffe2::TypeMeta typeMeta() const override;
   Backend backend() const override;
   const char* toString() const override;
-  size_t elementSizeInBytes() const override;
   TypeID ID() const override;
 
   Tensor empty(IntArrayRef size, const TensorOptions & options) const override {
@@ -74,10 +73,6 @@ TypeID CPUComplexFloatType::ID() const {
   return TypeID::CPUComplexFloat;
 }
 
-size_t CPUComplexFloatType::elementSizeInBytes() const {
-  return sizeof(float);
-}
-
 REGISTER_COMPLEX_HOOKS(ComplexHooks);
 
 } // namespace at
index d4b54e8..1b1ded3 100644 (file)
@@ -43,7 +43,6 @@ struct TORCH_API VariableType final : public at::TypeDefault {
   std::unique_ptr<at::Generator> generator() const override;
   const char * toString() const override;
   at::TypeID ID() const override;
-  size_t elementSizeInBytes() const override;
   at::Type & toBackend(at::Backend b) const override;
   at::Type & toScalarType(at::ScalarType s) const override;
   Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
index 2786b35..640a8d5 100644 (file)
@@ -50,9 +50,6 @@ std::unique_ptr<Generator> VariableType::generator() const {
 const char * VariableType::toString() const {
   return str.c_str();
 }
-size_t VariableType::elementSizeInBytes() const {
-  return baseType->elementSizeInBytes();
-}
 Type & VariableType::toBackend(Backend b) const {
   return *getVariableTypeFromBaseType(baseType->toBackend(b));
 }
index 9511852..2607041 100644 (file)
@@ -448,7 +448,7 @@ void GraphEncoder::EncodeTensor(
     AT_ASSERT(t.is_contiguous());
     tensor_proto->set_raw_data(std::string(
         static_cast<char*>(t.data_ptr()),
-        t.type().elementSizeInBytes() * t.numel()));
+        t.element_size() * t.numel()));
   }
 }
 
@@ -650,7 +650,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
   tensor_proto->set_requires_grad(tensor.requires_grad());
 
   uint64_t record_size =
-      tensor.type().elementSizeInBytes() * tensor.storage().size();
+      tensor.element_size() * tensor.storage().size();
   auto* key = tensor.storage().unsafeGetStorageImpl();
 
   auto storage_it = storageMap.find(key);
@@ -670,7 +670,7 @@ void ScriptModuleSerializer::convertAndWriteTensor(
                                /* stride = */ {1})
                            .cpu();
       AT_ASSERT(
-          storage_tensor.type().elementSizeInBytes() *
+          storage_tensor.element_size() *
               storage_tensor.storage().size() ==
           record_size);
     }
index 955db63..cf885bb 100644 (file)
@@ -237,7 +237,7 @@ void initPythonIRBindings(PyObject* module_) {
                 python_serialized_export_map;
             for (auto& kv : export_map) {
               auto t = kv.second;
-              size_t copy_bytes = t.type().elementSizeInBytes() * t.numel();
+              size_t copy_bytes = t.element_size() * t.numel();
               // TODO: this is an unecessary copy. In theory we can directly
               // return the map from identifier to Tensor, but we need some API
               // in Python to get raw `bytes` containing the raw tensor data.
index 7c7c2ed..1979a03 100644 (file)
@@ -318,7 +318,7 @@ DEFINE_TORCH_TENSOR_OP(bool, bool, at::empty({}, at::CPU(at::kByte).options()).f
             at::empty(sizes, at::initialTensorOptions().dtype(initial_scalar_type)));
 
           recursiveStore((char*)tensor.data_ptr(), sizes, tensor.strides(), 0,
-              tensor.type().elementSizeInBytes(), data);
+              tensor.element_size(), data);
 
           at::ScalarType scalar_type = dtype.isNone() ? tensor.scalar_type() : dtype.toScalarType();
           c10::Device dev = device.isNone() ? tensor.device() : device.toDevice();
index 6bf25bd..1eae084 100644 (file)
@@ -15,7 +15,7 @@ struct StridedData {
   StridedData(const Tensor & tensor)
     : data(tensor.data_ptr())
     , strides(tensor.strides())
-    , elementSize(tensor.type().elementSizeInBytes()) {}
+    , elementSize(tensor.element_size()) {}
 
   void* data;
   IntArrayRef strides;
index cac98e7..37c98f3 100644 (file)
@@ -23,10 +23,10 @@ std::vector<TensorGroup> take_tensors(
     if (type.is_sparse()) {
       const auto& indices = tensor._indices();
       const auto& values = tensor._values();
-      tensor_size = indices.numel() * indices.type().elementSizeInBytes() +
-                    values.numel() * indices.type().elementSizeInBytes();
+      tensor_size = indices.numel() * indices.element_size() +
+                    values.numel() * indices.element_size();
     } else {
-      tensor_size = tensor.numel() * type.elementSizeInBytes();
+      tensor_size = tensor.numel() * tensor.element_size();
     }
 
     auto& type_group = groups[type.ID()];
index c1c1eca..6339cc2 100644 (file)
@@ -228,7 +228,7 @@ Tensor internal_new_from_data(
   auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type)), /*requires_grad=*/false);
   recursive_store(
       (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
-      scalar_type, tensor.type().elementSizeInBytes(), data);
+      scalar_type, tensor.element_size(), data);
   auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type));
   AutoNoGIL no_gil;
   maybe_initialize_cuda(device);
index 4549971..fa0cb54 100644 (file)
@@ -69,7 +69,7 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
   auto sizes = to_numpy_shape(tensor.sizes());
   auto strides = to_numpy_shape(tensor.strides());
   // NumPy strides use bytes. Torch strides use element counts.
-  auto element_size_in_bytes = tensor.type().elementSizeInBytes();
+  auto element_size_in_bytes = tensor.element_size();
   for (auto& stride : strides) {
     stride *= element_size_in_bytes;
   }
index 50595b8..8ba170c 100644 (file)
@@ -172,7 +172,7 @@ void DataChannelGloo::allGatherT(
           "allGather got input and output on different devices");
     }
   }
-  uint64_t tensor_bytes = input.type().elementSizeInBytes() * input.numel();
+  uint64_t tensor_bytes = input.element_size() * input.numel();
   uint64_t all_tensor_bytes = tensor_bytes * output.size();
   auto ret = _cache->getAlgorithm<CollectiveType::ALL_GATHER, T>(
       group_id,
@@ -236,7 +236,7 @@ void DataChannelGloo::allReduceT(
     at::Tensor& t,
     THDReduceOp operation,
     THDGroup group_id) {
-  uint64_t tensor_bytes = t.type().elementSizeInBytes() * t.numel();
+  uint64_t tensor_bytes = t.element_size() * t.numel();
   auto ret = _cache->getAlgorithm<CollectiveType::ALL_REDUCE, T>(
       group_id,
       _groups.at(group_id),
@@ -276,7 +276,7 @@ void DataChannelGloo::broadcastT(
     at::Tensor& data,
     rank_type src_rank,
     THDGroup group_id) {
-  uint64_t tensor_bytes = data.type().elementSizeInBytes() * data.numel();
+  uint64_t tensor_bytes = data.element_size() * data.numel();
   auto ret = _cache->getAlgorithm<CollectiveType::BROADCAST, T>(
       group_id,
       _groups.at(group_id),
index da571a5..646de04 100644 (file)
@@ -268,7 +268,7 @@ void DataChannelTCP::allGather(
   memcpy(
       output[group_rank].data_ptr(),
       input.data_ptr(),
-      input.type().elementSizeInBytes() * input.numel());
+      input.element_size() * input.numel());
 
   auto j = group_rank, jnext = left;
   for (rank_type i = 0; i < group.size(); ++i) {
@@ -315,7 +315,7 @@ void DataChannelTCP::gather(
         memcpy(
             output.at(i).data_ptr(),
             input.data_ptr(),
-            input.numel() * input.type().elementSizeInBytes());
+            input.numel() * input.element_size());
       }
     }
   }
@@ -355,7 +355,7 @@ void DataChannelTCP::scatter(
         memcpy(
             output.data_ptr(),
             input.at(i).data_ptr(),
-            output.numel() * output.type().elementSizeInBytes());
+            output.numel() * output.element_size());
       }
     }
   }
@@ -389,7 +389,7 @@ void DataChannelTCP::allReduce(
   if (!exists)
     return;
 
-  uint64_t tensor_bytes = data.type().elementSizeInBytes() * data.numel();
+  uint64_t tensor_bytes = data.element_size() * data.numel();
   auto tmp_tensor = data.clone();
 
   auto pof2 = pow2(group.size());
@@ -489,7 +489,7 @@ void DataChannelTCP::reduce(
     std::memcpy(
         data.data_ptr(),
         result_tensor.data_ptr(),
-        data.type().elementSizeInBytes() * data.numel());
+        data.element_size() * data.numel());
 }
 
 void DataChannelTCP::broadcast(
@@ -703,7 +703,7 @@ void DataChannelTCP::_send(const at::Tensor& data, rank_type dst_rank) {
     throw std::logic_error("tensor to send is not contiguous");
 
   // send size of tensor data in bytes
-  uint64_t tensor_bytes = data.type().elementSizeInBytes() * data.numel();
+  uint64_t tensor_bytes = data.element_size() * data.numel();
   send_bytes<uint64_t>(process_dst.socket, &tensor_bytes, 1, true);
 
   // send data (bytes)
@@ -759,7 +759,7 @@ void DataChannelTCP::_receive(const at::Tensor& data, rank_type src_rank) {
   recv_bytes<uint64_t>(process_src.socket, &tensor_bytes, 1);
 
   uint64_t actual_tensor_bytes =
-      data.type().elementSizeInBytes() * data.numel();
+      data.element_size() * data.numel();
   if (actual_tensor_bytes == tensor_bytes) {
     recv_bytes<std::uint8_t>(
         process_src.socket,
index 058ad1e..6d1b4fe 100644 (file)
@@ -17,8 +17,8 @@ inline void assertSameSizeAndType(
     const at::Tensor& tensor1,
     const at::Tensor& tensor2,
     std::string prefix = std::string()) {
-  bool equal = tensor1.type().elementSizeInBytes() ==
-          tensor2.type().elementSizeInBytes() &&
+  bool equal = tensor1.element_size() ==
+          tensor2.element_size() &&
       tensor1.numel() == tensor2.numel() && tensor1.type() == tensor2.type();
 
   if (!prefix.empty())
index e00505b..4cf16e6 100644 (file)
@@ -184,7 +184,7 @@ struct GlooCache {
   }
 
   static void memcpy_input(value_type& info, at::Tensor& t) {
-    uint64_t tensor_bytes = t.type().elementSizeInBytes() * t.numel();
+    uint64_t tensor_bytes = t.element_size() * t.numel();
     auto t_dev = getDeviceType(t);
     auto input_buffer = GlooCache::input_buffer(info).get();
 
@@ -206,7 +206,7 @@ struct GlooCache {
   }
 
   static void memcpy_output(value_type& info, at::Tensor& t) {
-    uint64_t tensor_bytes = t.type().elementSizeInBytes() * t.numel();
+    uint64_t tensor_bytes = t.element_size() * t.numel();
     auto t_dev = getDeviceType(t);
     auto output_buffer = GlooCache::output_buffer(info).get();
 
index 7c8cf4c..140c85e 100644 (file)
@@ -1389,7 +1389,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::send(
   auto& tensor = checkSingleTensor(tensors);
   auto utag = checkTag(tag);
   auto ptr = tensor.data_ptr();
-  auto size = tensor.numel() * tensor.type().elementSizeInBytes();
+  auto size = tensor.numel() * tensor.element_size();
 
   // Construct unbound buffer.
   auto& context = contexts_[0];
@@ -1408,7 +1408,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::recv(
   auto& tensor = checkSingleTensor(tensors);
   auto utag = checkTag(tag);
   auto ptr = tensor.data_ptr();
-  auto size = tensor.numel() * tensor.type().elementSizeInBytes();
+  auto size = tensor.numel() * tensor.element_size();
 
   // Construct unbound buffer.
   auto& context = contexts_[0];
@@ -1426,7 +1426,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::recvAnysource(
   auto& tensor = checkSingleTensor(tensors);
   auto utag = checkTag(tag);
   auto ptr = tensor.data_ptr();
-  auto size = tensor.numel() * tensor.type().elementSizeInBytes();
+  auto size = tensor.numel() * tensor.element_size();
 
   // Construct unbound buffer.
   auto& context = contexts_[0];