Clean up some old ScalarType stuff
authorRoy Li <royboy@fb.com>
Fri, 8 Mar 2019 00:16:43 +0000 (16:16 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 8 Mar 2019 00:21:52 +0000 (16:21 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17755

Differential Revision: D14377135

Pulled By: li-roy

fbshipit-source-id: 35305760a1621340ba66c61a193ff61cfedfa7e8

aten/src/ATen/Utils.h
aten/src/ATen/function_wrapper.py
aten/src/THC/THCStorage.hpp
c10/core/ScalarType.h
c10/core/ScalarTypeUtils.h [deleted file]
c10/core/StorageImpl.h
c10/core/TensorOptions.h

index ec7e009..009c442 100644 (file)
@@ -31,7 +31,7 @@ static inline const Storage& checked_storage(
     const char* name,
     int pos,
     DeviceType device_type,
-    DataType data_type) {
+    caffe2::TypeMeta dtype) {
   if (expr.device_type() != device_type) {
     AT_ERROR(
         "Expected object of device type ",
@@ -44,10 +44,10 @@ static inline const Storage& checked_storage(
         name,
         "'");
   }
-  if (expr.dtype().id() != data_type) {
+  if (expr.dtype() != dtype) {
     AT_ERROR(
         "Expected object of data type ",
-        data_type,
+        dtype,
         " but got data type ",
         expr.dtype().id(),
         " for argument #",
index 737bf8a..7c31123 100644 (file)
@@ -290,7 +290,7 @@ CHECKED_CAST = {
             '${arg_name},"${arg_name}",${arg_pos}, '
             # We're punning here (Backend and DeviceType constructors coincide)
             # but DeviceType is the correct way to classify storages
-            'DeviceType::${Backend}, at::scalarTypeToDataType(ScalarType::${ScalarName}))'),
+            'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
     'THGenerator*':
         CodeTemplate(
             'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(device_type()))'),
index 85dab8c..62a1d95 100644 (file)
 #include <cuda_runtime.h>
 #include <cuda_fp16.h>
 
-namespace c10 {
-
-#if defined(__CUDACC__) || defined(__HIP_PLATFORM_HCC__)
-template <>
-struct CTypeToScalarType<__half> : public CTypeToScalarType<Half> {};
-#endif
-
-}
-
 THC_API THCStorage* THCStorage_new(THCState* state, caffe2::TypeMeta);
 
 THC_API void THCStorage_retain(THCState *state, THCStorage *storage);
index e959e2e..c5195ed 100644 (file)
@@ -70,19 +70,6 @@ enum class ScalarType : int8_t {
   NumOptions
 };
 
-static inline at::DataType scalarTypeToDataType(ScalarType scalar_type) {
-#define DEFINE_CASE(ctype, name, _) \
-  case ScalarType::name:            \
-    return caffe2::TypeIdentifier::Get<ctype>();
-
-  switch(scalar_type) {
-    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
-    case ScalarType::Undefined: return at::DataType::uninitialized();
-    default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
-  }
-#undef DEFINE_CASE
-}
-
 static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
 #define DEFINE_CASE(ctype,name,_) \
   case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>();
diff --git a/c10/core/ScalarTypeUtils.h b/c10/core/ScalarTypeUtils.h
deleted file mode 100644 (file)
index 1d2a7c4..0000000
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-#include <c10/core/ScalarType.h>
-
-namespace c10 {
-
-template <typename T>
-struct CTypeToScalarType {
-};
-
-#define DEFINE_TO_SCALAR_TYPE(ct, st, _2)                          \
-template <>                                                        \
-struct CTypeToScalarType<ct> {                                     \
-  static inline at::ScalarType to() { return at::ScalarType::st; } \
-};
-AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO_SCALAR_TYPE)
-#undef DEFINE_TO_SCALAR_TYPE
-
-} // namespace c10
index 3f92721..122fd08 100644 (file)
@@ -2,7 +2,6 @@
 
 #include <c10/core/Allocator.h>
 #include <c10/core/ScalarType.h>
-#include <c10/core/ScalarTypeUtils.h>
 
 #include <c10/util/intrusive_ptr.h>
 
@@ -64,15 +63,13 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target {
 
   template <typename T>
   inline T* data() const {
-    // TODO: This is bad: it means storage.data<T>() calls only work on
-    // T that are valid ScalarType.  FIXME!
-    auto data_type_T = at::scalarTypeToDataType(c10::CTypeToScalarType<T>::to());
-    if (dtype().id() != data_type_T) {
+    auto data_type = caffe2::TypeMeta::Make<T>();
+    if (dtype() != data_type) {
       AT_ERROR(
           "Attempt to access StorageImpl having data type ",
-          dtype().id(),
+          dtype(),
           " as data type ",
-          data_type_T);
+          data_type);
     }
     return unsafe_data<T>();
   }
index a91712a..9c90137 100644 (file)
@@ -4,7 +4,6 @@
 #include <c10/core/Backend.h>
 #include <c10/core/Layout.h>
 #include <c10/core/ScalarType.h>
-#include <c10/core/ScalarTypeUtils.h>
 #include <c10/core/Device.h>
 
 #include <c10/util/Optional.h>