Set and get default dtype (#13748)
authorPeter Goldsborough <psag@fb.com>
Wed, 5 Dec 2018 18:18:20 +0000 (10:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 18:28:41 +0000 (10:28 -0800)
Summary:
Replaces the `DefaultTensorOptions` with just a global default dtype that you can set and get like in Python.

Also, calls `set_default_dtype` in the implementation of `torch.set_default_dtype`. Right now these two default values are separate but will always be the same. Should we just bind `set_default_dtype`  into Python? I think that might be good to do in a separate PR though.

ezyang gchanan

Also CC colesbury who wanted to do this for ATen for a while? What do you think about it?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13748

Differential Revision: D13340207

Pulled By: goldsborough

fbshipit-source-id: 2689b09eb137fabb3a92d1ad1635782bee9398e8

aten/src/ATen/core/DefaultDtype.cpp [new file with mode: 0644]
aten/src/ATen/core/DefaultDtype.h [new file with mode: 0644]
aten/src/ATen/core/DefaultTensorOptions.h [deleted file]
aten/src/ATen/core/TensorOptions.h
test/cpp/api/tensor_options.cpp
test/test_cpp_extensions.py
torch/csrc/tensor/python_tensor.cpp

diff --git a/aten/src/ATen/core/DefaultDtype.cpp b/aten/src/ATen/core/DefaultDtype.cpp
new file mode 100644 (file)
index 0000000..6e49ecb
--- /dev/null
@@ -0,0 +1,14 @@
+#include <ATen/core/typeid.h>
+#include <ATen/core/DefaultDtype.h>
+
+namespace at {
+static auto default_dtype = caffe2::TypeMeta::Make<float>();
+
+void set_default_dtype(caffe2::TypeMeta dtype) {
+  default_dtype = std::move(dtype);
+}
+
+const caffe2::TypeMeta& get_default_dtype() {
+  return default_dtype;
+}
+} // namespace at
diff --git a/aten/src/ATen/core/DefaultDtype.h b/aten/src/ATen/core/DefaultDtype.h
new file mode 100644 (file)
index 0000000..6c18c84
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+
+namespace caffe2 {
+class TypeMeta;
+} // namespace caffe2
+
+namespace at {
+CAFFE2_API void set_default_dtype(caffe2::TypeMeta dtype);
+CAFFE2_API const caffe2::TypeMeta& get_default_dtype();
+} // namespace at
diff --git a/aten/src/ATen/core/DefaultTensorOptions.h b/aten/src/ATen/core/DefaultTensorOptions.h
deleted file mode 100644 (file)
index b4714ba..0000000
+++ /dev/null
@@ -1,37 +0,0 @@
-#pragma once
-
-#include <c10/core/Backend.h>
-#include <c10/Device.h>
-#include <c10/core/Layout.h>
-#include <c10/core/ScalarType.h>
-
-namespace at {
-
-struct TensorOptions;
-
-/// Like TensorOptions, but all fields are guaranteed to be filled.
-struct DefaultTensorOptions {
-  DefaultTensorOptions() = default;
-
-  caffe2::TypeMeta dtype() const noexcept { return dtype_; }
-  Device device()          const noexcept { return device_; }
-  Layout layout()          const noexcept { return layout_; }
-  bool requires_grad()     const noexcept { return requires_grad_; }
-  bool is_variable()       const noexcept { return is_variable_; }
-
-  // Defined in TensorOptions.h
-  inline DefaultTensorOptions& merge(const TensorOptions& options);
-
- private:
-  caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
-  Device device_          = at::kCPU;                        // 32-bit
-  Layout layout_          = at::kStrided;                    // 8-bit
-  bool requires_grad_     = false;                           // 8-bit
-  bool is_variable_       = false;                           // 8-bit
-};
-
-inline const DefaultTensorOptions& getDefaultTensorOptions() {
-  static const auto options = DefaultTensorOptions();
-  return options;
-}
-} // namespace at
index b5f1fd9..682d8ed 100644 (file)
@@ -1,11 +1,11 @@
 #pragma once
 
+#include <ATen/core/DefaultDtype.h>
 #include <c10/core/Backend.h>
-#include <ATen/core/DefaultTensorOptions.h>
-#include <c10/Device.h>
 #include <c10/core/Layout.h>
 #include <c10/core/ScalarType.h>
 #include <c10/core/ScalarTypeUtils.h>
+#include <c10/Device.h>
 
 #include <c10/util/Optional.h>
 #include <c10/util/C++17.h>
@@ -240,7 +240,7 @@ struct CAFFE2_API TensorOptions {
 
   /// Returns the device of the `TensorOptions`.
   Device device() const noexcept {
-    return has_device_ ? device_ : getDefaultTensorOptions().device();
+    return has_device_ ? device_ : Device(kCPU);
   }
 
   /// Returns whether the device is specified.
@@ -261,7 +261,7 @@ struct CAFFE2_API TensorOptions {
 
   /// Returns the dtype of the `TensorOptions`.
   caffe2::TypeMeta dtype() const noexcept {
-    return has_dtype_ ? dtype_ : getDefaultTensorOptions().dtype();
+    return has_dtype_ ? dtype_ : get_default_dtype();
   }
 
   /// Returns whether the dtype is specified.
@@ -277,7 +277,7 @@ struct CAFFE2_API TensorOptions {
 
   /// Returns the layout of the `TensorOptions`.
   Layout layout() const noexcept {
-    return has_layout_ ? layout_ : getDefaultTensorOptions().layout();
+    return has_layout_ ? layout_ : kStrided;
   }
 
   /// Returns whether the layout is specified.
@@ -293,7 +293,7 @@ struct CAFFE2_API TensorOptions {
 
   /// Returns the `requires_grad` property of the `TensorOptions`.
   bool requires_grad() const noexcept {
-    return has_requires_grad_ ? requires_grad_ : getDefaultTensorOptions().requires_grad();
+    return has_requires_grad_ ? requires_grad_ : false;
   }
 
   /// Returns whether the `requires_grad` is specified.
@@ -310,7 +310,7 @@ struct CAFFE2_API TensorOptions {
 
   /// Returns the `is_variable` property of the `TensorOptions`.
   bool is_variable() const noexcept {
-    return has_is_variable_ ? is_variable_ : getDefaultTensorOptions().is_variable();
+    return has_is_variable_ ? is_variable_ : false;
   }
 
   /// Returns whether the `is_variable` is specified.
@@ -477,26 +477,6 @@ CAFFE2_API std::ostream& operator<<(
     std::ostream& stream,
     const TensorOptions& options);
 
-
-DefaultTensorOptions& DefaultTensorOptions::merge(const TensorOptions& options) {
-  if (options.dtype_opt().has_value()) {
-    dtype_ = options.dtype();
-  }
-  if (options.device_opt().has_value()) {
-    device_ = options.device();
-  }
-  if (options.layout_opt().has_value()) {
-    layout_ = options.layout();
-  }
-  if (options.requires_grad_opt().has_value()) {
-    requires_grad_ = options.requires_grad();
-  }
-  if (options.is_variable_opt().has_value()) {
-    is_variable_ = options.is_variable();
-  }
-  return *this;
-}
-
 template <typename T>
 inline TensorOptions dtype() {
   return dtype(caffe2::TypeMeta::Make<T>());
index 047a620..bc8d571 100644 (file)
 using namespace at;
 
 // A macro so we don't lose location information when an assertion fails.
-#define REQUIRE_OPTIONS(device_, index_, type_, layout_)                      \
+#define REQUIRE_OPTIONS(device_, index_, type_, layout_)                  \
   ASSERT_EQ(options.device().type(), Device((device_), (index_)).type()); \
-  ASSERT_TRUE(                                                                \
-      options.device().index() == Device((device_), (index_)).index());       \
+  ASSERT_TRUE(                                                            \
+      options.device().index() == Device((device_), (index_)).index());   \
   ASSERT_EQ(options.dtype(), (type_));                                    \
   ASSERT_TRUE(options.layout() == (layout_))
 
-#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)                \
+#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)            \
   ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type());   \
   ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \
   ASSERT_EQ(tensor.type().scalarType(), (type_));                          \
@@ -128,3 +128,42 @@ TEST(DeviceTest, ParsesCorrectlyFromString) {
     ASSERT_ANY_THROW({ Device d(badness); });
   }
 }
+
+struct DefaultDtypeTest : ::testing::Test {
+  DefaultDtypeTest() {
+    set_default_dtype(caffe2::TypeMeta::Make<float>());
+  }
+  ~DefaultDtypeTest() {
+    set_default_dtype(caffe2::TypeMeta::Make<float>());
+  }
+};
+
+TEST_F(DefaultDtypeTest, CanSetAndGetDefaultDtype) {
+  ASSERT_EQ(at::get_default_dtype(), kFloat);
+  set_default_dtype(caffe2::TypeMeta::Make<int>());
+  ASSERT_EQ(at::get_default_dtype(), kInt);
+}
+
+TEST_F(DefaultDtypeTest, NewTensorOptionsHasCorrectDefault) {
+  set_default_dtype(caffe2::TypeMeta::Make<int>());
+  ASSERT_EQ(at::get_default_dtype(), kInt);
+  TensorOptions options;
+  ASSERT_EQ(options.dtype(), kInt);
+}
+
+TEST_F(DefaultDtypeTest, NewTensorsHaveCorrectDefaultDtype) {
+  set_default_dtype(caffe2::TypeMeta::Make<int>());
+  {
+    auto tensor = torch::ones(5);
+    ASSERT_EQ(tensor.dtype(), kInt);
+  }
+  set_default_dtype(caffe2::TypeMeta::Make<double>());
+  {
+    auto tensor = torch::ones(5);
+    ASSERT_EQ(tensor.dtype(), kDouble);
+  }
+  {
+    auto tensor = torch::ones(5, kFloat);
+    ASSERT_EQ(tensor.dtype(), kFloat);
+  }
+}
index a032b4d..4c42356 100755 (executable)
@@ -365,7 +365,7 @@ class TestCppExtension(common.TestCase):
         self.assertTrue(net.training)
         net.eval()
 
-        input = torch.randn(2, 3, dtype=torch.float32)
+        input = torch.randn(2, 3)
         output = net.forward(input)
         self.assertEqual(output, net.forward(input))
         self.assertEqual(list(output.shape), [2, 5])
@@ -416,6 +416,25 @@ class TestCppExtension(common.TestCase):
         self.assertEqual(len(matches), 1, str(matches))
         self.assertEqual(matches[0], "no_python_abi_suffix_test.so", str(matches))
 
+    def test_set_default_type_also_changes_aten_default_type(self):
+        module = torch.utils.cpp_extension.load_inline(
+            name="test_set_default_type",
+            cpp_sources="torch::Tensor get() { return torch::empty({}); }",
+            functions="get",
+            verbose=True)
+
+        initial_default = torch.get_default_dtype()
+        try:
+            self.assertEqual(module.get().dtype, initial_default)
+            torch.set_default_dtype(torch.float64)
+            self.assertEqual(module.get().dtype, torch.float64)
+            torch.set_default_dtype(torch.float32)
+            self.assertEqual(module.get().dtype, torch.float32)
+            torch.set_default_dtype(torch.float16)
+            self.assertEqual(module.get().dtype, torch.float16)
+        finally:
+            torch.set_default_dtype(initial_default)
+
 
 if __name__ == '__main__':
     common.run_tests()
index fc6ad74..00936c3 100644 (file)
@@ -372,6 +372,7 @@ void set_default_tensor_type(const at::Type& type) {
   // get the storage first, so if it doesn't exist we don't change the default tensor type
   THPObjectPtr storage = get_storage_obj(type);
   default_tensor_type = const_cast<Type*>(&type);
+  at::set_default_dtype(default_tensor_type->typeMeta());
 
   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
   if (!torch_module) throw python_error();