def get_target_type_str(node: Node, gm: GraphModule) -> str:
"""
Returns a string representation of the type of the function or module
- pointed to by this node, or '' for other op types.
+ pointed to by this node, or '' for other node types.
"""
target_type = ""
if node.op in ("call_function", "call_method"):
- target_type = str(node.target)
+ target_type = torch.typename(node.target)
elif node.op == "call_module":
assert isinstance(node.target, str)
target_mod = getattr_from_fqn(gm, node.target)
- target_type = str(type(target_mod))
+ target_type = torch.typename(target_mod)
return target_type
op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
ref_node_type = get_target_type_str(node, gm)
+ # for extracting weights, these are always the same
+ prev_node_type = ref_node_type
if node.op == 'call_function':
function_mapping = op_to_type_to_weight_extraction_fn['call_function']
'type': res_type,
'values': [weight],
'prev_node_name': node.name,
- 'prev_node_target_type': str(node.target),
+ 'prev_node_target_type': prev_node_type,
'ref_node_name': node.name,
'ref_node_target_type': ref_node_type,
'index_within_arg': 0,
'type': res_type,
'values': [weight],
'prev_node_name': node.name,
- 'prev_node_target_type': str(type(mod)),
+ 'prev_node_target_type': prev_node_type,
'ref_node_name': node.name,
'ref_node_target_type': ref_node_type,
'index_within_arg': 0,