return attributeTypes_[pos];
}
+ TypePtr getAttribute(size_t slot) const {
+ AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
+ AT_ASSERT(slot < attributeTypes_.size());
+ return attributeTypes_[slot];
+ }
+
+ const std::string& getAttributeName(size_t slot) const {
+ AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
+ AT_ASSERT(slot < attributeTypes_.size());
+ return attributeNames_[slot];
+ }
+
Method* getMethod(const std::string& name) const;
std::vector<Method*> methods() const;
class TestClassType(JitTestCase):
def test_get_with_method(self):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.foo = x
def test_get_attr(self):
@torch.jit.script # noqa: B903
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.foo = x
def test_set_attr_in_method(self):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
# type: (int) -> None
self.foo = x
def test_set_attr_type_mismatch(self):
with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.foo = x
self.foo = 10 # should error since int != Tensor
def test_get_attr_not_initialized(self):
with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.foo = x
def test_set_attr_non_initialized(self):
with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.foo = x
def test_type_annotations(self):
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
@torch.jit.script # noqa: B903
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
# type: (bool) -> None
self.foo = x
def test_conditional_set_attr(self):
with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
if True:
self.attr = x
def test_class_type_as_param(self):
@torch.jit.script # noqa: B903
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.attr = x
def test_out_of_order_methods(self):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.x = x
self.x = self.get_stuff(x)
def test_save_load_with_classes(self):
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.x = x
def test_save_load_with_classes_nested(self):
@torch.jit.script # noqa: B903
- class FooNestedTest:
+ class FooNestedTest(object):
def __init__(self, y):
self.y = y
@torch.jit.script
- class FooNestedTest2:
+ class FooNestedTest2(object):
def __init__(self, y):
self.y = y
self.nested = FooNestedTest(y)
@torch.jit.script
- class FooTest:
+ class FooTest(object):
def __init__(self, x):
self.class_attr = FooNestedTest(x)
self.class_attr2 = FooNestedTest2(x)
output = m_loaded(input)
self.assertEqual(2 * input, output)
+ def test_python_interop(self):
+ @torch.jit.script # noqa: B903
+ class Foo(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ @torch.jit.script
+ def use_foo(foo):
+ # type: (Foo) -> Foo
+ return foo
+
+ # create from python
+ x = torch.ones(2, 3)
+ y = torch.zeros(2, 3)
+ f = Foo(x, y)
+
+ self.assertEqual(x, f.x)
+ self.assertEqual(y, f.y)
+
+ # pass in and out of script
+ f2 = use_foo(f)
+
+ self.assertEqual(x, f2.x)
+ self.assertEqual(y, f2.y)
+
for test in autograd_method_tests():
add_autograd_test(*test)
}
return toIValue(obj, type->expect<OptionalType>()->getElementType());
}
+ case TypeKind::ClassType: {
+ auto classType = type->expect<ClassType>();
+ // 1. create a bare ivalue
+ const auto name = Symbol::user(classType->name());
+ const size_t numAttrs = classType->numAttributes();
+ auto userObj = c10::ivalue::Object::create(name, numAttrs);
+
+ // 2. copy all the contained types
+ for (size_t slot = 0; slot < numAttrs; slot++) {
+ const auto& attrType = classType->getAttribute(slot);
+ const auto& attrName = classType->getAttributeName(slot);
+
+ const auto& contained = py::getattr(obj, attrName.c_str());
+ userObj->setSlot(slot, toIValue(contained, attrType));
+ }
+ return userObj;
+ }
case TypeKind::NumberType:
case TypeKind::GeneratorType:
case TypeKind::VarType:
case TypeKind::FutureType:
- case TypeKind::ClassType:
break;
}
AT_ERROR(
py_dict[toPyObject(IValue{pair.first})] = toPyObject(IValue{pair.second});
}
return std::move(py_dict);
+ } else if (ivalue.isObject()) {
+ const auto obj = ivalue.toObject();
+ const auto classType = ClassType::get(obj->name().toUnqualString());
+ AT_ASSERT(classType);
+ auto pyClass = py::module::import("torch.jit")
+ .attr("_get_script_class")(obj->name().toUnqualString());
+ auto pyObj = pyClass.attr("__new__")(pyClass);
+
+
+ const auto numAttrs = classType->numAttributes();
+
+ for (size_t slot = 0; slot < numAttrs; slot++) {
+ const auto& attrName = classType->getAttributeName(slot);
+ py::setattr(pyObj, attrName.c_str(), toPyObject(obj->getSlot(slot)));
+ }
+ return pyObj;
} else {
AT_ERROR("Missing cases in 'toPyObject'! File a bug report.");
}
<< "Attempted to inline a Module with parameters. "
"Stateful modules to be inlined must be submodules of the callee.";
}
- const auto script_class_type =
- py::module::import("torch.jit").attr("ScriptClass");
- const bool is_class_type = py::isinstance(obj, script_class_type);
- if (is_class_type) {
- const auto classname = py::cast<std::string>(py::getattr(obj, "_name"));
- auto classType = ClassType::get(classname);
- AT_ASSERT(classType);
- return std::make_shared<ClassValue>(std::move(classType));
- }
return std::make_shared<ModuleValue>(mod);
} else if (py::isinstance<py::module>(obj)) {
return std::make_shared<PythonModuleValue>(obj);
TreeRef parseClass() {
L.expect(TK_CLASS_DEF);
const auto name = parseIdent();
- // TODO no inheritance or () allowed right now
+ if (L.nextIf('(')) {
+ // The parser only supports py3 syntax, so classes are new-style when
+ // they don't inherit from anything.
+ L.reportError(
+ "Inheritance is not yet supported for TorchScript classes yet.");
+ }
L.expect(':');
L.expect(TK_INDENT);
return entry["compiled_fn"]
+# ScriptClasses must be new-style classes because we construct them using their
+# __new__ method.
+def _is_new_style_class(cls):
+ if hasattr(cls, '__class__'):
+ return ('__dict__' in dir(cls) or hasattr(cls, '__slots__'))
+
+
def script(obj, optimize=True, _frames_up=0, _rcb=None):
if not _enabled:
return obj
if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
+ mod = ScriptModule()
if inspect.isclass(obj):
- mod = ScriptClass(obj.__name__)
+ if not _is_new_style_class(obj):
+ raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
ast = get_jit_class_def(obj)
_jit_script_class_compile(mod, ast, _rcb)
+ _add_script_class(obj, obj.__name__)
+ return obj
else:
- mod = ScriptModule()
ast = get_jit_def(obj)
_jit_script_compile(mod, ast, _rcb, get_default_args(obj))
# Forward docstrings
"weak script module once it has been "
"created".format(attr))
- class ScriptClass(ScriptModule):
- def __init__(self, name):
- super(ScriptClass, self).__init__()
- self._name = name
-
else:
class ScriptModule(torch.nn.Module):
def __init__(self, optimize=True):
super(ScriptModule, self).__init__()
- class ScriptClass(ScriptModule):
- def __init__(self, name):
- super(ScriptClass, self).__init__()
-
def _get_weak_stubs(cls):
"""
_register_builtin(len, 'aten::len')
_register_builtin(_wait, 'aten::wait')
+_script_classes = {}
+
+
+def _add_script_class(cls, name):
+ global _script_classes
+ _script_classes[name] = cls
+
+
+def _get_script_class(name):
+ global _script_classes
+ if name not in _script_classes:
+ raise RuntimeError("Unknown reference to ScriptClass '{}'. "
+ "Did you forget to import it?".format(name))
+ return _script_classes[name]
+
# torch.jit.Error
Error = torch._C.JITException