From: Thomas Viehmann Date: Sun, 24 Mar 2019 05:54:36 +0000 (-0700) Subject: Specialize optional tensor inputs to graphs in the JIT (#18360) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~660 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7cc7ed1322405ba3c627b9c5661a330f92c4183d;p=platform%2Fupstream%2Fpytorch.git Specialize optional tensor inputs to graphs in the JIT (#18360) Summary: This specializes optional tensor inputs to either a DimensionedTensorType or, when None is passed, UndefinedTensor (aka AutogradZeroTensorType). This works because we already have different specs and thus separate plans for the two cases. It enhances the shape analysis - because now unwrapped optional tensors will have DimensionedTensorType with appropriate shape and required grad etc. Also, when combined with "if-pruning" (which I understand #18259 works towards), we actually get much nicer concrete graphs, too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18360 Differential Revision: D14590577 Pulled By: soumith fbshipit-source-id: cac204a506d1d38b15703cbcc67a6b75fd4979f4 --- diff --git a/test/test_jit.py b/test/test_jit.py index 8145c88..17965a2 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4986,6 +4986,23 @@ a") if y is not None and x_none: print(x + y) # noqa: T484 + def test_optional_tensor(self): + @torch.jit.script + def fn(x): + # type: (Optional[Tensor]) -> int + if x is None: + return 1 + else: + return 0 + + fn(None) + g = fn.graph_for(None) + self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor') + t = torch.ones(1) + fn(t) + g = fn.graph_for(t) + self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType') + def test_while_write_outer_then_read(self): def func(a, b): while bool(a < 10): diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index a345c6b..74e9f07 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,7 +153,8 @@ struct ArgumentSpec { private: TypePtr fillType(TypePtr original, size_t& offset) const { - if (original->isSubtypeOf(TensorType::get())) { + if (original->isSubtypeOf(TensorType::get()) + || original->isSubtypeOf(OptionalType::ofTensor())) { auto& arg = args.at(offset++); if (!arg.defined()) return AutogradZeroTensorType::get(); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 8ac159b..2ada059 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -513,6 +513,16 @@ class ShapePropagator { } return; } + case prim::unchecked_unwrap_optional: { + // we know we cannot have None as input, so we can always pass + // on the type. + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; + } case prim::ConstantChunk: { Value* tensor = node->input(); if (auto type = tensor->type()->cast()) { @@ -529,10 +539,17 @@ class ShapePropagator { return; } case aten::_unwrap_optional: { + // if we have None as input, we need to leave the output alone auto input_ivalue = toIValue(node->input()); if (input_ivalue && input_ivalue->isNone()) { return; } + if(auto ot = node->input()->type()->cast()) { + node->output()->setType(ot->getElementType()); + } else { + node->output()->setType(node->input()->type()); + } + return; } default: break; // fall-through