Replace Slot on script::Method with NamedIValue (#18252)
authorDavid Riazati <davidriazati@fb.com>
Fri, 5 Apr 2019 06:27:05 +0000 (23:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 06:35:56 +0000 (23:35 -0700)
Summary:
This refactor lets us track the types of initial values added onto a `Method`. The main motivation for this is the change in `module.cpp`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18252

Differential Revision: D14673459

Pulled By: driazati

fbshipit-source-id: 21200180c47f25bb70898771adfb569856e6c34a

torch/csrc/jit/import_source.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h

index 6364947..756b962 100644 (file)
@@ -1,5 +1,4 @@
-#include "import_source.h"
-
+#include <torch/csrc/jit/import_source.h>
 #include <torch/csrc/jit/script/parser.h>
 
 namespace torch {
@@ -22,12 +21,12 @@ struct ModuleAccessorValue : public SugaredValue {
     if (NamedModule* v = module->find_module(field)) {
       return std::make_shared<ModuleAccessorValue>(v->module);
     } else if (NamedIValue* v = module->find_parameter(field)) {
-      return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+      return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
     } else if (NamedIValue* v = module->find_buffer(field)) {
-      return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+      return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
     } else if (script::NamedIValue* v = module->find_attribute(field)) {
       return std::make_shared<script::SimpleValue>(
-          m.get_or_add_attribute(v->type(), v->slot()));
+          m.get_or_add_initial_ivalue(v));
     } else if (Method* m = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *m);
     } else {
index 0844a96..8f5e510 100644 (file)
@@ -1123,8 +1123,14 @@ struct PythonPrintPass {
       const std::unordered_map<script::Slot, QualifiedNamePtr>&
           extra_ivalue_names) {
     std::vector<std::string> ivalue_names =
-        fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
-          return extra_ivalue_names.at(slot)->str();
+        fmap(method.initial_ivalues(), [&](const script::NamedIValue* value) {
+          auto entry = extra_ivalue_names.find(value->slot());
+          AT_CHECK(
+              entry != extra_ivalue_names.end(),
+              "Could not find named IValue '",
+              value->name(),
+              "' while pretty printing");
+          return entry->second->str();
         });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
index 54dbb98..55c916c 100644 (file)
@@ -276,7 +276,7 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
     const auto& param_list = module_->get_parameters();
     for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
       auto& param = *it;
-      params.push_back(caller.get_or_add_parameter(param.slot()));
+      params.push_back(caller.get_or_add_initial_ivalue(&param));
     }
     auto list = caller.graph()->createList(TensorType::get(), params);
     caller.graph()->insertNode(list);
@@ -319,7 +319,7 @@ struct ModuleValue : public SugaredValue {
         module->register_buffer("training", std::move(t));
         v = module->find_buffer(field);
       }
-      Value* the_tensor = m.get_or_add_parameter(v->slot());
+      Value* the_tensor = m.get_or_add_initial_ivalue(v);
       Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
       return std::make_shared<SimpleValue>(the_bool);
     }
@@ -329,10 +329,10 @@ struct ModuleValue : public SugaredValue {
     } else if (Method* v = module->find_method(field)) {
       return std::make_shared<MethodValue>(shared_from_this(), *v);
     } else if (NamedIValue* v = module->find_parameter(field)) {
-      return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+      return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
     } else if (NamedIValue* v = module->find_attribute(field)) {
       return std::make_shared<SimpleValue>(
-          m.get_or_add_attribute(v->type(), v->slot()));
+          m.get_or_add_initial_ivalue(v));
     }
 
     // This can also be a call to a non-script module, or a plain
@@ -603,14 +603,14 @@ py::object unpackVariableTensorList(std::vector<at::Tensor> outputs) {
 }
 
 static void gatherParametersAndBuffers(
-    std::vector<Slot>& values,
+    std::vector<const NamedIValue*>& values,
     const Module& m) {
   for (auto& param : m.get_parameters()) {
-    values.push_back(param.slot());
+    values.push_back(&param);
   }
   for (auto& param : m.get_attributes()) {
     if (param.type()->isSubtypeOf(TensorType::get())) {
-      values.push_back(param.slot());
+      values.push_back(&param);
     }
   }
   for (const auto& sub : m.get_modules()) {
@@ -852,11 +852,11 @@ void initJitScriptBindings(PyObject* module) {
              bool force_outplace) {
             // prereq: Module's buffers and parameters are unique
             // this was ensured in python before calling this function
-            std::vector<Slot> parameters;
+            std::vector<const NamedIValue*> parameters;
             gatherParametersAndBuffers(parameters, *self);
             Stack inputs = toStack(input_tuple);
-            for (const Slot& param : parameters) {
-              inputs.emplace_back(*param);
+            for (const NamedIValue* param : parameters) {
+              inputs.emplace_back(*param->slot());
             }
             auto graph = tracer::createGraphByTracing(
                 func,
@@ -946,14 +946,14 @@ void initJitScriptBindings(PyObject* module) {
              std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
                  params,
              std::shared_ptr<Module> orig) {
-            std::vector<Slot> member_inputs;
+            std::vector<const NamedIValue*> member_inputs;
             for (auto& p : params) {
-              NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
-              if (np == nullptr) {
-                np = std::get<0>(p)->find_buffer(std::get<1>(p));
+              auto named_param = std::get<0>(p)->find_parameter(std::get<1>(p));
+              if (named_param == nullptr) {
+                named_param = std::get<0>(p)->find_buffer(std::get<1>(p));
               }
-              AT_ASSERT(np != nullptr);
-              member_inputs.push_back(np->slot());
+              AT_ASSERT(named_param != nullptr);
+              member_inputs.push_back(named_param);
             }
 
             Method* orig_method = orig->find_method(name);
@@ -975,15 +975,21 @@ void initJitScriptBindings(PyObject* module) {
       .def(
           "propagate_and_assign_input_and_output_shapes",
           &Method::propagate_and_assign_input_and_output_shapes)
-      .def(
-          "initial_ivalues",
-          [](Method& m) {
-            std::vector<at::Tensor> tensors;
-            for (auto& t : m.initial_ivalues()) {
-              tensors.push_back(t->toTensor());
-            }
-            return tensors;
-          })
+      .def("initial_ivalues", [](Method& m) {
+          std::vector<at::Tensor> result;
+          result.reserve(m.initial_ivalues().size());
+
+          for (auto named_ivalue : m.initial_ivalues()) {
+            AT_CHECK(
+                named_ivalue->slot()->isTensor(),
+                "Cannot get initial"
+                " IValues if any are not Tensors (found ",
+                named_ivalue->type()->python_str(),
+                ")");
+            result.push_back(named_ivalue->slot()->toTensor());
+          }
+          return result;
+      })
       .def(
           "graph_for",
           [](py::args args, py::kwargs kwargs) {
index 933cf40..249c753 100644 (file)
@@ -55,10 +55,7 @@ Value* try_emit_call_to(
           << " attempting to call a method with parameters/attributes"
              " from a raw graph. File a bug report";
     }
-    // TODO: preserve the type information so we don't have to infer it here
-    auto type = incompleteInferTypeFrom(*member);
-    matched_schema->inputs.push_back(
-        caller->get_or_add_attribute(type, member));
+    matched_schema->inputs.push_back(caller->get_or_add_initial_ivalue(member));
   }
   callee.check_single_output();
   return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
index 304857d..2986810 100644 (file)
@@ -53,13 +53,35 @@ struct Module;
 using ModuleLookup =
     std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
 
+struct NamedIValue {
+  NamedIValue(std::string name, TypePtr type, IValue ivalue)
+      : name_(name),
+        type_(type),
+        ivalue_(torch::make_unique<IValue>(std::move(ivalue))) {}
+
+  Slot slot() const {
+    return Slot(ivalue_.get());
+  }
+  const std::string& name() const {
+    return name_;
+  }
+  const TypePtr& type() const {
+    return type_;
+  }
+
+ private:
+  const std::string name_;
+  const TypePtr type_;
+  std::unique_ptr<IValue> ivalue_;
+};
+
 struct Method {
   Method(
       Module* owner,
       std::string name,
       bool optimize,
       std::shared_ptr<Graph> graph,
-      std::vector<Slot> initial_members,
+      std::vector<const NamedIValue*> initial_members,
       std::function<void(Method&)> method_creator)
       : owner_(owner),
         name_(std::move(name)),
@@ -76,7 +98,7 @@ struct Method {
 
   void run(Stack& stack) {
     for (auto input : initial_ivalues_) {
-      push(stack, *input);
+      push(stack, *input->slot());
     }
     get_executor().run(stack);
   }
@@ -93,7 +115,7 @@ struct Method {
 
   std::shared_ptr<Graph> graph_for(Stack inputs) {
     for (auto tp : initial_ivalues_) {
-      inputs.emplace_back(*tp);
+      inputs.emplace_back(*tp->slot());
     }
     return get_executor().graphFor(inputs);
   }
@@ -121,19 +143,15 @@ struct Method {
   size_t num_inputs() const {
     return graph()->inputs().size() - initial_ivalues_.size();
   }
-  TORCH_API Value* get_or_add_parameter(Slot slot) {
-    AT_ASSERT(slot->isTensor());
-    return get_or_add_attribute(TensorType::get(), slot);
-  }
 
-  TORCH_API Value* get_or_add_attribute(TypePtr type, Slot slot) {
-    auto it = initial_ivalue_index.find(slot);
+  TORCH_API Value* get_or_add_initial_ivalue(const NamedIValue* value) {
+    auto it = initial_ivalue_index.find(value);
     if (it != initial_ivalue_index.end()) {
       return graph()->inputs().at(it->second);
     }
-    initial_ivalues_.push_back(slot);
-    initial_ivalue_index[slot] = graph()->inputs().size();
-    return graph()->addInput()->setType(type);
+    initial_ivalues_.push_back(value);
+    initial_ivalue_index[value] = graph()->inputs().size();
+    return graph()->addInput()->setType(value->type());
   }
 
   static void setInputTensorTypes(Graph& g, const Stack& stack) {
@@ -153,8 +171,8 @@ struct Method {
     for (at::Tensor& i : inputs) {
       stack.emplace_back(std::move(i));
     }
-    for (const Slot& inp : initial_ivalues_) {
-      stack.push_back(*inp);
+    for (const NamedIValue* inp : initial_ivalues_) {
+      stack.push_back(*inp->slot());
     }
     setInputTensorTypes(*retval, stack);
     PropagateInputShapes(retval);
@@ -168,8 +186,8 @@ struct Method {
       bool propagate = true) {
     auto retval = graph_->copy();
     for (auto inp : initial_ivalues_) {
-      if (inp->isTensor()) {
-        inputs.push_back(inp->toTensor());
+      if (inp->slot()->isTensor()) {
+        inputs.push_back(inp->slot()->toTensor());
       }
     }
     if (propagate) {
@@ -201,7 +219,7 @@ struct Method {
     return retval;
   }
 
-  const std::vector<Slot>& initial_ivalues() const {
+  const std::vector<const NamedIValue*>& initial_ivalues() const {
     return initial_ivalues_;
   }
 
@@ -330,11 +348,11 @@ struct Method {
   // each is a pointer to a slot in the module that owns this parameter
   // parameters and submodules can only be _added_ to script Modules to ensure
   // these pointers always stay valid
-  std::vector<Slot> initial_ivalues_;
+  std::vector<const NamedIValue*> initial_ivalues_;
 
-  // map from a IValue* in initial_ivalues to the offset it appears at
+  // map from a const NamedIValue* in initial_ivalues to the offset it appears at
   // in graph. used to accelerate get_or_add_parameter
-  std::unordered_map<Slot, size_t> initial_ivalue_index;
+  std::unordered_map<const NamedIValue*, size_t> initial_ivalue_index;
 
   // TODO: support that case where we allow _writes_ to parameters from
   // compiled functions.
@@ -364,28 +382,6 @@ struct NamedModule {
   std::shared_ptr<Module> module;
 };
 
-struct NamedIValue {
-  NamedIValue(std::string name, TypePtr type, IValue ivalue)
-      : name_(name),
-        type_(type),
-        ivalue_(torch::make_unique<IValue>(std::move(ivalue))) {}
-
-  Slot slot() const {
-    return Slot(ivalue_.get());
-  }
-  const std::string& name() const {
-    return name_;
-  }
-  const TypePtr& type() const {
-    return type_;
-  }
-
- private:
-  const std::string name_;
-  const TypePtr type_;
-  std::unique_ptr<IValue> ivalue_;
-};
-
 struct Module {
   TH_DISALLOW_COPY_AND_ASSIGN(Module);
   Module() : optimize(true) {}
@@ -453,7 +449,7 @@ struct Module {
   Method& create_method(
       const std::string& name,
       std::shared_ptr<Graph> graph,
-      std::vector<Slot> member_inputs) {
+      std::vector<const NamedIValue*> member_inputs) {
     AT_ASSERT(graph);
     std::unique_ptr<Method> method(new Method(
         this,
@@ -628,22 +624,25 @@ struct Module {
       ModuleLookup module_lookup,
       // parameter_remap is needed when a parent module uses a parameter of a
       // submodule
-      std::unordered_map<Slot, Slot>& parameter_remap,
+      std::unordered_map<const NamedIValue*, const NamedIValue*>&
+          parameter_remap,
       std::vector<std::string> names = {}) const {
     auto curr = module_lookup(names);
+    curr->parameters_.reserve(get_parameters().size() + get_attributes().size());
+
     for (auto& param : get_parameters()) {
       curr->register_parameter(
           param.name(),
           param.slot()->toTensor(),
           /*is_buffer=*/false);
-      parameter_remap[param.slot()] = curr->parameter_slot(param.name());
+      parameter_remap[&param] = curr->find_parameter(param.name());
     }
     for (auto& attr : get_attributes()) {
       if (!attr.type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
       curr->register_buffer(attr.name(), attr.slot()->toTensor());
-      parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
+      parameter_remap[&attr] = curr->find_buffer(attr.name());
     }
     for (auto& mod : get_modules()) {
       names.push_back(mod.name);
@@ -652,8 +651,9 @@ struct Module {
       mod.module->copy_into(module_lookup, parameter_remap, names);
       names.pop_back();
     }
+
     for (auto& method : get_methods()) {
-      std::vector<Slot> initial_ivalues;
+      std::vector<const NamedIValue*> initial_ivalues;
       for (auto& p : method->initial_ivalues()) {
         initial_ivalues.push_back(parameter_remap.at(p));
       }