Store the optimize flag in module (#14166)
authorLu Fang <lufang@fb.com>
Mon, 19 Nov 2018 22:29:31 +0000 (14:29 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 19 Nov 2018 22:34:05 +0000 (14:34 -0800)
Summary:
When the save/load of script module, we store optimize flag in module instead of encoding it in method.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/14166

Reviewed By: ezyang

Differential Revision: D13117577

Pulled By: dzhulgakov

fbshipit-source-id: dc322948bda0ac5809d8ef9a345497ebb8f33a61

caffe2/proto/torch.proto
torch/csrc/jit/export.cpp
torch/csrc/jit/import.cpp
torch/csrc/jit/script/module.h

index 3e70cc6..f7e2e8c 100644 (file)
@@ -56,6 +56,10 @@ message ModuleDef {
   // from the main method.
 
   optional string name = 6;
+
+  // whether apply the optimizations to this module, only applicable to
+  // script modules
+  optional bool optimize = 7;
 }
 
 enum ProtoVersion {
index aced0bc..ac28640 100644 (file)
@@ -618,10 +618,6 @@ void MethodEncoder::EncodeMethod(
   model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
   auto* node_proto = model_proto.mutable_graph()->add_node();
   node_proto->set_name(prefix + method.name());
-  if (method.is_optimized()) {
-    // mark that this method was optimized
-    node_proto->set_domain("optimized");
-  }
 
   // We store the schema string in the docstring.
   node_proto->set_doc_string(getExportableSchemaStringForMethod(method));
@@ -920,6 +916,7 @@ class ScriptModuleSerializer final {
       const std::string& name,
       torch::ModuleDef* module_def) {
     module_def->set_name(name);
+    module_def->set_optimize(module.is_optimized());
     for (const auto& elem : module.get_parameters()) {
       torch::ParameterDef* param_def = module_def->add_parameters();
       convertParameter(elem.value(), param_def);
index 759d992..8e62893 100644 (file)
@@ -351,8 +351,6 @@ MethodDecoder::MethodDecoder(
       member_inputs.push_back(it->second);
     }
     auto graph = buildGraph(node_proto.attribute(0).g());
-    // has_domain field has a string iff the method was optimized
-    parent_module->set_optimized(node_proto.has_domain());
     parent_module->create_method(name, graph, member_inputs);
     // We store the schema in the docstring so we can parse the schema and
     // assign it to the method.
@@ -437,6 +435,7 @@ class ScriptModuleDeserializer final {
   void convertModule(
       const torch::ModuleDef& module_def,
       script::Module* module) {
+    module->set_optimized(module_def.optimize());
     for (int i = 0; i < module_def.methods_size(); ++i) {
       const torch::MethodDef& method_def = module_def.methods(i);
       // TODO read unhacked torch script, right now it's serialized onnx proto
index 774ab8d..e9da1e5 100644 (file)
@@ -316,6 +316,10 @@ struct Module {
     optimize = o;
   }
 
+  bool is_optimized() const {
+    return optimize; 
+  }
+
   IValue forward(std::vector<IValue> inputs) {
     return get_method("forward")(inputs);
   }