#include <iostream>
#include <sstream>
-
using namespace std;
namespace at {
return dtype;
}
-
static DLContext getDLContext(const Type& type, const int64_t& device_id) {
DLContext ctx;
ctx.device_id = device_id;
return ctx;
}
-
static DeviceType getATenDeviceType(const DLContext& ctx) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
case DLDeviceType::kDLROCM:
return DeviceType::HIP;
default:
- throw std::logic_error("Unsupported device_type: " + std::to_string(ctx.device_type));
+ throw std::logic_error(
+ "Unsupported device_type: " + std::to_string(ctx.device_type));
}
return DeviceType::CPU; // impossible
}
-
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype;
- if (dtype.lanes != 1) throw std::logic_error("ATen does not support lanes != 1");
+ if (dtype.lanes != 1)
+ throw std::logic_error("ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
stype = ScalarType::Byte;
break;
default:
- throw std::logic_error("Unsupported kUInt bits " + std::to_string(dtype.bits));
+ throw std::logic_error(
+ "Unsupported kUInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
stype = ScalarType::Long;
break;
default:
- throw std::logic_error("Unsupported kInt bits " + std::to_string(dtype.bits));
+ throw std::logic_error(
+ "Unsupported kInt bits " + std::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
stype = ScalarType::Double;
break;
default:
- throw std::logic_error("Unsupported kFloat bits " + std::to_string(dtype.bits));
+ throw std::logic_error(
+ "Unsupported kFloat bits " + std::to_string(dtype.bits));
}
break;
default:
DLManagedTensor tensor;
};
-void deleter(DLManagedTensor * arg) {
+void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}
-
-// This function returns a shared_ptr to memory managed DLpack tensor constructed
-// out of ATen tensor
+// This function returns a shared_ptr to memory managed DLpack tensor
+// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
- ATenDLMTensor * atDLMTensor(new ATenDLMTensor);
+ ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = src;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
- atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
- atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
+ atDLMTensor->tensor.dl_tensor.shape =
+ const_cast<int64_t*>(src.sizes().data());
+ atDLMTensor->tensor.dl_tensor.strides =
+ const_cast<int64_t*>(src.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
-
Tensor fromDLPack(const DLManagedTensor* src) {
DeviceType device_type = getATenDeviceType(src->dl_tensor.ctx);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
- auto deleter = [src](void * self) {
+ auto deleter = [src](void* self) {
src->deleter(const_cast<DLManagedTensor*>(src));
};
- return at::from_blob(src->dl_tensor.data,
+ return at::from_blob(
+ src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
deleter,
at::device(device_type).dtype(stype));
}
-} //namespace at
+} // namespace at
#include <string>
#include <utility>
-#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
+#include <c10/macros/Macros.h>
#include <c10/util/Half.h>
namespace c10 {
/**
* Scalar represents a 0-dimensional tensor which contains a single element.
- * Unlike a tensor, numeric literals (in C++) are implicitly convertible to Scalar
- * (which is why, for example, we provide both add(Tensor) and add(Scalar) overloads
- * for many operations). It may also be used in circumstances where you statically
- * know a tensor is 0-dim and single size, but don't know it's type.
+ * Unlike a tensor, numeric literals (in C++) are implicitly convertible to
+ * Scalar (which is why, for example, we provide both add(Tensor) and
+ * add(Scalar) overloads for many operations). It may also be used in
+ * circumstances where you statically know a tensor is 0-dim and single size,
+ * but don't know it's type.
*/
class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}
-#define DEFINE_IMPLICIT_CTOR(type,name,member) \
- Scalar(type vv) \
- : tag(Tag::HAS_##member) { \
- v . member = convert<decltype(v.member),type>(vv); \
+#define DEFINE_IMPLICIT_CTOR(type, name, member) \
+ Scalar(type vv) : tag(Tag::HAS_##member) { \
+ v.member = convert<decltype(v.member), type>(vv); \
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
#undef DEFINE_IMPLICIT_CTOR
-// Value* is both implicitly convertible to SymbolicVariable and bool which
-// causes ambiguosity error. Specialized constructor for bool resolves this problem.
-template <typename T,
- typename std::enable_if<std::is_same<T, bool>::value, bool>::type* = nullptr>
-Scalar(T vv)
-: tag(Tag::HAS_i) {
- v.i = convert<decltype(v.i), bool>(vv);
-}
+ // Value* is both implicitly convertible to SymbolicVariable and bool which
+ // causes ambiguosity error. Specialized constructor for bool resolves this
+ // problem.
+ template <
+ typename T,
+ typename std::enable_if<std::is_same<T, bool>::value, bool>::type* =
+ nullptr>
+ Scalar(T vv) : tag(Tag::HAS_i) {
+ v.i = convert<decltype(v.i), bool>(vv);
+ }
#define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
Scalar(type vv) : tag(Tag::HAS_##member) { \
v.member[1] = c10::convert<double>(vv.imag()); \
}
- DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
- DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
- DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
+ DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf, ComplexHalf, z)
+ DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>, ComplexFloat, z)
+ DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>, ComplexDouble, z)
#undef DEFINE_IMPLICIT_COMPLEX_CTOR
-#define DEFINE_ACCESSOR(type,name,member) \
- type to##name () const { \
- if (Tag::HAS_d == tag) { \
- return checked_convert<type, double>(v.d, #type); \
- } else if (Tag::HAS_z == tag) { \
- return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
- } else { \
- return checked_convert<type, int64_t>(v.i, #type); \
- } \
+#define DEFINE_ACCESSOR(type, name, member) \
+ type to##name() const { \
+ if (Tag::HAS_d == tag) { \
+ return checked_convert<type, double>(v.d, #type); \
+ } else if (Tag::HAS_z == tag) { \
+ return checked_convert<type, std::complex<double>>( \
+ {v.z[0], v.z[1]}, #type); \
+ } else { \
+ return checked_convert<type, int64_t>(v.i, #type); \
+ } \
}
// TODO: Support ComplexHalf accessor
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
- //also support scalar.to<int64_t>();
- template<typename T>
+ // also support scalar.to<int64_t>();
+ template <typename T>
T to();
#undef DEFINE_ACCESSOR
Scalar operator-() const;
-private:
+ private:
enum class Tag { HAS_d, HAS_i, HAS_z };
Tag tag;
union {
};
// define the scalar.to<int64_t>() specializations
-template<typename T>
+template <typename T>
inline T Scalar::to() {
throw std::runtime_error("to() cast to unexpected type.");
}
-#define DEFINE_TO(T,name,_) \
-template<> \
-inline T Scalar::to<T>() { \
- return to##name(); \
-}
+#define DEFINE_TO(T, name, _) \
+ template <> \
+ inline T Scalar::to<T>() { \
+ return to##name(); \
+ }
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
#undef DEFINE_TO
-}
+} // namespace c10
#include <c10/util/Optional.h>
#include <c10/util/typeid.h>
+#include <complex>
#include <cstdint>
#include <iostream>
-#include <complex>
namespace c10 {
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
-#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
-_(uint8_t,Byte,i) /* 0 */ \
-_(int8_t,Char,i) /* 1 */ \
-_(int16_t,Short,i) /* 2 */ \
-_(int,Int,i) /* 3 */ \
-_(int64_t,Long,i) /* 4 */ \
-_(at::Half,Half,d) /* 5 */ \
-_(float,Float,d) /* 6 */ \
-_(double,Double,d) /* 7 */ \
-_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
-_(std::complex<float>,ComplexFloat,z) /* 9 */ \
-_(std::complex<double>,ComplexDouble,z) /* 10 */ \
-_(bool,Bool,i) /* 11 */
+#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
+ _(uint8_t, Byte, i) /* 0 */ \
+ _(int8_t, Char, i) /* 1 */ \
+ _(int16_t, Short, i) /* 2 */ \
+ _(int, Int, i) /* 3 */ \
+ _(int64_t, Long, i) /* 4 */ \
+ _(at::Half, Half, d) /* 5 */ \
+ _(float, Float, d) /* 6 */ \
+ _(double, Double, d) /* 7 */ \
+ _(at::ComplexHalf, ComplexHalf, z) /* 8 */ \
+ _(std::complex<float>, ComplexFloat, z) /* 9 */ \
+ _(std::complex<double>, ComplexDouble, z) /* 10 */ \
+ _(bool, Bool, i) /* 11 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
// beware: convert() doesn't work for all the conversions you need...
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
-_(uint8_t,Byte,i) \
-_(int8_t,Char,i) \
-_(int16_t,Short,i) \
-_(int,Int,i) \
-_(int64_t,Long,i) \
-_(at::Half,Half,d) \
-_(float,Float,d) \
-_(double,Double,d) \
-_(std::complex<float>,ComplexFloat,z) \
-_(std::complex<double>,ComplexDouble,z) \
-_(bool,Bool,i)
+ _(uint8_t, Byte, i) \
+ _(int8_t, Char, i) \
+ _(int16_t, Short, i) \
+ _(int, Int, i) \
+ _(int64_t, Long, i) \
+ _(at::Half, Half, d) \
+ _(float, Float, d) \
+ _(double, Double, d) \
+ _(std::complex<float>, ComplexFloat, z) \
+ _(std::complex<double>, ComplexDouble, z) \
+ _(bool, Bool, i)
#define AT_FORALL_SCALAR_TYPES(_) \
-_(uint8_t,Byte,i) \
-_(int8_t,Char,i) \
-_(int16_t,Short,i) \
-_(int,Int,i) \
-_(int64_t,Long,i) \
-_(at::Half,Half,d) \
-_(float,Float,d) \
-_(double,Double,d)
+ _(uint8_t, Byte, i) \
+ _(int8_t, Char, i) \
+ _(int16_t, Short, i) \
+ _(int, Int, i) \
+ _(int64_t, Long, i) \
+ _(at::Half, Half, d) \
+ _(float, Float, d) \
+ _(double, Double, d)
#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \
-_(uint8_t,Byte,i) \
-_(int8_t,Char,i) \
-_(int16_t,Short,i) \
-_(int,Int,i) \
-_(int64_t,Long,i) \
-_(at::Half,Half,d) \
-_(float,Float,d) \
-_(double,Double,d) \
-_(bool,Bool,i)
+ _(uint8_t, Byte, i) \
+ _(int8_t, Char, i) \
+ _(int16_t, Short, i) \
+ _(int, Int, i) \
+ _(int64_t, Long, i) \
+ _(at::Half, Half, d) \
+ _(float, Float, d) \
+ _(double, Double, d) \
+ _(bool, Bool, i)
#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
-_(uint8_t,Byte,i) \
-_(int8_t,Char,i) \
-_(int16_t,Short,i) \
-_(int,Int,i) \
-_(int64_t,Long,i) \
-_(float,Float,d) \
-_(double,Double,d)
+ _(uint8_t, Byte, i) \
+ _(int8_t, Char, i) \
+ _(int16_t, Short, i) \
+ _(int, Int, i) \
+ _(int64_t, Long, i) \
+ _(float, Float, d) \
+ _(double, Double, d)
enum class ScalarType : int8_t {
-#define DEFINE_ENUM(_1,n,_2) \
- n,
+#define DEFINE_ENUM(_1, n, _2) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
#undef DEFINE_ENUM
- Undefined,
+ Undefined,
NumOptions
};
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
-#define DEFINE_CASE(ctype,name,_) \
- case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>();
+#define DEFINE_CASE(ctype, name, _) \
+ case ScalarType::name: \
+ return caffe2::TypeMeta::Make<ctype>();
- switch(scalar_type) {
+ switch (scalar_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
- case ScalarType::Undefined: return caffe2::TypeMeta();
- default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
+ case ScalarType::Undefined:
+ return caffe2::TypeMeta();
+ default:
+ AT_ERROR(
+ "Unrecognized Scalartype ",
+ scalar_type,
+ " (please report this error)");
}
#undef DEFINE_CASE
}
-static inline c10::optional<ScalarType> tryTypeMetaToScalarType(caffe2::TypeMeta dtype) {
-#define DEFINE_IF(ctype, name, _) \
+static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
+ caffe2::TypeMeta dtype) {
+#define DEFINE_IF(ctype, name, _) \
if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
- return {ScalarType::name}; \
+ return {ScalarType::name}; \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
#undef DEFINE_IF
if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
return *scalar_type;
}
- AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
+ AT_ERROR(
+ "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
}
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
return t == m;
}
-#define DEFINE_CONSTANT(_,name,_2) \
-constexpr ScalarType k##name = ScalarType::name;
+#define DEFINE_CONSTANT(_, name, _2) \
+ constexpr ScalarType k##name = ScalarType::name;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT
-static inline const char * toString(ScalarType t) {
-#define DEFINE_CASE(_,name,_2) \
- case ScalarType:: name : return #name;
+static inline const char* toString(ScalarType t) {
+#define DEFINE_CASE(_, name, _2) \
+ case ScalarType::name: \
+ return #name;
- switch(t) {
+ switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
static inline size_t elementSize(ScalarType t) {
-#define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \
- case ScalarType:: name : return sizeof(ctype);
+#define CASE_ELEMENTSIZE_CASE(ctype, name, _2) \
+ case ScalarType::name: \
+ return sizeof(ctype);
- switch(t) {
+ switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
default:
AT_ERROR("Unknown ScalarType");
}
static inline bool isIntegralType(ScalarType t) {
- return (t == ScalarType::Byte ||
- t == ScalarType::Char ||
- t == ScalarType::Int ||
- t == ScalarType::Long ||
- t == ScalarType::Short);
+ return (
+ t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
+ t == ScalarType::Long || t == ScalarType::Short);
}
static inline bool isFloatingType(ScalarType t) {
- return (t == ScalarType::Double ||
- t == ScalarType::Float ||
- t == ScalarType::Half);
+ return (
+ t == ScalarType::Double || t == ScalarType::Float ||
+ t == ScalarType::Half);
}
static inline bool isComplexType(ScalarType t) {
- return (t == ScalarType::ComplexHalf ||
- t == ScalarType::ComplexFloat ||
- t == ScalarType::ComplexDouble);
+ return (
+ t == ScalarType::ComplexHalf || t == ScalarType::ComplexFloat ||
+ t == ScalarType::ComplexDouble);
}
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
return ScalarType::Undefined;
}
if (isComplexType(a) || isComplexType(b)) {
- AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
+ AT_ERROR(
+ "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
}
- // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
- // undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
- static constexpr ScalarType _promoteTypesLookup
- [static_cast<int>(ScalarType::NumOptions)]
- [static_cast<int>(ScalarType::NumOptions)] = {
- /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
- /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
- /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
- /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
- /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
- /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
- /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
- /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
- /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
- /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
- /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
- /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
- /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
+ // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
+ // so that's why we have to add undefined as we are not sure what is the
+ // corrent values for the type promotions in complex type cases.
+ static constexpr ScalarType _promoteTypesLookup[static_cast<int>(
+ ScalarType::NumOptions)][static_cast<int>(ScalarType::NumOptions)] = {
+ /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
+ /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1},
+ /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1},
+ /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2},
+ /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4},
+ /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8},
+ /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2},
+ /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4},
+ /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8},
+ /* c2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+ /* c4 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+ /* c8 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+ /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1},
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
AT_ERROR(msg);
}
-const TypeMetaData _typeMetaDataInstance_uninitialized_ = detail::TypeMetaData(0, nullptr, nullptr, nullptr, nullptr, nullptr, TypeIdentifier::uninitialized(), "nullptr (uninitialized)");
+const TypeMetaData _typeMetaDataInstance_uninitialized_ = detail::TypeMetaData(
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ TypeIdentifier::uninitialized(),
+ "nullptr (uninitialized)");
} // namespace detail
// TODO Inlineable on non-MSVC like other preallocated ids?
-template<>
-C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<detail::_Uninitialized>() noexcept {
+template <>
+C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<
+ detail::_Uninitialized>() noexcept {
return &detail::_typeMetaDataInstance_uninitialized_;
}
-
TypeIdentifier TypeIdentifier::createTypeId() {
static std::atomic<TypeIdentifier::underlying_type> counter(
TypeMeta::Id<_CaffeHighestPreallocatedTypeId>().underlyingId());
#include <atomic>
#include <cassert>
+#include <complex>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include <complex>
#ifdef __GXX_RTTI
#include <typeinfo>
#endif
#include <exception>
-#include "c10/util/Backtrace.h"
-#include "c10/util/Half.h"
#include "c10/macros/Macros.h"
+#include "c10/util/Backtrace.h"
#include "c10/util/C++17.h"
#include "c10/util/Exception.h"
+#include "c10/util/Half.h"
#include "c10/util/IdWrapper.h"
#include "c10/util/Type.h"
// Make at::Half a fundamental type.
namespace std {
-template<>
-struct is_fundamental<at::Half> : std::true_type {
-};
-} // namespace std
+template <>
+struct is_fundamental<at::Half> : std::true_type {};
+} // namespace std
namespace caffe2 {
public:
static TypeIdentifier createTypeId();
- friend std::ostream& operator<<(
- std::ostream& stream,
- TypeIdentifier typeId);
+ friend std::ostream& operator<<(std::ostream& stream, TypeIdentifier typeId);
friend bool operator<(TypeIdentifier lhs, TypeIdentifier rhs);
// 0 is uint8_t (due to ScalarType BC constraint)
TypeMetaData() = delete;
constexpr TypeMetaData(
- size_t itemsize,
- New* newFn,
- PlacementNew* placementNew,
- Copy* copy,
- PlacementDelete* placementDelete,
- Delete* deleteFn,
- TypeIdentifier id,
- const char* name) noexcept
- : itemsize_(itemsize), new_(newFn), placementNew_(placementNew), copy_(copy), placementDelete_(placementDelete), delete_(deleteFn), id_(id), name_(name) {}
+ size_t itemsize,
+ New* newFn,
+ PlacementNew* placementNew,
+ Copy* copy,
+ PlacementDelete* placementDelete,
+ Delete* deleteFn,
+ TypeIdentifier id,
+ const char* name) noexcept
+ : itemsize_(itemsize),
+ new_(newFn),
+ placementNew_(placementNew),
+ copy_(copy),
+ placementDelete_(placementDelete),
+ delete_(deleteFn),
+ id_(id),
+ name_(name) {}
size_t itemsize_;
New* new_;
" is not default-constructible.");
}
-template<
+template <
typename T,
c10::guts::enable_if_t<std::is_default_constructible<T>::value>* = nullptr>
inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() {
- return
- (std::is_fundamental<T>::value || std::is_pointer<T>::value)
- ? nullptr
- : &_PlacementNew<T>;
+ return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+ ? nullptr
+ : &_PlacementNew<T>;
}
-template<
+template <
typename T,
c10::guts::enable_if_t<!std::is_default_constructible<T>::value>* = nullptr>
inline constexpr TypeMetaData::PlacementNew* _PickPlacementNew() {
- static_assert(!std::is_fundamental<T>::value && !std::is_pointer<T>::value, "this should have picked the other SFINAE case");
+ static_assert(
+ !std::is_fundamental<T>::value && !std::is_pointer<T>::value,
+ "this should have picked the other SFINAE case");
return &_PlacementNewNotDefault<T>;
}
" is not default-constructible.");
}
-template<
+template <
typename T,
c10::guts::enable_if_t<std::is_default_constructible<T>::value>* = nullptr>
inline constexpr TypeMetaData::New* _PickNew() {
" does not allow assignment.");
}
-template<
+template <
typename T,
- c10::guts::enable_if_t<std::is_copy_assignable<T>::value>* = nullptr
- >
+ c10::guts::enable_if_t<std::is_copy_assignable<T>::value>* = nullptr>
inline constexpr TypeMetaData::Copy* _PickCopy() {
- return
- (std::is_fundamental<T>::value || std::is_pointer<T>::value)
- ? nullptr
- : &_Copy<T>;
+ return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+ ? nullptr
+ : &_Copy<T>;
}
-template<
+template <
typename T,
- c10::guts::enable_if_t<!std::is_copy_assignable<T>::value>* = nullptr
- >
+ c10::guts::enable_if_t<!std::is_copy_assignable<T>::value>* = nullptr>
inline constexpr TypeMetaData::Copy* _PickCopy() {
- static_assert(!std::is_fundamental<T>::value && !std::is_pointer<T>::value, "this should have picked the other SFINAE case");
+ static_assert(
+ !std::is_fundamental<T>::value && !std::is_pointer<T>::value,
+ "this should have picked the other SFINAE case");
return &_CopyNotAllowed<T>;
}
template <typename T>
inline constexpr TypeMetaData::PlacementDelete* _PickPlacementDelete() {
- return
- (std::is_fundamental<T>::value || std::is_pointer<T>::value)
- ? nullptr
- : &_PlacementDelete<T>;
+ return (std::is_fundamental<T>::value || std::is_pointer<T>::value)
+ ? nullptr
+ : &_PlacementDelete<T>;
}
template <typename T>
delete typed_ptr;
}
-template<class T>
+template <class T>
inline constexpr TypeMetaData::Delete* _PickDelete() noexcept {
return &_Delete<T>;
}
}
#endif
-template<class T>
+template <class T>
inline TypeMetaData _makeTypeMetaDataInstance(const char* typeName) {
- return {
- sizeof(T),
- _PickNew<T>(),
- _PickPlacementNew<T>(),
- _PickCopy<T>(),
- _PickPlacementDelete<T>(),
- _PickDelete<T>(),
- TypeIdentifier::Get<T>(),
- typeName
- };
+ return {sizeof(T),
+ _PickNew<T>(),
+ _PickPlacementNew<T>(),
+ _PickCopy<T>(),
+ _PickPlacementDelete<T>(),
+ _PickDelete<T>(),
+ TypeIdentifier::Get<T>(),
+ typeName};
}
class _Uninitialized final {};
private:
// TypeMeta can only be created by Make, making sure that we do not
// create incorrectly mixed up TypeMeta objects.
- explicit constexpr TypeMeta(const detail::TypeMetaData* data) noexcept : data_(data) {}
+ explicit constexpr TypeMeta(const detail::TypeMetaData* data) noexcept
+ : data_(data) {}
public:
/**
// disabled for compilers that don't know '-Wundefined-var-template' and
// would error at our attempt to disable it.
#ifndef _MSC_VER
-# pragma GCC diagnostic push
-# pragma GCC diagnostic ignored "-Wpragmas"
-# pragma GCC diagnostic ignored "-Wunknown-warning-option"
-# pragma GCC diagnostic ignored "-Wundefined-var-template"
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wpragmas"
+#pragma GCC diagnostic ignored "-Wunknown-warning-option"
+#pragma GCC diagnostic ignored "-Wundefined-var-template"
#endif
return TypeMeta(_typeMetaDataInstance<T>());
#ifndef _MSC_VER
-# pragma GCC diagnostic pop
+#pragma GCC diagnostic pop
#endif
}
private:
const detail::TypeMetaData* data_;
- template<class T>
+ template <class T>
C10_API static const detail::TypeMetaData* _typeMetaDataInstance() noexcept;
};
-template<>
-C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<detail::_Uninitialized>() noexcept;
+template <>
+C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<
+ detail::_Uninitialized>() noexcept;
-inline TypeMeta::TypeMeta() noexcept : data_(_typeMetaDataInstance<detail::_Uninitialized>()) {}
+inline TypeMeta::TypeMeta() noexcept
+ : data_(_typeMetaDataInstance<detail::_Uninitialized>()) {}
inline bool operator==(const TypeMeta& lhs, const TypeMeta& rhs) noexcept {
return (lhs.data_ == rhs.data_);
#define EXPORT_IF_NOT_GCC
#endif
-#define _CAFFE_KNOWN_TYPE_DEFINE_TYPEMETADATA_INSTANCE(T, Counter) \
- namespace detail { \
- const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_, Counter) = \
- _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
- } \
- template<> \
- EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept { \
- return &MACRO_CONCAT(detail::_typeMetaDataInstance_, Counter); \
+#define _CAFFE_KNOWN_TYPE_DEFINE_TYPEMETADATA_INSTANCE(T, Counter) \
+ namespace detail { \
+ const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_, Counter) = \
+ _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
+ } \
+ template <> \
+ EXPORT_IF_NOT_GCC const detail::TypeMetaData* \
+ TypeMeta::_typeMetaDataInstance<T>() noexcept { \
+ return &MACRO_CONCAT(detail::_typeMetaDataInstance_, Counter); \
}
#define CAFFE_KNOWN_TYPE(T) \
template <> \
* for your own types to allocate dynamic ids for them.
*/
#ifdef _MSC_VER
-#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
- template <> \
- inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() { \
- return TypeIdentifier(PreallocatedId); \
- } \
- namespace detail { \
- C10_API extern const TypeMetaData \
- MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId); \
+#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
+ template <> \
+ inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() { \
+ return TypeIdentifier(PreallocatedId); \
+ } \
+ namespace detail { \
+ C10_API extern const TypeMetaData MACRO_CONCAT( \
+ _typeMetaDataInstance_preallocated_, \
+ PreallocatedId); \
}
-#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
- namespace detail { \
- C10_EXPORT const TypeMetaData \
- MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId) \
- = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
- } \
- template<> \
- C10_EXPORT const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept { \
- return &MACRO_CONCAT(detail::_typeMetaDataInstance_preallocated_, PreallocatedId); \
+#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
+ namespace detail { \
+ C10_EXPORT const TypeMetaData MACRO_CONCAT( \
+ _typeMetaDataInstance_preallocated_, \
+ PreallocatedId) = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
+ } \
+ template <> \
+ C10_EXPORT const detail::TypeMetaData* \
+ TypeMeta::_typeMetaDataInstance<T>() noexcept { \
+ return &MACRO_CONCAT( \
+ detail::_typeMetaDataInstance_preallocated_, PreallocatedId); \
}
#else // _MSC_VER
-#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
- template <> \
- inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() { \
- return TypeIdentifier(PreallocatedId); \
- } \
- namespace detail { \
- C10_EXPORT extern const TypeMetaData \
- MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId); \
- } \
- template<> \
- inline const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<T>() noexcept { \
- return &MACRO_CONCAT(detail::_typeMetaDataInstance_preallocated_, PreallocatedId); \
+#define CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
+ template <> \
+ inline C10_EXPORT TypeIdentifier TypeIdentifier::Get<T>() { \
+ return TypeIdentifier(PreallocatedId); \
+ } \
+ namespace detail { \
+ C10_EXPORT extern const TypeMetaData MACRO_CONCAT( \
+ _typeMetaDataInstance_preallocated_, \
+ PreallocatedId); \
+ } \
+ template <> \
+ inline const detail::TypeMetaData* \
+ TypeMeta::_typeMetaDataInstance<T>() noexcept { \
+ return &MACRO_CONCAT( \
+ detail::_typeMetaDataInstance_preallocated_, PreallocatedId); \
}
-#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
- namespace detail { \
- const TypeMetaData MACRO_CONCAT(_typeMetaDataInstance_preallocated_, PreallocatedId) \
- = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
+#define CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(PreallocatedId, T) \
+ namespace detail { \
+ const TypeMetaData MACRO_CONCAT( \
+ _typeMetaDataInstance_preallocated_, \
+ PreallocatedId) = _makeTypeMetaDataInstance<T>(_typeName<T>(#T)); \
}
#endif
-#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/tensor_dtypes.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/generated/VariableType.h>
+#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/tensor_types.h>
-namespace torch { namespace utils {
+namespace torch {
+namespace utils {
-static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
- switch(scalarType) {
+static std::pair<std::string, std::string> getDtypeNames(
+ at::ScalarType scalarType) {
+ switch (scalarType) {
case at::ScalarType::Byte:
// no "byte" because byte is signed in numpy and we overload
// byte to mean bool often
void initializeDtypes() {
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
- if (!torch_module) throw python_error();
+ if (!torch_module)
+ throw python_error();
-#define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,
+#define DEFINE_SCALAR_TYPE(_1, n, _2) at::ScalarType::n,
at::ScalarType all_scalar_types[] = {
- AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)
- };
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)};
- for (at::ScalarType scalarType: all_scalar_types) {
+ for (at::ScalarType scalarType : all_scalar_types) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
- std::string name = std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
- PyObject *dtype = THPDtype_New(scalarType, name);
+ std::string name =
+ std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
+ PyObject* dtype = THPDtype_New(scalarType, name);
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
Py_INCREF(dtype);
- if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
+ if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) !=
+ 0) {
throw python_error();
}
if (legacy_name != "") {
Py_INCREF(dtype);
- if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {
+ if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) !=
+ 0) {
throw python_error();
}
}
}
}
-}} // namespace torch::utils
+} // namespace utils
+} // namespace torch