slots with explicit value/setValue make more sense in future patches (#18468)
authorZachary DeVito <zdevito@fb.com>
Fri, 5 Apr 2019 20:33:14 +0000 (13:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 20:41:02 +0000 (13:41 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18468
ghimport-source-id: d4b41c521f2269a695e03c8e7d05d5542731ee48

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18469 Create Object that represents a Module
* **#18468 slots with explicit value/setValue make more sense in future patches**
* #18467 Make Object hold its ClassType
* #18379 Enforce single parent for script submodules
* #18378 Unify namespace of script::Module
* #18314 Add ability to specialize class types to ArgumentSpec
* #18226 Add Slot type to abstract the raw pointers being used for slots.

Reviewed By: suo

Differential Revision: D14613509

fbshipit-source-id: 9f2208d0efd01465c78cebdc3e8365a9e0adf9ff

test/custom_operator/test_custom_ops.cpp
torch/csrc/api/src/serialize/input-archive.cpp
torch/csrc/jit/export.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.cpp
torch/csrc/jit/script/module.h
torch/csrc/jit/script/slot.h

index 1b64171..a943a08 100644 (file)
@@ -15,10 +15,10 @@ void check_all_parameters(
     const torch::jit::script::Module& module,
     Predicate predicate) {
   for (const auto& parameter : module.get_parameters()) {
-    AT_ASSERT(predicate(parameter.slot()->toTensor()));
+    AT_ASSERT(predicate(parameter.slot().value().toTensor()));
   }
   for (const auto& child : module.get_modules()) {
-    check_all_parameters(module, predicate);
+    check_all_parameters(*child, predicate);
   }
 }
 } // namespace helpers
index e8243cd..0355dc5 100644 (file)
@@ -32,7 +32,7 @@ void InputArchive::read(
       "'");
   // clang-format off
   auto read_param = is_buffer ? buffer : param;
-  auto read_tensor = read_param->slot()->toTensor();
+  auto read_tensor = read_param->slot().value().toTensor();
   AT_CHECK(
       bool(buffer) == is_buffer,
       "Expected deserialized tensor for key '", key,
index 4f27838..69399f9 100644 (file)
@@ -718,7 +718,7 @@ void ScriptModuleSerializer::writeAttributeTable() {
   }
   pickler.finish();
   writer_.writeRecord(
-        "attributes.pkl", pickler.stack().data(), pickler.stack().size());
+      "attributes.pkl", pickler.stack().data(), pickler.stack().size());
 }
 
 void ScriptModuleSerializer::convertModule(
@@ -739,7 +739,7 @@ void ScriptModuleSerializer::convertModule(
     attribute_def->set_name(attribute.name());
     attribute_def->set_type(attribute.type()->python_str());
 
-    attribute_table_.push_back(*attribute.slot());
+    attribute_table_.push_back(attribute.slot().value());
     attribute_def->set_id(attribute_table_.size() - 1);
   }
 
@@ -779,7 +779,7 @@ void ScriptModuleSerializer::convertParameter(
     bool is_parameter) {
   param_def->set_name(param.name());
   param_def->set_is_buffer(is_parameter);
-  param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
+  param_def->set_tensor_id(addTensor(param.slot().value().toTensor()));
 }
 
 // Pretty printing for ONNX
index ea69b11..79f7537 100644 (file)
@@ -784,7 +784,8 @@ void initJitScriptBindings(PyObject* module) {
               auto& p = parameters[i];
               py::tuple r(2);
               result[i] = std::make_tuple(
-                  p.name(), autograd::as_variable_ref(p.slot()->toTensor()));
+                  p.name(),
+                  autograd::as_variable_ref(p.slot().value().toTensor()));
             }
             return result;
           })
