Pass ScalarType separately from Type in python constructors
authorRoy Li <royboy@fb.com>
Thu, 4 Apr 2019 09:21:09 +0000 (02:21 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 09:24:20 +0000 (02:24 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17786

Reviewed By: ezyang

Differential Revision: D14379075

fbshipit-source-id: 3abf066563b789a30cafe5b0c868a41326f5b833

tools/autograd/templates/python_torch_functions.cpp
tools/autograd/templates/python_torch_functions_dispatch.h
tools/autograd/templates/python_variable_methods.cpp
torch/csrc/autograd/python_variable.cpp
torch/csrc/autograd/python_variable_indexing.cpp
torch/csrc/tensor/python_tensor.cpp
torch/csrc/tensor/python_tensor.h
torch/csrc/utils/tensor_new.cpp
torch/csrc/utils/tensor_new.h

index 2c5bbee..062e254 100644 (file)
@@ -328,7 +328,7 @@ static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject
 {
   HANDLE_TH_ERRORS
   jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR);
-  return THPVariable_Wrap(torch::utils::as_tensor(default_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::as_tensor(default_type(), default_scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -361,7 +361,7 @@ static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args,
 {
   HANDLE_TH_ERRORS
   jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR);
-  return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(default_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(default_type(), default_scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -369,7 +369,7 @@ static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* k
 {
   HANDLE_TH_ERRORS
   jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR);
-  return THPVariable_Wrap(torch::utils::tensor_ctor(default_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::tensor_ctor(default_type(), default_scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
index 07c5802..2e6b710 100644 (file)
@@ -17,6 +17,7 @@ namespace torch { namespace autograd {
 
 using at::Tensor;
 using at::Scalar;
+using at::ScalarType;
 using at::TensorList;
 using at::IntArrayRef;
 using at::Generator;
@@ -28,6 +29,10 @@ static at::Type& default_type() {
   return torch::tensors::get_default_tensor_type();
 }
 
+static ScalarType default_scalar_type() {
+  return torch::tensors::get_default_scalar_type();
+}
+
 static void maybe_initialize_cuda(const at::TensorOptions& options) {
   if (options.device().is_cuda()) {
     torch::utils::cuda_lazy_init();
index 34943d5..bdc2cf8 100644 (file)
@@ -496,7 +496,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar
   HANDLE_TH_ERRORS
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::legacy_tensor_new(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::legacy_tensor_new(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -506,7 +506,7 @@ static PyObject * THPVariable_new_empty(PyObject* self, PyObject* args, PyObject
   jit::tracer::warn("new_empty", jit::tracer::LEGACY_CONSTRUCTOR);
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::new_empty(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::new_empty(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -516,7 +516,7 @@ static PyObject * THPVariable_new_full(PyObject* self, PyObject* args, PyObject*
   jit::tracer::warn("new_full", jit::tracer::LEGACY_CONSTRUCTOR);
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::new_full(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::new_full(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -526,7 +526,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject*
   jit::tracer::warn("new_ones", jit::tracer::LEGACY_CONSTRUCTOR);
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::new_ones(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::new_ones(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -536,7 +536,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec
   jit::tracer::warn("new_tensor", jit::tracer::LEGACY_CONSTRUCTOR);
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::new_tensor(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::new_tensor(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -546,7 +546,7 @@ static PyObject * THPVariable_new_zeros(PyObject* self, PyObject* args, PyObject
   jit::tracer::warn("new_zeros", jit::tracer::LEGACY_CONSTRUCTOR);
   auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
   OptionalDeviceGuard device_guard(device_of(self_));
-  return THPVariable_Wrap(torch::utils::new_zeros(self_.dispatch_type(), args, kwargs));
+  return THPVariable_Wrap(torch::utils::new_zeros(self_.dispatch_type(), self_.scalar_type(), args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
index fd33033..54c8134 100644 (file)
@@ -132,7 +132,8 @@ static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject
   HANDLE_TH_ERRORS
   jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
   auto& default_type = torch::tensors::get_default_tensor_type();
-  auto tensor = torch::utils::legacy_tensor_ctor(default_type, args, kwargs);
+  auto default_scalar_type = torch::tensors::get_default_scalar_type();
+  auto tensor = torch::utils::legacy_tensor_ctor(default_type, default_scalar_type, args, kwargs);
   return THPVariable_NewWithVar(type, std::move(tensor));
   END_HANDLE_TH_ERRORS
 }
index 8035aeb..80461c4 100644 (file)
@@ -107,7 +107,7 @@ static Variable applySelect(const Variable& self, int64_t dim, int64_t index, in
 
 static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
   auto& idx_type = type.toScalarType(kLong);
-  return torch::utils::indexing_tensor_from_data(idx_type, c10::nullopt, seq);
+  return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
 }
 
 static Variable valueToTensor(const at::Type & type, PyObject* value) {
index a2a17d0..4e7957d 100644 (file)
@@ -71,7 +71,8 @@ static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs
   if (!aten_type) {
     throw unavailable_type(tensor_type);
   }
-  return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*aten_type, args, kwargs));
+  auto scalar_type = static_cast<ScalarType>(tensor_type.scalar_type);
+  return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*aten_type, scalar_type, args, kwargs));
   END_HANDLE_TH_ERRORS
 }
 
@@ -387,4 +388,8 @@ at::Type& get_default_tensor_type() {
   AT_ASSERT(default_tensor_type);
   return *default_tensor_type;
 }
+
+ScalarType get_default_scalar_type() {
+  return typeMetaToScalarType(get_default_dtype());
+}
 }} // namespace torch::tensors
index 3449359..ab97537 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <torch/csrc/python_headers.h>
+#include <c10/core/ScalarType.h>
 
 namespace c10 {
 struct Device;
@@ -31,4 +32,7 @@ void py_set_default_dtype(PyObject* dtype_obj);
 // returned value will be a VariableType instance.
 at::Type& get_default_tensor_type();
 
+// Gets the ScalarType for the default tensor type.
+at::ScalarType get_default_scalar_type();
+
 }} // namespace torch::tensors
index 7b04fd6..25a16c8 100644 (file)
@@ -189,6 +189,7 @@ void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t
 
 Tensor internal_new_from_data(
     const Type& type,
+    ScalarType scalar_type,
     c10::optional<Device> device_opt,
     PyObject* data,
     bool copy_variables,
@@ -208,52 +209,54 @@ Tensor internal_new_from_data(
     }
     // infer the scalar type and device type; it's not expected to infer the layout since these constructors
     // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
-    const auto& scalar_type = type_inference ? var.scalar_type() : type.scalarType();
+    const auto& inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type;
     auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(type.device_type()));
     AutoNoGIL no_gil;
     maybe_initialize_cuda(device);
-    return var.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
+    return var.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
   }
 
 #ifdef USE_NUMPY
   if (PyArray_Check(data)) {
     AT_CHECK(!pin_memory, "Can't pin tensor constructed from numpy");
     auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
-    const auto& scalar_type = type_inference ? tensor.scalar_type() : type.scalarType();
+    const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type;
     auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
     AutoNoGIL no_gil;
     maybe_initialize_cuda(device);
-    return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
+    return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
   }
 #endif
 
   auto sizes = compute_sizes(data);
-  ScalarType scalar_type = type_inference ? infer_scalar_type(data) : type.scalarType();
-  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type).pinned_memory(pin_memory)), /*requires_grad=*/false);
+  ScalarType inferred_scalar_type = type_inference ? infer_scalar_type(data) : scalar_type;
+  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type).pinned_memory(pin_memory)), /*requires_grad=*/false);
   recursive_store(
       (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
-      scalar_type, tensor.element_size(), data);
+      inferred_scalar_type, tensor.dtype().itemsize(), data);
   auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
   AutoNoGIL no_gil;
   maybe_initialize_cuda(device);
-  return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/false);
+  return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false);
 }
 
 Tensor new_from_data_copy(
     const Type& type,
+    ScalarType scalar_type,
     c10::optional<Device> device,
     PyObject* data) {
-  return internal_new_from_data(type, std::move(device), data, true, true, false);
+  return internal_new_from_data(type, scalar_type, std::move(device), data, true, true, false);
 }
 
 Tensor legacy_new_from_sequence(
     const Type& type,
+    ScalarType scalar_type,
     c10::optional<Device> device,
     PyObject* data) {
   if (!PySequence_Check(data)) {
     throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
   }
-  return internal_new_from_data(type, std::move(device), data, false, false, false);
+  return internal_new_from_data(type, scalar_type, std::move(device), data, false, false, false);
 }
 
 void check_legacy_ctor_device(const Type& type, c10::optional<Device> device) {
@@ -265,7 +268,7 @@ void check_legacy_ctor_device(const Type& type, c10::optional<Device> device) {
   }
 }
 
-Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new(*, Device? device=None)",
     "new(*, int64_t cdata)|hidden",
@@ -299,14 +302,14 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwa
     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
       // new(sequence) binds to this signature but should be treated differently
       // unless the sequences is a torch.Size
-      return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
+      return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
     }
     return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
   }
   throw std::runtime_error("new(): invalid arguments");
 }
 
-Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new(*, Device? device=None)",
     "new(*, int64_t cdata)|hidden",
@@ -345,7 +348,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwar
     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
       // new(sequence) binds to this signature but should be treated differently
       // unless the sequences is a torch.Size
-      return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
+      return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
     }
     return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
   }
@@ -353,15 +356,15 @@ Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwar
 }
 
 // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
-const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type) {
-  const auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType());
+const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type, ScalarType scalar_type) {
+  const auto scalartype = r.scalartypeWithDefault(dtype_idx, scalar_type);
   const Device types_device_type(type.device_type());
   const auto device_type = r.isNone(device_idx) ? types_device_type : r.device(device_idx).type();
   return torch::getVariableType(scalartype, *torch::getLayout(type.backend()), device_type);
 }
 } // namespace
 
-Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new(*, Device? device=None)",
     "new(Storage storage)",
@@ -372,7 +375,7 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
   });
 
   if (type.is_sparse()) {
-    return legacy_sparse_tensor_ctor(type, args, kwargs);
+    return legacy_sparse_tensor_ctor(type, scalar_type, args, kwargs);
   }
 
   ParsedArgs<2> parsed_args;
