From 9e176fe5fe85dda2fa45700cbbeeb2d517f00b3a Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 24 Mar 2019 21:26:45 -0700 Subject: [PATCH] Revert "Specialize optional tensor inputs to graphs in the JIT (#18360)" (#18411) Summary: This reverts commit 7cc7ed1322405ba3c627b9c5661a330f92c4183d. I think it's better to sort out the issues raised in #18407 firs. I'm sorry for not stopping it earlier. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18411 Differential Revision: D14594937 Pulled By: soumith fbshipit-source-id: 3c90b7fa7694e2f59e55607acecde4a47af801ea --- test/test_jit.py | 17 ----------------- torch/csrc/jit/argument_spec.h | 3 +-- torch/csrc/jit/passes/shape_analysis.cpp | 17 ----------------- 3 files changed, 1 insertion(+), 36 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index d5b644e..8606a6e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5012,23 +5012,6 @@ a") if y is not None and x_none: print(x + y) # noqa: T484 - def test_optional_tensor(self): - @torch.jit.script - def fn(x): - # type: (Optional[Tensor]) -> int - if x is None: - return 1 - else: - return 0 - - fn(None) - g = fn.graph_for(None) - self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor') - t = torch.ones(1) - fn(t) - g = fn.graph_for(t) - self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType') - def test_while_write_outer_then_read(self): def func(a, b): while bool(a < 10): diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index 74e9f07..a345c6b 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -153,8 +153,7 @@ struct ArgumentSpec { private: TypePtr fillType(TypePtr original, size_t& offset) const { - if (original->isSubtypeOf(TensorType::get()) - || original->isSubtypeOf(OptionalType::ofTensor())) { + if (original->isSubtypeOf(TensorType::get())) { auto& arg = args.at(offset++); if (!arg.defined()) return AutogradZeroTensorType::get(); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 2ada059..8ac159b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -513,16 +513,6 @@ class ShapePropagator { } return; } - case prim::unchecked_unwrap_optional: { - // we know we cannot have None as input, so we can always pass - // on the type. - if(auto ot = node->input()->type()->cast()) { - node->output()->setType(ot->getElementType()); - } else { - node->output()->setType(node->input()->type()); - } - return; - } case prim::ConstantChunk: { Value* tensor = node->input(); if (auto type = tensor->type()->cast()) { @@ -539,17 +529,10 @@ class ShapePropagator { return; } case aten::_unwrap_optional: { - // if we have None as input, we need to leave the output alone auto input_ivalue = toIValue(node->input()); if (input_ivalue && input_ivalue->isNone()) { return; } - if(auto ot = node->input()->type()->cast()) { - node->output()->setType(ot->getElementType()); - } else { - node->output()->setType(node->input()->type()); - } - return; } default: break; // fall-through -- 2.7.4