Make Object hold its ClassType (#18467)
authorZachary DeVito <zdevito@fb.com>
Fri, 5 Apr 2019 20:33:14 +0000 (13:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 20:40:59 +0000 (13:40 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18467
ghimport-source-id: d51bdd64d2529d08c634c58df1a0870b54ad49fb

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18469 Create Object that represents a Module
* #18468 slots with explicit value/setValue make more sense in future patches
* **#18467 Make Object hold its ClassType**
* #18379 Enforce single parent for script submodules
* #18378 Unify namespace of script::Module
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

Currently it holds a symbol whose unqualified name is the name of the
class. This will get confusing when there are multiple possible registries,
and it makes getting the class type from the object difficult.
The pointer to the class is only 4 more bytes so this patch just puts
it in the object.

Reviewed By: suo

Differential Revision: D14613510

fbshipit-source-id: b35175ba4be83d2522deaa6dad5070d6ec691fed

aten/src/ATen/core/ivalue.cpp
aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
aten/src/ATen/core/type.cpp
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/register_prim_ops.cpp

index 0a2d2a0..f89ff4b 100644 (file)
@@ -1,4 +1,5 @@
 #include <ATen/core/ivalue.h>
+#include <ATen/core/jit_type.h>
 #include <ATen/core/Formatting.h>
 #include <cmath>
 
@@ -95,7 +96,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
       return printDict(out, v.toGenericDict());
     case IValue::Tag::Object:
       // TODO we should print the object contents
-      return out << "Object<" << v.toObject()->name().toUnqualString()
+      return out << "Object<" << v.toObject()->name()
                  << ">";
   }
   AT_ERROR("Tag not found\n");
@@ -107,4 +108,9 @@ void IValue::dump() const {
   std::cout << *this << "\n";
 }
 
+
+const std::string& ivalue::Object::name() const {
+  return this->type_->name();
+}
+
 } // namespace c10
index dd4ed54..2151527 100644 (file)
@@ -14,6 +14,7 @@
 
 namespace c10 {
 struct IValue;
+struct ClassType;
 
 namespace ivalue {
 
@@ -683,14 +684,14 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
 // User-defined object.
 struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
  public:
-  Object(Symbol name, size_t numSlots) : typename_(std::move(name)) {
+  Object(std::shared_ptr<ClassType> type, size_t numSlots) : type_(std::move(type)) {
     slots_.resize(numSlots);
   }
 
   static c10::intrusive_ptr<Object> create(
-      Symbol name,
+      std::shared_ptr<ClassType> type,
       size_t numSlots) {
-    return c10::make_intrusive<Object>(std::move(name), numSlots);
+    return c10::make_intrusive<Object>(std::move(type), numSlots);
   }
 
   void setSlot(size_t slot, IValue v) {
@@ -701,14 +702,13 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
     return slots_.at(slot);
   }
 
-  Symbol name() const {
-    return typename_;
-  }
+  const std::string& name() const;
+
   const std::vector<IValue>& slots() const {
     return slots_;
   }
  private:
-  const Symbol typename_;
+  std::shared_ptr<ClassType> type_;
   std::vector<IValue> slots_;
 };
 
index be99f93..f98d3b7 100644 (file)
@@ -1166,7 +1166,7 @@ struct CAFFE2_API ClassType : public Type {
   Method* getMethod(const std::string& name) const;
   std::vector<Method*> methods() const;
 
-  std::string name() const {
+  const std::string& name() const {
     return typename_;
   }
 
index fe93a2c..c1885c7 100644 (file)
@@ -492,6 +492,7 @@ ClassTypePtr ClassType::get(const std::string& name) {
   return getRegistry().getType(name);
 }
 
+
 void ClassType::clearRegistry() {
   getRegistry().clear();
 }
index dedc639..64a79ca 100644 (file)
@@ -226,9 +226,8 @@ inline IValue toIValue(
     case TypeKind::ClassType: {
       auto classType = type->expect<ClassType>();
       // 1. create a bare ivalue
-      const auto name = Symbol::user(classType->name());
       const size_t numAttrs = classType->numAttributes();
-      auto userObj = c10::ivalue::Object::create(name, numAttrs);
+      auto userObj = c10::ivalue::Object::create(classType, numAttrs);
 
       // 2. copy all the contained types
       for (size_t slot = 0; slot < numAttrs; slot++) {
@@ -346,10 +345,10 @@ inline py::object toPyObject(IValue&& ivalue) {
     return std::move(py_dict);
   } else if (ivalue.isObject()) {
     const auto obj = ivalue.toObject();
-    const auto classType = ClassType::get(obj->name().toUnqualString());
+    const auto classType = ClassType::get(obj->name());
     AT_ASSERT(classType);
     auto pyClass = py::module::import("torch.jit")
-                       .attr("_get_script_class")(obj->name().toUnqualString());
+                       .attr("_get_script_class")(obj->name());
     auto pyObj = pyClass.attr("__new__")(pyClass);
 
 
index 688ca0e..27393d2 100644 (file)
@@ -858,10 +858,9 @@ RegisterOperators reg(
          prim::CreateObject,
          [](const Node* node) {
            const auto type = node->output()->type()->expect<ClassType>();
-           const auto name = Symbol::user(type->name());
            const size_t numAttrs = type->numAttributes();
-           return [name, numAttrs](Stack& stack) {
-             auto userObj = c10::ivalue::Object::create(name, numAttrs);
+           return [type, numAttrs](Stack& stack) {
+             auto userObj = c10::ivalue::Object::create(type, numAttrs);
              push(stack, std::move(userObj));
              return 0;
            };