@@ -396,18 +399,18 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
       // new(sequence) binds to this signature but should be treated differently
       // unless the sequences is a torch.Size
-      return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
+      return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
     }
     return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
   } else if (r.idx == 5) {
     auto deviceOptional = r.deviceOptional(1);
     check_legacy_ctor_device(type, deviceOptional);
-    return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
+    return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
   }
   throw std::runtime_error("new(): invalid arguments");
 }
 
-Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new(*, Device? device=None)",
     "new(Storage storage)",
@@ -418,7 +421,7 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
   });
 
   if (type.is_sparse()) {
-    return legacy_sparse_tensor_new(type, args, kwargs);
+    return legacy_sparse_tensor_new(type, scalar_type, args, kwargs);
   }
 
   ParsedArgs<3> parsed_args;
@@ -442,33 +445,34 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
       // new(sequence) binds to this signature but should be treated differently
       // unless the sequences is a torch.Size
-      return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
+      return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
     }
     return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
   } else if (r.idx == 5) {
     auto deviceOptional = r.deviceOptional(1);
     check_legacy_ctor_device(type, deviceOptional);
-    return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
+    return legacy_new_from_sequence(type, scalar_type, r.deviceOptional(1), r.pyobject(0));
   }
   throw std::runtime_error("new(): invalid arguments");
 }
 
 Tensor indexing_tensor_from_data(
     const Type& type,
+    ScalarType scalar_type,
     c10::optional<Device> device,
     PyObject* data) {
   // Specific to tensor indexing, converts an indexing list to an
   // indexing tensor (type Byte or Long)
-  ScalarType scalar_type = infer_scalar_type(data);
-  if (scalar_type == ScalarType::Byte) {
-    auto& idx_type = type.toScalarType(scalar_type);
-    return internal_new_from_data(idx_type, std::move(device), data, false, false, false);
+  ScalarType inferred_scalar_type = infer_scalar_type(data);
+  if (inferred_scalar_type == ScalarType::Byte) {
+    auto& idx_type = type.toScalarType(inferred_scalar_type);
+    return internal_new_from_data(idx_type, inferred_scalar_type, std::move(device), data, false, false, false);
   } else {
-    return internal_new_from_data(type, std::move(device), data, false, false, false);
+    return internal_new_from_data(type, scalar_type, std::move(device), data, false, false, false);
   }
 }
 
-Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject* kwargs) {
+Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
     "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
@@ -479,32 +483,34 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
     bool type_inference = r.isNone(2);
-    const auto& type = typeWithDefault(r, 2, 3, default_type);
+    const auto& type = typeWithDefault(r, 2, 3, default_type, scalar_type);
+    const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type);
     const auto& values_type = type.toDense();
     at::OptionalDeviceGuard device_guard(r.deviceOptional(3));
     // if no dtype provided, infer type based on value type.
