Replace more usages of Type with DeprecatedTypeProperties (#19093)
authorRoy Li <royboy@fb.com>
Thu, 11 Apr 2019 23:55:39 +0000 (16:55 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 12 Apr 2019 00:02:05 +0000 (17:02 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19093
ghimport-source-id: a82e3dce912a173b42a6a7e35eb1302d9f334e03

Differential Revision: D14865520

Pulled By: li-roy

fbshipit-source-id: b1a8bf32f87920ce8d82f990d670477bc79d0ca7

22 files changed:
aten/src/ATen/Context.h
aten/src/ATen/Dispatch.h
aten/src/ATen/core/DeprecatedTypeProperties.cpp [new file with mode: 0644]
aten/src/ATen/core/DeprecatedTypeProperties.h
aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp
aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/native/TypeProperties.cpp
aten/src/ATen/templates/Tensor.h
aten/src/ATen/templates/TensorMethods.h
aten/src/ATen/test/apply_utils_test.cpp
aten/src/ATen/test/basic.cpp
aten/src/ATen/test/broadcast_test.cpp
aten/src/ATen/test/native_test.cpp
aten/src/ATen/test/scalar_tensor_test.cpp
aten/src/ATen/test/undefined_tensor_test.cpp
aten/src/ATen/test/wrapdim_test.cpp
test/cpp/jit/test_argument_spec.h
torch/csrc/Generator.cpp
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/autograd/functions/tensor.h

index 76d1a90..9033494 100644 (file)
@@ -182,16 +182,19 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
 
 CAFFE2_API Allocator* getCPUAllocator();
 
-static inline TypeExtendedInterface& CPU(ScalarType s) {
-  return getNonVariableType(Backend::CPU, s);
+static inline DeprecatedTypeProperties& CPU(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::CPU, s, /*is_variable*/false);
 }
 
-static inline TypeExtendedInterface& CUDA(ScalarType s) {
-  return getNonVariableType(Backend::CUDA, s);
+static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::CUDA, s, /*is_variable*/false);
 }
 
-static inline TypeExtendedInterface& HIP(ScalarType s) {
-  return getNonVariableType(Backend::HIP, s);
+static inline DeprecatedTypeProperties& HIP(ScalarType s) {
+  return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+      Backend::HIP, s, /*is_variable*/false);
 }
 
 CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options);
index 2e83529..7ec10ef 100644 (file)
@@ -3,6 +3,7 @@
 #include <ATen/Type.h>
 #include <c10/util/Half.h>
 #include <c10/util/Exception.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
 
 #define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
   case enum_type: {                                \
diff --git a/aten/src/ATen/core/DeprecatedTypeProperties.cpp b/aten/src/ATen/core/DeprecatedTypeProperties.cpp
new file mode 100644 (file)
index 0000000..b634a50
--- /dev/null
@@ -0,0 +1,24 @@
+#include <ATen/core/DeprecatedTypeProperties.h>
+
+#include <ATen/core/LegacyTypeDispatch.h>
+#include <ATen/core/Type.h>
+
+namespace at {
+
+Tensor DeprecatedTypeProperties::unsafeTensorFromTH(void * th_pointer, bool retain) const {
+  return getDispatchType().unsafeTensorFromTH(th_pointer, retain);
+}
+
+Tensor DeprecatedTypeProperties::copy(const Tensor & src, bool non_blocking, c10::optional<Device> to_device) const {
+  return getDispatchType().copy(src, non_blocking, to_device);
+}
+
+std::unique_ptr<Generator> DeprecatedTypeProperties::generator() const {
+  return getDispatchType().generator();
+}
+
+Type & DeprecatedTypeProperties::getDispatchType() const {
+  return globalLegacyTypeDispatch().getType(backend_, scalar_type_, is_variable_);
+}
+
+} // namespace at
index 88f53f6..116af9d 100644 (file)
@@ -3,24 +3,33 @@
 #include <c10/core/Backend.h>
 #include <c10/core/ScalarType.h>
 #include <c10/core/Layout.h>
-
+#include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
+#include <ATen/core/Generator.h>
 
 
 namespace at {
 
+class Tensor;
+struct Type;
+
 // This class specifies a Backend and a ScalarType. Currently, it primarily
 // serves as a replacement return value for Tensor::type(). Previously,
 // Tensor::type() returned Type&, but we are changing Type to not be
 // dtype-specific.
-class DeprecatedTypeProperties {
+class CAFFE2_API DeprecatedTypeProperties {
  public:
-  DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
-    : backend_(backend), scalar_type_(scalar_type) {}
+  DeprecatedTypeProperties(Backend backend, ScalarType scalar_type, bool is_variable)
+    : backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable) {}
 
   Backend backend() const {
     return backend_;
   }
 
+  Layout layout() const {
+    return layout_from_backend(backend_);
+  }
+
   bool is_sparse() const {
     return layout_from_backend(backend()) == kSparse;
   }
@@ -41,8 +50,8 @@ class DeprecatedTypeProperties {
     return scalarTypeToTypeMeta(scalar_type_);
   }
 
-  bool is_defined() const {
-    return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined;
+  bool is_variable() const {
+    return is_variable_;
   }
 
   bool operator==(const DeprecatedTypeProperties& other) const {
@@ -59,9 +68,62 @@ class DeprecatedTypeProperties {
     return ss.str();
   }
 
+  DeprecatedTypeProperties & toBackend(Backend b) const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        b, scalar_type_, is_variable_);
+  }
+
+  DeprecatedTypeProperties & toScalarType(ScalarType s) const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        backend_, s, is_variable_);
+  }
+
+  DeprecatedTypeProperties & cpu() const {
+    return toBackend(Backend::CPU);
+  }
+
+  DeprecatedTypeProperties & cuda() const {
+    return toBackend(Backend::CUDA);
+  }
+
+  DeprecatedTypeProperties & hip() const {
+    return toBackend(Backend::HIP);
+  }
+
+  /// Constructs the `TensorOptions` from a type and a `device_index`.
+  TensorOptions options(int16_t device_index = -1) const {
+    return TensorOptions().dtype(typeMeta())
+                          .device(device_type(), device_index)
+                          .layout(layout())
+                          .is_variable(is_variable());
+  }
+
+  /// Constructs the `TensorOptions` from a type and a Device.  Asserts that
+  /// the device type matches the device type of the type.
+  TensorOptions options(c10::optional<Device> device_opt) const {
+    if (!device_opt.has_value()) {
+      return options(-1);
+    } else {
+      Device device = device_opt.value();
+      AT_ASSERT(device.type() == device_type());
+      return options(device.index());
+    }
+  }
+
+  operator TensorOptions() const {
+    return options();
+  }
+
+  Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
+  Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const;
+  std::unique_ptr<Generator> generator() const;
+
  private:
+  Type & getDispatchType() const;
+
   Backend backend_;
   ScalarType scalar_type_;
+  bool is_variable_;
 };
 
 }  // namespace at
