def flatten_check(tensor_type, start_dim, end_dim):
l = len(tensor_type.__args__)
- start_dim = l if start_dim == -1 else start_dim
+ start_dim = l if start_dim == -1 else abs(start_dim)
end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1
if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim:
elif isinstance(typ, TensorType):
new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__]
return TensorType(tuple(new_args))
+ elif isinstance(typ, list):
+ return [self.replace_dyn_with_fresh_var(t) for t in typ]
+ elif isinstance(typ, tuple):
+ return (self.replace_dyn_with_fresh_var(t) for t in typ)
else:
return typ
pass
if n.op == 'output':
- assert isinstance(n.args[0], Node)
- n.type = n.args[0].type
+ def get_node_type(a):
+ return a.type
+ n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
+ return n.type
else:
pass
if isinstance(t, Var):
if t in mapping.keys():
return mapping[t]
+ else:
+ return t
elif isinstance(t, TensorType):
new_type = []
new_type.append(typ)
return TensorType(tuple(new_type))
+ elif isinstance(t, list):
+ new_type = []
+ for typ in t:
+ new_type.append(substitute_solution_one_type(mapping, typ))
+ return new_type
+
+ elif isinstance(t, tuple):
+ new_type = []
+ for typ in t:
+ new_type.append(substitute_solution_one_type(mapping, typ))
+ return tuple(new_type)
+
+ else:
+ return t
+
def substitute_all_types(graph, mapping):
"""