-    Tensor values = internal_new_from_data(values_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference);
+    Tensor values = internal_new_from_data(values_type, inferred_scalar_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference);
     const auto& indices_type = values.dispatch_type().toScalarType(kLong);
-    Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(3), r.pyobject(0), false, true, false);
+    Tensor indices = internal_new_from_data(indices_type, kLong, r.deviceOptional(3), r.pyobject(0), false, true, false);
     return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4));
   } else if (r.idx == 1) {
     bool type_inference = r.isNone(3);
-    const auto& type = typeWithDefault(r, 3, 4, default_type);
+    const auto& type = typeWithDefault(r, 3, 4, default_type, scalar_type);
+    const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type);
     const auto& values_type = type.toDense();
     at::OptionalDeviceGuard device_guard(r.deviceOptional(4));
-    Tensor values = internal_new_from_data(values_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference);
+    Tensor values = internal_new_from_data(values_type, inferred_scalar_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference);
     const auto& indices_type = values.dispatch_type().toScalarType(kLong);
-    Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(4), r.pyobject(0), false, true, false);
+    Tensor indices = internal_new_from_data(indices_type, kLong, r.deviceOptional(4), r.pyobject(0), false, true, false);
     return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
   } else if (r.idx == 2) {
-    const auto& type = typeWithDefault(r, 1, 2, default_type);
+    const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type);
     at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
     return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3));
   }
   throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
 }
 
-Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
   });
@@ -523,7 +529,8 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
     bool pin_memory = r.toBool(3);
     bool args_requires_grad = r.toBool(4);
     auto new_tensor = internal_new_from_data(
-               typeWithDefault(r, 1, 2, type),
+               typeWithDefault(r, 1, 2, type, scalar_type),
+               r.scalartypeWithDefault(1, scalar_type),
                r.deviceOptional(2),
                data,
                true,
@@ -537,7 +544,7 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
   throw std::runtime_error("tensor(): invalid arguments");
 }
 
-Tensor as_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor as_tensor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   // TODO: add requires_grad once we decide on semantics for sharing data.
   static PythonArgParser parser({
     "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
@@ -548,12 +555,18 @@ Tensor as_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
   if (r.idx == 0) {
     bool type_inference = r.isNone(1);
     return internal_new_from_data(
-        typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0), false, false, type_inference);
+        typeWithDefault(r, 1, 2, type, scalar_type),
+        r.scalartypeWithDefault(1, scalar_type),
+        r.deviceOptional(2),
+        r.pyobject(0),
+        false,
+        false,
+        type_inference);
   }
   throw std::runtime_error("tensor(): invalid arguments");
 }
 
-Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor new_tensor(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   });
@@ -570,7 +583,8 @@ Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
 
     bool args_requires_grad = r.toBool(3);
     auto new_tensor = new_from_data_copy(
-               typeWithDefault(r, 1, 2, type),
+               typeWithDefault(r, 1, 2, type, scalar_type),
+               r.scalartypeWithDefault(1, scalar_type),
                r.deviceOptional(2),
                data);
     new_tensor.detach_(); // ensure new_tensor a leaf node
@@ -580,7 +594,7 @@ Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
   throw std::runtime_error("new_tensor(): invalid arguments");
 }
 
-Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
   }, /*traceable=*/true);
@@ -588,13 +602,13 @@ Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) {
   ParsedArgs<5> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, 2, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
     return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
   }
   throw std::runtime_error("new_empty(): invalid arguments");
 }
 
-Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor new_full(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new_full(IntArrayRef size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   }, /*traceable=*/true);
@@ -602,13 +616,13 @@ Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) {
   ParsedArgs<5> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 2, 3, type);
+    const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type);
     return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
   }
   throw std::runtime_error("new_full(): invalid arguments");
 }
 
-Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor new_ones(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   }, /*traceable=*/true);
@@ -616,13 +630,13 @@ Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) {
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, 2, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
     return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
   }
   throw std::runtime_error("new_ones(): invalid arguments");
 }
 
-Tensor new_zeros(const Type& type, PyObject* args, PyObject* kwargs) {
+Tensor new_zeros(const Type& type, ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
   static PythonArgParser parser({
     "new_zeros(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
   }, /*traceable=*/true);
@@ -630,7 +644,7 @@ Tensor new_zeros(const Type& type, PyObject* args, PyObject* kwargs) {
   ParsedArgs<4> parsed_args;
   auto r = parser.parse(args, kwargs, parsed_args);
   if (r.idx == 0) {
-    const auto& actual_type = typeWithDefault(r, 1, 2, type);
+    const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
     return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
   }
   throw std::runtime_error("new_zeros(): invalid arguments");
index 74b1809..451b14f 100644 (file)
@@ -6,19 +6,20 @@
 
 namespace torch { namespace utils {
 
-at::Tensor legacy_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor legacy_tensor_new(const at::Type& type, PyObject* args, PyObject* kwargs);
+at::Tensor legacy_tensor_ctor(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor legacy_tensor_new(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
 at::Tensor indexing_tensor_from_data(
     const at::Type& type,
+    at::ScalarType scalar_type,
     c10::optional<at::Device> device,
     PyObject* data);
-at::Tensor sparse_coo_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor as_tensor(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor new_tensor(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor new_empty(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor new_full(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor new_ones(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor new_zeros(const at::Type& type, PyObject* args, PyObject* kwargs);
+at::Tensor sparse_coo_tensor_ctor(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor tensor_ctor(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor as_tensor(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor new_tensor(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor new_empty(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor new_full(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor new_ones(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
+at::Tensor new_zeros(const at::Type& type, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
 
 }} // namespace torch::utils