Handling of pretty-printing methods (#14378)
authorZachary DeVito <zdevito@fb.com>
Wed, 28 Nov 2018 01:08:09 +0000 (17:08 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 01:10:23 +0000 (17:10 -0800)
Summary:
Stacked on #14176, review only the last commit.
* Print parameters to methods as self.weight rather than as extra inputs.
* Print entire set of methods out as a single string
* Update test code to test the module-at-a-time export/import
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14378

Differential Revision: D13198463

Pulled By: zdevito

fbshipit-source-id: 3fab02e8239cfd6f40d6ab6399047bd02cf0a8c8

test/test_jit.py
torch/csrc/jit/import_method.cpp
torch/csrc/jit/import_method.h
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/python_print.h
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h

index d33fa11..db2f9e1 100644 (file)
@@ -238,29 +238,32 @@ class JitTestCase(TestCase):
         torch._C._jit_set_emit_module_hook(self.emitModuleHook)
 
     def emitModuleHook(self, module):
+
+        def copy_structure_and_params(m):
+            c = torch.jit.ScriptModule()
+            for name, v, buffer in m._get_parameters():
+                c._register_parameter(name, v, buffer)
+            for name, s in m._get_modules():
+                c._register_module(name, copy_structure_and_params(s))
+            return c
+
         # disable the hook while we parse code, otherwise we will re-enter the hook
         with self.disableModuleHook():
-            for name in module._method_names():
-                method = module._get_method(name)
-                try:
-                    pp, constant_table = method.python_print()
-                except RuntimeError as e:
-                    if "could not export python function" not in str(e):
-                        raise
-                    else:
-                        continue
-
-                ppv = "op_version_set = 0\n{}".format(pp)
-                sm = torch.jit.ScriptModule()
-                torch._C._jit_import_method(sm, ppv, constant_table)
-                method2 = sm._get_method(name)
-                pp2, _ = method2.python_print()
-                if pp != pp2:
-                    print(method.graph)
-                    print(pp)
-                    print(method2.graph)
-                    print(pp2)
-                    self.assertMultiLineEqual(pp, pp2)
+            try:
+                pp, constant_table = module._python_print()
+            except RuntimeError as e:
+                if "could not export python function" not in str(e):
+                    raise
+                else:
+                    return
+
+            ppv = "op_version_set = 0\n{}".format(pp)
+            sm = copy_structure_and_params(module)
+            torch._C._jit_import_methods(sm, ppv, constant_table)
+
+            pp2, _ = sm._python_print()
+            if pp != pp2:
+                self.assertMultiLineEqual(pp, pp2)
 
     def getExportImportCopy(self, m):
         # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
@@ -2185,10 +2188,10 @@ class TestJit(JitTestCase):
         def foo(x, y):
             return 2 * x + y
 
-        r = foo.graph.pretty_print()
+        r, _ = foo._python_print()
         mod = torch.jit.ScriptModule()
-        torch._C._jit_import_method(mod, "op_version_set = 0\n{}".format(r), [])
-        self.assertExpected(mod.graph.graph.pretty_print())
+        torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
+        self.assertExpected(mod.graph.pretty_print())
 
     def test_function_default_values(self):
         outer_var = torch.tensor(20)
@@ -2942,7 +2945,7 @@ class TestScript(JitTestCase):
             pp, table = foo._get_method('forward').python_print()
             ppv = "op_version_set = 0\n{}".format(pp)
             sm = torch.jit.ScriptModule()
-            torch._C._jit_import_method(sm, ppv, table)
+            torch._C._jit_import_methods(sm, ppv, table)
             r = foo()
             r2 = sm()
             # use precise assert, we are checking floating point details
index 3dccb98..9abb9ab 100644 (file)
@@ -20,8 +20,9 @@ struct ModuleAccessorValue : public script::SugaredValue {
       return std::make_shared<script::SimpleValue>(m.get_or_add_parameter(v->slot()));
     } else if(script::Method* m = module->find_method(field)) {
       return std::make_shared<script::MethodValue>(module, *m);
+    } else {
+      throw script::ErrorReport(loc) << "unknown attr: " << field;
     }
-    return script::SugaredValue::attr(loc, m, field);
   }
 private:
   std::shared_ptr<script::Module> module;
@@ -80,7 +81,7 @@ static size_t parseVersionNumber(script::Lexer& L) {
    return size_t(version.asIntegral());
 }
 
-void import_method(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
+void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
   script::Parser p(src);
 
   size_t version = parseVersionNumber(p.lexer());
index 16deb8d..6ee08d4 100644 (file)
@@ -7,7 +7,7 @@
 namespace torch {
 namespace jit {
 
-TORCH_API void import_method(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
+TORCH_API void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
 
 
 } // namespace jit
index 9553892..d727b78 100644 (file)
@@ -3,6 +3,7 @@
 #include "torch/csrc/jit/generic_if.h"
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/ir_views.h"
+#include "torch/csrc/jit/export.h"
 #include "torch/csrc/jit/resource_guard.h"
 #include "torch/csrc/jit/script/error_report.h"
 #include "torch/csrc/jit/script/module.h"
 namespace torch {
 namespace jit {
 
+// unix isprint but insensitive to locale
+static bool isPrint(char s) {
+  return s > 0x1f && s < 0x7f;
+}
+
+void printQuotedString(std::ostream& stmt, const std::string& str) {
+  stmt << "\"";
+  for(auto s : str) {
+    switch (s) {
+      case '\\':
+        stmt << "\\\\";
+        break;
+      case '\'':
+        stmt << "\\'";
+        break;
+      case '\"':
+        stmt << "\\\"";
+        break;
+      case '\a':
+        stmt << "\\a";
+        break;
+      case '\b':
+        stmt << "\\b";
+        break;
+      case '\f':
+        stmt << "\\f";
+        break;
+      case '\n':
+        stmt << "\\n";
+        break;
+      case '\r':
+        stmt << "\\r";
+        break;
+      case '\t':
+        stmt << "\\t";
+        break;
+      case '\v':
+        stmt << "\\v";
+        break;
+      default:
+        if (isPrint(s)) {
+          stmt << s;
+        } else {
+          // C++ io has stateful formatting settings. Messing with
+          // them is probably worse than doing this manually.
+          char buf[4] = "000";
+          buf[2] += s % 8; s /= 8;
+          buf[1] += s % 8; s /= 8;
+          buf[0] += s;
+          stmt << "\\" << buf;
+        }
+        break;
+    }
+  }
+  stmt << "\"";
+}
+
+static bool isValidIdentifierChar(char c, size_t pos) {
+  return islower(c) || isupper(c) || c == '_' ||  (pos > 0 && isdigit(c));
+}
+
+static bool isValidIdentifier(const std::string & name) {
+  if (name.size() == 0)
+    return false;
+  for (size_t i = 0; i < name.size(); ++i) {
+    if (!isValidIdentifierChar(name[i], i))
+      return false;
+  }
+  return true;
+}
+
+// handles names of the form, e.g., self.a.b
+// if a field is not a valid identifier, then it will print as, e.g.
+// getattr(self, "0").b
+struct QualifiedName;
+using QualifiedNamePtr = c10::intrusive_ptr<QualifiedName>;
+struct QualifiedName : c10::intrusive_ptr_target {
+  QualifiedName(QualifiedNamePtr prefix, std::string name)
+  : prefix_(std::move(prefix)), name_(std::move(name)) {}
+  QualifiedNamePtr prefix_;
+  std::string name_;
+  static QualifiedNamePtr create(QualifiedNamePtr prefix, std::string name) {
+    return c10::make_intrusive<QualifiedName>(std::move(prefix), std::move(name));
+  }
+  static QualifiedNamePtr create(std::string name) {
+    return c10::make_intrusive<QualifiedName>(QualifiedNamePtr(), std::move(name));
+  }
+  std::string str() const {
+    std::stringstream ss;
+    emit(ss);
+    return ss.str();
+  }
+private:
+  void emit(std::ostream& out) const {
+    if (isValidIdentifier(name_)) {
+      if (prefix_) {
+        prefix_->emit(out);
+        out << ".";
+      }
+      out << name_;
+    } else {
+      JIT_ASSERT(prefix_);
+      out << "getattr(";
+      prefix_->emit(out);
+      out << ", ";
+      printQuotedString(out, name_);
+      out << ")";
+    }
+  }
+};
+
+void createTensorToParameterNameMap(
+    const script::Module& module,
+    QualifiedNamePtr prefix,
+    std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
+
+  for (const auto& elem : module.get_parameters()) {
+    const script::NamedParameter& param = elem.value();
+    result[param.slot()] = QualifiedName::create(prefix, param.name);
+  }
+  for (const auto& elem : module.get_modules()) {
+    createTensorToParameterNameMap(
+        *elem->module, QualifiedName::create(prefix, elem.key()), result);
+  }
+}
+
   // some names are valid identifiers but off limits because
   // they are keywords or namespaces used in the output
   const static std::unordered_set<std::string> reserved_names = {
@@ -20,6 +147,7 @@ namespace jit {
     "CONSTANTS",
     "fork",
     "attribute",
+    "getattr",
     "_", // avoid the confusing unnamed _
     "inf",
     "nan",
@@ -495,63 +623,6 @@ struct PythonPrintPass {
     }
   }
 
-  // unix isprint but insensitive to locale
-  static bool isPrint(char s) {
-    return s > 0x1f && s < 0x7f;
-  }
-
-  void printQuotedString(std::ostream& stmt, const std::string& str) {
-    stmt << "\"";
-    for(auto s : str) {
-      switch (s) {
-        case '\\':
-          stmt << "\\\\";
-          break;
-        case '\'':
-          stmt << "\\'";
-          break;
-        case '\"':
-          stmt << "\\\"";
-          break;
-        case '\a':
-          stmt << "\\a";
-          break;
-        case '\b':
-          stmt << "\\b";
-          break;
-        case '\f':
-          stmt << "\\f";
-          break;
-        case '\n':
-          stmt << "\\n";
-          break;
-        case '\r':
-          stmt << "\\r";
-          break;
-        case '\t':
-          stmt << "\\t";
-          break;
-        case '\v':
-          stmt << "\\v";
-          break;
-        default:
-          if (isPrint(s)) {
-            stmt << s;
-          } else {
-            // C++ io has stateful formatting settings. Messing with
-            // them is probably worse than doing this manually.
-            char buf[4] = "000";
-            buf[2] += s % 8; s /= 8;
-            buf[1] += s % 8; s /= 8;
-            buf[0] += s;
-            stmt << "\\" << buf;
-          }
-          break;
-      }
-    }
-    stmt << "\"";
-  }
-
   void printConstant(std::ostream& stmt, IValue v) {
     if(v.isTensor()) {
       stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
@@ -687,7 +758,7 @@ struct PythonPrintPass {
         auto name = genMethodName("__forked_function");
         std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
         worklist.emplace_back([graph, name, this] {
-          printOneFunction(*graph, name);
+          printFunctionDefinition(*graph, name);
         });
         // and we put a call to fork which invokes that function.
         stmt << "fork(self." << name;
@@ -743,7 +814,12 @@ struct PythonPrintPass {
     }
     printConstant(stmt, value);
   }
-  void printOneFunction(Graph& graph, const std::string& name, const std::vector<c10::optional<IValue>> defaults = {}) {
+  void printFunctionDefinition(
+      Graph& graph,
+      const std::string& name,
+      const std::vector<c10::optional<IValue>> defaults = {},
+      const std::vector<std::string>& param_names = {}) {
+
     used_names_.clear(); // each graph can reuse local names
 
     // we always print constants at the top of the function, in the order
@@ -753,10 +829,19 @@ struct PythonPrintPass {
 
     // current graph is used to de-dup names within a single graph
     scanBlock(graph.block());
-    assignValuesToTheirUniqueNames(graph.inputs());
+
+    // last param_names.size() arguments to the graph are parameters and not
+    // actual inputs, we will print these as, e.g. self.foo.bar
+    // while we print the true_inputs out as parameters
+    auto true_inputs = graph.inputs().slice(0, graph.inputs().size() - param_names.size());
+    auto param_names_it = param_names.begin();
+    for(auto param : graph.inputs().slice(true_inputs.size())) {
+      assignValue(param, *param_names_it++);
+    }
+    assignValuesToTheirUniqueNames(true_inputs);
     out << "def " << name << "(self";
     auto defaults_offset = defaults.begin();
-    for (auto input : graph.inputs()) {
+    for (auto input : true_inputs) {
       out << ",\n    " << useOf(input) << ": " << input->type()->python_str();
       if (defaults_offset != defaults.end()) {
         const c10::optional<IValue>& def = *defaults_offset++;
@@ -766,6 +851,10 @@ struct PythonPrintPass {
         }
       }
     }
+
+    // have we use all the provided defaults?
+    JIT_ASSERT(defaults_offset == defaults.end());
+
     out << ") -> " << resultType(graph)->python_str() << ":\n";
     {
       auto guard = WithIndented();
@@ -798,8 +887,12 @@ struct PythonPrintPass {
     }
   }
 
-  void printFunction(Graph& graph, const std::string& name, const std::vector<c10::optional<IValue>>& defaults = {}) {
-    printOneFunction(graph, name, defaults);
+  void printFunction(
+      Graph& graph,
+      const std::string& name,
+      const std::vector<c10::optional<IValue>>& defaults = {},
+      const std::vector<std::string>& param_names = {}) {
+    printFunctionDefinition(graph, name, defaults, param_names);
     while(!worklist.empty()) {
       out << "\n\n";
       auto work = worklist.back();
@@ -808,12 +901,36 @@ struct PythonPrintPass {
     }
   }
   void printMethod(script::Method& method) {
+    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
+    createTensorToParameterNameMap(method.owner(), QualifiedName::create("self"),  parameter_names);
+    printMethod(method, parameter_names);
+  }
+  void printMethod(
+      script::Method& method,
+      const std::unordered_map<at::Tensor*, QualifiedNamePtr>&
+          parameter_names) {
+    std::vector<std::string> param_names = fmap(
+        method.params(),
+        [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
     const std::string& name = method.name();
     Graph& graph = *method.graph();
     auto defaults = fmap(method.getSchema().arguments(), [](const Argument& arg) {
       return arg.default_value();
     });
-    printFunction(graph, name, defaults);
+    printFunction(graph, name, defaults, param_names);
+  }
+  void printModule(script::Module& module) {
+    std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
+    createTensorToParameterNameMap(module, QualifiedName::create("self"),  parameter_names);
+    for(auto& method : module.get_methods()) {
+      const std::string& name = method.value()->name();
+      // we skip __forked_functions because they actually get inlined into their
+      // callers, exporting them again will lead to more code generated on each export
+      if (name.find("__forked_function") == 0) {
+        continue;
+      }
+      printMethod(*method.value(), parameter_names);
+    }
   }
 };
 
@@ -828,6 +945,12 @@ TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method&
   return pp.tensor_constants;
 }
 
+TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable) {
+  PythonPrintPass pp(out, enforce_importable);
+  pp.printModule(module);
+  return pp.tensor_constants;
+}
+
 TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
   // WARNING: by adding a value to this set, you are asserting
   // that you have also added special handling of this symbol to
index 54463d8..e1c7ef6 100644 (file)
@@ -9,9 +9,12 @@ namespace torch { namespace jit {
 
 namespace script {
   struct Method;
+  struct Module;
 }
 
 TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable=false);
 TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& graph, bool enforce_importable=false);
+TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable=false);
+
 TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
 }}
index 7a1d418..ce3d49a 100644 (file)
@@ -35,6 +35,13 @@ struct NoneValue : SugaredValue {
   }
 };
 
+// matched against for special handling of getattr expressions
+struct GetAttrValue : SugaredValue {
+  std::string kind() const override {
+    return "getattr";
+  }
+};
+
 struct PrintValue : public SugaredValue {
   std::string kind() const override {
     return "print";
@@ -348,6 +355,7 @@ struct Environment {
         {"float", std::make_shared<CastValue>(FloatType::get())},
         {"int", std::make_shared<CastValue>(IntType::get())},
         {"bool", std::make_shared<CastValue>(BoolType::get())},
+        {"getattr", std::make_shared<GetAttrValue>()},
         // todo(zach): remove when we can correctly export torch.full via ONNX
         // or we have implicit conversion that can convert numbers to tensors
         {"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) },
@@ -2010,6 +2018,20 @@ private:
             << " but found " << expr->type()->python_str();
       }
       return std::make_shared<SimpleValue>(expr);
+    } else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
+      if (apply.attributes().size() > 0) {
+        throw ErrorReport(loc) << "getattr takes no keyword arguments";
+      }
+      if (apply.inputs().size() != 2) {
+        throw ErrorReport(loc) << "getattr expects 2 inputs";
+      }
+      auto obj = emitSugaredExpr(apply.inputs()[0], 1);
+      auto selector = apply.inputs()[1];
+      if (selector.kind() != TK_STRINGLITERAL) {
+        throw ErrorReport(loc) << "getattr's second argument must be a string literal";
+      }
+      const std::string& name = StringLiteral(selector).text();
+      return obj->attr(apply.range(), method, name);
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
index 459d5c7..1e05cb0 100644 (file)
@@ -478,7 +478,7 @@ void initJitScriptBindings(PyObject* module) {
         py::tuple result(modules.size());
         for(size_t i = 0; i < modules.size(); ++i) {
           auto & item = modules[i];
-          result[i] = std::make_pair(item.key(), item.value());
+          result[i] = std::make_pair(item.key(), item.value().module);
         }
         return result;
       })
@@ -584,6 +584,11 @@ void initJitScriptBindings(PyObject* module) {
         // see: [pybind11 varargs]
         Module& self = py::cast<Module&>(args[0]);
         return invokeScriptMethodFromPython(self.get_method("forward"), tuple_slice(std::move(args), 1), std::move(kwargs));
+      })
+      .def("_python_print", [](Module& self) {
+        std::ostringstream ss;
+        std::vector<at::Tensor> tensors = PythonPrint(ss, self, true);
+        return std::make_pair(ss.str(), tensors);
       });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
@@ -637,7 +642,7 @@ void initJitScriptBindings(PyObject* module) {
     std::istringstream in(buffer);
     import_ir_module(module_lookup, in);
   });
-  m.def("_jit_import_method", import_method);
+  m.def("_jit_import_methods", import_methods);
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
 }
 
index 264d57d..37d6c60 100644 (file)
@@ -39,12 +39,15 @@ namespace torch { namespace jit { namespace script {
 // Note: because Method/Module are exposed to python these
 // classes use python method naming conventions
 
+struct Module;
+
 struct Method {
-  Method(std::string name, bool optimize,
+  Method(Module* owner, std::string name, bool optimize,
          std::shared_ptr<Graph> graph,
          std::vector<at::Tensor*> initial_members,
          std::function<void(Method&)> method_creator)
-  : name_(std::move(name))
+  : owner_(owner)
+  , name_(std::move(name))
   , graph_(std::move(graph))
   , optimize(optimize)
   , member_inputs(std::move(initial_members))
@@ -186,6 +189,11 @@ struct Method {
     return optimize;
   }
 
+  // the module that contains this method.
+  Module& owner() const {
+    return *owner_;
+  }
+
 private:
 
   static FunctionSchema defaultSchemaFor(const Method& method) {
@@ -204,10 +212,6 @@ private:
     return { method.name(), std::move(args), std::move(returns) };
   }
 
-  std::string name_;
-  std::shared_ptr<Graph> graph_; // for debugging and for inlining
-  bool optimize;
-
   GraphExecutor& get_executor() {
     std::call_once(executor_init, [&]{
       executor = GraphExecutor(graph(), optimize);
@@ -243,6 +247,15 @@ private:
     }
   }
 
+
+  // Methods are uniqued onwed by a single module. This raw pointer allows
+  // looking up the module.
+  Module* owner_;
+
+  std::string name_;
+  std::shared_ptr<Graph> graph_; // for debugging and for inlining
+  bool optimize;
+
   GraphExecutor executor; // for execution
   // member_inputs are a list of additional arguments appended to graph that are
   // inputs that come from the members of the Module or its submodules.
@@ -338,12 +351,12 @@ struct Module {
 
   Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) {
     JIT_ASSERT(graph);
-    std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr));
+    std::unique_ptr<Method> method(new Method(this, name, optimize, std::move(graph), std::move(member_inputs), nullptr));
     return *methods.insert(name, std::move(method));
   }
 
   Method& create_method(const std::string & name, std::function<void(Method&)> creator) {
-    std::unique_ptr<Method> method(new Method(name, optimize, std::make_shared<Graph>(), {}, creator));
+    std::unique_ptr<Method> method(new Method(this, name, optimize, std::make_shared<Graph>(), {}, creator));
     return *methods.insert(name, std::move(method));
   }