Don't throw in operator== for TypeMeta and ScalarType (#16736)
authorAdam Paszke <adam.paszke@gmail.com>
Tue, 5 Feb 2019 16:52:55 +0000 (08:52 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 5 Feb 2019 16:56:22 +0000 (08:56 -0800)
Differential Revision: D13957847

Pulled By: ezyang

fbshipit-source-id: 3cc01538aab1bbb396c29ce61e0e95118f8d011f

c10/core/ScalarType.h

index 84c79c6..2b6aba4 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <c10/util/ArrayRef.h>
 #include <c10/util/Half.h>
+#include <c10/util/Optional.h>
 #include <c10/util/typeid.h>
 
 #include <cstdint>
@@ -93,25 +94,35 @@ static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
 #undef DEFINE_CASE
 }
 
-static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
+static inline c10::optional<ScalarType> tryTypeMetaToScalarType(caffe2::TypeMeta dtype) {
 #define DEFINE_IF(ctype, name, _)                      \
   if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
-    return ScalarType::name;                           \
+    return {ScalarType::name};                         \
   }
   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
 #undef DEFINE_IF
   if (dtype == caffe2::TypeMeta()) {
-    return ScalarType::Undefined;
+    return {ScalarType::Undefined};
+  }
+  return c10::nullopt;
+}
+
+static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
+  if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
+    return *scalar_type;
   }
   AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
 }
 
 static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
-  return typeMetaToScalarType(m) == t;
+  if (auto mt = tryTypeMetaToScalarType(m)) {
+    return (*mt) == t;
+  }
+  return false;
 }
 
 static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
-  return typeMetaToScalarType(m) == t;
+  return t == m;
 }
 
 #define DEFINE_CONSTANT(_,name,_2) \