// from the main method.
optional string name = 6;
+
+ // whether apply the optimizations to this module, only applicable to
+ // script modules
+ optional bool optimize = 7;
}
enum ProtoVersion {
model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
auto* node_proto = model_proto.mutable_graph()->add_node();
node_proto->set_name(prefix + method.name());
- if (method.is_optimized()) {
- // mark that this method was optimized
- node_proto->set_domain("optimized");
- }
// We store the schema string in the docstring.
node_proto->set_doc_string(getExportableSchemaStringForMethod(method));
const std::string& name,
torch::ModuleDef* module_def) {
module_def->set_name(name);
+ module_def->set_optimize(module.is_optimized());
for (const auto& elem : module.get_parameters()) {
torch::ParameterDef* param_def = module_def->add_parameters();
convertParameter(elem.value(), param_def);
member_inputs.push_back(it->second);
}
auto graph = buildGraph(node_proto.attribute(0).g());
- // has_domain field has a string iff the method was optimized
- parent_module->set_optimized(node_proto.has_domain());
parent_module->create_method(name, graph, member_inputs);
// We store the schema in the docstring so we can parse the schema and
// assign it to the method.
void convertModule(
const torch::ModuleDef& module_def,
script::Module* module) {
+ module->set_optimized(module_def.optimize());
for (int i = 0; i < module_def.methods_size(); ++i) {
const torch::MethodDef& method_def = module_def.methods(i);
// TODO read unhacked torch script, right now it's serialized onnx proto
optimize = o;
}
+ bool is_optimized() const {
+ return optimize;
+ }
+
IValue forward(std::vector<IValue> inputs) {
return get_method("forward")(inputs);
}