allow lists to contain any tensor type (#17321)
authorMichael Suo <suo@fb.com>
Thu, 21 Feb 2019 08:15:59 +0000 (00:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Feb 2019 08:18:50 +0000 (00:18 -0800)
Summary:
If something is a TensorList, it should be a list of `TensorType`, not a list of some specialized type.
Fixes #17140, #15642
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17321

Differential Revision: D14158192

Pulled By: suo

fbshipit-source-id: ba8fe6ae8d618c73b23cd00cbcb3111c390c5514

test/test_jit.py
torch/csrc/jit/script/compiler.cpp

index f8c4b45..76ba1d9 100644 (file)
@@ -10313,6 +10313,19 @@ a")
         a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
         self.checkScript(fn, (a_dict, ('a', 'c')))
 
+    def test_tensor_import_export(self):
+        @torch.jit.script
+        def foo(x):
+            a = torch.tensor(1)
+            b = torch.tensor([1, 2])
+            c = [a, b]
+            return c
+
+        self.run_pass('constant_propagation', foo.graph)
+        m = torch.jit.ScriptModule()
+        m._create_method_from_graph("forward", foo.graph)
+        self.getExportImportCopy(m)
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index 35f0c23..40cfc41 100644 (file)
@@ -2243,6 +2243,15 @@ struct to_ir {
         } else if (!values.empty()) {
           elem_type = values.at(0)->type();
         }
+
+        // Tensors are special because they have dymnamic properties. So any
+        // list containing tensors should be typed with the unified typeof all
+        // the elements.
+        if (elem_type->isSubtypeOf(TensorType::get())) {
+          for (const auto& value : values) {
+            elem_type = unifyTypes(elem_type, value->type()).value();
+          }
+        }
         for (auto v : values) {
           if (!v->type()->isSubtypeOf(elem_type)) {
             throw ErrorReport(tree)