Expand a type to the desired tensor dimension if possible
Raise an error otherwise.
- t is the given type
- - n is a number to expand to
+ - n is a number of dimensions to expand to
"""
if t == Dyn:
dims = [Dyn] * n
def broadcast_types(t1, t2):
+ """
+ Applies broadcasting to both given types such that they
+ become consistent with eachother and returns two new
+ resulting types
+ """
+
+ # if either type is Dyn, do nothing since the types are already consistent
if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var):
return t1, t2
new_t1 = list(t1.__args__)
new_t2 = list(t2.__args__)
- # here, we make our tensors the same length
+ # We make the types the same length which is the first requirement
+ # for consistency
if s1 > s2:
for i in range(s1 - s2):
new_t2.insert(0, 1)
for i in range(s2 - s1):
new_t1.insert(0, 1)
+ # we replace occurrences of "1" with each tensor with
+ # the corresponding type from the other tensor
for i, (x, y) in enumerate(zip(new_t1, new_t2)):
if x == 1:
new_t1[i] = y
elif y == 1:
new_t2[i] = x
+ # at this point our tensors should be consistent
+ # and we can apply the element-wise operation and find the right dimension
+ # for the output of the operation
(t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2))
-
-
return (t1, t2)
else:
raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def register_inference_rule(call_target):
def register(fn):
if call_target in _INFERENCE_RULES:
- raise RuntimeError('Inference rule already registered for {call_target}!')
+ raise RuntimeError(f'Inference rule already registered for {call_target}!')
_INFERENCE_RULES[call_target] = fn
return fn
return register
def register_refinement_rule(call_target):
def register(fn):
if call_target in _REFINEMENT_RULES:
- raise RuntimeError('Refinement rule already registered for {call_target}!')
+ raise RuntimeError(f'Refinement rule already registered for {call_target}!')
_REFINEMENT_RULES[call_target] = fn
return fn
return register
def register_algebraic_expressions_inference_rule(call_target):
def register(fn):
if call_target in _RULES:
- raise RuntimeError('Rule already registered for {call_target}!')
+ raise RuntimeError(f'Rule already registered for {call_target}!')
_RULES[call_target] = fn
return fn
return register
@register_inference_rule(torch.add)
@register_inference_rule(operator.add)
def add_inference_rule(n: Node):
+ """
+ Apply the addition inference rule. This includes:
+ - scalar addition
+ - broadcasting semantics
+
+ Note that we always return the least precise type between
+ the operands (after applying broadcasting) to be the final type of the operation
+
+ Note that we do not modify the operand types themselves after applying broadcasting
+ to them. We only use them to calculate the final type
+ """
assert isinstance(n.args[0], Node)
assert isinstance(n.args[1], Node)
t1 = n.args[0].type
n.type = t2
return n.type
+ # handle scalar addition
elif t2 == int and isinstance(t1, TensorType):
n.type = t1
return n.type
+ # we bring the new types to the point where
+ # we can check for consistency
+ # any inconsistency would not have been caused
+ # by broadcasting at this point
(new_t1, new_t2) = broadcast_types(t1, t2)
if new_t1 != t1 or new_t2 != t2:
n.meta[str(n.args[0])] = new_t1
n.meta[str(n.args[1])] = new_t2
- # Todo: maybe figure out that broadcasting definitely did not happen?
else:
n.meta['broadcast'] = False
new_t1 = t1 if not n.meta['broadcast'] else new_t1
new_t2 = t2 if not n.meta['broadcast'] else new_t2
+ # we check for consistency between the new types
if is_consistent(new_t1, new_t2):
# we return the less precise type because
# broadcasting may have happened
@register_inference_rule(getattr)
def get_attr_inference_rule(n: Node, traced):
+ """
+ The current getattr rule only handles the shape attribute
+ Can be extended to other attributes
+ The most representitive type we have is "Dyn" but the system
+ can be extended with more types, such as a type to represent shapes
+ """
attr_node = n.args[0]
attr_name = n.args[1]
@register_inference_rule(torch.transpose)
def transpose_inference_rule(n: Node):
+ """
+ We check that dimentions for the transpose operations
+ are within range of the tensor type of the node
+ """
if n.target == torch.transpose:
assert isinstance(n.args[0], Node)
t = n.args[0].type
return n.type
elif isinstance(t, TensorType):
-
if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__):
new_type = list(t.__args__)
new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1]
final = TensorType(new_type)
- n.type = final
+ n.type = get_greatest_upper_bound(n.type, final)
return n.type
else:
raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
@register_inference_rule(torch.reshape)
def reshape_inference_rule(n: Node):
+ """
+ Without dynamism, the rule checks that the
+ product of the elements of the argument tensor
+ type is equal to the product of the elements
+ of the required shape. We gradualize this rule
+ by adding a case to handle fully dynamic input
+ as well as input where some of the tensor dimensions
+ are unknown. In this case we check for divisibility
+ """
assert isinstance(n.args[0], Node)
t1 = n.args[0].type
# if any of the dimensions are unknown,
# we check for divisibility
- elif isinstance(t1, TensorType) and Dyn in t1.__args__ or -1 in t2:
+ elif isinstance(t1, TensorType):
assert isinstance(t1, TensorType)
a = [e if e != Dyn else 1 for e in t1.__args__]
p1 = reduce(lambda x, y: x * y, a)
return t2_type
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
-
- # if all dimensions are known we check the products
- elif isinstance(t1, TensorType):
- p1 = reduce(lambda x, y: x * y, t1.__args__)
- p2 = reduce(lambda x, y: x * y, t2)
- if p1 == p2:
- n.type = t2_type
- return t2_type
- else:
- raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
-
else:
raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}')
def calculate_out_dimension(d_in, module_instance, index):
"""
- For calculating h_in and w_out.
+ For calculating h_in and w_out according to the conv2D documentation
"""
padding = (module_instance.padding, module_instance.padding) \
if isinstance(module_instance.padding, int) else module_instance.padding
def maxpool2d_check(typ, module_instance):
+ """
+ Applies the maxpool2d shape information to the input
+ this affects the last two dimensions
+ """
new_type_list = list(typ.__args__)
if len(new_type_list) == 4 or len(new_type_list) == 3:
w_in = new_type_list[-1]
"""
if len(tensor_type.__args__) >= 2:
if is_consistent(module_instance.in_features, tensor_type.__args__[-1]):
- # Todo backwards propagation
new_type_args = list(tensor_type.__args__)
new_type_args[-1] = module_instance.out_features
return TensorType(tuple(new_type_args))
@register_inference_rule(torch.nn.Linear)
def linear_inference_rule(n: Node, module_instance):
+ """
+ Applies the shape information to the input then gets the greatest upper bound
+ of the resulting type and the existing type
+ """
assert isinstance(n.args[0], Node)
if n.args[0].type == Dyn and isinstance(n.type, TensorType):
n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__))
@register_inference_rule(torch.flatten)
def flatten_inference_rule(n: Node):
+ """
+ Applies the flatten shape information to the input then gets the
+ greatest upper bound of the resulting type and the existing type
+ """
assert isinstance(n.args[0], Node)
# set the default start and end dims
@register_refinement_rule(Conv2d)
def conv_refinement_rule(n: Node):
+ """
+ The equality constraints are between the first dimension of
+ the input and output
+ """
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
@register_refinement_rule(torch.nn.Linear)
def linear_refinement_rule(n: Node):
+ """
+ The equality constraints are between the first dimension of
+ the input and output
+ """
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
return res
-# todo needs review for addition. Is this constraint correct?
@register_refinement_rule(BatchNorm2d)
@register_refinement_rule(torch.nn.ReLU)
def all_eq(n: Node):
+ """
+ For operations where the input shape is equal to the output shape
+ """
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
@register_refinement_rule(torch.nn.AdaptiveAvgPool2d)
-def first_two__eq(n: Node):
+@register_refinement_rule(torch.nn.MaxPool2d)
+def first_two_eq(n: Node):
+ """
+ For operations where the first two dimensions of the input and output shape
+ are equal
+ """
res = []
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
return res
+
@register_refinement_rule(torch.add)
@register_refinement_rule(operator.add)
-def add_eq(n: Node):
+def element_wise_eq(n: Node):
+ """
+ For element-wise operations and handles broadcasting.
+ Note that after applying broadcasting to the arguments
+ we are able to determine if certain dimensions have not been broadcast
+ if they are symbolicallu equal.
+
+ in this case, we can establish equality between those dimensions and the
+ corresponding output dimensions.
+
+ Note that it takes two iterations for this result. One iteration to establish
+ equality between certain dimensions of the operands (requiring the whole solver
+ including unification) and another iteration to establish equality between the operands
+ and the resulting type, requiring another round of constraint generation and unificaiton.
+ """
res = []
if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
arg_type1 = n.args[0].type
arg_type2 = n.args[1].type
if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
args1, args2 = broadcast_types(arg_type1, arg_type2)
- # by this point, we know for sure that args1 and args2 are the same size.
+ # by this point, we know that args1 and args2 are the same size.
a1 = args1.__args__
a2 = args2.__args__
a3 = n.type.__args__
+
+ # we would be here in the second iteration where we establish equality
+ # between operand type dimensions and the resulting type dimensions
r = []
for x, y, z in zip(a1, a2, a3):
if x == y:
res = r
return res
-@register_refinement_rule(torch.nn.MaxPool2d)
-def first_two(n: Node):
- res = []
- assert isinstance(n.args[0], Node)
- arg_type = n.args[0].type
- if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
- args1 = arg_type.__args__
- args2 = n.type.__args__
- res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
- return res
@register_refinement_rule(torch.flatten)
def flatten_refinement_rule(n: Node):
+ """
+ Generates equality constraints between the dimensions of the input and output
+ that will not be involved in the flatten operation
+ """
assert isinstance(n.args[0], Node)
eq_const = []
@register_algebraic_expressions_inference_rule(Conv2d)
def conv_rule(n: Node, module_instance):
+ """
+ Represents the outout in terms of an algrbraic expression w.r.t
+ the input when possible
+ """
assert isinstance(n.args[0], Node)
arg_type = n.args[0].type
if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):