add documentation to shape inference algorithm (#64312)
authorZeina Migeed <migeedz@fb.com>
Thu, 2 Sep 2021 01:04:19 +0000 (18:04 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 01:34:17 +0000 (18:34 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64312

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D30709254

Pulled By: migeed-z

fbshipit-source-id: 3297d26fe6727c5b9ca176625b1683d787f59659

torch/fx/experimental/graph_gradual_typechecker.py

index a54e521..6094952 100644 (file)
@@ -28,7 +28,7 @@ def expand_to_tensor_dim(t, n):
     Expand a type to the desired tensor dimension if possible
     Raise an error otherwise.
     - t is the given type
-    - n is a number to expand to
+    - n is a number of dimensions to expand to
     """
     if t == Dyn:
         dims = [Dyn] * n
@@ -42,6 +42,13 @@ def expand_to_tensor_dim(t, n):
 
 
 def broadcast_types(t1, t2):
+    """
+    Applies broadcasting to both given types such that they
+    become consistent with eachother and returns two new
+    resulting types
+    """
+
+    # if either type is Dyn, do nothing since the types are already consistent
     if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
         return t1, t2
 
@@ -52,7 +59,8 @@ def broadcast_types(t1, t2):
         new_t1 = list(t1.__args__)
         new_t2 = list(t2.__args__)
 
-        # here, we make our tensors the same length
+        # We make the types the same length which is the first requirement
+        # for consistency
         if s1 > s2:
             for i in range(s1 - s2):
                 new_t2.insert(0, 1)
@@ -61,15 +69,18 @@ def broadcast_types(t1, t2):
             for i in range(s2 - s1):
                 new_t1.insert(0, 1)
 
+        # we replace occurrences of "1" with each tensor with
+        # the corresponding type from the other tensor
         for i, (x, y) in enumerate(zip(new_t1, new_t2)):
             if x == 1:
                 new_t1[i] = y
             elif y == 1:
                 new_t2[i] = x
 
+        # at this point our tensors should be consistent
+        # and we can apply the element-wise operation and find the right dimension
+        # for the output of the operation
         (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
-
-
         return (t1, t2)
     else:
         raise TypeError(f'Cannot broadcast types {t1} and {t2}')
@@ -77,7 +88,7 @@ def broadcast_types(t1, t2):
 def register_inference_rule(call_target):
     def register(fn):
         if call_target in _INFERENCE_RULES:
-            raise RuntimeError('Inference rule already registered for {call_target}!')
+            raise RuntimeError(f'Inference rule already registered for {call_target}!')
         _INFERENCE_RULES[call_target] = fn
         return fn
     return register
@@ -85,7 +96,7 @@ def register_inference_rule(call_target):
 def register_refinement_rule(call_target):
     def register(fn):
         if call_target in _REFINEMENT_RULES:
-            raise RuntimeError('Refinement rule already registered for {call_target}!')
+            raise RuntimeError(f'Refinement rule already registered for {call_target}!')
         _REFINEMENT_RULES[call_target] = fn
         return fn
     return register
@@ -93,7 +104,7 @@ def register_refinement_rule(call_target):
 def register_algebraic_expressions_inference_rule(call_target):
     def register(fn):
         if call_target in _RULES:
-            raise RuntimeError('Rule already registered for {call_target}!')
+            raise RuntimeError(f'Rule already registered for {call_target}!')
         _RULES[call_target] = fn
         return fn
     return register
@@ -101,6 +112,17 @@ def register_algebraic_expressions_inference_rule(call_target):
 @register_inference_rule(torch.add)
 @register_inference_rule(operator.add)
 def add_inference_rule(n: Node):
+    """
+    Apply the addition inference rule. This includes:
+    - scalar addition
+    - broadcasting semantics
+
+    Note that we always return the least precise type between
+    the operands (after applying broadcasting) to be the final type of the operation
+
+    Note that we do not modify the operand types themselves after applying broadcasting
+    to them. We only use them to calculate the final type
+    """
     assert isinstance(n.args[0], Node)
     assert isinstance(n.args[1], Node)
     t1 = n.args[0].type
@@ -111,10 +133,15 @@ def add_inference_rule(n: Node):
         n.type = t2
         return n.type
 
+    # handle scalar addition
     elif t2 == int and isinstance(t1, TensorType):
         n.type = t1
         return n.type
 
+    # we bring the new types to the point where
+    # we can check for consistency
+    # any inconsistency would not have been caused
+    # by broadcasting at this point
     (new_t1, new_t2) = broadcast_types(t1, t2)
 
     if new_t1 != t1 or new_t2 != t2:
@@ -122,13 +149,13 @@ def add_inference_rule(n: Node):
         n.meta[str(n.args[0])] = new_t1
         n.meta[str(n.args[1])] = new_t2
 
-    # Todo: maybe figure out that broadcasting definitely did not happen?
     else:
         n.meta['broadcast'] = False
 
     new_t1 = t1 if not n.meta['broadcast'] else new_t1
     new_t2 = t2 if not n.meta['broadcast'] else new_t2
 
+    # we check for consistency between the new types
     if is_consistent(new_t1, new_t2):
         # we return the less precise type because
         # broadcasting may have happened
@@ -145,6 +172,12 @@ def add_inference_rule(n: Node):
 
 @register_inference_rule(getattr)
 def get_attr_inference_rule(n: Node, traced):
+    """
+    The current getattr rule only handles the shape attribute
+    Can be extended to other attributes
+    The most representitive type we have is "Dyn" but the system
+    can be extended with more types, such as a type to represent shapes
+    """
     attr_node = n.args[0]
     attr_name = n.args[1]
 
@@ -158,6 +191,10 @@ def get_attr_inference_rule(n: Node, traced):
 
 @register_inference_rule(torch.transpose)
 def transpose_inference_rule(n: Node):
+    """
+    We check that dimentions for the transpose operations
+    are within range of the tensor type of the node
+    """
     if n.target == torch.transpose:
         assert isinstance(n.args[0], Node)
         t = n.args[0].type
@@ -171,12 +208,11 @@ def transpose_inference_rule(n: Node):
             return n.type
 
         elif isinstance(t, TensorType):
-
             if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
                 new_type = list(t.__args__)
                 new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
                 final = TensorType(new_type)
-                n.type = final
+                n.type = get_greatest_upper_bound(n.type, final)
                 return n.type
             else:
                 raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
@@ -186,6 +222,15 @@ def transpose_inference_rule(n: Node):
 
 @register_inference_rule(torch.reshape)
 def reshape_inference_rule(n: Node):
+    """
+    Without dynamism, the rule checks that the
+    product of the elements of the argument tensor
+    type is equal to the product of the elements
+    of the required shape. We gradualize this rule
+    by adding a case to handle fully dynamic input
+    as well as input where some of the tensor dimensions
+    are unknown. In this case we check for divisibility
+    """
     assert isinstance(n.args[0], Node)
     t1 = n.args[0].type
 
@@ -201,7 +246,7 @@ def reshape_inference_rule(n: Node):
 
     # if any of the dimensions are unknown,
     # we check for divisibility
-    elif isinstance(t1, TensorType) and Dyn in t1.__args__ or -1 in t2:
+    elif isinstance(t1, TensorType):
         assert isinstance(t1, TensorType)
         a = [e if e != Dyn else 1 for e in t1.__args__]
         p1 = reduce(lambda x, y: x * y, a)
@@ -211,17 +256,6 @@ def reshape_inference_rule(n: Node):
             return t2_type
         else:
             raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
-
-    # if all dimensions are known we check the products
-    elif isinstance(t1, TensorType):
-        p1 = reduce(lambda x, y: x * y, t1.__args__)
-        p2 = reduce(lambda x, y: x * y, t2)
-        if p1 == p2:
-            n.type = t2_type
-            return t2_type
-        else:
-            raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
-
     else:
         raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
 
@@ -260,7 +294,7 @@ def bn2d_inference_rule(n: Node, module_instance):
 
 def calculate_out_dimension(d_in, module_instance, index):
     """
-    For calculating h_in and w_out.
+    For calculating h_in and w_out according to the conv2D documentation
     """
     padding = (module_instance.padding, module_instance.padding) \
         if isinstance(module_instance.padding, int) else module_instance.padding
@@ -346,6 +380,10 @@ def relu_inference_rule(n: Node, module_instance):
 
 
 def maxpool2d_check(typ, module_instance):
+    """
+    Applies the maxpool2d shape information to the input
+    this affects the last two dimensions
+    """
     new_type_list = list(typ.__args__)
     if len(new_type_list) == 4 or len(new_type_list) == 3:
         w_in = new_type_list[-1]
@@ -391,7 +429,6 @@ def linear_check(tensor_type, module_instance):
     """
     if len(tensor_type.__args__) >= 2:
         if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
-            # Todo backwards propagation
             new_type_args = list(tensor_type.__args__)
             new_type_args[-1] = module_instance.out_features
             return TensorType(tuple(new_type_args))
@@ -403,6 +440,10 @@ def linear_check(tensor_type, module_instance):
 
 @register_inference_rule(torch.nn.Linear)
 def linear_inference_rule(n: Node, module_instance):
+    """
+    Applies the shape information to the input then gets the greatest upper bound
+    of the resulting type and the existing type
+    """
     assert isinstance(n.args[0], Node)
     if n.args[0].type == Dyn and isinstance(n.type, TensorType):
         n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
@@ -470,6 +511,10 @@ def flatten_check(tensor_type, start_dim, end_dim):
 
 @register_inference_rule(torch.flatten)
 def flatten_inference_rule(n: Node):
+    """
+    Applies the flatten shape information to the input then gets the
+    greatest upper bound of the resulting type and the existing type
+    """
     assert isinstance(n.args[0], Node)
 
     # set the default start and end dims
@@ -568,6 +613,10 @@ class GraphTypeChecker:
 
 @register_refinement_rule(Conv2d)
 def conv_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
     res = []
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
@@ -578,6 +627,10 @@ def conv_refinement_rule(n: Node):
 
 @register_refinement_rule(torch.nn.Linear)
 def linear_refinement_rule(n: Node):
+    """
+    The equality constraints are between the first dimension of
+    the input and output
+    """
     res = []
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
@@ -585,10 +638,12 @@ def linear_refinement_rule(n: Node):
         res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
     return res
 
-# todo needs review for addition. Is this constraint correct?
 @register_refinement_rule(BatchNorm2d)
 @register_refinement_rule(torch.nn.ReLU)
 def all_eq(n: Node):
+    """
+    For operations where the input shape is equal to the output shape
+    """
     res = []
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
@@ -600,7 +655,12 @@ def all_eq(n: Node):
 
 
 @register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
-def first_two__eq(n: Node):
+@register_refinement_rule(torch.nn.MaxPool2d)
+def first_two_eq(n: Node):
+    """
+    For operations where the first two dimensions of the input and output shape
+    are equal
+    """
     res = []
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
@@ -610,19 +670,37 @@ def first_two__eq(n: Node):
         res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
     return res
 
+
 @register_refinement_rule(torch.add)
 @register_refinement_rule(operator.add)
-def add_eq(n: Node):
+def element_wise_eq(n: Node):
+    """
+    For element-wise operations and handles broadcasting.
+    Note that after applying broadcasting to the arguments
+    we are able to determine if certain dimensions have not been broadcast
+    if they are symbolicallu equal.
+
+    in this case, we can establish equality between those dimensions and the
+    corresponding output dimensions.
+
+    Note that it takes two iterations for this result. One iteration to establish
+    equality between certain dimensions of the operands (requiring the whole solver
+    including unification) and another iteration to establish equality between the operands
+    and the resulting type, requiring another round of constraint generation and unificaiton.
+    """
     res = []
     if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
         arg_type1 = n.args[0].type
         arg_type2 = n.args[1].type
         if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
             args1, args2 = broadcast_types(arg_type1, arg_type2)
-            # by this point, we know for sure that args1 and args2 are the same size.
+            # by this point, we know that args1 and args2 are the same size.
             a1 = args1.__args__
             a2 = args2.__args__
             a3 = n.type.__args__
+
+            # we would be here in the second iteration where we establish equality
+            # between operand type dimensions and the resulting type dimensions
             r = []
             for x, y, z in zip(a1, a2, a3):
                 if x == y:
@@ -630,19 +708,13 @@ def add_eq(n: Node):
             res = r
     return res
 
-@register_refinement_rule(torch.nn.MaxPool2d)
-def first_two(n: Node):
-    res = []
-    assert isinstance(n.args[0], Node)
-    arg_type = n.args[0].type
-    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
-        args1 = arg_type.__args__
-        args2 = n.type.__args__
-        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
-    return res
 
 @register_refinement_rule(torch.flatten)
 def flatten_refinement_rule(n: Node):
+    """
+    Generates equality constraints between the dimensions of the input and output
+    that will not be involved in the flatten operation
+    """
     assert isinstance(n.args[0], Node)
 
     eq_const = []
@@ -674,6 +746,10 @@ def flatten_refinement_rule(n: Node):
 
 @register_algebraic_expressions_inference_rule(Conv2d)
 def conv_rule(n: Node, module_instance):
+    """
+    Represents the outout in terms of an algrbraic expression w.r.t
+    the input when possible
+    """
     assert isinstance(n.args[0], Node)
     arg_type = n.args[0].type
     if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):