Specialize optional tensor inputs to graphs in the JIT (#18360)
authorThomas Viehmann <tv.code@beamnet.de>
Sun, 24 Mar 2019 05:54:36 +0000 (22:54 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 24 Mar 2019 06:00:37 +0000 (23:00 -0700)
Summary:
This specializes optional tensor inputs to either a DimensionedTensorType or, when None is passed,
UndefinedTensor (aka AutogradZeroTensorType).
This works because we already have different specs and thus separate plans for the two cases.
It enhances the shape analysis - because now unwrapped optional tensors will have DimensionedTensorType with appropriate shape and required grad etc.
Also, when combined with "if-pruning" (which I understand #18259 works towards), we actually get much nicer concrete graphs, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18360

Differential Revision: D14590577

Pulled By: soumith

fbshipit-source-id: cac204a506d1d38b15703cbcc67a6b75fd4979f4

test/test_jit.py
torch/csrc/jit/argument_spec.h
torch/csrc/jit/passes/shape_analysis.cpp

index 8145c88..17965a2 100644 (file)
@@ -4986,6 +4986,23 @@ a")
                 if y is not None and x_none:
                     print(x + y)  # noqa: T484
 
+    def test_optional_tensor(self):
+        @torch.jit.script
+        def fn(x):
+            # type: (Optional[Tensor]) -> int
+            if x is None:
+                return 1
+            else:
+                return 0
+
+        fn(None)
+        g = fn.graph_for(None)
+        self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor')
+        t = torch.ones(1)
+        fn(t)
+        g = fn.graph_for(t)
+        self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType')
+
     def test_while_write_outer_then_read(self):
         def func(a, b):
             while bool(a < 10):
index a345c6b..74e9f07 100644 (file)
@@ -153,7 +153,8 @@ struct ArgumentSpec {
 
  private:
   TypePtr fillType(TypePtr original, size_t& offset) const {
-    if (original->isSubtypeOf(TensorType::get())) {
+    if (original->isSubtypeOf(TensorType::get())
+       || original->isSubtypeOf(OptionalType::ofTensor())) {
       auto& arg = args.at(offset++);
       if (!arg.defined())
         return AutogradZeroTensorType::get();
index 8ac159b..2ada059 100644 (file)
@@ -513,6 +513,16 @@ class ShapePropagator {
         }
         return;
       }
+      case prim::unchecked_unwrap_optional: {
+       // we know we cannot have None as input, so we can always pass
+       // on the type.
+       if(auto ot = node->input()->type()->cast<OptionalType>()) {
+         node->output()->setType(ot->getElementType());
+       } else {
+         node->output()->setType(node->input()->type());
+       }
+       return;
+      }
       case prim::ConstantChunk: {
         Value* tensor = node->input();
         if (auto type = tensor->type()->cast<DimensionedTensorType>()) {
@@ -529,10 +539,17 @@ class ShapePropagator {
         return;
       }
       case aten::_unwrap_optional: {
+       // if we have None as input, we need to leave the output alone
         auto input_ivalue = toIValue(node->input());
         if (input_ivalue && input_ivalue->isNone()) {
           return;
         }
+       if(auto ot = node->input()->type()->cast<OptionalType>()) {
+         node->output()->setType(ot->getElementType());
+       } else {
+         node->output()->setType(node->input()->type());
+       }
+       return;
       }
       default:
         break; // fall-through