index 154f04d..e9188bf 100644 (file)
@@ -1,7 +1,31 @@
 #include <ATen/core/DeprecatedTypePropertiesRegistry.h>
 
+#include <ATen/core/DeprecatedTypeProperties.h>
+
 namespace at {
 
+void DeprecatedTypePropertiesDeleter::operator()(DeprecatedTypeProperties * ptr) {
+  delete ptr;
+}
+
+DeprecatedTypePropertiesRegistry::DeprecatedTypePropertiesRegistry() {
+  for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
+    for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
+      for (int v = 0; v < 2; ++ v) {
+        registry[b][s][v] = c10::guts::make_unique<DeprecatedTypeProperties>(
+                static_cast<Backend>(b),
+                static_cast<ScalarType>(s),
+                v);
+      }
+    }
+  }
+}
+
+DeprecatedTypeProperties& DeprecatedTypePropertiesRegistry::getDeprecatedTypeProperties(
+    Backend p, ScalarType s, bool is_variable) const {
+  return *registry[static_cast<int>(p)][static_cast<int>(s)][is_variable];
+}
+
 // TODO: This could be bad juju if someone calls globalContext() in the
 // destructor of an object with static lifetime.
 DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() {
index 0ab57bf..543db04 100644 (file)
@@ -5,40 +5,26 @@
 
 #include <c10/core/Backend.h>
 #include <c10/core/ScalarType.h>
-#include <ATen/core/DeprecatedTypeProperties.h>
 
 namespace at {
 
+class DeprecatedTypeProperties;
+
 struct CAFFE2_API DeprecatedTypePropertiesDeleter {
-  void operator()(DeprecatedTypeProperties * ptr) {
-      delete ptr;
-  }
+  void operator()(DeprecatedTypeProperties * ptr);
 };
 
 class CAFFE2_API DeprecatedTypePropertiesRegistry {
  public:
-  using DeprecatedTypePropertiesUniquePtr =
-      std::unique_ptr<DeprecatedTypeProperties, DeprecatedTypePropertiesDeleter>;
-
-  DeprecatedTypePropertiesRegistry() {
-    for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
-      for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
-        registry[b][s] = DeprecatedTypePropertiesUniquePtr{
-            new DeprecatedTypeProperties(static_cast<Backend>(b), static_cast<ScalarType>(s)),
-            DeprecatedTypePropertiesDeleter()
-        };
-      }
-    }
-  }
-
-  DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) {
-    return *registry[static_cast<int>(p)][static_cast<int>(s)];
-  }
+  DeprecatedTypePropertiesRegistry();
+
+  DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s, bool is_variable) const;
 
 private:
