For some files that are touched by the QTensor diff (#18765)
authorJerry Zhang <jerryzh@fb.com>
Wed, 3 Apr 2019 19:01:57 +0000 (12:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:31 +0000 (12:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18765

att

Reviewed By: ZolotukhinM

Differential Revision: D14733442

fbshipit-source-id: 525002034e6dccc2045da645e1193671fd0474b3

aten/src/ATen/DLConvertor.cpp
c10/core/Scalar.h
c10/core/ScalarType.h
c10/util/typeid.cpp
c10/util/typeid.h
torch/csrc/utils/tensor_dtypes.cpp

index ce9da56..de0b1b0 100644 (file)
@@ -4,7 +4,6 @@
 #include <iostream>
 #include <sstream>
 
-
 using namespace std;
 namespace at {
 
@@ -54,7 +53,6 @@ static DLDataType getDLDataType(const Tensor& t) {
   return dtype;
 }
 
-
 static DLContext getDLContext(const Type& type, const int64_t& device_id) {
   DLContext ctx;
   ctx.device_id = device_id;
@@ -66,7 +64,6 @@ static DLContext getDLContext(const Type& type, const int64_t& device_id) {
   return ctx;
 }
 
-
 static DeviceType getATenDeviceType(const DLContext& ctx) {
   switch (ctx.device_type) {
     case DLDeviceType::kDLCPU:
@@ -78,15 +75,16 @@ static DeviceType getATenDeviceType(const DLContext& ctx) {
     case DLDeviceType::kDLROCM:
       return DeviceType::HIP;
     default:
-      throw std::logic_error("Unsupported device_type: " + std::to_string(ctx.device_type));
+      throw std::logic_error(
+          "Unsupported device_type: " + std::to_string(ctx.device_type));
   }
   return DeviceType::CPU; // impossible
 }
 
-
 ScalarType toScalarType(const DLDataType& dtype) {
   ScalarType stype;
-  if (dtype.lanes != 1) throw std::logic_error("ATen does not support lanes != 1");
+  if (dtype.lanes != 1)
+    throw std::logic_error("ATen does not support lanes != 1");
   switch (dtype.code) {
     case DLDataTypeCode::kDLUInt:
       switch (dtype.bits) {
@@ -94,7 +92,8 @@ ScalarType toScalarType(const DLDataType& dtype) {
           stype = ScalarType::Byte;
           break;
         default:
-          throw std::logic_error("Unsupported kUInt bits " + std::to_string(dtype.bits));
+          throw std::logic_error(
+              "Unsupported kUInt bits " + std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLInt:
@@ -112,7 +111,8 @@ ScalarType toScalarType(const DLDataType& dtype) {
           stype = ScalarType::Long;
           break;
         default:
-          throw std::logic_error("Unsupported kInt bits " + std::to_string(dtype.bits));
+          throw std::logic_error(
+              "Unsupported kInt bits " + std::to_string(dtype.bits));
       }
       break;
     case DLDataTypeCode::kDLFloat:
@@ -127,7 +127,8 @@ ScalarType toScalarType(const DLDataType& dtype) {
           stype = ScalarType::Double;
           break;
         default:
-          throw std::logic_error("Unsupported kFloat bits " + std::to_string(dtype.bits));
+          throw std::logic_error(
+              "Unsupported kFloat bits " + std::to_string(dtype.bits));
       }
       break;
     default:
@@ -141,15 +142,14 @@ struct ATenDLMTensor {
   DLManagedTensor tensor;
 };
 
-void deleter(DLManagedTensor * arg) {
+void deleter(DLManagedTensor* arg) {
   delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
 }
 
-
-// This function returns a shared_ptr to memory managed DLpack tensor constructed
-// out of ATen tensor
+// This function returns a shared_ptr to memory managed DLpack tensor
+// constructed out of ATen tensor
 DLManagedTensor* toDLPack(const Tensor& src) {
-  ATenDLMTensor * atDLMTensor(new ATenDLMTensor);
+  ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
   atDLMTensor->handle = src;
   atDLMTensor->tensor.manager_ctx = atDLMTensor;
   atDLMTensor->tensor.deleter = &deleter;
@@ -161,23 +161,25 @@ DLManagedTensor* toDLPack(const Tensor& src) {
   atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
   atDLMTensor->tensor.dl_tensor.ndim = src.dim();
   atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
-  atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
-  atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
+  atDLMTensor->tensor.dl_tensor.shape =
+      const_cast<int64_t*>(src.sizes().data());
+  atDLMTensor->tensor.dl_tensor.strides =
+      const_cast<int64_t*>(src.strides().data());
   atDLMTensor->tensor.dl_tensor.byte_offset = 0;
   return &(atDLMTensor->tensor);
 }
 
-
 Tensor fromDLPack(const DLManagedTensor* src) {
   DeviceType device_type = getATenDeviceType(src->dl_tensor.ctx);
   ScalarType stype = toScalarType(src->dl_tensor.dtype);
-  auto deleter = [src](void * self) {
+  auto deleter = [src](void* self) {
     src->deleter(const_cast<DLManagedTensor*>(src));
   };
-  return at::from_blob(src->dl_tensor.data,
+  return at::from_blob(
+      src->dl_tensor.data,
       IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
       IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
       deleter,
       at::device(device_type).dtype(stype));
 }
-} //namespace at
+} // namespace at
index 94fe216..7952e86 100644 (file)
@@ -6,27 +6,27 @@
 #include <string>
 #include <utility>
 
-#include <c10/macros/Macros.h>
 #include <c10/core/ScalarType.h>
+#include <c10/macros/Macros.h>
 #include <c10/util/Half.h>
 
 namespace c10 {
 
 /**
  * Scalar represents a 0-dimensional tensor which contains a single element.
- * Unlike a tensor, numeric literals (in C++) are implicitly convertible to Scalar
- * (which is why, for example, we provide both add(Tensor) and add(Scalar) overloads
- * for many operations). It may also be used in circumstances where you statically
- * know a tensor is 0-dim and single size, but don't know it's type.
+ * Unlike a tensor, numeric literals (in C++) are implicitly convertible to
+ * Scalar (which is why, for example, we provide both add(Tensor) and
+ * add(Scalar) overloads for many operations). It may also be used in
+ * circumstances where you statically know a tensor is 0-dim and single size,
+ * but don't know it's type.
  */
 class C10_API Scalar {
  public:
   Scalar() : Scalar(int64_t(0)) {}
 
-#define DEFINE_IMPLICIT_CTOR(type,name,member) \
-  Scalar(type vv) \
-  : tag(Tag::HAS_##member) { \
-    v . member = convert<decltype(v.member),type>(vv); \
+#define DEFINE_IMPLICIT_CTOR(type, name, member)      \
+  Scalar(type vv) : tag(Tag::HAS_##member) {          \
+    v.member = convert<decltype(v.member), type>(vv); \
   }
   // We can't set v in the initializer list using the
   // syntax v{ .member = ... } because it doesn't work on MSVC
@@ -35,14 +35,16 @@ class C10_API Scalar {
 
 #undef DEFINE_IMPLICIT_CTOR
 
-// Value* is both implicitly convertible to SymbolicVariable and bool which
-// causes ambiguosity error. Specialized constructor for bool resolves this problem.
-template <typename T,
-       typename std::enable_if<std::is_same<T, bool>::value, bool>::type* = nullptr>
-Scalar(T vv)
-: tag(Tag::HAS_i) {
-  v.i = convert<decltype(v.i), bool>(vv);
-}
+  // Value* is both implicitly convertible to SymbolicVariable and bool which
+  // causes ambiguosity error. Specialized constructor for bool resolves this
+  // problem.
+  template <
+      typename T,
+      typename std::enable_if<std::is_same<T, bool>::value, bool>::type* =
+          nullptr>
+  Scalar(T vv) : tag(Tag::HAS_i) {
+    v.i = convert<decltype(v.i), bool>(vv);
+  }
 
 #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
   Scalar(type vv) : tag(Tag::HAS_##member) {             \
@@ -50,28 +52,29 @@ Scalar(T vv)
     v.member[1] = c10::convert<double>(vv.imag());       \
   }
 
-  DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
-  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
-  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
+  DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf, ComplexHalf, z)
+  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>, ComplexFloat, z)
+  DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>, ComplexDouble, z)
 
 #undef DEFINE_IMPLICIT_COMPLEX_CTOR
 
-#define DEFINE_ACCESSOR(type,name,member) \
-  type to##name () const { \
-    if (Tag::HAS_d == tag) { \
-      return checked_convert<type, double>(v.d, #type); \
-    } else if (Tag::HAS_z == tag) { \
-      return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
-    } else { \
-      return checked_convert<type, int64_t>(v.i, #type); \
-    } \
+#define DEFINE_ACCESSOR(type, name, member)               \
+  type to##name() const {                                 \
+    if (Tag::HAS_d == tag) {                              \
+      return checked_convert<type, double>(v.d, #type);   \
+    } else if (Tag::HAS_z == tag) {                       \
+      return checked_convert<type, std::complex<double>>( \
+          {v.z[0], v.z[1]}, #type);                       \
+    } else {                                              \
+      return checked_convert<type, int64_t>(v.i, #type);  \
+    }                                                     \
   }
 
   // TODO: Support ComplexHalf accessor
   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
 
-  //also support scalar.to<int64_t>();
-  template<typename T>
+  // also support scalar.to<int64_t>();
+  template <typename T>
   T to();
 
 #undef DEFINE_ACCESSOR
@@ -87,7 +90,7 @@ Scalar(T vv)
 
   Scalar operator-() const;
 
-private:
+ private:
   enum class Tag { HAS_d, HAS_i, HAS_z };
   Tag tag;
   union {
@@ -101,16 +104,16 @@ private:
 };
 
 // define the scalar.to<int64_t>() specializations
-template<typename T>
+template <typename T>
 inline T Scalar::to() {
   throw std::runtime_error("to() cast to unexpected type.");
 }
 
-#define DEFINE_TO(T,name,_) \
-template<> \
-inline T Scalar::to<T>() { \
-  return to##name(); \
-}
+#define DEFINE_TO(T, name, _) \
+  template <>                 \
+  inline T Scalar::to<T>() {  \
+    return to##name();        \
+  }
 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
 #undef DEFINE_TO
-}
+} // namespace c10
index 2d852d1..11549e1 100644 (file)
 #include <c10/util/Optional.h>
 #include <c10/util/typeid.h>
 
+#include <complex>
 #include <cstdint>
 #include <iostream>
-#include <complex>
 
 namespace c10 {
 
 // NB: Order matters for this macro; it is relied upon in
 // _promoteTypesLookup and the serialization format.
-#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
-_(uint8_t,Byte,i)  /* 0 */ \
-_(int8_t,Char,i)   /* 1 */ \
-_(int16_t,Short,i) /* 2 */ \
-_(int,Int,i)       /* 3 */ \
-_(int64_t,Long,i)  /* 4 */ \
-_(at::Half,Half,d) /* 5 */ \
-_(float,Float,d)   /* 6 */ \
-_(double,Double,d) /* 7 */ \
-_(at::ComplexHalf,ComplexHalf,z)        /* 8 */ \
-_(std::complex<float>,ComplexFloat,z)   /* 9 */ \
-_(std::complex<double>,ComplexDouble,z) /* 10 */ \
-_(bool,Bool,i)     /* 11 */
+#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_)       \
+  _(uint8_t, Byte, i) /* 0 */                        \
+  _(int8_t, Char, i) /* 1 */                         \
+  _(int16_t, Short, i) /* 2 */                       \
+  _(int, Int, i) /* 3 */                             \
+  _(int64_t, Long, i) /* 4 */                        \
+  _(at::Half, Half, d) /* 5 */                       \
+  _(float, Float, d) /* 6 */                         \
+  _(double, Double, d) /* 7 */                       \
+  _(at::ComplexHalf, ComplexHalf, z) /* 8 */         \
+  _(std::complex<float>, ComplexFloat, z) /* 9 */    \
+  _(std::complex<double>, ComplexDouble, z) /* 10 */ \
+  _(bool, Bool, i) /* 11 */
 
 // If you want to support ComplexHalf for real, replace occurrences
 // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX.  But
 // beware: convert() doesn't work for all the conversions you need...
 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
-_(uint8_t,Byte,i)  \
-_(int8_t,Char,i)   \
-_(int16_t,Short,i) \
-_(int,Int,i)       \
-_(int64_t,Long,i)  \
-_(at::Half,Half,d) \
-_(float,Float,d)   \
-_(double,Double,d) \
-_(std::complex<float>,ComplexFloat,z) \
-_(std::complex<double>,ComplexDouble,z) \
-_(bool,Bool,i)
+  _(uint8_t, Byte, i)                                              \
+  _(int8_t, Char, i)                                               \
+  _(int16_t, Short, i)                                             \
+  _(int, Int, i)                                                   \
+  _(int64_t, Long, i)                                              \
+  _(at::Half, Half, d)                                             \
+  _(float, Float, d)                                               \
+  _(double, Double, d)                                             \
+  _(std::complex<float>, ComplexFloat, z)                          \
+  _(std::complex<double>, ComplexDouble, z)                        \
+  _(bool, Bool, i)
 
 #define AT_FORALL_SCALAR_TYPES(_) \
-_(uint8_t,Byte,i)  \
-_(int8_t,Char,i)   \
-_(int16_t,Short,i) \
-_(int,Int,i)       \
-_(int64_t,Long,i)  \
-_(at::Half,Half,d) \
-_(float,Float,d)   \
-_(double,Double,d)
+  _(uint8_t, Byte, i)             \
+  _(int8_t, Char, i)              \
+  _(int16_t, Short, i)            \
+  _(int, Int, i)                  \
+  _(int64_t, Long, i)             \
+  _(at::Half, Half, d)            \
+  _(float, Float, d)              \
+  _(double, Double, d)
 
 #define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \
-_(uint8_t,Byte,i)  \
-_(int8_t,Char,i)   \
-_(int16_t,Short,i) \
-_(int,Int,i)       \
-_(int64_t,Long,i)  \
-_(at::Half,Half,d) \
-_(float,Float,d)   \
-_(double,Double,d) \
-_(bool,Bool,i)
+  _(uint8_t, Byte, i)                      \
+  _(int8_t, Char, i)                       \
+  _(int16_t, Short, i)                     \
+  _(int, Int, i)                           \
+  _(int64_t, Long, i)                      \
+  _(at::Half, Half, d)                     \
+  _(float, Float, d)                       \
+  _(double, Double, d)                     \
+  _(bool, Bool, i)
 
 #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
-_(uint8_t,Byte,i) \
-_(int8_t,Char,i) \
-_(int16_t,Short,i) \
-_(int,Int,i) \
-_(int64_t,Long,i) \
-_(float,Float,d) \
-_(double,Double,d)
+  _(uint8_t, Byte, i)                         \
+  _(int8_t, Char, i)                          \
+  _(int16_t, Short, i)                        \
+  _(int, Int, i)                              \
+  _(int64_t, Long, i)                         \
+  _(float, Float, d)                          \
+  _(double, Double, d)
 
 enum class ScalarType : int8_t {
-#define DEFINE_ENUM(_1,n,_2) \
-  n,
+#define DEFINE_ENUM(_1, n, _2) n,
   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
 #undef DEFINE_ENUM
-  Undefined,
+      Undefined,
   NumOptions
 };
 
 static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
-#define DEFINE_CASE(ctype,name,_) \
-  case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>();
+#define DEFINE_CASE(ctype, name, _) \
+  case ScalarType::name:            \
+    return caffe2::TypeMeta::Make<ctype>();
 
-  switch(scalar_type) {
+  switch (scalar_type) {
     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
-    case ScalarType::Undefined: return caffe2::TypeMeta();
-    default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
+    case ScalarType::Undefined:
+      return caffe2::TypeMeta();
+    default:
+      AT_ERROR(
+          "Unrecognized Scalartype ",
+          scalar_type,
+          " (please report this error)");
   }
 #undef DEFINE_CASE
 }
 
-static inline c10::optional<ScalarType> tryTypeMetaToScalarType(caffe2::TypeMeta dtype) {
-#define DEFINE_IF(ctype, name, _)                      \
+static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
+    caffe2::TypeMeta dtype) {
+#define DEFINE_IF(ctype, name, _)                 \
   if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
-    return {ScalarType::name};                         \
+    return {ScalarType::name};                    \
   }
   AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
 #undef DEFINE_IF
@@ -111,7 +117,8 @@ static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
   if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
     return *scalar_type;
   }
-  AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
+  AT_ERROR(
+      "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
 }
 
 static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
@@ -125,17 +132,18 @@ static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
   return t == m;
 }
 
-#define DEFINE_CONSTANT(_,name,_2) \
-constexpr ScalarType k##name = ScalarType::name;
+#define DEFINE_CONSTANT(_, name, _2) \
+  constexpr ScalarType k##name = ScalarType::name;
 
 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
 #undef DEFINE_CONSTANT
 
-static inline const char * toString(ScalarType t) {
-#define DEFINE_CASE(_,name,_2) \
-  case ScalarType:: name : return #name;
+static inline const char* toString(ScalarType t) {
+#define DEFINE_CASE(_, name, _2) \
+  case ScalarType::name:         \
+    return #name;
 
-  switch(t) {
+  switch (t) {
     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
     default:
       return "UNKNOWN_SCALAR";
@@ -144,10 +152,11 @@ static inline const char * toString(ScalarType t) {
 }
 
 static inline size_t elementSize(ScalarType t) {
-#define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \
-  case ScalarType:: name : return sizeof(ctype);
+#define CASE_ELEMENTSIZE_CASE(ctype, name, _2) \
+  case ScalarType::name:                       \
+    return sizeof(ctype);
 
-  switch(t) {
+  switch (t) {
     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
     default:
       AT_ERROR("Unknown ScalarType");
@@ -156,23 +165,21 @@ static inline size_t elementSize(ScalarType t) {
 }
 
 static inline bool isIntegralType(ScalarType t) {
-  return (t == ScalarType::Byte ||
-          t == ScalarType::Char ||
-          t == ScalarType::Int ||
-          t == ScalarType::Long ||
-          t == ScalarType::Short);
+  return (
+      t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
+      t == ScalarType::Long || t == ScalarType::Short);
 }
 
 static inline bool isFloatingType(ScalarType t) {
-  return (t == ScalarType::Double ||
-          t == ScalarType::Float ||
-          t == ScalarType::Half);
+  return (
+      t == ScalarType::Double || t == ScalarType::Float ||
+      t == ScalarType::Half);
 }
 
 static inline bool isComplexType(ScalarType t) {
-  return (t == ScalarType::ComplexHalf ||
-          t == ScalarType::ComplexFloat ||
-          t == ScalarType::ComplexDouble);
+  return (
+      t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
+      t == ScalarType::ComplexDouble);
 }
 
 static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
@@ -191,27 +198,28 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
     return ScalarType::Undefined;
   }
   if (isComplexType(a) || isComplexType(b)) {
-    AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
+    AT_ERROR(
+        "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
   }
 
-  // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
-  // undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
-  static constexpr ScalarType _promoteTypesLookup
-      [static_cast<int>(ScalarType::NumOptions)]
-      [static_cast<int>(ScalarType::NumOptions)] = {
-            /* u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1 */
-    /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
-    /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
-    /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
-    /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
-    /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
-    /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
-    /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
-    /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
-    /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
-    /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
-    /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
-    /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
+  // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
+  // so that's why we have to add undefined as we are not sure what is the
+  // corrent values for the type promotions in complex type cases.
+  static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
+      ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
+      /* u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1 */
+      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1},
+      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1},
+      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2},
+      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4},
+      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8},
+      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2},
+      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4},
+      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8},
+      /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* c4 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* c8 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1},
   };
   return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
 }
index 4d65208..6253c83 100644 (file)
@@ -17,17 +17,25 @@ C10_EXPORT void _ThrowRuntimeTypeLogicError(const string& msg) {
   AT_ERROR(msg);
 }
 
-const TypeMetaData _typeMetaDataInstance_uninitialized_ = detail::TypeMetaData(0, nullptr, nullptr, nullptr, nullptr, nullptr, TypeIdentifier::uninitialized(), "nullptr (uninitialized)");
+const TypeMetaData _typeMetaDataInstance_uninitialized_ = detail::TypeMetaData(
+    0,
+    nullptr,
+    nullptr,
+    nullptr,
+    nullptr,
+    nullptr,
+    TypeIdentifier::uninitialized(),
+    "nullptr (uninitialized)");
 
 } // namespace detail
 
 // TODO Inlineable on non-MSVC like other preallocated ids?
-template<>
-C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<detail::_Uninitialized>() noexcept {
+template <>
+C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<
+    detail::_Uninitialized>() noexcept {
   return &detail::_typeMetaDataInstance_uninitialized_;
 }
 
-
 TypeIdentifier TypeIdentifier::createTypeId() {
   static std::atomic<TypeIdentifier::underlying_type> counter(
       TypeMeta::Id<_CaffeHighestPreallocatedTypeId>().underlyingId());
index 448f44e..9cf9d92 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <atomic>
 #include <cassert>
+#include <complex>
 #include <cstdlib>
 #include <iostream>
 #include <memory>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
-#include <complex>
 #ifdef __GXX_RTTI
 #include <typeinfo>
 #endif
 
 #include <exception>
 
-#include "c10/util/Backtrace.h"
-#include "c10/util/Half.h"
 #include "c10/macros/Macros.h"
+#include "c10/util/Backtrace.h"
 #include "c10/util/C++17.h"
 #include "c10/util/Exception.h"
+#include "c10/util/Half.h"
 #include "c10/util/IdWrapper.h"
 
 #include "c10/util/Type.h"
 
 // Make at::Half a fundamental type.
 namespace std {
-template<>
-struct is_fundamental<at::Half> : std::true_type {
-};
-}  // namespace std
+template <>
+struct is_fundamental<at::Half> : std::true_type {};
+} // namespace std
 
 namespace caffe2 {
 
@@ -62,9 +61,7 @@ class C10_API TypeIdentifier final
  public:
   static TypeIdentifier createTypeId();
 
-  friend std::ostream& operator<<(
-      std::ostream& stream,
-      TypeIdentifier typeId);
+  friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId);
   friend bool operator<(TypeIdentifier lhs, TypeIdentifier rhs);
 
   // 0 is uint8_t (due to ScalarType BC constraint)
@@ -123,15 +120,22 @@ struct TypeMetaData final {
 
   TypeMetaData() = delete;
   constexpr TypeMetaData(
-    size_t itemsize,
-    New* newFn,
-    PlacementNew* placementNew,
-    Copy* copy,
-    PlacementDelete* placementDelete,
-    Delete* deleteFn,
-    TypeIdentifier id,
-    const char* name) noexcept
-  : itemsize_(itemsize), new_(newFn), placementNew_(placementNew), copy_(copy), placementDelete_(placementDelete), delete_(deleteFn), id_(id), name_(name) {}
+      size_t itemsize,
+      New* newFn,
+      PlacementNew* placementNew,
+      Copy* copy,
+      PlacementDelete* placementDelete,
+      Delete* deleteFn,
+      TypeIdentifier id,
+      const char* name) noexcept
+      : itemsize_(itemsize),
+        new_(newFn),
+        placementNew_(placementNew),
+        copy_(copy),
+        placementDelete_(placementDelete),
+        delete_(deleteFn),
+        id_(id),
+        name_(name) {}
 
   size_t itemsize_;
   New* new_;
@@ -167,21 +171,22 @@ inline void _PlacementNewNotDefault(void* /*ptr*/, size_t /*n*/) {
       " is not default-constructible.");
 }
 
-template<
+template <
     typename T,
     c10::guts::enable_if_t<std::is_default_constructible<T>::value>* = nullptr>
 inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() {
-  return
-    (std::is_fundamental<T>::value || std::is_pointer<T>::value)
-    ? nullptr
-    : &_PlacementNew<T>;
+  return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+      ? nullptr
+      : &_PlacementNew<T>;
 }
 
-template<
+template <
     typename T,
     c10::guts::enable_if_t<!std::is_default_constructible<T>::value>* = nullptr>
 inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() {
-  static_assert(!std::is_fundamental<T>::value && !std::is_pointer<T>::value, "this should have picked the other SFINAE case");
+  static_assert(
+      !std::is_fundamental<T>::value && !std::is_pointer<T>::value,
+      "this should have picked the other SFINAE case");
   return &_PlacementNewNotDefault<T>;
 }
 
@@ -197,7 +202,7 @@ inline void* _NewNotDefault() {
       " is not default-constructible.");
 }
 
-template<
+template <
     typename T,
     c10::guts::enable_if_t<std::is_default_constructible<T>::value>* = nullptr>
 inline constexpr TypeMetaData::New* _PickNew() {
@@ -233,23 +238,22 @@ inline void _CopyNotAllowed(const void* /*src*/, void* /*dst*/, size_t /*n*/) {
       " does not allow assignment.");
 }
 
-template<
+template <
     typename T,
-    c10::guts::enable_if_t<std::is_copy_assignable<T>::value>* = nullptr
-    >
+    c10::guts::enable_if_t<std::is_copy_assignable<T>::value>* = nullptr>
 inline constexpr TypeMetaData::Copy* _PickCopy() {
-  return
-    (std::is_fundamental<T>::value || std::is_pointer<T>::value)
-    ? nullptr
-    : &_Copy<T>;
+  return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+      ? nullptr
+      : &_Copy<T>;
 }
 
-template<
+template <
     typename T,
-    c10::guts::enable_if_t<!std::is_copy_assignable<T>::value>* = nullptr
-    >
+    c10::guts::enable_if_t<!std::is_copy_assignable<T>::value>* = nullptr>
 inline constexpr TypeMetaData::Copy* _PickCopy() {
-  static_assert(!std::is_fundamental<T>::value && !std::is_pointer<T>::value, "this should have picked the other SFINAE case");
+  static_assert(
+      !std::is_fundamental<T>::value && !std::is_pointer<T>::value,
+      "this should have picked the other SFINAE case");
   return &_CopyNotAllowed<T>;
 }
 
@@ -266,10 +270,9 @@ inline void _PlacementDelete(void* ptr, size_t n) {
 
 template <typename T>
 inline constexpr TypeMetaData::PlacementDelete* _PickPlacementDelete() {
-  return
-    (std::is_fundamental<T>::value || std::is_pointer<T>::value)
-    ? nullptr
-    : &_PlacementDelete<T>;
+  return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+      ? nullptr
+      : &_PlacementDelete<T>;
 }
 
 template <typename T>
@@ -278,7 +281,7 @@ inline void _Delete(void* ptr) {
   delete typed_ptr;
 }
 
-template<class T>
+template <class T>
 inline constexpr TypeMetaData::Delete* _PickDelete() noexcept {
   return &_Delete<T>;
 }
@@ -297,18 +300,16 @@ constexpr const char* _typeName(const char* literalName) noexcept {
 }
 #endif
 
-template<class T>
+template <class T>
 inline TypeMetaData _makeTypeMetaDataInstance(const char* typeName) {
-  return {
-    sizeof(T),
-    _PickNew<T>(),
-    _PickPlacementNew<T>(),
-    _PickCopy<T>(),
-    _PickPlacementDelete<T>(),
-    _PickDelete<T>(),
-    TypeIdentifier::Get<T>(),
-    typeName
-  };
+  return {sizeof(T),
+          _PickNew<T>(),
+          _PickPlacementNew<T>(),
+          _PickCopy<T>(),
+          _PickPlacementDelete<T>(),
+          _PickDelete<T>(),
+          TypeIdentifier::Get<T>(),
+          typeName};
 }
 
 class _Uninitialized final {};
@@ -350,7 +351,8 @@ class C10_API TypeMeta {
  private:
   // TypeMeta can only be created by Make, making sure that we do not
   // create incorrectly mixed up TypeMeta objects.
-  explicit constexpr TypeMeta(const detail::TypeMetaData* data) noexcept : data_(data) {}
+  explicit constexpr TypeMeta(const detail::TypeMetaData* data) noexcept
+      : data_(data) {}
 
  public:
   /**
@@ -431,28 +433,30 @@ class C10_API TypeMeta {
     // disabled for compilers that don't know '-Wundefined-var-template' and
     // would error at our attempt to disable it.
 #ifndef _MSC_VER
-#  pragma GCC diagnostic push
-#  pragma GCC diagnostic ignored "-Wpragmas"
-#  pragma GCC diagnostic ignored "-Wunknown-warning-option"
-#  pragma GCC diagnostic ignored "-Wundefined-var-template"
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wpragmas"
+#pragma GCC diagnostic ignored "-Wunknown-warning-option"
+#pragma GCC diagnostic ignored "-Wundefined-var-template"
 #endif
     return TypeMeta(_typeMetaDataInstance<T>());
 #ifndef _MSC_VER
-#  pragma GCC diagnostic pop
+#pragma GCC diagnostic pop
 #endif
   }
 
  private:
   const detail::TypeMetaData* data_;
 
-  template<class T>
+  template <class T>
   C10_API static const detail::TypeMetaData* _typeMetaDataInstance() noexcept;
 };
 
-template<>
-C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<detail::_Uninitialized>() noexcept;
+template <>
+C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<
+    detail::_Uninitialized>() noexcept;
 
-inline TypeMeta::TypeMeta() noexcept : data_(_typeMetaDataInstance<detail::_Uninitialized>()) {}
+inline TypeMeta::TypeMeta() noexcept
+    : data_(_typeMetaDataInstance<detail::_Uninitialized>()) {}
 
 inline bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept {
   return (lhs.data_ == rhs.data_);
@@ -492,14 +496,15 @@ inline std::ostream& operator<<(
 #define EXPORT_IF_NOT_GCC
 #endif
 
-#define _CAFFE_KNOWN_TYPE_DEFINE_TYPEMETADATA_INSTANCE(T, Counter)        \
-  namespace detail {                                                      \
-  const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_, Counter) =      \
-      _makeTypeMetaDataInstance<T>(_typeName<T>(#T));                     \
-  }                                                                       \
-  template<>                                                              \
-  EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept {     \
-    return &MACRO_CONCAT(detail::_typeMetaDataInstance_, Counter);        \
+#define _CAFFE_KNOWN_TYPE_DEFINE_TYPEMETADATA_INSTANCE(T, Counter)   \
+  namespace detail {                                                 \
+  const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_, Counter) = \
+      _makeTypeMetaDataInstance<T>(_typeName<T>(#T));                \
+  }                                                                  \
+  template <>                                                        \
+  EXPORT_IF_NOT_GCC const detail::TypeMetaData*                      \
+  TypeMeta::_typeMetaDataInstance<T>() noexcept {                    \
+    return &MACRO_CONCAT(detail::_typeMetaDataInstance_, Counter);   \
   }
 #define CAFFE_KNOWN_TYPE(T)                                               \
   template <>                                                             \
@@ -516,43 +521,50 @@ inline std::ostream& operator<<(
  * for your own types to allocate dynamic ids for them.
  */
 #ifdef _MSC_VER
-#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)              \
-  template <>                                                                 \
-  inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() {                 \
-    return TypeIdentifier(PreallocatedId);                                    \
-  }                                                                           \
-  namespace detail {                                                          \
-  C10_API extern const TypeMetaData                                           \
-      MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId);      \
+#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
+  template <>                                                    \
+  inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() {    \
+    return TypeIdentifier(PreallocatedId);                       \
+  }                                                              \
+  namespace detail {                                             \
+  C10_API extern const TypeMetaData MACRO_CONCAT(                \
+      _typeMetaDataInstance_preallocated_,                       \
+      PreallocatedId);                                           \
   }
-#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)               \
-  namespace detail {                                                          \
-  C10_EXPORT const TypeMetaData                                               \
-    MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId)         \
-      = _makeTypeMetaDataInstance<T>(_typeName<T>(#T));                       \
-  }                                                                           \
-  template<>                                                                  \
-  C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept { \
-    return &MACRO_CONCAT(detail::_typeMetaDataInstance_preallocated_, PreallocatedId);   \
+#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)         \
+  namespace detail {                                                    \
+  C10_EXPORT const TypeMetaData MACRO_CONCAT(                           \
+      _typeMetaDataInstance_preallocated_,                              \
+      PreallocatedId) = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
+  }                                                                     \
+  template <>                                                           \
+  C10_EXPORT const detail::TypeMetaData*                                \
+  TypeMeta::_typeMetaDataInstance<T>() noexcept {                       \
+    return &MACRO_CONCAT(                                               \
+        detail::_typeMetaDataInstance_preallocated_, PreallocatedId);   \
   }
 #else // _MSC_VER
-#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)              \
-  template <>                                                                 \
-  inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() {                 \
-    return TypeIdentifier(PreallocatedId);                                    \
-  }                                                                           \
-  namespace detail {                                                          \
-  C10_EXPORT extern const TypeMetaData                                        \
-      MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId);      \
-  }                                                                           \
-  template<>                                                                  \
-  inline const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept {    \
-    return &MACRO_CONCAT(detail::_typeMetaDataInstance_preallocated_, PreallocatedId);  \
+#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)      \
+  template <>                                                         \
+  inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() {         \
+    return TypeIdentifier(PreallocatedId);                            \
+  }                                                                   \
+  namespace detail {                                                  \
+  C10_EXPORT extern const TypeMetaData MACRO_CONCAT(                  \
+      _typeMetaDataInstance_preallocated_,                            \
+      PreallocatedId);                                                \
+  }                                                                   \
+  template <>                                                         \
+  inline const detail::TypeMetaData*                                  \
+  TypeMeta::_typeMetaDataInstance<T>() noexcept {                     \
+    return &MACRO_CONCAT(                                             \
+        detail::_typeMetaDataInstance_preallocated_, PreallocatedId); \
   }
-#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)               \
-  namespace detail {                                                          \
-  const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId)  \
-      = _makeTypeMetaDataInstance<T>(_typeName<T>(#T));                       \
+#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T)         \
+  namespace detail {                                                    \
+  const TypeMetaData MACRO_CONCAT(                                      \
+      _typeMetaDataInstance_preallocated_,                              \
+      PreallocatedId) = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
   }
 #endif
 
index de199d8..b165551 100644 (file)
@@ -1,15 +1,17 @@
-#include <torch/csrc/python_headers.h>
 #include <torch/csrc/utils/tensor_dtypes.h>
 #include <torch/csrc/Dtype.h>
 #include <torch/csrc/DynamicTypes.h>
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/autograd/generated/VariableType.h>
+#include <torch/csrc/python_headers.h>
 #include <torch/csrc/utils/tensor_types.h>
 
-namespace torch { namespace utils {
+namespace torch {
+namespace utils {
 
-static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
-  switch(scalarType) {
+static std::pair<std::string, std::string> getDtypeNames(
+    at::ScalarType scalarType) {
+  switch (scalarType) {
     case at::ScalarType::Byte:
       // no "byte" because byte is signed in numpy and we overload
       // byte to mean bool often
@@ -45,31 +47,35 @@ static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarTy
 
 void initializeDtypes() {
   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
-  if (!torch_module) throw python_error();
+  if (!torch_module)
+    throw python_error();
 
-#define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,
+#define DEFINE_SCALAR_TYPE(_1, n, _2) at::ScalarType::n,
 
   at::ScalarType all_scalar_types[] = {
-    AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)
-  };
+      AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
 
-  for (at::ScalarType scalarType: all_scalar_types) {
+  for (at::ScalarType scalarType : all_scalar_types) {
     std::string primary_name, legacy_name;
     std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
-    std::string name = std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
-    PyObject *dtype = THPDtype_New(scalarType, name);
+    std::string name =
+        std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
+    PyObject* dtype = THPDtype_New(scalarType, name);
     torch::registerDtypeObject((THPDtype*)dtype, scalarType);
     Py_INCREF(dtype);
-    if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
+    if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=
+        0) {
       throw python_error();
     }
     if (legacy_name != "") {
       Py_INCREF(dtype);
-      if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {
+      if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) !=
+          0) {
         throw python_error();
       }
     }
   }
 }
 
-}} // namespace torch::utils
+} // namespace utils
+} // namespace torch