From c751cf8b368eb5f63dc2f5cdd6ffc4cfee2c8a04 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 5 Feb 2019 08:52:55 -0800 Subject: [PATCH] Don't throw in operator== for TypeMeta and ScalarType (#16736) Differential Revision: D13957847 Pulled By: ezyang fbshipit-source-id: 3cc01538aab1bbb396c29ce61e0e95118f8d011f --- c10/core/ScalarType.h | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 84c79c6..2b6aba4 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -93,25 +94,35 @@ static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { #undef DEFINE_CASE } -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { +static inline c10::optional tryTypeMetaToScalarType(caffe2::TypeMeta dtype) { #define DEFINE_IF(ctype, name, _) \ if (dtype == caffe2::TypeMeta::Make()) { \ - return ScalarType::name; \ + return {ScalarType::name}; \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF) #undef DEFINE_IF if (dtype == caffe2::TypeMeta()) { - return ScalarType::Undefined; + return {ScalarType::Undefined}; + } + return c10::nullopt; +} + +static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { + if (auto scalar_type = tryTypeMetaToScalarType(dtype)) { + return *scalar_type; } AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); } static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { - return typeMetaToScalarType(m) == t; + if (auto mt = tryTypeMetaToScalarType(m)) { + return (*mt) == t; + } + return false; } static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { - return typeMetaToScalarType(m) == t; + return t == m; } #define DEFINE_CONSTANT(_,name,_2) \ -- 2.7.4