Allow tracing with fork/wait (#15184)
authorJames Sun <jamessun@fb.com>
Tue, 18 Dec 2018 04:28:00 +0000 (20:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 04:34:26 +0000 (20:34 -0800)
Summary:
There is still limitation on this: if a script module is somewhere
in the trace, the inputs/outputs can only be tensors or tuples of
tensors.

resolves #15052
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15184

Differential Revision: D13457691

Pulled By: highker

fbshipit-source-id: 8fe46afc41357a0eb8eadd83f687b31d074deb0e

aten/src/ATen/core/jit_type.h
test/test_jit.py
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h

index 9971551..c057963 100644 (file)
@@ -532,6 +532,9 @@ struct CAFFE2_API FutureType : public SingleElementType<TypeKind::FutureType, Fu
     ss << "Future[" << getElementType()->python_str() << "]";
     return ss.str();
   }
+  TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
+    return create(contained_types.at(0));
+  }
 private:
   FutureType(TypePtr elem) : SingleElementType(elem) {}
 };
index 3e072de..86d3410 100644 (file)
@@ -11222,6 +11222,59 @@ class TestAsync(JitTestCase):
         self.assertEqual(y2, foo2(x1, x2))
         self.assertEqual(y3, foo3(x1, x2, x3))
 
+    def test_async_script_trace(self):
+        class Traced(nn.Module):
+            def __init__(self):
+                super(Traced, self).__init__()
+
+            def forward(self, x):
+                return tuple([torch.neg(x), x])
+
+        class Module(torch.jit.ScriptModule):
+            def __init__(self):
+                super(Module, self).__init__(False)
+                x = torch.rand(3, 3)
+                self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
+
+            @torch.jit.script_method
+            def forward(self, x):
+                # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
+                future1 = torch.jit._fork(self.traced, x)
+                future2 = torch.jit._fork(torch.neg, x)
+
+                tensor_tuple = torch.jit._wait(future1)
+                tensor_single = torch.jit._wait(future2)
+
+                tensor_list = []
+                tensor_list.append(tensor_tuple[0])
+                tensor_list.append(tensor_single)
+
+                # return a nested structure of tensors
+                return (tensor_list, tensor_tuple, tensor_tuple[1])
+
+        class Tuple(nn.Module):
+            def __init__(self):
+                super(Tuple, self).__init__()
+                self.module = Module()
+
+            def forward(self, x):
+                z = torch.neg(x)
+                y = self.module(x)
+                list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
+                return tuple(list)
+
+        x = torch.rand(3, 3)
+        module = torch.jit.trace(Tuple(), (x), _force_outplace=True)
+
+        # Make sure we have forks
+        self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
+        # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
+        self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
+        self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
+
+        y = torch.neg(x)
+        self.assertEqual(module(x), tuple([y, y, y, y, x, x]))
+
 
 for test in autograd_method_tests:
     add_autograd_test(*test)
index 07b3019..cbdf893 100644 (file)
@@ -513,7 +513,11 @@ private:
     // NB: we could just run the fallback in here and call it a day, but that would loose all
     // the control flow information we have in the graph. Thus, we run the fallback to
     // get the correct output values, but we will override the tracing states later.
-    getOrCompileFallback().run(stack);
+    {
+      // No need to trace a script module.
+      ResourceGuard guard(tracer::pauseTracing());
+      getOrCompileFallback().run(stack);
+    }
 
     // Traces always have types propagated through them, so we make sure to
     // also propagate types through the graph we are inserting here.
@@ -527,10 +531,7 @@ private:
 
     auto outputs = last(stack, num_outputs);
     for (size_t i = 0; i < outputs.size(); ++i) {
-      // We can't attach tracing states to scalars, so we have to skip them here
-      // TODO: Should we reinterpret them as scalar tensors instead?
-      if (!outputs[i].isTensor()) continue;
-      tracer::setValueTrace(outputs[i].toTensor(), output_values[i]);
+      tracer::setValueTrace(outputs[i], output_values[i]);
     }
   }
 
index 91b333c..f86ae6d 100644 (file)
@@ -37,6 +37,33 @@ thread_local std::shared_ptr<TracingState> tracing_state;
 
 } // namespace detail
 
+void setValueTrace(const IValue &v, Value *value) {
+  if (v.isTensor()) {
+    auto var = v.toTensor();
+    JIT_ASSERT(var.defined());
+    getTracingState()->value_map[var] = value;
+  } else if (v.isTensorList()) {
+    auto& outputs = v.toTensorList()->elements();
+    auto graph = getTracingState()->graph;
+    Node * unpack_node = graph->appendNode(graph->create(prim::ListUnpack, {value}, outputs.size()));
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      setValueTrace(outputs[i], unpack_node->outputs()[i]);
+    }
+  } else if (v.isTuple()) {
+    auto& outputs = v.toTuple()->elements();
+    auto graph = getTracingState()->graph;
+    Node * unpack_node = graph->appendNode(graph->create(prim::TupleUnpack, {value}, outputs.size()));
+    for (size_t i = 0; i < outputs.size(); ++i) {
+      setValueTrace(outputs[i], unpack_node->outputs()[i]);
+    }
+  } else {
+    std::ostringstream os;
+    os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
+       << "Supported types are tensor, tensor list, and tuple of tensors.";
+    throw std::runtime_error(os.str());
+  }
+}
+
 void addInputs(Node *n, const char * name, int64_t value) {
   using ArgumentStash = jit::tracer::ArgumentStash;
   if (ArgumentStash::hasValue(name)) {
index 34b285f..691a1d9 100644 (file)
@@ -32,16 +32,22 @@ TORCH_API void setRecordSourceLocation(void (*v)(Node*));
 // Having finished adding a new 'node' to the graph IR 'setValueTrace' associates
 // this node with an output variable, so that further operations involving this
 // variable know which node in the IR to reference.
-inline void setValueTrace(const Variable& var, Value *value) {
-  JIT_ASSERT(var.defined());
-  getTracingState()->value_map[var] = value;
-}
+TORCH_API void setValueTrace(const IValue& v, Value* value);
 
 inline void delValueTrace(const Variable& var) {
   JIT_ASSERT(var.defined());
   getTracingState()->value_map.erase(var);
 }
 
+inline std::function<void()> pauseTracing() {
+  std::shared_ptr<tracer::TracingState> state = getTracingState();
+  tracer::setTracingState(nullptr);
+
+  return [state]() {
+    tracer::setTracingState(state);
+  };
+}
+
 // Given a variable 'var', return the 'node' which represents the instruction
 // which computes the value of this variable in the IR.
 // Here, we interpret untraced variables as constants that are just embedded