From ddda563f226391095ff44158afe319701f3f9e09 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Sat, 13 Apr 2019 08:28:11 -0700 Subject: [PATCH] Cleanup ScriptModule bindings (#19138) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19138 ghimport-source-id: 10f810f5e7551c1cb65fc4799744083bd7ffd1ee Reviewed By: jamesr66a Differential Revision: D14886945 Pulled By: zdevito fbshipit-source-id: a5e5bb08694d03166a7516ec038656c2a02e7896 --- test/test_jit.py | 34 +++++++++++++++++----------------- torch/csrc/jit/init.cpp | 7 ++++++- torch/csrc/jit/script/init.cpp | 42 +++++++++++++++--------------------------- 3 files changed, 38 insertions(+), 45 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 56f0fc2..49908a8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -42,7 +42,7 @@ from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck -from torch._C import TensorType, parse_ir, _propagate_shapes +from torch._C import TensorType, parse_ir, _propagate_shapes, _jit_python_print from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -2642,7 +2642,7 @@ graph(%x : Tensor, def foo(x, y): return 2 * x + y - r, _ = foo._python_print() + r, _ = _jit_python_print(foo) mod = torch.jit.ScriptModule() torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), []) self.assertExpected(mod.graph.pretty_print()) @@ -2931,7 +2931,7 @@ class TestScript(JitTestCase): def bar(x): return foo(x, y=x) ''') - self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema()) + self.assertTrue('*' in str(cu.module._get_method('foo').schema)) with self.assertRaisesRegex(RuntimeError, "not provided"): torch.jit.CompilationUnit(''' def foo(x, *, y) -> Tuple[Tensor, Tensor]: @@ -2951,7 +2951,7 @@ class TestScript(JitTestCase): def foo(): return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan - pp, table = foo._get_method('forward').python_print() + pp, table = _jit_python_print(foo) ppv = "op_version_set = 0\n{}".format(pp) sm = torch.jit.ScriptModule() torch._C._jit_import_methods(sm, ppv, table) @@ -8350,7 +8350,7 @@ a") return x, x ''') - self.assertExpected(cu.__getattr__('foo').pretty_print_schema()) + self.assertExpected(str(cu.__getattr__('foo').schema)) def test_parser_type_annotations_comment(self): cu = torch.jit.CompilationUnit(''' @@ -8359,7 +8359,7 @@ a") return x, x ''') - self.assertExpected(cu.__getattr__('foo').pretty_print_schema()) + self.assertExpected(str(cu.__getattr__('foo').schema)) def test_parser_type_annotations_unknown_type(self): with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'): @@ -9578,7 +9578,7 @@ a") # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor return x - self.assertExpected(foo.__getattr__('forward').pretty_print_schema()) + self.assertExpected(str(foo.__getattr__('forward').schema)) def test_annotated_script_method(self): class SM(torch.jit.ScriptModule): @@ -9589,7 +9589,7 @@ a") sm = SM() - self.assertExpected(sm.__getattr__('forward').pretty_print_schema()) + self.assertExpected(str(sm.__getattr__('forward').schema)) def test_annotated_script_fn_return_mismatch(self): with self.assertRaisesRegex(RuntimeError, "but is actually of type"): @@ -9683,7 +9683,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): cu = torch.jit.CompilationUnit(self.format_code(code, pair)) - test_str.append(cu.__getattr__('foo').pretty_print_schema()) + test_str.append(str(cu.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) # String frontend , Python 3-style type annotations , Script method @@ -9700,7 +9700,7 @@ a") for pair in self.type_input_return_pairs(): tm = TestModule() tm.define(self.format_code(code, pair)) - test_str.append(tm.__getattr__('foo').pretty_print_schema()) + test_str.append(str(tm.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) # String frontend , MyPy-style type comments , Script function @@ -9713,7 +9713,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): cu = torch.jit.CompilationUnit(self.format_code(code, pair)) - test_str.append(cu.__getattr__('foo').pretty_print_schema()) + test_str.append(str(cu.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) # String frontend , MyPy-style type comments , Script method @@ -9732,7 +9732,7 @@ a") for pair in self.type_input_return_pairs(): tm = TestModule() tm.define(self.format_code(code, pair)) - test_str.append(tm.__getattr__('foo').pretty_print_schema()) + test_str.append(str(tm.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) # Helper function to eval Python3 code without causing a syntax error for @@ -9764,7 +9764,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): fn = self._get_py3_code(self.format_code(code, pair), 'foo') - test_str.append(fn.__getattr__('forward').pretty_print_schema()) + test_str.append(str(fn.__getattr__('forward').schema)) self.assertExpected("\n".join(test_str)) # Python AST Frontend , Python 3-style type annotations , Script method @@ -9786,7 +9786,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): fn = self._get_py3_code(self.format_code(code, pair), 'instance') - test_str.append(fn.__getattr__('foo').pretty_print_schema()) + test_str.append(str(fn.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) # Python AST Frontend , MyPy-style type comments , Script function @@ -9803,7 +9803,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): fn = self._get_py3_code(self.format_code(code, pair), 'foo') - test_str.append(fn.__getattr__('forward').pretty_print_schema()) + test_str.append(str(fn.__getattr__('forward').schema)) self.assertExpected("\n".join(test_str)) # Python AST Frontend , MyPy-style type comments , Script method @@ -9822,7 +9822,7 @@ a") test_str = [] for pair in self.type_input_return_pairs(): fn = self._get_py3_code(self.format_code(code, pair), 'instance') - test_str.append(fn.__getattr__('foo').pretty_print_schema()) + test_str.append(str(fn.__getattr__('foo').schema)) self.assertExpected("\n".join(test_str)) def test_method_casts_script(self): @@ -11015,7 +11015,7 @@ a") self.assertEqual(m.some_state, torch.zeros(1) + 100) # Export and ensure ignored code not present - pp, constants = m._python_print() + pp, constants = _jit_python_print(m._get_method('forward')) printed = torch.jit.ScriptModule() ppv = "op_version_set = 0\n{}".format(pp) torch._C._jit_import_methods(printed, ppv, constants) diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 9bdb9e9..b67cc90 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -385,7 +385,12 @@ void initJITBindings(PyObject* module) { .def_property_readonly( "arguments", [](FunctionSchema& self) { return self.arguments(); }) .def_property_readonly( - "returns", [](FunctionSchema& self) { return self.returns(); }); + "returns", [](FunctionSchema& self) { return self.returns(); }) + .def("__str__", [](FunctionSchema& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }); py::class_(m, "Argument") .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index caeb618..cd7ddaa 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -995,15 +995,6 @@ void initJitScriptBindings(PyObject* module) { tuple_slice(std::move(args), 1), std::move(kwargs)); }) - .def( - "_python_print", - [](Module& self) { - std::ostringstream ss; - std::vector tensors; - std::vector classes; - PythonPrint(ss, self, tensors, classes, true); - return std::make_pair(ss.str(), tensors); - }) .def_property_readonly( "code", [](Module& self) { @@ -1064,24 +1055,7 @@ void initJitScriptBindings(PyObject* module) { return self.graph_for(createStackForSchema( self.getSchema(), tuple_slice(std::move(args), 1), kwargs)); }) - .def("schema", &Method::getSchema) - .def( - "pretty_print_schema", - [](Method& m) { - const FunctionSchema& schema = m.getSchema(); - std::stringstream ss; - ss << schema; - return ss.str(); - }) - .def( - "python_print", - [](Method& m) { - std::ostringstream oss; - std::vector constants; - std::vector classes; - PythonPrint(oss, m, constants, classes, true); - return std::make_pair(oss.str(), std::move(constants)); - }) + .def_property_readonly("schema", &Method::getSchema) .def_property_readonly("code", [](Method& self) { std::ostringstream ss; std::vector tensors; @@ -1167,6 +1141,20 @@ void initJitScriptBindings(PyObject* module) { m.def( "_propagate_and_assign_input_and_output_shapes", _propagate_and_assign_input_and_output_shapes); + m.def("_jit_python_print", [](py::object obj) { + std::ostringstream ss; + std::vector constants; + std::vector classes; + if (py::isinstance(obj)) { + auto& self = py::cast(obj); + PythonPrint(ss, self, constants, classes, true); + } else { + auto& m = py::cast(obj); + PythonPrint(ss, m, constants, classes, true); + } + return std::make_pair(ss.str(), std::move(constants)); + }); + py::class_(m, "FileCheck") .def(py::init<>()) .def("check", &testing::FileCheck::check) -- 2.7.4