From cc7aec12fd76678d5e702cbf49b5339b7125486f Mon Sep 17 00:00:00 2001 From: Roy Li Date: Thu, 7 Mar 2019 16:16:43 -0800 Subject: [PATCH] Clean up some old ScalarType stuff Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17755 Differential Revision: D14377135 Pulled By: li-roy fbshipit-source-id: 35305760a1621340ba66c61a193ff61cfedfa7e8 --- aten/src/ATen/Utils.h | 6 +++--- aten/src/ATen/function_wrapper.py | 2 +- aten/src/THC/THCStorage.hpp | 9 --------- c10/core/ScalarType.h | 13 ------------- c10/core/ScalarTypeUtils.h | 19 ------------------- c10/core/StorageImpl.h | 11 ++++------- c10/core/TensorOptions.h | 1 - 7 files changed, 8 insertions(+), 53 deletions(-) delete mode 100644 c10/core/ScalarTypeUtils.h diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index ec7e009..009c442 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -31,7 +31,7 @@ static inline const Storage& checked_storage( const char* name, int pos, DeviceType device_type, - DataType data_type) { + caffe2::TypeMeta dtype) { if (expr.device_type() != device_type) { AT_ERROR( "Expected object of device type ", @@ -44,10 +44,10 @@ static inline const Storage& checked_storage( name, "'"); } - if (expr.dtype().id() != data_type) { + if (expr.dtype() != dtype) { AT_ERROR( "Expected object of data type ", - data_type, + dtype, " but got data type ", expr.dtype().id(), " for argument #", diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 737bf8a..7c31123 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -290,7 +290,7 @@ CHECKED_CAST = { '${arg_name},"${arg_name}",${arg_pos}, ' # We're punning here (Backend and DeviceType constructors coincide) # but DeviceType is the correct way to classify storages - 'DeviceType::${Backend}, at::scalarTypeToDataType(ScalarType::${ScalarName}))'), + 'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'), 'THGenerator*': CodeTemplate( 'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(device_type()))'), diff --git a/aten/src/THC/THCStorage.hpp b/aten/src/THC/THCStorage.hpp index 85dab8c..62a1d95 100644 --- a/aten/src/THC/THCStorage.hpp +++ b/aten/src/THC/THCStorage.hpp @@ -13,15 +13,6 @@ #include #include -namespace c10 { - -#if defined(__CUDACC__) || defined(__HIP_PLATFORM_HCC__) -template <> -struct CTypeToScalarType<__half> : public CTypeToScalarType {}; -#endif - -} - THC_API THCStorage* THCStorage_new(THCState* state, caffe2::TypeMeta); THC_API void THCStorage_retain(THCState *state, THCStorage *storage); diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index e959e2e..c5195ed 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -70,19 +70,6 @@ enum class ScalarType : int8_t { NumOptions }; -static inline at::DataType scalarTypeToDataType(ScalarType scalar_type) { -#define DEFINE_CASE(ctype, name, _) \ - case ScalarType::name: \ - return caffe2::TypeIdentifier::Get(); - - switch(scalar_type) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) - case ScalarType::Undefined: return at::DataType::uninitialized(); - default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)"); - } -#undef DEFINE_CASE -} - static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { #define DEFINE_CASE(ctype,name,_) \ case ScalarType:: name : return caffe2::TypeMeta::Make(); diff --git a/c10/core/ScalarTypeUtils.h b/c10/core/ScalarTypeUtils.h deleted file mode 100644 index 1d2a7c4..0000000 --- a/c10/core/ScalarTypeUtils.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -namespace c10 { - -template -struct CTypeToScalarType { -}; - -#define DEFINE_TO_SCALAR_TYPE(ct, st, _2) \ -template <> \ -struct CTypeToScalarType { \ - static inline at::ScalarType to() { return at::ScalarType::st; } \ -}; -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_TO_SCALAR_TYPE) -#undef DEFINE_TO_SCALAR_TYPE - -} // namespace c10 diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index 3f92721..122fd08 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -2,7 +2,6 @@ #include #include -#include #include @@ -64,15 +63,13 @@ struct C10_API StorageImpl final : public c10::intrusive_ptr_target { template inline T* data() const { - // TODO: This is bad: it means storage.data() calls only work on - // T that are valid ScalarType. FIXME! - auto data_type_T = at::scalarTypeToDataType(c10::CTypeToScalarType::to()); - if (dtype().id() != data_type_T) { + auto data_type = caffe2::TypeMeta::Make(); + if (dtype() != data_type) { AT_ERROR( "Attempt to access StorageImpl having data type ", - dtype().id(), + dtype(), " as data type ", - data_type_T); + data_type); } return unsafe_data(); } diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index a91712a..9c90137 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -4,7 +4,6 @@ #include #include #include -#include #include #include -- 2.7.4