inference for algebraic expressions (#63822)
authorZeina Migeed <migeedz@fb.com>
Thu, 26 Aug 2021 03:42:14 +0000 (20:42 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 03:47:23 +0000 (20:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63822

Infer algebraic expressions and add it to our symbolic inferencer. Works for conv2D and can be extended to other operations.

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D30518469

Pulled By: migeed-z

fbshipit-source-id: b92dfa40b2d834a535177da42b851701b8f7178c

test/fx/test_gradual_type.py
torch/fx/experimental/graph_gradual_typechecker.py
torch/fx/experimental/unify_refinements.py

index 203cf6b..37e8db1 100644 (file)
@@ -9,7 +9,14 @@ from torch.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, br
 from torch.fx.experimental.rewriter import RewritingTracer
 from torch.fx import GraphModule
 from torch.fx.passes.shape_prop import ShapeProp
-from torch.fx.experimental.unification import Var
+
+try:
+    import sympy
+    HAS_SYMPY = True
+except ImportError:
+    HAS_SYMPY = False
+skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
+
 
 try:
     from torchvision.models import resnet50
@@ -19,13 +26,6 @@ except ImportError:
     HAS_TORCHVISION = False
 skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
 
-# try:
-#     from unification import Var
-#     HAS_UNIFICATION = True
-# except ImportError:
-#     HAS_UNIFICATION = False
-# skipIfNoUnification = unittest.skipIf(not HAS_UNIFICATION, "no unification")
-
 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
     """3x3 convolution with padding"""
     return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
@@ -270,10 +270,9 @@ class TypeCheckerTest(unittest.TestCase):
     def test_type_check_batch_norm_2D(self):
         class BasicBlock(torch.nn.Module):
 
-            def __init__(self, inplanes, planes, norm_layer=None):
+            def __init__(self, inplanes, planes):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.bn1 = norm_layer(planes)
 
             def forward(self, x: TensorType((2, 2, 5, 4))):
@@ -302,10 +301,9 @@ class TypeCheckerTest(unittest.TestCase):
     def test_type_check_batch_norm_2D_false(self):
         class BasicBlock(torch.nn.Module):
 
-            def __init__(self, inplanes, planes, norm_layer=None):
+            def __init__(self, inplanes, planes):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.bn1 = norm_layer(planes)
 
             def forward(self, x: TensorType((2, 2, 5))):
@@ -325,10 +323,9 @@ class TypeCheckerTest(unittest.TestCase):
     def test_type_check_batch_norm_2D_broadcast(self):
         class BasicBlock(torch.nn.Module):
 
-            def __init__(self, inplanes, planes, norm_layer=None):
+            def __init__(self, inplanes, planes):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.bn1 = norm_layer(planes)
 
             def forward(self, x: Dyn):
@@ -363,10 +360,9 @@ class TypeCheckerTest(unittest.TestCase):
 
     def test_type_check_conv2D(self):
         class BasicBlock(torch.nn.Module):
-            def __init__(self, inplanes, planes, stride=1, norm_layer=None):
+            def __init__(self, inplanes, planes, stride=1):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.conv1 = conv3x3(inplanes, planes, stride)
                 self.bn1 = norm_layer(planes)
 
@@ -394,10 +390,9 @@ class TypeCheckerTest(unittest.TestCase):
 
     def test_type_check_conv2D_2(self):
         class BasicBlock(torch.nn.Module):
-            def __init__(self, inplanes, planes, stride=1, norm_layer=None):
+            def __init__(self, inplanes, planes, stride=1):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.conv1 = conv3x3(inplanes, planes, stride)
                 self.bn1 = norm_layer(planes)
 
@@ -434,7 +429,6 @@ class TypeCheckerTest(unittest.TestCase):
         with self.assertRaises(TypeError):
             tc.type_check()
 
-
     def test_type_check_conv2D_2_fully_static(self):
         annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
                            (10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
@@ -522,16 +516,14 @@ class TypeCheckerTest(unittest.TestCase):
                     assert n.type == TensorType(output_types[i])
                     assert is_consistent(n.type, TensorType(b.size()))
 
-
     def test_typecheck_basicblock(self):
         class BasicBlock(torch.nn.Module):
             expansion = 1
 
             def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
-                         base_width=64, dilation=1, norm_layer=None):
+                         base_width=64, dilation=1):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 if groups != 1 or base_width != 64:
                     raise ValueError('BasicBlock only supports groups=1 and base_width=64')
                 if dilation > 1:
@@ -643,7 +635,6 @@ class TypeCheckerTest(unittest.TestCase):
             if n.op == 'output':
                 assert n.type == TensorType((1, Dyn, 5, Dyn))
 
-
     def test_type_check_flatten3(self):
         class M(torch.nn.Module):
             def forward(self, x: TensorType((2, 3, 4, 5))):
@@ -661,7 +652,6 @@ class TypeCheckerTest(unittest.TestCase):
         c = r.constraints
         assert c == [Equality(2, 2)]
 
-
     def test_type_typechecl_maxpool2d_3dinput(self):
 
         class BasicBlock(torch.nn.Module):
@@ -770,7 +760,6 @@ class TypeCheckerTest(unittest.TestCase):
                     assert n.type == TensorType(output_types[i])
                     assert is_consistent(n.type, TensorType(b.size()))
 
-
     def test_flatten_fully_static(self):
         annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)),
                            TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))]
