CAFFE2_API Allocator* getCPUAllocator();
-static inline TypeExtendedInterface& CPU(ScalarType s) {
- return getNonVariableType(Backend::CPU, s);
+static inline DeprecatedTypeProperties& CPU(ScalarType s) {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ Backend::CPU, s, /*is_variable*/false);
}
-static inline TypeExtendedInterface& CUDA(ScalarType s) {
- return getNonVariableType(Backend::CUDA, s);
+static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ Backend::CUDA, s, /*is_variable*/false);
}
-static inline TypeExtendedInterface& HIP(ScalarType s) {
- return getNonVariableType(Backend::HIP, s);
+static inline DeprecatedTypeProperties& HIP(ScalarType s) {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ Backend::HIP, s, /*is_variable*/false);
}
CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options);
#include <ATen/Type.h>
#include <c10/util/Half.h>
#include <c10/util/Exception.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
--- /dev/null
+#include <ATen/core/DeprecatedTypeProperties.h>
+
+#include <ATen/core/LegacyTypeDispatch.h>
+#include <ATen/core/Type.h>
+
+namespace at {
+
+Tensor DeprecatedTypeProperties::unsafeTensorFromTH(void * th_pointer, bool retain) const {
+ return getDispatchType().unsafeTensorFromTH(th_pointer, retain);
+}
+
+Tensor DeprecatedTypeProperties::copy(const Tensor & src, bool non_blocking, c10::optional<Device> to_device) const {
+ return getDispatchType().copy(src, non_blocking, to_device);
+}
+
+std::unique_ptr<Generator> DeprecatedTypeProperties::generator() const {
+ return getDispatchType().generator();
+}
+
+Type & DeprecatedTypeProperties::getDispatchType() const {
+ return globalLegacyTypeDispatch().getType(backend_, scalar_type_, is_variable_);
+}
+
+} // namespace at
#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Layout.h>
-
+#include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
+#include <ATen/core/Generator.h>
namespace at {
+class Tensor;
+struct Type;
+
// This class specifies a Backend and a ScalarType. Currently, it primarily
// serves as a replacement return value for Tensor::type(). Previously,
// Tensor::type() returned Type&, but we are changing Type to not be
// dtype-specific.
-class DeprecatedTypeProperties {
+class CAFFE2_API DeprecatedTypeProperties {
public:
- DeprecatedTypeProperties(Backend backend, ScalarType scalar_type)
- : backend_(backend), scalar_type_(scalar_type) {}
+ DeprecatedTypeProperties(Backend backend, ScalarType scalar_type, bool is_variable)
+ : backend_(backend), scalar_type_(scalar_type), is_variable_(is_variable) {}
Backend backend() const {
return backend_;
}
+ Layout layout() const {
+ return layout_from_backend(backend_);
+ }
+
bool is_sparse() const {
return layout_from_backend(backend()) == kSparse;
}
return scalarTypeToTypeMeta(scalar_type_);
}
- bool is_defined() const {
- return backend_ != Backend::Undefined && scalar_type_ != ScalarType::Undefined;
+ bool is_variable() const {
+ return is_variable_;
}
bool operator==(const DeprecatedTypeProperties& other) const {
return ss.str();
}
+ DeprecatedTypeProperties & toBackend(Backend b) const {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ b, scalar_type_, is_variable_);
+ }
+
+ DeprecatedTypeProperties & toScalarType(ScalarType s) const {
+ return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+ backend_, s, is_variable_);
+ }
+
+ DeprecatedTypeProperties & cpu() const {
+ return toBackend(Backend::CPU);
+ }
+
+ DeprecatedTypeProperties & cuda() const {
+ return toBackend(Backend::CUDA);
+ }
+
+ DeprecatedTypeProperties & hip() const {
+ return toBackend(Backend::HIP);
+ }
+
+ /// Constructs the `TensorOptions` from a type and a `device_index`.
+ TensorOptions options(int16_t device_index = -1) const {
+ return TensorOptions().dtype(typeMeta())
+ .device(device_type(), device_index)
+ .layout(layout())
+ .is_variable(is_variable());
+ }
+
+ /// Constructs the `TensorOptions` from a type and a Device. Asserts that
+ /// the device type matches the device type of the type.
+ TensorOptions options(c10::optional<Device> device_opt) const {
+ if (!device_opt.has_value()) {
+ return options(-1);
+ } else {
+ Device device = device_opt.value();
+ AT_ASSERT(device.type() == device_type());
+ return options(device.index());
+ }
+ }
+
+ operator TensorOptions() const {
+ return options();
+ }
+
+ Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const;
+ Tensor copy(const Tensor & src, bool non_blocking=false, c10::optional<Device> to_device={}) const;
+ std::unique_ptr<Generator> generator() const;
+
private:
+ Type & getDispatchType() const;
+
Backend backend_;
ScalarType scalar_type_;
+ bool is_variable_;
};
} // namespace at
#include <ATen/core/DeprecatedTypePropertiesRegistry.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
+
namespace at {
+void DeprecatedTypePropertiesDeleter::operator()(DeprecatedTypeProperties * ptr) {
+ delete ptr;
+}
+
+DeprecatedTypePropertiesRegistry::DeprecatedTypePropertiesRegistry() {
+ for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
+ for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
+ for (int v = 0; v < 2; ++ v) {
+ registry[b][s][v] = c10::guts::make_unique<DeprecatedTypeProperties>(
+ static_cast<Backend>(b),
+ static_cast<ScalarType>(s),
+ v);
+ }
+ }
+ }
+}
+
+DeprecatedTypeProperties& DeprecatedTypePropertiesRegistry::getDeprecatedTypeProperties(
+ Backend p, ScalarType s, bool is_variable) const {
+ return *registry[static_cast<int>(p)][static_cast<int>(s)][is_variable];
+}
+
// TODO: This could be bad juju if someone calls globalContext() in the
// destructor of an object with static lifetime.
DeprecatedTypePropertiesRegistry & globalDeprecatedTypePropertiesRegistry() {
#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>
-#include <ATen/core/DeprecatedTypeProperties.h>
namespace at {
+class DeprecatedTypeProperties;
+
struct CAFFE2_API DeprecatedTypePropertiesDeleter {
- void operator()(DeprecatedTypeProperties * ptr) {
- delete ptr;
- }
+ void operator()(DeprecatedTypeProperties * ptr);
};
class CAFFE2_API DeprecatedTypePropertiesRegistry {
public:
- using DeprecatedTypePropertiesUniquePtr =
- std::unique_ptr<DeprecatedTypeProperties, DeprecatedTypePropertiesDeleter>;
-
- DeprecatedTypePropertiesRegistry() {
- for (int b = 0; b < static_cast<int>(Backend::NumOptions); ++b) {
- for (int s = 0; s < static_cast<int>(ScalarType::NumOptions); ++s) {
- registry[b][s] = DeprecatedTypePropertiesUniquePtr{
- new DeprecatedTypeProperties(static_cast<Backend>(b), static_cast<ScalarType>(s)),
- DeprecatedTypePropertiesDeleter()
- };
- }
- }
- }
-
- DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) {
- return *registry[static_cast<int>(p)][static_cast<int>(s)];
- }
+ DeprecatedTypePropertiesRegistry();
+
+ DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s, bool is_variable) const;
private:
- DeprecatedTypePropertiesUniquePtr registry
+ std::unique_ptr<DeprecatedTypeProperties> registry
[static_cast<int>(Backend::NumOptions)]
- [static_cast<int>(ScalarType::NumOptions)];
+ [static_cast<int>(ScalarType::NumOptions)]
+ [2]; // is_variable
};
CAFFE2_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry();
namespace at {
struct Generator;
struct Type;
+class DeprecatedTypeProperties;
class Tensor;
} // namespace at
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- tensorTypeIdToBackend(type_id()), scalar_type());
+ tensorTypeIdToBackend(type_id()),
+ scalar_type(),
+ is_variable() && !at::NonVariableTypeMode::is_enabled());
}
Type & dispatch_type() const {
return legacyTensorType(*impl_);
bool is_alias_of(const at::Tensor& other) const{
return impl_->storage().is_alias_of(other.storage());
}
- Tensor toType(const Type & t, bool non_blocking=false) const;
Tensor & copy_(const Tensor & src, bool non_blocking=false);
+ Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
Tensor toType(ScalarType t) const;
Tensor toBackend(Backend b) const;
#include <ATen/core/SparseTensorRef.h>
#include <ATen/core/Type.h>
#include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
namespace at {
-inline Tensor Tensor::toType(const Type & t, bool non_blocking) const {
- if(dispatch_type() == t)
+inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const {
+ if(type() == t)
return *this;
- return t.copy(*this, non_blocking);
+ return to(
+ at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()),
+ non_blocking,
+ /*copy*/ true);
}
inline Tensor Tensor::cpu() const {
- return toType(dispatch_type().cpu());
+ return toType(type().cpu());
}
inline Tensor Tensor::cuda() const {
- return toType(dispatch_type().cuda());
+ return toType(type().cuda());
}
inline Tensor Tensor::hip() const {
- return toType(dispatch_type().hip());
+ return toType(type().hip());
}
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
}
inline Tensor Tensor::toType(ScalarType t) const {
- return toType(dispatch_type().toScalarType(t));
+ return toType(type().toScalarType(t));
}
inline Tensor Tensor::toBackend(Backend b) const {
- return toType(dispatch_type().toBackend(b));
+ return toType(type().toBackend(b));
}
inline TensorOptions Tensor::options() const {
}
Tensor type_as(const Tensor& self, const Tensor& other) {
- return self.toType(other.dispatch_type());
+ return self.toType(other.type());
}
}} // namespace at::native
namespace at {
struct Generator;
struct Type;
+class DeprecatedTypeProperties;
class Tensor;
} // namespace at
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- tensorTypeIdToBackend(type_id()), scalar_type());
+ tensorTypeIdToBackend(type_id()),
+ scalar_type(),
+ is_variable() && !at::NonVariableTypeMode::is_enabled());
}
Type & dispatch_type() const {
return legacyTensorType(*impl_);
bool is_alias_of(const at::Tensor& other) const{
return impl_->storage().is_alias_of(other.storage());
}
- Tensor toType(const Type & t, bool non_blocking=false) const;
Tensor & copy_(const Tensor & src, bool non_blocking=false);
+ Tensor toType(const DeprecatedTypeProperties & t, bool non_blocking=false) const;
Tensor toType(ScalarType t) const;
Tensor toBackend(Backend b) const;
#include <ATen/core/SparseTensorRef.h>
#include <ATen/core/Type.h>
#include <c10/core/TensorOptions.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
namespace at {
-inline Tensor Tensor::toType(const Type & t, bool non_blocking) const {
- if(dispatch_type() == t)
+inline Tensor Tensor::toType(const DeprecatedTypeProperties & t, bool non_blocking) const {
+ if(type() == t)
return *this;
- return t.copy(*this, non_blocking);
+ return to(
+ at::device(t.device_type()).layout(t.layout()).dtype(t.scalarType()),
+ non_blocking,
+ /*copy*/ true);
}
inline Tensor Tensor::cpu() const {
- return toType(dispatch_type().cpu());
+ return toType(type().cpu());
}
inline Tensor Tensor::cuda() const {
- return toType(dispatch_type().cuda());
+ return toType(type().cuda());
}
inline Tensor Tensor::hip() const {
- return toType(dispatch_type().hip());
+ return toType(type().hip());
}
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
}
inline Tensor Tensor::toType(ScalarType t) const {
- return toType(dispatch_type().toScalarType(t));
+ return toType(type().toScalarType(t));
}
inline Tensor Tensor::toBackend(Backend b) const {
- return toType(dispatch_type().toBackend(b));
+ return toType(type().toBackend(b));
}
inline TensorOptions Tensor::options() const {
// write the same type as we read (using a0, ..., aX-1) and we once write to
// double (using a4 as a target). We also exercise on a zero_dim and empty
// tensor.
-void test(Type& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
+void test(DeprecatedTypeProperties& type, IntArrayRef shape, int64_t a = 0, int64_t b = 1) {
auto zero_dim = at::empty({}, type);
zero_dim.fill_(2);
zero_dim.exp_();
using namespace at;
-void TestResize(Type& type) {
+void TestResize(DeprecatedTypeProperties& type) {
auto a = at::empty({0}, type.options());
a.resize_({3, 4});
ASSERT_EQ_RESOLVED(a.numel(), 12);
ASSERT_EQ_RESOLVED(a.numel(), 35);
}
-void TestOnesAndDot(Type& type) {
+void TestOnesAndDot(DeprecatedTypeProperties& type) {
Tensor b0 = ones({1, 1}, type);
ASSERT_EQ_RESOLVED((b0 + b0).sum().item<double>(), 2);
ASSERT_EQ_RESOLVED(b.view(-1).dot(b.view(-1)).item<double>(), 12);
}
-void TestSort(Type& type) {
+void TestSort(DeprecatedTypeProperties& type) {
Tensor b = rand({3, 4}, type);
auto z = b.sort(1);
ASSERT_TRUE(isLT);
}
-void TestRandperm(Type& type) {
+void TestRandperm(DeprecatedTypeProperties& type) {
if (type.backend() != Backend::CUDA) {
Tensor b = randperm(15, type);
Tensor rv, ri;
ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl;
}
-void TestAdd(Type& type) {
+void TestAdd(DeprecatedTypeProperties& type) {
Tensor a = rand({3, 4}, type);
Tensor b = rand({3, 4}, type);
Tensor c = add(a, add(a, b));
ASSERT_TRUE(add(c, d).allclose(a + a + b + d));
}
-void TestLoadsOfAdds(Type& type) {
+void TestLoadsOfAdds(DeprecatedTypeProperties& type) {
auto begin = std::chrono::high_resolution_clock::now();
Tensor d = ones({3, 4}, type);
Tensor r = zeros({3, 4}, type);
ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>());
}
-void TestLoadOfAddsWithCopy(Type& type) {
+void TestLoadOfAddsWithCopy(DeprecatedTypeProperties& type) {
auto begin = std::chrono::high_resolution_clock::now();
Tensor d = ones({3, 4}, type);
Tensor r = zeros({3, 4}, type);
ASSERT_EQ_RESOLVED(norm(100000 * d).item<double>(), norm(r).item<double>());
}
-void TestIsContiguous(Type& type) {
+void TestIsContiguous(DeprecatedTypeProperties& type) {
Tensor a = rand({3, 4}, type);
ASSERT_TRUE(a.is_contiguous());
a = a.transpose(0, 1);
ASSERT_FALSE(a.is_contiguous());
}
-void TestPermute(Type& type) {
+void TestPermute(DeprecatedTypeProperties& type) {
Tensor a = rand({3, 4, 5}, type);
Tensor b = a.permute({1, 2, 0});
ASSERT_TRUE(b.sizes().equals({4, 5, 3}));
ASSERT_TRUE(b.strides().equals({5, 1, 20}));
}
-void TestMm(Type& type) {
+void TestMm(DeprecatedTypeProperties& type) {
Tensor a = rand({3, 4}, type);
Tensor b = rand({4}, type);
Tensor c = mv(a, b);
ASSERT_TRUE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
}
-void TestSqueeze(Type& type) {
+void TestSqueeze(DeprecatedTypeProperties& type) {
Tensor a = rand({2, 1}, type);
Tensor b = squeeze(a);
ASSERT_EQ_RESOLVED(b.dim(), 1);
ASSERT_TRUE(a[0].equal(b));
}
-void TestCopy(Type& type) {
+void TestCopy(DeprecatedTypeProperties& type) {
Tensor a = zeros({4, 3}, type);
Tensor e = rand({4, 3}, type);
a.copy_(e);
ASSERT_TRUE(a.equal(e));
}
-void TestCopyBroadcasting(Type& type) {
+void TestCopyBroadcasting(DeprecatedTypeProperties& type) {
Tensor a = zeros({4, 3}, type);
Tensor e = rand({3}, type);
a.copy_(e);
ASSERT_TRUE(a[i].equal(e));
}
}
-void TestAbsValue(Type& type) {
+void TestAbsValue(DeprecatedTypeProperties& type) {
Tensor r = at::abs(at::scalar_tensor(-3, type.options()));
ASSERT_EQ_RESOLVED(r.item<int32_t>(), 3);
}
#endif
*/
-void TestAddingAValueWithScalar(Type& type) {
+void TestAddingAValueWithScalar(DeprecatedTypeProperties& type) {
Tensor a = rand({4, 3}, type);
ASSERT_TRUE((ones({4, 3}, type) + a).equal(add(a, 1)));
}
-void TestSelect(Type& type) {
+void TestSelect(DeprecatedTypeProperties& type) {
Tensor a = rand({3, 7}, type);
auto a_13 = select(a, 1, 3);
auto a_13_02 = select(select(a, 1, 3), 0, 2);
ASSERT_TRUE(a[2][3].equal(a_13_02));
}
-void TestZeroDim(Type& type) {
+void TestZeroDim(DeprecatedTypeProperties& type) {
Tensor a = at::scalar_tensor(4, type.options()); // rand(type, {1});
Tensor b = rand({3, 4}, type);
// Throw StartsWith("Can only index with tensors that are scalars (zero-dim)")
ASSERT_ANY_THROW(tensor[ones({2, 3, 4}, kInt)].equal(one));
}
-void TestIndexingMixedDevice(Type& type) {
+void TestIndexingMixedDevice(DeprecatedTypeProperties& type) {
Tensor tensor = randn({20, 20}, type);
Tensor index = arange(10, kLong).cpu();
Tensor result = tensor.index({index});
ASSERT_TRUE(result.allclose(mse_loss(relu(tensor), other)));
}
-void TestNegativeDim(Type& type) {
+void TestNegativeDim(DeprecatedTypeProperties& type) {
ASSERT_ANY_THROW(empty({5, -5, 5}, type.options()));
ASSERT_ANY_THROW(empty({5, -5, -5}, type.options()));
Tensor tensor = empty({5, 5}, type.options());
ASSERT_ANY_THROW(tensor.reshape({-5, -5}));
}
-void test(Type& type) {
+void test(DeprecatedTypeProperties& type) {
TestResize(type);
TestOnesAndDot(type);
using namespace at;
// can't expand empty tensor
-void TestEmptyTensor(Type& T) {
+void TestEmptyTensor(DeprecatedTypeProperties& T) {
auto empty = randn({0}, T);
ASSERT_ANY_THROW(empty.expand({3}));
}
// out-place function with 2 args
-void TestOut2Basic(Type& T) {
+void TestOut2Basic(DeprecatedTypeProperties& T) {
auto a = randn({3, 1}, T);
auto b = randn({5}, T);
std::vector<int64_t> expanded_sizes = {3, 5};
}
// with scalar
-void TestOut2WithScalar(Type& T) {
+void TestOut2WithScalar(DeprecatedTypeProperties& T) {
auto aScalar = ones({1}, T);
aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 5}, T);
}
// old fallback behavior yields error
-void TestOut2OldFallback(Type& T) {
+void TestOut2OldFallback(DeprecatedTypeProperties& T) {
auto a = randn({3, 5}, T);
auto b = randn({5, 3}, T);
ASSERT_ANY_THROW(a + b);
}
// with mismatched sizes
-void TestOut2MismatchedSizes(Type& T) {
+void TestOut2MismatchedSizes(DeprecatedTypeProperties& T) {
auto a = randn({3, 5}, T);
auto b = randn({7, 5}, T);
ASSERT_ANY_THROW(a + b);
}
// out-place function with 3 args
-void TestOut3Basic(Type& T) {
+void TestOut3Basic(DeprecatedTypeProperties& T) {
auto a = randn({3, 1, 1}, T);
auto b = randn({1, 2, 1}, T);
auto c = randn({1, 1, 5}, T);
}
// with scalar
-void TestOut3WithScalar(Type& T) {
+void TestOut3WithScalar(DeprecatedTypeProperties& T) {
auto aTensorScalar = ones({1}, T);
aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 2, 1}, T);
}
// old fallback behavior yields error
-void TestOut3OldFallback(Type& T) {
+void TestOut3OldFallback(DeprecatedTypeProperties& T) {
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 3, 2}, T);
}
// with mismatched sizes
-void TestOut3MismatchedSizes(Type& T) {
+void TestOut3MismatchedSizes(DeprecatedTypeProperties& T) {
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 5, 5}, T);
}
// in-place function with 2 args
-void TestIn2Basic(Type& T) {
+void TestIn2Basic(DeprecatedTypeProperties& T) {
auto a = randn({3, 5}, T);
auto b = randn({3, 1}, T);
ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
}
// with scalar
-void TestIn2WithScalar(Type& T) {
+void TestIn2WithScalar(DeprecatedTypeProperties& T) {
auto a = randn({3, 5}, T);
auto bScalar = ones({1}, T);
bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
}
// error: would have to expand inplace arg
-void TestIn2ExpandError(Type& T) {
+void TestIn2ExpandError(DeprecatedTypeProperties& T) {
auto a = randn({1, 5}, T);
auto b = randn({3, 1}, T);
ASSERT_ANY_THROW(a.add_(b));
}
// in-place function with 3 args
-void TestIn3Basic(Type& T) {
+void TestIn3Basic(DeprecatedTypeProperties& T) {
auto a = randn({3, 5, 2}, T);
auto b = randn({3, 1, 2}, T);
auto c = randn({1, 5, 1}, T);
}
// with scalar
-void TestIn3WithScalar(Type& T) {
+void TestIn3WithScalar(DeprecatedTypeProperties& T) {
auto a = randn({3, 5, 2}, T);
auto b = randn({3, 1, 2}, T);
auto c = randn({1, 5, 1}, T);
}
// error: would have to expand inplace arg
-void TestIn3ExpandError(Type& T) {
+void TestIn3ExpandError(DeprecatedTypeProperties& T) {
auto a = randn({1, 3, 5}, T);
auto b = randn({4, 1, 1}, T);
auto c = randn({1, 3, 1}, T);
}
// explicit dim specification
-void TestExplicitDimBasic(Type& T) {
+void TestExplicitDimBasic(DeprecatedTypeProperties& T) {
auto a = randn({1}, T);
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
}
// with scalar
-void TestExplicitDimWithScalar(Type& T) {
+void TestExplicitDimWithScalar(DeprecatedTypeProperties& T) {
auto a = randn({1}, T);
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
}
// with mismatched sizes
-void TestExplicitDimWithMismatchedSizes(Type& T) {
+void TestExplicitDimWithMismatchedSizes(DeprecatedTypeProperties& T) {
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
auto a = randn({3, 3}, T);
TEST(BroadcastTest, Broadcast) {
manual_seed(123);
- Type& T = CPU(kFloat);
+ DeprecatedTypeProperties& T = CPU(kFloat);
TestEmptyTensor(T);
}
}
-// split: test method, type, namespace give same result
-void TestSplit(Type& T, Tensor& t) {
+// split: test method, namespace give same result
+void TestSplit(DeprecatedTypeProperties& T, Tensor& t) {
auto splitMethod = t.split(1, 0);
- auto splitType = T.split(t, 1, 0);
auto splitNs = at::split(t, 1, 0);
- requireEqualTensorList(splitMethod, splitType);
requireEqualTensorList(splitMethod, splitNs);
// test rebuilding with cat
ASSERT_EQUAL(at::cat(splitMethod, 0), t);
}
-// chunk: test method, type, namespace give same result
-void TestChunk(Type& T, Tensor& t) {
+// chunk: test method, namespace give same result
+void TestChunk(DeprecatedTypeProperties& T, Tensor& t) {
// test method, type, namespace give same result
auto chunkMethod = t.chunk(3, 0);
- auto chunkType = T.chunk(t, 3, 0);
auto chunkNs = at::chunk(t, 3, 0);
- requireEqualTensorList(chunkMethod, chunkType);
requireEqualTensorList(chunkMethod, chunkNs);
// test rebuilding with cat
ASSERT_EQUAL(at::cat(chunkMethod, 0), t);
}
-void TestStack(Type& T, Tensor& t) {
+void TestStack(DeprecatedTypeProperties& T, Tensor& t) {
auto x = rand({2, 3, 4});
auto y = rand({2, 3, 4});
auto z = rand({2, 3, 4});
}
// size / stride
-void TestSize(Type& T, Tensor& t) {
+void TestSize(DeprecatedTypeProperties& T, Tensor& t) {
auto scalar = randn({}, T);
// Throw StartsWith("dimension specified as 0 but tensor has no dimensions")
ASSERT_ANY_THROW(scalar.size(0));
ASSERT_EQ(empty.stride(-1), 1);
}
-void TestMatmul(Type& T, Tensor& t, Type& AccT) {
+void TestMatmul(DeprecatedTypeProperties& T, Tensor& t, DeprecatedTypeProperties& AccT) {
auto scalar = randn({}, T);
auto d1 = randn({3}, T);
auto d2 = randn({2, 3}, T);
ASSERT_ANY_THROW(d5.matmul(d5wrong));
}
-void TestStandardGammaGrad(Type& T, Tensor& t) {
+void TestStandardGammaGrad(DeprecatedTypeProperties& T, Tensor& t) {
// check empty
auto empty = ones({0}, T);
ASSERT_EQUAL(empty, at::_standard_gamma_grad(empty, empty));
ASSERT_ANY_THROW(at::_standard_gamma_grad(t1, t2));
}
-void TestWhere(Type& T, Tensor& t) {
+void TestWhere(DeprecatedTypeProperties& T, Tensor& t) {
// empty
auto empty = ones({0}, T);
auto& bT = T.toScalarType(ScalarType::Byte);
at::where(cond_1d, x_1d, y_1d));
}
-void test(Type& T, Type& AccT) {
+void test(DeprecatedTypeProperties& T, DeprecatedTypeProperties& AccT) {
auto t = randn({3, 3}, T);
TestSplit(T, t);
TestChunk(T, t);
return true;
}
-void test(Type &T) {
+void test(DeprecatedTypeProperties &T) {
std::vector<std::vector<int64_t>> sizes = {{}, {0}, {1}, {1, 1}, {2}};
// single-tensor/size tests
ASSERT_ANY_THROW(und.add(5));
ASSERT_ANY_THROW(und.mm(und));
- und.toType(und.dispatch_type());
- ASSERT_ANY_THROW(und.toType(ft.dispatch_type()));
- ASSERT_ANY_THROW(ft.toType(und.dispatch_type()));
+ und.toType(und.type());
+ ASSERT_ANY_THROW(und.toType(ft.type()));
+ ASSERT_ANY_THROW(ft.toType(und.type()));
und.toType(ScalarType::Undefined);
ASSERT_ANY_THROW(und.toType(ScalarType::Float));
ASSERT_ANY_THROW(ft.toType(ScalarType::Undefined));
#include <ATen/ATen.h>
using namespace at;
-void TestSimpleCase(Type& T) {
+void TestSimpleCase(DeprecatedTypeProperties& T) {
auto a = randn({2, 3, 4, 5}, T);
ASSERT_TRUE(a.prod(-4).equal(a.prod(0)));
ASSERT_TRUE(a.prod(3).equal(a.prod(-1)));
}
-void TestExpressionSpecification(Type& T) {
+void TestExpressionSpecification(DeprecatedTypeProperties& T) {
auto a = randn({2, 3, 4, 5}, T);
ASSERT_TRUE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
ASSERT_TRUE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
ASSERT_TRUE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
}
-void TestEmptyTensor(Type& T) {
+void TestEmptyTensor(DeprecatedTypeProperties& T) {
auto a = randn(0, T);
ASSERT_TRUE(a.prod(0).equal(at::ones({}, T)));
}
-void TestScalarVs1Dim1Size(Type& T) {
+void TestScalarVs1Dim1Size(DeprecatedTypeProperties& T) {
auto a = randn(1, T);
ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
a.unsafeGetTensorImpl()->maybe_zero_dim(true);
TEST(TestWrapdim, TestWrapdim) {
manual_seed(123);
- Type& T = CPU(kFloat);
+ DeprecatedTypeProperties& T = CPU(kFloat);
TestSimpleCase(T);
TestEmptyTensor(T);
isEqual(ti.strides(), v.strides());
}
-autograd::Variable var(at::Type& t, at::IntArrayRef sizes, bool requires_grad) {
+autograd::Variable var(at::DeprecatedTypeProperties& t, at::IntArrayRef sizes, bool requires_grad) {
return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
}
autograd::Variable undef() {
throw TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name);
}
auto& tensor = ((THPVariable*)_new_state)->cdata.data();
- auto& tensor_type = at::globalContext().getNonVariableType(tensor.type().backend(), tensor.scalar_type());
- if (tensor_type != CPU(kByte)) {
+ if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) {
auto type_name = torch::utils::type_to_string(tensor.dispatch_type());
throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
}
if (requires_grad) {
grad_fn = std::make_shared<CopyBackwards>();
grad_fn->set_next_edges(collect_next_edges(self, src));
- grad_fn->src_type = &src.dispatch_type();
+ grad_fn->src_type = &src.type();
grad_fn->src_device = src.device();
}
{
#include <torch/csrc/autograd/variable.h>
#include <ATen/TensorGeometry.h>
-#include <ATen/Type.h>
+#include <ATen/core/DeprecatedTypeProperties.h>
#include <c10/util/Optional.h>
#include <cstdint>
struct CopyBackwards : public Function {
variable_list apply(variable_list&& grads) override;
- at::Type *src_type = nullptr; // initialized for safety.
+ at::DeprecatedTypeProperties *src_type = nullptr; // initialized for safety.
at::Device src_device = at::kCPU;
};