--- /dev/null
+#include <ATen/core/typeid.h>
+#include <ATen/core/DefaultDtype.h>
+
+namespace at {
+static auto default_dtype = caffe2::TypeMeta::Make<float>();
+
+void set_default_dtype(caffe2::TypeMeta dtype) {
+ default_dtype = std::move(dtype);
+}
+
+const caffe2::TypeMeta& get_default_dtype() {
+ return default_dtype;
+}
+} // namespace at
--- /dev/null
+#pragma once
+
+#include <c10/macros/Macros.h>
+
+namespace caffe2 {
+class TypeMeta;
+} // namespace caffe2
+
+namespace at {
+CAFFE2_API void set_default_dtype(caffe2::TypeMeta dtype);
+CAFFE2_API const caffe2::TypeMeta& get_default_dtype();
+} // namespace at
+++ /dev/null
-#pragma once
-
-#include <c10/core/Backend.h>
-#include <c10/Device.h>
-#include <c10/core/Layout.h>
-#include <c10/core/ScalarType.h>
-
-namespace at {
-
-struct TensorOptions;
-
-/// Like TensorOptions, but all fields are guaranteed to be filled.
-struct DefaultTensorOptions {
- DefaultTensorOptions() = default;
-
- caffe2::TypeMeta dtype() const noexcept { return dtype_; }
- Device device() const noexcept { return device_; }
- Layout layout() const noexcept { return layout_; }
- bool requires_grad() const noexcept { return requires_grad_; }
- bool is_variable() const noexcept { return is_variable_; }
-
- // Defined in TensorOptions.h
- inline DefaultTensorOptions& merge(const TensorOptions& options);
-
- private:
- caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
- Device device_ = at::kCPU; // 32-bit
- Layout layout_ = at::kStrided; // 8-bit
- bool requires_grad_ = false; // 8-bit
- bool is_variable_ = false; // 8-bit
-};
-
-inline const DefaultTensorOptions& getDefaultTensorOptions() {
- static const auto options = DefaultTensorOptions();
- return options;
-}
-} // namespace at
#pragma once
+#include <ATen/core/DefaultDtype.h>
#include <c10/core/Backend.h>
-#include <ATen/core/DefaultTensorOptions.h>
-#include <c10/Device.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeUtils.h>
+#include <c10/Device.h>
#include <c10/util/Optional.h>
#include <c10/util/C++17.h>
/// Returns the device of the `TensorOptions`.
Device device() const noexcept {
- return has_device_ ? device_ : getDefaultTensorOptions().device();
+ return has_device_ ? device_ : Device(kCPU);
}
/// Returns whether the device is specified.
/// Returns the dtype of the `TensorOptions`.
caffe2::TypeMeta dtype() const noexcept {
- return has_dtype_ ? dtype_ : getDefaultTensorOptions().dtype();
+ return has_dtype_ ? dtype_ : get_default_dtype();
}
/// Returns whether the dtype is specified.
/// Returns the layout of the `TensorOptions`.
Layout layout() const noexcept {
- return has_layout_ ? layout_ : getDefaultTensorOptions().layout();
+ return has_layout_ ? layout_ : kStrided;
}
/// Returns whether the layout is specified.
/// Returns the `requires_grad` property of the `TensorOptions`.
bool requires_grad() const noexcept {
- return has_requires_grad_ ? requires_grad_ : getDefaultTensorOptions().requires_grad();
+ return has_requires_grad_ ? requires_grad_ : false;
}
/// Returns whether the `requires_grad` is specified.
/// Returns the `is_variable` property of the `TensorOptions`.
bool is_variable() const noexcept {
- return has_is_variable_ ? is_variable_ : getDefaultTensorOptions().is_variable();
+ return has_is_variable_ ? is_variable_ : false;
}
/// Returns whether the `is_variable` is specified.
std::ostream& stream,
const TensorOptions& options);
-
-DefaultTensorOptions& DefaultTensorOptions::merge(const TensorOptions& options) {
- if (options.dtype_opt().has_value()) {
- dtype_ = options.dtype();
- }
- if (options.device_opt().has_value()) {
- device_ = options.device();
- }
- if (options.layout_opt().has_value()) {
- layout_ = options.layout();
- }
- if (options.requires_grad_opt().has_value()) {
- requires_grad_ = options.requires_grad();
- }
- if (options.is_variable_opt().has_value()) {
- is_variable_ = options.is_variable();
- }
- return *this;
-}
-
template <typename T>
inline TensorOptions dtype() {
return dtype(caffe2::TypeMeta::Make<T>());
using namespace at;
// A macro so we don't lose location information when an assertion fails.
-#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
+#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
ASSERT_EQ(options.device().type(), Device((device_), (index_)).type()); \
- ASSERT_TRUE( \
- options.device().index() == Device((device_), (index_)).index()); \
+ ASSERT_TRUE( \
+ options.device().index() == Device((device_), (index_)).index()); \
ASSERT_EQ(options.dtype(), (type_)); \
ASSERT_TRUE(options.layout() == (layout_))
-#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
+#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \
ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \
ASSERT_EQ(tensor.type().scalarType(), (type_)); \
ASSERT_ANY_THROW({ Device d(badness); });
}
}
+
+struct DefaultDtypeTest : ::testing::Test {
+ DefaultDtypeTest() {
+ set_default_dtype(caffe2::TypeMeta::Make<float>());
+ }
+ ~DefaultDtypeTest() {
+ set_default_dtype(caffe2::TypeMeta::Make<float>());
+ }
+};
+
+TEST_F(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
+ ASSERT_EQ(at::get_default_dtype(), kFloat);
+ set_default_dtype(caffe2::TypeMeta::Make<int>());
+ ASSERT_EQ(at::get_default_dtype(), kInt);
+}
+
+TEST_F(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
+ set_default_dtype(caffe2::TypeMeta::Make<int>());
+ ASSERT_EQ(at::get_default_dtype(), kInt);
+ TensorOptions options;
+ ASSERT_EQ(options.dtype(), kInt);
+}
+
+TEST_F(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
+ set_default_dtype(caffe2::TypeMeta::Make<int>());
+ {
+ auto tensor = torch::ones(5);
+ ASSERT_EQ(tensor.dtype(), kInt);
+ }
+ set_default_dtype(caffe2::TypeMeta::Make<double>());
+ {
+ auto tensor = torch::ones(5);
+ ASSERT_EQ(tensor.dtype(), kDouble);
+ }
+ {
+ auto tensor = torch::ones(5, kFloat);
+ ASSERT_EQ(tensor.dtype(), kFloat);
+ }
+}
self.assertTrue(net.training)
net.eval()
- input = torch.randn(2, 3, dtype=torch.float32)
+ input = torch.randn(2, 3)
output = net.forward(input)
self.assertEqual(output, net.forward(input))
self.assertEqual(list(output.shape), [2, 5])
self.assertEqual(len(matches), 1, str(matches))
self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
+ def test_set_default_type_also_changes_aten_default_type(self):
+ module = torch.utils.cpp_extension.load_inline(
+ name="test_set_default_type",
+ cpp_sources="torch::Tensor get() { return torch::empty({}); }",
+ functions="get",
+ verbose=True)
+
+ initial_default = torch.get_default_dtype()
+ try:
+ self.assertEqual(module.get().dtype, initial_default)
+ torch.set_default_dtype(torch.float64)
+ self.assertEqual(module.get().dtype, torch.float64)
+ torch.set_default_dtype(torch.float32)
+ self.assertEqual(module.get().dtype, torch.float32)
+ torch.set_default_dtype(torch.float16)
+ self.assertEqual(module.get().dtype, torch.float16)
+ finally:
+ torch.set_default_dtype(initial_default)
+
if __name__ == '__main__':
common.run_tests()
// get the storage first, so if it doesn't exist we don't change the default tensor type
THPObjectPtr storage = get_storage_obj(type);
default_tensor_type = const_cast<Type*>(&type);
+ at::set_default_dtype(default_tensor_type->typeMeta());
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!torch_module) throw python_error();