@@ -816,6 +805,7 @@ class TypeCheckerTest(unittest.TestCase):
                 if n.op == 'output':
                     assert is_consistent(n.type, TensorType(b.size()))
 
+    @skipIfNoSympy
     @skipIfNoTorchVision
     def test_resnet50(self):
         gm_run = symbolic_trace(resnet50())
@@ -859,14 +849,13 @@ class TypeCheckerTest(unittest.TestCase):
             batch_sizes.add(n.type.__args__[0])
         assert (len(batch_sizes) == 1)
 
-
+    @skipIfNoSympy
     def test_type_check_batch_norm_symbolic(self):
         class BasicBlock(torch.nn.Module):
 
-            def __init__(self, inplanes, planes, norm_layer=None):
+            def __init__(self, inplanes, planes):
                 super(BasicBlock, self).__init__()
-                if norm_layer is None:
-                    norm_layer = torch.nn.BatchNorm2d
+                norm_layer = torch.nn.BatchNorm2d
                 self.bn1 = norm_layer(planes)
 
             def forward(self, x: Dyn):
@@ -884,15 +873,15 @@ class TypeCheckerTest(unittest.TestCase):
 
         infer_symbolic_types(traced)
 
-
-        my_types = iter([TensorType[(2, 2, Var(7), 4)],
-                         TensorType[(2, 2, Var(7), 4)],
-                         TensorType[(2, 2, Var(7), 4)],
-                         TensorType[(2, 2, Var(7), 4)]])
+        my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)],
+                         TensorType[(2, 2, sympy.symbols('~7'), 4)],
+                         TensorType[(2, 2, sympy.symbols('~7'), 4)],
+                         TensorType[(2, 2, sympy.symbols('~7'), 4)]])
 
         for n in graph.nodes:
             assert n.type == next(my_types)
 
+    @skipIfNoSympy
     def test_symbolic_add_with_broadcast(self):
         class M(torch.nn.Module):
             def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
@@ -911,16 +900,17 @@ class TypeCheckerTest(unittest.TestCase):
 
         infer_symbolic_types(symbolic_traced)
 
-        expected_ph_types = [TensorType((1, 2, 3, Var(0))),
+        expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))),
                              TensorType((2, 3, 4)),
-                             TensorType((1, 2, 3, Var(1))),
-                             TensorType((1, 2, 3, Var(1)))]
+                             TensorType((1, 2, 3, sympy.symbols('~1'))),
+                             TensorType((1, 2, 3, sympy.symbols('~1')))]
         expected_iter = iter(expected_ph_types)
 
+
         for n in symbolic_traced.graph.nodes:
             assert n.type == next(expected_iter)
 
-
+    @skipIfNoSympy
     def test_symbolic_add_with_broadcast_2(self):
         class M(torch.nn.Module):
             def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
@@ -934,13 +924,80 @@ class TypeCheckerTest(unittest.TestCase):
         r.refine()
 
         expected_ph_types = [TensorType((1, 2)),
-                             TensorType((Var(1), 2)),
-                             TensorType((Var(1), 2)),
-                             TensorType((Var(1), 2))]
+                             TensorType((sympy.symbols('~1'), 2)),
+                             TensorType((sympy.symbols('~1'), 2)),
+                             TensorType((sympy.symbols('~1'), 2))]
         expected_iter = iter(expected_ph_types)
 
         for n in symbolic_traced.graph.nodes:
             assert n.type == next(expected_iter)
 
