if y is not None and x_none:
print(x + y) # noqa: T484
+ def test_optional_tensor(self):
+ @torch.jit.script
+ def fn(x):
+ # type: (Optional[Tensor]) -> int
+ if x is None:
+ return 1
+ else:
+ return 0
+
+ fn(None)
+ g = fn.graph_for(None)
+ self.assertEqual(list(g.inputs())[0].type().str(), 'UndefinedTensor')
+ t = torch.ones(1)
+ fn(t)
+ g = fn.graph_for(t)
+ self.assertEqual(list(g.inputs())[0].type().kind(), 'DimensionedTensorType')
+
def test_while_write_outer_then_read(self):
def func(a, b):
while bool(a < 10):
private:
TypePtr fillType(TypePtr original, size_t& offset) const {
- if (original->isSubtypeOf(TensorType::get())) {
+ if (original->isSubtypeOf(TensorType::get())
+ || original->isSubtypeOf(OptionalType::ofTensor())) {
auto& arg = args.at(offset++);
if (!arg.defined())
return AutogradZeroTensorType::get();
}
return;
}
+ case prim::unchecked_unwrap_optional: {
+ // we know we cannot have None as input, so we can always pass
+ // on the type.
+ if(auto ot = node->input()->type()->cast<OptionalType>()) {
+ node->output()->setType(ot->getElementType());
+ } else {
+ node->output()->setType(node->input()->type());
+ }
+ return;
+ }
case prim::ConstantChunk: {
Value* tensor = node->input();
if (auto type = tensor->type()->cast<DimensionedTensorType>()) {
return;
}
case aten::_unwrap_optional: {
+ // if we have None as input, we need to leave the output alone
auto input_ivalue = toIValue(node->input());
if (input_ivalue && input_ivalue->isNone()) {
return;
}
+ if(auto ot = node->input()->type()->cast<OptionalType>()) {
+ node->output()->setType(ot->getElementType());
+ } else {
+ node->output()->setType(node->input()->type());
+ }
+ return;
}
default:
break; // fall-through