torch._C._jit_set_emit_module_hook(self.emitModuleHook)
def emitModuleHook(self, module):
+
+ def copy_structure_and_params(m):
+ c = torch.jit.ScriptModule()
+ for name, v, buffer in m._get_parameters():
+ c._register_parameter(name, v, buffer)
+ for name, s in m._get_modules():
+ c._register_module(name, copy_structure_and_params(s))
+ return c
+
# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableModuleHook():
- for name in module._method_names():
- method = module._get_method(name)
- try:
- pp, constant_table = method.python_print()
- except RuntimeError as e:
- if "could not export python function" not in str(e):
- raise
- else:
- continue
-
- ppv = "op_version_set = 0\n{}".format(pp)
- sm = torch.jit.ScriptModule()
- torch._C._jit_import_method(sm, ppv, constant_table)
- method2 = sm._get_method(name)
- pp2, _ = method2.python_print()
- if pp != pp2:
- print(method.graph)
- print(pp)
- print(method2.graph)
- print(pp2)
- self.assertMultiLineEqual(pp, pp2)
+ try:
+ pp, constant_table = module._python_print()
+ except RuntimeError as e:
+ if "could not export python function" not in str(e):
+ raise
+ else:
+ return
+
+ ppv = "op_version_set = 0\n{}".format(pp)
+ sm = copy_structure_and_params(module)
+ torch._C._jit_import_methods(sm, ppv, constant_table)
+
+ pp2, _ = sm._python_print()
+ if pp != pp2:
+ self.assertMultiLineEqual(pp, pp2)
def getExportImportCopy(self, m):
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
def foo(x, y):
return 2 * x + y
- r = foo.graph.pretty_print()
+ r, _ = foo._python_print()
mod = torch.jit.ScriptModule()
- torch._C._jit_import_method(mod, "op_version_set = 0\n{}".format(r), [])
- self.assertExpected(mod.graph.graph.pretty_print())
+ torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
+ self.assertExpected(mod.graph.pretty_print())
def test_function_default_values(self):
outer_var = torch.tensor(20)
pp, table = foo._get_method('forward').python_print()
ppv = "op_version_set = 0\n{}".format(pp)
sm = torch.jit.ScriptModule()
- torch._C._jit_import_method(sm, ppv, table)
+ torch._C._jit_import_methods(sm, ppv, table)
r = foo()
r2 = sm()
# use precise assert, we are checking floating point details
return std::make_shared<script::SimpleValue>(m.get_or_add_parameter(v->slot()));
} else if(script::Method* m = module->find_method(field)) {
return std::make_shared<script::MethodValue>(module, *m);
+ } else {
+ throw script::ErrorReport(loc) << "unknown attr: " << field;
}
- return script::SugaredValue::attr(loc, m, field);
}
private:
std::shared_ptr<script::Module> module;
return size_t(version.asIntegral());
}
-void import_method(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
+void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table) {
script::Parser p(src);
size_t version = parseVersionNumber(p.lexer());
namespace torch {
namespace jit {
-TORCH_API void import_method(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
+TORCH_API void import_methods(const std::shared_ptr<script::Module>& mod, const std::string& src, const std::vector<at::Tensor>& constant_table);
} // namespace jit
#include "torch/csrc/jit/generic_if.h"
#include "torch/csrc/jit/ir.h"
#include "torch/csrc/jit/ir_views.h"
+#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/resource_guard.h"
#include "torch/csrc/jit/script/error_report.h"
#include "torch/csrc/jit/script/module.h"
namespace torch {
namespace jit {
+// unix isprint but insensitive to locale
+static bool isPrint(char s) {
+ return s > 0x1f && s < 0x7f;
+}
+
+void printQuotedString(std::ostream& stmt, const std::string& str) {
+ stmt << "\"";
+ for(auto s : str) {
+ switch (s) {
+ case '\\':
+ stmt << "\\\\";
+ break;
+ case '\'':
+ stmt << "\\'";
+ break;
+ case '\"':
+ stmt << "\\\"";
+ break;
+ case '\a':
+ stmt << "\\a";
+ break;
+ case '\b':
+ stmt << "\\b";
+ break;
+ case '\f':
+ stmt << "\\f";
+ break;
+ case '\n':
+ stmt << "\\n";
+ break;
+ case '\r':
+ stmt << "\\r";
+ break;
+ case '\t':
+ stmt << "\\t";
+ break;
+ case '\v':
+ stmt << "\\v";
+ break;
+ default:
+ if (isPrint(s)) {
+ stmt << s;
+ } else {
+ // C++ io has stateful formatting settings. Messing with
+ // them is probably worse than doing this manually.
+ char buf[4] = "000";
+ buf[2] += s % 8; s /= 8;
+ buf[1] += s % 8; s /= 8;
+ buf[0] += s;
+ stmt << "\\" << buf;
+ }
+ break;
+ }
+ }
+ stmt << "\"";
+}
+
+static bool isValidIdentifierChar(char c, size_t pos) {
+ return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
+}
+
+static bool isValidIdentifier(const std::string & name) {
+ if (name.size() == 0)
+ return false;
+ for (size_t i = 0; i < name.size(); ++i) {
+ if (!isValidIdentifierChar(name[i], i))
+ return false;
+ }
+ return true;
+}
+
+// handles names of the form, e.g., self.a.b
+// if a field is not a valid identifier, then it will print as, e.g.
+// getattr(self, "0").b
+struct QualifiedName;
+using QualifiedNamePtr = c10::intrusive_ptr<QualifiedName>;
+struct QualifiedName : c10::intrusive_ptr_target {
+ QualifiedName(QualifiedNamePtr prefix, std::string name)
+ : prefix_(std::move(prefix)), name_(std::move(name)) {}
+ QualifiedNamePtr prefix_;
+ std::string name_;
+ static QualifiedNamePtr create(QualifiedNamePtr prefix, std::string name) {
+ return c10::make_intrusive<QualifiedName>(std::move(prefix), std::move(name));
+ }
+ static QualifiedNamePtr create(std::string name) {
+ return c10::make_intrusive<QualifiedName>(QualifiedNamePtr(), std::move(name));
+ }
+ std::string str() const {
+ std::stringstream ss;
+ emit(ss);
+ return ss.str();
+ }
+private:
+ void emit(std::ostream& out) const {
+ if (isValidIdentifier(name_)) {
+ if (prefix_) {
+ prefix_->emit(out);
+ out << ".";
+ }
+ out << name_;
+ } else {
+ JIT_ASSERT(prefix_);
+ out << "getattr(";
+ prefix_->emit(out);
+ out << ", ";
+ printQuotedString(out, name_);
+ out << ")";
+ }
+ }
+};
+
+void createTensorToParameterNameMap(
+ const script::Module& module,
+ QualifiedNamePtr prefix,
+ std::unordered_map<at::Tensor*, QualifiedNamePtr>& result) {
+
+ for (const auto& elem : module.get_parameters()) {
+ const script::NamedParameter& param = elem.value();
+ result[param.slot()] = QualifiedName::create(prefix, param.name);
+ }
+ for (const auto& elem : module.get_modules()) {
+ createTensorToParameterNameMap(
+ *elem->module, QualifiedName::create(prefix, elem.key()), result);
+ }
+}
+
// some names are valid identifiers but off limits because
// they are keywords or namespaces used in the output
const static std::unordered_set<std::string> reserved_names = {
"CONSTANTS",
"fork",
"attribute",
+ "getattr",
"_", // avoid the confusing unnamed _
"inf",
"nan",
}
}
- // unix isprint but insensitive to locale
- static bool isPrint(char s) {
- return s > 0x1f && s < 0x7f;
- }
-
- void printQuotedString(std::ostream& stmt, const std::string& str) {
- stmt << "\"";
- for(auto s : str) {
- switch (s) {
- case '\\':
- stmt << "\\\\";
- break;
- case '\'':
- stmt << "\\'";
- break;
- case '\"':
- stmt << "\\\"";
- break;
- case '\a':
- stmt << "\\a";
- break;
- case '\b':
- stmt << "\\b";
- break;
- case '\f':
- stmt << "\\f";
- break;
- case '\n':
- stmt << "\\n";
- break;
- case '\r':
- stmt << "\\r";
- break;
- case '\t':
- stmt << "\\t";
- break;
- case '\v':
- stmt << "\\v";
- break;
- default:
- if (isPrint(s)) {
- stmt << s;
- } else {
- // C++ io has stateful formatting settings. Messing with
- // them is probably worse than doing this manually.
- char buf[4] = "000";
- buf[2] += s % 8; s /= 8;
- buf[1] += s % 8; s /= 8;
- buf[0] += s;
- stmt << "\\" << buf;
- }
- break;
- }
- }
- stmt << "\"";
- }
-
void printConstant(std::ostream& stmt, IValue v) {
if(v.isTensor()) {
stmt << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor());
auto name = genMethodName("__forked_function");
std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
worklist.emplace_back([graph, name, this] {
- printOneFunction(*graph, name);
+ printFunctionDefinition(*graph, name);
});
// and we put a call to fork which invokes that function.
stmt << "fork(self." << name;
}
printConstant(stmt, value);
}
- void printOneFunction(Graph& graph, const std::string& name, const std::vector<c10::optional<IValue>> defaults = {}) {
+ void printFunctionDefinition(
+ Graph& graph,
+ const std::string& name,
+ const std::vector<c10::optional<IValue>> defaults = {},
+ const std::vector<std::string>& param_names = {}) {
+
used_names_.clear(); // each graph can reuse local names
// we always print constants at the top of the function, in the order
// current graph is used to de-dup names within a single graph
scanBlock(graph.block());
- assignValuesToTheirUniqueNames(graph.inputs());
+
+ // last param_names.size() arguments to the graph are parameters and not
+ // actual inputs, we will print these as, e.g. self.foo.bar
+ // while we print the true_inputs out as parameters
+ auto true_inputs = graph.inputs().slice(0, graph.inputs().size() - param_names.size());
+ auto param_names_it = param_names.begin();
+ for(auto param : graph.inputs().slice(true_inputs.size())) {
+ assignValue(param, *param_names_it++);
+ }
+ assignValuesToTheirUniqueNames(true_inputs);
out << "def " << name << "(self";
auto defaults_offset = defaults.begin();
- for (auto input : graph.inputs()) {
+ for (auto input : true_inputs) {
out << ",\n " << useOf(input) << ": " << input->type()->python_str();
if (defaults_offset != defaults.end()) {
const c10::optional<IValue>& def = *defaults_offset++;
}
}
}
+
+ // have we use all the provided defaults?
+ JIT_ASSERT(defaults_offset == defaults.end());
+
out << ") -> " << resultType(graph)->python_str() << ":\n";
{
auto guard = WithIndented();
}
}
- void printFunction(Graph& graph, const std::string& name, const std::vector<c10::optional<IValue>>& defaults = {}) {
- printOneFunction(graph, name, defaults);
+ void printFunction(
+ Graph& graph,
+ const std::string& name,
+ const std::vector<c10::optional<IValue>>& defaults = {},
+ const std::vector<std::string>& param_names = {}) {
+ printFunctionDefinition(graph, name, defaults, param_names);
while(!worklist.empty()) {
out << "\n\n";
auto work = worklist.back();
}
}
void printMethod(script::Method& method) {
+ std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
+ createTensorToParameterNameMap(method.owner(), QualifiedName::create("self"), parameter_names);
+ printMethod(method, parameter_names);
+ }
+ void printMethod(
+ script::Method& method,
+ const std::unordered_map<at::Tensor*, QualifiedNamePtr>&
+ parameter_names) {
+ std::vector<std::string> param_names = fmap(
+ method.params(),
+ [&](at::Tensor* slot) { return parameter_names.at(slot)->str(); });
const std::string& name = method.name();
Graph& graph = *method.graph();
auto defaults = fmap(method.getSchema().arguments(), [](const Argument& arg) {
return arg.default_value();
});
- printFunction(graph, name, defaults);
+ printFunction(graph, name, defaults, param_names);
+ }
+ void printModule(script::Module& module) {
+ std::unordered_map<at::Tensor*, QualifiedNamePtr> parameter_names;;
+ createTensorToParameterNameMap(module, QualifiedName::create("self"), parameter_names);
+ for(auto& method : module.get_methods()) {
+ const std::string& name = method.value()->name();
+ // we skip __forked_functions because they actually get inlined into their
+ // callers, exporting them again will lead to more code generated on each export
+ if (name.find("__forked_function") == 0) {
+ continue;
+ }
+ printMethod(*method.value(), parameter_names);
+ }
}
};
return pp.tensor_constants;
}
+TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable) {
+ PythonPrintPass pp(out, enforce_importable);
+ pp.printModule(module);
+ return pp.tensor_constants;
+}
+
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
// WARNING: by adding a value to this set, you are asserting
// that you have also added special handling of this symbol to
namespace script {
struct Method;
+ struct Module;
}
TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable=false);
TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& graph, bool enforce_importable=false);
+TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable=false);
+
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
}}
}
};
+// matched against for special handling of getattr expressions
+struct GetAttrValue : SugaredValue {
+ std::string kind() const override {
+ return "getattr";
+ }
+};
+
struct PrintValue : public SugaredValue {
std::string kind() const override {
return "print";
{"float", std::make_shared<CastValue>(FloatType::get())},
{"int", std::make_shared<CastValue>(IntType::get())},
{"bool", std::make_shared<CastValue>(BoolType::get())},
+ {"getattr", std::make_shared<GetAttrValue>()},
// todo(zach): remove when we can correctly export torch.full via ONNX
// or we have implicit conversion that can convert numbers to tensors
{"_to_tensor", std::make_shared<CastValue>(DynamicType::get()) },
<< " but found " << expr->type()->python_str();
}
return std::make_shared<SimpleValue>(expr);
+ } else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
+ if (apply.attributes().size() > 0) {
+ throw ErrorReport(loc) << "getattr takes no keyword arguments";
+ }
+ if (apply.inputs().size() != 2) {
+ throw ErrorReport(loc) << "getattr expects 2 inputs";
+ }
+ auto obj = emitSugaredExpr(apply.inputs()[0], 1);
+ auto selector = apply.inputs()[1];
+ if (selector.kind() != TK_STRINGLITERAL) {
+ throw ErrorReport(loc) << "getattr's second argument must be a string literal";
+ }
+ const std::string& name = StringLiteral(selector).text();
+ return obj->attr(apply.range(), method, name);
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());
py::tuple result(modules.size());
for(size_t i = 0; i < modules.size(); ++i) {
auto & item = modules[i];
- result[i] = std::make_pair(item.key(), item.value());
+ result[i] = std::make_pair(item.key(), item.value().module);
}
return result;
})
// see: [pybind11 varargs]
Module& self = py::cast<Module&>(args[0]);
return invokeScriptMethodFromPython(self.get_method("forward"), tuple_slice(std::move(args), 1), std::move(kwargs));
+ })
+ .def("_python_print", [](Module& self) {
+ std::ostringstream ss;
+ std::vector<at::Tensor> tensors = PythonPrint(ss, self, true);
+ return std::make_pair(ss.str(), tensors);
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
std::istringstream in(buffer);
import_ir_module(module_lookup, in);
});
- m.def("_jit_import_method", import_method);
+ m.def("_jit_import_methods", import_methods);
m.def("_jit_set_emit_module_hook", setEmitModuleHook);
}
// Note: because Method/Module are exposed to python these
// classes use python method naming conventions
+struct Module;
+
struct Method {
- Method(std::string name, bool optimize,
+ Method(Module* owner, std::string name, bool optimize,
std::shared_ptr<Graph> graph,
std::vector<at::Tensor*> initial_members,
std::function<void(Method&)> method_creator)
- : name_(std::move(name))
+ : owner_(owner)
+ , name_(std::move(name))
, graph_(std::move(graph))
, optimize(optimize)
, member_inputs(std::move(initial_members))
return optimize;
}
+ // the module that contains this method.
+ Module& owner() const {
+ return *owner_;
+ }
+
private:
static FunctionSchema defaultSchemaFor(const Method& method) {
return { method.name(), std::move(args), std::move(returns) };
}
- std::string name_;
- std::shared_ptr<Graph> graph_; // for debugging and for inlining
- bool optimize;
-
GraphExecutor& get_executor() {
std::call_once(executor_init, [&]{
executor = GraphExecutor(graph(), optimize);
}
}
+
+ // Methods are uniqued onwed by a single module. This raw pointer allows
+ // looking up the module.
+ Module* owner_;
+
+ std::string name_;
+ std::shared_ptr<Graph> graph_; // for debugging and for inlining
+ bool optimize;
+
GraphExecutor executor; // for execution
// member_inputs are a list of additional arguments appended to graph that are
// inputs that come from the members of the Module or its submodules.
Method& create_method(const std::string & name, std::shared_ptr<Graph> graph, std::vector<at::Tensor*> member_inputs) {
JIT_ASSERT(graph);
- std::unique_ptr<Method> method(new Method(name, optimize, std::move(graph), std::move(member_inputs), nullptr));
+ std::unique_ptr<Method> method(new Method(this, name, optimize, std::move(graph), std::move(member_inputs), nullptr));
return *methods.insert(name, std::move(method));
}
Method& create_method(const std::string & name, std::function<void(Method&)> creator) {
- std::unique_ptr<Method> method(new Method(name, optimize, std::make_shared<Graph>(), {}, creator));
+ std::unique_ptr<Method> method(new Method(this, name, optimize, std::make_shared<Graph>(), {}, creator));
return *methods.insert(name, std::move(method));
}