Show warning in eager mode for empty containers (#62978)
authorKarol Sputo <35783431+karolsputo@users.noreply.github.com>
Thu, 12 Aug 2021 22:36:29 +0000 (15:36 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 12 Aug 2021 23:11:27 +0000 (16:11 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/54873

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62978

Reviewed By: navahgar

Differential Revision: D30278343

Pulled By: ansley

fbshipit-source-id: ebb19f7b8a10720f2612b99a2668d1ebbc1f2d16

test/jit/test_isinstance.py
torch/_jit_internal.py

index 17b60a8..93b2605 100644 (file)
@@ -2,6 +2,7 @@ import os
 import sys
 
 import torch
+import warnings
 from typing import List, Any, Dict, Tuple, Optional
 
 # Make the helper files in test/ importable
@@ -295,3 +296,17 @@ class TestIsinstance(JitTestCase):
 
         with self.assertRaisesRegex(RuntimeError, err_highlight):
             fn2(2)
+
+    def test_empty_container_throws_warning_in_eager(self):
+        def fn(x: Any):
+            torch.jit.isinstance(x, List[int])
+
+        with warnings.catch_warnings(record=True) as w:
+            x: List[int] = []
+            fn(x)
+            self.assertEqual(len(w), 1)
+
+        with warnings.catch_warnings(record=True) as w:
+            x: int = 2
+            fn(x)
+            self.assertEqual(len(w), 0)
index 37b1baf..bd7b616 100644 (file)
@@ -1132,12 +1132,22 @@ def check_args_exist(target_type) -> None:
         raise_error_container_parameter_missing("Optional")
 
 
+def check_empty_containers(obj) -> None:
+    if not obj:
+        warnings.warn("The inner type of a container is lost when "
+                      "calling torch.jit.isinstance in eager mode. For "
+                      "example, List[int] would become list and "
+                      "therefore falsely return True for List[float] or"
+                      " List[str].")
+
+
 # supports List/Dict/Tuple and Optional types
 # TODO support future
 def container_checker(obj, target_type) -> bool:
     origin_type = get_origin(target_type)
     check_args_exist(target_type)
     if origin_type is list or origin_type is List:
+        check_empty_containers(obj)
         if not isinstance(obj, list):
             return False
         arg_type = get_args(target_type)[0]
@@ -1151,6 +1161,7 @@ def container_checker(obj, target_type) -> bool:
                 return False
         return True
     elif origin_type is Dict or origin_type is dict:
+        check_empty_containers(obj)
         if not isinstance(obj, dict):
             return False
         key_type = get_args(target_type)[0]
@@ -1167,6 +1178,7 @@ def container_checker(obj, target_type) -> bool:
                 return False
         return True
     elif origin_type is Tuple or origin_type is tuple:
+        check_empty_containers(obj)
         if not isinstance(obj, tuple):
             return False
         arg_types = get_args(target_type)