From: Michael Suo Date: Fri, 22 Mar 2019 23:24:36 +0000 (-0700) Subject: python interop for script classes (#18148) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~670 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=10751d5fb44f339ee96873cb8043d5393fd73f8f;p=platform%2Fupstream%2Fpytorch.git python interop for script classes (#18148) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18148 ghimport-source-id: 40a9d745dc9aeba53d098743323fcbd50ca65137 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18148 py interop** Support for converting classes between the Python–TorchScript boundary. Like other TorchScript values, ScriptClasses are native Python values when used in Python and IValues when used in TorchScript. Notably, there is a copy across this boundary, which will be surprising to users who will expect standard Python reference semantics. I have some ideas for fixing that, but it's a more involved process. Reviewed By: jamesr66a Differential Revision: D14526259 fbshipit-source-id: 5916e3032488a42dc7da756c1826d7c040a21ebd --- diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 5cbe64e..9b18a81 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1151,6 +1151,18 @@ struct CAFFE2_API ClassType : public Type { return attributeTypes_[pos]; } + TypePtr getAttribute(size_t slot) const { + AT_ASSERT(attributeNames_.size() == attributeTypes_.size()); + AT_ASSERT(slot < attributeTypes_.size()); + return attributeTypes_[slot]; + } + + const std::string& getAttributeName(size_t slot) const { + AT_ASSERT(attributeNames_.size() == attributeTypes_.size()); + AT_ASSERT(slot < attributeTypes_.size()); + return attributeNames_[slot]; + } + Method* getMethod(const std::string& name) const; std::vector methods() const; diff --git a/test/test_jit.py b/test/test_jit.py index dcae659..4a9e555 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -14001,7 +14001,7 @@ class TestDataParallel(JitTestCase): class TestClassType(JitTestCase): def test_get_with_method(self): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.foo = x @@ -14018,7 +14018,7 @@ class TestClassType(JitTestCase): def test_get_attr(self): @torch.jit.script # noqa: B903 - class FooTest: + class FooTest(object): def __init__(self, x): self.foo = x @@ -14032,7 +14032,7 @@ class TestClassType(JitTestCase): def test_set_attr_in_method(self): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): # type: (int) -> None self.foo = x @@ -14053,7 +14053,7 @@ class TestClassType(JitTestCase): def test_set_attr_type_mismatch(self): with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.foo = x self.foo = 10 # should error since int != Tensor @@ -14061,7 +14061,7 @@ class TestClassType(JitTestCase): def test_get_attr_not_initialized(self): with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.foo = x @@ -14071,7 +14071,7 @@ class TestClassType(JitTestCase): def test_set_attr_non_initialized(self): with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.foo = x @@ -14081,7 +14081,7 @@ class TestClassType(JitTestCase): def test_type_annotations(self): with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"): @torch.jit.script # noqa: B903 - class FooTest: + class FooTest(object): def __init__(self, x): # type: (bool) -> None self.foo = x @@ -14095,14 +14095,14 @@ class TestClassType(JitTestCase): def test_conditional_set_attr(self): with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): if True: self.attr = x def test_class_type_as_param(self): @torch.jit.script # noqa: B903 - class FooTest: + class FooTest(object): def __init__(self, x): self.attr = x @@ -14121,7 +14121,7 @@ class TestClassType(JitTestCase): def test_out_of_order_methods(self): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.x = x self.x = self.get_stuff(x) @@ -14139,7 +14139,7 @@ class TestClassType(JitTestCase): def test_save_load_with_classes(self): @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.x = x @@ -14170,18 +14170,18 @@ class TestClassType(JitTestCase): def test_save_load_with_classes_nested(self): @torch.jit.script # noqa: B903 - class FooNestedTest: + class FooNestedTest(object): def __init__(self, y): self.y = y @torch.jit.script - class FooNestedTest2: + class FooNestedTest2(object): def __init__(self, y): self.y = y self.nested = FooNestedTest(y) @torch.jit.script - class FooTest: + class FooTest(object): def __init__(self, x): self.class_attr = FooNestedTest(x) self.class_attr2 = FooNestedTest2(x) @@ -14209,6 +14209,32 @@ class TestClassType(JitTestCase): output = m_loaded(input) self.assertEqual(2 * input, output) + def test_python_interop(self): + @torch.jit.script # noqa: B903 + class Foo(object): + def __init__(self, x, y): + self.x = x + self.y = y + + @torch.jit.script + def use_foo(foo): + # type: (Foo) -> Foo + return foo + + # create from python + x = torch.ones(2, 3) + y = torch.zeros(2, 3) + f = Foo(x, y) + + self.assertEqual(x, f.x) + self.assertEqual(y, f.y) + + # pass in and out of script + f2 = use_foo(f) + + self.assertEqual(x, f2.x) + self.assertEqual(y, f2.y) + for test in autograd_method_tests(): add_autograd_test(*test) diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 028ed30..dedc639 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -223,11 +223,27 @@ inline IValue toIValue( } return toIValue(obj, type->expect()->getElementType()); } + case TypeKind::ClassType: { + auto classType = type->expect(); + // 1. create a bare ivalue + const auto name = Symbol::user(classType->name()); + const size_t numAttrs = classType->numAttributes(); + auto userObj = c10::ivalue::Object::create(name, numAttrs); + + // 2. copy all the contained types + for (size_t slot = 0; slot < numAttrs; slot++) { + const auto& attrType = classType->getAttribute(slot); + const auto& attrName = classType->getAttributeName(slot); + + const auto& contained = py::getattr(obj, attrName.c_str()); + userObj->setSlot(slot, toIValue(contained, attrType)); + } + return userObj; + } case TypeKind::NumberType: case TypeKind::GeneratorType: case TypeKind::VarType: case TypeKind::FutureType: - case TypeKind::ClassType: break; } AT_ERROR( @@ -328,6 +344,22 @@ inline py::object toPyObject(IValue&& ivalue) { py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second}); } return std::move(py_dict); + } else if (ivalue.isObject()) { + const auto obj = ivalue.toObject(); + const auto classType = ClassType::get(obj->name().toUnqualString()); + AT_ASSERT(classType); + auto pyClass = py::module::import("torch.jit") + .attr("_get_script_class")(obj->name().toUnqualString()); + auto pyObj = pyClass.attr("__new__")(pyClass); + + + const auto numAttrs = classType->numAttributes(); + + for (size_t slot = 0; slot < numAttrs; slot++) { + const auto& attrName = classType->getAttributeName(slot); + py::setattr(pyObj, attrName.c_str(), toPyObject(obj->getSlot(slot))); + } + return pyObj; } else { AT_ERROR("Missing cases in 'toPyObject'! File a bug report."); } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 5e87027..86ff58b 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -526,15 +526,6 @@ std::shared_ptr toSugaredValue( << "Attempted to inline a Module with parameters. " "Stateful modules to be inlined must be submodules of the callee."; } - const auto script_class_type = - py::module::import("torch.jit").attr("ScriptClass"); - const bool is_class_type = py::isinstance(obj, script_class_type); - if (is_class_type) { - const auto classname = py::cast(py::getattr(obj, "_name")); - auto classType = ClassType::get(classname); - AT_ASSERT(classType); - return std::make_shared(std::move(classType)); - } return std::make_shared(mod); } else if (py::isinstance(obj)) { return std::make_shared(obj); diff --git a/torch/csrc/jit/script/parser.cpp b/torch/csrc/jit/script/parser.cpp index 0a3d61e..46a121e 100644 --- a/torch/csrc/jit/script/parser.cpp +++ b/torch/csrc/jit/script/parser.cpp @@ -573,7 +573,12 @@ struct ParserImpl { TreeRef parseClass() { L.expect(TK_CLASS_DEF); const auto name = parseIdent(); - // TODO no inheritance or () allowed right now + if (L.nextIf('(')) { + // The parser only supports py3 syntax, so classes are new-style when + // they don't inherit from anything. + L.reportError( + "Inheritance is not yet supported for TorchScript classes yet."); + } L.expect(':'); L.expect(TK_INDENT); diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index c9d9d6f..9d30e88 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -724,17 +724,27 @@ def _try_compile_weak_script(fn): return entry["compiled_fn"] +# ScriptClasses must be new-style classes because we construct them using their +# __new__ method. +def _is_new_style_class(cls): + if hasattr(cls, '__class__'): + return ('__dict__' in dir(cls) or hasattr(cls, '__slots__')) + + def script(obj, optimize=True, _frames_up=0, _rcb=None): if not _enabled: return obj if _rcb is None: _rcb = _jit_internal.createResolutionCallback(_frames_up + 1) + mod = ScriptModule() if inspect.isclass(obj): - mod = ScriptClass(obj.__name__) + if not _is_new_style_class(obj): + raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'") ast = get_jit_class_def(obj) _jit_script_class_compile(mod, ast, _rcb) + _add_script_class(obj, obj.__name__) + return obj else: - mod = ScriptModule() ast = get_jit_def(obj) _jit_script_compile(mod, ast, _rcb, get_default_args(obj)) # Forward docstrings @@ -1292,20 +1302,11 @@ if _enabled: "weak script module once it has been " "created".format(attr)) - class ScriptClass(ScriptModule): - def __init__(self, name): - super(ScriptClass, self).__init__() - self._name = name - else: class ScriptModule(torch.nn.Module): def __init__(self, optimize=True): super(ScriptModule, self).__init__() - class ScriptClass(ScriptModule): - def __init__(self, name): - super(ScriptClass, self).__init__() - def _get_weak_stubs(cls): """ @@ -1544,6 +1545,21 @@ def _find_builtin(fn): _register_builtin(len, 'aten::len') _register_builtin(_wait, 'aten::wait') +_script_classes = {} + + +def _add_script_class(cls, name): + global _script_classes + _script_classes[name] = cls + + +def _get_script_class(name): + global _script_classes + if name not in _script_classes: + raise RuntimeError("Unknown reference to ScriptClass '{}'. " + "Did you forget to import it?".format(name)) + return _script_classes[name] + # torch.jit.Error Error = torch._C.JITException