package torch;
-message ParameterDef {
+message RecordRef {
+ optional string key = 1;
+ // size here refers to the uncompressed size, in bytes of the record
+ // this information also exists in the PyTorch container format data
+ // but is repeated here to make it possible to know size information without
+ // seeking to another record in the file.
+ optional int64 size = 2;
+}
+
+message TensorDef {
+ repeated int64 dims = 1;
+ optional int64 offset = 2;
+ repeated int64 strides = 3;
// whether we compute the gradient for the parameter
- optional bool require_gradient = 1;
- // whether this parameter is registered as buffer or not
- optional bool is_buffer = 2;
+ optional bool requires_grad = 4;
+ optional caffe2.TensorProto.DataType data_type = 5;
- // do not store tensor in parameter anymore, and retire field 3
- // optional caffe2.TensorProto tensor = 3;
- // the id in the tensor table, defined in TensorProto.name
- optional string tensor_id = 5;
- // objects other than tensors will be added here
+ optional RecordRef data = 6;
- optional string name = 4;
+ // future: device options
}
-message MethodDef {
- // method name
- // by default, we follow the naming convention below:
- // 1) forward --> main method
- // 2) init --> init method
- optional string name = 1; // method name
-
- // one of graph and torch_script must exist,
- // if both exist, we reconstruct the graph from torch_script
- optional caffe2.NetDef graph = 2;
- optional string torch_script = 3;
- // temporary place to store the methods of jit script modules
- optional bytes onnx_proto = 101;
-
- // inputs and outputs are inferred from graph or script
+message ParameterDef {
+ // whether this parameter is registered as buffer or not
+ optional bool is_buffer = 1;
+
+ // the offset into the tensor table where this parameter is stored
+ optional int64 tensor_id = 2;
+
+ optional string name = 3;
}
message ModuleDef {
repeated ModuleDef submodules = 1;
- // We suppose to store the modules in one of the following format:
- // - methods (static graph or torch script)
- // - pickle
- // - cpp_arena
- repeated MethodDef methods = 2;
+ optional RecordRef torchscript_arena = 2;
+
+ repeated caffe2.NetDef caffe2_nets = 3;
+
// because the old pickle modules may not be supported by torch_script,
// have to stored as pickle_arena at this moment.
- optional bytes pickle_arena = 3;
+ optional RecordRef pickle_arena = 4;
// should be exposed by the Class Archive, so user can save
// module specific data which cannot be store in the graph or torch_script
- optional bytes cpp_arena = 4;
+ optional RecordRef cpp_arena = 5;
// the parameters of this module
- repeated ParameterDef parameters = 5;
+ repeated ParameterDef parameters = 6;
// the names of inputs and outputs of the module are inferred
// from the main method.
- optional string name = 6;
+ optional string name = 7;
// whether apply the optimizations to this module, only applicable to
// script modules
- optional bool optimize = 7;
+ optional bool optimize = 8;
}
enum ProtoVersion {
// put build version here
optional string producer_version = 4;
- optional string name = 5;
-
- // metadata
- // - exporter - string (either "CAFFE2" or "PYTORCH"),
- // to help the runtime understand who exports the model
- // - debug_info - string
- // for MetaNetDef:
- // - project - string
- // - model_class - string
- // - internal_version - string
- // - predictor_type - string
- // - predictor_id - string
- // - execute_plan - string
- // - applicationSpecificInfo
- // - publish_time - string
- repeated caffe2.Argument annotations = 6;
-
// the table contains all the tensor information
// the tensor id is defined as TensorProto.name
- repeated caffe2.TensorProto tensors = 7;
+ repeated TensorDef tensors = 5;
+
+ // future: add a way to provide additional meta-data
}
-graph(%0 : Float(*, *)
- %1 : Float(*)
- %2 : Float(*)) {
- %3 : Float(*, *) = prim::FusionGroup_0(%2, %0, %1)
+graph(%x : Float(*, *)
+ %scale : Float(*)
+ %shift : Float(*)) {
+ %3 : Float(*, *) = prim::FusionGroup_0(%shift, %x, %scale)
return (%3);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
-graph(%0 : Float(*, *)
- %1 : Float(*, *)) {
- %2 : Float(*, *) = prim::FusionGroup_0(%0, %1)
+graph(%hx : Float(*, *)
+ %cx : Float(*, *)) {
+ %2 : Float(*, *) = prim::FusionGroup_0(%hx, %cx)
return (%2);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
-graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*, *)) {
+graph(%x : Float(*, *)
+ %y : Float(*, *)
+ %z : Float(*, *)) {
%3 : int = prim::Constant[value=1]()
- %4 : Float(*, *) = prim::FusionGroup_0(%0, %1)
- %5 : Float(*, *) = aten::add(%4, %2, %3)
+ %w : Float(*, *) = prim::FusionGroup_0(%x, %y)
+ %5 : Float(*, *) = aten::add(%w, %z, %3)
return (%5);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Float(*, *)) {
%2 : int = prim::Constant[value=1]()
- %3 : Float(*, *) = aten::add(%0, %1, %2)
+ %x1 : Float(*, *) = aten::add(%0, %1, %2)
%4 : int = prim::Constant[value=1]()
- %5 : Float(*, *) = aten::sub(%0, %1, %4)
- %6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5)
- return (%6);
+ %y1 : Float(*, *) = aten::sub(%0, %1, %4)
+ %w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
+ return (%w);
}
%1 : bool = prim::Constant[value=1]()
%b.1 : int = prim::Constant[value=0]()
%3 : int = prim::Constant[value=9223372036854775807]()
- %b.2 : int = prim::Constant[value=1]()
- %b.4 : int = prim::Constant[value=2]()
- %b.3 : int = prim::Loop(%3, %1, %b.1)
+ %4 : int = prim::Constant[value=1]()
+ %5 : int = prim::Constant[value=2]()
+ %b.2 : int = prim::Loop(%3, %1, %b.1)
block0(%7 : int, %8 : int) {
- -> (%1, %b.2)
+ -> (%1, %4)
}
- %b : int = prim::Loop(%3, %0, %b.3)
+ %b : int = prim::Loop(%3, %0, %b.2)
block0(%10 : int, %11 : int) {
- -> (%0, %b.4)
+ -> (%0, %5)
}
return (%b);
}
-graph(%0 : Float(*)
- %1 : Float(*)) {
- %2 : Float(*) = prim::FusionGroup_0(%0, %1)
+graph(%x : Float(*)
+ %y : Float(*)) {
+ %2 : Float(*) = prim::FusionGroup_0(%x, %y)
return (%2);
}
with prim::FusionGroup_0 = graph(%0 : Float(*)
-graph(%0 : Float(*, *)
- %1 : Float(*, *)) {
- %2 : Dynamic[] = prim::ListConstruct(%0, %1)
+graph(%x : Float(*, *)
+ %y : Float(*, *)) {
+ %2 : Dynamic[] = prim::ListConstruct(%x, %y)
%3 : Dynamic[] = aten::broadcast_tensors(%2)
%4 : Dynamic, %5 : Dynamic = prim::ListUnpack(%3)
%6 : Float(*, *) = prim::FusionGroup_0(%5, %4)
def graph(self,
x: Tensor,
y: Tensor) -> Tensor:
- return aten.add(aten.mul(x, 2), y, alpha=1)
+ _0 = torch.add(torch.mul(x, 2), y, alpha=1)
+ return _0
-graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*, *)
- %3 : Float(*, *)
- %4 : Float(*, *)
- %5 : Float(*)
- %6 : Float(*)) {
- %7 : Float(*, *) = aten::t(%3)
- %8 : Float(*, *) = aten::mm(%0, %7)
- %9 : Float(*, *) = aten::t(%4)
- %10 : Float(*, *) = aten::mm(%1, %9)
- %11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
+graph(%input_1 : Float(*, *)
+ %input : Float(*, *)
+ %cx : Float(*, *)
+ %weight_1 : Float(*, *)
+ %weight : Float(*, *)
+ %bias_1 : Float(*)
+ %bias : Float(*)) {
+ %7 : Float(*, *) = aten::t(%weight_1)
+ %8 : Float(*, *) = aten::mm(%input_1, %7)
+ %9 : Float(*, *) = aten::t(%weight)
+ %10 : Float(*, *) = aten::mm(%input, %9)
+ %11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%12 : Dynamic[] = aten::broadcast_tensors(%11)
%13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
- %17 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
+ %17 : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
return (%17);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%42 : Float(*, *) = aten::add(%34, %26, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %28, %43)
- %45 : Float(*, *) = aten::sigmoid(%38)
- %46 : Float(*, *) = aten::sigmoid(%40)
- %47 : Float(*, *) = aten::tanh(%42)
- %48 : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%46, %0)
- %50 : Float(*, *) = aten::mul(%45, %47)
+ %ingate : Float(*, *) = aten::sigmoid(%38)
+ %forgetgate : Float(*, *) = aten::sigmoid(%40)
+ %cellgate : Float(*, *) = aten::tanh(%42)
+ %outgate : Float(*, *) = aten::sigmoid(%44)
+ %49 : Float(*, *) = aten::mul(%forgetgate, %0)
+ %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
%51 : int = prim::Constant[value=1]()
- %52 : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%52)
- %54 : Float(*, *) = aten::mul(%48, %53)
- %55 : Float(*, *) = prim::FusedConcat[dim=0](%54, %52)
+ %cy : Float(*, *) = aten::add(%49, %50, %51)
+ %53 : Float(*, *) = aten::tanh(%cy)
+ %hy : Float(*, *) = aten::mul(%outgate, %53)
+ %55 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
return (%55);
}
-graph(%0 : Float(*, *)
- %1 : Float(*, *)
- %2 : Float(*, *)
- %3 : Float(*, *)
- %4 : Float(*, *)
- %5 : Float(*)
- %6 : Float(*)) {
- %7 : Float(*, *) = aten::t(%3)
- %8 : Float(*, *) = aten::mm(%0, %7)
- %9 : Float(*, *) = aten::t(%4)
- %10 : Float(*, *) = aten::mm(%1, %9)
- %11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
+graph(%input_1 : Float(*, *)
+ %input : Float(*, *)
+ %cx : Float(*, *)
+ %weight_1 : Float(*, *)
+ %weight : Float(*, *)
+ %bias_1 : Float(*)
+ %bias : Float(*)) {
+ %7 : Float(*, *) = aten::t(%weight_1)
+ %8 : Float(*, *) = aten::mm(%input_1, %7)
+ %9 : Float(*, *) = aten::t(%weight)
+ %10 : Float(*, *) = aten::mm(%input, %9)
+ %11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
%12 : Dynamic[] = aten::broadcast_tensors(%11)
%13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
- %17 : Float(*, *), %18 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
- return (%17, %18);
+ %17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
+ return (%17, %cy);
}
with prim::FusionGroup_0 = graph(%0 : Float(*, *)
%1 : Dynamic
%42 : Float(*, *) = aten::add(%34, %26, %41)
%43 : int = prim::Constant[value=1]()
%44 : Float(*, *) = aten::add(%36, %28, %43)
- %45 : Float(*, *) = aten::sigmoid(%38)
- %46 : Float(*, *) = aten::sigmoid(%40)
- %47 : Float(*, *) = aten::tanh(%42)
- %48 : Float(*, *) = aten::sigmoid(%44)
- %49 : Float(*, *) = aten::mul(%46, %0)
- %50 : Float(*, *) = aten::mul(%45, %47)
+ %ingate : Float(*, *) = aten::sigmoid(%38)
+ %forgetgate : Float(*, *) = aten::sigmoid(%40)
+ %cellgate : Float(*, *) = aten::tanh(%42)
+ %outgate : Float(*, *) = aten::sigmoid(%44)
+ %49 : Float(*, *) = aten::mul(%forgetgate, %0)
+ %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
%51 : int = prim::Constant[value=1]()
- %52 : Float(*, *) = aten::add(%49, %50, %51)
- %53 : Float(*, *) = aten::tanh(%52)
- %54 : Float(*, *) = aten::mul(%48, %53)
- return (%54, %52);
+ %cy : Float(*, *) = aten::add(%49, %50, %51)
+ %53 : Float(*, *) = aten::tanh(%cy)
+ %54 : Float(*, *) = aten::mul(%outgate, %53)
+ return (%54, %cy);
}
def graph(self,
y: Tensor) -> int:
x = annotate(List[int], [])
- return aten.select(x, 0)
+ return torch.select(x, 0)
def graph(self,
a: Tensor,
b: Tensor) -> Tensor:
- if bool(aten.lt(a, b)):
+ if bool(torch.lt(a, b)):
c = a
else:
c = b
def graph(self,
a: Tensor,
b: Tensor) -> Tensor:
- if bool(aten.lt(a, b)):
+ if bool(torch.lt(a, b)):
c = b
else:
c = a
def graph(self,
y_1: Tensor) -> Tuple[Tensor, Tensor]:
- x = aten.add(y_1, 1, 1)
- z_1 = aten.add(x, 5, 1)
+ x = torch.add(y_1, 1, 1)
+ z_1 = torch.add(x, 5, 1)
y, z = y_1, z_1
- _0 = bool(aten.lt(y_1, 8))
+ _0 = bool(torch.lt(y_1, 8))
while _0:
- y_2 = aten.add_(y, 1, 1)
- _0, y, z = bool(aten.lt(y_2, 8)), y_2, x
+ y_2 = torch.add_(y, 1, 1)
+ _0, y, z = bool(torch.lt(y_2, 8)), y_2, x
return x, z
a_1: Tensor,
b_1: Tensor) -> Tensor:
a, b, c = a_1, b_1, 0
- _0 = bool(aten.lt(a_1, 10))
+ _0 = bool(torch.lt(a_1, 10))
while _0:
- a_2 = aten.add(a, 1, 1)
- b_2 = aten.add(b, 1, 1)
- if bool(aten.gt(a_2, b_2)):
- c_4 = 2
+ a_2 = torch.add(a, 1, 1)
+ b_2 = torch.add(b, 1, 1)
+ if bool(torch.gt(a_2, b_2)):
+ c_2 = 2
else:
- c_4 = 3
- _0, a, b, c = bool(aten.lt(a_2, 10)), a_2, b_2, c_4
- return aten.add(aten.add(a, 1, 1), c, 1)
+ c_2 = 3
+ _0, a, b, c = bool(torch.lt(a_2, 10)), a_2, b_2, c_2
+ return torch.add(torch.add(a, 1, 1), c, 1)
a_1: Tensor,
i_1: Tensor) -> Tensor:
a, i = a_1, i_1
- _0 = bool(aten.lt(i_1, 3))
+ _0 = bool(torch.lt(i_1, 3))
while _0:
- a_2 = aten.mul_(a, a)
- i_2 = aten.add_(i, 1, 1)
- _0, a, i = bool(aten.lt(i_2, 3)), a_2, i_2
+ a_2 = torch.mul_(a, a)
+ i_2 = torch.add_(i, 1, 1)
+ _0, a, i = bool(torch.lt(i_2, 3)), a_2, i_2
return a
-graph(%0 : Double(2, 2)
- %1 : Double(2, 2)) {
+graph(%a : Dynamic
+ %b : Dynamic) {
%2 : int = prim::Constant[value=1]()
- %3 : Double(2, 2) = aten::add(%0, %1, %2)
+ %3 : Dynamic = aten::add(%a, %b, %2)
return (%3);
}
-graph(%0 : Double(2, 2)
- %1 : Double(2, 2)) {
+graph(%a : Dynamic
+ %b : Dynamic) {
%2 : int = prim::Constant[value=1]()
- %3 : Double(2, 2) = aten::add(%0, %1, %2)
+ %3 : Dynamic = aten::add(%a, %b, %2)
return (%3, %3);
}
%1 : bool = prim::Constant[value=1]()
%2 : bool = prim::Constant[value=0]()
%c1.1 : int = prim::Constant[value=1]()
- %c1.2 : int = prim::Constant[value=0]()
+ %4 : int = prim::Constant[value=0]()
%5 : bool = prim::If(%2)
block0() {
- %6 : Dynamic = aten::select(%t, %c1.2, %c1.1)
+ %6 : Dynamic = aten::select(%t, %4, %c1.1)
%7 : bool = prim::TensorToBool(%6)
-> (%7)
}
-> (%1)
}
block1() {
- %10 : Dynamic = aten::select(%t, %c1.2, %c1.1)
+ %10 : Dynamic = aten::select(%t, %4, %c1.1)
%11 : bool = prim::TensorToBool(%10)
-> (%11)
}
}
%c1 : int = prim::If(%8)
block0() {
- -> (%c1.2)
+ -> (%4)
}
block1() {
-> (%c1.1)
raise
else:
return
-
ppv = "op_version_set = 0\n{}".format(pp)
sm = copy_structure_and_params(module)
torch._C._jit_import_methods(sm, ppv, constant_table)
return ge
+ def test_jitter_bug(self):
+ @torch.jit.script
+ def fn2(input, kernel_size):
+ # type: (Tensor, List[int]) -> Tensor
+ if kernel_size[0] > 1:
+ _stride = [2]
+ else:
+ _stride = kernel_size
+ print(_stride, kernel_size)
+ return input
+
+ @torch.jit.script
+ def fn(input):
+ # type: (Tensor) -> Tensor
+ return fn2(input, [1])
+
def test_annoying_doubles(self):
mod = types.ModuleType("temp")
mod.inf = float("inf")
#include "torch/csrc/utils/functional.h"
#include <torch/csrc/jit/assertions.h>
#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/python_print.h"
+
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
class ScriptModuleSerializer;
-std::string getExportableSchemaStringForMethod(const script::Method& method) {
- const auto& schema = method.getSchema();
- for (const auto& argument : schema.arguments()) {
- AT_CHECK(
- !argument.default_value(),
- "Default arguments in script graphs may currently not be exported.");
- }
- std::ostringstream stream;
- stream << schema;
- return stream.str();
-}
-
std::string getNodeStackTraceString(const Node* n) {
std::stringstream ss;
if (n->getSourceLocation()) {
}
}
-class MethodEncoder : public EncoderBase {
- public:
- MethodEncoder(
- const script::Method& method,
- const ScriptModuleSerializer& serializer);
-
- std::string EncodeMethod(
- const script::Method& method,
- const std::string& prefix);
-
- private:
- void EncodeTensor(
- onnx::TensorProto* tensor_proto,
- const at::Tensor& tensor,
- const c10::optional<std::string> external_ref = {}) override;
-
- void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
- const Value* n) override;
-
- void EncodeValueInfo(onnx::GraphProto *graph_proto,
- onnx::ValueInfoProto* v,
- const Value* n) override;
-
- void EncodeTypeInfo(onnx::GraphProto *graph_proto,
- onnx::ValueInfoProto* v,
- const TypePtr& type,
- const std::string& name);
-
- // serializer already serialized all the tensors, and stores
- // the tensor and parameter tables
- const ScriptModuleSerializer* serializer_;
-
- // Used to create sequential dummy names for node types
- size_t type_counter_ = 0;
-};
-
// this is a serializer class which saves script modules to pt files. the
// content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h. all the records except the last
void serialize(const script::Module& module);
- uint64_t lookupTensorId(const at::Tensor* tensor) const;
-
- const std::string& lookupParamName(const at::Tensor* tensor) const;
-
private:
- void convertToModel(const script::Module& module, torch::ModelDef* model_def);
+ void convertModel(const script::Module& module, torch::ModelDef* model_def);
// add a tensor to the tensorTable
- void addTensor(const at::Tensor* tensor);
-
- // recursively collect the tensors in a block and add them to the tensorTable
- void findTensorInBlock(const Block& block);
-
- // recursively iterate over the whole module to collect the information of
- // tensors and parameters
- void collectInfo(const script::Module& module, const std::string& prefix);
+ // returns the offset into the tensor table
+ size_t addTensor(const at::Tensor& tensor);
// write the content of the tensor to the file/stream, and save the
// offset in the storageMap_
void convertAndWriteTensor(
const at::Tensor& tensor,
- caffe2::TensorProto* tensor_proto);
+ torch::TensorDef* tensor_proto,
+ std::unordered_map<const void*, uint64_t>& storageMap);
// dump all the tensors in the tensorTable_ to a ModelDef (metadata) and
// the file/stream (the content), assuming all the information of the
const script::NamedParameter& param,
torch::ParameterDef* param_def);
- void convertMethod(
- const script::Method& method,
- torch::MethodDef* method_def);
-
std::ofstream ofs_;
PyTorchStreamWriter writer_;
- // storage_ptr => record_offset
- std::unordered_map<const void*, uint64_t> storageMap_;
- // tensor => param name
- std::unordered_map<const at::Tensor*, std::string> paramMap_;
- // tensor => tensor_id
- std::unordered_map<const at::Tensor*, uint64_t> tensorTable_;
- // used for generating table id for tensors
- uint64_t nextTensorId_ = 0;
-};
-
-// MethodEncoder's methods
-MethodEncoder::MethodEncoder(
- const script::Method& method,
- const ScriptModuleSerializer& serializer)
- : EncoderBase(onnx_torch::OperatorExportTypes::RAW, false) {
- serializer_ = &serializer;
-}
-
-std::string MethodEncoder::EncodeMethod(
- const script::Method& method,
- const std::string& prefix) {
- onnx::ModelProto model_proto;
- model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
- auto* node_proto = model_proto.mutable_graph()->add_node();
- node_proto->set_name(prefix + method.name());
-
- // We store the schema string in the docstring.
- node_proto->set_doc_string(getExportableSchemaStringForMethod(method));
-
- // Store member_inputs of Method in input
- for (auto& member_input : method.params()) {
- const auto& param_name = serializer_->lookupParamName(member_input);
- node_proto->add_input(param_name);
- }
-
- auto attr_proto = node_proto->add_attribute();
- attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-
- for (auto node : method.graph()->nodes()) {
- if (node->kind() == prim::PythonOp) {
- auto py_node = static_cast<torch::jit::PythonOp*>(node);
- throw std::runtime_error(
- "Couldn't export Python operator " + py_node->name() +
- "\n\nDefined at:\n" + getNodeStackTraceString(node));
- }
- }
- EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {});
- std::string torch_script;
- AT_ASSERT(model_proto.SerializeToString(&torch_script));
- return torch_script;
-}
-
-void MethodEncoder::EncodeTensor(
- onnx::TensorProto* tensor_proto,
- const at::Tensor& tensor,
- const c10::optional<std::string> external_ref) {
- uint64_t tensor_id = serializer_->lookupTensorId(&tensor);
- tensor_proto->set_name(c10::to_string(tensor_id));
- // No need to store the content of the tensor to the file/stream
- // any more, since it is already saved at the beginning of the
- // serialization in writeTensorTable
-}
-
-void MethodEncoder::EncodeIntermediateValueInfo(
- onnx::GraphProto* graph_proto,
- const Value* n) {
- auto v = graph_proto->add_value_info();
- EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
-}
-
-c10::optional<std::string> getBaseTypeDenotation(TypeKind& kind) {
- if (kind == TypeKind::NumberType) {
- return "NumberType";
- } else if (kind == TypeKind::FloatType) {
- return "FloatType";
- } else if (kind == TypeKind::IntType) {
- return "IntType";
- } else if (kind == TypeKind::BoolType) {
- return "BoolType";
- } else if (kind == TypeKind::NoneType) {
- return "NoneType";
- } else if (kind == TypeKind::GeneratorType) {
- return "GeneratorType";
- } else if (kind == TypeKind::StringType) {
- return "StringType";
- }
- return c10::nullopt;
-}
-
-void MethodEncoder::EncodeTypeInfo(
- onnx::GraphProto* graph_proto,
- onnx::ValueInfoProto* v,
- const TypePtr& type,
- const std::string& name) {
- v->set_name(name);
- onnx::TypeProto* type_proto = v->mutable_type();
- onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type();
- onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape();
-
- // Use TypeProto fields to encode types.
- // denotation stores the type as a string
- auto kind = type->kind();
- if (kind == TypeKind::DynamicType) {
- type_proto->set_denotation("DynamicType");
- tensortype_proto->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
- } else if (kind == TypeKind::TensorType) {
- type_proto->set_denotation("TensorType");
- // encode the number of dimensions by pushing that number of ones into the shape proto
- auto tensor_type = type->expect<TensorType>();
- for (int i = 0; i < tensor_type->dim(); i++) {
- shape_proto->add_dim();
- shape_proto->mutable_dim(i)->set_dim_value(1);
- }
- tensortype_proto->set_elem_type(ATenTypeToOnnxType(tensor_type->scalarType()));
- } else if (kind == TypeKind::CompleteTensorType) {
- type_proto->set_denotation("CompleteTensorType");
- CompleteTensorTypePtr node_type = type->cast<CompleteTensorType>();
-
- // store the sizes and strides in the dims field of TensorShapeProto
- size_t i = 0;
- for (auto &size : node_type->sizes()) {
- shape_proto->add_dim();
- shape_proto->mutable_dim(i)->set_dim_value(size);
- i++;
- }
- for (auto &stride : node_type->strides()) {
- shape_proto->add_dim();
- shape_proto->mutable_dim(i)->set_dim_value(stride);
- i++;
- }
- tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
- } else if (kind == TypeKind::TupleType) {
- type_proto->set_denotation("TupleType");
- TupleTypePtr node_type = type->cast<TupleType>();
- auto elements = node_type->elements();
-
- // Generate a name for and encode each subtype in the value_info field of the GraphProto.
- for (size_t i = 0; i < elements.size(); i++) {
- std::string name = "#" + std::to_string(type_counter_++);
- shape_proto->add_dim();
- shape_proto->mutable_dim(i)->set_dim_param(name);
- onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
- EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name);
- }
- } else if (kind == TypeKind::ListType) {
- type_proto->set_denotation("ListType");
- ListTypePtr node_type = type->cast<ListType>();
-
- // Generate a name for and encode the subtype in the value_info field of the GraphProto.
- std::string name = "#" + std::to_string(type_counter_++);
- shape_proto->add_dim();
- shape_proto->mutable_dim(0)->set_dim_param(name);
- onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
- EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
- } else if (kind == TypeKind::VarType) {
- type_proto->set_denotation("TypeVar:" + type->expect<VarType>()->name());
- } else if (kind == TypeKind::OptionalType) {
- type_proto->set_denotation("OptionalType");
- OptionalTypePtr node_type = type->cast<OptionalType>();
-
- // Generate a name for and encode each subtype in the value_info field of the GraphProto.
- std::string name = "#" + std::to_string(type_counter_++);
- shape_proto->add_dim();
- shape_proto->mutable_dim(0)->set_dim_param(name);
- onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
- EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
- } else {
- auto denotation = getBaseTypeDenotation(kind);
- if (!denotation) {
- throw std::runtime_error("unexpected type kind");
- }
- type_proto->set_denotation(*denotation);
- }
-}
-void MethodEncoder::EncodeValueInfo(
- onnx::GraphProto* graph_proto,
- onnx::ValueInfoProto* v,
- const Value* n) {
- EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
-}
+ // all tensors that will be stored
+ std::vector<at::Tensor> tensor_table_;
+};
// ScriptModuleSerializer's methods
ScriptModuleSerializer::ScriptModuleSerializer(const std::string& filename)
void ScriptModuleSerializer::serialize(const script::Module& module) {
torch::ModelDef model_def;
- convertToModel(module, &model_def);
+ convertModel(module, &model_def);
std::string output;
// NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
// be consistent with MessageToJsonString
writer_.writeEndOfFile();
}
-uint64_t ScriptModuleSerializer::lookupTensorId(
- const at::Tensor* tensor) const {
- auto it = tensorTable_.find(tensor);
- AT_ASSERT(it != tensorTable_.end());
- return it->second;
-}
-
-const std::string& ScriptModuleSerializer::lookupParamName(
- const at::Tensor* tensor) const {
- auto it = paramMap_.find(tensor);
- AT_ASSERT(it != paramMap_.end());
- return it->second;
-}
-
-void ScriptModuleSerializer::convertToModel(
+void ScriptModuleSerializer::convertModel(
const script::Module& module,
torch::ModelDef* model_def) {
- model_def->set_name("script-model");
model_def->set_producer_name("pytorch");
model_def->set_producer_version("1.0"); // TODO: set the producer version
// using appropriate function call
model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
std::string main_module_name = "";
- nextTensorId_ = 0;
- collectInfo(module, main_module_name);
- writeTensorTable(model_def);
convertModule(module, main_module_name, model_def->mutable_main_module());
+ writeTensorTable(model_def);
}
-void ScriptModuleSerializer::addTensor(const at::Tensor* tensor) {
- if (tensorTable_.find(tensor) == tensorTable_.end()) {
- tensorTable_[tensor] = nextTensorId_;
- ++nextTensorId_;
- }
-}
-
-void ScriptModuleSerializer::findTensorInBlock(const Block& block) {
- for (auto node : block.nodes()) {
- for (auto attr_name : node->attributeNames()) {
- AT_ASSERT(attr_name.is_attr());
- switch (node->kindOf(attr_name)) {
- case AttributeKind::f:
- case AttributeKind::fs:
- case AttributeKind::i:
- case AttributeKind::is:
- case AttributeKind::s:
- case AttributeKind::ss:
- break;
- case AttributeKind::t: {
- const at::Tensor* tensor = &node->t(attr_name);
- addTensor(tensor);
- } break;
- case AttributeKind::ts: {
- for (auto& v : node->ts(attr_name)) {
- const at::Tensor* tensor = &v;
- addTensor(tensor);
- }
- } break;
- case AttributeKind::g: {
- findTensorInBlock(*node->g(attr_name)->block());
- } break;
- case AttributeKind::gs: {
- for (auto& v : node->gs(attr_name)) {
- findTensorInBlock(*v->block());
- }
- } break;
- default:
- AT_ERROR("unexpected attribute kind");
- }
- }
- for (auto b : node->blocks()) {
- findTensorInBlock(*b);
- }
- }
-}
-
-void ScriptModuleSerializer::collectInfo(
- const script::Module& module,
- const std::string& prefix) {
- for (const auto& elem : module.get_parameters()) {
- const script::NamedParameter& param = elem.value();
- paramMap_[param.slot()] = prefix + param.name;
- addTensor(param.slot());
- }
- for (const auto& elem : module.get_methods()) {
- findTensorInBlock(*elem.value()->graph()->block());
- }
- for (const auto& elem : module.get_modules()) {
- collectInfo(*elem->module, prefix + elem.key() + ".");
- }
+size_t ScriptModuleSerializer::addTensor(const at::Tensor& tensor) {
+ tensor_table_.push_back(tensor);
+ return tensor_table_.size() - 1;
}
void ScriptModuleSerializer::convertAndWriteTensor(
const at::Tensor& tensor,
- caffe2::TensorProto* tensor_proto) {
- auto tensor_it = tensorTable_.find(&tensor);
- AT_ASSERT(tensor_it != tensorTable_.end());
- tensor_proto->set_name(c10::to_string(tensor_it->second));
+ torch::TensorDef* tensor_proto,
+ std::unordered_map<const void*, uint64_t>& storageMap) {
for (auto d : tensor.sizes()) {
tensor_proto->add_dims(d);
}
- tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
- at::scalarTypeToTypeMeta(tensor.type().scalarType())));
- tensor_proto->set_storage_type(caffe2::TensorProto_StorageType_EXTERNAL);
- caffe2::ExternalDataProto* external_data =
- tensor_proto->mutable_external_data();
for (auto s : tensor.strides()) {
- external_data->add_strides(s);
+ tensor_proto->add_strides(s);
}
- external_data->set_offset(tensor.storage_offset());
+ tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
+ at::scalarTypeToTypeMeta(tensor.type().scalarType())));
+ tensor_proto->set_offset(tensor.storage_offset());
+
+ tensor_proto->set_requires_grad(tensor.requires_grad());
+
uint64_t record_size =
tensor.type().elementSizeInBytes() * tensor.storage().size();
- external_data->set_record_size(record_size);
auto* key = tensor.storage().unsafeGetStorageImpl();
- auto storage_it = storageMap_.find(key);
- if (storage_it == storageMap_.end()) {
+
+ auto storage_it = storageMap.find(key);
+ if (storage_it == storageMap.end()) {
+ at::Tensor storage_tensor = tensor;
// TODO HIP support
- uint64_t record_id;
if (tensor.storage().device_type() == at::DeviceType::CUDA) {
// NB: This new tensor is created to support cuda tensors.
// Storages can be mutated when converting tensors from cuda to cpu,
// and we need a cpu tensor to copy data from.
- at::Tensor t = at::getType(tensor)
- ._th_tensor(
- tensor.storage(),
- /* storageOffset = */ 0,
- /* size = */
- {static_cast<int64_t>(tensor.storage().size())},
- /* stride = */ {1})
- .cpu();
+ storage_tensor = at::getType(tensor)
+ ._th_tensor(
+ tensor.storage(),
+ /* storageOffset = */ 0,
+ /* size = */
+ {static_cast<int64_t>(tensor.storage().size())},
+ /* stride = */ {1})
+ .cpu();
AT_ASSERT(
- t.type().elementSizeInBytes() * t.storage().size() == record_size);
- record_id = writer_.writeRecord(
- t.storage().data(),
- t.type().elementSizeInBytes() * t.storage().size());
- } else {
- record_id = writer_.writeRecord(tensor.storage().data(), record_size);
+ storage_tensor.type().elementSizeInBytes() * storage_tensor.storage().size() ==
+ record_size);
}
- external_data->set_record_id(c10::to_string(record_id));
- storageMap_[key] = record_id;
- } else {
- external_data->set_record_id(c10::to_string(storage_it->second));
+ uint64_t record_id = writer_.writeRecord(storage_tensor.storage().data(), record_size);
+ storage_it = storageMap.insert({key, record_id}).first;
}
+
+ auto* data = tensor_proto->mutable_data();
+ data->set_key(std::to_string(storage_it->second));
+ data->set_size(record_size);
+
// TODO handle device case, set the device_detail and load to CUDA device
}
void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
- // NB: we don't reserve any order for tensors in the tensorTable_
- for (const auto& kv : tensorTable_) {
+ std::unordered_map<const void*, uint64_t> storageMap;
+ for (const at::Tensor& t : tensor_table_) {
auto* tensor_proto = model_def->add_tensors();
- convertAndWriteTensor(*kv.first, tensor_proto);
+ convertAndWriteTensor(t, tensor_proto, storageMap);
}
}
torch::ParameterDef* param_def = module_def->add_parameters();
convertParameter(elem.value(), param_def);
}
- for (auto& elem : module.get_methods()) {
- torch::MethodDef* method_def = module_def->add_methods();
- convertMethod(*elem.value(), method_def);
- }
+
+ std::ostringstream ss;
+ ss << "op_version_set = 0\n";
+ PythonPrint(ss, module, tensor_table_, /*enforce_importable=*/true);
+ torch::RecordRef* record = module_def->mutable_torchscript_arena();
+ std::string str = ss.str();
+ auto key = writer_.writeRecord(str.c_str(), str.size());
+ record->set_key(std::to_string(key));
+ record->set_size(str.size());
+
for (const auto& elem : module.get_modules()) {
torch::ModuleDef* sub_def = module_def->add_submodules();
convertModule(*elem->module, elem.key(), sub_def);
torch::ParameterDef* param_def) {
param_def->set_name(param.name);
param_def->set_is_buffer(param.is_buffer);
- param_def->set_require_gradient(param.slot()->requires_grad());
- auto it = tensorTable_.find(param.slot());
- AT_ASSERT(it != tensorTable_.end());
- param_def->set_tensor_id(c10::to_string(it->second));
-}
-
-void ScriptModuleSerializer::convertMethod(
- const script::Method& method,
- torch::MethodDef* method_def) {
- // TODO encode the real torch script instead of ModelProto
- MethodEncoder encoder(method, *this);
- // we already keep the tree structure in the top level module,
- // so pass "" as prefix
- std::string torch_script = encoder.EncodeMethod(method, "");
- method_def->set_onnx_proto(torch_script);
+ param_def->set_tensor_id(addTensor(*param.slot()));
}
-// Pretty printing
+// Pretty printing for ONNX
constexpr char indent_char = ' ';
constexpr size_t indent_multiplier = 2;
} // namespace
-std::string PrettyPrintExportedGraph(
+std::string pretty_print_onnx(
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers,
int64_t onnx_opset_version,
// conform to the ONNX op specification. Thus, the output will not
// be interpretable by a ONNX-compatible framework. However, PyTorch or
// libtorch will be able to import the IR and play it back.
-std::tuple<std::string, RawDataExportMap> ExportGraph(
+std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph> &graph,
const std::vector<at::Tensor> &initializers,
int64_t onnx_opset_version,
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
-TORCH_API std::tuple<std::string, RawDataExportMap> ExportGraph(
+TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version,
= ::torch::onnx::OperatorExportTypes::ONNX);
// For testing purposes
-TORCH_API std::string PrettyPrintExportedGraph(
+TORCH_API std::string pretty_print_onnx(
const std::shared_ptr<Graph>& graph,
const std::vector<at::Tensor> & initializers,
int64_t onnx_opset_version,
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/import_method.h"
+
#include "caffe2/core/types.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/torch_pb.h"
#include "caffe2/serialize/inline_container.h"
-#include "onnx/onnx_pb.h"
#include <ATen/ATen.h>
namespace {
-namespace onnx = ::ONNX_NAMESPACE;
-
-// IR graph construction
-
-class ScriptModuleDeserializer;
-
-class MethodDecoder {
- public:
- MethodDecoder(
- const onnx::ModelProto& model_proto,
- script::Module* parent_module,
- ScriptModuleDeserializer* deserializer);
-
- private:
- std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
-
- void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
- std::unordered_map<std::string, Value*>& value_map);
-
- void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
- std::unordered_map<std::string, Value*>& value_map);
-
- void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto);
-
- void buildIntermediateValue(Value* value, const std::string& name);
-
- at::ScalarType onnxTypeToATenType(int32_t tensor_proto);
-
- at::Tensor buildTensor(const onnx::TensorProto& tensor_proto);
-
- TypePtr buildType(const onnx::TypeProto& type_proto);
-
- std::pair<std::shared_ptr<script::Module>, std::string> parseFullName(
- ModuleLookup module_lookup,
- const std::string fullname);
-
- // deserializer already loads the metadata of tensors, and it is used to
- // load tensors
- ScriptModuleDeserializer* deserializer_;
- std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
-};
-
// this is a deserializer class which loads script modules from pt files. the
// content of the file is written using PyTorchStreamWriter, for details please
// check caffe2/serialize/inline_container.h. all the records except the last
void deserialize(ModuleLookup module_lookup);
- // given the tensor id, load the data of the tensor from file/stream,
- // and return a new tensor which contains the loaded data
- at::Tensor loadTensor(uint64_t tensor_id);
-
- at::Tensor* lookupTensor(const std::string& param_name) const;
-
- private:
- // recursively load all the parameters of a module, and construct a
- // parameter map (i.e., name => tensor). call loadTensor to load and
- // create a new tensor
- void loadParams(
- const torch::ModuleDef& module_def,
- const std::string& prefix);
-
- void convertModule(
- const torch::ModuleDef& module_def,
- script::Module* module);
-
- std::ifstream ifs_;
- PyTorchStreamReader reader_;
- // this is a hack to make sure the script module created in C++ is the
- // same as created in Python
- ModuleLookup moduleLookup_;
- std::vector<std::string> moduleStack_;
- // record_id => storage
- std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storageMap_;
- // tensor_id => TensorProto
- std::unordered_map<uint64_t, const caffe2::TensorProto*> metaMap_;
- // parameter_name => at::Tensor
- std::unordered_map<std::string, at::Tensor*> paramMap_;
-};
-
-at::ScalarType MethodDecoder::onnxTypeToATenType(int32_t onnx_type) {
- switch(onnx_type) {
- case onnx::TensorProto_DataType_UINT8:
- return at::kByte;
- case onnx::TensorProto_DataType_INT8:
- return at::kChar;
- case onnx::TensorProto_DataType_INT16:
- return at::kShort;
- case onnx::TensorProto_DataType_INT32:
- return at::kInt;
- case onnx::TensorProto_DataType_INT64:
- return at::kLong;
- case onnx::TensorProto_DataType_FLOAT16:
- return at::kHalf;
- case onnx::TensorProto_DataType_FLOAT:
- return at::kFloat;
- case onnx::TensorProto_DataType_DOUBLE:
- return at::kDouble;
- default:
- throw std::runtime_error("Unsupported data type");
- }
-}
-
-MethodDecoder::MethodDecoder(
- const onnx::ModelProto& model_proto,
- script::Module* parent_module,
- ScriptModuleDeserializer* deserializer) {
- deserializer_ = deserializer;
- const auto& graph_proto = model_proto.graph();
- for (const auto& node_proto : graph_proto.node()) {
- std::vector<at::Tensor*> member_inputs;
- const std::string& name = node_proto.name();
- for (const auto& param_name : node_proto.input()) {
- at::Tensor* tensor = deserializer_->lookupTensor(param_name);
- member_inputs.push_back(tensor);
- }
- auto graph = buildGraph(node_proto.attribute(0).g());
- parent_module->create_method(name, graph, member_inputs);
- // We store the schema in the docstring so we can parse the schema and
- // assign it to the method.
- auto schema = parseSchema(node_proto.doc_string());
- parent_module->get_method(name).setSchema(std::move(schema));
- }
-}
-
-void MethodDecoder::buildBlocks(
- const std::vector<onnx::GraphProto>& graphs_,
- Node* node,
- std::unordered_map<std::string, Value*>& value_map) {
- for (auto g_ : graphs_) {
- auto block = node->addBlock();
- buildBlock(g_, block, value_map);
- }
-}
-
-std::shared_ptr<Graph> MethodDecoder::buildGraph(
- const onnx::GraphProto& graph_proto) {
- auto graph = std::make_shared<Graph>();
- std::unordered_map<std::string, Value*> value_map;
-
- buildBlock(graph_proto, graph->block(), value_map);
-
- return graph;
-}
-
-void MethodDecoder::buildBlock(
- const onnx::GraphProto& graph_proto,
- Block* block,
- std::unordered_map<std::string, Value*>& value_map) {
- for (auto &subtype : graph_proto.value_info()) {
- value_type_map_[subtype.name()] = &subtype.type();
- }
+private:
+ at::Tensor loadTensor(
+ const torch::TensorDef& tensor_proto,
+ std::unordered_map<uint64_t, at::Storage>& storageMap);
- for (auto & input : graph_proto.input()) {
- auto value = block->addInput();
- value_map[input.name()] = value;
- buildValue(value, input);
- }
+ void convertModule(const torch::ModuleDef& module_def);
- for (auto & node_ : graph_proto.node()) {
- JIT_ASSERT(node_.op_type() != "PythonOp");
-
- auto node = block->owningGraph()->create(Symbol::fromDomainAndUnqualString(node_.domain(), node_.op_type()),
- node_.output().size());
-
- for (auto & attr : node_.attribute()) {
- Symbol name = Symbol::attr(attr.name());
-
- switch(attr.type()) {
- case onnx::AttributeProto_AttributeType_UNDEFINED:
- throw std::runtime_error("UNDEFINED attribute unsupported");
- break;
- case onnx::AttributeProto_AttributeType_FLOAT:
- node->f_(name, attr.f());
- break;
- case onnx::AttributeProto_AttributeType_INT:
- node->i_(name, attr.i());
- break;
- case onnx::AttributeProto_AttributeType_STRING:
- node->s_(name, std::move(attr.s()));
- break;
- case onnx::AttributeProto_AttributeType_TENSOR:
- node->t_(name, buildTensor(attr.t()));
- break;
- case onnx::AttributeProto_AttributeType_GRAPH:
- node->g_(name, buildGraph(attr.g()));
- break;
- case onnx::AttributeProto_AttributeType_FLOATS:
- node->fs_(name, {attr.floats().begin(), attr.floats().end()});
- break;
- case onnx::AttributeProto_AttributeType_INTS:
- node->is_(name, {attr.ints().begin(), attr.ints().end()});
- break;
- case onnx::AttributeProto_AttributeType_STRINGS:
- node->ss_(name, {attr.strings().begin(), attr.strings().end()});
- break;
- case onnx::AttributeProto_AttributeType_TENSORS:
- node->ts_(name, fmap(attr.tensors(), [this](const onnx::TensorProto& t) {
- return buildTensor(t);
- }));
- break;
- case onnx::AttributeProto_AttributeType_GRAPHS:
- if (attr.name() == "_blocks") {
- buildBlocks({attr.graphs().begin(), attr.graphs().end()}, node, value_map);
- }
- else {
- node->gs_(name, fmap(attr.graphs(), [this](const onnx::GraphProto& g_) {
- return buildGraph(g_);
- }));
- }
- break;
- }
- }
+ void loadTensorTable(torch::ModelDef* model_def);
- for (auto & input : node_.input()) {
- auto v = value_map[input];
- node->addInput(v);
- }
+ std::ifstream ifs_;
+ PyTorchStreamReader reader_;
+ // this is a hack to make sure the script module created in C++ is the
+ // same as created in Python
+ ModuleLookup moduleLookup_;
+ std::vector<std::string> moduleStack_;
- for (int i=0; i<node_.output().size(); i++) {
- value_map[node_.output(i)] = node->outputs()[i];
- buildIntermediateValue(node->outputs()[i], node_.output(i));
- }
-
- block->appendNode(node);
- }
-
- for (auto & output : graph_proto.output()) {
- Value* v = value_map.at(output.name());
- buildValue(v, output);
- block->registerOutput(v);
- }
-}
-
-TypePtr MethodDecoder::buildType(const onnx::TypeProto& type_proto) {
- auto tensortype_proto = type_proto.tensor_type();
- auto shape_proto = tensortype_proto.shape();
- auto kind = type_proto.denotation();
- if (kind == "DynamicType") {
- return DynamicType::get();
- } else if (kind == "TensorType") {
- auto dims = shape_proto.dim_size();
- return TensorType::create(onnxTypeToATenType(tensortype_proto.elem_type()), at::kCPU, dims);
- } else if (kind == "CompleteTensorType") {
- // first half of the dims are sizes and the second half are strides
- auto total = shape_proto.dim_size();
- std::vector<int64_t> sizes, strides;
- for (int i = 0; i < total / 2; i++) {
- sizes.push_back(shape_proto.dim(i).dim_value());
- }
- for (int i = total / 2; i < total; i++) {
- strides.push_back(shape_proto.dim(i).dim_value());
- }
- return CompleteTensorType::create(onnxTypeToATenType(tensortype_proto.elem_type()), at::kCPU, sizes, strides);
- } else if (kind == "TupleType") {
- std::vector<TypePtr> elems;
- for (auto &subkind : shape_proto.dim()) {
- auto it = value_type_map_.find(subkind.dim_param());
- JIT_ASSERT(it != value_type_map_.end());
- elems.push_back(buildType(*it->second));
- }
- return TupleType::create(elems);
- } else if (kind == "ListType") {
- auto subkind = shape_proto.dim(0);
- auto it = value_type_map_.find(subkind.dim_param());
- JIT_ASSERT(it != value_type_map_.end());
- return ListType::create(buildType(*it->second));
- } else if (kind == "NumberType") {
- return NumberType::get();
- } else if (kind == "FloatType") {
- return FloatType::get();
- } else if (kind == "IntType") {
- return IntType::get();
- } else if (kind == "BoolType") {
- return BoolType::get();
- } else if (kind == "NoneType") {
- return NoneType::get();
- } else if (kind == "GeneratorType") {
- return GeneratorType::get();
- } else if (kind == "StringType") {
- return StringType::get();
- } else if (kind == "OptionalType") {
- auto subkind = shape_proto.dim(0);
- auto it = value_type_map_.find(subkind.dim_param());
- JIT_ASSERT(it != value_type_map_.end());
- return OptionalType::create(buildType(*it->second));
- } else if (kind.find("TypeVar:") == 0) {
- return VarType::create(kind.substr(strlen("TypeVar:")));
- } else {
- throw std::runtime_error("unexpected string for type kind: " + kind);
- }
-}
-
-void MethodDecoder::buildValue(
- Value* value,
- const onnx::ValueInfoProto& valueinfo_proto) {
- value->setType(buildType(valueinfo_proto.type()));
-}
-
-void MethodDecoder::buildIntermediateValue(
- Value* value,
- const std::string& name) {
- auto it = value_type_map_.find(name);
- JIT_ASSERT(it != value_type_map_.end());
- value->setType(buildType(*it->second));
-}
-
-// Given a full name of a parameter or method,
-// return the parent submodule and local name
-std::pair<std::shared_ptr<script::Module>, std::string> MethodDecoder::
- parseFullName(ModuleLookup module_lookup, const std::string fullname) {
- AT_ASSERT(!fullname.empty());
- std::vector<std::string> vec;
- std::stringstream ss(fullname);
- std::string name;
- while (std::getline(ss, name, '.')) {
- vec.push_back(name);
- }
-
- std::string last = vec.back();
- vec.pop_back();
- return std::make_pair(module_lookup(vec), std::move(last));
-}
-
-at::Tensor MethodDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
- uint64_t tensor_id = caffe2::stoull(tensor_proto.name());
- return deserializer_->loadTensor(tensor_id);
-}
+ std::vector<at::Tensor> tensor_table_;
+};
ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
: ifs_(filename, std::ifstream::in | std::ifstream::binary),
"JSON transcoder produced invalid protobuf output.");
moduleLookup_ = module_lookup;
- metaMap_.clear();
- for (int i = 0; i < model_def.tensors_size(); ++i) {
- const auto& tensor_proto = model_def.tensors(i);
- uint64_t tensor_id = caffe2::stoull(tensor_proto.name());
- metaMap_[tensor_id] = &tensor_proto;
- }
-
const auto& module_def = model_def.main_module();
-
- loadParams(module_def, module_def.name());
+ loadTensorTable(&model_def);
// TODO: this can be simplified when C++/Python interop lands,
// and the submodules would be created as the same in either C++ or Python
- std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
- convertModule(module_def, module.get());
+ convertModule(module_def);
}
-at::Tensor ScriptModuleDeserializer::loadTensor(uint64_t tensor_id) {
- auto it = metaMap_.find(tensor_id);
- AT_ASSERT(it != metaMap_.end());
- const caffe2::TensorProto& tensor_proto = *it->second;
- std::vector<int64_t> dims;
- for (int i = 0; i < tensor_proto.dims_size(); ++i) {
- dims.push_back(tensor_proto.dims(i));
- }
- AT_ASSERT(
- tensor_proto.storage_type() == caffe2::TensorProto_StorageType_EXTERNAL);
- const caffe2::ExternalDataProto& external_data = tensor_proto.external_data();
- std::vector<int64_t> strides;
- for (int i = 0; i < external_data.strides_size(); ++i) {
- strides.push_back(external_data.strides(i));
+void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
+ std::unordered_map<uint64_t, at::Storage> storageMap;
+ for(const torch::TensorDef& tensor : model_def->tensors()) {
+ tensor_table_.emplace_back(loadTensor(tensor, storageMap));
}
+}
+
+at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_proto,
+ std::unordered_map<uint64_t, at::Storage>& storageMap) {
+ std::vector<int64_t> dims(tensor_proto.dims().begin(), tensor_proto.dims().end());
+ std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
- uint64_t record_id = caffe2::stoull(external_data.record_id());
- AT_ASSERT(record_id != 0);
- auto storage_it = storageMap_.find(record_id);
- if (storage_it == storageMap_.end()) {
+
+ uint64_t record_id = caffe2::stoull(tensor_proto.data().key());
+ auto storage_it = storageMap.find(record_id);
+ if (storage_it == storageMap.end()) {
at::DataPtr storage_ptr;
uint64_t record_size;
std::tie(storage_ptr, record_size) = reader_.getRecordWithKey(record_id);
- AT_ASSERT(record_size == external_data.record_size());
- auto storage = std::make_shared<at::Storage>(
+ AT_ASSERT(record_size == tensor_proto.data().size());
+ auto storage = at::Storage(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
record_size / at::CPU(type).typeMeta().itemsize(),
nullptr); // NB: we didn't set any allocator for the tensor
- storageMap_.insert(std::make_pair(record_id, storage));
- return at::CPU(type)._th_tensor(
- *storage, external_data.offset(), dims, strides);
- }
- return at::CPU(type)._th_tensor(
- *(storage_it->second.get()), external_data.offset(), dims, strides);
-}
-
-at::Tensor* ScriptModuleDeserializer::lookupTensor(
- const std::string& param_name) const {
- auto it = paramMap_.find(param_name);
- AT_ASSERTM(it != paramMap_.end(), "cannot find parameter ", param_name);
- return it->second;
-}
-
-void ScriptModuleDeserializer::loadParams(
- const torch::ModuleDef& module_def,
- const std::string& prefix) {
- std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
- for (int i = 0; i < module_def.parameters_size(); ++i) {
- const torch::ParameterDef& param_def = module_def.parameters(i);
- uint64_t tensor_id = caffe2::stoull(param_def.tensor_id());
- at::Tensor tensor = loadTensor(tensor_id);
- autograd::Variable variable =
- autograd::make_variable(tensor, param_def.require_gradient());
- module->register_parameter(
- param_def.name(), variable, param_def.is_buffer());
- paramMap_[prefix + param_def.name()] =
- module->parameter_slot(param_def.name());
- }
- for (int i = 0; i < module_def.submodules_size(); ++i) {
- const torch::ModuleDef& sub_def = module_def.submodules(i);
- moduleStack_.push_back(sub_def.name());
- loadParams(sub_def, prefix + sub_def.name() + ".");
- moduleStack_.pop_back();
+ storage_it = storageMap.insert(std::make_pair(record_id, storage)).first;
}
+ auto t = at::CPU(type)._th_tensor(
+ storage_it->second, tensor_proto.offset(), dims, strides);
+ return autograd::make_variable(t, tensor_proto.requires_grad());
}
void ScriptModuleDeserializer::convertModule(
- const torch::ModuleDef& module_def,
- script::Module* module) {
+ const torch::ModuleDef& module_def) {
+ std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
module->set_optimized(module_def.optimize());
- for (int i = 0; i < module_def.methods_size(); ++i) {
- const torch::MethodDef& method_def = module_def.methods(i);
- // TODO read unhacked torch script, right now it's serialized onnx proto
- ::ONNX_NAMESPACE::ModelProto method_proto;
- AT_ASSERTM(
- method_proto.ParseFromString(method_def.onnx_proto()),
- "cannot parse method proto (i.e., hacked onnx proto)");
- MethodDecoder decoder(method_proto, module, this);
- (void)decoder;
- }
for (int i = 0; i < module_def.submodules_size(); ++i) {
const torch::ModuleDef& sub_def = module_def.submodules(i);
- moduleStack_.push_back(sub_def.name());
- std::shared_ptr<script::Module> sub = moduleLookup_(moduleStack_);
- convertModule(sub_def, sub.get());
+ moduleStack_.emplace_back(sub_def.name());
+ convertModule(sub_def);
moduleStack_.pop_back();
}
+ for (int i = 0; i < module_def.parameters_size(); ++i) {
+ const torch::ParameterDef& param_def = module_def.parameters(i);
+ at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
+ module->register_parameter(
+ param_def.name(), tensor, param_def.is_buffer());
+ }
+ at::DataPtr data;
+ size_t size;
+ std::tie(data, size) = reader_.getRecordWithKey(caffe2::stoull(module_def.torchscript_arena().key()));
+ JIT_ASSERT(size == module_def.torchscript_arena().size());
+ std::string data_str(static_cast<const char*>(data.get()), size);
+ import_methods(module, data_str, tensor_table_);
}
} // namespace
std::shared_ptr<script::Module> module;
};
+struct OpsValue : public script::SugaredValue {
+ OpsValue(size_t version)
+ : version_(version) {}
+ std::string kind() const override {
+ return "ops";
+ }
+ std::shared_ptr<SugaredValue> attr(SourceRange loc, script::Method & m, const std::string& field) override {
+ return std::make_shared<script::BuiltinModule>(field, version_);
+ }
+ size_t version_;
+};
+
struct ConstantValue : public script::SugaredValue {
ConstantValue(IValue value)
: value_(std::move(value)) {}
size_t version = parseVersionNumber(p.lexer());
std::unordered_map<std::string, std::shared_ptr<script::SugaredValue>> env = {
- {"aten", std::make_shared<script::BuiltinModule>("aten", version)},
- {"prim", std::make_shared<script::BuiltinModule>("prim", version)},
+ {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
+ {"ops", std::make_shared<OpsValue>(version)},
{"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
{"fork", std::make_shared<script::ForkValue>()},
{"annotate", std::make_shared<script::AnnotateValue>()},
}
std::ostream& Graph::prettyPrint(std::ostream & out) {
- PythonPrint(out, *this);
+ std::vector<at::Tensor> tensor_table;
+ PythonPrint(out, *this, tensor_table);
return out;
}
void Graph::dumpPretty() {
- PythonPrint(std::cout, *this);
+ std::vector<at::Tensor> tensor_table;
+ PythonPrint(std::cout, *this, tensor_table);
}
static void checkSameDevice(const Node* node) {
const static std::unordered_set<std::string> reserved_names = {
// identifiers in the environment while parsing
"aten",
- "prim",
+ "ops",
"CONSTANTS",
"fork",
"attribute",
// constants are written to this table, and given then named CONSTANTS.cN
// where N is the index into this table.
- std::vector<at::Tensor> tensor_constants;
+ std::vector<at::Tensor>& tensor_table_;
// When printing this node, is it safe to write it inline (i.e. without
// assigning a temporary variable
std::unordered_set<Node*> output_inline_;
// ConstantPool, which is also N^2 in the size of the constants,
// because it doesn't hash any information about the tensors.
// We will probably need to optimize this at some point using hashing.
- for(size_t i = 0; i < tensor_constants.size(); ++i) {
- if (t.type() == tensor_constants[i].type() && t.equal(tensor_constants[i])) {
+ for(size_t i = 0; i < tensor_table_.size(); ++i) {
+ if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
return i;
}
}
JIT_ASSERT(t.is_variable());
- tensor_constants.emplace_back(std::move(t));
- return tensor_constants.size() - 1;
+ tensor_table_.emplace_back(std::move(t));
+ return tensor_table_.size() - 1;
}
std::unordered_set<Node*> seen_constants;
} break;
default: {
Symbol kind = node->kind();
- stmt << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+ if (kind.is_aten()) {
+ // special case aten -> torch because we want to rename
+ // the aten namespace, but this change will take more time
+ // doing it here ensures we do not have fix up archives later
+ stmt << "torch." << kind.toUnqualString() << "(";
+ } else {
+ stmt << "ops." << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+ }
const FunctionSchema& schema = node->schema();
- for (size_t i = 0; i < schema.arguments().size(); ++i) {
- auto v = useOf(node->inputs().at(i));
- auto arg = schema.arguments().at(i);
+ for (size_t i = 0; i < node->inputs().size(); ++i) {
if (i > 0) {
stmt << ", ";
}
- if (arg.kwarg_only()) {
- stmt << arg.name() << "=";
+ auto v = useOf(node->inputs().at(i));
+ // print the kwarg name if it is a kwarg only argument.
+ if (i < schema.arguments().size()) {
+ auto arg = schema.arguments().at(i);
+ if (arg.kwarg_only()) {
+ stmt << arg.name() << "=";
+ }
+ } else {
+ // vararg functions like format can have extra arguments
+ JIT_ASSERT(schema.is_vararg());
}
stmt << v;
}
public:
PythonPrintPass(
std::ostream& out_,
+ std::vector<at::Tensor>& tensor_table,
bool enforce_importable)
- : out(out_), enforce_importable_(enforce_importable) {}
+ : out(out_), tensor_table_(tensor_table), enforce_importable_(enforce_importable) {}
// TODO: we should consider forcing functions to return a single value
// instead of handling this tuple logic both in the compiler and the printer
}
};
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable) {
- PythonPrintPass pp(out, enforce_importable);
- pp.printFunction(graph, "graph");
- return pp.tensor_constants;
+TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+ PythonPrintPass pp(out, tensor_table, enforce_importable);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ pp.printFunction(const_cast<Graph&>(graph), "graph");
}
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& method, bool enforce_importable) {
- PythonPrintPass pp(out, enforce_importable);
- pp.printMethod(method);
- return pp.tensor_constants;
+
+TORCH_API void PythonPrint(std::ostream& out, const script::Method& method, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+ PythonPrintPass pp(out, tensor_table, enforce_importable);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ pp.printMethod(const_cast<script::Method&>(method));
}
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable) {
- PythonPrintPass pp(out, enforce_importable);
- pp.printModule(module);
- return pp.tensor_constants;
+TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+ PythonPrintPass pp(out, tensor_table, enforce_importable);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+ pp.printModule(const_cast<script::Module&>(module));
}
TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
struct Module;
}
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable=false);
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& graph, bool enforce_importable=false);
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const script::Method& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
}}
std::string getPythonName(const PyObject* obj_) {
AutoGIL gil;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
PyObject* obj = const_cast<PyObject*>(obj_);
auto v = py::getattr(obj, "__name__", py::str("<python_value>"));
// if this was a autograd.Function recover the name of the class
std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
AutoGIL gil;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
if (py::isinstance<py::tuple>(pyobj)) {
// This special-case for printing tuples handles a problem where
struct ConcretePythonOp : public PythonOp {
ConcretePythonOp(Graph * graph)
: PythonOp(graph) {}
- virtual std::string name() const override {
+ std::string name() const override {
AutoGIL gil;
if(auto autograd = autogradFunction()) {
return getPythonName(autograd->get());
this->scalar_args.emplace_back(sa.get());
}
}
- virtual Node * allocNewInstance(Graph * g) override {
+ Node * allocNewInstance(Graph * g) override {
return new ConcretePythonOp(g);
}
// recover the autograd.Function instance, if this PythonOp's function
// was originally SomeFunction.apply
// used in ONNX for discovering symbolics
- virtual c10::optional<THPObjectPtr> autogradFunction() const override {
+ c10::optional<THPObjectPtr> autogradFunction() const override {
AutoGIL gil;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
py::handle obj = const_cast<PyObject*>(pyobj.get());
auto r = py::getattr(obj, "__self__", py::none());
return THPObjectPtr(r.release().ptr());
}
- virtual void writeScalars(std::ostream& out) const override {
+ void writeScalars(std::ostream& out) const override {
out << "(";
int i = 0;
for (auto& scalar : scalar_args) {
setInputTypes(*g, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
PropagateInputShapes(g);
})
- .def("export", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
+ .def("_export_onnx", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version, bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type) {
std::string graph;
RawDataExportMap export_map;
- std::tie(graph, export_map) = ExportGraph(
+ std::tie(graph, export_map) = export_onnx(
g, initializers, onnx_opset_version, defer_weight_export, operator_export_type);
std::unordered_map<std::string, py::bytes> python_serialized_export_map;
for (auto& kv : export_map) {
py::arg("onnx_opset_version")=0,
py::arg("defer_weight_export")=false,
py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX)
- .def("prettyPrintExport", [](const std::shared_ptr<Graph> g,
+ .def("_pretty_print_onnx", [](const std::shared_ptr<Graph> g,
const std::vector<at::Tensor>& initializers,
int64_t onnx_opset_version, bool defer_weight_export,
::torch::onnx::OperatorExportTypes operator_export_type,
bool google_printer) {
- return PrettyPrintExportedGraph(
+ return pretty_print_onnx(
g, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
google_printer);
}, py::arg("initializers"),
return n.t(Symbol::attr(name));
})
.def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
+ // NOLINTNEXTLINE(modernize-loop-convert)
for (size_t i = 0; i < v.size(); ++ i) {
v[i] = autograd::Variable(v[i].view({})).data();
}
void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
Value* as_simple_value = asSimple(value);
- if (as_simple_value && !as_simple_value->hasUniqueName() && meaningfulName(name)) {
+ if (as_simple_value && !as_simple_value->hasUniqueName() &&
+ meaningfulName(name) &&
+ // note: if the value wasn't defined in this block, we might be giving a name
+ // only used inside this block to a value outside of this. this is not
+ // normally helpful for debugging and causes import/export jitter.
+ as_simple_value->node()->owningBlock() == block()) {
as_simple_value->setUniqueName(name);
}
// prevent re-assignment involving any sugared values
})
.def("_python_print", [](Module& self) {
std::ostringstream ss;
- std::vector<at::Tensor> tensors = PythonPrint(ss, self, true);
+ std::vector<at::Tensor> tensors;
+ PythonPrint(ss, self, tensors, true);
return std::make_pair(ss.str(), tensors);
});
.def("pretty_print_schema", &Method::pretty_print_schema)
.def("python_print", [](Method &m) {
std::ostringstream oss;
- std::vector<at::Tensor> constants = PythonPrint(oss, m, true);
+ std::vector<at::Tensor> constants;
+ PythonPrint(oss, m, constants, true);
return std::make_pair(oss.str(), std::move(constants));
});
example_outputs, propagate)
from torch.onnx.symbolic import _onnx_opset_version
- return graph.prettyPrintExport(params, _onnx_opset_version, False, operator_export_type, google_printer)
+ return graph._pretty_print_onnx(params, _onnx_opset_version, False, operator_export_type, google_printer)
# NOTE: the output `torch_out` will contain the output tensors resulting from
from torch.onnx.symbolic import _onnx_opset_version
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
if export_params:
- proto, export_map = graph.export(params, _onnx_opset_version, defer_weight_export, operator_export_type)
+ proto, export_map = graph._export_onnx(params, _onnx_opset_version, defer_weight_export, operator_export_type)
else:
- proto, export_map = graph.export([], _onnx_opset_version, False, operator_export_type)
+ proto, export_map = graph._export_onnx([], _onnx_opset_version, False, operator_export_type)
if export_type == ExportTypes.PROTOBUF_FILE:
assert(len(export_map) == 0)