+    @skipIfNoSympy
+    def test_type_check_conv2D_types(self):
+        class BasicBlock(torch.nn.Module):
+            def __init__(self, inplanes, planes, stride=1):
+                super(BasicBlock, self).__init__()
+                norm_layer = torch.nn.BatchNorm2d
+                self.conv1 = conv3x3(inplanes, planes, stride)
+                self.bn1 = norm_layer(planes)
+
+            def forward(self, x: Dyn):
+                identity = x
+                out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
+                out += identity
+                return out
+
+        B = BasicBlock(2, 2)
+        ast_rewriter = RewritingTracer()
+        graph = ast_rewriter.trace(B)
+        traced = GraphModule(ast_rewriter.root, graph, "gm")
+        tc = GraphTypeChecker({}, traced)
+        tc.type_check()
+        infer_symbolic_types(traced)
+
+        for n in traced.graph.nodes:
+            if n.op == 'call_module':
+                assert isinstance(n.type.__args__[2], sympy.floor)
+                assert isinstance(n.type.__args__[3], sympy.floor)
+
+    @skipIfNoSympy
+    def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
+
+        class BasicBlock(torch.nn.Module):
+            def __init__(self):
+                super(BasicBlock, self).__init__()
+
+                self.conv1 = torch.nn.Conv2d(3, 6, 5)
+                self.pool = torch.nn.MaxPool2d(2, 2)
+                self.conv2 = torch.nn.Conv2d(6, 16, 5)
+                self.fc1 = torch.nn.Linear(5, 120)
+                self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))
+
+            def forward(self, x : TensorType((4, 3, Dyn, Dyn))):
+                out = self.conv1(x)
+                out = self.pool(out)
+                out = self.conv2(out)
+                out = self.pool(out)
+                out = self.fc1(out)
+                out = self.pool2(out)
+                out = torch.flatten(out, 1)
+                return out
+
+        B = BasicBlock()
+        ast_rewriter = RewritingTracer()
+        traced = symbolic_trace(B)
+        tc = GraphTypeChecker({}, traced)
+        tc.type_check()
+        infer_symbolic_types(traced)
+
+        for n in traced.graph.nodes:
+            if n.target == 'conv1':
+                assert n.type == TensorType((4, 6, sympy.floor((sympy.symbols('~0') - 4)),
+                                             sympy.floor((sympy.symbols('~1') - 4))))
+
+            elif n.target == 'conv2':
+                assert n.type == TensorType((4, 16, sympy.floor((sympy.symbols('~4') - 4)),
+                                             sympy.floor((sympy.symbols('~5') - 4))))
+
 if __name__ == '__main__':
     unittest.main()
index 6e05f91..a54e521 100644 (file)
@@ -9,12 +9,18 @@ from torch.nn.modules.conv import Conv2d
 from torch.fx.experimental.refinement_types import Equality
 import itertools
 
-
 from torch.fx.experimental.unification import Var  # type: ignore[attr-defined]
 
 
+try:
+    import sympy  # type: ignore[import]
+    HAS_SYMPY = True
+except ImportError:
+    HAS_SYMPY = False
+
 _INFERENCE_RULES: Dict[Target, Callable] = {}
 _REFINEMENT_RULES: Dict[Target, Callable] = {}
+_RULES: Dict[Target, Callable] = {}
 
 
 def expand_to_tensor_dim(t, n):
@@ -84,6 +90,13 @@ def register_refinement_rule(call_target):
         return fn
     return register
 
+def register_algebraic_expressions_inference_rule(call_target):
+    def register(fn):
+        if call_target in _RULES:
+            raise RuntimeError('Rule already registered for {call_target}!')
+        _RULES[call_target] = fn
+        return fn
+    return register
 
 @register_inference_rule(torch.add)
 @register_inference_rule(operator.add)
@@ -258,10 +271,12 @@ def calculate_out_dimension(d_in, module_instance, index):
     dilation = (module_instance.dilation, module_instance.dilation) \
         if isinstance(module_instance.dilation, int) else module_instance.dilation
 
+    DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,)
+
     if d_in == Dyn:
         return Dyn
 
-    elif isinstance(d_in, int):
+    elif isinstance(d_in, DIMENSION_TYPES):
         n = d_in + 2 * padding[index] - \
             dilation[index] * \
             (kernel_size[index] - 1) - 1
@@ -269,7 +284,7 @@ def calculate_out_dimension(d_in, module_instance, index):
         return (n // stride[0]) + 1
 
     else:
-        raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn')
+        raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}')
 
 
 def get_greatest_upper_bound(type1, type2):
@@ -552,8 +567,17 @@ class GraphTypeChecker:
 
 
 @register_refinement_rule(Conv2d)
+def conv_refinement_rule(n: Node):
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
+        return res
+
+
 @register_refinement_rule(torch.nn.Linear)
