Add Slot type to abstract the raw pointers being used for slots. (#18226)
authorZachary DeVito <zdevito@fb.com>
Thu, 28 Mar 2019 17:31:45 +0000 (10:31 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Mar 2019 17:35:36 +0000 (10:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18226
ghimport-source-id: b9ec8651212875b30971cc6859d2ddec6559ae3a

If modules become first-class IValues, then the slots will no longer be raw pointers but (IValue, index) pairs. This commit inserts the Slot abstraction so that this change can be made in later patches.

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18226 Add Slot type to abstract the raw pointers being used for slots.**

Differential Revision: D14542022

fbshipit-source-id: b81d7f4334c983d663e7551bda82df43680d7c5f

torch/csrc/jit/export.cpp
torch/csrc/jit/import_source.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/slot.h [new file with mode: 0644]

index effc760..fe96902 100644 (file)
@@ -734,8 +734,8 @@ void ScriptModuleSerializer::convertModule(
     auto& attribute = item.value();
     // Add attribute to ModuleDef
     torch::AttributeDef* attribute_def = module_def->add_attributes();
-    attribute_def->set_name(attribute.name_);
-    attribute_def->set_type(attribute.type->python_str());
+    attribute_def->set_name(attribute.name());
+    attribute_def->set_type(attribute.type()->python_str());
 
     attribute_table_.push_back(*attribute.slot());
     attribute_def->set_id(attribute_table_.size() - 1);
@@ -775,7 +775,7 @@ void ScriptModuleSerializer::convertParameter(
     const script::NamedIValue& param,
     torch::ParameterDef* param_def,
     bool is_parameter) {
-  param_def->set_name(param.name_);
+  param_def->set_name(param.name());
   param_def->set_is_buffer(is_parameter);
   param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
 }
index c2b633d..6364947 100644 (file)
@@ -27,7 +27,7 @@ struct ModuleAccessorValue : public SugaredValue {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     } else if (script::NamedIValue* v = module->find_attribute(field)) {
       return std::make_shared<script::SimpleValue>(
-          m.get_or_add_attribute(v->type, v->slot()));
+          m.get_or_add_attribute(v->type(), v->slot()));
     } else if (Method* m = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *m);
     } else {
index bc4fee3..3e37054 100644 (file)
@@ -130,14 +130,14 @@ struct QualifiedName : c10::intrusive_ptr_target {
 void createTensorToParameterNameMap(
     const script::Module& module,
     const QualifiedNamePtr& prefix,
-    std::unordered_map<IValue*, QualifiedNamePtr>& result) {
+    std::unordered_map<script::Slot, QualifiedNamePtr>& result) {
   for (const auto& elem : module.get_parameters()) {
     const script::NamedIValue& param = elem.value();
-    result[param.slot()] = QualifiedName::create(prefix, param.name_);
+    result[param.slot()] = QualifiedName::create(prefix, param.name());
   }
   for (const auto& elem : module.get_attributes()) {
     const script::NamedIValue& param = elem.value();
-    result[param.slot()] = QualifiedName::create(prefix, param.name_);
+    result[param.slot()] = QualifiedName::create(prefix, param.name());
   }
   for (const auto& elem : module.get_modules()) {
     createTensorToParameterNameMap(
@@ -1114,7 +1114,7 @@ struct PythonPrintPass {
     }
   }
   void printMethod(script::Method& method) {
-    std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
+    std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
     createTensorToParameterNameMap(
         method.owner(), QualifiedName::create("self"), extra_ivalue_names);
     printMethod(method, /*is_class=*/false, extra_ivalue_names);
@@ -1122,10 +1122,10 @@ struct PythonPrintPass {
   void printMethod(
       script::Method& method,
       bool is_class,
-      const std::unordered_map<IValue*, QualifiedNamePtr>& extra_ivalue_names) {
+      const std::unordered_map<script::Slot, QualifiedNamePtr>& extra_ivalue_names) {
     std::vector<std::string> ivalue_names = fmap(
         method.initial_ivalues(),
-        [&](IValue* slot) { return extra_ivalue_names.at(slot)->str(); });
+        [&](const script::Slot& slot) { return extra_ivalue_names.at(slot)->str(); });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
     auto defaults = fmap(
@@ -1134,7 +1134,7 @@ struct PythonPrintPass {
     printFunction(graph, name, is_class, defaults, ivalue_names);
   }
   void printModule(script::Module& module) {
-    std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
+    std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
     createTensorToParameterNameMap(
         module, QualifiedName::create("self"), extra_ivalue_names);
     for (auto& method : module.get_methods()) {
@@ -1153,7 +1153,7 @@ struct PythonPrintPass {
     out << "class " << classType->name() << ":\n";
     {
       const auto guard = WithIndented();
-      std::unordered_map<IValue*, QualifiedNamePtr> extra_ivalue_names;
+      std::unordered_map<script::Slot, QualifiedNamePtr> extra_ivalue_names;
       for (auto& method : classType->methods()) {
         printMethod(*method, /*is_class=*/true, extra_ivalue_names);
       }
index 0692763..d3efa4d 100644 (file)
@@ -329,7 +329,7 @@ struct ModuleValue : public SugaredValue {
       return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
     } else if (NamedIValue* v = module->find_attribute(field)) {
       return std::make_shared<SimpleValue>(
-          m.get_or_add_attribute(v->type, v->slot()));
+          m.get_or_add_attribute(v->type(), v->slot()));
     }
 
     // This can also be a call to a non-script module, or a plain
@@ -600,13 +600,13 @@ py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
 }
 
 static void gatherParametersAndBuffers(
-    std::vector<IValue*>& values,
+    std::vector<Slot>& values,
     const Module& m) {
   for (auto& param : m.get_parameters()) {
     values.push_back(param->slot());
   }
   for (auto& param : m.get_attributes()) {
-    if (param->type->isSubtypeOf(TensorType::get())) {
+    if (param->type()->isSubtypeOf(TensorType::get())) {
       values.push_back(param->slot());
     }
   }
@@ -794,7 +794,7 @@ void initJitScriptBindings(PyObject* module) {
               py::tuple r(3);
               IValue v = *buffer->slot();
               result[i] = std::make_tuple(
-                  buffer.key(), buffer->type, toPyObject(std::move(v)));
+                  buffer.key(), buffer->type(), toPyObject(std::move(v)));
             }
             return result;
           })
@@ -844,10 +844,10 @@ void initJitScriptBindings(PyObject* module) {
              bool force_outplace) {
             // prereq: Module's buffers and parameters are unique
             // this was ensured in python before calling this function
-            std::vector<IValue*> parameters;
+            std::vector<Slot> parameters;
             gatherParametersAndBuffers(parameters, *self);
             Stack inputs = toStack(input_tuple);
-            for (IValue* param : parameters) {
+            for (const Slot& param : parameters) {
               inputs.emplace_back(*param);
             }
             auto graph = tracer::createGraphByTracing(
@@ -938,7 +938,7 @@ void initJitScriptBindings(PyObject* module) {
              std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
                  params,
              std::shared_ptr<Module> orig) {
-            std::vector<IValue*> member_inputs;
+            std::vector<Slot> member_inputs;
             for (auto& p : params) {
               NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
               if (np == nullptr) {
@@ -967,7 +967,13 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "propagate_and_assign_input_and_output_shapes",
           &Method::propagate_and_assign_input_and_output_shapes)
-      .def("initial_ivalues", &Method::initial_ivalues)
+      .def("initial_ivalues",[](Method& m) {
+        std::vector<at::Tensor> tensors;
+        for (auto& t : m.initial_ivalues()) {
+          tensors.push_back(t->toTensor());
+        }
+        return tensors;
+      })
       .def(
           "graph_for",
           [](py::args args, py::kwargs kwargs) {
index 17147e5..6d4d746 100644 (file)
@@ -8,6 +8,8 @@
 #include <torch/csrc/jit/named_value.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
 #include <torch/csrc/jit/source_range.h>
+#include <torch/csrc/jit/script/slot.h>
+
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/api/include/torch/ordered_dict.h>
@@ -58,7 +60,7 @@ struct Method {
       std::string name,
       bool optimize,
       std::shared_ptr<Graph> graph,
-      std::vector<IValue*> initial_members,
+      std::vector<Slot> initial_members,
       std::function<void(Method&)> method_creator)
       : owner_(owner),
         name_(std::move(name)),
@@ -120,12 +122,12 @@ struct Method {
   size_t num_inputs() const {
     return graph()->inputs().size() - initial_ivalues_.size();
   }
-  TORCH_API Value* get_or_add_parameter(IValue* slot) {
+  TORCH_API Value* get_or_add_parameter(Slot slot) {
     AT_ASSERT(slot->isTensor());
     return get_or_add_attribute(TensorType::get(), slot);
   }
 
-  TORCH_API Value* get_or_add_attribute(TypePtr type, IValue* slot) {
+  TORCH_API Value* get_or_add_attribute(TypePtr type, Slot slot) {
     auto it = initial_ivalue_index.find(slot);
     if (it != initial_ivalue_index.end()) {
       return graph()->inputs().at(it->second);
@@ -144,7 +146,7 @@ struct Method {
     for (at::Tensor& i : inputs) {
       stack.emplace_back(std::move(i));
     }
-    for (IValue* inp : initial_ivalues_) {
+    for (const Slot& inp : initial_ivalues_) {
       stack.push_back(*inp);
     }
     const auto size = stack.size();
@@ -195,7 +197,7 @@ struct Method {
     return retval;
   }
 
-  const std::vector<IValue*>& initial_ivalues() const {
+  const std::vector<Slot>& initial_ivalues() const {
     return initial_ivalues_;
   }
 
@@ -324,11 +326,11 @@ struct Method {
   // each is a pointer to a slot in the module that owns this parameter
   // parameters and submodules can only be _added_ to script Modules to ensure
   // these pointers always stay valid
-  std::vector<IValue*> initial_ivalues_;
+  std::vector<Slot> initial_ivalues_;
 
   // map from a IValue* in initial_ivalues to the offset it appears at
   // in graph. used to accelerate get_or_add_parameter
-  std::unordered_map<IValue*, size_t> initial_ivalue_index;
+  std::unordered_map<Slot, size_t> initial_ivalue_index;
 
   // TODO: support that case where we allow _writes_ to parameters from
   // compiled functions.
@@ -361,15 +363,22 @@ struct NamedModule {
 struct NamedIValue {
   NamedIValue(std::string name, TypePtr type, IValue ivalue)
       : name_(name),
-        type(type),
-        ivalue(torch::make_unique<IValue>(std::move(ivalue))) {}
+        type_(type),
+        ivalue_(torch::make_unique<IValue>(std::move(ivalue))) {}
 
-  IValue* slot() const {
-    return ivalue.get();
+  Slot slot() const {
+    return Slot(ivalue_.get());
+  }
+  const std::string& name() const {
+    return name_;
+  }
+  const TypePtr& type() const {
+    return type_;
   }
+private:
   const std::string name_;
-  const TypePtr type;
-  std::unique_ptr<IValue> ivalue;
+  const TypePtr type_;
+  std::unique_ptr<IValue> ivalue_;
 };
 
 struct Module {
@@ -397,7 +406,7 @@ struct Module {
 
   void register_buffer(const std::string& name, autograd::Variable v) {
     if (auto b = attributes.find(name)) {
-      AT_ASSERT(b->type->isSubtypeOf(TensorType::get()));
+      AT_ASSERT(b->type()->isSubtypeOf(TensorType::get()));
       *b->slot() = v;
       return;
     }
@@ -432,7 +441,7 @@ struct Module {
   Method& create_method(
       const std::string& name,
       std::shared_ptr<Graph> graph,
-      std::vector<IValue*> member_inputs) {
+      std::vector<Slot> member_inputs) {
     AT_ASSERT(graph);
     std::unique_ptr<Method> method(new Method(
         this,
@@ -457,7 +466,7 @@ struct Module {
     return *methods.insert(name, std::move(method));
   }
 
-  IValue* parameter_slot(const std::string& name) const {
+  Slot parameter_slot(const std::string& name) const {
     return parameters[name].slot();
   }
 
@@ -506,7 +515,7 @@ struct Module {
   }
   NamedIValue* find_buffer(const std::string& name) {
     auto b = attributes.find(name);
-    if (b && b->type->isSubtypeOf(TensorType::get())) {
+    if (b && b->type()->isSubtypeOf(TensorType::get())) {
       return b;
     }
     return nullptr;
@@ -604,7 +613,7 @@ struct Module {
       ModuleLookup module_lookup,
       // parameter_remap is needed when a parent module uses a parameter of a
       // submodule
-      std::unordered_map<IValue*, IValue*>& parameter_remap,
+      std::unordered_map<Slot, Slot>& parameter_remap,
       std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
     for (auto& kv : parameters) {
@@ -615,7 +624,7 @@ struct Module {
       parameter_remap[kv.value().slot()] = curr->parameter_slot(kv.key());
     }
     for (auto& kv : attributes) {
-      if (!kv.value().type->isSubtypeOf(TensorType::get())) {
+      if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
       curr->register_buffer(
@@ -631,7 +640,7 @@ struct Module {
       names.pop_back();
     }
     for (auto& kv : methods) {
-      std::vector<IValue*> initial_ivalues;
+      std::vector<Slot> initial_ivalues;
       for (auto& p : kv.value()->initial_ivalues()) {
         initial_ivalues.push_back(parameter_remap.at(p));
       }
diff --git a/torch/csrc/jit/script/slot.h b/torch/csrc/jit/script/slot.h
new file mode 100644 (file)
index 0000000..dc70e7b
--- /dev/null
@@ -0,0 +1,42 @@
+#pragma once
+#include <ATen/core/ivalue.h>
+
+namespace torch {
+namespace jit {
+namespace script {
+
+// a stable location that can hold an IValue.
+// Currently this is internally implemented as a pointer, but when
+// modules become first-class this will be a pair of  <module_ivalue, slot_number>
+struct Slot {
+  friend struct NamedIValue;
+  Slot()
+  : slot_(nullptr) {}
+  Slot(at::IValue* slot)
+  : slot_(slot) {}
+  at::IValue& operator*() const {
+    return *slot_;
+  }
+  at::IValue* operator->() const {
+      return slot_;
+  }
+  bool operator==(const Slot& rhs) const {
+    return slot_ == rhs.slot_;
+  }
+private:
+  at::IValue* slot_;
+  friend struct std::hash<Slot>;
+};
+
+}}}
+
+// slots are hashable, because they are often used as keys in maps
+// for remapping uses of a slot from one model to another
+namespace std {
+  template <>
+  struct hash<torch::jit::script::Slot> {
+    size_t operator()(const torch::jit::script::Slot& s) const noexcept {
+      return std::hash<at::IValue*>{}(s.slot_);
+    }
+  };
+} // namespace std