From 2c302b6ea68eb9024f5f1b9cdc41e3bf39556fa2 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 21 Feb 2019 00:15:59 -0800 Subject: [PATCH] allow lists to contain any tensor type (#17321) Summary: If something is a TensorList, it should be a list of `TensorType`, not a list of some specialized type. Fixes #17140, #15642 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17321 Differential Revision: D14158192 Pulled By: suo fbshipit-source-id: ba8fe6ae8d618c73b23cd00cbcb3111c390c5514 --- test/test_jit.py | 13 +++++++++++++ torch/csrc/jit/script/compiler.cpp | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index f8c4b45..76ba1d9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10313,6 +10313,19 @@ a") a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} self.checkScript(fn, (a_dict, ('a', 'c'))) + def test_tensor_import_export(self): + @torch.jit.script + def foo(x): + a = torch.tensor(1) + b = torch.tensor([1, 2]) + c = [a, b] + return c + + self.run_pass('constant_propagation', foo.graph) + m = torch.jit.ScriptModule() + m._create_method_from_graph("forward", foo.graph) + self.getExportImportCopy(m) + class MnistNet(nn.Module): def __init__(self): diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 35f0c23..40cfc41 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -2243,6 +2243,15 @@ struct to_ir { } else if (!values.empty()) { elem_type = values.at(0)->type(); } + + // Tensors are special because they have dymnamic properties. So any + // list containing tensors should be typed with the unified typeof all + // the elements. + if (elem_type->isSubtypeOf(TensorType::get())) { + for (const auto& value : values) { + elem_type = unifyTypes(elem_type, value->type()).value(); + } + } for (auto v : values) { if (!v->type()->isSubtypeOf(elem_type)) { throw ErrorReport(tree) -- 2.7.4