acc type inference (#63119)
authorZeina Migeed <migeedz@fb.com>
Thu, 19 Aug 2021 22:22:52 +0000 (15:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 19 Aug 2021 22:23:56 +0000 (15:23 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63119

Test Plan:
buck run mode/opt-clang caffe2/torch/fb/model_transform/experimental:fx_ir_lower_inline_cvr -- \
    --action=lower_and_run \
    --filename=inline_cvr_7x_dec_2020.model \
    --print_glow_glog=True

Reviewed By: jamesr66a, jfix71, ansley

Differential Revision: D30235895

fbshipit-source-id: dab7f96e1799b99eeae0ee519cf0ddd636fddf2e

torch/fx/experimental/graph_gradual_typechecker.py

index e3c1ce8..5ce53a7 100644 (file)
@@ -63,8 +63,6 @@ def broadcast_types(t1, t2):
 
         (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
 
-        if not is_consistent(t1, t2):
-            raise TypeError
 
         return (t1, t2)
     else:
@@ -521,7 +519,7 @@ class GraphTypeChecker:
             return n.type
 
         elif n.op == 'get_attr':
-            t = self.traced.get_parameter(n.target)
+            t = get_parameter(self.traced, n.target)  # type: ignore[arg-type]
             if isinstance(t.data, torch.Tensor):
                 n.type = TensorType(t.data.shape)
             return n.type
@@ -705,3 +703,37 @@ class Refine:
 
         else:
             pass
+
+
+def get_parameter(traced, target: str):
+    """
+    Returns the parameter given by ``target`` if it exists,
+    otherwise throws an error.
+
+    See the docstring for ``get_submodule`` for a more detailed
+    explanation of this method's functionality as well as how to
+    correctly specify ``target``.
+
+    Args:
+        target: The fully-qualified string name of the Parameter
+            to look for. (See ``get_submodule`` for how to specify a
+            fully-qualified string.)
+
+    Returns:
+        torch.nn.Parameter: The Parameter referenced by ``target``
+
+    Raises:
+        AttributeError: If the target string references an invalid
+            path or resolves to something that is not an
+            ``nn.Parameter``
+    """
+    module_path, _, param_name = target.rpartition(".")
+
+    mod: torch.nn.Module = traced.get_submodule(module_path)
+
+    if not hasattr(mod, param_name):
+        raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`")
+
+    param: torch.nn.Parameter = getattr(mod, param_name)
+
+    return param