self.assertEqual((1, 2, 3), a[True, True].shape)
self.assertEqual((1, 2, 3), a[true, true].shape)
+ def test_boolean_list_indexing(self):
+ # Indexing a 2-dimensional array with
+ # boolean lists
+ a = tensor([[1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]])
+ b = [True, False, False]
+ c = [True, True, False]
+ self.assertEqual(a[b], tensor([[1, 2, 3]]))
+ self.assertEqual(a[b, b], tensor([1]))
+ self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]]))
+ self.assertEqual(a[c, c], tensor([1, 5]))
+
def test_everything_returns_views(self):
# Before `...` would return a itself.
a = tensor([5])
static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
auto& idx_type = type.toScalarType(kLong);
- return torch::utils::legacy_new_from_data(idx_type, c10::nullopt, seq);
+ return torch::utils::indexing_tensor_from_data(idx_type, c10::nullopt, seq);
}
static Variable valueToTensor(const at::Type & type, PyObject* value) {
if (!PySequence_Check(data)) {
throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
}
- return legacy_new_from_data(type, std::move(device), data);
+ return internal_new_from_data(type, std::move(device), data, false, false, false);
}
void check_legacy_ctor_device(const Type& type, c10::optional<Device> device) {
throw std::runtime_error("new(): invalid arguments");
}
-Tensor legacy_new_from_data(
+Tensor indexing_tensor_from_data(
const Type& type,
c10::optional<Device> device,
PyObject* data) {
- return internal_new_from_data(type, std::move(device), data, false, false, false);
+ // Specific to tensor indexing, converts an indexing list to an
+ // indexing tensor (type Byte or Long)
+ ScalarType scalar_type = infer_scalar_type(data);
+ if (scalar_type == ScalarType::Byte) {
+ auto& idx_type = type.toScalarType(scalar_type);
+ return internal_new_from_data(idx_type, std::move(device), data, false, false, false);
+ } else {
+ return internal_new_from_data(type, std::move(device), data, false, false, false);
+ }
}
Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject* kwargs) {
at::Tensor legacy_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
at::Tensor legacy_tensor_new(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor legacy_new_from_data(
+at::Tensor indexing_tensor_from_data(
const at::Type& type,
c10::optional<at::Device> device,
PyObject* data);