-def first_one(n: Node):
+def linear_refinement_rule(n: Node):
     res = []
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
@@ -564,7 +588,6 @@ def first_one(n: Node):
 # todo needs review for addition. Is this constraint correct?
 @register_refinement_rule(BatchNorm2d)
 @register_refinement_rule(torch.nn.ReLU)
-@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
 def all_eq(n: Node):
     res = []
     assert isinstance(n.args[0], Node)
@@ -575,6 +598,18 @@ def all_eq(n: Node):
         res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
     return res
 
+
+@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
+def first_two__eq(n: Node):
+    res = []
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        args1 = arg_type.__args__
+        args2 = n.type.__args__
+        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
+    return res
+
 @register_refinement_rule(torch.add)
 @register_refinement_rule(operator.add)
 def add_eq(n: Node):
@@ -636,6 +671,20 @@ def flatten_refinement_rule(n: Node):
             eq_const.append(Equality(t1, t2))
     return eq_const
 
+
+@register_algebraic_expressions_inference_rule(Conv2d)
+def conv_rule(n: Node, module_instance):
+    assert isinstance(n.args[0], Node)
+    arg_type = n.args[0].type
+    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
+        w_in = arg_type.__args__[3]
+        h_in = arg_type.__args__[2]
+        h_out = calculate_out_dimension(h_in, module_instance, 0)
+        w_out = calculate_out_dimension(w_in, module_instance, 1)
+        new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out))
+        n.type = new_type
+        return new_type
+
 class Refine:
     """
     Symbolic shape inference.
@@ -658,6 +707,15 @@ class Refine:
             self.refine_node(n)
         return True
 
+    def symbolic_relations(self):
+        """
+        Infers algebraic relations
+        """
+        graph = self.traced.graph
+        for n in graph.nodes:
+            self.infer_symbolic_relations(n)
+        return True
+
     def replace_dyn_with_fresh_var(self, typ):
         """
         Replace all unknown types with fresh type variables.
@@ -675,6 +733,26 @@ class Refine:
         else:
             return typ
 
+
+    def convert_to_sympy_symbols(self, typ):
+        """
+        Replace all unknown types with fresh type variables.
+        """
+        if HAS_SYMPY:
+            if isinstance(typ, Var):
+                return sympy.symbols(str(typ))
+            elif isinstance(typ, TensorType):
+                new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
+                return TensorType(tuple(new_args))
+            elif isinstance(typ, list):
+                return [self.convert_to_sympy_symbols(t) for t in typ]
+            elif isinstance(typ, tuple):
+                return (self.convert_to_sympy_symbols(t) for t in typ)
+            else:
+                return typ
+        else:
+            return typ
+
     def refine_node(self, n: Node):
         """
         Returns a list of equality constraints for
@@ -710,6 +788,32 @@ class Refine:
         else:
             pass
 
+    def infer_symbolic_relations(self, n: Node):
+        if HAS_SYMPY:
+            n.type = self.convert_to_sympy_symbols(n.type)
+            if n.op == 'call_function':
+                if n.target in _RULES:
+                    return _RULES[n.target](n)
+                else:
+                    pass
+
+            if n.op == 'call_module':
+                module_instance = self.traced.get_submodule(n.target)
+                if type(module_instance) in _RULES:
+                    return _RULES[type(module_instance)](n, module_instance)
+                else:
+                    pass
+
+            if n.op == 'output':
+                def get_node_type(a):
+                    return a.type
+                n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+                return n.type
+
+            else:
+                pass
+        else:
+            pass
 
 def get_parameter(traced, target: str):
     """
index 5074377..532d278 100644 (file)
@@ -2,11 +2,10 @@ from torch.fx.experimental.graph_gradual_typechecker import Refine
 from torch.fx.tensor_type import TensorType
 from torch.fx.experimental.unification import Var, unify  # type: ignore[attr-defined]
 
+
 def infer_symbolic_types_single_pass(traced):
     """
-    Generate constraints over types,
-    solve constraints with unification,
-    apply solution back to the types
+    Calls our symbolic inferencer once.
     """
     r = Refine(traced)
     r.refine()
@@ -20,8 +19,17 @@ def infer_symbolic_types(traced):
     to infer all the information such as the case
     for braodcasting.
     """
-    infer_symbolic_types_single_pass(traced)
-    infer_symbolic_types_single_pass(traced)
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r = Refine(traced)
+    r.refine()
+    mgu = unify_eq(r.constraints)
+    substitute_all_types(traced.graph, mgu)
+
+    r.symbolic_relations()
 
 def convert_eq(list_of_eq):
     """