Update graph printouts in JIT docs (#14914)
authorJames Reed <jamesreed@fb.com>
Fri, 7 Dec 2018 23:06:48 +0000 (15:06 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 7 Dec 2018 23:08:53 +0000 (15:08 -0800)
Summary:
Tracing records variable names and we have new types and stuff in the IR, so this updates the graph printouts in the docs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14914

Differential Revision: D13385101

Pulled By: jamesr66a

fbshipit-source-id: 6477e4861f1ac916329853763c83ea157be77f23

docs/source/jit.rst

index 2baa500..eced843 100644 (file)
@@ -596,37 +596,35 @@ Interpreting Graphs
     graph of a method named ``bar`` on a ScriptModule by accessing ``.bar.graph``.
 
     The example script above produces the graph::
+       graph(%len : int) {
+         %15 : int = prim::Constant[value=1]()
+         %9 : bool = prim::Constant[value=1]()
+         %7 : Device = prim::Constant[value="cpu"]()
+         %6 : int = prim::Constant[value=0]()
+         %5 : int = prim::Constant[value=6]()
+         %1 : int = prim::Constant[value=3]()
+         %2 : int = prim::Constant[value=4]()
+         %11 : int = prim::Constant[value=10]()
+         %14 : float = prim::Constant[value=1]()
+         %4 : int[] = prim::ListConstruct(%1, %2)
+         %rv.1 : Tensor = aten::zeros(%4, %5, %6, %7)
+         %rv : Tensor = prim::Loop(%len, %9, %rv.1)
+           block0(%i : int, %13 : Tensor) {
+             %12 : bool = aten::lt(%i, %11)
+             %rv.4 : Tensor = prim::If(%12)
+               block0() {
+                 %rv.2 : Tensor = aten::sub(%13, %14, %15)
+                 -> (%rv.2)
+               }
+               block1() {
+                 %rv.3 : Tensor = aten::add(%13, %14, %15)
+                 -> (%rv.3)
+               }
+             -> (%9, %rv.4)
+           }
+         return (%rv);
+       }
 
-        graph(%len : int) {
-          %13 : float = prim::Constant[value=1]()
-          %10 : int = prim::Constant[value=10]()
-          %2 : int = prim::Constant[value=4]()
-          %1 : int = prim::Constant[value=3]()
-          %3 : int[] = prim::ListConstruct(%1, %2)
-          %4 : int = prim::Constant[value=6]()
-          %5 : int = prim::Constant[value=0]()
-          %6 : int[] = prim::Constant[value=[0, -1]]()
-          %rv.1 : Dynamic = aten::zeros(%3, %4, %5, %6)
-          %8 : int = prim::Constant[value=1]()
-          %rv : Dynamic = prim::Loop(%len, %8, %rv.1)
-            block0(%i : int, %12 : Dynamic) {
-              %11 : int = aten::lt(%i, %10)
-              %rv.4 : Dynamic = prim::If(%11)
-                block0() {
-                  %14 : int = prim::Constant[value=1]()
-                  %rv.2 : Dynamic = aten::sub(%12, %13, %14)
-                  -> (%rv.2)
-                }
-                block1() {
-                  %16 : int = prim::Constant[value=1]()
-                  %rv.3 : Dynamic = aten::add(%12, %13, %16)
-                  -> (%rv.3)
-                }
-              %19 : int = prim::Constant[value=1]()
-              -> (%19, %rv.4)
-            }
-          return (%rv);
-        }
 
     Take the instruction ``%rv.1 : Dynamic = aten::zeros(%3, %4, %5, %6)`` for
     example. ``%rv.1 : Dynamic`` means we assign the output to a (unique)
@@ -676,34 +674,38 @@ Automatic Trace Checking
         traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
 
     Gives us the following diagnostic information::
-
-        ERROR: Graphs differed across invocations!
-        Graph diff:
-            graph(%0 : Dynamic) {
-                  %1 : int = prim::Constant[value=0]()
-                  %2 : int = prim::Constant[value=0]()
-                  %3 : Dynamic = aten::select(%0, %1, %2)
-                  %4 : int = prim::Constant[value=0]()
-                  %5 : int = prim::Constant[value=0]()
-                  %6 : Dynamic = aten::select(%0, %4, %5)
-                  %7 : Dynamic = aten::mul(%3, %6)
-                  %8 : int = prim::Constant[value=0]()
-                  %9 : int = prim::Constant[value=1]()
-                  %10 : Dynamic = aten::select(%0, %8, %9)
-                  %11 : Dynamic = aten::mul(%7, %10)
-                  %12 : int = prim::Constant[value=0]()
-                  %13 : int = prim::Constant[value=2]()
-                  %14 : Dynamic = aten::select(%0, %12, %13)
-                  %15 : Dynamic = aten::mul(%11, %14)
-              +   %16 : int = prim::Constant[value=0]()
-              +   %17 : int = prim::Constant[value=3]()
-              +   %18 : Dynamic = aten::select(%0, %16, %17)
-              +   %19 : Dynamic = aten::mul(%15, %18)
-              -   return (%15);
-              ?             ^
-              +   return (%19);
-              ?             ^
-            }
+       ERROR: Graphs differed across invocations!
+       Graph diff:
+                 graph(%x : Tensor) {
+                   %1 : int = prim::Constant[value=0]()
+                   %2 : int = prim::Constant[value=0]()
+                   %result.1 : Tensor = aten::select(%x, %1, %2)
+                   %4 : int = prim::Constant[value=0]()
+                   %5 : int = prim::Constant[value=0]()
+                   %6 : Tensor = aten::select(%x, %4, %5)
+                   %result.2 : Tensor = aten::mul(%result.1, %6)
+                   %8 : int = prim::Constant[value=0]()
+                   %9 : int = prim::Constant[value=1]()
+                   %10 : Tensor = aten::select(%x, %8, %9)
+               -   %result : Tensor = aten::mul(%result.2, %10)
+               +   %result.3 : Tensor = aten::mul(%result.2, %10)
+               ?          ++
+                   %12 : int = prim::Constant[value=0]()
+                   %13 : int = prim::Constant[value=2]()
+                   %14 : Tensor = aten::select(%x, %12, %13)
+               +   %result : Tensor = aten::mul(%result.3, %14)
+               +   %16 : int = prim::Constant[value=0]()
+               +   %17 : int = prim::Constant[value=3]()
+               +   %18 : Tensor = aten::select(%x, %16, %17)
+               -   %15 : Tensor = aten::mul(%result, %14)
+               ?     ^                                 ^
+               +   %19 : Tensor = aten::mul(%result, %18)
+               ?     ^                                 ^
+               -   return (%15);
+               ?             ^
+               +   return (%19);
+               ?             ^
+                 }
 
 
     This message indicates to us that the computation differed between when
@@ -733,23 +735,19 @@ Automatic Trace Checking
 
     Which produces::
 
-        graph(%x : Dynamic) {
-          %1 : int = prim::Constant[value=0]()
-          %2 : int = prim::Constant[value=0]()
-          %result.1 : Dynamic = aten::select(%x, %2, %1)
-          %4 : int = aten::size(%x, %1)
-          %5 : int = prim::Constant[value=1]()
-          %result : Dynamic = prim::Loop(%4, %5, %result.1)
-            block0(%i : int, %7 : Dynamic) {
-              %9 : int = prim::Constant[value=0]()
-              %10 : Dynamic = aten::select(%x, %9, %i)
-              %result.2 : Dynamic = aten::mul(%7, %10)
-              %12 : int = prim::Constant[value=1]()
-              -> (%12, %result.2)
-            }
-          return (%result);
-        }
-
+       graph(%x : Tensor) {
+         %5 : bool = prim::Constant[value=1]()
+         %1 : int = prim::Constant[value=0]()
+         %result.1 : Tensor = aten::select(%x, %1, %1)
+         %4 : int = aten::size(%x, %1)
+         %result : Tensor = prim::Loop(%4, %5, %result.1)
+           block0(%i : int, %7 : Tensor) {
+             %10 : Tensor = aten::select(%x, %1, %i)
+             %result.2 : Tensor = aten::mul(%7, %10)
+             -> (%5, %result.2)
+           }
+         return (%result);
+       }
 
 Tracer Warnings
     The tracer produces warnings for several problematic patterns in traced