namespace detail {
+template <at::ScalarType N>
+struct ScalarTypeToCType;
+
+template<>
+struct ScalarTypeToCType<at::ScalarType::Half> {
+ using type = at::Half;
+
+ // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
+ // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
+ // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu
+ // TODO: remove once the bug is fixed.
+ static at::Half t;
+};
+
+template<>
+struct ScalarTypeToCType<at::ScalarType::Bool> {
+ using type = bool;
+
+ // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
+ // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
+ // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu
+ // TODO: remove once the bug is fixed.
+ static bool t;
+};
+
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
} \
}()
-#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
+#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
+ detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
const auto& the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
-#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
+#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
- detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
const auto& the_type = TYPE; \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
} \
}()
+#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+ [&] { \
+ switch (TYPE) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(::detail::ScalarTypeToCType<SCALARTYPE>::t), __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ } \
+ }()
-template <at::ScalarType N>
-struct MyTemplate;
-
-template<>
-struct MyTemplate<at::ScalarType::Half> {
- using type = at::Half;
-};
-
-template<>
-struct MyTemplate<at::ScalarType::Bool> {
- using type = bool;
-};
-
-#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
- [&] { \
- switch (TYPE) { \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- } \
+#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
+ [&] { \
+ switch (TYPE) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType<SCALARTYPE1>::t), __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType<SCALARTYPE2>::t), __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ } \
}()
-#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
- [&] { \
- switch (TYPE) { \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(SCALARTYPE1, MyTemplate<SCALARTYPE1>::type, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(SCALARTYPE2, MyTemplate<SCALARTYPE2>::type, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
- } \
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
+ [&] { \
+ switch (TYPE) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType<SCALARTYPE1>::t), __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType<SCALARTYPE2>::t), __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE( \
+ at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE( \
+ at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
+ } \
}()
TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, scalar_t value);
TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, scalar_t value);
-TH_API void THTensor_(zerosLike)(THTensor *r_, THTensor *input);
-TH_API void THTensor_(onesLike)(THTensor *r_, THTensor *input);
TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
-TH_API void THTensor_(eye)(THTensor *r_, int64_t n, int64_t m);
-TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n);
TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted);
TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t);
-TH_API scalar_t THTensor_(powOne)(scalar_t x, scalar_t y);
TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t);
TH_API void THTensor_(frac)(THTensor *r_, THTensor *t);
-TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim);
TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim);
TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim);
TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, int keepdim);
# This is a temporary test for a boolean tensors on CPU. Once the CUDA part
# will be done, these test cases will be moved down to test_tensor_factories_empty test
- def test_tensor_factories_empty_bool(self):
+ def test_tensor_factories_bool(self):
expectedShape = (1, 2)
test = torch.empty(expectedShape, dtype=torch.bool)
self.assertEqual(expectedShape, test.shape)
- self.assertEqual(expectedShape, torch.empty_like(test).shape)
+
+ test2 = torch.empty_like(test, dtype=torch.bool)
+ self.assertEqual(test.shape, test2.shape)
test = torch.full(expectedShape, True, dtype=torch.bool)
self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
+
+ test2 = torch.full_like(test, True, dtype=torch.bool)
+ self.assertEqual(test, test2)
+
+ test = torch.zeros(expectedShape, dtype=torch.bool)
+ self.assertEqual(test, torch.tensor([[False, False]], dtype=torch.bool))
+
+ test2 = torch.zeros_like(test, dtype=torch.bool)
+ self.assertEqual(test, test2)
+
+ test = torch.ones(expectedShape, dtype=torch.bool)
+ self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
+
+ test2 = torch.ones_like(test, dtype=torch.bool)
+ self.assertEqual(test, test2)
+
+ test = torch.randint(10, expectedShape, dtype=torch.bool)
self.assertEqual(expectedShape, test.shape)
- self.assertEqual(expectedShape, torch.full_like(test, True).shape)
+ self.assertEqual(torch.bool, test.dtype)
def test_tensor_factories_empty(self):
# ensure we can create empty tensors from each factory function