user defined types (#17314)
authorMichael Suo <suo@fb.com>
Tue, 26 Feb 2019 09:24:05 +0000 (01:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Feb 2019 09:34:07 +0000 (01:34 -0800)
Summary:
First pass at user defined types. The following is contained in this PR:
- `UserType` type, which contains a reference to a module with all methods for the type, and a separate namespace for data attributes (map of name -> TypePtr).
- `UserTypeRegistry`, similar to the operator registry
- `UserObject` which is the runtime representation of the user type (just a map of names -> IValues)
- `UserTypeValue` SugaredValue, to manage getattr and setattr while generating IR, plus compiler.cpp changes to make that work.
- Frontend changes to get `torch.jit.script` to work as a class decorator
- `ClassDef` node in our AST.
- primitive ops for object creation, setattr, and getattr, plus alias analysis changes to make mutation safe.

Things that definitely need to get done:
- Import/export, python_print support
- String frontend doesn't understand class definitions yet
- Python interop (using a user-defined type outside TorchScript) is completely broken
- Static methods (without `self`) don't work

Things that are nice but not essential:
- Method definition shouldn't matter (right now you can only reference a method that's already been defined)
- Class definitions can only contain defs, no other expressions are supported.

Things I definitely won't do initially:
- Polymorphism/inheritance
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17314

Differential Revision: D14194065

Pulled By: suo

fbshipit-source-id: c5434afdb9b39f84b7c85a9fdc2891f8250b5025

28 files changed:
aten/src/ATen/core/interned_strings.h
aten/src/ATen/core/ivalue.cpp
aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
aten/src/ATen/core/type.cpp
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/import_method.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/lexer.h
torch/csrc/jit/script/python_tree_views.cpp
torch/csrc/jit/script/script_type_parser.cpp
torch/csrc/jit/script/sugared_value.cpp
torch/csrc/jit/script/sugared_value.h
torch/csrc/jit/script/tree_views.h
torch/csrc/jit/script/user_type.cpp [new file with mode: 0644]
torch/jit/__init__.py
torch/jit/frontend.py

index 527a604..ed4b768 100644 (file)
@@ -17,6 +17,7 @@ namespace c10 {
   _(namespaces, onnx)              \
   _(namespaces, attr)              \
   _(namespaces, scope)             \
+  _(namespaces, user)              \
   _(namespaces, namespaces)        \
   _(prim, Assign)                  \
   _(prim, BroadcastingChunk)       \
@@ -82,6 +83,9 @@ namespace c10 {
   _(prim, fork)                    \
   _(prim, RaiseException)          \
   _(prim, Function)                \
+  _(prim, CreateUserObject)        \
+  _(prim, SetAttr)                 \
+  _(prim, GetAttr)                 \
   _(aten, append)                  \
   _(aten, format)                  \
   _(aten, __not__)                 \
@@ -151,7 +155,8 @@ namespace c10 {
   _(attr, b)                       \
   _(attr, beg)                     \
   _(attr, idx)                     \
-  _(attr, split)
+  _(attr, split)                   \
+  _(attr, slot)
 #else
 #define FORALL_NS_SYMBOLS(_) \
   _(namespaces, prim)              \
@@ -159,6 +164,7 @@ namespace c10 {
   _(namespaces, onnx)              \
   _(namespaces, attr)              \
   _(namespaces, scope)             \
+  _(namespaces, user)              \
   _(namespaces, namespaces)
 #endif
 
@@ -225,6 +231,7 @@ struct CAFFE2_API Symbol {
   static Symbol aten(const std::string & s);
   static Symbol onnx(const std::string & s);
   static Symbol prim(const std::string & s);
+  static Symbol user(const std::string & s);
   // TODO: eliminate me
   static Symbol scope(const std::string & s);
 
@@ -232,6 +239,7 @@ struct CAFFE2_API Symbol {
   bool is_aten() const;
   bool is_prim() const;
   bool is_onnx() const;
+  bool is_user() const;
 
   // So we can switch on this
   constexpr operator unique_t() const {
@@ -290,10 +298,12 @@ inline Symbol Symbol::aten(const std::string & s)  { return Symbol::fromQualStri
 inline Symbol Symbol::onnx(const std::string & s)  { return Symbol::fromQualString("onnx::" + s); }
 inline Symbol Symbol::prim(const std::string & s)  { return Symbol::fromQualString("prim::" + s); }
 inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
+inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
 inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
 inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
 inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
 inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
+inline bool Symbol::is_user() const { return ns() == namespaces::user; }
 
 } // namespace c10
 
index 86ac4b8..1b61d16 100644 (file)
@@ -93,6 +93,10 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
       return out << v.toDevice();
     case IValue::Tag::GenericDict:
       return printDict(out, v.toGenericDict());
+    case IValue::Tag::UserObject:
+      // TODO we should print the object contents
+      return out << "UserObject<" << v.toUserObject()->name().toUnqualString()
+                 << ">";
   }
   AT_ERROR("Tag not found\n");
 }
index 08b2298..533edfa 100644 (file)
@@ -91,7 +91,7 @@ using DoubleList = List<double>;
 using BoolList = List<bool>;
 using GenericList = List<IValue>;
 
-
+struct UserObject;
 }
 
 // IValue is the generic tagged union used by the interpreter to hold
@@ -117,7 +117,8 @@ using GenericList = List<IValue>;
   _(GenericList) \
   _(GenericDict) \
   _(Future) \
-  _(Device)
+  _(Device) \
+  _(UserObject)
 
 struct CAFFE2_API IValue final {
   IValue()
@@ -393,6 +394,18 @@ struct CAFFE2_API IValue final {
     return toIntrusivePtr<ivalue::GenericDict>();
   }
 
+  // UserType
+  IValue(c10::intrusive_ptr<ivalue::UserObject> v);
+  bool isUserObject() const { return tag == Tag::UserObject; }
+  c10::intrusive_ptr<ivalue::UserObject> toUserObject() && {
+    AT_ASSERT(isUserObject());
+    return toIntrusivePtr<ivalue::UserObject>();
+  }
+  c10::intrusive_ptr<ivalue::UserObject> toUserObject() const & {
+    AT_ASSERT(isUserObject());
+    return toIntrusivePtr<ivalue::UserObject>();
+  }
+
   // None
   bool isNone() const {
     return Tag::None == tag;
@@ -665,6 +678,36 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
   FutureError error;
 };
 
+// User-defined object.
+struct C10_EXPORT ivalue::UserObject final : c10::intrusive_ptr_target {
+ public:
+  UserObject(Symbol name, size_t numSlots) : typename_(std::move(name)) {
+    slots_.resize(numSlots);
+  }
+
+  static c10::intrusive_ptr<UserObject> create(
+      Symbol name,
+      size_t numSlots) {
+    return c10::make_intrusive<UserObject>(std::move(name), numSlots);
+  }
+
+  void setSlot(size_t slot, IValue v) {
+    slots_[slot] = v;
+  }
+
+  IValue getSlot(size_t slot) const {
+    return slots_.at(slot);
+  }
+
+  Symbol name() const {
+    return typename_;
+  }
+
+ private:
+  const Symbol typename_;
+  std::vector<IValue> slots_;
+};
+
 struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
  private:
   UnorderedMap elements_;
@@ -737,6 +780,7 @@ DEFINE_TO(c10::intrusive_ptr<ivalue::TensorList>, toTensorList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericList>, toGenericList)
 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericDict>, toGenericDict)
 DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
+DEFINE_TO(c10::intrusive_ptr<ivalue::UserObject>, toUserObject)
 DEFINE_TO(at::Scalar, toScalar)
 DEFINE_TO(std::vector<int64_t>, toIntListRef)
 DEFINE_TO(std::vector<double>, toDoubleListRef)
@@ -808,6 +852,10 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericDict> v)
 inline IValue::IValue(ivalue::UnorderedMap v)
 : IValue(ivalue::GenericDict::create(std::move(v))) {}
 
+inline IValue::IValue(c10::intrusive_ptr<ivalue::UserObject> v)
+: tag(Tag::UserObject), is_intrusive_ptr(true) {
+  payload.as_intrusive_ptr = v.release();
+}
 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
 : tag(Tag::Future), is_intrusive_ptr(true) {
   payload.as_intrusive_ptr = v.release();
index 4db6f06..3e2cd3a 100644 (file)
 #include <iostream>
 #include <type_traits>
 
+namespace torch {
+namespace jit {
+namespace script {
+struct Module;
+struct Method;
+}
+} // namespace jit
+} // namespace torch
+
 namespace c10 {
 
 #define C10_FORALL_TYPES(_) \
@@ -35,6 +44,7 @@ _(BoolType) \
 _(OptionalType) \
 _(VarType) \
 _(DeviceObjType) \
+_(UserType) \
 
 enum class TypeKind {
 #define DEFINE_TYPE(T) T,
@@ -1068,4 +1078,111 @@ matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env);
 
 CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
 
+/**
+ * User Defined Types
+ */
+
+struct UserType;
+using UserTypePtr = std::shared_ptr<UserType>;
+using ::torch::jit::script::Module;
+using ::torch::jit::script::Method;
+
+// This represents a user-defined type in TorchScript.
+struct CAFFE2_API UserType : public Type {
+  // Create a user type and register it globally.
+  static UserTypePtr create(const std::string& name, std::shared_ptr<Module> module);
+  // returns nullptr if there is no type with that name
+  static UserTypePtr get(const std::string& name);
+
+  DEFINE_IS_SUBCLASS(UserType);
+  bool operator==(const Type& rhs) const override {
+    if (auto user_rhs = rhs.cast<UserType>()) {
+      return typename_ == user_rhs->typename_;
+    }
+    return false;
+  }
+
+  bool isSubtypeOf(const TypePtr rhs) const override {
+    // XXX: We do not have inheritance implemented, only types that are the
+    // same can subtype from each other.
+    return *this == *rhs;
+  }
+  std::string str() const override {
+    return std::string("UserType<") + typename_ + ">";
+  }
+
+  TypePtr getAttribute(const std::string& name) const {
+    const auto it = std::find_if(
+        attributes_.cbegin(), attributes_.cend(), [&](const Attribute& attr) {
+          return attr.name == name;
+        });
+    if (it == attributes_.cend()) {
+      return nullptr;
+    }
+
+    return it->type;
+  }
+
+  Method* getMethod(const std::string& name) const;
+
+  std::string name() const {
+    return typename_;
+  }
+
+  size_t numAttributes() const {
+    return attributes_.size();
+  }
+
+  // Attributes are stored in a specific slot at runtime for effiency.
+  // When emitting instructions we specify the slot so that attribute access is
+  // a constant lookup
+  size_t getAttributeSlot(const std::string& name) const {
+    size_t slot = 0;
+    for (const auto& attr : attributes_) {
+      if (name == attr.name) {
+        return slot;
+      }
+      slot++;
+    }
+    throw std::runtime_error("Couldn't find attribute: " + name);
+  }
+
+  bool hasAttribute(const std::string& name) const {
+    return std::find_if(
+               attributes_.cbegin(),
+               attributes_.cend(),
+               [&](const Attribute& attr) { return attr.name == name; }) !=
+        attributes_.cend();
+  }
+
+  void addAttribute(const std::string& name, TypePtr type) {
+    attributes_.emplace_back(name, type);
+  }
+
+  static const TypeKind Kind = TypeKind::UserType;
+
+ private:
+  UserType(std::string name, std::shared_ptr<Module> module)
+      : Type(TypeKind::UserType),
+        typename_(std::move(name)),
+        module_(std::move(module)) {}
+
+  // Name of type (note that this has to be globally unique).
+  std::string typename_;
+
+  // Mapping of attribute names -> their type.
+  // NOTE: this does not contain methods, which are stored in the module
+  // TODO: once modules support arbitrary ivalue attributes, we don't need this
+  // anymore.
+  struct Attribute {
+    Attribute(std::string n, TypePtr t)
+        : name(std::move(n)), type(std::move(t)) {}
+    std::string name;
+    TypePtr type;
+  };
+  std::vector<Attribute> attributes_;
+  // Holds method attributes
+  std::shared_ptr<Module> module_;
+
+};
 } // namespace c10
index e29345c..eec94a2 100644 (file)
@@ -424,4 +424,43 @@ bool Type::isSubtypeOf(const TypePtr rhs) const {
   return *this == *rhs;
 }
 
+namespace {
+class UserTypeRegistry {
+ public:
+  void registerType(std::string name, UserTypePtr type) {
+    std::lock_guard<std::mutex> g(mutex_);
+    // TODO: new type registrations will override the old ones. Is this safe?
+    reg_[name] = type;
+  }
+
+  UserTypePtr getType(const std::string& name) {
+    std::lock_guard<std::mutex> g(mutex_);
+    if (reg_.count(name)) {
+      return reg_.at(name);
+    }
+    return nullptr;
+  }
+
+ private:
+  std::mutex mutex_;
+  std::unordered_map<std::string, UserTypePtr> reg_;
+};
+
+UserTypeRegistry& getRegistry() {
+  static UserTypeRegistry r;
+  return r;
+}
+} // namespace
+
+UserTypePtr UserType::create(
+    const std::string& name,
+    std::shared_ptr<Module> module) {
+  auto ptr = UserTypePtr(new UserType(name, std::move(module)));
+  getRegistry().registerType(name, ptr);
+  return ptr;
+}
+
+UserTypePtr UserType::get(const std::string& name) {
+  return getRegistry().getType(name);
+}
 } // namespace c10
index 0a51256..b63636a 100644 (file)
@@ -4346,21 +4346,21 @@ a")
             assert 1 == 1, "hello"
             return x
 
-        ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+        ast = torch.jit.frontend.get_jit_def(fn)
         self.assertExpected(str(ast))
 
     @unittest.skipIf(not PY2, "Requires python 2")
     def test_python_frontend_py2(self):
         def fn():
             raise Exception("hello")
-        ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+        ast = torch.jit.frontend.get_jit_def(fn)
         self.assertExpected(str(ast))
 
     @unittest.skipIf(PY2, "Requires python 3")
     def test_python_frontend_py3(self):
         def fn():
             raise Exception("hello")
-        ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+        ast = torch.jit.frontend.get_jit_def(fn)
         self.assertExpected(str(ast))
 
     def _make_scalar_vars(self, arr, dtype):
@@ -13434,6 +13434,146 @@ class TestDataParallel(JitTestCase):
         self.assertEqual(first_forward, r1_forward)
 
 
+class TestUserType(JitTestCase):
+    def test_get_with_method(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            @torch.jit.script
+            class Foo:
+                def __init__(self, x):
+                    self.foo = x
+
+                def getFoo(self):
+                    return self.foo
+
+            @torch.jit.script
+            def fn(x):
+                foo = Foo(x)
+                return foo.getFoo()
+
+            input = torch.ones(2, 3)
+            self.assertEqual(fn(input), input)
+
+    def test_get_attr(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            @torch.jit.script
+            class Foo:
+                def __init__(self, x):
+                    self.foo = x
+
+            @torch.jit.script
+            def fn(x):
+                foo = Foo(x)
+                return foo.foo
+
+            input = torch.ones(2, 3)
+            self.assertEqual(fn(input), input)
+
+    def test_set_attr_in_method(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            @torch.jit.script
+            class Foo:
+                def __init__(self, x):
+                    # type: (int)
+                    self.foo = x
+
+                def incFoo(self, y):
+                    # type: (int)
+                    self.foo = self.foo + y
+
+            @torch.jit.script
+            def fn(x):
+                # type: (int)
+                foo = Foo(x)
+                foo.incFoo(2)
+                return foo.foo
+
+            self.assertEqual(fn(1), 3)
+
+    def test_set_attr_type_mismatch(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
+                @torch.jit.script
+                class Foo:
+                    def __init__(self, x):
+                        self.foo = x
+                        self.foo = 10  # should error since int != Tensor
+
+    def test_get_attr_not_initialized(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
+                @torch.jit.script
+                class Foo:
+                    def __init__(self, x):
+                        self.foo = x
+
+                    def get_non_initialized(self):
+                        return self.asdf  # asdf isn't an attr
+
+    def test_set_attr_non_initialized(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
+                @torch.jit.script
+                class Foo:
+                    def __init__(self, x):
+                        self.foo = x
+
+                    def set_non_initialized(self, y):
+                        self.bar = y  # can't assign to non-initialized attr
+
+    def test_type_annotations(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
+                @torch.jit.script
+                class Foo:
+                    def __init__(self, x):
+                        # type: (bool)
+                        self.foo = x
+
+                @torch.jit.script
+                def fn(x):
+                    Foo(x)
+
+                fn(2)
+
+    def test_conditional_set_attr(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
+                @torch.jit.script
+                class Foo:
+                    def __init__(self, x):
+                        if True:
+                            self.attr = x
+
+    def test_user_type_as_param(self):
+        # Remove this when import/export is implemented for classes
+        with self.disableModuleHook():
+            @torch.jit.script
+            class Foo:
+                def __init__(self, x):
+                    self.attr = x
+
+            @torch.jit.script
+            def fn(foo):
+                # type: (Foo)
+                return foo.attr
+
+            @torch.jit.script
+            def fn2(x):
+                foo = Foo(x)
+                return fn(foo)
+
+            input = torch.ones(1)
+            self.assertEqual(fn2(input), input)
+
+
 for test in autograd_method_tests():
     add_autograd_test(*test)
 
index bebc87b..b73fac2 100644 (file)
@@ -98,6 +98,7 @@ libtorch_sources = [
     "torch/csrc/jit/script/script_type_parser.cpp",
     "torch/csrc/jit/script/sugared_value.cpp",
     "torch/csrc/jit/script/schema_matching.cpp",
+    "torch/csrc/jit/script/user_type.cpp",
     "torch/csrc/jit/script/parser.cpp",
     "torch/csrc/jit/testing/file_check.cpp",
     "torch/csrc/jit/import_method.cpp",
index a39cf45..62be18c 100644 (file)
@@ -177,6 +177,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/script/schema_type_parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/script/user_type.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
   ${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
index a5a3bc6..029988b 100644 (file)
@@ -23,7 +23,7 @@ struct ModuleAccessorValue : public script::SugaredValue {
       return std::make_shared<script::SimpleValue>(
           m.get_or_add_parameter(v->slot()));
     } else if (script::Method* m = module->find_method(field)) {
-      return std::make_shared<script::MethodValue>(module, *m);
+      return std::make_shared<script::MethodValue>(shared_from_this(), *m);
     } else {
       throw script::ErrorReport(loc) << "unknown attr: " << field;
     }
index 54f31fd..9d636ca 100644 (file)
@@ -836,6 +836,7 @@ bool Node::hasSideEffects() const {
     case prim::IgnoredPythonOp:
     case prim::Print:
     case prim::RaiseException:
+    case prim::SetAttr:
     case aten::warn:
       return true;
   }
@@ -1308,6 +1309,34 @@ Node* Graph::createImplicitTensorToNum(const TypePtr& type, Value* value) {
   return result;
 }
 
+Node* Graph::createUserObject(const UserTypePtr& type) {
+  auto result = create(prim::CreateUserObject);
+  result->output()->setType(type);
+  return result;
+}
+
+Node* Graph::createSetAttr(
+    Value* obj,
+    const std::string& field,
+    Value* newValue) {
+  const auto userType = obj->type()->expect<UserType>();
+
+  auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
+  n->s_(attr::name, field);
+  return n;
+}
+
+Node* Graph::createGetAttr(Value* obj, const std::string& field) {
+  const auto userType = obj->type()->expect<UserType>();
+
+  auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
+  n->s_(attr::name, field);
+
+  const auto outputType = userType->getAttribute(field);
+  n->output()->setType(outputType);
+  return n;
+}
+
 Node* Graph::createClone(
     Node* n,
     const std::function<Value*(Value*)>& value_map,
index ea5a5d9..6aea973 100644 (file)
@@ -1062,6 +1062,12 @@ struct Graph {
   TORCH_API Node* createDictIndex(Value* dict, Value* index);
   TORCH_API Node* createNumToTensor(Value* value);
   TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
+  TORCH_API Node* createUserObject(const UserTypePtr& type);
+  TORCH_API Node* createSetAttr(
+      Value* obj,
+      const std::string& field,
+      Value* newValue);
+  TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
   Node* createPythonOp(
       THPObjectPtr&& pyobj,
       const std::string& cconv,
index 07375c3..88680d2 100644 (file)
@@ -12,6 +12,7 @@ bool shouldAnnotate(const TypePtr& type) {
       type->kind() == TypeKind::TupleType ||
       type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
       type->kind() == TypeKind::FutureType ||
+      type->kind() == TypeKind::UserType ||
       (type->kind() == TypeKind::OptionalType &&
        shouldAnnotate(type->cast<OptionalType>()->getElementType()));
 }
@@ -210,6 +211,7 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
   std::map<TypeKind, std::vector<Value*>> listTypes;
   std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes;
   std::unordered_map<DictTypePtr, std::vector<Value*>> dictTypes;
+  std::unordered_map<UserTypePtr, std::vector<Value*>> userTypes;
   std::vector<Value*> tensors;
 
   for (auto input : graph->inputs()) {
@@ -235,6 +237,9 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
     } else if (inputType->kind() == TypeKind::DictType) {
       auto dictType = inputType->cast<DictType>();
       dictTypes[dictType].push_back(input);
+    } else if (inputType->kind() == TypeKind::UserType) {
+      auto userType = inputType->cast<UserType>();
+      userTypes[userType].push_back(input);
     } else {
       AT_ASSERT(!shouldAnnotate(input));
     }
@@ -250,6 +255,9 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
   for (const auto& pr : dictTypes) {
     makeAllAlias(pr.second, *aliasTracker_);
   }
+  for (const auto& pr : userTypes) {
+    makeAllAlias(pr.second, *aliasTracker_);
+  }
   makeAllAlias(tensors, *aliasTracker_);
 
   analyze(graph->block());
@@ -300,6 +308,7 @@ void AliasDb::analyzeImpl(Node* node) {
     case prim::BroadcastSizes:
     case prim::ChunkSizes:
     case prim::Function:
+    case prim::CreateUserObject:
       return analyzeCreator(node);
     case prim::TupleUnpack:
     case prim::TupleIndex:
@@ -307,11 +316,14 @@ void AliasDb::analyzeImpl(Node* node) {
     case prim::TupleSlice:
     case prim::ListUnpack:
     case prim::PythonOp:
+    case prim::GetAttr:
       return analyzeExtractor(node);
     case prim::ConstantChunk:
       return analyzeChunk(node);
     case prim::BroadcastingChunk:
       return analyzeBroadcastingChunk(node);
+    case prim::SetAttr:
+      return analyzeSetAttr(node);
     case aten::add:
     case aten::sub:
     case aten::mul:
@@ -504,7 +516,9 @@ void AliasDb::analyzeCreator(Node* node) {
 // gives up and creates wildcards for everything.
 void AliasDb::analyzeExtractor(Node* node) {
   for (const auto output : node->outputs()) {
-    aliasTracker_->setWildcard(output);
+    if (shouldAnnotate(output)) {
+      aliasTracker_->setWildcard(output);
+    }
   }
 }
 
@@ -581,6 +595,13 @@ void AliasDb::analyzeWait(Node* node) {
   }
 }
 
+// SetAttr: writes to the `self` field
+void AliasDb::analyzeSetAttr(Node* node) {
+  const auto self = node->inputs().at(0);
+  AT_ASSERT(self->type()->kind() == TypeKind::UserType);
+  aliasTracker_->registerWrite(self, node);
+}
+
 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
 // This is an intermediate node used only in the graph fuser.
 void AliasDb::analyzeBroadcastingChunk(Node* node) {
@@ -1008,6 +1029,9 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
       prim::ConstantChunk,
       prim::BroadcastingChunk,
       prim::fork,
+      prim::CreateUserObject,
+      prim::GetAttr,
+      prim::SetAttr,
       aten::wait,
       aten::add,
       aten::sub,
index 5228472..508de79 100644 (file)
@@ -119,6 +119,7 @@ class AliasDb {
   void analyzeBroadcastingChunk(Node* node);
   void analyzeFork(Node* node);
   void analyzeWait(Node* node);
+  void analyzeSetAttr(Node* node);
 
   void makeAliasOf(const Value* value, const Value* to);
   void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
index e559b84..0673fb6 100644 (file)
@@ -193,7 +193,7 @@ class DeadCodeEliminator {
     if (!aliasDb_) {
       // If we don't have alias information, all mutable ops have unknown
       // effects and can't be considered for elimination.
-      if (!node->kind().is_aten()) {
+      if (!node->kind().is_aten() && !node->kind().is_prim()) {
         return false;
       }
       // onnx export calls EliminateDeadCode but sometimes passes invalid
index 92a9d70..dbab76b 100644 (file)
@@ -839,6 +839,10 @@ struct PythonPrintPass {
             [graph, name, this] { printFunctionDefinition(*graph, name); });
         stmt << "self." << name;
       } break;
+      case prim::CreateUserObject:
+      case prim::SetAttr:
+      case prim::GetAttr:
+        throw std::runtime_error("NYI");
       default: {
         Symbol kind = node->kind();
         if (kind.is_aten()) {
@@ -1082,6 +1086,9 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
       prim::TupleSlice,
       prim::TupleUnpack,
       prim::Undefined,
+      prim::CreateUserObject,
+      prim::GetAttr,
+      prim::SetAttr,
   };
 
   // WARNING: by adding a value to this set, you are asserting that your
index 49b48fe..9c29ad2 100644 (file)
@@ -222,6 +222,7 @@ inline IValue toIValue(
     case TypeKind::GeneratorType:
     case TypeKind::VarType:
     case TypeKind::FutureType:
+    case TypeKind::UserType:
       break;
   }
   AT_ERROR(
index 4b091b2..acc94c4 100644 (file)
@@ -816,7 +816,43 @@ RegisterOperators reg({
           }
           return 0;
         }),
-});
+     Operator(
+         prim::CreateUserObject,
+         [](const Node* node) {
+           const auto type = node->output()->type()->expect<UserType>();
+           const auto name = Symbol::user(type->name());
+           const size_t numAttrs = type->numAttributes();
+           return [name, numAttrs](Stack& stack) {
+             auto userObj =
+                 c10::ivalue::UserObject::create(name, numAttrs);
+             push(stack, std::move(userObj));
+             return 0;
+           };
+         }),
+     Operator(
+         prim::GetAttr,
+         [](const Node* node) {
+           const auto type = node->input()->type()->expect<UserType>();
+           const auto& field = node->s(attr::name);
+           const auto slot = type->getAttributeSlot(field);
+           return [slot](Stack& stack) {
+             auto userObj = pop(stack).toUserObject();
+             auto value = userObj->getSlot(slot);
+             push(stack, std::move(value));
+             return 0;
+           };
+         }),
+     Operator(prim::SetAttr, [](const Node* node) {
+       const auto type = node->inputs().at(0)->type()->expect<UserType>();
+       const auto& field = node->s(attr::name);
+       const auto slot = type->getAttributeSlot(field);
+       return [slot](Stack& stack) {
+         auto v = pop(stack);
+         auto userObj = pop(stack).toUserObject();
+         userObj->setSlot(slot, std::move(v));
+         return 0;
+       };
+     })});
 
 // define implementations for primitive number ops
 #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
index 79e23f5..6c56d6a 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/script/compiler.h>
+
 #include <c10/util/Exception.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -566,6 +567,7 @@ struct to_ir {
       const SugaredValuePtr& self,
       Block* block) {
     auto schema = extractSchemaFromDef(def, self);
+    // TODO need guards on init returning none
     if (schema.returns().size() == 1) {
       def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
     }
@@ -620,8 +622,9 @@ struct to_ir {
       const SugaredValuePtr& self) {
     auto params_begin = decl.params().begin();
     auto params_end = decl.params().end();
-    if (self)
+    if (self) {
       ++params_begin;
+    }
     std::vector<Argument> retval;
 
     std::vector<Expr> default_types;
@@ -690,7 +693,7 @@ struct to_ir {
   FunctionSchema extractSchemaFromDef(
       const Def& def,
       const SugaredValuePtr& self) {
-    auto name = def.name().name();
+    const auto name = def.name().name();
     std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
     std::vector<Argument> returns = parseReturnFromDecl(def.decl());
     return FunctionSchema(
@@ -706,8 +709,10 @@ struct to_ir {
     // inputs
     auto it = def.decl().params().begin();
     auto end = def.decl().params().end();
-    auto expected_annotation_size =
-        self ? def.decl().params().size() - 1 : def.decl().params().size();
+    auto expected_annotation_size = def.decl().params().size();
+    if (self) {
+      expected_annotation_size--;
+    }
     if (schema.arguments().size() != expected_annotation_size) {
       throw ErrorReport(def.decl().params().range())
           << "Number of type annotations for"
@@ -715,9 +720,19 @@ struct to_ir {
           << " does not match the number of parameters on the function ("
           << expected_annotation_size << ")!";
     }
+
     if (self) {
       AT_ASSERT(it != end);
-      environment_stack->setSugaredVar(def.range(), (*it).ident().name(), self);
+      const auto& name = (*it).ident().name();
+      if (auto userType = dynamic_cast<UserTypeValue*>(self.get())) {
+        const auto type = userType->type_;
+        Value* new_input =
+            block->addInput()->setUniqueName(name)->setType(type);
+        environment_stack->setVar((*it).ident().range(), name, new_input);
+        arguments.emplace_back(name, type);
+      } else {
+        environment_stack->setSugaredVar(def.range(), name, self);
+      }
       ++it;
     }
     size_t arg_annotation_idx = 0;
@@ -1796,6 +1811,9 @@ struct to_ir {
       case TK_TUPLE_LITERAL:
         emitTupleAssign(TupleLiteral(stmt.lhs()), stmt.rhs());
         break;
+      case '.':
+        emitSelectAssign(stmt);
+        break;
       case TK_SUBSCRIPT:
         emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
         break;
@@ -1805,6 +1823,18 @@ struct to_ir {
     }
   }
 
+  void emitSelectAssign(const Assign& stmt) {
+    const auto lhs = Select(stmt.lhs());
+    const auto basename = Var(lhs.value()).name();
+    const auto rhsValue =
+        emitSugaredExpr(stmt.rhs(), 1)->asValue(stmt.rhs().range(), method);
+    auto userObject = environment_stack->getSugaredVar(basename);
+    const bool shouldDefine =
+        method.name() == "__init__" && basename.name() == "self";
+    userObject->setAttr(
+        stmt.range(), method, lhs.selector().name(), rhsValue, shouldDefine);
+  }
+
   NodeKind getNodeKind(int kind, int ninputs) {
     switch (kind) {
       case '+':
index e04477f..8d59d34 100644 (file)
@@ -311,7 +311,7 @@ struct ModuleValue : public SugaredValue {
     if (NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleValue>(v->module);
     } else if (Method* v = module->find_method(field)) {
-      return std::make_shared<MethodValue>(module, *v);
+      return std::make_shared<MethodValue>(shared_from_this(), *v);
     } else if (NamedParameter* v = module->find_parameter(field)) {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     }
@@ -524,6 +524,15 @@ std::shared_ptr<SugaredValue> toSugaredValue(
           << "Attempted to inline a Module with parameters. "
              "Stateful modules to be inlined must be submodules of the callee.";
     }
+    const auto script_class_type =
+        py::module::import("torch.jit").attr("ScriptClass");
+    const bool is_user_type = py::isinstance(obj, script_class_type);
+    if (is_user_type) {
+      const auto classname = py::cast<std::string>(py::getattr(obj, "_name"));
+      auto userType = UserType::get(classname);
+      AT_ASSERT(userType);
+      return std::make_shared<UserTypeValue>(std::move(userType));
+    }
     return std::make_shared<ModuleValue>(mod);
   } else if (py::isinstance<py::module>(obj)) {
     return std::make_shared<PythonModuleValue>(obj);
@@ -958,6 +967,26 @@ void initJitScriptBindings(PyObject* module) {
         return mod;
       });
 
+  m.def(
+      "_jit_script_class_compile",
+      [](std::shared_ptr<Module> module,
+         const ClassDef& classDef,
+         ResolutionCallback rcb) {
+        auto userType = UserType::create(classDef.name().name(), module);
+        std::vector<Resolver> rcbs;
+        std::vector<Def> methodDefs;
+        for (const auto& def : classDef.defs()) {
+          methodDefs.push_back(def);
+          rcbs.push_back(pythonResolver(rcb));
+        }
+        defineMethodsInModule(
+            module,
+            methodDefs,
+            rcbs,
+            std::make_shared<UserTypeValue>(userType));
+        return module;
+      });
+
   m.def("parse_type_comment", [](const std::string& comment) {
     Parser p(comment);
     return Decl(p.parseTypeComment());
index fa7a265..aaaa2e8 100644 (file)
@@ -95,7 +95,8 @@ namespace script {
   _(TK_RAISE, "raise", "raise")                  \
   _(TK_ASSERT, "assert", "assert")               \
   _(TK_DOTS, "dots", "...")                      \
-  _(TK_PASS, "pass", "pass")
+  _(TK_PASS, "pass", "pass")                     \
+  _(TK_CLASS_DEF, "class", "class")
 
 static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|";
 
index 2035956..5769916 100644 (file)
@@ -120,6 +120,11 @@ void initTreeViewBindings(PyObject* module) {
         const auto& r = name.range();
         return Def::create(r, name, decl, wrap_list(r, std::move(body)));
       }));
+  py::class_<ClassDef, TreeView>(m, "ClassDef")
+      .def(py::init([](const Ident& name, std::vector<Def> body) {
+        const auto& r = name.range();
+        return ClassDef::create(r, name, wrap_list(r, std::move(body)));
+      }));
   py::class_<Decl, TreeView>(m, "Decl").def(py::init(
       [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
         return Decl::create(
index aadc0a7..8792178 100644 (file)
@@ -178,6 +178,9 @@ TypePtr parseTypeFromExpr(const Expr& expr) {
     if (itr != ident_to_type_lut().end()) {
       return itr->second;
     }
+    if (auto typePtr = UserType::get(*name)) {
+      return typePtr;
+    }
     throw ErrorReport(expr) << "Unknown type name " << *name;
   }
   throw ErrorReport(expr.range())
index 1ebb98e..b5e4d80 100644 (file)
@@ -73,11 +73,11 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
     Method& m,
     const std::string& field) {
   // Allow method-style casts on Tensor types. e.g. x.int()
-  if (value->type()->isSubtypeOf(TensorType::get())) {
+  if (value_->type()->isSubtypeOf(TensorType::get())) {
     if (builtin_cast_methods().count(field)) {
       return std::make_shared<BuiltinFunction>(
           Symbol::aten(builtin_cast_methods().at(field)),
-          NamedValue(loc, "self", value));
+          NamedValue(loc, "self", value_));
     }
     // functions that are just direct property lookups on tensor
     // must be registered as prim::<name>(Tensor t) -> <return_type>
@@ -90,31 +90,47 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
     };
     if (fields.count(field)) {
       auto r =
-          m.graph()->insert(Symbol::fromQualString("prim::" + field), {value});
+          m.graph()->insert(Symbol::fromQualString("prim::" + field), {value_});
       return std::make_shared<SimpleValue>(r);
     }
   }
-  if (getValue()->type()->isSubtypeOf(NumberType::get())) {
+  if (value_->type()->isSubtypeOf(NumberType::get())) {
     throw ErrorReport(loc) << "Cannot call methods on numbers";
   }
-  if (getValue()->type()->kind() == TypeKind::TupleType) {
-    auto tuple_type = getValue()->type()->expect<TupleType>();
+  if (auto tuple_type = value_->type()->cast<TupleType>()) {
     if (!tuple_type->hasNames()) {
       throw ErrorReport(loc) << "Getting attributes of tuples is not supported";
     }
     auto names = tuple_type->names();
-    for (int i = 0; i < names.size(); i++) {
+    for (size_t i = 0; i < names.size(); i++) {
       if (names[i] == field) {
         auto r = m.graph()
-                     ->insertNode(m.graph()->createTupleIndex(getValue(), i))
+                     ->insertNode(m.graph()->createTupleIndex(value_, i))
                      ->output();
         return std::make_shared<SimpleValue>(r);
       }
     }
     throw ErrorReport(loc) << "Unknown attribute to named tuple";
   }
+
+  if (auto userType = value_->type()->cast<UserType>()) {
+    // This is a user-defined type, emit the proper attribute lookup
+    if (auto method = userType->getMethod(field)) {
+      return std::make_shared<MethodValue>(shared_from_this(), *method);
+    }
+
+    if (!userType->hasAttribute(field)) {
+      throw ErrorReport(loc)
+          << "Tried to access to nonexistent attribute " << field
+          << ". Did you forget to initialize it in __init__()?";
+    }
+    auto& g = *m.graph();
+    auto n = g.insertNode(g.createGetAttr(value_, field));
+    return std::make_shared<SimpleValue>(n->output());
+  }
+
   return std::make_shared<BuiltinFunction>(
-      Symbol::aten(field), NamedValue(loc, "self", value));
+      Symbol::aten(field), NamedValue(loc, "self", value_));
 }
 
 std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
@@ -125,22 +141,91 @@ std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
       [](Value* v) -> std::shared_ptr<SugaredValue> {
     return std::make_shared<SimpleValue>(v);
   };
-  if (value->type()->kind() == TypeKind::TupleType) {
-    auto outputs = createTupleUnpack(value);
+  if (value_->type()->kind() == TypeKind::TupleType) {
+    auto outputs = createTupleUnpack(value_);
     return fmap(outputs, make_simple_value);
-  } else if (value->type()->kind() == TypeKind::ListType) {
+  } else if (value_->type()->kind() == TypeKind::ListType) {
     if (!size_hint) {
       throw ErrorReport(loc)
           << "cannot statically infer the expected size of a list in this context";
     }
-    auto graph = value->owningGraph();
+    auto graph = value_->owningGraph();
     Node* unpack =
-        graph->insertNode(graph->createListUnpack(value, *size_hint));
+        graph->insertNode(graph->createListUnpack(value_, *size_hint));
     return fmap(unpack->outputs(), make_simple_value);
   }
-  throw ErrorReport(loc) << value->type()->str()
+  throw ErrorReport(loc) << value_->type()->str()
                          << " cannot be used as a tuple";
 }
+
+void SimpleValue::setAttr(
+    const SourceRange& loc,
+    Method& m,
+    const std::string& field,
+    Value* newValue,
+    bool shouldDefine) {
+  const auto userType = value_->type()->cast<UserType>();
+  if (!userType) {
+    throw ErrorReport(loc) << "Tried to set an attribute: " << field
+                           << " on a non-user-defined type: "
+                           << value_->type()->str();
+  }
+
+  auto expectedType = userType->getAttribute(field);
+  if (!expectedType) {
+    // We don't have an attribute with this name, either add it to the type
+    // definition or throw an error
+    if (shouldDefine) {
+      userType->addAttribute(field, newValue->type());
+      expectedType = newValue->type();
+      const auto insertPoint = m.graph()->insertPoint();
+      const auto topLevelBlock = m.graph()->block();
+      if (insertPoint->owningBlock() != topLevelBlock) {
+        throw ErrorReport(loc)
+            << "First assignment cannot be in a control-flow block. "
+            << "Initialize the field at the top level first.";
+      }
+    } else {
+      throw ErrorReport(loc)
+          << "Tried to set nonexistent attribute: " << field
+          << ". Did you forget to initialize it in __init__()?";
+    }
+  }
+
+  // Check type correctness
+  const auto newType = newValue->type();
+  if (!newType->isSubtypeOf(expectedType)) {
+    throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
+                           << expectedType->str() << " but got "
+                           << newType->str();
+  }
+
+  auto& g = *m.graph();
+  g.insertNode(g.createSetAttr(value_, field, newValue));
+}
+
+std::shared_ptr<SugaredValue> UserTypeValue::call(
+    const SourceRange& loc,
+    Method& m,
+    // note: names for args will be 'argument 0', 'argument 1', etc..
+    at::ArrayRef<NamedValue> inputs,
+    at::ArrayRef<NamedValue> attributes,
+    size_t n_binders) {
+  AT_ASSERT(n_binders <= 1);
+
+  // Generate a new object of the right type, then call `__init__` on it
+  auto& g = *m.graph();
+  auto createNode = g.insertNode(g.createUserObject(type_));
+  auto self = std::make_shared<SimpleValue>(createNode->output());
+
+  auto initMethod = type_->getMethod("__init__");
+  AT_ASSERT(initMethod);
+
+  // Call the init function
+  MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders);
+
+  return self;
+}
 } // namespace script
 } // namespace jit
 } // namespace torch
index 5f998b1..f80b183 100644 (file)
@@ -39,6 +39,17 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
       const std::string& field) {
     throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
   }
+
+  // assign an attribute on it, e.g. `this.field = newValue`
+  virtual void setAttr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field,
+      Value* newValue,
+      bool shouldDefine) {
+    throw ErrorReport(loc) << "attribute assignment is not defined on "
+                           << kind();
+  }
   virtual NoneStatus isNone() {
     return NEVER;
   }
@@ -83,17 +94,17 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
 // most things in the environment are just simple value types
 // and not special python syntax sugar types
 struct TORCH_API SimpleValue : public SugaredValue {
-  SimpleValue(Value* value) : value(value) {}
+  SimpleValue(Value* value) : value_(value) {}
   std::string kind() const override {
     return "value";
   }
   Value* asValue(const SourceRange& range, Method& m) override {
-    return value;
+    return value_;
   }
   NoneStatus isNone() override {
-    if (value->mustBeNone())
+    if (value_->mustBeNone())
       return ALWAYS;
-    else if (value->type()->cast<OptionalType>())
+    else if (value_->type()->cast<OptionalType>())
       return MAYBE;
     else
       return NEVER;
@@ -106,12 +117,20 @@ struct TORCH_API SimpleValue : public SugaredValue {
       const SourceRange& loc,
       Method& m,
       const std::string& field) override;
+
+  void setAttr(
+      const SourceRange& loc,
+      Method& m,
+      const std::string& field,
+      Value* newValue,
+      bool shouldDefine) override;
+
   Value* getValue() const {
-    return value;
+    return value_;
   }
 
  private:
-  Value* value;
+  Value* value_;
 };
 
 struct TORCH_API BuiltinFunction : public SugaredValue {
@@ -157,12 +176,30 @@ struct TORCH_API BuiltinModule : public SugaredValue {
   c10::optional<int64_t> version;
 };
 
+// Represents a user type, analagous to `int` or `dict`
+struct TORCH_API UserTypeValue : public SugaredValue {
+  UserTypeValue(UserTypePtr type) : type_(std::move(type)) {}
+
+  // Call the type's constructor, as in:
+  //    n = Foo(constructor_arg)
+  std::shared_ptr<SugaredValue> call(
+      const SourceRange& loc,
+      Method& m,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override;
+
+  std::string kind() const override {
+    return type_->str();
+  }
+
+  UserTypePtr type_;
+};
+
 // defines how a method obtained from a module behaves in script
 struct MethodValue : public SugaredValue {
-  MethodValue(std::shared_ptr<Module> module, Method& method)
-      : module(std::move(module)) // insurance that method stays alive
-        ,
-        method(method) {}
+  MethodValue(std::shared_ptr<SugaredValue> self, Method& method)
+      : self_(std::move(self)), method(method) {}
   std::string kind() const override {
     return "method";
   }
@@ -172,12 +209,22 @@ struct MethodValue : public SugaredValue {
       at::ArrayRef<NamedValue> inputs,
       at::ArrayRef<NamedValue> attributes,
       size_t n_binders) override {
+    if (auto userType = dynamic_cast<SimpleValue*>(self_.get())) {
+      // If self_ is a user-defined type, then it will be expected as part of
+      // the schema. Add it to the front of the inputs.
+      std::vector<NamedValue> inputsWithSelf;
+      inputsWithSelf.emplace_back(loc, userType->getValue());
+      inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end());
+      return std::make_shared<SimpleValue>(
+          caller.emit_call_to(loc, method, inputsWithSelf, attributes));
+    }
+
     return std::make_shared<SimpleValue>(
         caller.emit_call_to(loc, method, inputs, attributes));
   }
 
  private:
-  std::shared_ptr<Module> module;
+  std::shared_ptr<SugaredValue> self_;
   Method& method;
 };
 
index 853a32f..a560a9e 100644 (file)
@@ -22,6 +22,7 @@ namespace script {
 //
 // Decl  = Decl(List<Param> params, Maybe<Expr> return_type)            TK_DECL
 // Def   = Def(Ident name, Decl decl, List<Stmt> body)                  TK_DEF
+// ClassDef = ClassDef(Ident name, List<Def> body)                      TK_CLASS_DEF
 //
 // Stmt  = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body)   TK_IF
 //       | For(List<Expr> targets, List<Expr> iters, List<Stmt> body)   TK_FOR
@@ -400,6 +401,28 @@ struct Def : public TreeView {
   }
 };
 
+struct ClassDef : public TreeView {
+  explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
+    tree->match(TK_CLASS_DEF);
+  }
+  ClassDef withName(std::string new_name) const {
+    auto new_ident = Ident::create(name().range(), std::move(new_name));
+    return create(range(), new_ident, defs());
+  }
+  Ident name() const {
+    return Ident(subtree(0));
+  }
+  List<Def> defs() const {
+    return List<Def>(subtree(1));
+  }
+  static ClassDef create(
+      const SourceRange& range,
+      const Ident& name,
+      const List<Def>& defs) {
+    return ClassDef(Compound::create(TK_CLASS_DEF, range, {name, defs}));
+  }
+};
+
 ////////////////////////////////////////////////////////////////////////////////
 // Statements
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/torch/csrc/jit/script/user_type.cpp b/torch/csrc/jit/script/user_type.cpp
new file mode 100644 (file)
index 0000000..29452eb
--- /dev/null
@@ -0,0 +1,12 @@
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/jit/script/module.h>
+
+namespace c10 {
+
+// This file exists because we need to reference module.h, which we can't from
+// c10. Sigh...
+Method* UserType::getMethod(const std::string& name) const {
+  return module_->find_method(name);
+}
+
+} // namespace c10
index f7494b0..3eace9b 100644 (file)
@@ -3,7 +3,7 @@ from torch import Tensor
 from torch.autograd import Variable, function
 from torch.serialization import validate_cuda_device
 from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
-from torch.jit.frontend import get_jit_ast, get_default_args
+from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
 import torch.backends.cudnn as cudnn
 import torch.jit.annotations
 import torch._jit_internal as _jit_internal
@@ -56,6 +56,7 @@ _enabled = _parse_env('PYTORCH_JIT', True, "> Using PyTorch JIT", "> PyTorch JIT
 _flatten = torch._C._jit_flatten
 _unflatten = torch._C._jit_unflatten
 _jit_script_compile = torch._C._jit_script_compile
+_jit_script_class_compile = torch._C._jit_script_class_compile
 BatchTensor = torch._C._jit.BatchTensor
 
 Future = torch._C.Future
@@ -712,18 +713,24 @@ def _try_compile_weak_script(fn):
         return entry["compiled_fn"]
 
 
-def script(fn, optimize=True, _frames_up=0, _rcb=None):
+def script(obj, optimize=True, _frames_up=0, _rcb=None):
     if not _enabled:
-        return fn
+        return obj
     if _rcb is None:
         _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
-    ast = get_jit_ast(fn, is_method=False)
-    mod = ScriptModule()
-    _jit_script_compile(mod, ast, _rcb, get_default_args(fn))
+    if inspect.isclass(obj):
+        mod = ScriptClass(obj.__name__)
+        ast = get_jit_class_def(obj)
+        _jit_script_class_compile(mod, ast, _rcb)
+    else:
+        mod = ScriptModule()
+        ast = get_jit_def(obj)
+        _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
     # Forward docstrings
-    mod.__doc__ = fn.__doc__
+    mod.__doc__ = obj.__doc__
     return mod
 
+
 ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
 
 
@@ -744,7 +751,7 @@ def script_method(fn, _rcb=None):
     # function (the calling function). Adding 2 gets us to the proper surrounding scope.
     if _rcb is None:
         _rcb = _jit_internal.createResolutionCallback(frames_up=2)
-    ast = get_jit_ast(fn, is_method=True)
+    ast = get_jit_def(fn, self_name="ScriptModule")
     return ScriptMethodStub(_rcb, ast, fn)
 
 
@@ -1268,11 +1275,20 @@ if _enabled:
                                      "weak script module once it has been "
                                      "created".format(attr))
 
+    class ScriptClass(ScriptModule):
+        def __init__(self, name):
+            super(ScriptClass, self).__init__()
+            self._name = name
+
 else:
     class ScriptModule(torch.nn.Module):
         def __init__(self, optimize=True):
             super(ScriptModule, self).__init__()
 
+    class ScriptClass(ScriptModule):
+        def __init__(self, name):
+            super(ScriptClass, self).__init__()
+
 
 def _get_weak_stubs(cls):
     """
index eb727ba..70089dd 100644 (file)
@@ -137,14 +137,27 @@ def _uses_true_division(fn):
             '_uses_true_division: expected function or method, got {}'.format(type(fn)))
 
 
-def get_jit_ast(fn, is_method):
+def get_jit_class_def(cls, self_name=None):
+    # Get defs for each method independently
+    methods = inspect.getmembers(
+        cls, predicate=lambda m: inspect.ismethod(m) or inspect.isfunction(m))
+    method_defs = [get_jit_def(method[1],
+                   self_name=cls.__name__) for method in methods]
+
+    source = dedent(inspect.getsource(cls))
+    py_ast = ast.parse(source)
+    ctx = SourceContext(source, False)
+    return build_class_def(ctx, py_ast.body[0], method_defs)
+
+
+def get_jit_def(fn, self_name=None):
     source = dedent(inspect.getsource(fn))
     py_ast = ast.parse(source)
     if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
         raise RuntimeError("expected a single top-level function")
     type_line = torch.jit.annotations.get_type_line(source)
     ctx = SourceContext(source, _uses_true_division(fn))
-    return build_def(ctx, py_ast.body[0], type_line, is_method)
+    return build_def(ctx, py_ast.body[0], type_line, self_name)
 
 
 # Thin wrapper around SourceRangeFactory to store extra metadata
@@ -163,17 +176,22 @@ class Builder(object):
         return method(ctx, node)
 
 
-def build_def(ctx, py_def, type_line, is_method):
-    returns = []
-    ret_body = []
+def build_class_def(ctx, py_def, methods):
+    r = ctx.make_range(py_def.lineno, py_def.col_offset,
+                       py_def.col_offset + len("class"))
+    return ClassDef(Ident(r, py_def.name), methods)
+
+
+def build_def(ctx, py_def, type_line, self_name=None):
     body = py_def.body
     r = ctx.make_range(py_def.lineno, py_def.col_offset,
                        py_def.col_offset + len("def"))
-    param_list = build_param_list(ctx, py_def.args)
+    param_list = build_param_list(ctx, py_def.args, self_name)
     return_type = None
     if getattr(py_def, 'returns', None) is not None:
         return_type = build_expr(ctx, py_def.returns)
     decl = Decl(r, param_list, return_type)
+    is_method = self_name is not None
     if type_line is not None:
         type_comment_decl = torch._C.parse_type_comment(type_line)
         decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
@@ -186,24 +204,26 @@ _vararg_kwarg_err = ("Compiled functions can't take variable number of arguments
                      "or use keyword-only arguments with defaults")
 
 
-def build_param_list(ctx, py_args):
+def build_param_list(ctx, py_args, self_name):
     if py_args.vararg is not None or py_args.kwarg is not None:
         raise ValueError(_vararg_kwarg_err)
     if not PY2 and py_args.kw_defaults:
         raise ValueError(_vararg_kwarg_err)
-    result = [build_param(ctx, arg, False) for arg in py_args.args]
+    result = [build_param(ctx, arg, self_name, False) for arg in py_args.args]
     if not PY2:
-        result += [build_params(ctx, arg, True) for arg in py_args.kwonlyargs]
+        result += [build_params(ctx, arg, self_name, True) for arg in py_args.kwonlyargs]
     return result
 
 
-def build_param(ctx, py_arg, kwarg_only):
+def build_param(ctx, py_arg, self_name, kwarg_only):
     # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
     #     In Python2 py_arg is a Name (Expr subclass)
     name = py_arg.id if PY2 else py_arg.arg
     r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
     if getattr(py_arg, 'annotation', None) is not None:
         annotation_expr = build_expr(ctx, py_arg.annotation)
+    elif self_name is not None and name == 'self':
+        annotation_expr = Var(Ident(r, self_name))
     else:
         annotation_expr = Var(Ident(r, 'Tensor'))
     return Param(annotation_expr, Ident(r, name), kwarg_only)