From 59c6ceb6a8338c5de3f3aee7b7790b1d0daefb0a Mon Sep 17 00:00:00 2001 From: Zeina Migeed Date: Wed, 1 Sep 2021 18:04:19 -0700 Subject: [PATCH] add documentation to shape inference algorithm (#64312) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64312 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D30709254 Pulled By: migeed-z fbshipit-source-id: 3297d26fe6727c5b9ca176625b1683d787f59659 --- torch/fx/experimental/graph_gradual_typechecker.py | 152 +++++++++++++++------ 1 file changed, 114 insertions(+), 38 deletions(-) diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index a54e521..6094952 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -28,7 +28,7 @@ def expand_to_tensor_dim(t, n): Expand a type to the desired tensor dimension if possible Raise an error otherwise. - t is the given type - - n is a number to expand to + - n is a number of dimensions to expand to """ if t == Dyn: dims = [Dyn] * n @@ -42,6 +42,13 @@ def expand_to_tensor_dim(t, n): def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): return t1, t2 @@ -52,7 +59,8 @@ def broadcast_types(t1, t2): new_t1 = list(t1.__args__) new_t2 = list(t2.__args__) - # here, we make our tensors the same length + # We make the types the same length which is the first requirement + # for consistency if s1 > s2: for i in range(s1 - s2): new_t2.insert(0, 1) @@ -61,15 +69,18 @@ def broadcast_types(t1, t2): for i in range(s2 - s1): new_t1.insert(0, 1) + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor for i, (x, y) in enumerate(zip(new_t1, new_t2)): if x == 1: new_t1[i] = y elif y == 1: new_t2[i] = x + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) - - return (t1, t2) else: raise TypeError(f'Cannot broadcast types {t1} and {t2}') @@ -77,7 +88,7 @@ def broadcast_types(t1, t2): def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError('Inference rule already registered for {call_target}!') + raise RuntimeError(f'Inference rule already registered for {call_target}!') _INFERENCE_RULES[call_target] = fn return fn return register @@ -85,7 +96,7 @@ def register_inference_rule(call_target): def register_refinement_rule(call_target): def register(fn): if call_target in _REFINEMENT_RULES: - raise RuntimeError('Refinement rule already registered for {call_target}!') + raise RuntimeError(f'Refinement rule already registered for {call_target}!') _REFINEMENT_RULES[call_target] = fn return fn return register @@ -93,7 +104,7 @@ def register_refinement_rule(call_target): def register_algebraic_expressions_inference_rule(call_target): def register(fn): if call_target in _RULES: - raise RuntimeError('Rule already registered for {call_target}!') + raise RuntimeError(f'Rule already registered for {call_target}!') _RULES[call_target] = fn return fn return register @@ -101,6 +112,17 @@ def register_algebraic_expressions_inference_rule(call_target): @register_inference_rule(torch.add) @register_inference_rule(operator.add) def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) t1 = n.args[0].type @@ -111,10 +133,15 @@ def add_inference_rule(n: Node): n.type = t2 return n.type + # handle scalar addition elif t2 == int and isinstance(t1, TensorType): n.type = t1 return n.type + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: @@ -122,13 +149,13 @@ def add_inference_rule(n: Node): n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 - # Todo: maybe figure out that broadcasting definitely did not happen? else: n.meta['broadcast'] = False new_t1 = t1 if not n.meta['broadcast'] else new_t1 new_t2 = t2 if not n.meta['broadcast'] else new_t2 + # we check for consistency between the new types if is_consistent(new_t1, new_t2): # we return the less precise type because # broadcasting may have happened @@ -145,6 +172,12 @@ def add_inference_rule(n: Node): @register_inference_rule(getattr) def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ attr_node = n.args[0] attr_name = n.args[1] @@ -158,6 +191,10 @@ def get_attr_inference_rule(n: Node, traced): @register_inference_rule(torch.transpose) def transpose_inference_rule(n: Node): + """ + We check that dimentions for the transpose operations + are within range of the tensor type of the node + """ if n.target == torch.transpose: assert isinstance(n.args[0], Node) t = n.args[0].type @@ -171,12 +208,11 @@ def transpose_inference_rule(n: Node): return n.type elif isinstance(t, TensorType): - if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): new_type = list(t.__args__) new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] final = TensorType(new_type) - n.type = final + n.type = get_greatest_upper_bound(n.type, final) return n.type else: raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') @@ -186,6 +222,15 @@ def transpose_inference_rule(n: Node): @register_inference_rule(torch.reshape) def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ assert isinstance(n.args[0], Node) t1 = n.args[0].type @@ -201,7 +246,7 @@ def reshape_inference_rule(n: Node): # if any of the dimensions are unknown, # we check for divisibility - elif isinstance(t1, TensorType) and Dyn in t1.__args__ or -1 in t2: + elif isinstance(t1, TensorType): assert isinstance(t1, TensorType) a = [e if e != Dyn else 1 for e in t1.__args__] p1 = reduce(lambda x, y: x * y, a) @@ -211,17 +256,6 @@ def reshape_inference_rule(n: Node): return t2_type else: raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - - # if all dimensions are known we check the products - elif isinstance(t1, TensorType): - p1 = reduce(lambda x, y: x * y, t1.__args__) - p2 = reduce(lambda x, y: x * y, t2) - if p1 == p2: - n.type = t2_type - return t2_type - else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') - else: raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') @@ -260,7 +294,7 @@ def bn2d_inference_rule(n: Node, module_instance): def calculate_out_dimension(d_in, module_instance, index): """ - For calculating h_in and w_out. + For calculating h_in and w_out according to the conv2D documentation """ padding = (module_instance.padding, module_instance.padding) \ if isinstance(module_instance.padding, int) else module_instance.padding @@ -346,6 +380,10 @@ def relu_inference_rule(n: Node, module_instance): def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ new_type_list = list(typ.__args__) if len(new_type_list) == 4 or len(new_type_list) == 3: w_in = new_type_list[-1] @@ -391,7 +429,6 @@ def linear_check(tensor_type, module_instance): """ if len(tensor_type.__args__) >= 2: if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): - # Todo backwards propagation new_type_args = list(tensor_type.__args__) new_type_args[-1] = module_instance.out_features return TensorType(tuple(new_type_args)) @@ -403,6 +440,10 @@ def linear_check(tensor_type, module_instance): @register_inference_rule(torch.nn.Linear) def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ assert isinstance(n.args[0], Node) if n.args[0].type == Dyn and isinstance(n.type, TensorType): n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) @@ -470,6 +511,10 @@ def flatten_check(tensor_type, start_dim, end_dim): @register_inference_rule(torch.flatten) def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ assert isinstance(n.args[0], Node) # set the default start and end dims @@ -568,6 +613,10 @@ class GraphTypeChecker: @register_refinement_rule(Conv2d) def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -578,6 +627,10 @@ def conv_refinement_rule(n: Node): @register_refinement_rule(torch.nn.Linear) def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -585,10 +638,12 @@ def linear_refinement_rule(n: Node): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res -# todo needs review for addition. Is this constraint correct? @register_refinement_rule(BatchNorm2d) @register_refinement_rule(torch.nn.ReLU) def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -600,7 +655,12 @@ def all_eq(n: Node): @register_refinement_rule(torch.nn.AdaptiveAvgPool2d) -def first_two__eq(n: Node): +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type @@ -610,19 +670,37 @@ def first_two__eq(n: Node): res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] return res + @register_refinement_rule(torch.add) @register_refinement_rule(operator.add) -def add_eq(n: Node): +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ res = [] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): args1, args2 = broadcast_types(arg_type1, arg_type2) - # by this point, we know for sure that args1 and args2 are the same size. + # by this point, we know that args1 and args2 are the same size. a1 = args1.__args__ a2 = args2.__args__ a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions r = [] for x, y, z in zip(a1, a2, a3): if x == y: @@ -630,19 +708,13 @@ def add_eq(n: Node): res = r return res -@register_refinement_rule(torch.nn.MaxPool2d) -def first_two(n: Node): - res = [] - assert isinstance(n.args[0], Node) - arg_type = n.args[0].type - if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): - args1 = arg_type.__args__ - args2 = n.type.__args__ - res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] - return res @register_refinement_rule(torch.flatten) def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ assert isinstance(n.args[0], Node) eq_const = [] @@ -674,6 +746,10 @@ def flatten_refinement_rule(n: Node): @register_algebraic_expressions_inference_rule(Conv2d) def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): -- 2.7.4