From 2e5a8cee82b21b917d3d5cf6b77577f798b596a5 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 28 Feb 2019 12:59:34 -0800 Subject: [PATCH] Customize the printing of namedtuple return (#17136) Summary: Fixes https://github.com/pytorch/pytorch/issues/17112 ```python print("good", torch.randn(5,5,5).max(1)) print("terrible", torch.randn(5,5,10).max(1)) print("not as good", torch.randn(5,5,500).max(1)) print ("old behaviour = gold standard") print(tuple(torch.randn(5,5,5).max(1))) print(tuple(torch.randn(5,5,10).max(1))) print(tuple(torch.randn(5,5,500).max(1))) ``` now gives ``` >>> import torch >>> print("good", torch.randn(5,5,5).max(1)) good torch.return_types.max( values=tensor([[ 1.2821, 1.8063, 1.8075, 1.3082, -0.1267], [ 0.3437, 0.7353, 1.2619, 0.7557, 1.6662], [ 0.8583, 1.8906, 1.0246, 1.7598, 1.1184], [ 1.7821, 0.0230, 0.9452, 1.0318, 1.0823], [ 0.4116, -0.0379, -0.1843, 1.4129, 1.8796]]), indices=tensor([[4, 4, 3, 2, 1], [1, 2, 4, 1, 1], [2, 4, 0, 2, 1], [0, 2, 0, 3, 1], [0, 4, 4, 4, 4]])) >>> print("terrible", torch.randn(5,5,10).max(1)) terrible torch.return_types.max( values=tensor([[ 2.1272, 1.3664, 2.2067, 1.3974, -0.0883, 1.2505, 1.0074, 1.1217, 0.3849, 0.6936], [ 0.6288, -0.4560, 1.2748, 1.5482, 1.2777, 1.6874, 0.7151, 0.6041, 1.3572, 1.6232], [ 1.6703, 1.0075, 1.6480, 2.2839, 1.3390, 0.4938, 1.6449, 1.7628, 0.8141, 2.5714], [ 0.7079, 1.8677, 3.2478, 1.5591, 2.4870, 0.8635, -0.1450, 1.6923, 1.4924, 1.6298], [ 2.4056, 0.8002, 0.9317, 0.7455, 0.7866, 2.1191, 0.3492, 1.2095, 1.8637, 1.7470]]), indices=tensor([[1, 1, 0, 0, 0, 0, 3, 4, 4, 4], [4, 2, 2, 1, 2, 2, 3, 1, 1, 3], [0, 3, 3, 0, 2, 1, 4, 1, 0, 1], [4, 1, 3, 0, 3, 2, 0, 1, 4, 3], [1, 0, 3, 2, 1, 0, 0, 1, 0, 1]])) >>> print("not as good", torch.randn(5,5,500).max(1)) not as good torch.return_types.max( values=tensor([[ 0.3877, 0.7873, 1.8701, ..., 0.5971, 1.6103, -0.3435], [ 1.1300, 2.2418, 1.4239, ..., 1.3943, 0.3872, 1.6475], [ 2.0656, 1.3136, 0.9896, ..., 2.3918, 0.8226, 1.0517], [ 1.1054, 0.9945, 1.0561, ..., 2.1039, 1.1524, 3.0304], [ 1.5041, 2.2809, 1.0883, ..., 0.8504, 2.4774, 1.1041]]), indices=tensor([[4, 3, 1, ..., 1, 4, 0], [4, 4, 4, ..., 3, 0, 3], [3, 0, 1, ..., 2, 2, 4], [0, 1, 1, ..., 4, 2, 2], [1, 0, 4, ..., 2, 0, 2]])) >>> print ("old behaviour = gold standard") old behaviour = gold standard >>> print(tuple(torch.randn(5,5,5).max(1))) (tensor([[ 1.1908, 1.1807, 1.3151, 1.7184, 0.3556], [ 0.3798, 0.9213, 0.3001, 1.3087, 2.2419], [ 1.4233, 1.4814, 1.9900, 1.7744, 1.3059], [ 1.0026, -0.0330, 1.3061, 1.8730, 2.0685], [ 1.3041, 1.6458, 1.3449, 1.8948, 3.6206]]), tensor([[0, 4, 3, 4, 0], [1, 1, 4, 0, 4], [4, 1, 0, 3, 3], [1, 2, 1, 4, 0], [3, 3, 0, 3, 3]])) >>> print(tuple(torch.randn(5,5,10).max(1))) (tensor([[-0.1232, 0.8275, 0.6732, 1.1223, 0.8247, 1.2851, 1.6009, 1.9979, 1.9109, 0.7313], [ 0.2260, 0.5922, 1.6928, 0.6024, 2.1158, 3.0619, 0.5653, 0.7426, 0.8316, 0.6346], [ 0.4319, 0.2231, 0.5255, 1.7620, 1.1657, 0.8875, 0.5782, 0.6506, 0.5032, 1.7097], [ 0.4137, 1.7265, 1.4260, 2.0301, 1.2244, 0.7128, 2.6345, 0.7230, 1.3553, 1.6508], [ 1.0684, 1.7195, 1.4068, 0.7076, -0.0242, 0.8474, 0.8754, 1.7108, 0.2188, 1.1584]]), tensor([[0, 1, 3, 4, 2, 3, 4, 2, 1, 0], [1, 4, 0, 0, 3, 2, 0, 0, 3, 3], [2, 3, 1, 1, 4, 0, 1, 4, 4, 4], [0, 4, 1, 3, 2, 0, 2, 0, 3, 1], [1, 0, 0, 0, 0, 3, 3, 3, 2, 0]])) >>> print(tuple(torch.randn(5,5,500).max(1))) (tensor([[0.9395, 1.5572, 1.8797, ..., 2.0494, 0.8202, 0.9623], [1.7937, 0.7225, 1.8836, ..., 0.7927, 1.4976, 1.1813], [0.8558, 1.6943, 1.4192, ..., 0.8327, 1.9661, 0.4197], [1.2993, 1.4995, 0.9357, ..., 0.7810, 1.3030, 2.6216], [1.4206, 1.8315, 1.0338, ..., 1.4312, 1.3198, 1.5233]]), tensor([[0, 4, 3, ..., 3, 0, 2], [0, 1, 0, ..., 0, 4, 3], [3, 4, 3, ..., 3, 0, 0], [3, 2, 3, ..., 1, 2, 1], [1, 2, 4, ..., 3, 1, 3]])) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/17136 Differential Revision: D14250021 Pulled By: VitalyFedyunin fbshipit-source-id: aae72f03b35980063b1ac1f07b8353eddb0c8b93 --- test/test_torch.py | 17 ++++ tools/autograd/gen_python_functions.py | 1 + tools/autograd/templates/python_nn_functions.cpp | 1 + .../autograd/templates/python_torch_functions.cpp | 1 + .../autograd/templates/python_variable_methods.cpp | 1 + torch/CMakeLists.txt | 1 + torch/csrc/utils/six.h | 12 +++ torch/csrc/utils/structseq.cpp | 107 +++++++++++++++++++++ torch/csrc/utils/structseq.h | 13 +++ 9 files changed, 154 insertions(+) create mode 100644 torch/csrc/utils/structseq.cpp create mode 100644 torch/csrc/utils/structseq.h diff --git a/test/test_torch.py b/test/test_torch.py index da77809..de1bd82 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -14,6 +14,7 @@ import warnings import pickle import gzip import types +import textwrap import re from torch._utils_internal import get_file_path, get_file_path_2 from torch.utils.dlpack import from_dlpack, to_dlpack @@ -7228,6 +7229,22 @@ class _TestTorchMixin(object): self.assertEqual(tensor.std(), tensor.std(unbiased=True)) self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) + def test_structseq_repr(self): + a = torch.arange(250).reshape(5, 5, 10) + expected = """ + torch.return_types.max( + values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], + [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], + [140, 141, 142, 143, 144, 145, 146, 147, 148, 149], + [190, 191, 192, 193, 194, 195, 196, 197, 198, 199], + [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]), + indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" + self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip()) + def test_var_stability(self): tensor = torch.FloatTensor([2281.5, 2281.25]) self.assertEqual(tensor.var(dim=0), 0.03125) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 7f9ebad..bb48b6a 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -127,6 +127,7 @@ static PyTypeObject type${namedtuple_type_index}; static bool namedtuple_type_initialized${namedtuple_type_index} = false; if (!namedtuple_type_initialized${namedtuple_type_index}) { PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index}); + type${namedtuple_type_index}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; namedtuple_type_initialized${namedtuple_type_index} = true; } """) diff --git a/tools/autograd/templates/python_nn_functions.cpp b/tools/autograd/templates/python_nn_functions.cpp index b964a36..96de550 100644 --- a/tools/autograd/templates/python_nn_functions.cpp +++ b/tools/autograd/templates/python_nn_functions.cpp @@ -9,6 +9,7 @@ #include "torch/csrc/autograd/utils/wrap_outputs.h" #include "torch/csrc/autograd/utils/python_arg_parsing.h" #include "torch/csrc/utils/python_arg_parser.h" +#include "torch/csrc/utils/structseq.h" #include "python_nn_functions_dispatch.h" diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index c0d6c20..f84f4ab 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -21,6 +21,7 @@ #include "torch/csrc/utils/tensor_numpy.h" #include "torch/csrc/jit/tracer.h" #include "torch/csrc/autograd/generated/variable_factories.h" +#include "torch/csrc/utils/structseq.h" #include diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index a0a8ca2..6957202 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -25,6 +25,7 @@ #include "torch/csrc/utils/tensor_new.h" #include "torch/csrc/utils/tensor_numpy.h" #include "torch/csrc/utils/tensor_types.h" +#include "torch/csrc/utils/structseq.h" #include #include "c10/util/Optional.h" diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 62be18c..f1233635 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -546,6 +546,7 @@ if (BUILD_PYTHON) ${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp ${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp ${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp + ${TORCH_SRC_DIR}/csrc/utils/structseq.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h index 54899d4..957cc20 100644 --- a/torch/csrc/utils/six.h +++ b/torch/csrc/utils/six.h @@ -1,6 +1,7 @@ #pragma once #include +#include "torch/csrc/utils/structseq.h" namespace six { @@ -24,4 +25,15 @@ inline bool isTuple(PyObject* obj) { return isTuple(pybind11::handle(obj)); } +inline PyObject *toTuple(PyStructSequence *obj) { + // create a new tuple object on python 2, or increase + // the ref count of the current object on python 3. +#if PY_MAJOR_VERSION == 2 + return torch::utils::structseq_slice(obj, 0, Py_SIZE(obj)); +#else + Py_INCREF(obj); + return (PyObject *)obj; +#endif +} + } // namespace six diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp new file mode 100644 index 0000000..0bf4adb --- /dev/null +++ b/torch/csrc/utils/structseq.cpp @@ -0,0 +1,107 @@ +/* Copyright Python Software Foundation + * + * This file is copy-pasted from CPython source code with modifications: + * https://github.com/python/cpython/blob/master/Objects/structseq.c + * https://github.com/python/cpython/blob/2.7/Objects/structseq.c + * + * The purpose of this file is to overwrite the default behavior + * of repr of structseq to provide better printting for returned + * structseq objects from operators, aka torch.return_types.* + * + * For more information on copyright of CPython, see: + * https://github.com/python/cpython#copyright-and-license-information + */ + +#include "torch/csrc/utils/structseq.h" +#include "torch/csrc/utils/six.h" +#include "structmember.h" +#include + +namespace torch { +namespace utils { + +#if PY_MAJOR_VERSION == 2 +PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high) +{ + PyTupleObject *np; + Py_ssize_t i; + + if (low < 0) { + low = 0; + } + if (high > Py_SIZE(obj)) { + high = Py_SIZE(obj); + } + if (high < low) { + high = low; + } + np = (PyTupleObject *)PyTuple_New(high-low); + if (np == nullptr) { + return nullptr; + } + for(i = low; i < high; ++i) { + PyObject *v = obj->ob_item[i]; + Py_INCREF(v); + PyTuple_SET_ITEM(np, i-low, v); + } + return (PyObject *) np; +} + +#define PyUnicode_AsUTF8 PyString_AsString +#endif + +PyObject *returned_structseq_repr(PyStructSequence *obj) { + PyTypeObject *typ = Py_TYPE(obj); + PyObject *tup = six::toTuple(obj); + if (tup == nullptr) { + return nullptr; + } + + std::stringstream ss; + ss << typ->tp_name << "(\n"; + size_t num_elements = Py_SIZE(obj); + + for (int i=0; i < num_elements; i++) { + PyObject *val, *repr; + const char *cname, *crepr; + + cname = typ->tp_members[i].name; + if (cname == nullptr) { + PyErr_Format(PyExc_SystemError, "In structseq_repr(), member %d name is nullptr" + " for type %.500s", i, typ->tp_name); + Py_DECREF(tup); + return nullptr; + } + + val = PyTuple_GetItem(tup, i); + if (val == nullptr) { + Py_DECREF(tup); + return nullptr; + } + + repr = PyObject_Repr(val); + if (repr == nullptr) { + Py_DECREF(tup); + return nullptr; + } + + crepr = PyUnicode_AsUTF8(repr); + Py_DECREF(repr); + if (crepr == nullptr) { + Py_DECREF(tup); + return nullptr; + } + + ss << cname << '=' << crepr; + if (i < num_elements - 1) { + ss << ",\n"; + } + } + ss << ")"; + + Py_DECREF(tup); + return PyUnicode_FromString(ss.str().c_str()); +} + +} +} diff --git a/torch/csrc/utils/structseq.h b/torch/csrc/utils/structseq.h new file mode 100644 index 0000000..7573242 --- /dev/null +++ b/torch/csrc/utils/structseq.h @@ -0,0 +1,13 @@ +#pragma once + +#include "torch/csrc/python_headers.h" + +namespace torch { namespace utils { + +#if PY_MAJOR_VERSION == 2 +PyObject *structseq_slice(PyStructSequence *obj, Py_ssize_t low, Py_ssize_t high); +#endif + +PyObject *returned_structseq_repr(PyStructSequence *obj); + +}} -- 2.7.4