python interop for script classes (#18148)
authorMichael Suo <suo@fb.com>
Fri, 22 Mar 2019 23:24:36 +0000 (16:24 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Mar 2019 23:30:04 +0000 (16:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18148
ghimport-source-id: 40a9d745dc9aeba53d098743323fcbd50ca65137

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18148 py interop**

Support for converting classes between the Python–TorchScript boundary. Like other TorchScript values, ScriptClasses are native Python values when used in Python and IValues when used in TorchScript.

Notably, there is a copy across this boundary, which will be surprising to users who will expect standard Python reference semantics. I have some ideas for fixing that, but it's a more involved process.

Reviewed By: jamesr66a

Differential Revision: D14526259

fbshipit-source-id: 5916e3032488a42dc7da756c1826d7c040a21ebd

aten/src/ATen/core/jit_type.h
test/test_jit.py
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/parser.cpp
torch/jit/__init__.py

index 5cbe64e..9b18a81 100644 (file)
@@ -1151,6 +1151,18 @@ struct CAFFE2_API ClassType : public Type {
     return attributeTypes_[pos];
   }
 
+  TypePtr getAttribute(size_t slot) const {
+    AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
+    AT_ASSERT(slot < attributeTypes_.size());
+    return attributeTypes_[slot];
+  }
+
+  const std::string& getAttributeName(size_t slot) const {
+    AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
+    AT_ASSERT(slot < attributeTypes_.size());
+    return attributeNames_[slot];
+  }
+
   Method* getMethod(const std::string& name) const;
   std::vector<Method*> methods() const;
 
index dcae659..4a9e555 100644 (file)
@@ -14001,7 +14001,7 @@ class TestDataParallel(JitTestCase):
 class TestClassType(JitTestCase):
     def test_get_with_method(self):
         @torch.jit.script
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.foo = x
 
@@ -14018,7 +14018,7 @@ class TestClassType(JitTestCase):
 
     def test_get_attr(self):
         @torch.jit.script  # noqa: B903
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.foo = x
 
@@ -14032,7 +14032,7 @@ class TestClassType(JitTestCase):
 
     def test_set_attr_in_method(self):
         @torch.jit.script
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 # type: (int) -> None
                 self.foo = x
@@ -14053,7 +14053,7 @@ class TestClassType(JitTestCase):
     def test_set_attr_type_mismatch(self):
         with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
             @torch.jit.script
-            class FooTest:
+            class FooTest(object):
                 def __init__(self, x):
                     self.foo = x
                     self.foo = 10  # should error since int != Tensor
@@ -14061,7 +14061,7 @@ class TestClassType(JitTestCase):
     def test_get_attr_not_initialized(self):
         with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
             @torch.jit.script
-            class FooTest:
+            class FooTest(object):
                 def __init__(self, x):
                     self.foo = x
 
@@ -14071,7 +14071,7 @@ class TestClassType(JitTestCase):
     def test_set_attr_non_initialized(self):
         with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
             @torch.jit.script
-            class FooTest:
+            class FooTest(object):
                 def __init__(self, x):
                     self.foo = x
 
@@ -14081,7 +14081,7 @@ class TestClassType(JitTestCase):
     def test_type_annotations(self):
         with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
             @torch.jit.script  # noqa: B903
-            class FooTest:
+            class FooTest(object):
                 def __init__(self, x):
                     # type: (bool) -> None
                     self.foo = x
@@ -14095,14 +14095,14 @@ class TestClassType(JitTestCase):
     def test_conditional_set_attr(self):
         with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
             @torch.jit.script
-            class FooTest:
+            class FooTest(object):
                 def __init__(self, x):
                     if True:
                         self.attr = x
 
     def test_class_type_as_param(self):
         @torch.jit.script  # noqa: B903
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.attr = x
 
@@ -14121,7 +14121,7 @@ class TestClassType(JitTestCase):
 
     def test_out_of_order_methods(self):
         @torch.jit.script
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.x = x
                 self.x = self.get_stuff(x)
@@ -14139,7 +14139,7 @@ class TestClassType(JitTestCase):
 
     def test_save_load_with_classes(self):
         @torch.jit.script
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.x = x
 
@@ -14170,18 +14170,18 @@ class TestClassType(JitTestCase):
 
     def test_save_load_with_classes_nested(self):
         @torch.jit.script  # noqa: B903
-        class FooNestedTest:
+        class FooNestedTest(object):
             def __init__(self, y):
                 self.y = y
 
         @torch.jit.script
-        class FooNestedTest2:
+        class FooNestedTest2(object):
             def __init__(self, y):
                 self.y = y
                 self.nested = FooNestedTest(y)
 
         @torch.jit.script
-        class FooTest:
+        class FooTest(object):
             def __init__(self, x):
                 self.class_attr = FooNestedTest(x)
                 self.class_attr2 = FooNestedTest2(x)
@@ -14209,6 +14209,32 @@ class TestClassType(JitTestCase):
         output = m_loaded(input)
         self.assertEqual(2 * input, output)
 
+    def test_python_interop(self):
+        @torch.jit.script  # noqa: B903
+        class Foo(object):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+        @torch.jit.script
+        def use_foo(foo):
+            # type: (Foo) -> Foo
+            return foo
+
+        # create from python
+        x = torch.ones(2, 3)
+        y = torch.zeros(2, 3)
+        f = Foo(x, y)
+
+        self.assertEqual(x, f.x)
+        self.assertEqual(y, f.y)
+
+        # pass in and out of script
+        f2 = use_foo(f)
+
+        self.assertEqual(x, f2.x)
+        self.assertEqual(y, f2.y)
+
 
 for test in autograd_method_tests():
     add_autograd_test(*test)
index 028ed30..dedc639 100644 (file)
@@ -223,11 +223,27 @@ inline IValue toIValue(
       }
       return toIValue(obj, type->expect<OptionalType>()->getElementType());
     }
+    case TypeKind::ClassType: {
+      auto classType = type->expect<ClassType>();
+      // 1. create a bare ivalue
+      const auto name = Symbol::user(classType->name());
+      const size_t numAttrs = classType->numAttributes();
+      auto userObj = c10::ivalue::Object::create(name, numAttrs);
+
+      // 2. copy all the contained types
+      for (size_t slot = 0; slot < numAttrs; slot++) {
+        const auto& attrType = classType->getAttribute(slot);
+        const auto& attrName = classType->getAttributeName(slot);
+
+        const auto& contained = py::getattr(obj, attrName.c_str());
+        userObj->setSlot(slot, toIValue(contained, attrType));
+      }
+      return userObj;
+    }
     case TypeKind::NumberType:
     case TypeKind::GeneratorType:
     case TypeKind::VarType:
     case TypeKind::FutureType:
-    case TypeKind::ClassType:
       break;
   }
   AT_ERROR(
@@ -328,6 +344,22 @@ inline py::object toPyObject(IValue&& ivalue) {
       py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second});
     }
     return std::move(py_dict);
