Switch import/export to python printing (#14400)
authorZachary DeVito <zdevito@fb.com>
Fri, 30 Nov 2018 01:51:45 +0000 (17:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 30 Nov 2018 01:53:49 +0000 (17:53 -0800)
Summary:
Stacked on https://github.com/pytorch/pytorch/pull/14378, only look at the last commit.

This changes the way methods are defined in TorchScript archives to use
PythonPrint rather than ONNX protobufs.

It also updates torch.proto to directly document the tensor data
structure actually being serialized.

Notes:
* because PythonPrint prints all the methods at once per module, this
  removes MethodDef in favor of a single torchscript_area and a separate
  caffe2_graphs entry. Note that NetDef's already have method names,
  so there is no need or a separate method name entry.
* This switches cpp/pickle area to RecordRef (references to a file in
  the container format) since it is possible the data in these arenas
  may be large and not suited to json ouput.
* Removes 'annotations' -- annotations should be re-added on the first
  commit that actually has a practical use for them. In the current state
  it is unlikely they are representing the right information.
* Some expect files have changed because PythonPrint is preserving more
  debug name information for parameter names.
* MethodEncoder (the ONNX output format) has been deleted. There is still
  some cleanup possible combining EncoderBase and GraphEncode now that there
  is only a single pathway using EncoderBase.
* This incorporates the changes from #14397
  to define TensorDef
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14400

Reviewed By: suo

Differential Revision: D13231800

Pulled By: zdevito

fbshipit-source-id: af5c1152d0bd6bca8b06c4703f59b161bb19f571

31 files changed:
caffe2/proto/torch.proto
test/expect/TestJit.test_broadcast_fusion_cuda.expect
test/expect/TestJit.test_concat_fusion_cuda.expect
test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
test/expect/TestJit.test_constant_prop_loop_constant.expect
test/expect/TestJit.test_fuse_last_device_cuda.expect
test/expect/TestJit.test_fusion_distribute_cuda.expect
test/expect/TestJit.test_import_method.expect
test/expect/TestJit.test_lstm_fusion_concat_cuda.expect
test/expect/TestJit.test_lstm_fusion_cuda.expect
test/expect/TestJit.test_pretty_printer-empty_int_list_test.expect
test/expect/TestJit.test_pretty_printer-if_one.expect
test/expect/TestJit.test_pretty_printer-if_test.expect
test/expect/TestJit.test_pretty_printer-loop_use_test.expect
test/expect/TestJit.test_pretty_printer-while_if_test.expect
test/expect/TestJit.test_pretty_printer-while_test.expect
test/expect/TestJit.test_repeated_input.expect
test/expect/TestJit.test_repeated_output.expect
test/expect/TestScript.test_logical_short_circuit.expect
test/test_jit.py
torch/csrc/jit/export.cpp
torch/csrc/jit/export.h
torch/csrc/jit/import.cpp
torch/csrc/jit/import_method.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/python_print.h
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/init.cpp
torch/onnx/utils.py

index 6105f7e..a4363a1 100644 (file)
@@ -4,64 +4,63 @@ import "caffe2/proto/caffe2.proto";
 
 package torch;
 
-message ParameterDef {
+message RecordRef {
+  optional string key = 1;
+  // size here refers to the uncompressed size, in bytes of the record
+  // this information also exists in the PyTorch container format data
+  // but is repeated here to make it possible to know size information without
+  // seeking to another record in the file.
+  optional int64 size = 2;
+}
+
+message TensorDef {
+  repeated int64 dims = 1;
+  optional int64 offset = 2;
+  repeated int64 strides = 3;
   // whether we compute the gradient for the parameter
-  optional bool require_gradient = 1;
-  // whether this parameter is registered as buffer or not
-  optional bool is_buffer = 2;
+  optional bool requires_grad = 4;
+  optional caffe2.TensorProto.DataType data_type = 5;
 
-  // do not store tensor in parameter anymore, and retire field 3
-  // optional caffe2.TensorProto tensor = 3;
-  // the id in the tensor table, defined in TensorProto.name
-  optional string tensor_id = 5;
-  // objects other than tensors will be added here
+  optional RecordRef data = 6;
 
-  optional string name = 4;
+  // future: device options
 }
 
-message MethodDef {
-  // method name
-  // by default, we follow the naming convention below:
-  //   1) forward --> main method
-  //   2) init --> init method
-  optional string name = 1; // method name
-
-  // one of graph and torch_script must exist,
-  // if both exist, we reconstruct the graph from torch_script
-  optional caffe2.NetDef graph = 2;
-  optional string torch_script = 3;
-  // temporary place to store the methods of jit script modules
-  optional bytes onnx_proto = 101;
-
-  // inputs and outputs are inferred from graph or script
+message ParameterDef {
+  // whether this parameter is registered as buffer or not
+  optional bool is_buffer = 1;
+
+  // the offset into the tensor table where this parameter is stored
+  optional int64 tensor_id = 2;
+
+  optional string name = 3;
 }
 
 message ModuleDef {
   repeated ModuleDef submodules = 1;
 
-  // We suppose to store the modules in one of the following format:
-  //   - methods (static graph or torch script)
-  //   - pickle
-  //   - cpp_arena
-  repeated MethodDef methods = 2;
+  optional RecordRef torchscript_arena = 2;
+
+  repeated caffe2.NetDef caffe2_nets = 3;
+
   // because the old pickle modules may not be supported by torch_script,
   // have to stored as pickle_arena at this moment.
-  optional bytes pickle_arena = 3;
+  optional RecordRef pickle_arena = 4;
   // should be exposed by the Class Archive, so user can save
   // module specific data which cannot be store in the graph or torch_script
-  optional bytes cpp_arena = 4;
+  optional RecordRef cpp_arena = 5;
 
   // the parameters of this module
-  repeated ParameterDef parameters = 5;
+  repeated ParameterDef parameters = 6;
 
   // the names of inputs and outputs of the module are inferred
   // from the main method.
 
-  optional string name = 6;
+  optional string name = 7;
 
   // whether apply the optimizations to this module, only applicable to
   // script modules
-  optional bool optimize = 7;
+  optional bool optimize = 8;
 }
 
 enum ProtoVersion {
@@ -84,24 +83,9 @@ message ModelDef {
   // put build version here
   optional string producer_version = 4;
 
-  optional string name = 5;
-
-  // metadata
-  //   - exporter - string (either "CAFFE2" or "PYTORCH"),
-  //     to help the runtime understand who exports the model
-  //   - debug_info - string
-  // for MetaNetDef:
-  //   - project - string
-  //   - model_class - string
-  //   - internal_version - string
-  //   - predictor_type - string
-  //   - predictor_id - string
-  //   - execute_plan - string
-  //   - applicationSpecificInfo
-  //   - publish_time - string
-  repeated caffe2.Argument annotations = 6;
-
   // the table contains all the tensor information
   // the tensor id is defined as TensorProto.name
-  repeated caffe2.TensorProto tensors = 7;
+  repeated TensorDef tensors = 5;
+
+  // future: add a way to provide additional meta-data
 }
index eaeb885..c3ec90c 100644 (file)
@@ -1,7 +1,7 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*)
-      %2 : Float(*)) {
-  %3 : Float(*, *) = prim::FusionGroup_0(%2, %0, %1)
+graph(%x : Float(*, *)
+      %scale : Float(*)
+      %shift : Float(*)) {
+  %3 : Float(*, *) = prim::FusionGroup_0(%shift, %x, %scale)
   return (%3);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*)
index 60258b2..96bafae 100644 (file)
@@ -1,6 +1,6 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*, *)) {
-  %2 : Float(*, *) = prim::FusionGroup_0(%0, %1)
+graph(%hx : Float(*, *)
+      %cx : Float(*, *)) {
+  %2 : Float(*, *) = prim::FusionGroup_0(%hx, %cx)
   return (%2);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
index 5d10184..a2e2eb8 100644 (file)
@@ -1,17 +1,17 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*, *)) {
+graph(%x : Float(*, *)
+      %y : Float(*, *)
+      %z : Float(*, *)) {
   %3 : int = prim::Constant[value=1]()
-  %4 : Float(*, *) = prim::FusionGroup_0(%0, %1)
-  %5 : Float(*, *) = aten::add(%4, %2, %3)
+  %w : Float(*, *) = prim::FusionGroup_0(%x, %y)
+  %5 : Float(*, *) = aten::add(%w, %z, %3)
   return (%5);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Float(*, *)) {
   %2 : int = prim::Constant[value=1]()
-  %3 : Float(*, *) = aten::add(%0, %1, %2)
+  %x1 : Float(*, *) = aten::add(%0, %1, %2)
   %4 : int = prim::Constant[value=1]()
-  %5 : Float(*, *) = aten::sub(%0, %1, %4)
-  %6 : Float(*, *) = prim::FusedConcat[dim=0](%3, %5)
-  return (%6);
+  %y1 : Float(*, *) = aten::sub(%0, %1, %4)
+  %w : Float(*, *) = prim::FusedConcat[dim=0](%x1, %y1)
+  return (%w);
 }
index 63a42c6..ff94569 100644 (file)
@@ -3,15 +3,15 @@ graph() {
   %1 : bool = prim::Constant[value=1]()
   %b.1 : int = prim::Constant[value=0]()
   %3 : int = prim::Constant[value=9223372036854775807]()
-  %b.2 : int = prim::Constant[value=1]()
-  %b.4 : int = prim::Constant[value=2]()
-  %b.3 : int = prim::Loop(%3, %1, %b.1)
+  %4 : int = prim::Constant[value=1]()
+  %5 : int = prim::Constant[value=2]()
+  %b.2 : int = prim::Loop(%3, %1, %b.1)
     block0(%7 : int, %8 : int) {
-      -> (%1, %b.2)
+      -> (%1, %4)
     }
-  %b : int = prim::Loop(%3, %0, %b.3)
+  %b : int = prim::Loop(%3, %0, %b.2)
     block0(%10 : int, %11 : int) {
-      -> (%0, %b.4)
+      -> (%0, %5)
     }
   return (%b);
 }
index 4c7a854..b2ef06b 100644 (file)
@@ -1,6 +1,6 @@
-graph(%0 : Float(*)
-      %1 : Float(*)) {
-  %2 : Float(*) = prim::FusionGroup_0(%0, %1)
+graph(%x : Float(*)
+      %y : Float(*)) {
+  %2 : Float(*) = prim::FusionGroup_0(%x, %y)
   return (%2);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*)
index 7ae8356..ce36dac 100644 (file)
@@ -1,6 +1,6 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*, *)) {
-  %2 : Dynamic[] = prim::ListConstruct(%0, %1)
+graph(%x : Float(*, *)
+      %y : Float(*, *)) {
+  %2 : Dynamic[] = prim::ListConstruct(%x, %y)
   %3 : Dynamic[] = aten::broadcast_tensors(%2)
   %4 : Dynamic, %5 : Dynamic = prim::ListUnpack(%3)
   %6 : Float(*, *) = prim::FusionGroup_0(%5, %4)
index d039949..00f86c4 100644 (file)
@@ -1,4 +1,5 @@
 def graph(self,
     x: Tensor,
     y: Tensor) -> Tensor:
-  return aten.add(aten.mul(x, 2), y, alpha=1)
+  _0 = torch.add(torch.mul(x, 2), y, alpha=1)
+  return _0
index 35add35..1e13658 100644 (file)
@@ -1,18 +1,18 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*, *)
-      %3 : Float(*, *)
-      %4 : Float(*, *)
-      %5 : Float(*)
-      %6 : Float(*)) {
-  %7 : Float(*, *) = aten::t(%3)
-  %8 : Float(*, *) = aten::mm(%0, %7)
-  %9 : Float(*, *) = aten::t(%4)
-  %10 : Float(*, *) = aten::mm(%1, %9)
-  %11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
+graph(%input_1 : Float(*, *)
+      %input : Float(*, *)
+      %cx : Float(*, *)
+      %weight_1 : Float(*, *)
+      %weight : Float(*, *)
+      %bias_1 : Float(*)
+      %bias : Float(*)) {
+  %7 : Float(*, *) = aten::t(%weight_1)
+  %8 : Float(*, *) = aten::mm(%input_1, %7)
+  %9 : Float(*, *) = aten::t(%weight)
+  %10 : Float(*, *) = aten::mm(%input, %9)
+  %11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
   %12 : Dynamic[] = aten::broadcast_tensors(%11)
   %13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
-  %17 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
+  %17 : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
   return (%17);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
@@ -48,16 +48,16 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %42 : Float(*, *) = aten::add(%34, %26, %41)
   %43 : int = prim::Constant[value=1]()
   %44 : Float(*, *) = aten::add(%36, %28, %43)
-  %45 : Float(*, *) = aten::sigmoid(%38)
-  %46 : Float(*, *) = aten::sigmoid(%40)
-  %47 : Float(*, *) = aten::tanh(%42)
-  %48 : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%46, %0)
-  %50 : Float(*, *) = aten::mul(%45, %47)
+  %ingate : Float(*, *) = aten::sigmoid(%38)
+  %forgetgate : Float(*, *) = aten::sigmoid(%40)
+  %cellgate : Float(*, *) = aten::tanh(%42)
+  %outgate : Float(*, *) = aten::sigmoid(%44)
+  %49 : Float(*, *) = aten::mul(%forgetgate, %0)
+  %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
   %51 : int = prim::Constant[value=1]()
-  %52 : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%52)
-  %54 : Float(*, *) = aten::mul(%48, %53)
-  %55 : Float(*, *) = prim::FusedConcat[dim=0](%54, %52)
+  %cy : Float(*, *) = aten::add(%49, %50, %51)
+  %53 : Float(*, *) = aten::tanh(%cy)
+  %hy : Float(*, *) = aten::mul(%outgate, %53)
+  %55 : Float(*, *) = prim::FusedConcat[dim=0](%hy, %cy)
   return (%55);
 }
index 06eebdf..0675d79 100644 (file)
@@ -1,19 +1,19 @@
-graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*, *)
-      %3 : Float(*, *)
-      %4 : Float(*, *)
-      %5 : Float(*)
-      %6 : Float(*)) {
-  %7 : Float(*, *) = aten::t(%3)
-  %8 : Float(*, *) = aten::mm(%0, %7)
-  %9 : Float(*, *) = aten::t(%4)
-  %10 : Float(*, *) = aten::mm(%1, %9)
-  %11 : Dynamic[] = prim::ListConstruct(%5, %8, %6, %10)
+graph(%input_1 : Float(*, *)
+      %input : Float(*, *)
+      %cx : Float(*, *)
+      %weight_1 : Float(*, *)
+      %weight : Float(*, *)
+      %bias_1 : Float(*)
+      %bias : Float(*)) {
+  %7 : Float(*, *) = aten::t(%weight_1)
+  %8 : Float(*, *) = aten::mm(%input_1, %7)
+  %9 : Float(*, *) = aten::t(%weight)
+  %10 : Float(*, *) = aten::mm(%input, %9)
+  %11 : Dynamic[] = prim::ListConstruct(%bias_1, %8, %bias, %10)
   %12 : Dynamic[] = aten::broadcast_tensors(%11)
   %13 : Dynamic, %14 : Dynamic, %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%12)
-  %17 : Float(*, *), %18 : Float(*, *) = prim::FusionGroup_0(%2, %16, %15, %14, %13)
-  return (%17, %18);
+  %17 : Float(*, *), %cy : Float(*, *) = prim::FusionGroup_0(%cx, %16, %15, %14, %13)
+  return (%17, %cy);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Dynamic
@@ -48,15 +48,15 @@ with prim::FusionGroup_0 = graph(%0 : Float(*, *)
   %42 : Float(*, *) = aten::add(%34, %26, %41)
   %43 : int = prim::Constant[value=1]()
   %44 : Float(*, *) = aten::add(%36, %28, %43)
-  %45 : Float(*, *) = aten::sigmoid(%38)
-  %46 : Float(*, *) = aten::sigmoid(%40)
-  %47 : Float(*, *) = aten::tanh(%42)
-  %48 : Float(*, *) = aten::sigmoid(%44)
-  %49 : Float(*, *) = aten::mul(%46, %0)
-  %50 : Float(*, *) = aten::mul(%45, %47)
+  %ingate : Float(*, *) = aten::sigmoid(%38)
+  %forgetgate : Float(*, *) = aten::sigmoid(%40)
+  %cellgate : Float(*, *) = aten::tanh(%42)
+  %outgate : Float(*, *) = aten::sigmoid(%44)
+  %49 : Float(*, *) = aten::mul(%forgetgate, %0)
+  %50 : Float(*, *) = aten::mul(%ingate, %cellgate)
   %51 : int = prim::Constant[value=1]()
-  %52 : Float(*, *) = aten::add(%49, %50, %51)
-  %53 : Float(*, *) = aten::tanh(%52)
-  %54 : Float(*, *) = aten::mul(%48, %53)
-  return (%54, %52);
+  %cy : Float(*, *) = aten::add(%49, %50, %51)
+  %53 : Float(*, *) = aten::tanh(%cy)
+  %54 : Float(*, *) = aten::mul(%outgate, %53)
+  return (%54, %cy);
 }
index c2867f4..3c8959e 100644 (file)
@@ -1,4 +1,4 @@
 def graph(self,
     y: Tensor) -> int:
   x = annotate(List[int], [])
-  return aten.select(x, 0)
+  return torch.select(x, 0)
index c0864c9..cc9181e 100644 (file)
@@ -1,7 +1,7 @@
 def graph(self,
     a: Tensor,
     b: Tensor) -> Tensor:
-  if bool(aten.lt(a, b)):
+  if bool(torch.lt(a, b)):
     c = a
   else:
     c = b
index 026853c..f503839 100644 (file)
@@ -1,7 +1,7 @@
 def graph(self,
     a: Tensor,
     b: Tensor) -> Tensor:
-  if bool(aten.lt(a, b)):
+  if bool(torch.lt(a, b)):
     c = b
   else:
     c = a
index b49e66c..7442d34 100644 (file)
@@ -1,10 +1,10 @@
 def graph(self,
     y_1: Tensor) -> Tuple[Tensor, Tensor]:
-  x = aten.add(y_1, 1, 1)
-  z_1 = aten.add(x, 5, 1)
+  x = torch.add(y_1, 1, 1)
+  z_1 = torch.add(x, 5, 1)
   y, z = y_1, z_1
-  _0 = bool(aten.lt(y_1, 8))
+  _0 = bool(torch.lt(y_1, 8))
   while _0:
-    y_2 = aten.add_(y, 1, 1)
-    _0, y, z = bool(aten.lt(y_2, 8)), y_2, x
+    y_2 = torch.add_(y, 1, 1)
+    _0, y, z = bool(torch.lt(y_2, 8)), y_2, x
   return x, z
index c6df96c..8894034 100644 (file)
@@ -2,13 +2,13 @@ def graph(self,
     a_1: Tensor,
     b_1: Tensor) -> Tensor:
   a, b, c = a_1, b_1, 0
-  _0 = bool(aten.lt(a_1, 10))
+  _0 = bool(torch.lt(a_1, 10))
   while _0:
-    a_2 = aten.add(a, 1, 1)
-    b_2 = aten.add(b, 1, 1)
-    if bool(aten.gt(a_2, b_2)):
-      c_4 = 2
+    a_2 = torch.add(a, 1, 1)
+    b_2 = torch.add(b, 1, 1)
+    if bool(torch.gt(a_2, b_2)):
+      c_2 = 2
     else:
-      c_4 = 3
-    _0, a, b, c = bool(aten.lt(a_2, 10)), a_2, b_2, c_4
-  return aten.add(aten.add(a, 1, 1), c, 1)
+      c_2 = 3
+    _0, a, b, c = bool(torch.lt(a_2, 10)), a_2, b_2, c_2
+  return torch.add(torch.add(a, 1, 1), c, 1)
index 84ee52b..79818ab 100644 (file)
@@ -2,9 +2,9 @@ def graph(self,
     a_1: Tensor,
     i_1: Tensor) -> Tensor:
   a, i = a_1, i_1
-  _0 = bool(aten.lt(i_1, 3))
+  _0 = bool(torch.lt(i_1, 3))
   while _0:
-    a_2 = aten.mul_(a, a)
-    i_2 = aten.add_(i, 1, 1)
-    _0, a, i = bool(aten.lt(i_2, 3)), a_2, i_2
+    a_2 = torch.mul_(a, a)
+    i_2 = torch.add_(i, 1, 1)
+    _0, a, i = bool(torch.lt(i_2, 3)), a_2, i_2
   return a
index ac67a6c..082caed 100644 (file)
@@ -1,6 +1,6 @@
-graph(%0 : Double(2, 2)
-      %1 : Double(2, 2)) {
+graph(%a : Dynamic
+      %b : Dynamic) {
   %2 : int = prim::Constant[value=1]()
-  %3 : Double(2, 2) = aten::add(%0, %1, %2)
+  %3 : Dynamic = aten::add(%a, %b, %2)
   return (%3);
 }
index 64a937a..5444190 100644 (file)
@@ -1,6 +1,6 @@
-graph(%0 : Double(2, 2)
-      %1 : Double(2, 2)) {
+graph(%a : Dynamic
+      %b : Dynamic) {
   %2 : int = prim::Constant[value=1]()
-  %3 : Double(2, 2) = aten::add(%0, %1, %2)
+  %3 : Dynamic = aten::add(%a, %b, %2)
   return (%3, %3);
 }
index 37abfdf..14f829b 100644 (file)
@@ -2,10 +2,10 @@ graph(%t : Dynamic) {
   %1 : bool = prim::Constant[value=1]()
   %2 : bool = prim::Constant[value=0]()
   %c1.1 : int = prim::Constant[value=1]()
-  %c1.2 : int = prim::Constant[value=0]()
+  %4 : int = prim::Constant[value=0]()
   %5 : bool = prim::If(%2)
     block0() {
-      %6 : Dynamic = aten::select(%t, %c1.2, %c1.1)
+      %6 : Dynamic = aten::select(%t, %4, %c1.1)
       %7 : bool = prim::TensorToBool(%6)
       -> (%7)
     }
@@ -22,7 +22,7 @@ graph(%t : Dynamic) {
           -> (%1)
         }
         block1() {
-          %10 : Dynamic = aten::select(%t, %c1.2, %c1.1)
+          %10 : Dynamic = aten::select(%t, %4, %c1.1)
           %11 : bool = prim::TensorToBool(%10)
           -> (%11)
         }
@@ -30,7 +30,7 @@ graph(%t : Dynamic) {
     }
   %c1 : int = prim::If(%8)
     block0() {
-      -> (%c1.2)
+      -> (%4)
     }
     block1() {
       -> (%c1.1)
index f247b67..3e1e1fe 100644 (file)
@@ -258,7 +258,6 @@ class JitTestCase(TestCase):
                     raise
                 else:
                     return
-
             ppv = "op_version_set = 0\n{}".format(pp)
             sm = copy_structure_and_params(module)
             torch._C._jit_import_methods(sm, ppv, constant_table)
@@ -2951,6 +2950,22 @@ class TestScript(JitTestCase):
 
         return ge
 
+    def test_jitter_bug(self):
+        @torch.jit.script
+        def fn2(input, kernel_size):
+            # type: (Tensor, List[int]) -> Tensor
+            if kernel_size[0] > 1:
+                _stride = [2]
+            else:
+                _stride = kernel_size
+            print(_stride, kernel_size)
+            return input
+
+        @torch.jit.script
+        def fn(input):
+            # type: (Tensor) -> Tensor
+            return fn2(input, [1])
+
     def test_annoying_doubles(self):
         mod = types.ModuleType("temp")
         mod.inf = float("inf")
index 4d870a5..1ae3045 100644 (file)
@@ -8,6 +8,8 @@
 #include "torch/csrc/utils/functional.h"
 #include <torch/csrc/jit/assertions.h>
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/python_print.h"
+
 
 #include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_pb.h"
@@ -33,18 +35,6 @@ namespace onnx = ::ONNX_NAMESPACE;
 
 class ScriptModuleSerializer;
 
-std::string getExportableSchemaStringForMethod(const script::Method& method) {
-  const auto& schema = method.getSchema();
-  for (const auto& argument : schema.arguments()) {
-    AT_CHECK(
-        !argument.default_value(),
-        "Default arguments in script graphs may currently not be exported.");
-  }
-  std::ostringstream stream;
-  stream << schema;
-  return stream.str();
-}
-
 std::string getNodeStackTraceString(const Node* n) {
   std::stringstream ss;
   if (n->getSourceLocation()) {
@@ -438,42 +428,6 @@ void GraphEncoder::EncodeTensor(
   }
 }
 
-class MethodEncoder : public EncoderBase {
- public:
-  MethodEncoder(
-      const script::Method& method,
-      const ScriptModuleSerializer& serializer);
-
-  std::string EncodeMethod(
-      const script::Method& method,
-      const std::string& prefix);
-
- private:
-  void EncodeTensor(
-      onnx::TensorProto* tensor_proto,
-      const at::Tensor& tensor,
-      const c10::optional<std::string> external_ref = {}) override;
-
-  void EncodeIntermediateValueInfo(onnx::GraphProto *graph_proto,
-                                           const Value* n) override;
-
-  void EncodeValueInfo(onnx::GraphProto *graph_proto,
-                               onnx::ValueInfoProto* v,
-                               const Value* n) override;
-
-  void EncodeTypeInfo(onnx::GraphProto *graph_proto,
-                      onnx::ValueInfoProto* v,
-                      const TypePtr& type,
-                      const std::string& name);
-
-  // serializer already serialized all the tensors, and stores
-  // the tensor and parameter tables
-  const ScriptModuleSerializer* serializer_;
-
-  // Used to create sequential dummy names for node types
-  size_t type_counter_ = 0;
-};
-
 // this is a serializer class which saves script modules to pt files. the
 // content of the file is written using PyTorchStreamWriter, for details please
 // check caffe2/serialize/inline_container.h. all the records except the last
@@ -488,28 +442,19 @@ class ScriptModuleSerializer final {
 
   void serialize(const script::Module& module);
 
-  uint64_t lookupTensorId(const at::Tensor* tensor) const;
-
-  const std::string& lookupParamName(const at::Tensor* tensor) const;
-
  private:
-  void convertToModel(const script::Module& module, torch::ModelDef* model_def);
+  void convertModel(const script::Module& module, torch::ModelDef* model_def);
 
   // add a tensor to the tensorTable
-  void addTensor(const at::Tensor* tensor);
-
-  // recursively collect the tensors in a block and add them to the tensorTable
-  void findTensorInBlock(const Block& block);
-
-  // recursively iterate over the whole module to collect the information of
-  // tensors and parameters
-  void collectInfo(const script::Module& module, const std::string& prefix);
+  // returns the offset into the tensor table
+  size_t addTensor(const at::Tensor& tensor);
 
   // write the content of the tensor to the file/stream, and save the
   // offset in the storageMap_
   void convertAndWriteTensor(
       const at::Tensor& tensor,
-      caffe2::TensorProto* tensor_proto);
+      torch::TensorDef* tensor_proto,
+      std::unordered_map<const void*, uint64_t>& storageMap);
 
   // dump all the tensors in the tensorTable_ to a ModelDef (metadata) and
   // the file/stream (the content), assuming all the information of the
@@ -526,193 +471,12 @@ class ScriptModuleSerializer final {
       const script::NamedParameter& param,
       torch::ParameterDef* param_def);
 
-  void convertMethod(
-      const script::Method& method,
-      torch::MethodDef* method_def);
-
   std::ofstream ofs_;
   PyTorchStreamWriter writer_;
-  // storage_ptr => record_offset
-  std::unordered_map<const void*, uint64_t> storageMap_;
-  // tensor => param name
-  std::unordered_map<const at::Tensor*, std::string> paramMap_;
-  // tensor => tensor_id
-  std::unordered_map<const at::Tensor*, uint64_t> tensorTable_;
-  // used for generating table id for tensors
-  uint64_t nextTensorId_ = 0;
-};
-
-// MethodEncoder's methods
-MethodEncoder::MethodEncoder(
-    const script::Method& method,
-    const ScriptModuleSerializer& serializer)
-    : EncoderBase(onnx_torch::OperatorExportTypes::RAW, false) {
-  serializer_ = &serializer;
-}
-
-std::string MethodEncoder::EncodeMethod(
-    const script::Method& method,
-    const std::string& prefix) {
-  onnx::ModelProto model_proto;
-  model_proto.set_doc_string("THIS PROTO IS NOT STANDARD ONNX");
-  auto* node_proto = model_proto.mutable_graph()->add_node();
-  node_proto->set_name(prefix + method.name());
-
-  // We store the schema string in the docstring.
-  node_proto->set_doc_string(getExportableSchemaStringForMethod(method));
-
-  // Store member_inputs of Method in input
-  for (auto& member_input : method.params()) {
-    const auto& param_name = serializer_->lookupParamName(member_input);
-    node_proto->add_input(param_name);
-  }
-
-  auto attr_proto = node_proto->add_attribute();
-  attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH);
-
-  for (auto node : method.graph()->nodes()) {
-    if (node->kind() == prim::PythonOp) {
-      auto py_node = static_cast<torch::jit::PythonOp*>(node);
-      throw std::runtime_error(
-          "Couldn't export Python operator " + py_node->name() +
-          "\n\nDefined at:\n" + getNodeStackTraceString(node));
-    }
-  }
-  EncodeBlock(attr_proto->mutable_g(), method.graph()->block(), {});
-  std::string torch_script;
-  AT_ASSERT(model_proto.SerializeToString(&torch_script));
-  return torch_script;
-}
-
-void MethodEncoder::EncodeTensor(
-    onnx::TensorProto* tensor_proto,
-    const at::Tensor& tensor,
-    const c10::optional<std::string> external_ref) {
-  uint64_t tensor_id = serializer_->lookupTensorId(&tensor);
-  tensor_proto->set_name(c10::to_string(tensor_id));
-  // No need to store the content of the tensor to the file/stream
-  // any more, since it is already saved at the beginning of the
-  // serialization in writeTensorTable
-}
-
-void MethodEncoder::EncodeIntermediateValueInfo(
-    onnx::GraphProto* graph_proto,
-    const Value* n) {
-  auto v = graph_proto->add_value_info();
-  EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
-}
-
-c10::optional<std::string> getBaseTypeDenotation(TypeKind& kind) {
-  if (kind == TypeKind::NumberType) {
-    return "NumberType";
-  } else if (kind == TypeKind::FloatType) {
-    return "FloatType";
-  } else if (kind == TypeKind::IntType) {
-    return "IntType";
-  } else if (kind == TypeKind::BoolType) {
-    return "BoolType";
-  } else if (kind == TypeKind::NoneType) {
-    return "NoneType";
-  } else if (kind == TypeKind::GeneratorType) {
-    return "GeneratorType";
-  } else if (kind == TypeKind::StringType) {
-    return "StringType";
-  }
-  return c10::nullopt;
-}
-
-void MethodEncoder::EncodeTypeInfo(
-    onnx::GraphProto* graph_proto,
-    onnx::ValueInfoProto* v,
-    const TypePtr& type,
-    const std::string& name) {
-  v->set_name(name);
-  onnx::TypeProto* type_proto = v->mutable_type();
-  onnx::TypeProto_Tensor* tensortype_proto = type_proto->mutable_tensor_type();
-  onnx::TensorShapeProto* shape_proto = tensortype_proto->mutable_shape();
-
-  // Use TypeProto fields to encode types.
-  // denotation stores the type as a string
-  auto kind = type->kind();
-  if (kind == TypeKind::DynamicType) {
-    type_proto->set_denotation("DynamicType");
-    tensortype_proto->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
-  } else if (kind == TypeKind::TensorType) {
-    type_proto->set_denotation("TensorType");
-    // encode the number of dimensions by pushing that number of ones into the shape proto
-    auto tensor_type = type->expect<TensorType>();
-    for (int i = 0; i < tensor_type->dim(); i++) {
-      shape_proto->add_dim();
-      shape_proto->mutable_dim(i)->set_dim_value(1);
-    }
-    tensortype_proto->set_elem_type(ATenTypeToOnnxType(tensor_type->scalarType()));
-  } else if (kind == TypeKind::CompleteTensorType) {
-    type_proto->set_denotation("CompleteTensorType");
-    CompleteTensorTypePtr node_type = type->cast<CompleteTensorType>();
-
-    // store the sizes and strides in the dims field of TensorShapeProto
-    size_t i = 0;
-    for (auto &size : node_type->sizes()) {
-      shape_proto->add_dim();
-      shape_proto->mutable_dim(i)->set_dim_value(size);
-      i++;
-    }
-    for (auto &stride : node_type->strides()) {
-      shape_proto->add_dim();
-      shape_proto->mutable_dim(i)->set_dim_value(stride);
-      i++;
-    }
-    tensortype_proto->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
-  } else if (kind == TypeKind::TupleType) {
-    type_proto->set_denotation("TupleType");
-    TupleTypePtr node_type = type->cast<TupleType>();
-    auto elements = node_type->elements();
-
-    // Generate a name for and encode each subtype in the value_info field of the GraphProto.
-    for (size_t i = 0; i < elements.size(); i++) {
-      std::string name = "#" + std::to_string(type_counter_++);
-      shape_proto->add_dim();
-      shape_proto->mutable_dim(i)->set_dim_param(name);
-      onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
-      EncodeTypeInfo(graph_proto, subtype_proto, elements[i], name);
-    }
-  } else if (kind == TypeKind::ListType) {
-    type_proto->set_denotation("ListType");
-    ListTypePtr node_type = type->cast<ListType>();
-
-    // Generate a name for and encode the subtype in the value_info field of the GraphProto.
-    std::string name = "#" + std::to_string(type_counter_++);
-    shape_proto->add_dim();
-    shape_proto->mutable_dim(0)->set_dim_param(name);
-    onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
-    EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
-  } else if (kind == TypeKind::VarType) {
-    type_proto->set_denotation("TypeVar:" + type->expect<VarType>()->name());
-  } else if (kind == TypeKind::OptionalType) {
-    type_proto->set_denotation("OptionalType");
-    OptionalTypePtr node_type = type->cast<OptionalType>();
-
-    // Generate a name for and encode each subtype in the value_info field of the GraphProto.
-    std::string name = "#" + std::to_string(type_counter_++);
-    shape_proto->add_dim();
-    shape_proto->mutable_dim(0)->set_dim_param(name);
-    onnx::ValueInfoProto* subtype_proto = graph_proto->add_value_info();
-    EncodeTypeInfo(graph_proto, subtype_proto, node_type->getElementType(), name);
-  } else {
-    auto denotation = getBaseTypeDenotation(kind);
-    if (!denotation) {
-      throw std::runtime_error("unexpected type kind");
-    }
-    type_proto->set_denotation(*denotation);
-  }
-}
 
-void MethodEncoder::EncodeValueInfo(
-    onnx::GraphProto* graph_proto,
-    onnx::ValueInfoProto* v,
-    const Value* n) {
-  EncodeTypeInfo(graph_proto, v, n->type(), n->uniqueName());
-}
+  // all tensors that will be stored
+  std::vector<at::Tensor> tensor_table_;
+};
 
 // ScriptModuleSerializer's methods
 ScriptModuleSerializer::ScriptModuleSerializer(const std::string& filename)
@@ -728,7 +492,7 @@ ScriptModuleSerializer::ScriptModuleSerializer(std::ostream* ofs)
 
 void ScriptModuleSerializer::serialize(const script::Module& module) {
   torch::ModelDef model_def;
-  convertToModel(module, &model_def);
+  convertModel(module, &model_def);
   std::string output;
   // NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
   // be consistent with MessageToJsonString
@@ -752,157 +516,79 @@ void ScriptModuleSerializer::serialize(const script::Module& module) {
   writer_.writeEndOfFile();
 }
 
-uint64_t ScriptModuleSerializer::lookupTensorId(
-    const at::Tensor* tensor) const {
-  auto it = tensorTable_.find(tensor);
-  AT_ASSERT(it != tensorTable_.end());
-  return it->second;
-}
-
-const std::string& ScriptModuleSerializer::lookupParamName(
-    const at::Tensor* tensor) const {
-  auto it = paramMap_.find(tensor);
-  AT_ASSERT(it != paramMap_.end());
-  return it->second;
-}
-
-void ScriptModuleSerializer::convertToModel(
+void ScriptModuleSerializer::convertModel(
     const script::Module& module,
     torch::ModelDef* model_def) {
-  model_def->set_name("script-model");
   model_def->set_producer_name("pytorch");
   model_def->set_producer_version("1.0"); // TODO: set the producer version
                                           // using appropriate function call
   model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
   std::string main_module_name = "";
-  nextTensorId_ = 0;
-  collectInfo(module, main_module_name);
-  writeTensorTable(model_def);
   convertModule(module, main_module_name, model_def->mutable_main_module());
+  writeTensorTable(model_def);
 }
 
-void ScriptModuleSerializer::addTensor(const at::Tensor* tensor) {
-  if (tensorTable_.find(tensor) == tensorTable_.end()) {
-    tensorTable_[tensor] = nextTensorId_;
-    ++nextTensorId_;
-  }
-}
-
-void ScriptModuleSerializer::findTensorInBlock(const Block& block) {
-  for (auto node : block.nodes()) {
-    for (auto attr_name : node->attributeNames()) {
-      AT_ASSERT(attr_name.is_attr());
-      switch (node->kindOf(attr_name)) {
-        case AttributeKind::f:
-        case AttributeKind::fs:
-        case AttributeKind::i:
-        case AttributeKind::is:
-        case AttributeKind::s:
-        case AttributeKind::ss:
-          break;
-        case AttributeKind::t: {
-          const at::Tensor* tensor = &node->t(attr_name);
-          addTensor(tensor);
-        } break;
-        case AttributeKind::ts: {
-          for (auto& v : node->ts(attr_name)) {
-            const at::Tensor* tensor = &v;
-            addTensor(tensor);
-          }
-        } break;
-        case AttributeKind::g: {
-          findTensorInBlock(*node->g(attr_name)->block());
-        } break;
-        case AttributeKind::gs: {
-          for (auto& v : node->gs(attr_name)) {
-            findTensorInBlock(*v->block());
-          }
-        } break;
-        default:
-          AT_ERROR("unexpected attribute kind");
-      }
-    }
-    for (auto b : node->blocks()) {
-      findTensorInBlock(*b);
-    }
-  }
-}
-
-void ScriptModuleSerializer::collectInfo(
-    const script::Module& module,
-    const std::string& prefix) {
-  for (const auto& elem : module.get_parameters()) {
-    const script::NamedParameter& param = elem.value();
-    paramMap_[param.slot()] = prefix + param.name;
-    addTensor(param.slot());
-  }
-  for (const auto& elem : module.get_methods()) {
-    findTensorInBlock(*elem.value()->graph()->block());
-  }
-  for (const auto& elem : module.get_modules()) {
-    collectInfo(*elem->module, prefix + elem.key() + ".");
-  }
+size_t ScriptModuleSerializer::addTensor(const at::Tensor& tensor) {
+  tensor_table_.push_back(tensor);
+  return tensor_table_.size() - 1;
 }
 
 void ScriptModuleSerializer::convertAndWriteTensor(
     const at::Tensor& tensor,
-    caffe2::TensorProto* tensor_proto) {
-  auto tensor_it = tensorTable_.find(&tensor);
-  AT_ASSERT(tensor_it != tensorTable_.end());
-  tensor_proto->set_name(c10::to_string(tensor_it->second));
+    torch::TensorDef* tensor_proto,
+    std::unordered_map<const void*, uint64_t>& storageMap) {
   for (auto d : tensor.sizes()) {
     tensor_proto->add_dims(d);
   }
-  tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
-      at::scalarTypeToTypeMeta(tensor.type().scalarType())));
-  tensor_proto->set_storage_type(caffe2::TensorProto_StorageType_EXTERNAL);
-  caffe2::ExternalDataProto* external_data =
-      tensor_proto->mutable_external_data();
   for (auto s : tensor.strides()) {
-    external_data->add_strides(s);
+    tensor_proto->add_strides(s);
   }
-  external_data->set_offset(tensor.storage_offset());
+  tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
+      at::scalarTypeToTypeMeta(tensor.type().scalarType())));
+  tensor_proto->set_offset(tensor.storage_offset());
+
+  tensor_proto->set_requires_grad(tensor.requires_grad());
+
   uint64_t record_size =
       tensor.type().elementSizeInBytes() * tensor.storage().size();
-  external_data->set_record_size(record_size);
   auto* key = tensor.storage().unsafeGetStorageImpl();
-  auto storage_it = storageMap_.find(key);
-  if (storage_it == storageMap_.end()) {
+
+  auto storage_it = storageMap.find(key);
+  if (storage_it == storageMap.end()) {
+    at::Tensor storage_tensor = tensor;
     // TODO HIP support
-    uint64_t record_id;
     if (tensor.storage().device_type() == at::DeviceType::CUDA) {
       // NB: This new tensor is created to support cuda tensors.
       // Storages can be mutated when converting tensors from cuda to cpu,
       // and we need a cpu tensor to copy data from.
-      at::Tensor t = at::getType(tensor)
-                         ._th_tensor(
-                             tensor.storage(),
-                             /* storageOffset = */ 0,
-                             /* size = */
-                             {static_cast<int64_t>(tensor.storage().size())},
-                             /* stride = */ {1})
-                         .cpu();
+      storage_tensor = at::getType(tensor)
+                           ._th_tensor(
+                               tensor.storage(),
+                               /* storageOffset = */ 0,
+                               /* size = */
+                               {static_cast<int64_t>(tensor.storage().size())},
+                               /* stride = */ {1})
+                           .cpu();
       AT_ASSERT(
-          t.type().elementSizeInBytes() * t.storage().size() == record_size);
-      record_id = writer_.writeRecord(
-          t.storage().data(),
-          t.type().elementSizeInBytes() * t.storage().size());
-    } else {
-      record_id = writer_.writeRecord(tensor.storage().data(), record_size);
+          storage_tensor.type().elementSizeInBytes() * storage_tensor.storage().size() ==
+          record_size);
     }
-    external_data->set_record_id(c10::to_string(record_id));
-    storageMap_[key] = record_id;
-  } else {
-    external_data->set_record_id(c10::to_string(storage_it->second));
+    uint64_t record_id = writer_.writeRecord(storage_tensor.storage().data(), record_size);
+    storage_it = storageMap.insert({key, record_id}).first;
   }
+
+  auto* data = tensor_proto->mutable_data();
+  data->set_key(std::to_string(storage_it->second));
+  data->set_size(record_size);
+
   // TODO handle device case, set the device_detail and load to CUDA device
 }
 
 void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
-  // NB: we don't reserve any order for tensors in the tensorTable_
-  for (const auto& kv : tensorTable_) {
+  std::unordered_map<const void*, uint64_t> storageMap;
+  for (const at::Tensor& t : tensor_table_) {
     auto* tensor_proto = model_def->add_tensors();
-    convertAndWriteTensor(*kv.first, tensor_proto);
+    convertAndWriteTensor(t, tensor_proto, storageMap);
   }
 }
 
@@ -916,10 +602,16 @@ void ScriptModuleSerializer::convertModule(
     torch::ParameterDef* param_def = module_def->add_parameters();
     convertParameter(elem.value(), param_def);
   }
-  for (auto& elem : module.get_methods()) {
-    torch::MethodDef* method_def = module_def->add_methods();
-    convertMethod(*elem.value(), method_def);
-  }
+
+  std::ostringstream ss;
+  ss << "op_version_set = 0\n";
+  PythonPrint(ss, module, tensor_table_, /*enforce_importable=*/true);
+  torch::RecordRef* record = module_def->mutable_torchscript_arena();
+  std::string str = ss.str();
+  auto key = writer_.writeRecord(str.c_str(), str.size());
+  record->set_key(std::to_string(key));
+  record->set_size(str.size());
+
   for (const auto& elem : module.get_modules()) {
     torch::ModuleDef* sub_def = module_def->add_submodules();
     convertModule(*elem->module, elem.key(), sub_def);
@@ -931,24 +623,10 @@ void ScriptModuleSerializer::convertParameter(
     torch::ParameterDef* param_def) {
   param_def->set_name(param.name);
   param_def->set_is_buffer(param.is_buffer);
-  param_def->set_require_gradient(param.slot()->requires_grad());
-  auto it = tensorTable_.find(param.slot());
-  AT_ASSERT(it != tensorTable_.end());
-  param_def->set_tensor_id(c10::to_string(it->second));
-}
-
-void ScriptModuleSerializer::convertMethod(
-    const script::Method& method,
-    torch::MethodDef* method_def) {
-  // TODO encode the real torch script instead of ModelProto
-  MethodEncoder encoder(method, *this);
-  // we already keep the tree structure in the top level module,
-  // so pass "" as prefix
-  std::string torch_script = encoder.EncodeMethod(method, "");
-  method_def->set_onnx_proto(torch_script);
+  param_def->set_tensor_id(addTensor(*param.slot()));
 }
 
-// Pretty printing
+// Pretty printing for ONNX
 constexpr char indent_char = ' ';
 constexpr size_t indent_multiplier = 2;
 
@@ -1124,7 +802,7 @@ std::string prettyPrint(const onnx::ModelProto& model) {
 
 } // namespace
 
-std::string PrettyPrintExportedGraph(
+std::string pretty_print_onnx(
                         const std::shared_ptr<Graph> &graph,
                         const std::vector<at::Tensor> &initializers,
                         int64_t onnx_opset_version,
@@ -1144,7 +822,7 @@ std::string PrettyPrintExportedGraph(
 // conform to the ONNX op specification. Thus, the output will not
 // be interpretable by a ONNX-compatible framework. However, PyTorch or
 // libtorch will be able to import the IR and play it back.
-std::tuple<std::string, RawDataExportMap> ExportGraph(
+std::tuple<std::string, RawDataExportMap> export_onnx(
                         const std::shared_ptr<Graph> &graph,
                         const std::vector<at::Tensor> &initializers,
                         int64_t onnx_opset_version,
index 363de0b..6865a1a 100644 (file)
@@ -18,7 +18,7 @@ namespace torch { namespace jit {
 // file contents being the raw tensor data.
 using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
 
-TORCH_API std::tuple<std::string, RawDataExportMap> ExportGraph(
+TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
     const std::vector<at::Tensor>& initializers,
     int64_t onnx_opset_version,
@@ -27,7 +27,7 @@ TORCH_API std::tuple<std::string, RawDataExportMap> ExportGraph(
       = ::torch::onnx::OperatorExportTypes::ONNX);
 
 // For testing purposes
-TORCH_API std::string PrettyPrintExportedGraph(
+TORCH_API std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
     const std::vector<at::Tensor> & initializers,
     int64_t onnx_opset_version,
index 131931f..d7cb726 100644 (file)
@@ -6,12 +6,13 @@
 #include "torch/csrc/utils/functional.h"
 #include "torch/csrc/jit/assertions.h"
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/import_method.h"
+
 
 #include "caffe2/core/types.h"
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/proto/torch_pb.h"
 #include "caffe2/serialize/inline_container.h"
-#include "onnx/onnx_pb.h"
 
 #include <ATen/ATen.h>
 
@@ -24,48 +25,6 @@ namespace torch { namespace jit {
 
 namespace {
 
-namespace onnx = ::ONNX_NAMESPACE;
-
-// IR graph construction
-
-class ScriptModuleDeserializer;
-
-class MethodDecoder {
- public:
-  MethodDecoder(
-      const onnx::ModelProto& model_proto,
-      script::Module* parent_module,
-      ScriptModuleDeserializer* deserializer);
-
- private:
-  std::shared_ptr<Graph> buildGraph(const onnx::GraphProto& graph_proto);
-
-  void buildBlock(const onnx::GraphProto& graph_proto, Block* block,
-                  std::unordered_map<std::string, Value*>& value_map);
-
-  void buildBlocks(const std::vector<onnx::GraphProto>& graphs_, Node* node,
-                   std::unordered_map<std::string, Value*>& value_map);
-
-  void buildValue(Value* value, const onnx::ValueInfoProto& valueinfo_proto);
-
-  void buildIntermediateValue(Value* value, const std::string& name);
-
-  at::ScalarType onnxTypeToATenType(int32_t tensor_proto);
-
-  at::Tensor buildTensor(const onnx::TensorProto& tensor_proto);
-
-  TypePtr buildType(const onnx::TypeProto& type_proto);
-
-  std::pair<std::shared_ptr<script::Module>, std::string> parseFullName(
-      ModuleLookup module_lookup,
-      const std::string fullname);
-
-  // deserializer already loads the metadata of tensors, and it is used to
-  // load tensors
-  ScriptModuleDeserializer* deserializer_;
-  std::unordered_map<std::string, const onnx::TypeProto*> value_type_map_;
-};
-
 // this is a deserializer class which loads script modules from pt files. the
 // content of the file is written using PyTorchStreamWriter, for details please
 // check caffe2/serialize/inline_container.h. all the records except the last
@@ -80,286 +39,24 @@ class ScriptModuleDeserializer final {
 
   void deserialize(ModuleLookup module_lookup);
 
-  // given the tensor id, load the data of the tensor from file/stream,
-  // and return a new tensor which contains the loaded data
-  at::Tensor loadTensor(uint64_t tensor_id);
-
-  at::Tensor* lookupTensor(const std::string& param_name) const;
-
- private:
-  // recursively load all the parameters of a module, and construct a
-  // parameter map (i.e., name => tensor). call loadTensor to load and
-  // create a new tensor
-  void loadParams(
-      const torch::ModuleDef& module_def,
-      const std::string& prefix);
-
-  void convertModule(
-      const torch::ModuleDef& module_def,
-      script::Module* module);
-
-  std::ifstream ifs_;
-  PyTorchStreamReader reader_;
-  // this is a hack to make sure the script module created in C++ is the
-  // same as created in Python
-  ModuleLookup moduleLookup_;
-  std::vector<std::string> moduleStack_;
-  // record_id => storage
-  std::unordered_map<uint64_t, std::shared_ptr<at::Storage>> storageMap_;
-  // tensor_id => TensorProto
-  std::unordered_map<uint64_t, const caffe2::TensorProto*> metaMap_;
-  // parameter_name => at::Tensor
-  std::unordered_map<std::string, at::Tensor*> paramMap_;
-};
-
-at::ScalarType MethodDecoder::onnxTypeToATenType(int32_t onnx_type) {
-  switch(onnx_type) {
-    case onnx::TensorProto_DataType_UINT8:
-      return at::kByte;
-    case onnx::TensorProto_DataType_INT8:
-      return at::kChar;
-    case onnx::TensorProto_DataType_INT16:
-      return at::kShort;
-    case onnx::TensorProto_DataType_INT32:
-      return at::kInt;
-    case onnx::TensorProto_DataType_INT64:
-      return at::kLong;
-    case onnx::TensorProto_DataType_FLOAT16:
-      return at::kHalf;
-    case onnx::TensorProto_DataType_FLOAT:
-      return at::kFloat;
-    case onnx::TensorProto_DataType_DOUBLE:
-      return at::kDouble;
-    default:
-      throw std::runtime_error("Unsupported data type");
-  }
-}
-
-MethodDecoder::MethodDecoder(
-    const onnx::ModelProto& model_proto,
-    script::Module* parent_module,
-    ScriptModuleDeserializer* deserializer) {
-  deserializer_ = deserializer;
-  const auto& graph_proto = model_proto.graph();
-  for (const auto& node_proto : graph_proto.node()) {
-    std::vector<at::Tensor*> member_inputs;
-    const std::string& name = node_proto.name();
-    for (const auto& param_name : node_proto.input()) {
-      at::Tensor* tensor = deserializer_->lookupTensor(param_name);
-      member_inputs.push_back(tensor);
-    }
-    auto graph = buildGraph(node_proto.attribute(0).g());
-    parent_module->create_method(name, graph, member_inputs);
-    // We store the schema in the docstring so we can parse the schema and
-    // assign it to the method.
-    auto schema = parseSchema(node_proto.doc_string());
-    parent_module->get_method(name).setSchema(std::move(schema));
-  }
-}
-
-void MethodDecoder::buildBlocks(
-    const std::vector<onnx::GraphProto>& graphs_,
-    Node* node,
-    std::unordered_map<std::string, Value*>& value_map) {
-  for (auto g_ : graphs_) {
-    auto block = node->addBlock();
-    buildBlock(g_, block, value_map);
-  }
-}
-
-std::shared_ptr<Graph> MethodDecoder::buildGraph(
-    const onnx::GraphProto& graph_proto) {
-  auto graph = std::make_shared<Graph>();
-  std::unordered_map<std::string, Value*> value_map;
-
-  buildBlock(graph_proto, graph->block(), value_map);
-
-  return graph;
-}
-
-void MethodDecoder::buildBlock(
-    const onnx::GraphProto& graph_proto,
-    Block* block,
-    std::unordered_map<std::string, Value*>& value_map) {
-  for (auto &subtype : graph_proto.value_info()) {
-    value_type_map_[subtype.name()] = &subtype.type();
-  }
+private:
+ at::Tensor loadTensor(
+     const torch::TensorDef& tensor_proto,
+     std::unordered_map<uint64_t, at::Storage>& storageMap);
 
-  for (auto & input : graph_proto.input()) {
-    auto value = block->addInput();
-    value_map[input.name()] = value;
-    buildValue(value, input);
-  }
+ void convertModule(const torch::ModuleDef& module_def);
 
-  for (auto & node_ : graph_proto.node()) {
-    JIT_ASSERT(node_.op_type() != "PythonOp");
-
-    auto node = block->owningGraph()->create(Symbol::fromDomainAndUnqualString(node_.domain(), node_.op_type()),
-                                             node_.output().size());
-
-    for (auto & attr : node_.attribute()) {
-      Symbol name = Symbol::attr(attr.name());
-
-      switch(attr.type()) {
-        case onnx::AttributeProto_AttributeType_UNDEFINED:
-          throw std::runtime_error("UNDEFINED attribute unsupported");
-          break;
-        case onnx::AttributeProto_AttributeType_FLOAT:
-          node->f_(name, attr.f());
-          break;
-        case onnx::AttributeProto_AttributeType_INT:
-          node->i_(name, attr.i());
-          break;
-        case onnx::AttributeProto_AttributeType_STRING:
-          node->s_(name, std::move(attr.s()));
-          break;
-        case onnx::AttributeProto_AttributeType_TENSOR:
-          node->t_(name, buildTensor(attr.t()));
-          break;
-        case onnx::AttributeProto_AttributeType_GRAPH:
-          node->g_(name, buildGraph(attr.g()));
-          break;
-        case onnx::AttributeProto_AttributeType_FLOATS:
-          node->fs_(name, {attr.floats().begin(), attr.floats().end()});
-          break;
-        case onnx::AttributeProto_AttributeType_INTS:
-          node->is_(name, {attr.ints().begin(), attr.ints().end()});
-          break;
-        case onnx::AttributeProto_AttributeType_STRINGS:
-          node->ss_(name, {attr.strings().begin(), attr.strings().end()});
-          break;
-        case onnx::AttributeProto_AttributeType_TENSORS:
-          node->ts_(name, fmap(attr.tensors(), [this](const onnx::TensorProto& t) {
-                                                 return buildTensor(t);
-                                               }));
-          break;
-        case onnx::AttributeProto_AttributeType_GRAPHS:
-          if (attr.name() == "_blocks") {
-            buildBlocks({attr.graphs().begin(), attr.graphs().end()}, node, value_map);
-          }
-          else {
-            node->gs_(name, fmap(attr.graphs(), [this](const onnx::GraphProto& g_) {
-                                                  return buildGraph(g_);
-                                                }));
-          }
-          break;
-      }
-    }
+ void loadTensorTable(torch::ModelDef* model_def);
 
-    for (auto & input : node_.input()) {
-      auto v = value_map[input];
-      node->addInput(v);
-    }
+ std::ifstream ifs_;
+ PyTorchStreamReader reader_;
+ // this is a hack to make sure the script module created in C++ is the
+ // same as created in Python
+ ModuleLookup moduleLookup_;
+ std::vector<std::string> moduleStack_;
 
-    for (int i=0; i<node_.output().size(); i++) {
-      value_map[node_.output(i)] = node->outputs()[i];
-      buildIntermediateValue(node->outputs()[i], node_.output(i));
-    }
-
-    block->appendNode(node);
-  }
-
-  for (auto & output : graph_proto.output()) {
-    Value* v = value_map.at(output.name());
-    buildValue(v, output);
-    block->registerOutput(v);
-  }
-}
-
-TypePtr MethodDecoder::buildType(const onnx::TypeProto& type_proto) {
-  auto tensortype_proto = type_proto.tensor_type();
-  auto shape_proto = tensortype_proto.shape();
-  auto kind = type_proto.denotation();
-  if (kind == "DynamicType") {
-    return DynamicType::get();
-  } else if (kind == "TensorType") {
-    auto dims = shape_proto.dim_size();
-    return TensorType::create(onnxTypeToATenType(tensortype_proto.elem_type()), at::kCPU, dims);
-  } else if (kind == "CompleteTensorType") {
-    // first half of the dims are sizes and the second half are strides
-    auto total = shape_proto.dim_size();
-    std::vector<int64_t> sizes, strides;
-    for (int i = 0; i < total / 2; i++) {
-      sizes.push_back(shape_proto.dim(i).dim_value());
-    }
-    for (int i = total / 2; i < total; i++) {
-      strides.push_back(shape_proto.dim(i).dim_value());
-    }
-    return CompleteTensorType::create(onnxTypeToATenType(tensortype_proto.elem_type()), at::kCPU, sizes, strides);
-  } else if (kind == "TupleType") {
-    std::vector<TypePtr> elems;
-    for (auto &subkind : shape_proto.dim()) {
-      auto it = value_type_map_.find(subkind.dim_param());
-      JIT_ASSERT(it != value_type_map_.end());
-      elems.push_back(buildType(*it->second));
-    }
-    return TupleType::create(elems);
-  } else if (kind == "ListType") {
-    auto subkind = shape_proto.dim(0);
-    auto it = value_type_map_.find(subkind.dim_param());
-    JIT_ASSERT(it != value_type_map_.end());
-    return ListType::create(buildType(*it->second));
-  } else if (kind == "NumberType") {
-    return NumberType::get();
-  } else if (kind == "FloatType") {
-    return FloatType::get();
-  } else if (kind == "IntType") {
-    return IntType::get();
-  } else if (kind == "BoolType") {
-    return BoolType::get();
-  } else if (kind == "NoneType") {
-    return NoneType::get();
-  } else if (kind == "GeneratorType") {
-    return GeneratorType::get();
-  } else if (kind == "StringType") {
-    return StringType::get();
-  } else if (kind == "OptionalType") {
-    auto subkind = shape_proto.dim(0);
-    auto it = value_type_map_.find(subkind.dim_param());
-    JIT_ASSERT(it != value_type_map_.end());
-    return OptionalType::create(buildType(*it->second));
-  } else if (kind.find("TypeVar:") == 0) {
-    return VarType::create(kind.substr(strlen("TypeVar:")));
-  } else {
-    throw std::runtime_error("unexpected string for type kind: " + kind);
-  }
-}
-
-void MethodDecoder::buildValue(
-    Value* value,
-    const onnx::ValueInfoProto& valueinfo_proto) {
-  value->setType(buildType(valueinfo_proto.type()));
-}
-
-void MethodDecoder::buildIntermediateValue(
-    Value* value,
-    const std::string& name) {
-  auto it = value_type_map_.find(name);
-  JIT_ASSERT(it != value_type_map_.end());
-  value->setType(buildType(*it->second));
-}
-
-// Given a full name of a parameter or method,
-// return the parent submodule and local name
-std::pair<std::shared_ptr<script::Module>, std::string> MethodDecoder::
-    parseFullName(ModuleLookup module_lookup, const std::string fullname) {
-  AT_ASSERT(!fullname.empty());
-  std::vector<std::string> vec;
-  std::stringstream ss(fullname);
-  std::string name;
-  while (std::getline(ss, name, '.')) {
-    vec.push_back(name);
-  }
-
-  std::string last = vec.back();
-  vec.pop_back();
-  return std::make_pair(module_lookup(vec), std::move(last));
-}
-
-at::Tensor MethodDecoder::buildTensor(const onnx::TensorProto& tensor_proto) {
-  uint64_t tensor_id = caffe2::stoull(tensor_proto.name());
-  return deserializer_->loadTensor(tensor_id);
-}
+ std::vector<at::Tensor> tensor_table_;
+};
 
 ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
     : ifs_(filename, std::ifstream::in | std::ifstream::binary),
@@ -400,111 +97,68 @@ void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup) {
       "JSON transcoder produced invalid protobuf output.");
   moduleLookup_ = module_lookup;
 
-  metaMap_.clear();
-  for (int i = 0; i < model_def.tensors_size(); ++i) {
-    const auto& tensor_proto = model_def.tensors(i);
-    uint64_t tensor_id = caffe2::stoull(tensor_proto.name());
-    metaMap_[tensor_id] = &tensor_proto;
-  }
-
   const auto& module_def = model_def.main_module();
-
-  loadParams(module_def, module_def.name());
+  loadTensorTable(&model_def);
   // TODO: this can be simplified when C++/Python interop lands,
   // and the submodules would be created as the same in either C++ or Python
-  std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
-  convertModule(module_def, module.get());
+  convertModule(module_def);
 }
 
-at::Tensor ScriptModuleDeserializer::loadTensor(uint64_t tensor_id) {
-  auto it = metaMap_.find(tensor_id);
-  AT_ASSERT(it != metaMap_.end());
-  const caffe2::TensorProto& tensor_proto = *it->second;
-  std::vector<int64_t> dims;
-  for (int i = 0; i < tensor_proto.dims_size(); ++i) {
-    dims.push_back(tensor_proto.dims(i));
-  }
-  AT_ASSERT(
-      tensor_proto.storage_type() == caffe2::TensorProto_StorageType_EXTERNAL);
-  const caffe2::ExternalDataProto& external_data = tensor_proto.external_data();
-  std::vector<int64_t> strides;
-  for (int i = 0; i < external_data.strides_size(); ++i) {
-    strides.push_back(external_data.strides(i));
+void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
+  std::unordered_map<uint64_t, at::Storage> storageMap;
+  for(const torch::TensorDef& tensor : model_def->tensors()) {
+    tensor_table_.emplace_back(loadTensor(tensor, storageMap));
   }
+}
+
+at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_proto,
+                std::unordered_map<uint64_t, at::Storage>& storageMap) {
+  std::vector<int64_t> dims(tensor_proto.dims().begin(), tensor_proto.dims().end());
+  std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
   auto type = at::typeMetaToScalarType(
       caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
-  uint64_t record_id = caffe2::stoull(external_data.record_id());
-  AT_ASSERT(record_id != 0);
-  auto storage_it = storageMap_.find(record_id);
-  if (storage_it == storageMap_.end()) {
+
+  uint64_t record_id = caffe2::stoull(tensor_proto.data().key());
+  auto storage_it = storageMap.find(record_id);
+  if (storage_it == storageMap.end()) {
     at::DataPtr storage_ptr;
     uint64_t record_size;
     std::tie(storage_ptr, record_size) = reader_.getRecordWithKey(record_id);
-    AT_ASSERT(record_size == external_data.record_size());
-    auto storage = std::make_shared<at::Storage>(
+    AT_ASSERT(record_size == tensor_proto.data().size());
+    auto storage = at::Storage(
         at::CPU(type).typeMeta(),
         std::move(storage_ptr),
         record_size / at::CPU(type).typeMeta().itemsize(),
         nullptr); // NB: we didn't set any allocator for the tensor
-    storageMap_.insert(std::make_pair(record_id, storage));
-    return at::CPU(type)._th_tensor(
-        *storage, external_data.offset(), dims, strides);
-  }
-  return at::CPU(type)._th_tensor(
-      *(storage_it->second.get()), external_data.offset(), dims, strides);
-}
-
-at::Tensor* ScriptModuleDeserializer::lookupTensor(
-    const std::string& param_name) const {
-  auto it = paramMap_.find(param_name);
-  AT_ASSERTM(it != paramMap_.end(), "cannot find parameter ", param_name);
-  return it->second;
-}
-
-void ScriptModuleDeserializer::loadParams(
-    const torch::ModuleDef& module_def,
-    const std::string& prefix) {
-  std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
-  for (int i = 0; i < module_def.parameters_size(); ++i) {
-    const torch::ParameterDef& param_def = module_def.parameters(i);
-    uint64_t tensor_id = caffe2::stoull(param_def.tensor_id());
-    at::Tensor tensor = loadTensor(tensor_id);
-    autograd::Variable variable =
-        autograd::make_variable(tensor, param_def.require_gradient());
-    module->register_parameter(
-        param_def.name(), variable, param_def.is_buffer());
-    paramMap_[prefix + param_def.name()] =
-        module->parameter_slot(param_def.name());
-  }
-  for (int i = 0; i < module_def.submodules_size(); ++i) {
-    const torch::ModuleDef& sub_def = module_def.submodules(i);
-    moduleStack_.push_back(sub_def.name());
-    loadParams(sub_def, prefix + sub_def.name() + ".");
-    moduleStack_.pop_back();
+    storage_it = storageMap.insert(std::make_pair(record_id, storage)).first;
   }
+  auto t = at::CPU(type)._th_tensor(
+      storage_it->second, tensor_proto.offset(), dims, strides);
+  return autograd::make_variable(t, tensor_proto.requires_grad());
 }
 
 void ScriptModuleDeserializer::convertModule(
-    const torch::ModuleDef& module_def,
-    script::Module* module) {
+    const torch::ModuleDef& module_def) {
+  std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
   module->set_optimized(module_def.optimize());
-  for (int i = 0; i < module_def.methods_size(); ++i) {
-    const torch::MethodDef& method_def = module_def.methods(i);
-    // TODO read unhacked torch script, right now it's serialized onnx proto
-    ::ONNX_NAMESPACE::ModelProto method_proto;
-    AT_ASSERTM(
-        method_proto.ParseFromString(method_def.onnx_proto()),
-        "cannot parse method proto (i.e., hacked onnx proto)");
-    MethodDecoder decoder(method_proto, module, this);
-    (void)decoder;
-  }
   for (int i = 0; i < module_def.submodules_size(); ++i) {
     const torch::ModuleDef& sub_def = module_def.submodules(i);
-    moduleStack_.push_back(sub_def.name());
-    std::shared_ptr<script::Module> sub = moduleLookup_(moduleStack_);
-    convertModule(sub_def, sub.get());
+    moduleStack_.emplace_back(sub_def.name());
+    convertModule(sub_def);
     moduleStack_.pop_back();
   }
+  for (int i = 0; i < module_def.parameters_size(); ++i) {
+    const torch::ParameterDef& param_def = module_def.parameters(i);
+    at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
+    module->register_parameter(
+        param_def.name(), tensor, param_def.is_buffer());
+  }
+  at::DataPtr data;
+  size_t size;
+  std::tie(data, size) = reader_.getRecordWithKey(caffe2::stoull(module_def.torchscript_arena().key()));
+  JIT_ASSERT(size == module_def.torchscript_arena().size());
+  std::string data_str(static_cast<const char*>(data.get()), size);
+  import_methods(module, data_str, tensor_table_);
 }
 
 }  // namespace
index 9abb9ab..124200c 100644 (file)
@@ -28,6 +28,18 @@ private:
   std::shared_ptr<script::Module> module;
 };
 
+struct OpsValue : public script::SugaredValue {
+  OpsValue(size_t version)
+  : version_(version) {}
+  std::string kind() const override {
+    return "ops";
+  }
+  std::shared_ptr<SugaredValue> attr(SourceRange loc, script::Method & m, const std::string& field) override {
+    return std::make_shared<script::BuiltinModule>(field, version_);
+  }
+  size_t version_;
+};
+
 struct ConstantValue : public script::SugaredValue {
   ConstantValue(IValue value)
   : value_(std::move(value)) {}
@@ -87,8 +99,8 @@ void import_methods(const std::shared_ptr<script::Module>& mod, const std::strin
   size_t version = parseVersionNumber(p.lexer());
 
   std::unordered_map<std::string, std::shared_ptr<script::SugaredValue>> env = {
-    {"aten", std::make_shared<script::BuiltinModule>("aten", version)},
-    {"prim", std::make_shared<script::BuiltinModule>("prim", version)},
+    {"torch", std::make_shared<script::BuiltinModule>("aten", version)},
+    {"ops", std::make_shared<OpsValue>(version)},
     {"CONSTANTS", std::make_shared<ConstantTableValue>(constant_table)},
     {"fork", std::make_shared<script::ForkValue>()},
     {"annotate", std::make_shared<script::AnnotateValue>()},
index c7078eb..3a674c0 100644 (file)
@@ -189,12 +189,14 @@ std::ostream& operator<<(std::ostream & out, const Graph & g) {
 }
 
 std::ostream& Graph::prettyPrint(std::ostream & out) {
-  PythonPrint(out, *this);
+  std::vector<at::Tensor> tensor_table;
+  PythonPrint(out, *this, tensor_table);
   return out;
 }
 
 void Graph::dumpPretty() {
-  PythonPrint(std::cout, *this);
+  std::vector<at::Tensor> tensor_table;
+  PythonPrint(std::cout, *this, tensor_table);
 }
 
 static void checkSameDevice(const Node* node) {
index d727b78..f642dd0 100644 (file)
@@ -143,7 +143,7 @@ void createTensorToParameterNameMap(
   const static std::unordered_set<std::string> reserved_names = {
     // identifiers in the environment while parsing
     "aten",
-    "prim",
+    "ops",
     "CONSTANTS",
     "fork",
     "attribute",
@@ -193,7 +193,7 @@ struct PythonPrintPass {
   // constants are written to this table, and given then named CONSTANTS.cN
   // where N is the index into this table.
 
-  std::vector<at::Tensor> tensor_constants;
+  std::vector<at::Tensor>& tensor_table_;
   // When printing this node, is it safe to write it inline (i.e. without
   // assigning a temporary variable
   std::unordered_set<Node*> output_inline_;
@@ -319,14 +319,14 @@ struct PythonPrintPass {
     // ConstantPool, which is also N^2 in the size of the constants,
     // because it doesn't hash any information about the tensors.
     // We will probably need to optimize this at some point using hashing.
-    for(size_t i = 0; i < tensor_constants.size(); ++i) {
-      if (t.type() == tensor_constants[i].type() && t.equal(tensor_constants[i])) {
+    for(size_t i = 0; i < tensor_table_.size(); ++i) {
+      if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
         return i;
       }
     }
     JIT_ASSERT(t.is_variable());
-    tensor_constants.emplace_back(std::move(t));
-    return tensor_constants.size() - 1;
+    tensor_table_.emplace_back(std::move(t));
+    return tensor_table_.size() - 1;
   }
 
   std::unordered_set<Node*> seen_constants;
@@ -769,16 +769,29 @@ struct PythonPrintPass {
       } break;
       default: {
         Symbol kind = node->kind();
-        stmt << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+        if (kind.is_aten()) {
+          // special case aten -> torch because we want to rename
+          // the aten namespace, but this change will take more time
+          // doing it here ensures we do not have fix up archives later
+          stmt << "torch." << kind.toUnqualString() << "(";
+        } else {
+          stmt << "ops." << kind.ns().toUnqualString() << "." << kind.toUnqualString() << "(";
+        }
         const FunctionSchema& schema = node->schema();
-        for (size_t i = 0; i < schema.arguments().size(); ++i) {
-            auto v = useOf(node->inputs().at(i));
-            auto arg = schema.arguments().at(i);
+        for (size_t i = 0; i < node->inputs().size(); ++i) {
             if (i > 0) {
               stmt << ", ";
             }
-            if (arg.kwarg_only()) {
-              stmt << arg.name() << "=";
+            auto v = useOf(node->inputs().at(i));
+            // print the kwarg name if it is a kwarg only argument.
+            if (i < schema.arguments().size()) {
+              auto arg = schema.arguments().at(i);
+              if (arg.kwarg_only()) {
+                stmt << arg.name() << "=";
+              }
+            } else {
+              // vararg functions like format can have extra arguments
+              JIT_ASSERT(schema.is_vararg());
             }
             stmt << v;
         }
@@ -873,8 +886,9 @@ struct PythonPrintPass {
  public:
   PythonPrintPass(
       std::ostream& out_,
+      std::vector<at::Tensor>& tensor_table,
       bool enforce_importable)
-      : out(out_), enforce_importable_(enforce_importable) {}
+      : out(out_), tensor_table_(tensor_table), enforce_importable_(enforce_importable) {}
 
   // TODO: we should consider forcing functions to return a single value
   // instead of handling this tuple logic both in the compiler and the printer
@@ -934,21 +948,22 @@ struct PythonPrintPass {
   }
 };
 
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable) {
-  PythonPrintPass pp(out, enforce_importable);
-  pp.printFunction(graph, "graph");
-  return pp.tensor_constants;
+TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+  PythonPrintPass pp(out, tensor_table, enforce_importable);
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+  pp.printFunction(const_cast<Graph&>(graph), "graph");
 }
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& method, bool enforce_importable) {
-  PythonPrintPass pp(out, enforce_importable);
-  pp.printMethod(method);
-  return pp.tensor_constants;
+
+TORCH_API void PythonPrint(std::ostream& out, const script::Method& method, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+  PythonPrintPass pp(out, tensor_table, enforce_importable);
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+  pp.printMethod(const_cast<script::Method&>(method));
 }
 
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable) {
-  PythonPrintPass pp(out, enforce_importable);
-  pp.printModule(module);
-  return pp.tensor_constants;
+TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable) {
+  PythonPrintPass pp(out, tensor_table, enforce_importable);
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+  pp.printModule(const_cast<script::Module&>(module));
 }
 
 TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
index e1c7ef6..e815d90 100644 (file)
@@ -12,9 +12,9 @@ namespace script {
   struct Module;
 }
 
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, Graph& graph, bool enforce_importable=false);
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Method& graph, bool enforce_importable=false);
-TORCH_API std::vector<at::Tensor> PythonPrint(std::ostream& out, script::Module& module, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const Graph& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const script::Method& graph, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
+TORCH_API void PythonPrint(std::ostream& out, const script::Module& module, std::vector<at::Tensor>& tensor_table, bool enforce_importable=false);
 
 TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
 }}
index edded9c..05c1f5f 100644 (file)
@@ -21,6 +21,7 @@ using c10::Type;
 
 std::string getPythonName(const PyObject* obj_) {
   AutoGIL gil;
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   PyObject* obj = const_cast<PyObject*>(obj_);
   auto v = py::getattr(obj, "__name__", py::str("<python_value>"));
   // if this was a autograd.Function recover the name of the class
@@ -29,6 +30,7 @@ std::string getPythonName(const PyObject* obj_) {
 
 std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
   AutoGIL gil;
+  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
   if (py::isinstance<py::tuple>(pyobj)) {
     // This special-case for printing tuples handles a problem where
@@ -70,7 +72,7 @@ std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
 struct ConcretePythonOp : public PythonOp {
  ConcretePythonOp(Graph * graph)
  : PythonOp(graph) {}
virtual std::string name() const override {
+ std::string name() const override {
    AutoGIL gil;
    if(auto autograd = autogradFunction()) {
      return getPythonName(autograd->get());
@@ -89,14 +91,15 @@ struct ConcretePythonOp : public PythonOp {
      this->scalar_args.emplace_back(sa.get());
    }
  }
virtual Node * allocNewInstance(Graph * g) override {
+ Node * allocNewInstance(Graph * g) override {
    return new ConcretePythonOp(g);
  }
  // recover the autograd.Function instance, if this PythonOp's function
  // was originally SomeFunction.apply
  // used in ONNX for discovering symbolics
virtual c10::optional<THPObjectPtr> autogradFunction() const override {
+ c10::optional<THPObjectPtr> autogradFunction() const override {
    AutoGIL gil;
+   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
    py::handle obj = const_cast<PyObject*>(pyobj.get());
 
    auto r = py::getattr(obj, "__self__", py::none());
@@ -116,7 +119,7 @@ struct ConcretePythonOp : public PythonOp {
    return THPObjectPtr(r.release().ptr());
  }
 
- virtual void writeScalars(std::ostream& out) const override {
+ void writeScalars(std::ostream& out) const override {
    out << "(";
    int i = 0;
    for (auto& scalar : scalar_args) {
@@ -150,12 +153,12 @@ void initPythonIRBindings(PyObject * module_) {
       setInputTypes(*g, ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
       PropagateInputShapes(g);
     })
-    .def("export", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
+    .def("_export_onnx", [](const std::shared_ptr<Graph> g, const std::vector<at::Tensor>& initializers,
                       int64_t onnx_opset_version, bool defer_weight_export,
                       ::torch::onnx::OperatorExportTypes operator_export_type) {
       std::string graph;
       RawDataExportMap export_map;
-      std::tie(graph, export_map) = ExportGraph(
+      std::tie(graph, export_map) = export_onnx(
         g, initializers, onnx_opset_version, defer_weight_export, operator_export_type);
       std::unordered_map<std::string, py::bytes> python_serialized_export_map;
       for (auto& kv : export_map) {
@@ -171,12 +174,12 @@ void initPythonIRBindings(PyObject * module_) {
        py::arg("onnx_opset_version")=0,
        py::arg("defer_weight_export")=false,
        py::arg("operator_export_type")=::torch::onnx::OperatorExportTypes::ONNX)
-    .def("prettyPrintExport", [](const std::shared_ptr<Graph> g,
+    .def("_pretty_print_onnx", [](const std::shared_ptr<Graph> g,
           const std::vector<at::Tensor>& initializers,
           int64_t onnx_opset_version, bool defer_weight_export,
           ::torch::onnx::OperatorExportTypes operator_export_type,
           bool google_printer) {
-      return PrettyPrintExportedGraph(
+      return pretty_print_onnx(
         g, initializers, onnx_opset_version, defer_weight_export, operator_export_type,
         google_printer);
     }, py::arg("initializers"),
@@ -388,6 +391,7 @@ void initPythonIRBindings(PyObject * module_) {
         return n.t(Symbol::attr(name));
     })
     .def("zs_",[](Node & n, const char * name, TensorsAttr::ValueType v) {
+        // NOLINTNEXTLINE(modernize-loop-convert)
         for (size_t i = 0; i < v.size(); ++ i) {
             v[i] = autograd::Variable(v[i].view({})).data();
         }
index 870cb5e..0713a07 100644 (file)
@@ -296,7 +296,12 @@ struct Environment {
 
   void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
     Value* as_simple_value = asSimple(value);
-    if (as_simple_value && !as_simple_value->hasUniqueName() && meaningfulName(name)) {
+    if (as_simple_value && !as_simple_value->hasUniqueName() &&
+        meaningfulName(name) &&
+        // note: if the value wasn't defined in this block, we might be giving a name
+        // only used inside this block to a value outside of this. this is not
+        // normally helpful for debugging and causes import/export jitter.
+        as_simple_value->node()->owningBlock() == block()) {
       as_simple_value->setUniqueName(name);
     }
     // prevent re-assignment involving any sugared values
index 8306103..12331f2 100644 (file)
@@ -662,7 +662,8 @@ void initJitScriptBindings(PyObject* module) {
       })
       .def("_python_print", [](Module& self) {
         std::ostringstream ss;
-        std::vector<at::Tensor> tensors = PythonPrint(ss, self, true);
+        std::vector<at::Tensor> tensors;
+        PythonPrint(ss, self, tensors, true);
         return std::make_pair(ss.str(), tensors);
       });
 
@@ -690,7 +691,8 @@ void initJitScriptBindings(PyObject* module) {
     .def("pretty_print_schema", &Method::pretty_print_schema)
     .def("python_print", [](Method &m) {
       std::ostringstream oss;
-      std::vector<at::Tensor> constants = PythonPrint(oss, m, true);
+      std::vector<at::Tensor> constants;
+      PythonPrint(oss, m, constants, true);
       return std::make_pair(oss.str(), std::move(constants));
     });
 
index 3034e32..d18a743 100644 (file)
@@ -265,7 +265,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
                                                example_outputs, propagate)
 
     from torch.onnx.symbolic import _onnx_opset_version
-    return graph.prettyPrintExport(params, _onnx_opset_version, False, operator_export_type, google_printer)
+    return graph._pretty_print_onnx(params, _onnx_opset_version, False, operator_export_type, google_printer)
 
 
 # NOTE: the output `torch_out` will contain the output tensors resulting from
@@ -284,9 +284,9 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
     from torch.onnx.symbolic import _onnx_opset_version
     defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
     if export_params:
-        proto, export_map = graph.export(params, _onnx_opset_version, defer_weight_export, operator_export_type)
+        proto, export_map = graph._export_onnx(params, _onnx_opset_version, defer_weight_export, operator_export_type)
     else:
-        proto, export_map = graph.export([], _onnx_opset_version, False, operator_export_type)
+        proto, export_map = graph._export_onnx([], _onnx_opset_version, False, operator_export_type)
 
     if export_type == ExportTypes.PROTOBUF_FILE:
         assert(len(export_map) == 0)