@@ -796,7 +797,7 @@ void initJitScriptBindings(PyObject* module) {
             for (size_t i = 0; i < attributes.size(); ++i) {
               auto& buffer = attributes[i];
               py::tuple r(3);
-              IValue v = *buffer.slot();
+              IValue v = buffer.slot().value();
               result[i] = std::make_tuple(
                   buffer.name(), buffer.type(), toPyObject(std::move(v)));
             }
@@ -856,7 +857,7 @@ void initJitScriptBindings(PyObject* module) {
             gatherParametersAndBuffers(parameters, *self);
             Stack inputs = toStack(input_tuple);
             for (const Slot& param : parameters) {
-              inputs.emplace_back(*param);
+              inputs.emplace_back(param.value());
             }
             auto graph = tracer::createGraphByTracing(
                 func,
@@ -980,7 +981,7 @@ void initJitScriptBindings(PyObject* module) {
           [](Method& m) {
             std::vector<at::Tensor> tensors;
             for (auto& t : m.initial_ivalues()) {
-              tensors.push_back(t->toTensor());
+              tensors.push_back(t.value().toTensor());
             }
             return tensors;
           })
index cb9cdc7..1e80a0b 100644 (file)
@@ -56,7 +56,7 @@ Value* try_emit_call_to(
              " from a raw graph. File a bug report";
     }
     // TODO: preserve the type information so we don't have to infer it here
-    auto type = incompleteInferTypeFrom(*member);
+    auto type = incompleteInferTypeFrom(member.value());
     matched_schema->inputs.push_back(
         caller->get_or_add_attribute(type, member));
   }
@@ -128,7 +128,7 @@ void Module::to_impl(
   // Then convert every of our parameters.
   for (auto& parameter : get_parameters()) {
     // Need to access the `at::Tensor` as a `Variable` here.
-    autograd::Variable variable = parameter.slot()->toTensor();
+    autograd::Variable variable = parameter.slot().value().toTensor();
     at::Tensor data = variable.data();
     // Use the data's original device or dtype if not supplied here.
     auto new_data = data.to(
index e7211e0..e674b02 100644 (file)
@@ -76,7 +76,7 @@ struct Method {
 
   void run(Stack& stack) {
     for (auto input : initial_ivalues_) {
-      push(stack, *input);
+      push(stack, input.value());
     }
     get_executor().run(stack);
   }
@@ -93,7 +93,7 @@ struct Method {
 
   std::shared_ptr<Graph> graph_for(Stack inputs) {
     for (auto tp : initial_ivalues_) {
-      inputs.emplace_back(*tp);
+      inputs.emplace_back(tp.value());
     }
     return get_executor().graphFor(inputs);
   }
@@ -122,7 +122,7 @@ struct Method {
     return graph()->inputs().size() - initial_ivalues_.size();
   }
   TORCH_API Value* get_or_add_parameter(Slot slot) {
-    AT_ASSERT(slot->isTensor());
+    AT_ASSERT(slot.value().isTensor());
     return get_or_add_attribute(TensorType::get(), slot);
   }
 
@@ -154,7 +154,7 @@ struct Method {
       stack.emplace_back(std::move(i));
     }
     for (const Slot& inp : initial_ivalues_) {
-      stack.push_back(*inp);
+      stack.push_back(inp.value());
     }
     setInputTensorTypes(*retval, stack);
     PropagateInputShapes(retval);
@@ -168,8 +168,8 @@ struct Method {
       bool propagate = true) {
     auto retval = graph_->copy();
     for (auto inp : initial_ivalues_) {
-      if (inp->isTensor()) {
-        inputs.push_back(inp->toTensor());
+      if (inp.value().isTensor()) {
+        inputs.push_back(inp.value().toTensor());
       }
     }
     if (propagate) {
@@ -406,7 +406,7 @@ struct Module {
   void register_buffer(const std::string& name, autograd::Variable v) {
     if (auto b = find_attribute(name)) {
       AT_ASSERT(b->type()->isSubtypeOf(TensorType::get()));
-      *b->slot() = v;
+      b->slot().setValue(v);
       return;
     }
     insert(
@@ -424,7 +424,7 @@ struct Module {
       return;
     }
     if (auto p = find_parameter(name)) {
-      *p->slot() = v;
+      p->slot().setValue(v);
       return;
     }
     insert(
@@ -497,15 +497,15 @@ struct Module {
   }
 
   void set_parameter(const std::string& name, at::Tensor v) {
-    *parameter_slot(name) = std::move(v);
+    parameter_slot(name).setValue(std::move(v));
   }
 
   autograd::Variable get_parameter(const std::string& name) const {
-    return autograd::as_variable_ref(parameter_slot(name)->toTensor());
+    return autograd::as_variable_ref(parameter_slot(name).value().toTensor());
   }
 
   IValue get_attribute(const std::string& name) const {
-    return *attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot();
+    return attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot().value();
   }
 
   autograd::Variable get_buffer(const std::string& name) const {
@@ -579,7 +579,7 @@ struct Module {
   /// True if the module is in training mode.
   bool is_training() {
     if (auto p = find_buffer("training")) {
-      return p->slot()->toTensor().item<int64_t>() == 1;
+      return p->slot().value().toTensor().item<int64_t>() == 1;
     }
     // We are in training mode by default
     return true;
@@ -648,7 +648,7 @@ struct Module {
     for (auto& param : get_parameters()) {
       curr->register_parameter(
           param.name(),
-          param.slot()->toTensor(),
+          param.slot().value().toTensor(),
           /*is_buffer=*/false);
       parameter_remap[param.slot()] = curr->parameter_slot(param.name());
     }
@@ -656,7 +656,7 @@ struct Module {
       if (!attr.type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
-      curr->register_buffer(attr.name(), attr.slot()->toTensor());
+      curr->register_buffer(attr.name(), attr.slot().value().toTensor());
       parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
     }
     for (auto& mod : get_modules()) {
index dc70e7b..67fd170 100644 (file)
@@ -7,36 +7,39 @@ namespace script {
 
 // a stable location that can hold an IValue.
 // Currently this is internally implemented as a pointer, but when
-// modules become first-class this will be a pair of  <module_ivalue, slot_number>
+// modules become first-class this will be a pair of  <module_ivalue,
+// slot_number>
 struct Slot {
   friend struct NamedIValue;
-  Slot()
-  : slot_(nullptr) {}
-  Slot(at::IValue* slot)
-  : slot_(slot) {}
-  at::IValue& operator*() const {
-    return *slot_;
-  }
-  at::IValue* operator->() const {
-      return slot_;
-  }
+  Slot() : slot_(nullptr) {}
+  Slot(at::IValue* slot) : slot_(slot) {}
+
   bool operator==(const Slot& rhs) const {
     return slot_ == rhs.slot_;
   }
-private:
+  void setValue(at::IValue v) {
+    *slot_ = std::move(v);
+  }
+  const at::IValue& value() const {
+    return *slot_;
+  }
+
+ private:
   at::IValue* slot_;
   friend struct std::hash<Slot>;
 };
 
-}}}
+} // namespace script
+} // namespace jit
+} // namespace torch
 
 // slots are hashable, because they are often used as keys in maps
 // for remapping uses of a slot from one model to another
 namespace std {
-  template <>
-  struct hash<torch::jit::script::Slot> {
-    size_t operator()(const torch::jit::script::Slot& s) const noexcept {
-      return std::hash<at::IValue*>{}(s.slot_);
-    }
-  };
+template <>
+struct hash<torch::jit::script::Slot> {
+  size_t operator()(const torch::jit::script::Slot& s) const noexcept {
+    return std::hash<at::IValue*>{}(s.slot_);
+  }
+};
 } // namespace std