From 3ae721d3501c56cdf6819fe40a54dc6b77900fdd Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Wed, 5 Dec 2018 10:18:20 -0800 Subject: [PATCH] Set and get default dtype (#13748) Summary: Replaces the `DefaultTensorOptions` with just a global default dtype that you can set and get like in Python. Also, calls `set_default_dtype` in the implementation of `torch.set_default_dtype`. Right now these two default values are separate but will always be the same. Should we just bind `set_default_dtype` into Python? I think that might be good to do in a separate PR though. ezyang gchanan Also CC colesbury who wanted to do this for ATen for a while? What do you think about it? Pull Request resolved: https://github.com/pytorch/pytorch/pull/13748 Differential Revision: D13340207 Pulled By: goldsborough fbshipit-source-id: 2689b09eb137fabb3a92d1ad1635782bee9398e8 --- aten/src/ATen/core/DefaultDtype.cpp | 14 +++++++++ aten/src/ATen/core/DefaultDtype.h | 12 ++++++++ aten/src/ATen/core/DefaultTensorOptions.h | 37 ------------------------ aten/src/ATen/core/TensorOptions.h | 34 +++++----------------- test/cpp/api/tensor_options.cpp | 47 ++++++++++++++++++++++++++++--- test/test_cpp_extensions.py | 21 +++++++++++++- torch/csrc/tensor/python_tensor.cpp | 1 + 7 files changed, 97 insertions(+), 69 deletions(-) create mode 100644 aten/src/ATen/core/DefaultDtype.cpp create mode 100644 aten/src/ATen/core/DefaultDtype.h delete mode 100644 aten/src/ATen/core/DefaultTensorOptions.h diff --git a/aten/src/ATen/core/DefaultDtype.cpp b/aten/src/ATen/core/DefaultDtype.cpp new file mode 100644 index 0000000..6e49ecb --- /dev/null +++ b/aten/src/ATen/core/DefaultDtype.cpp @@ -0,0 +1,14 @@ +#include +#include + +namespace at { +static auto default_dtype = caffe2::TypeMeta::Make(); + +void set_default_dtype(caffe2::TypeMeta dtype) { + default_dtype = std::move(dtype); +} + +const caffe2::TypeMeta& get_default_dtype() { + return default_dtype; +} +} // namespace at diff --git a/aten/src/ATen/core/DefaultDtype.h b/aten/src/ATen/core/DefaultDtype.h new file mode 100644 index 0000000..6c18c84 --- /dev/null +++ b/aten/src/ATen/core/DefaultDtype.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace caffe2 { +class TypeMeta; +} // namespace caffe2 + +namespace at { +CAFFE2_API void set_default_dtype(caffe2::TypeMeta dtype); +CAFFE2_API const caffe2::TypeMeta& get_default_dtype(); +} // namespace at diff --git a/aten/src/ATen/core/DefaultTensorOptions.h b/aten/src/ATen/core/DefaultTensorOptions.h deleted file mode 100644 index b4714ba..0000000 --- a/aten/src/ATen/core/DefaultTensorOptions.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace at { - -struct TensorOptions; - -/// Like TensorOptions, but all fields are guaranteed to be filled. -struct DefaultTensorOptions { - DefaultTensorOptions() = default; - - caffe2::TypeMeta dtype() const noexcept { return dtype_; } - Device device() const noexcept { return device_; } - Layout layout() const noexcept { return layout_; } - bool requires_grad() const noexcept { return requires_grad_; } - bool is_variable() const noexcept { return is_variable_; } - - // Defined in TensorOptions.h - inline DefaultTensorOptions& merge(const TensorOptions& options); - - private: - caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make(); // 64-bit - Device device_ = at::kCPU; // 32-bit - Layout layout_ = at::kStrided; // 8-bit - bool requires_grad_ = false; // 8-bit - bool is_variable_ = false; // 8-bit -}; - -inline const DefaultTensorOptions& getDefaultTensorOptions() { - static const auto options = DefaultTensorOptions(); - return options; -} -} // namespace at diff --git a/aten/src/ATen/core/TensorOptions.h b/aten/src/ATen/core/TensorOptions.h index b5f1fd9..682d8ed 100644 --- a/aten/src/ATen/core/TensorOptions.h +++ b/aten/src/ATen/core/TensorOptions.h @@ -1,11 +1,11 @@ #pragma once +#include #include -#include -#include #include #include #include +#include #include #include @@ -240,7 +240,7 @@ struct CAFFE2_API TensorOptions { /// Returns the device of the `TensorOptions`. Device device() const noexcept { - return has_device_ ? device_ : getDefaultTensorOptions().device(); + return has_device_ ? device_ : Device(kCPU); } /// Returns whether the device is specified. @@ -261,7 +261,7 @@ struct CAFFE2_API TensorOptions { /// Returns the dtype of the `TensorOptions`. caffe2::TypeMeta dtype() const noexcept { - return has_dtype_ ? dtype_ : getDefaultTensorOptions().dtype(); + return has_dtype_ ? dtype_ : get_default_dtype(); } /// Returns whether the dtype is specified. @@ -277,7 +277,7 @@ struct CAFFE2_API TensorOptions { /// Returns the layout of the `TensorOptions`. Layout layout() const noexcept { - return has_layout_ ? layout_ : getDefaultTensorOptions().layout(); + return has_layout_ ? layout_ : kStrided; } /// Returns whether the layout is specified. @@ -293,7 +293,7 @@ struct CAFFE2_API TensorOptions { /// Returns the `requires_grad` property of the `TensorOptions`. bool requires_grad() const noexcept { - return has_requires_grad_ ? requires_grad_ : getDefaultTensorOptions().requires_grad(); + return has_requires_grad_ ? requires_grad_ : false; } /// Returns whether the `requires_grad` is specified. @@ -310,7 +310,7 @@ struct CAFFE2_API TensorOptions { /// Returns the `is_variable` property of the `TensorOptions`. bool is_variable() const noexcept { - return has_is_variable_ ? is_variable_ : getDefaultTensorOptions().is_variable(); + return has_is_variable_ ? is_variable_ : false; } /// Returns whether the `is_variable` is specified. @@ -477,26 +477,6 @@ CAFFE2_API std::ostream& operator<<( std::ostream& stream, const TensorOptions& options); - -DefaultTensorOptions& DefaultTensorOptions::merge(const TensorOptions& options) { - if (options.dtype_opt().has_value()) { - dtype_ = options.dtype(); - } - if (options.device_opt().has_value()) { - device_ = options.device(); - } - if (options.layout_opt().has_value()) { - layout_ = options.layout(); - } - if (options.requires_grad_opt().has_value()) { - requires_grad_ = options.requires_grad(); - } - if (options.is_variable_opt().has_value()) { - is_variable_ = options.is_variable(); - } - return *this; -} - template inline TensorOptions dtype() { return dtype(caffe2::TypeMeta::Make()); diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp index 047a620..bc8d571 100644 --- a/test/cpp/api/tensor_options.cpp +++ b/test/cpp/api/tensor_options.cpp @@ -12,14 +12,14 @@ using namespace at; // A macro so we don't lose location information when an assertion fails. -#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \ +#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \ ASSERT_EQ(options.device().type(), Device((device_), (index_)).type()); \ - ASSERT_TRUE( \ - options.device().index() == Device((device_), (index_)).index()); \ + ASSERT_TRUE( \ + options.device().index() == Device((device_), (index_)).index()); \ ASSERT_EQ(options.dtype(), (type_)); \ ASSERT_TRUE(options.layout() == (layout_)) -#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \ +#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \ ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \ ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \ ASSERT_EQ(tensor.type().scalarType(), (type_)); \ @@ -128,3 +128,42 @@ TEST(DeviceTest, ParsesCorrectlyFromString) { ASSERT_ANY_THROW({ Device d(badness); }); } } + +struct DefaultDtypeTest : ::testing::Test { + DefaultDtypeTest() { + set_default_dtype(caffe2::TypeMeta::Make()); + } + ~DefaultDtypeTest() { + set_default_dtype(caffe2::TypeMeta::Make()); + } +}; + +TEST_F(DefaultDtypeTest, CanSetAndGetDefaultDtype) { + ASSERT_EQ(at::get_default_dtype(), kFloat); + set_default_dtype(caffe2::TypeMeta::Make()); + ASSERT_EQ(at::get_default_dtype(), kInt); +} + +TEST_F(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) { + set_default_dtype(caffe2::TypeMeta::Make()); + ASSERT_EQ(at::get_default_dtype(), kInt); + TensorOptions options; + ASSERT_EQ(options.dtype(), kInt); +} + +TEST_F(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) { + set_default_dtype(caffe2::TypeMeta::Make()); + { + auto tensor = torch::ones(5); + ASSERT_EQ(tensor.dtype(), kInt); + } + set_default_dtype(caffe2::TypeMeta::Make()); + { + auto tensor = torch::ones(5); + ASSERT_EQ(tensor.dtype(), kDouble); + } + { + auto tensor = torch::ones(5, kFloat); + ASSERT_EQ(tensor.dtype(), kFloat); + } +} diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index a032b4d..4c42356 100755 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -365,7 +365,7 @@ class TestCppExtension(common.TestCase): self.assertTrue(net.training) net.eval() - input = torch.randn(2, 3, dtype=torch.float32) + input = torch.randn(2, 3) output = net.forward(input) self.assertEqual(output, net.forward(input)) self.assertEqual(list(output.shape), [2, 5]) @@ -416,6 +416,25 @@ class TestCppExtension(common.TestCase): self.assertEqual(len(matches), 1, str(matches)) self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches)) + def test_set_default_type_also_changes_aten_default_type(self): + module = torch.utils.cpp_extension.load_inline( + name="test_set_default_type", + cpp_sources="torch::Tensor get() { return torch::empty({}); }", + functions="get", + verbose=True) + + initial_default = torch.get_default_dtype() + try: + self.assertEqual(module.get().dtype, initial_default) + torch.set_default_dtype(torch.float64) + self.assertEqual(module.get().dtype, torch.float64) + torch.set_default_dtype(torch.float32) + self.assertEqual(module.get().dtype, torch.float32) + torch.set_default_dtype(torch.float16) + self.assertEqual(module.get().dtype, torch.float16) + finally: + torch.set_default_dtype(initial_default) + if __name__ == '__main__': common.run_tests() diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index fc6ad74..00936c3 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -372,6 +372,7 @@ void set_default_tensor_type(const at::Type& type) { // get the storage first, so if it doesn't exist we don't change the default tensor type THPObjectPtr storage = get_storage_obj(type); default_tensor_type = const_cast(&type); + at::set_default_dtype(default_tensor_type->typeMeta()); auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); if (!torch_module) throw python_error(); -- 2.7.4