From 71ab48ed3b99fa43b715cb2efa3e4b8da40e92cf Mon Sep 17 00:00:00 2001 From: Zeina Migeed Date: Thu, 19 Aug 2021 15:22:52 -0700 Subject: [PATCH] acc type inference (#63119) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63119 Test Plan: buck run mode/opt-clang caffe2/torch/fb/model_transform/experimental:fx_ir_lower_inline_cvr -- \ --action=lower_and_run \ --filename=inline_cvr_7x_dec_2020.model \ --print_glow_glog=True Reviewed By: jamesr66a, jfix71, ansley Differential Revision: D30235895 fbshipit-source-id: dab7f96e1799b99eeae0ee519cf0ddd636fddf2e --- torch/fx/experimental/graph_gradual_typechecker.py | 38 ++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index e3c1ce8..5ce53a7 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -63,8 +63,6 @@ def broadcast_types(t1, t2): (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) - if not is_consistent(t1, t2): - raise TypeError return (t1, t2) else: @@ -521,7 +519,7 @@ class GraphTypeChecker: return n.type elif n.op == 'get_attr': - t = self.traced.get_parameter(n.target) + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] if isinstance(t.data, torch.Tensor): n.type = TensorType(t.data.shape) return n.type @@ -705,3 +703,37 @@ class Refine: else: pass + + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param -- 2.7.4