From 422b01e78889a9fd283d8fb86bb5527c5cae3cbf Mon Sep 17 00:00:00 2001 From: Roy Li Date: Thu, 11 Apr 2019 16:55:39 -0700 Subject: [PATCH] Replace more usages of Type with DeprecatedTypeProperties (#19093) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19093 ghimport-source-id: a82e3dce912a173b42a6a7e35eb1302d9f334e03 Differential Revision: D14865520 Pulled By: li-roy fbshipit-source-id: b1a8bf32f87920ce8d82f990d670477bc79d0ca7 --- aten/src/ATen/Context.h | 15 +++-- aten/src/ATen/Dispatch.h | 1 + aten/src/ATen/core/DeprecatedTypeProperties.cpp | 24 +++++++ aten/src/ATen/core/DeprecatedTypeProperties.h | 74 ++++++++++++++++++++-- .../ATen/core/DeprecatedTypePropertiesRegistry.cpp | 24 +++++++ .../ATen/core/DeprecatedTypePropertiesRegistry.h | 32 +++------- aten/src/ATen/core/Tensor.h | 7 +- aten/src/ATen/core/TensorMethods.h | 20 +++--- aten/src/ATen/native/TypeProperties.cpp | 2 +- aten/src/ATen/templates/Tensor.h | 7 +- aten/src/ATen/templates/TensorMethods.h | 20 +++--- aten/src/ATen/test/apply_utils_test.cpp | 2 +- aten/src/ATen/test/basic.cpp | 40 ++++++------ aten/src/ATen/test/broadcast_test.cpp | 38 +++++------ aten/src/ATen/test/native_test.cpp | 24 +++---- aten/src/ATen/test/scalar_tensor_test.cpp | 2 +- aten/src/ATen/test/undefined_tensor_test.cpp | 6 +- aten/src/ATen/test/wrapdim_test.cpp | 10 +-- test/cpp/jit/test_argument_spec.h | 2 +- torch/csrc/Generator.cpp | 3 +- torch/csrc/autograd/VariableTypeManual.cpp | 2 +- torch/csrc/autograd/functions/tensor.h | 4 +- 22 files changed, 234 insertions(+), 125 deletions(-) create mode 100644 aten/src/ATen/core/DeprecatedTypeProperties.cpp diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 76d1a90..9033494 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -182,16 +182,19 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&); CAFFE2_API Allocator* getCPUAllocator(); -static inline TypeExtendedInterface& CPU(ScalarType s) { - return getNonVariableType(Backend::CPU, s); +static inline DeprecatedTypeProperties& CPU(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CPU, s, /*is_variable*/false); } -static inline TypeExtendedInterface& CUDA(ScalarType s) { - return getNonVariableType(Backend::CUDA, s); +static inline DeprecatedTypeProperties& CUDA(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CUDA, s, /*is_variable*/false); } -static inline TypeExtendedInterface& HIP(ScalarType s) { - return getNonVariableType(Backend::HIP, s); +static inline DeprecatedTypeProperties& HIP(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::HIP, s, /*is_variable*/false); } CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options); diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 2e83529..7ec10ef 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -3,6 +3,7 @@ #include #include #include +#include #define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \ case enum_type: { \ diff --git a/aten/src/ATen/core/DeprecatedTypeProperties.cpp b/aten/src/ATen/core/DeprecatedTypeProperties.cpp new file mode 100644 index 0000000..b634a50 --- /dev/null +++ b/aten/src/ATen/core/DeprecatedTypeProperties.cpp @@ -0,0 +1,24 @@ +#include + +#include +#include + +namespace at { + +Tensor DeprecatedTypeProperties::unsafeTensorFromTH(void * th_pointer, bool retain) const { + return getDispatchType().unsafeTensorFromTH(th_pointer, retain); +} + +Tensor DeprecatedTypeProperties::copy(const Tensor & src, bool non_blocking, c10::optional to_device) const { + return getDispatchType().copy(src, non_blocking, to_device); +} + +std::unique_ptr DeprecatedTypeProperties::generator() const { + return getDispatchType().generator(); +} + +Type & DeprecatedTypeProperties::getDispatchType() const { + return globalLegacyTypeDispatch().getType(backend_, scalar_type_, is_variable_); +} + +} // namespace at diff --git a/aten/src/ATen/core/DeprecatedTypeProperties.h b/aten/src/ATen/core/DeprecatedTypeProperties.h index 88f53f6..116af9d 100644 --- a/aten/src/ATen/core/DeprecatedTypeProperties.h +++ b/aten/src/ATen/core/DeprecatedTypeProperties.h @@ -3,24 +3,33 @@ #include #include #include - +#include +#include +#include namespace at { +class Tensor; +struct Type; + // This class specifies a Backend and a ScalarType. Currently, it primarily // serves as a replacement return value for Tensor::type(). Previously, // Tensor::type() returned Type&, but we are changing Type to not be // dtype-specific. -class DeprecatedTypeProperties { +class CAFFE2_API DeprecatedTypeProperties { public: - DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) - : backend_(backend), scalar_type_(scalar_type) {} + DeprecatedTypeProperties(Backend backend, ScalarType scalar_type, bool is_variable) + : backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable) {} Backend backend() const { return backend_; } + Layout layout() const { + return layout_from_backend(backend_); + } + bool is_sparse() const { return layout_from_backend(backend()) == kSparse; } @@ -41,8 +50,8 @@ class DeprecatedTypeProperties { return scalarTypeToTypeMeta(scalar_type_); } - bool is_defined() const { - return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined; + bool is_variable() const { + return is_variable_; } bool operator==(const DeprecatedTypeProperties& other) const { @@ -59,9 +68,62 @@ class DeprecatedTypeProperties { return ss.str(); } + DeprecatedTypeProperties & toBackend(Backend b) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + b, scalar_type_, is_variable_); + } + + DeprecatedTypeProperties & toScalarType(ScalarType s) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + backend_, s, is_variable_); + } + + DeprecatedTypeProperties & cpu() const { + return toBackend(Backend::CPU); + } + + DeprecatedTypeProperties & cuda() const { + return toBackend(Backend::CUDA); + } + + DeprecatedTypeProperties & hip() const { + return toBackend(Backend::HIP); + } + + /// Constructs the `TensorOptions` from a type and a `device_index`. + TensorOptions options(int16_t device_index = -1) const { + return TensorOptions().dtype(typeMeta()) + .device(device_type(), device_index) + .layout(layout()) + .is_variable(is_variable()); + } + + /// Constructs the `TensorOptions` from a type and a Device. Asserts that + /// the device type matches the device type of the type. + TensorOptions options(c10::optional device_opt) const { + if (!device_opt.has_value()) { + return options(-1); + } else { + Device device = device_opt.value(); + AT_ASSERT(device.type() == device_type()); + return options(device.index()); + } + } + + operator TensorOptions() const { + return options(); + } + + Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const; + Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional to_device={}) const; + std::unique_ptr generator() const; + private: + Type & getDispatchType() const; + Backend backend_; ScalarType scalar_type_; + bool is_variable_; }; } // namespace at diff --git a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp index 154f04d..e9188bf 100644 --- a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp +++ b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp @@ -1,7 +1,31 @@ #include +#include + namespace at { +void DeprecatedTypePropertiesDeleter::operator()(DeprecatedTypeProperties * ptr) { + delete ptr; +} + +DeprecatedTypePropertiesRegistry::DeprecatedTypePropertiesRegistry() { + for (int b = 0; b < static_cast(Backend::NumOptions); ++b) { + for (int s = 0; s < static_cast(ScalarType::NumOptions); ++s) { + for (int v = 0; v < 2; ++ v) { + registry[b][s][v] = c10::guts::make_unique( + static_cast(b), + static_cast(s), + v); + } + } + } +} + +DeprecatedTypeProperties& DeprecatedTypePropertiesRegistry::getDeprecatedTypeProperties( + Backend p, ScalarType s, bool is_variable) const { + return *registry[static_cast(p)][static_cast(s)][is_variable]; +} + // TODO: This could be bad juju if someone calls globalContext() in the // destructor of an object with static lifetime. DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() { diff --git a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h index 0ab57bf..543db04 100644 --- a/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h +++ b/aten/src/ATen/core/DeprecatedTypePropertiesRegistry.h @@ -5,40 +5,26 @@ #include #include -#include namespace at { +class DeprecatedTypeProperties; + struct CAFFE2_API DeprecatedTypePropertiesDeleter { - void operator()(DeprecatedTypeProperties * ptr) { - delete ptr; - } + void operator()(DeprecatedTypeProperties * ptr); }; class CAFFE2_API DeprecatedTypePropertiesRegistry { public: - using DeprecatedTypePropertiesUniquePtr = - std::unique_ptr; - - DeprecatedTypePropertiesRegistry() { - for (int b = 0; b < static_cast(Backend::NumOptions); ++b) { - for (int s = 0; s < static_cast(ScalarType::NumOptions); ++s) { - registry[b][s] = DeprecatedTypePropertiesUniquePtr{ - new DeprecatedTypeProperties(static_cast(b), static_cast(s)), - DeprecatedTypePropertiesDeleter() - }; - } - } - } - - DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) { - return *registry[static_cast(p)][static_cast(s)]; - } + DeprecatedTypePropertiesRegistry(); + + DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s, bool is_variable) const; private: - DeprecatedTypePropertiesUniquePtr registry + std::unique_ptr registry [static_cast(Backend::NumOptions)] - [static_cast(ScalarType::NumOptions)]; + [static_cast(ScalarType::NumOptions)] + [2]; // is_variable }; CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 6f58a4a..3ba8a57 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -21,6 +21,7 @@ struct TensorOptions; namespace at { struct Generator; struct Type; +class DeprecatedTypeProperties; class Tensor; } // namespace at @@ -199,7 +200,9 @@ class CAFFE2_API Tensor { DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(type_id()), scalar_type()); + tensorTypeIdToBackend(type_id()), + scalar_type(), + is_variable() && !at::NonVariableTypeMode::is_enabled()); } Type & dispatch_type() const { return legacyTensorType(*impl_); @@ -219,8 +222,8 @@ class CAFFE2_API Tensor { bool is_alias_of(const at::Tensor& other) const{ return impl_->storage().is_alias_of(other.storage()); } - Tensor toType(const Type & t, bool non_blocking=false) const; Tensor & copy_(const Tensor & src, bool non_blocking=false); + Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const; Tensor toType(ScalarType t) const; Tensor toBackend(Backend b) const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 26f3807..367839d 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -6,25 +6,29 @@ #include #include #include +#include namespace at { -inline Tensor Tensor::toType(const Type & t, bool non_blocking) const { - if(dispatch_type() == t) +inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const { + if(type() == t) return *this; - return t.copy(*this, non_blocking); + return to( + at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()), + non_blocking, + /*copy*/ true); } inline Tensor Tensor::cpu() const { - return toType(dispatch_type().cpu()); + return toType(type().cpu()); } inline Tensor Tensor::cuda() const { - return toType(dispatch_type().cuda()); + return toType(type().cuda()); } inline Tensor Tensor::hip() const { - return toType(dispatch_type().hip()); + return toType(type().hip()); } inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) { @@ -32,11 +36,11 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) { } inline Tensor Tensor::toType(ScalarType t) const { - return toType(dispatch_type().toScalarType(t)); + return toType(type().toScalarType(t)); } inline Tensor Tensor::toBackend(Backend b) const { - return toType(dispatch_type().toBackend(b)); + return toType(type().toBackend(b)); } inline TensorOptions Tensor::options() const { diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index c2cae17..ae1910f 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -35,7 +35,7 @@ bool is_sparse(const Tensor& self) { } Tensor type_as(const Tensor& self, const Tensor& other) { - return self.toType(other.dispatch_type()); + return self.toType(other.type()); } }} // namespace at::native diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index b1e917a..ab9917e 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -21,6 +21,7 @@ struct TensorOptions; namespace at { struct Generator; struct Type; +class DeprecatedTypeProperties; class Tensor; } // namespace at @@ -199,7 +200,9 @@ class CAFFE2_API Tensor { DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(type_id()), scalar_type()); + tensorTypeIdToBackend(type_id()), + scalar_type(), + is_variable() && !at::NonVariableTypeMode::is_enabled()); } Type & dispatch_type() const { return legacyTensorType(*impl_); @@ -219,8 +222,8 @@ class CAFFE2_API Tensor { bool is_alias_of(const at::Tensor& other) const{ return impl_->storage().is_alias_of(other.storage()); } - Tensor toType(const Type & t, bool non_blocking=false) const; Tensor & copy_(const Tensor & src, bool non_blocking=false); + Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const; Tensor toType(ScalarType t) const; Tensor toBackend(Backend b) const; diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 5928907..18b5e53 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -6,25 +6,29 @@ #include #include #include +#include namespace at { -inline Tensor Tensor::toType(const Type & t, bool non_blocking) const { - if(dispatch_type() == t) +inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const { + if(type() == t) return *this; - return t.copy(*this, non_blocking); + return to( + at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()), + non_blocking, + /*copy*/ true); } inline Tensor Tensor::cpu() const { - return toType(dispatch_type().cpu()); + return toType(type().cpu()); } inline Tensor Tensor::cuda() const { - return toType(dispatch_type().cuda()); + return toType(type().cuda()); } inline Tensor Tensor::hip() const { - return toType(dispatch_type().hip()); + return toType(type().hip()); } inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) { @@ -32,11 +36,11 @@ inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) { } inline Tensor Tensor::toType(ScalarType t) const { - return toType(dispatch_type().toScalarType(t)); + return toType(type().toScalarType(t)); } inline Tensor Tensor::toBackend(Backend b) const { - return toType(dispatch_type().toBackend(b)); + return toType(type().toBackend(b)); } inline TensorOptions Tensor::options() const { diff --git a/aten/src/ATen/test/apply_utils_test.cpp b/aten/src/ATen/test/apply_utils_test.cpp index cc97c03..2f5e1b6 100644 --- a/aten/src/ATen/test/apply_utils_test.cpp +++ b/aten/src/ATen/test/apply_utils_test.cpp @@ -23,7 +23,7 @@ void fill_tensor(int64_t scalar, Tensor& t_) { // write the same type as we read (using a0, ..., aX-1) and we once write to // double (using a4 as a target). We also exercise on a zero_dim and empty // tensor. -void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { +void test(DeprecatedTypeProperties& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) { auto zero_dim = at::empty({}, type); zero_dim.fill_(2); zero_dim.exp_(); diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index ebd569a..7f2e5ce 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -21,7 +21,7 @@ extern "C" void THFloatTensor_fill(THFloatTensor *, float v); using namespace at; -void TestResize(Type& type) { +void TestResize(DeprecatedTypeProperties& type) { auto a = at::empty({0}, type.options()); a.resize_({3, 4}); ASSERT_EQ_RESOLVED(a.numel(), 12); @@ -29,7 +29,7 @@ void TestResize(Type& type) { ASSERT_EQ_RESOLVED(a.numel(), 35); } -void TestOnesAndDot(Type& type) { +void TestOnesAndDot(DeprecatedTypeProperties& type) { Tensor b0 = ones({1, 1}, type); ASSERT_EQ_RESOLVED((b0 + b0).sum().item(), 2); @@ -42,7 +42,7 @@ void TestOnesAndDot(Type& type) { ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item(), 12); } -void TestSort(Type& type) { +void TestSort(DeprecatedTypeProperties& type) { Tensor b = rand({3, 4}, type); auto z = b.sort(1); @@ -52,7 +52,7 @@ void TestSort(Type& type) { ASSERT_TRUE(isLT); } -void TestRandperm(Type& type) { +void TestRandperm(DeprecatedTypeProperties& type) { if (type.backend() != Backend::CUDA) { Tensor b = randperm(15, type); Tensor rv, ri; @@ -67,7 +67,7 @@ void SendContext() { ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl; } -void TestAdd(Type& type) { +void TestAdd(DeprecatedTypeProperties& type) { Tensor a = rand({3, 4}, type); Tensor b = rand({3, 4}, type); Tensor c = add(a, add(a, b)); @@ -76,7 +76,7 @@ void TestAdd(Type& type) { ASSERT_TRUE(add(c, d).allclose(a + a + b + d)); } -void TestLoadsOfAdds(Type& type) { +void TestLoadsOfAdds(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor d = ones({3, 4}, type); Tensor r = zeros({3, 4}, type); @@ -93,7 +93,7 @@ void TestLoadsOfAdds(Type& type) { ASSERT_EQ_RESOLVED(norm(100000 * d).item(), norm(r).item()); } -void TestLoadOfAddsWithCopy(Type& type) { +void TestLoadOfAddsWithCopy(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor d = ones({3, 4}, type); Tensor r = zeros({3, 4}, type); @@ -110,28 +110,28 @@ void TestLoadOfAddsWithCopy(Type& type) { ASSERT_EQ_RESOLVED(norm(100000 * d).item(), norm(r).item()); } -void TestIsContiguous(Type& type) { +void TestIsContiguous(DeprecatedTypeProperties& type) { Tensor a = rand({3, 4}, type); ASSERT_TRUE(a.is_contiguous()); a = a.transpose(0, 1); ASSERT_FALSE(a.is_contiguous()); } -void TestPermute(Type& type) { +void TestPermute(DeprecatedTypeProperties& type) { Tensor a = rand({3, 4, 5}, type); Tensor b = a.permute({1, 2, 0}); ASSERT_TRUE(b.sizes().equals({4, 5, 3})); ASSERT_TRUE(b.strides().equals({5, 1, 20})); } -void TestMm(Type& type) { +void TestMm(DeprecatedTypeProperties& type) { Tensor a = rand({3, 4}, type); Tensor b = rand({4}, type); Tensor c = mv(a, b); ASSERT_TRUE(c.equal(addmv(zeros({3}, type), a, b, 0, 1))); } -void TestSqueeze(Type& type) { +void TestSqueeze(DeprecatedTypeProperties& type) { Tensor a = rand({2, 1}, type); Tensor b = squeeze(a); ASSERT_EQ_RESOLVED(b.dim(), 1); @@ -141,14 +141,14 @@ void TestSqueeze(Type& type) { ASSERT_TRUE(a[0].equal(b)); } -void TestCopy(Type& type) { +void TestCopy(DeprecatedTypeProperties& type) { Tensor a = zeros({4, 3}, type); Tensor e = rand({4, 3}, type); a.copy_(e); ASSERT_TRUE(a.equal(e)); } -void TestCopyBroadcasting(Type& type) { +void TestCopyBroadcasting(DeprecatedTypeProperties& type) { Tensor a = zeros({4, 3}, type); Tensor e = rand({3}, type); a.copy_(e); @@ -156,7 +156,7 @@ void TestCopyBroadcasting(Type& type) { ASSERT_TRUE(a[i].equal(e)); } } -void TestAbsValue(Type& type) { +void TestAbsValue(DeprecatedTypeProperties& type) { Tensor r = at::abs(at::scalar_tensor(-3, type.options())); ASSERT_EQ_RESOLVED(r.item(), 3); } @@ -173,12 +173,12 @@ std::cout << (a == 10.) << " -- should be 1" << std::endl; #endif */ -void TestAddingAValueWithScalar(Type& type) { +void TestAddingAValueWithScalar(DeprecatedTypeProperties& type) { Tensor a = rand({4, 3}, type); ASSERT_TRUE((ones({4, 3}, type) + a).equal(add(a, 1))); } -void TestSelect(Type& type) { +void TestSelect(DeprecatedTypeProperties& type) { Tensor a = rand({3, 7}, type); auto a_13 = select(a, 1, 3); auto a_13_02 = select(select(a, 1, 3), 0, 2); @@ -186,7 +186,7 @@ void TestSelect(Type& type) { ASSERT_TRUE(a[2][3].equal(a_13_02)); } -void TestZeroDim(Type& type) { +void TestZeroDim(DeprecatedTypeProperties& type) { Tensor a = at::scalar_tensor(4, type.options()); // rand(type, {1}); Tensor b = rand({3, 4}, type); @@ -263,7 +263,7 @@ void TestIndexingByZerodimTensor() { // Throw StartsWith("Can only index with tensors that are scalars (zero-dim)") ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one)); } -void TestIndexingMixedDevice(Type& type) { +void TestIndexingMixedDevice(DeprecatedTypeProperties& type) { Tensor tensor = randn({20, 20}, type); Tensor index = arange(10, kLong).cpu(); Tensor result = tensor.index({index}); @@ -276,14 +276,14 @@ void TestDispatch() { ASSERT_TRUE(result.allclose(mse_loss(relu(tensor), other))); } -void TestNegativeDim(Type& type) { +void TestNegativeDim(DeprecatedTypeProperties& type) { ASSERT_ANY_THROW(empty({5, -5, 5}, type.options())); ASSERT_ANY_THROW(empty({5, -5, -5}, type.options())); Tensor tensor = empty({5, 5}, type.options()); ASSERT_ANY_THROW(tensor.reshape({-5, -5})); } -void test(Type& type) { +void test(DeprecatedTypeProperties& type) { TestResize(type); TestOnesAndDot(type); diff --git a/aten/src/ATen/test/broadcast_test.cpp b/aten/src/ATen/test/broadcast_test.cpp index 6463115..42a6544 100644 --- a/aten/src/ATen/test/broadcast_test.cpp +++ b/aten/src/ATen/test/broadcast_test.cpp @@ -6,13 +6,13 @@ using namespace at; // can't expand empty tensor -void TestEmptyTensor(Type& T) { +void TestEmptyTensor(DeprecatedTypeProperties& T) { auto empty = randn({0}, T); ASSERT_ANY_THROW(empty.expand({3})); } // out-place function with 2 args -void TestOut2Basic(Type& T) { +void TestOut2Basic(DeprecatedTypeProperties& T) { auto a = randn({3, 1}, T); auto b = randn({5}, T); std::vector expanded_sizes = {3, 5}; @@ -21,7 +21,7 @@ void TestOut2Basic(Type& T) { } // with scalar -void TestOut2WithScalar(Type& T) { +void TestOut2WithScalar(DeprecatedTypeProperties& T) { auto aScalar = ones({1}, T); aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true); auto b = randn({3, 5}, T); @@ -30,21 +30,21 @@ void TestOut2WithScalar(Type& T) { } // old fallback behavior yields error -void TestOut2OldFallback(Type& T) { +void TestOut2OldFallback(DeprecatedTypeProperties& T) { auto a = randn({3, 5}, T); auto b = randn({5, 3}, T); ASSERT_ANY_THROW(a + b); } // with mismatched sizes -void TestOut2MismatchedSizes(Type& T) { +void TestOut2MismatchedSizes(DeprecatedTypeProperties& T) { auto a = randn({3, 5}, T); auto b = randn({7, 5}, T); ASSERT_ANY_THROW(a + b); } // out-place function with 3 args -void TestOut3Basic(Type& T) { +void TestOut3Basic(DeprecatedTypeProperties& T) { auto a = randn({3, 1, 1}, T); auto b = randn({1, 2, 1}, T); auto c = randn({1, 1, 5}, T); @@ -55,7 +55,7 @@ void TestOut3Basic(Type& T) { } // with scalar -void TestOut3WithScalar(Type& T) { +void TestOut3WithScalar(DeprecatedTypeProperties& T) { auto aTensorScalar = ones({1}, T); aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true); auto b = randn({3, 2, 1}, T); @@ -67,7 +67,7 @@ void TestOut3WithScalar(Type& T) { } // old fallback behavior yields error -void TestOut3OldFallback(Type& T) { +void TestOut3OldFallback(DeprecatedTypeProperties& T) { auto a = randn({3, 2, 5}, T); auto b = randn({2, 3, 5}, T); auto c = randn({5, 3, 2}, T); @@ -75,7 +75,7 @@ void TestOut3OldFallback(Type& T) { } // with mismatched sizes -void TestOut3MismatchedSizes(Type& T) { +void TestOut3MismatchedSizes(DeprecatedTypeProperties& T) { auto a = randn({3, 2, 5}, T); auto b = randn({2, 3, 5}, T); auto c = randn({5, 5, 5}, T); @@ -83,14 +83,14 @@ void TestOut3MismatchedSizes(Type& T) { } // in-place function with 2 args -void TestIn2Basic(Type& T) { +void TestIn2Basic(DeprecatedTypeProperties& T) { auto a = randn({3, 5}, T); auto b = randn({3, 1}, T); ASSERT_TRUE((a + b).equal(a + b.expand({3, 5}))); } // with scalar -void TestIn2WithScalar(Type& T) { +void TestIn2WithScalar(DeprecatedTypeProperties& T) { auto a = randn({3, 5}, T); auto bScalar = ones({1}, T); bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true); @@ -98,14 +98,14 @@ void TestIn2WithScalar(Type& T) { } // error: would have to expand inplace arg -void TestIn2ExpandError(Type& T) { +void TestIn2ExpandError(DeprecatedTypeProperties& T) { auto a = randn({1, 5}, T); auto b = randn({3, 1}, T); ASSERT_ANY_THROW(a.add_(b)); } // in-place function with 3 args -void TestIn3Basic(Type& T) { +void TestIn3Basic(DeprecatedTypeProperties& T) { auto a = randn({3, 5, 2}, T); auto b = randn({3, 1, 2}, T); auto c = randn({1, 5, 1}, T); @@ -115,7 +115,7 @@ void TestIn3Basic(Type& T) { } // with scalar -void TestIn3WithScalar(Type& T) { +void TestIn3WithScalar(DeprecatedTypeProperties& T) { auto a = randn({3, 5, 2}, T); auto b = randn({3, 1, 2}, T); auto c = randn({1, 5, 1}, T); @@ -128,7 +128,7 @@ void TestIn3WithScalar(Type& T) { } // error: would have to expand inplace arg -void TestIn3ExpandError(Type& T) { +void TestIn3ExpandError(DeprecatedTypeProperties& T) { auto a = randn({1, 3, 5}, T); auto b = randn({4, 1, 1}, T); auto c = randn({1, 3, 1}, T); @@ -136,7 +136,7 @@ void TestIn3ExpandError(Type& T) { } // explicit dim specification -void TestExplicitDimBasic(Type& T) { +void TestExplicitDimBasic(DeprecatedTypeProperties& T) { auto a = randn({1}, T); auto b = randn({5, 3}, T); auto c = randn({3, 7}, T); @@ -144,7 +144,7 @@ void TestExplicitDimBasic(Type& T) { } // with scalar -void TestExplicitDimWithScalar(Type& T) { +void TestExplicitDimWithScalar(DeprecatedTypeProperties& T) { auto a = randn({1}, T); auto b = randn({5, 3}, T); auto c = randn({3, 7}, T); @@ -154,7 +154,7 @@ void TestExplicitDimWithScalar(Type& T) { } // with mismatched sizes -void TestExplicitDimWithMismatchedSizes(Type& T) { +void TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties& T) { auto b = randn({5, 3}, T); auto c = randn({3, 7}, T); auto a = randn({3, 3}, T); @@ -163,7 +163,7 @@ void TestExplicitDimWithMismatchedSizes(Type& T) { TEST(BroadcastTest, Broadcast) { manual_seed(123); - Type& T = CPU(kFloat); + DeprecatedTypeProperties& T = CPU(kFloat); TestEmptyTensor(T); diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp index 0d0ca1b..cd017cc 100644 --- a/aten/src/ATen/test/native_test.cpp +++ b/aten/src/ATen/test/native_test.cpp @@ -21,32 +21,28 @@ void requireEqualTensorList(TensorList t1, TensorList t2) { } } -// split: test method, type, namespace give same result -void TestSplit(Type& T, Tensor& t) { +// split: test method, namespace give same result +void TestSplit(DeprecatedTypeProperties& T, Tensor& t) { auto splitMethod = t.split(1, 0); - auto splitType = T.split(t, 1, 0); auto splitNs = at::split(t, 1, 0); - requireEqualTensorList(splitMethod, splitType); requireEqualTensorList(splitMethod, splitNs); // test rebuilding with cat ASSERT_EQUAL(at::cat(splitMethod, 0), t); } -// chunk: test method, type, namespace give same result -void TestChunk(Type& T, Tensor& t) { +// chunk: test method, namespace give same result +void TestChunk(DeprecatedTypeProperties& T, Tensor& t) { // test method, type, namespace give same result auto chunkMethod = t.chunk(3, 0); - auto chunkType = T.chunk(t, 3, 0); auto chunkNs = at::chunk(t, 3, 0); - requireEqualTensorList(chunkMethod, chunkType); requireEqualTensorList(chunkMethod, chunkNs); // test rebuilding with cat ASSERT_EQUAL(at::cat(chunkMethod, 0), t); } -void TestStack(Type& T, Tensor& t) { +void TestStack(DeprecatedTypeProperties& T, Tensor& t) { auto x = rand({2, 3, 4}); auto y = rand({2, 3, 4}); auto z = rand({2, 3, 4}); @@ -69,7 +65,7 @@ void TestStack(Type& T, Tensor& t) { } // size / stride -void TestSize(Type& T, Tensor& t) { +void TestSize(DeprecatedTypeProperties& T, Tensor& t) { auto scalar = randn({}, T); // Throw StartsWith("dimension specified as 0 but tensor has no dimensions") ASSERT_ANY_THROW(scalar.size(0)); @@ -87,7 +83,7 @@ void TestSize(Type& T, Tensor& t) { ASSERT_EQ(empty.stride(-1), 1); } -void TestMatmul(Type& T, Tensor& t, Type& AccT) { +void TestMatmul(DeprecatedTypeProperties& T, Tensor& t, DeprecatedTypeProperties& AccT) { auto scalar = randn({}, T); auto d1 = randn({3}, T); auto d2 = randn({2, 3}, T); @@ -160,7 +156,7 @@ void TestMatmul(Type& T, Tensor& t, Type& AccT) { ASSERT_ANY_THROW(d5.matmul(d5wrong)); } -void TestStandardGammaGrad(Type& T, Tensor& t) { +void TestStandardGammaGrad(DeprecatedTypeProperties& T, Tensor& t) { // check empty auto empty = ones({0}, T); ASSERT_EQUAL(empty, at::_standard_gamma_grad(empty, empty)); @@ -179,7 +175,7 @@ void TestStandardGammaGrad(Type& T, Tensor& t) { ASSERT_ANY_THROW(at::_standard_gamma_grad(t1, t2)); } -void TestWhere(Type& T, Tensor& t) { +void TestWhere(DeprecatedTypeProperties& T, Tensor& t) { // empty auto empty = ones({0}, T); auto& bT = T.toScalarType(ScalarType::Byte); @@ -198,7 +194,7 @@ void TestWhere(Type& T, Tensor& t) { at::where(cond_1d, x_1d, y_1d)); } -void test(Type& T, Type& AccT) { +void test(DeprecatedTypeProperties& T, DeprecatedTypeProperties& AccT) { auto t = randn({3, 3}, T); TestSplit(T, t); TestChunk(T, t); diff --git a/aten/src/ATen/test/scalar_tensor_test.cpp b/aten/src/ATen/test/scalar_tensor_test.cpp index 6872344..97e598d 100644 --- a/aten/src/ATen/test/scalar_tensor_test.cpp +++ b/aten/src/ATen/test/scalar_tensor_test.cpp @@ -42,7 +42,7 @@ bool should_expand(const IntArrayRef &from_size, const IntArrayRef &to_size) { return true; } -void test(Type &T) { +void test(DeprecatedTypeProperties &T) { std::vector> sizes = {{}, {0}, {1}, {1, 1}, {2}}; // single-tensor/size tests diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 5a3c926..9c9c42c 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -27,9 +27,9 @@ TEST(TestUndefined, UndefinedTest) { ASSERT_ANY_THROW(und.add(5)); ASSERT_ANY_THROW(und.mm(und)); - und.toType(und.dispatch_type()); - ASSERT_ANY_THROW(und.toType(ft.dispatch_type())); - ASSERT_ANY_THROW(ft.toType(und.dispatch_type())); + und.toType(und.type()); + ASSERT_ANY_THROW(und.toType(ft.type())); + ASSERT_ANY_THROW(ft.toType(und.type())); und.toType(ScalarType::Undefined); ASSERT_ANY_THROW(und.toType(ScalarType::Float)); ASSERT_ANY_THROW(ft.toType(ScalarType::Undefined)); diff --git a/aten/src/ATen/test/wrapdim_test.cpp b/aten/src/ATen/test/wrapdim_test.cpp index b7088cf..9c3b18a 100644 --- a/aten/src/ATen/test/wrapdim_test.cpp +++ b/aten/src/ATen/test/wrapdim_test.cpp @@ -3,13 +3,13 @@ #include using namespace at; -void TestSimpleCase(Type& T) { +void TestSimpleCase(DeprecatedTypeProperties& T) { auto a = randn({2, 3, 4, 5}, T); ASSERT_TRUE(a.prod(-4).equal(a.prod(0))); ASSERT_TRUE(a.prod(3).equal(a.prod(-1))); } -void TestExpressionSpecification(Type& T) { +void TestExpressionSpecification(DeprecatedTypeProperties& T) { auto a = randn({2, 3, 4, 5}, T); ASSERT_TRUE(a.unsqueeze(-5).equal(a.unsqueeze(0))); ASSERT_TRUE(a.unsqueeze(4).equal(a.unsqueeze(-1))); @@ -20,12 +20,12 @@ void TestExpressionSpecification(Type& T) { ASSERT_TRUE(b.unsqueeze(0).equal(b.unsqueeze(-1))); } -void TestEmptyTensor(Type& T) { +void TestEmptyTensor(DeprecatedTypeProperties& T) { auto a = randn(0, T); ASSERT_TRUE(a.prod(0).equal(at::ones({}, T))); } -void TestScalarVs1Dim1Size(Type& T) { +void TestScalarVs1Dim1Size(DeprecatedTypeProperties& T) { auto a = randn(1, T); ASSERT_TRUE(a.prod(0).equal(a.prod(-1))); a.unsafeGetTensorImpl()->maybe_zero_dim(true); @@ -35,7 +35,7 @@ void TestScalarVs1Dim1Size(Type& T) { TEST(TestWrapdim, TestWrapdim) { manual_seed(123); - Type& T = CPU(kFloat); + DeprecatedTypeProperties& T = CPU(kFloat); TestSimpleCase(T); TestEmptyTensor(T); diff --git a/test/cpp/jit/test_argument_spec.h b/test/cpp/jit/test_argument_spec.h index 315f32e..0bd73c9 100644 --- a/test/cpp/jit/test_argument_spec.h +++ b/test/cpp/jit/test_argument_spec.h @@ -24,7 +24,7 @@ bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) { isEqual(ti.strides(), v.strides()); } -autograd::Variable var(at::Type& t, at::IntArrayRef sizes, bool requires_grad) { +autograd::Variable var(at::DeprecatedTypeProperties& t, at::IntArrayRef sizes, bool requires_grad) { return autograd::make_variable(at::rand(sizes, t.options()), requires_grad); } autograd::Variable undef() { diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 0c166ac..984105d 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -82,8 +82,7 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name); } auto& tensor = ((THPVariable*)_new_state)->cdata.data(); - auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type()); - if (tensor_type != CPU(kByte)) { + if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) { auto type_name = torch::utils::type_to_string(tensor.dispatch_type()); throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str()); } diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index c2eed5c..be5f576 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -260,7 +260,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_block if (requires_grad) { grad_fn = std::make_shared(); grad_fn->set_next_edges(collect_next_edges(self, src)); - grad_fn->src_type = &src.dispatch_type(); + grad_fn->src_type = &src.type(); grad_fn->src_device = src.device(); } { diff --git a/torch/csrc/autograd/functions/tensor.h b/torch/csrc/autograd/functions/tensor.h index ecbf711..e4e4b7e 100644 --- a/torch/csrc/autograd/functions/tensor.h +++ b/torch/csrc/autograd/functions/tensor.h @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -15,7 +15,7 @@ namespace torch { namespace autograd { struct CopyBackwards : public Function { variable_list apply(variable_list&& grads) override; - at::Type *src_type = nullptr; // initialized for safety. + at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety. at::Device src_device = at::kCPU; }; -- 2.7.4