Add equality constraints for some acc opeartions for symbolic inference (#63689)
authorZeina Migeed <migeedz@fb.com>
Mon, 23 Aug 2021 21:09:10 +0000 (14:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 23 Aug 2021 21:11:08 +0000 (14:11 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63689

Test Plan:
buck run mode/opt-clang caffe2/torch/fb/model_transform/experimental:fx_ir_lower_inline_cvr -- \
    --action=lower_and_run \
    --filename=inline_cvr_7x_dec_2020.model \
    --print_glow_glog=True

Reviewed By: jamesr66a

Differential Revision: D30462113

fbshipit-source-id: 0b2a1ce9770561248527d47c07b80112491dc949

torch/fx/experimental/graph_gradual_typechecker.py
torch/fx/experimental/unify_refinements.py

index 5ce53a7..6e05f91 100644 (file)
@@ -436,7 +436,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance):
 def flatten_check(tensor_type, start_dim, end_dim):
     l = len(tensor_type.__args__)
 
-    start_dim = l if start_dim == -1 else start_dim
+    start_dim = l if start_dim == -1 else abs(start_dim)
     end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
 
     if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
@@ -668,6 +668,10 @@ class Refine:
         elif isinstance(typ, TensorType):
             new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
             return TensorType(tuple(new_args))
+        elif isinstance(typ, list):
+            return [self.replace_dyn_with_fresh_var(t) for t in typ]
+        elif isinstance(typ, tuple):
+            return (self.replace_dyn_with_fresh_var(t) for t in typ)
         else:
             return typ
 
@@ -698,8 +702,10 @@ class Refine:
                 pass
 
         if n.op == 'output':
-            assert isinstance(n.args[0], Node)
-            n.type = n.args[0].type
+            def get_node_type(a):
+                return a.type
+            n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+            return n.type
 
         else:
             pass
index c856104..5074377 100644 (file)
@@ -52,6 +52,8 @@ def substitute_solution_one_type(mapping, t):
     if isinstance(t, Var):
         if t in mapping.keys():
             return mapping[t]
+        else:
+            return t
 
     elif isinstance(t, TensorType):
         new_type = []
@@ -62,6 +64,21 @@ def substitute_solution_one_type(mapping, t):
                 new_type.append(typ)
         return TensorType(tuple(new_type))
 
+    elif isinstance(t, list):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return new_type
+
+    elif isinstance(t, tuple):
+        new_type = []
+        for typ in t:
+            new_type.append(substitute_solution_one_type(mapping, typ))
+        return tuple(new_type)
+
+    else:
+        return t
+
 
 def substitute_all_types(graph, mapping):
     """