Summary:
Fixes https://github.com/pytorch/pytorch/issues/17112
```python
print("good", torch.randn(5,5,5).max(1))
print("terrible", torch.randn(5,5,10).max(1))
print("not as good", torch.randn(5,5,500).max(1))
print ("old behaviour = gold standard")
print(tuple(torch.randn(5,5,5).max(1)))
print(tuple(torch.randn(5,5,10).max(1)))
print(tuple(torch.randn(5,5,500).max(1)))
```
now gives
```
>>> import torch
>>> print("good", torch.randn(5,5,5).max(1))
good torch.return_types.max(
values=tensor([[ 1.2821, 1.8063, 1.8075, 1.3082, -0.1267],
[ 0.3437, 0.7353, 1.2619, 0.7557, 1.6662],
[ 0.8583, 1.8906, 1.0246, 1.7598, 1.1184],
[ 1.7821, 0.0230, 0.9452, 1.0318, 1.0823],
[ 0.4116, -0.0379, -0.1843, 1.4129, 1.8796]]),
indices=tensor([[4, 4, 3, 2, 1],
[1, 2, 4, 1, 1],
[2, 4, 0, 2, 1],
[0, 2, 0, 3, 1],
[0, 4, 4, 4, 4]]))
>>> print("terrible", torch.randn(5,5,10).max(1))
terrible torch.return_types.max(
values=tensor([[ 2.1272, 1.3664, 2.2067, 1.3974, -0.0883, 1.2505, 1.0074, 1.1217,
0.3849, 0.6936],
[ 0.6288, -0.4560, 1.2748, 1.5482, 1.2777, 1.6874, 0.7151, 0.6041,
1.3572, 1.6232],
[ 1.6703, 1.0075, 1.6480, 2.2839, 1.3390, 0.4938, 1.6449, 1.7628,
0.8141, 2.5714],
[ 0.7079, 1.8677, 3.2478, 1.5591, 2.4870, 0.8635, -0.1450, 1.6923,
1.4924, 1.6298],
[ 2.4056, 0.8002, 0.9317, 0.7455, 0.7866, 2.1191, 0.3492, 1.2095,
1.8637, 1.7470]]),
indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4],
[4, 2, 2, 1, 2, 2, 3, 1, 1, 3],
[0, 3, 3, 0, 2, 1, 4, 1, 0, 1],
[4, 1, 3, 0, 3, 2, 0, 1, 4, 3],
[1, 0, 3, 2, 1, 0, 0, 1, 0, 1]]))
>>> print("not as good", torch.randn(5,5,500).max(1))
not as good torch.return_types.max(
values=tensor([[ 0.3877, 0.7873, 1.8701, ..., 0.5971, 1.6103, -0.3435],
[ 1.1300, 2.2418, 1.4239, ..., 1.3943, 0.3872, 1.6475],
[ 2.0656, 1.3136, 0.9896, ..., 2.3918, 0.8226, 1.0517],
[ 1.1054, 0.9945, 1.0561, ..., 2.1039, 1.1524, 3.0304],
[ 1.5041, 2.2809, 1.0883, ..., 0.8504, 2.4774, 1.1041]]),
indices=tensor([[4, 3, 1, ..., 1, 4, 0],
[4, 4, 4, ..., 3, 0, 3],
[3, 0, 1, ..., 2, 2, 4],
[0, 1, 1, ..., 4, 2, 2],
[1, 0, 4, ..., 2, 0, 2]]))
>>> print ("old behaviour = gold standard")
old behaviour = gold standard
>>> print(tuple(torch.randn(5,5,5).max(1)))
(tensor([[ 1.1908, 1.1807, 1.3151, 1.7184, 0.3556],
[ 0.3798, 0.9213, 0.3001, 1.3087, 2.2419],
[ 1.4233, 1.4814, 1.9900, 1.7744, 1.3059],
[ 1.0026, -0.0330, 1.3061, 1.8730, 2.0685],
[ 1.3041, 1.6458, 1.3449, 1.8948, 3.6206]]), tensor([[0, 4, 3, 4, 0],
[1, 1, 4, 0, 4],
[4, 1, 0, 3, 3],
[1, 2, 1, 4, 0],
[3, 3, 0, 3, 3]]))
>>> print(tuple(torch.randn(5,5,10).max(1)))
(tensor([[-0.1232, 0.8275, 0.6732, 1.1223, 0.8247, 1.2851, 1.6009, 1.9979,
1.9109, 0.7313],
[ 0.2260, 0.5922, 1.6928, 0.6024, 2.1158, 3.0619, 0.5653, 0.7426,
0.8316, 0.6346],
[ 0.4319, 0.2231, 0.5255, 1.7620, 1.1657, 0.8875, 0.5782, 0.6506,
0.5032, 1.7097],
[ 0.4137, 1.7265, 1.4260, 2.0301, 1.2244, 0.7128, 2.6345, 0.7230,
1.3553, 1.6508],
[ 1.0684, 1.7195, 1.4068, 0.7076, -0.0242, 0.8474, 0.8754, 1.7108,
0.2188, 1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0],
[1, 4, 0, 0, 3, 2, 0, 0, 3, 3],
[2, 3, 1, 1, 4, 0, 1, 4, 4, 4],
[0, 4, 1, 3, 2, 0, 2, 0, 3, 1],
[1, 0, 0, 0, 0, 3, 3, 3, 2, 0]]))
>>> print(tuple(torch.randn(5,5,500).max(1)))
(tensor([[0.9395, 1.5572, 1.8797, ..., 2.0494, 0.8202, 0.9623],
[1.7937, 0.7225, 1.8836, ..., 0.7927, 1.4976, 1.1813],
[0.8558, 1.6943, 1.4192, ..., 0.8327, 1.9661, 0.4197],
[1.2993, 1.4995, 0.9357, ..., 0.7810, 1.3030, 2.6216],
[1.4206, 1.8315, 1.0338, ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3, ..., 3, 0, 2],
[0, 1, 0, ..., 0, 4, 3],
[3, 4, 3, ..., 3, 0, 0],
[3, 2, 3, ..., 1, 2, 1],
[1, 2, 4, ..., 3, 1, 3]]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136
Differential Revision:
D14250021
Pulled By: VitalyFedyunin
fbshipit-source-id:
aae72f03b35980063b1ac1f07b8353eddb0c8b93
import pickle
import gzip
import types
+import textwrap
import re
from torch._utils_internal import get_file_path, get_file_path_2
from torch.utils.dlpack import from_dlpack, to_dlpack
self.assertEqual(tensor.std(), tensor.std(unbiased=True))
self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
+ def test_structseq_repr(self):
+ a = torch.arange(250).reshape(5, 5, 10)
+ expected = """
+ torch.return_types.max(
+ values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
+ [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
+ [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
+ [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
+ [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]),
+ indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+ [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+ [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+ [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+ [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))"""
+ self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
+
def test_var_stability(self):
tensor = torch.FloatTensor([2281.5, 2281.25])
self.assertEqual(tensor.var(dim=0), 0.03125)
static bool namedtuple_type_initialized${namedtuple_type_index} = false;
if (!namedtuple_type_initialized${namedtuple_type_index}) {
PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
+ type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
namedtuple_type_initialized${namedtuple_type_index} = true;
}
""")
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
#include "python_nn_functions_dispatch.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "torch/csrc/jit/tracer.h"
#include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/structseq.h"
#include <ATen/ATen.h>
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_numpy.h"
#include "torch/csrc/utils/tensor_types.h"
+#include "torch/csrc/utils/structseq.h"
#include <ATen/ATen.h>
#include "c10/util/Optional.h"
${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp
${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp
${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp
+ ${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
#pragma once
#include <pybind11/pybind11.h>
+#include "torch/csrc/utils/structseq.h"
namespace six {
return isTuple(pybind11::handle(obj));
}
+inline PyObject *toTuple(PyStructSequence *obj) {
+ // create a new tuple object on python 2, or increase
+ // the ref count of the current object on python 3.
+#if PY_MAJOR_VERSION == 2
+ return torch::utils::structseq_slice(obj, 0, Py_SIZE(obj));
+#else
+ Py_INCREF(obj);
+ return (PyObject *)obj;
+#endif
+}
+
} // namespace six
--- /dev/null
+/* Copyright Python Software Foundation
+ *
+ * This file is copy-pasted from CPython source code with modifications:
+ * https://github.com/python/cpython/blob/master/Objects/structseq.c
+ * https://github.com/python/cpython/blob/2.7/Objects/structseq.c
+ *
+ * The purpose of this file is to overwrite the default behavior
+ * of repr of structseq to provide better printting for returned
+ * structseq objects from operators, aka torch.return_types.*
+ *
+ * For more information on copyright of CPython, see:
+ * https://github.com/python/cpython#copyright-and-license-information
+ */
+
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/six.h"
+#include "structmember.h"
+#include <sstream>
+
+namespace torch {
+namespace utils {
+
+#if PY_MAJOR_VERSION == 2
+PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high)
+{
+ PyTupleObject *np;
+ Py_ssize_t i;
+
+ if (low < 0) {
+ low = 0;
+ }
+ if (high > Py_SIZE(obj)) {
+ high = Py_SIZE(obj);
+ }
+ if (high < low) {
+ high = low;
+ }
+ np = (PyTupleObject *)PyTuple_New(high-low);
+ if (np == nullptr) {
+ return nullptr;
+ }
+ for(i = low; i < high; ++i) {
+ PyObject *v = obj->ob_item[i];
+ Py_INCREF(v);
+ PyTuple_SET_ITEM(np, i-low, v);
+ }
+ return (PyObject *) np;
+}
+
+#define PyUnicode_AsUTF8 PyString_AsString
+#endif
+
+PyObject *returned_structseq_repr(PyStructSequence *obj) {
+ PyTypeObject *typ = Py_TYPE(obj);
+ PyObject *tup = six::toTuple(obj);
+ if (tup == nullptr) {
+ return nullptr;
+ }
+
+ std::stringstream ss;
+ ss << typ->tp_name << "(\n";
+ size_t num_elements = Py_SIZE(obj);
+
+ for (int i=0; i < num_elements; i++) {
+ PyObject *val, *repr;
+ const char *cname, *crepr;
+
+ cname = typ->tp_members[i].name;
+ if (cname == nullptr) {
+ PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
+ " for type %.500s", i, typ->tp_name);
+ Py_DECREF(tup);
+ return nullptr;
+ }
+
+ val = PyTuple_GetItem(tup, i);
+ if (val == nullptr) {
+ Py_DECREF(tup);
+ return nullptr;
+ }
+
+ repr = PyObject_Repr(val);
+ if (repr == nullptr) {
+ Py_DECREF(tup);
+ return nullptr;
+ }
+
+ crepr = PyUnicode_AsUTF8(repr);
+ Py_DECREF(repr);
+ if (crepr == nullptr) {
+ Py_DECREF(tup);
+ return nullptr;
+ }
+
+ ss << cname << '=' << crepr;
+ if (i < num_elements - 1) {
+ ss << ",\n";
+ }
+ }
+ ss << ")";
+
+ Py_DECREF(tup);
+ return PyUnicode_FromString(ss.str().c_str());
+}
+
+}
+}
--- /dev/null
+#pragma once
+
+#include "torch/csrc/python_headers.h"
+
+namespace torch { namespace utils {
+
+#if PY_MAJOR_VERSION == 2
+PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high);
+#endif
+
+PyObject *returned_structseq_repr(PyStructSequence *obj);
+
+}}