-  DeprecatedTypePropertiesUniquePtr registry
+  std::unique_ptr<DeprecatedTypeProperties> registry
     [static_cast<int>(Backend::NumOptions)]
-    [static_cast<int>(ScalarType::NumOptions)];
+    [static_cast<int>(ScalarType::NumOptions)]
+    [2];  // is_variable
 };
 
 CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
index 6f58a4a..3ba8a57 100644 (file)
@@ -21,6 +21,7 @@ struct TensorOptions;
 namespace at {
 struct Generator;
 struct Type;
+class DeprecatedTypeProperties;
 class Tensor;
 } // namespace at
 
@@ -199,7 +200,9 @@ class CAFFE2_API Tensor {
 
   DeprecatedTypeProperties & type() const {
     return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
-        tensorTypeIdToBackend(type_id()), scalar_type());
+        tensorTypeIdToBackend(type_id()),
+        scalar_type(),
+        is_variable() && !at::NonVariableTypeMode::is_enabled());
   }
   Type & dispatch_type() const {
     return legacyTensorType(*impl_);
@@ -219,8 +222,8 @@ class CAFFE2_API Tensor {
   bool is_alias_of(const at::Tensor& other) const{
     return impl_->storage().is_alias_of(other.storage());
   }
-  Tensor toType(const Type & t, bool non_blocking=false) const;
   Tensor & copy_(const Tensor & src, bool non_blocking=false);
+  Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
   Tensor toType(ScalarType t) const;
   Tensor toBackend(Backend b) const;
 
index 26f3807..367839d 100644 (file)
@@ -6,25 +6,29 @@
 #include <ATen/core/SparseTensorRef.h>
 #include <ATen/core/Type.h>
 #include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
 
 namespace at {
 
-inline Tensor Tensor::toType(const Type & t, bool non_blocking) const {
-  if(dispatch_type() == t)
+inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const {
+  if(type() == t)
     return *this;
-  return t.copy(*this, non_blocking);
+  return to(
+      at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()),
+      non_blocking,
+      /*copy*/ true);
 }
 
 inline Tensor Tensor::cpu() const {
-  return toType(dispatch_type().cpu());
+  return toType(type().cpu());
 }
 
 inline Tensor Tensor::cuda() const {
-  return toType(dispatch_type().cuda());
+  return toType(type().cuda());
 }
 
 inline Tensor Tensor::hip() const {
-  return toType(dispatch_type().hip());
+  return toType(type().hip());
 }
 
 inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
@@ -32,11 +36,11 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
 }
 
 inline Tensor Tensor::toType(ScalarType t) const {
-  return toType(dispatch_type().toScalarType(t));
+  return toType(type().toScalarType(t));
 }
 
 inline Tensor Tensor::toBackend(Backend b) const {
-  return toType(dispatch_type().toBackend(b));
+  return toType(type().toBackend(b));
 }
 
 inline TensorOptions Tensor::options() const {
index c2cae17..ae1910f 100644 (file)
@@ -35,7 +35,7 @@ bool is_sparse(const Tensor& self) {
 }
 
 Tensor type_as(const Tensor& self, const Tensor& other) {
-  return self.toType(other.dispatch_type());
+  return self.toType(other.type());
 }
 
 }} // namespace at::native
index b1e917a..ab9917e 100644 (file)
@@ -21,6 +21,7 @@ struct TensorOptions;
 namespace at {
 struct Generator;
 struct Type;
+class DeprecatedTypeProperties;
 class Tensor;
 } // namespace at
 
@@ -199,7 +200,9 @@ class CAFFE2_API Tensor {
 
   DeprecatedTypeProperties & type() const {
     return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
-        tensorTypeIdToBackend(type_id()), scalar_type());
+        tensorTypeIdToBackend(type_id()),
+        scalar_type(),
+        is_variable() && !at::NonVariableTypeMode::is_enabled());
   }
   Type & dispatch_type() const {
     return legacyTensorType(*impl_);
@@ -219,8 +222,8 @@ class CAFFE2_API Tensor {
   bool is_alias_of(const at::Tensor& other) const{
     return impl_->storage().is_alias_of(other.storage());
   }
-  Tensor toType(const Type & t, bool non_blocking=false) const;
   Tensor & copy_(const Tensor & src, bool non_blocking=false);
+  Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
   Tensor toType(ScalarType t) const;
   Tensor toBackend(Backend b) const;
 
index 5928907..18b5e53 100644 (file)
@@ -6,25 +6,29 @@
 #include <ATen/core/SparseTensorRef.h>
 #include <ATen/core/Type.h>
 #include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
 
 namespace at {
 
-inline Tensor Tensor::toType(const Type & t, bool non_blocking) const {
-  if(dispatch_type() == t)
+inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const {
+  if(type() == t)
     return *this;
-  return t.copy(*this, non_blocking);
+  return to(
+      at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()),
+      non_blocking,
+      /*copy*/ true);
 }
 
 inline Tensor Tensor::cpu() const {
-  return toType(dispatch_type().cpu());
+  return toType(type().cpu());
 }
 
 inline Tensor Tensor::cuda() const {
-  return toType(dispatch_type().cuda());
+  return toType(type().cuda());
 }
 
 inline Tensor Tensor::hip() const {
-  return toType(dispatch_type().hip());
+  return toType(type().hip());
 }
 
 inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
@@ -32,11 +36,11 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
 }
 
 inline Tensor Tensor::toType(ScalarType t) const {
-  return toType(dispatch_type().toScalarType(t));
+  return toType(type().toScalarType(t));
 }
 
 inline Tensor Tensor::toBackend(Backend b) const {
-  return toType(dispatch_type().toBackend(b));
+  return toType(type().toBackend(b));
 }
 
 inline TensorOptions Tensor::options() const {
index cc97c03..2f5e1b6 100644 (file)
@@ -23,7 +23,7 @@ void fill_tensor(int64_t scalar, Tensor& t_) {
 // write the same type as we read (using a0, ..., aX-1) and we once write to
 // double (using a4 as a target). We also exercise on a zero_dim and empty
 // tensor.
-void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
+void test(DeprecatedTypeProperties& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
   auto zero_dim = at::empty({}, type);
   zero_dim.fill_(2);
   zero_dim.exp_();
index ebd569a..7f2e5ce 100644 (file)
@@ -21,7 +21,7 @@ extern "C" void THFloatTensor_fill(THFloatTensor *, float v);
 
 using namespace at;
 
-void TestResize(Type& type) {
+void TestResize(DeprecatedTypeProperties& type) {
   auto a = at::empty({0}, type.options());
   a.resize_({3, 4});
   ASSERT_EQ_RESOLVED(a.numel(), 12);
@@ -29,7 +29,7 @@ void TestResize(Type& type) {
   ASSERT_EQ_RESOLVED(a.numel(), 35);
 }
 
-void TestOnesAndDot(Type& type) {
+void TestOnesAndDot(DeprecatedTypeProperties& type) {
   Tensor b0 = ones({1, 1}, type);
   ASSERT_EQ_RESOLVED((b0 + b0).sum().item<double>(), 2);
 
@@ -42,7 +42,7 @@ void TestOnesAndDot(Type& type) {
   ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
 }
 
-void TestSort(Type& type) {
+void TestSort(DeprecatedTypeProperties& type) {
   Tensor b = rand({3, 4}, type);
 
   auto z = b.sort(1);
@@ -52,7 +52,7 @@ void TestSort(Type& type) {
   ASSERT_TRUE(isLT);
 }
 
-void TestRandperm(Type& type) {
+void TestRandperm(DeprecatedTypeProperties& type) {
   if (type.backend() != Backend::CUDA) {
     Tensor b = randperm(15, type);
     Tensor rv, ri;
@@ -67,7 +67,7 @@ void SendContext() {
   ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl;
 }
 
-void TestAdd(Type& type) {
+void TestAdd(DeprecatedTypeProperties& type) {
   Tensor a = rand({3, 4}, type);
   Tensor b = rand({3, 4}, type);
   Tensor c = add(a, add(a, b));
@@ -76,7 +76,7 @@ void TestAdd(Type& type) {
   ASSERT_TRUE(add(c, d).allclose(a + a + b + d));
 }
 
-void TestLoadsOfAdds(Type& type) {
+void TestLoadsOfAdds(DeprecatedTypeProperties& type) {
   auto begin = std::chrono::high_resolution_clock::now();
   Tensor d = ones({3, 4}, type);
   Tensor r = zeros({3, 4}, type);
@@ -93,7 +93,7 @@ void TestLoadsOfAdds(Type& type) {
   ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>());
 }
 
-void TestLoadOfAddsWithCopy(Type& type) {
+void TestLoadOfAddsWithCopy(DeprecatedTypeProperties& type) {
   auto begin = std::chrono::high_resolution_clock::now();
   Tensor d = ones({3, 4}, type);
   Tensor r = zeros({3, 4}, type);
@@ -110,28 +110,28 @@ void TestLoadOfAddsWithCopy(Type& type) {
   ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>());
 }
 
-void TestIsContiguous(Type& type) {
+void TestIsContiguous(DeprecatedTypeProperties& type) {
   Tensor a = rand({3, 4}, type);
   ASSERT_TRUE(a.is_contiguous());
   a = a.transpose(0, 1);
   ASSERT_FALSE(a.is_contiguous());
 }
 
-void TestPermute(Type& type) {
+void TestPermute(DeprecatedTypeProperties& type) {
   Tensor a = rand({3, 4, 5}, type);
   Tensor b = a.permute({1, 2, 0});
   ASSERT_TRUE(b.sizes().equals({4, 5, 3}));
   ASSERT_TRUE(b.strides().equals({5, 1, 20}));
 }
 
-void TestMm(Type& type) {
+void TestMm(DeprecatedTypeProperties& type) {
   Tensor a = rand({3, 4}, type);
   Tensor b = rand({4}, type);
   Tensor c = mv(a, b);
   ASSERT_TRUE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
 }
 
-void TestSqueeze(Type& type) {
+void TestSqueeze(DeprecatedTypeProperties& type) {
   Tensor a = rand({2, 1}, type);
   Tensor b = squeeze(a);
   ASSERT_EQ_RESOLVED(b.dim(), 1);
@@ -141,14 +141,14 @@ void TestSqueeze(Type& type) {
   ASSERT_TRUE(a[0].equal(b));
 }
 
-void TestCopy(Type& type) {
+void TestCopy(DeprecatedTypeProperties& type) {
   Tensor a = zeros({4, 3}, type);
   Tensor e = rand({4, 3}, type);
   a.copy_(e);
   ASSERT_TRUE(a.equal(e));
 }
 
-void TestCopyBroadcasting(Type& type) {
+void TestCopyBroadcasting(DeprecatedTypeProperties& type) {
   Tensor a = zeros({4, 3}, type);
   Tensor e = rand({3}, type);
   a.copy_(e);
@@ -156,7 +156,7 @@ void TestCopyBroadcasting(Type& type) {
     ASSERT_TRUE(a[i].equal(e));
   }
 }
-void TestAbsValue(Type& type) {
+void TestAbsValue(DeprecatedTypeProperties& type) {
   Tensor r = at::abs(at::scalar_tensor(-3, type.options()));
   ASSERT_EQ_RESOLVED(r.item<int32_t>(), 3);
 }
@@ -173,12 +173,12 @@ std::cout << (a == 10.) << " -- should be 1" << std::endl;
 #endif
 */
 
-void TestAddingAValueWithScalar(Type& type) {
+void TestAddingAValueWithScalar(DeprecatedTypeProperties& type) {
   Tensor a = rand({4, 3}, type);
   ASSERT_TRUE((ones({4, 3}, type) + a).equal(add(a, 1)));
 }
 
-void TestSelect(Type& type) {
+void TestSelect(DeprecatedTypeProperties& type) {
   Tensor a = rand({3, 7}, type);
   auto a_13 = select(a, 1, 3);
   auto a_13_02 = select(select(a, 1, 3), 0, 2);
@@ -186,7 +186,7 @@ void TestSelect(Type& type) {
   ASSERT_TRUE(a[2][3].equal(a_13_02));
 }
 
-void TestZeroDim(Type& type) {
+void TestZeroDim(DeprecatedTypeProperties& type) {
   Tensor a = at::scalar_tensor(4, type.options()); // rand(type, {1});
 
   Tensor b = rand({3, 4}, type);
@@ -263,7 +263,7 @@ void TestIndexingByZerodimTensor() {
   // Throw StartsWith("Can only index with tensors that are scalars (zero-dim)")
   ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one));
 }
-void TestIndexingMixedDevice(Type& type) {
+void TestIndexingMixedDevice(DeprecatedTypeProperties& type) {
   Tensor tensor = randn({20, 20}, type);
   Tensor index = arange(10, kLong).cpu();
   Tensor result = tensor.index({index});
@@ -276,14 +276,14 @@ void TestDispatch() {
   ASSERT_TRUE(result.allclose(mse_loss(relu(tensor), other)));
 }
 
-void TestNegativeDim(Type& type) {
+void TestNegativeDim(DeprecatedTypeProperties& type) {
   ASSERT_ANY_THROW(empty({5, -5, 5}, type.options()));
   ASSERT_ANY_THROW(empty({5, -5, -5}, type.options()));
   Tensor tensor = empty({5, 5}, type.options());
   ASSERT_ANY_THROW(tensor.reshape({-5, -5}));
 }
 
-void test(Type& type) {
+void test(DeprecatedTypeProperties& type) {
   TestResize(type);
   TestOnesAndDot(type);
 
index 6463115..42a6544 100644 (file)
@@ -6,13 +6,13 @@
 using namespace at;
 
 // can't expand empty tensor
-void TestEmptyTensor(Type& T) {
+void TestEmptyTensor(DeprecatedTypeProperties& T) {
   auto empty = randn({0}, T);
   ASSERT_ANY_THROW(empty.expand({3}));
 }
 
 // out-place function with 2 args
-void TestOut2Basic(Type& T) {
+void TestOut2Basic(DeprecatedTypeProperties& T) {
   auto a = randn({3, 1}, T);
   auto b = randn({5}, T);
   std::vector<int64_t> expanded_sizes = {3, 5};
@@ -21,7 +21,7 @@ void TestOut2Basic(Type& T) {
 }
 
 // with scalar
-void TestOut2WithScalar(Type& T) {
+void TestOut2WithScalar(DeprecatedTypeProperties& T) {
   auto aScalar = ones({1}, T);
   aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
   auto b = randn({3, 5}, T);
@@ -30,21 +30,21 @@ void TestOut2WithScalar(Type& T) {
 }
 
 // old fallback behavior yields error
-void TestOut2OldFallback(Type& T) {
+void TestOut2OldFallback(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5}, T);
   auto b = randn({5, 3}, T);
   ASSERT_ANY_THROW(a + b);
 }
 
 // with mismatched sizes
-void TestOut2MismatchedSizes(Type& T) {
+void TestOut2MismatchedSizes(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5}, T);
   auto b = randn({7, 5}, T);
   ASSERT_ANY_THROW(a + b);
 }
 
 // out-place function with 3 args
-void TestOut3Basic(Type& T) {
+void TestOut3Basic(DeprecatedTypeProperties& T) {
   auto a = randn({3, 1, 1}, T);
   auto b = randn({1, 2, 1}, T);
   auto c = randn({1, 1, 5}, T);
@@ -55,7 +55,7 @@ void TestOut3Basic(Type& T) {
 }
 
 // with scalar
-void TestOut3WithScalar(Type& T) {
+void TestOut3WithScalar(DeprecatedTypeProperties& T) {
   auto aTensorScalar = ones({1}, T);
   aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
   auto b = randn({3, 2, 1}, T);
@@ -67,7 +67,7 @@ void TestOut3WithScalar(Type& T) {
 }
 
 // old fallback behavior yields error
-void TestOut3OldFallback(Type& T) {
+void TestOut3OldFallback(DeprecatedTypeProperties& T) {
   auto a = randn({3, 2, 5}, T);
   auto b = randn({2, 3, 5}, T);
   auto c = randn({5, 3, 2}, T);
@@ -75,7 +75,7 @@ void TestOut3OldFallback(Type& T) {
 }
 
 // with mismatched sizes
-void TestOut3MismatchedSizes(Type& T) {
+void TestOut3MismatchedSizes(DeprecatedTypeProperties& T) {
   auto a = randn({3, 2, 5}, T);
   auto b = randn({2, 3, 5}, T);
   auto c = randn({5, 5, 5}, T);
@@ -83,14 +83,14 @@ void TestOut3MismatchedSizes(Type& T) {
 }
 
 // in-place function with 2 args
-void TestIn2Basic(Type& T) {
+void TestIn2Basic(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5}, T);
   auto b = randn({3, 1}, T);
   ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
 }
 
 // with scalar
-void TestIn2WithScalar(Type& T) {
+void TestIn2WithScalar(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5}, T);
   auto bScalar = ones({1}, T);
   bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
@@ -98,14 +98,14 @@ void TestIn2WithScalar(Type& T) {
 }
 
 // error: would have to expand inplace arg
-void TestIn2ExpandError(Type& T) {
+void TestIn2ExpandError(DeprecatedTypeProperties& T) {
   auto a = randn({1, 5}, T);
   auto b = randn({3, 1}, T);
   ASSERT_ANY_THROW(a.add_(b));
 }
 
 // in-place function with 3 args
-void TestIn3Basic(Type& T) {
+void TestIn3Basic(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5, 2}, T);
   auto b = randn({3, 1, 2}, T);
   auto c = randn({1, 5, 1}, T);
@@ -115,7 +115,7 @@ void TestIn3Basic(Type& T) {
 }
 
 // with scalar
-void TestIn3WithScalar(Type& T) {
+void TestIn3WithScalar(DeprecatedTypeProperties& T) {
   auto a = randn({3, 5, 2}, T);
   auto b = randn({3, 1, 2}, T);
   auto c = randn({1, 5, 1}, T);
@@ -128,7 +128,7 @@ void TestIn3WithScalar(Type& T) {
 }
 
 // error: would have to expand inplace arg
-void TestIn3ExpandError(Type& T) {
+void TestIn3ExpandError(DeprecatedTypeProperties& T) {
   auto a = randn({1, 3, 5}, T);
   auto b = randn({4, 1, 1}, T);
   auto c = randn({1, 3, 1}, T);
@@ -136,7 +136,7 @@ void TestIn3ExpandError(Type& T) {
 }
 
 // explicit dim specification
-void TestExplicitDimBasic(Type& T) {
+void TestExplicitDimBasic(DeprecatedTypeProperties& T) {
   auto a = randn({1}, T);
   auto b = randn({5, 3}, T);
   auto c = randn({3, 7}, T);
@@ -144,7 +144,7 @@ void TestExplicitDimBasic(Type& T) {
 }
 
 // with scalar
-void TestExplicitDimWithScalar(Type& T) {
+void TestExplicitDimWithScalar(DeprecatedTypeProperties& T) {
   auto a = randn({1}, T);
   auto b = randn({5, 3}, T);
   auto c = randn({3, 7}, T);
@@ -154,7 +154,7 @@ void TestExplicitDimWithScalar(Type& T) {
 }
 
 // with mismatched sizes
-void TestExplicitDimWithMismatchedSizes(Type& T) {
+void TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties& T) {
   auto b = randn({5, 3}, T);
   auto c = randn({3, 7}, T);
   auto a = randn({3, 3}, T);
@@ -163,7 +163,7 @@ void TestExplicitDimWithMismatchedSizes(Type& T) {
 
 TEST(BroadcastTest, Broadcast) {
   manual_seed(123);
-  Type& T = CPU(kFloat);
+  DeprecatedTypeProperties& T = CPU(kFloat);
 
   TestEmptyTensor(T);
 
index 0d0ca1b..cd017cc 100644 (file)
@@ -21,32 +21,28 @@ void requireEqualTensorList(TensorList t1, TensorList t2) {
   }
 }
 
-// split: test method, type, namespace give same result
-void TestSplit(Type& T, Tensor& t) {
+// split: test method, namespace give same result
+void TestSplit(DeprecatedTypeProperties& T, Tensor& t) {
   auto splitMethod = t.split(1, 0);
-  auto splitType = T.split(t, 1, 0);
   auto splitNs = at::split(t, 1, 0);
-  requireEqualTensorList(splitMethod, splitType);
   requireEqualTensorList(splitMethod, splitNs);
 
   // test rebuilding with cat
   ASSERT_EQUAL(at::cat(splitMethod, 0), t);
 }
 
-// chunk: test method, type, namespace give same result
-void TestChunk(Type& T, Tensor& t) {
+// chunk: test method, namespace give same result
+void TestChunk(DeprecatedTypeProperties& T, Tensor& t) {
   // test method, type, namespace give same result
   auto chunkMethod = t.chunk(3, 0);
-  auto chunkType = T.chunk(t, 3, 0);
   auto chunkNs = at::chunk(t, 3, 0);
-  requireEqualTensorList(chunkMethod, chunkType);
   requireEqualTensorList(chunkMethod, chunkNs);
 
   // test rebuilding with cat
   ASSERT_EQUAL(at::cat(chunkMethod, 0), t);
 }
 
-void TestStack(Type& T, Tensor& t) {
+void TestStack(DeprecatedTypeProperties& T, Tensor& t) {
   auto x = rand({2, 3, 4});
   auto y = rand({2, 3, 4});
   auto z = rand({2, 3, 4});
@@ -69,7 +65,7 @@ void TestStack(Type& T, Tensor& t) {
 }
 
 // size / stride
-void TestSize(Type& T, Tensor& t) {
+void TestSize(DeprecatedTypeProperties& T, Tensor& t) {
   auto scalar = randn({}, T);
   // Throw StartsWith("dimension specified as 0 but tensor has no dimensions")
   ASSERT_ANY_THROW(scalar.size(0));
@@ -87,7 +83,7 @@ void TestSize(Type& T, Tensor& t) {
   ASSERT_EQ(empty.stride(-1), 1);
 }
 
-void TestMatmul(Type& T, Tensor& t, Type& AccT) {
+void TestMatmul(DeprecatedTypeProperties& T, Tensor& t, DeprecatedTypeProperties& AccT) {
   auto scalar = randn({}, T);
   auto d1 = randn({3}, T);
   auto d2 = randn({2, 3}, T);
@@ -160,7 +156,7 @@ void TestMatmul(Type& T, Tensor& t, Type& AccT) {
   ASSERT_ANY_THROW(d5.matmul(d5wrong));
 }
 
-void TestStandardGammaGrad(Type& T, Tensor& t) {
+void TestStandardGammaGrad(DeprecatedTypeProperties& T, Tensor& t) {
   // check empty
   auto empty = ones({0}, T);
   ASSERT_EQUAL(empty, at::_standard_gamma_grad(empty, empty));
@@ -179,7 +175,7 @@ void TestStandardGammaGrad(Type& T, Tensor& t) {
   ASSERT_ANY_THROW(at::_standard_gamma_grad(t1, t2));
 }
 
-void TestWhere(Type& T, Tensor& t) {
+void TestWhere(DeprecatedTypeProperties& T, Tensor& t) {
   // empty
   auto empty = ones({0}, T);
   auto& bT = T.toScalarType(ScalarType::Byte);
@@ -198,7 +194,7 @@ void TestWhere(Type& T, Tensor& t) {
       at::where(cond_1d, x_1d, y_1d));
 }
 
-void test(Type& T, Type& AccT) {
+void test(DeprecatedTypeProperties& T, DeprecatedTypeProperties& AccT) {
   auto t = randn({3, 3}, T);
   TestSplit(T, t);
   TestChunk(T, t);
index 6872344..97e598d 100644 (file)
@@ -42,7 +42,7 @@ bool should_expand(const IntArrayRef &from_size, const IntArrayRef &to_size) {
   return true;
 }
 
-void test(Type &T) {
+void test(DeprecatedTypeProperties &T) {
   std::vector<std::vector<int64_t>> sizes = {{}, {0}, {1}, {1, 1}, {2}};
 
   // single-tensor/size tests
index 5a3c926..9c9c42c 100644 (file)
@@ -27,9 +27,9 @@ TEST(TestUndefined, UndefinedTest) {
   ASSERT_ANY_THROW(und.add(5));
   ASSERT_ANY_THROW(und.mm(und));
 
-  und.toType(und.dispatch_type());
-  ASSERT_ANY_THROW(und.toType(ft.dispatch_type()));
-  ASSERT_ANY_THROW(ft.toType(und.dispatch_type()));
+  und.toType(und.type());
+  ASSERT_ANY_THROW(und.toType(ft.type()));
+  ASSERT_ANY_THROW(ft.toType(und.type()));
   und.toType(ScalarType::Undefined);
   ASSERT_ANY_THROW(und.toType(ScalarType::Float));
   ASSERT_ANY_THROW(ft.toType(ScalarType::Undefined));
index b7088cf..9c3b18a 100644 (file)
@@ -3,13 +3,13 @@
 #include <ATen/ATen.h>
 
 using namespace at;
-void TestSimpleCase(Type& T) {
+void TestSimpleCase(DeprecatedTypeProperties& T) {
   auto a = randn({2, 3, 4, 5}, T);
   ASSERT_TRUE(a.prod(-4).equal(a.prod(0)));
   ASSERT_TRUE(a.prod(3).equal(a.prod(-1)));
 }
 
-void TestExpressionSpecification(Type& T) {
+void TestExpressionSpecification(DeprecatedTypeProperties& T) {
   auto a = randn({2, 3, 4, 5}, T);
   ASSERT_TRUE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
   ASSERT_TRUE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
@@ -20,12 +20,12 @@ void TestExpressionSpecification(Type& T) {
   ASSERT_TRUE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
 }
 
-void TestEmptyTensor(Type& T) {
+void TestEmptyTensor(DeprecatedTypeProperties& T) {
   auto a = randn(0, T);
   ASSERT_TRUE(a.prod(0).equal(at::ones({}, T)));
 }
 
-void TestScalarVs1Dim1Size(Type& T) {
+void TestScalarVs1Dim1Size(DeprecatedTypeProperties& T) {
   auto a = randn(1, T);
   ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
   a.unsafeGetTensorImpl()->maybe_zero_dim(true);
@@ -35,7 +35,7 @@ void TestScalarVs1Dim1Size(Type& T) {
 
 TEST(TestWrapdim, TestWrapdim) {
   manual_seed(123);
-  Type& T = CPU(kFloat);
+  DeprecatedTypeProperties& T = CPU(kFloat);
 
   TestSimpleCase(T);
   TestEmptyTensor(T);
index 315f32e..0bd73c9 100644 (file)
@@ -24,7 +24,7 @@ bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) {
       isEqual(ti.strides(), v.strides());
 }
 
-autograd::Variable var(at::Type& t, at::IntArrayRef sizes, bool requires_grad) {
+autograd::Variable var(at::DeprecatedTypeProperties& t, at::IntArrayRef sizes, bool requires_grad) {
   return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
 }
 autograd::Variable undef() {
index 0c166ac..984105d 100644 (file)
@@ -82,8 +82,7 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state
     throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name);
   }
   auto& tensor = ((THPVariable*)_new_state)->cdata.data();
-  auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type());
-  if (tensor_type != CPU(kByte)) {
+  if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) {
     auto type_name = torch::utils::type_to_string(tensor.dispatch_type());
     throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
   }
index c2eed5c..be5f576 100644 (file)
@@ -260,7 +260,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block
   if (requires_grad) {
     grad_fn = std::make_shared<CopyBackwards>();
     grad_fn->set_next_edges(collect_next_edges(self, src));
-    grad_fn->src_type = &src.dispatch_type();
+    grad_fn->src_type = &src.type();
     grad_fn->src_device = src.device();
   }
   {
index ecbf711..e4e4b7e 100644 (file)
@@ -4,7 +4,7 @@
 #include <torch/csrc/autograd/variable.h>
 
 #include <ATen/TensorGeometry.h>
-#include <ATen/Type.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
 #include <c10/util/Optional.h>
 
 #include <cstdint>
@@ -15,7 +15,7 @@ namespace torch { namespace autograd {
 struct CopyBackwards : public Function {
   variable_list apply(variable_list&& grads) override;
 
-  at::Type *src_type = nullptr; // initialized for safety.
+  at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
   at::Device src_device = at::kCPU;
 };