+  } else if (ivalue.isObject()) {
+    const auto obj = ivalue.toObject();
+    const auto classType = ClassType::get(obj->name().toUnqualString());
+    AT_ASSERT(classType);
+    auto pyClass = py::module::import("torch.jit")
+                       .attr("_get_script_class")(obj->name().toUnqualString());
+    auto pyObj = pyClass.attr("__new__")(pyClass);
+
+
+    const auto numAttrs = classType->numAttributes();
+
+    for (size_t slot = 0; slot < numAttrs; slot++) {
+      const auto& attrName = classType->getAttributeName(slot);
+      py::setattr(pyObj, attrName.c_str(), toPyObject(obj->getSlot(slot)));
+    }
+    return pyObj;
   } else {
     AT_ERROR("Missing cases in 'toPyObject'! File a bug report.");
   }
index 5e87027..86ff58b 100644 (file)
@@ -526,15 +526,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
           << "Attempted to inline a Module with parameters. "
              "Stateful modules to be inlined must be submodules of the callee.";
     }
-    const auto script_class_type =
-        py::module::import("torch.jit").attr("ScriptClass");
-    const bool is_class_type = py::isinstance(obj, script_class_type);
-    if (is_class_type) {
-      const auto classname = py::cast<std::string>(py::getattr(obj, "_name"));
-      auto classType = ClassType::get(classname);
-      AT_ASSERT(classType);
-      return std::make_shared<ClassValue>(std::move(classType));
-    }
     return std::make_shared<ModuleValue>(mod);
   } else if (py::isinstance<py::module>(obj)) {
     return std::make_shared<PythonModuleValue>(obj);
index 0a3d61e..46a121e 100644 (file)
@@ -573,7 +573,12 @@ struct ParserImpl {
   TreeRef parseClass() {
     L.expect(TK_CLASS_DEF);
     const auto name = parseIdent();
-    // TODO no inheritance or () allowed right now
+    if (L.nextIf('(')) {
+      // The parser only supports py3 syntax, so classes are new-style when
+      // they don't inherit from anything.
+      L.reportError(
+          "Inheritance is not yet supported for TorchScript classes yet.");
+    }
     L.expect(':');
 
     L.expect(TK_INDENT);
index c9d9d6f..9d30e88 100644 (file)
@@ -724,17 +724,27 @@ def _try_compile_weak_script(fn):
         return entry["compiled_fn"]
 
 
+# ScriptClasses must be new-style classes because we construct them using their
+# __new__ method.
+def _is_new_style_class(cls):
+    if hasattr(cls, '__class__'):
+        return ('__dict__' in dir(cls) or hasattr(cls, '__slots__'))
+
+
 def script(obj, optimize=True, _frames_up=0, _rcb=None):
     if not _enabled:
         return obj
     if _rcb is None:
         _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+    mod = ScriptModule()
     if inspect.isclass(obj):
-        mod = ScriptClass(obj.__name__)
+        if not _is_new_style_class(obj):
+            raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
         ast = get_jit_class_def(obj)
         _jit_script_class_compile(mod, ast, _rcb)
+        _add_script_class(obj, obj.__name__)
+        return obj
     else:
-        mod = ScriptModule()
         ast = get_jit_def(obj)
         _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
     # Forward docstrings
@@ -1292,20 +1302,11 @@ if _enabled:
                                      "weak script module once it has been "
                                      "created".format(attr))
 
-    class ScriptClass(ScriptModule):
-        def __init__(self, name):
-            super(ScriptClass, self).__init__()
-            self._name = name
-
 else:
     class ScriptModule(torch.nn.Module):
         def __init__(self, optimize=True):
             super(ScriptModule, self).__init__()
 
-    class ScriptClass(ScriptModule):
-        def __init__(self, name):
-            super(ScriptClass, self).__init__()
-
 
 def _get_weak_stubs(cls):
     """
@@ -1544,6 +1545,21 @@ def _find_builtin(fn):
 _register_builtin(len, 'aten::len')
 _register_builtin(_wait, 'aten::wait')
 
+_script_classes = {}
+
+
+def _add_script_class(cls, name):
+    global _script_classes
+    _script_classes[name] = cls
+
+
+def _get_script_class(name):
+    global _script_classes
+    if name not in _script_classes:
+        raise RuntimeError("Unknown reference to ScriptClass '{}'. "
+                           "Did you forget to import it?".format(name))
+    return _script_classes[name]
+
 # torch.jit.Error
 Error = torch._C.JITException