Customize the printing of namedtuple return (#17136)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Thu, 28 Feb 2019 20:59:34 +0000 (12:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 21:07:26 +0000 (13:07 -0800)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/17112
```python
print("good", torch.randn(5,5,5).max(1))
print("terrible", torch.randn(5,5,10).max(1))
print("not as good", torch.randn(5,5,500).max(1))
print ("old behaviour = gold standard")
print(tuple(torch.randn(5,5,5).max(1)))
print(tuple(torch.randn(5,5,10).max(1)))
print(tuple(torch.randn(5,5,500).max(1)))
```
now gives
```
>>> import torch
>>> print("good", torch.randn(5,5,5).max(1))
good torch.return_types.max(
values=tensor([[ 1.2821,  1.8063,  1.8075,  1.3082, -0.1267],
        [ 0.3437,  0.7353,  1.2619,  0.7557,  1.6662],
        [ 0.8583,  1.8906,  1.0246,  1.7598,  1.1184],
        [ 1.7821,  0.0230,  0.9452,  1.0318,  1.0823],
        [ 0.4116, -0.0379, -0.1843,  1.4129,  1.8796]]),
indices=tensor([[4, 4, 3, 2, 1],
        [1, 2, 4, 1, 1],
        [2, 4, 0, 2, 1],
        [0, 2, 0, 3, 1],
        [0, 4, 4, 4, 4]]))
>>> print("terrible", torch.randn(5,5,10).max(1))
terrible torch.return_types.max(
values=tensor([[ 2.1272,  1.3664,  2.2067,  1.3974, -0.0883,  1.2505,  1.0074,  1.1217,
          0.3849,  0.6936],
        [ 0.6288, -0.4560,  1.2748,  1.5482,  1.2777,  1.6874,  0.7151,  0.6041,
          1.3572,  1.6232],
        [ 1.6703,  1.0075,  1.6480,  2.2839,  1.3390,  0.4938,  1.6449,  1.7628,
          0.8141,  2.5714],
        [ 0.7079,  1.8677,  3.2478,  1.5591,  2.4870,  0.8635, -0.1450,  1.6923,
          1.4924,  1.6298],
        [ 2.4056,  0.8002,  0.9317,  0.7455,  0.7866,  2.1191,  0.3492,  1.2095,
          1.8637,  1.7470]]),
indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4],
        [4, 2, 2, 1, 2, 2, 3, 1, 1, 3],
        [0, 3, 3, 0, 2, 1, 4, 1, 0, 1],
        [4, 1, 3, 0, 3, 2, 0, 1, 4, 3],
        [1, 0, 3, 2, 1, 0, 0, 1, 0, 1]]))
>>> print("not as good", torch.randn(5,5,500).max(1))
not as good torch.return_types.max(
values=tensor([[ 0.3877,  0.7873,  1.8701,  ...,  0.5971,  1.6103, -0.3435],
        [ 1.1300,  2.2418,  1.4239,  ...,  1.3943,  0.3872,  1.6475],
        [ 2.0656,  1.3136,  0.9896,  ...,  2.3918,  0.8226,  1.0517],
        [ 1.1054,  0.9945,  1.0561,  ...,  2.1039,  1.1524,  3.0304],
        [ 1.5041,  2.2809,  1.0883,  ...,  0.8504,  2.4774,  1.1041]]),
indices=tensor([[4, 3, 1,  ..., 1, 4, 0],
        [4, 4, 4,  ..., 3, 0, 3],
        [3, 0, 1,  ..., 2, 2, 4],
        [0, 1, 1,  ..., 4, 2, 2],
        [1, 0, 4,  ..., 2, 0, 2]]))
>>> print ("old behaviour = gold standard")
old behaviour = gold standard
>>> print(tuple(torch.randn(5,5,5).max(1)))
(tensor([[ 1.1908,  1.1807,  1.3151,  1.7184,  0.3556],
        [ 0.3798,  0.9213,  0.3001,  1.3087,  2.2419],
        [ 1.4233,  1.4814,  1.9900,  1.7744,  1.3059],
        [ 1.0026, -0.0330,  1.3061,  1.8730,  2.0685],
        [ 1.3041,  1.6458,  1.3449,  1.8948,  3.6206]]), tensor([[0, 4, 3, 4, 0],
        [1, 1, 4, 0, 4],
        [4, 1, 0, 3, 3],
        [1, 2, 1, 4, 0],
        [3, 3, 0, 3, 3]]))
>>> print(tuple(torch.randn(5,5,10).max(1)))
(tensor([[-0.1232,  0.8275,  0.6732,  1.1223,  0.8247,  1.2851,  1.6009,  1.9979,
          1.9109,  0.7313],
        [ 0.2260,  0.5922,  1.6928,  0.6024,  2.1158,  3.0619,  0.5653,  0.7426,
          0.8316,  0.6346],
        [ 0.4319,  0.2231,  0.5255,  1.7620,  1.1657,  0.8875,  0.5782,  0.6506,
          0.5032,  1.7097],
        [ 0.4137,  1.7265,  1.4260,  2.0301,  1.2244,  0.7128,  2.6345,  0.7230,
          1.3553,  1.6508],
        [ 1.0684,  1.7195,  1.4068,  0.7076, -0.0242,  0.8474,  0.8754,  1.7108,
          0.2188,  1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0],
        [1, 4, 0, 0, 3, 2, 0, 0, 3, 3],
        [2, 3, 1, 1, 4, 0, 1, 4, 4, 4],
        [0, 4, 1, 3, 2, 0, 2, 0, 3, 1],
        [1, 0, 0, 0, 0, 3, 3, 3, 2, 0]]))
>>> print(tuple(torch.randn(5,5,500).max(1)))
(tensor([[0.9395, 1.5572, 1.8797,  ..., 2.0494, 0.8202, 0.9623],
        [1.7937, 0.7225, 1.8836,  ..., 0.7927, 1.4976, 1.1813],
        [0.8558, 1.6943, 1.4192,  ..., 0.8327, 1.9661, 0.4197],
        [1.2993, 1.4995, 0.9357,  ..., 0.7810, 1.3030, 2.6216],
        [1.4206, 1.8315, 1.0338,  ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3,  ..., 3, 0, 2],
        [0, 1, 0,  ..., 0, 4, 3],
        [3, 4, 3,  ..., 3, 0, 0],
        [3, 2, 3,  ..., 1, 2, 1],
        [1, 2, 4,  ..., 3, 1, 3]]))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136

Differential Revision: D14250021

Pulled By: VitalyFedyunin

fbshipit-source-id: aae72f03b35980063b1ac1f07b8353eddb0c8b93

test/test_torch.py
tools/autograd/gen_python_functions.py
tools/autograd/templates/python_nn_functions.cpp
tools/autograd/templates/python_torch_functions.cpp
tools/autograd/templates/python_variable_methods.cpp
torch/CMakeLists.txt
torch/csrc/utils/six.h
torch/csrc/utils/structseq.cpp [new file with mode: 0644]
torch/csrc/utils/structseq.h [new file with mode: 0644]

index da77809..de1bd82 100644 (file)
@@ -14,6 +14,7 @@ import warnings
 import pickle
 import gzip
 import types
+import textwrap
 import re
 from torch._utils_internal import get_file_path, get_file_path_2
 from torch.utils.dlpack import from_dlpack, to_dlpack
@@ -7228,6 +7229,22 @@ class _TestTorchMixin(object):
         self.assertEqual(tensor.std(), tensor.std(unbiased=True))
         self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
 
+    def test_structseq_repr(self):
+        a = torch.arange(250).reshape(5, 5, 10)
+        expected = """
+        torch.return_types.max(
+        values=tensor([[ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49],
+                [ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99],
+                [140, 141, 142, 143, 144, 145, 146, 147, 148, 149],
+                [190, 191, 192, 193, 194, 195, 196, 197, 198, 199],
+                [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]),
+        indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
+                [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))"""
+        self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip())
+
     def test_var_stability(self):
         tensor = torch.FloatTensor([2281.5, 2281.25])
         self.assertEqual(tensor.var(dim=0), 0.03125)
index 7f9ebad..bb48b6a 100644 (file)
@@ -127,6 +127,7 @@ static PyTypeObject type${namedtuple_type_index};
 static bool namedtuple_type_initialized${namedtuple_type_index} = false;
 if (!namedtuple_type_initialized${namedtuple_type_index}) {
   PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
+  type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
   namedtuple_type_initialized${namedtuple_type_index} = true;
 }
 """)
index b964a36..96de550 100644 (file)
@@ -9,6 +9,7 @@
 #include "torch/csrc/autograd/utils/wrap_outputs.h"
 #include "torch/csrc/autograd/utils/python_arg_parsing.h"
 #include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
 
 #include "python_nn_functions_dispatch.h"
 
index c0d6c20..f84f4ab 100644 (file)
@@ -21,6 +21,7 @@
 #include "torch/csrc/utils/tensor_numpy.h"
 #include "torch/csrc/jit/tracer.h"
 #include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/structseq.h"
 
 #include <ATen/ATen.h>
 
index a0a8ca2..6957202 100644 (file)
@@ -25,6 +25,7 @@
 #include "torch/csrc/utils/tensor_new.h"
 #include "torch/csrc/utils/tensor_numpy.h"
 #include "torch/csrc/utils/tensor_types.h"
+#include "torch/csrc/utils/structseq.h"
 
 #include <ATen/ATen.h>
 #include "c10/util/Optional.h"
index 62be18c..f123363 100644 (file)
@@ -546,6 +546,7 @@ if (BUILD_PYTHON)
     ${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp
     ${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp
     ${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp
+    ${TORCH_SRC_DIR}/csrc/utils/structseq.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp
     ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
index 54899d4..957cc20 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <pybind11/pybind11.h>
+#include "torch/csrc/utils/structseq.h"
 
 namespace six {
 
@@ -24,4 +25,15 @@ inline bool isTuple(PyObject* obj) {
   return isTuple(pybind11::handle(obj));
 }
 
+inline PyObject *toTuple(PyStructSequence *obj) {
+  // create a new tuple object on python 2, or increase
+  // the ref count of the current object on python 3.
+#if PY_MAJOR_VERSION == 2
+  return torch::utils::structseq_slice(obj, 0, Py_SIZE(obj));
+#else
+  Py_INCREF(obj);
+  return (PyObject *)obj;
+#endif
+}
+
 }  // namespace six
diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp
new file mode 100644 (file)
index 0000000..0bf4adb
--- /dev/null
@@ -0,0 +1,107 @@
+/* Copyright Python Software Foundation
+ *
+ * This file is copy-pasted from CPython source code with modifications:
+ * https://github.com/python/cpython/blob/master/Objects/structseq.c
+ * https://github.com/python/cpython/blob/2.7/Objects/structseq.c
+ *
+ * The purpose of this file is to overwrite the default behavior
+ * of repr of structseq to provide better printting for returned
+ * structseq objects from operators, aka torch.return_types.*
+ *
+ * For more information on copyright of CPython, see:
+ * https://github.com/python/cpython#copyright-and-license-information
+ */
+
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/six.h"
+#include "structmember.h"
+#include <sstream>
+
+namespace torch {
+namespace utils {
+
+#if PY_MAJOR_VERSION == 2
+PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high)
+{
+    PyTupleObject *np;
+    Py_ssize_t i;
+
+    if (low < 0) {
+        low = 0;
+    }
+    if (high > Py_SIZE(obj)) {
+        high = Py_SIZE(obj);
+    }
+    if (high < low) {
+        high = low;
+    }
+    np = (PyTupleObject *)PyTuple_New(high-low);
+    if (np == nullptr) {
+        return nullptr;
+    }
+    for(i = low; i < high; ++i) {
+        PyObject *v = obj->ob_item[i];
+        Py_INCREF(v);
+        PyTuple_SET_ITEM(np, i-low, v);
+    }
+    return (PyObject *) np;
+}
+
+#define PyUnicode_AsUTF8 PyString_AsString
+#endif
+
+PyObject *returned_structseq_repr(PyStructSequence *obj) {
+    PyTypeObject *typ = Py_TYPE(obj);
+    PyObject *tup = six::toTuple(obj);
+    if (tup == nullptr) {
+        return nullptr;
+    }
+
+    std::stringstream ss;
+    ss << typ->tp_name << "(\n";
+    size_t num_elements = Py_SIZE(obj);
+
+    for (int i=0; i < num_elements; i++) {
+        PyObject *val, *repr;
+        const char *cname, *crepr;
+
+        cname = typ->tp_members[i].name;
+        if (cname == nullptr) {
+            PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr"
+                         " for type %.500s", i, typ->tp_name);
+            Py_DECREF(tup);
+            return nullptr;
+        }
+
+        val = PyTuple_GetItem(tup, i);
+        if (val == nullptr) {
+            Py_DECREF(tup);
+            return nullptr;
+        }
+
+        repr = PyObject_Repr(val);
+        if (repr == nullptr) {
+            Py_DECREF(tup);
+            return nullptr;
+        }
+
+        crepr = PyUnicode_AsUTF8(repr);
+        Py_DECREF(repr);
+        if (crepr == nullptr) {
+            Py_DECREF(tup);
+            return nullptr;
+        }
+
+        ss << cname << '=' << crepr;
+        if (i < num_elements - 1) {
+            ss << ",\n";
+        }
+    }
+    ss << ")";
+
+    Py_DECREF(tup);
+    return PyUnicode_FromString(ss.str().c_str());
+}
+
+}
+}
diff --git a/torch/csrc/utils/structseq.h b/torch/csrc/utils/structseq.h
new file mode 100644 (file)
index 0000000..7573242
--- /dev/null
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "torch/csrc/python_headers.h"
+
+namespace torch { namespace utils {
+
+#if PY_MAJOR_VERSION == 2
+PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high);
+#endif
+
+PyObject *returned_structseq_repr(PyStructSequence *obj);
+
+}}