msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
+
def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])
+
def _get_pytorch_value_type(typ, default_dtype="float32"):
kind = typ.kind()
if kind == 'TensorType':
return 'UnsupportedType'
-def _get_input_types(op_node, default_dtype="float32"):
+def _get_input_types(op_node, outputs, default_dtype="float32"):
"""Returns a TVM dtype for each input nodes derived from the torch type"""
- return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
- for i in op_node.inputs()]
-
+ in_types = []
+ for inp in op_node.inputs():
+ if inp.node().kind() == "prim::GetAttr":
+ # GetAttr nodes always return None when we call scalarType() on it
+ name = inp.debugName()
+ assert name in outputs
+ if isinstance(outputs[name], _expr.Var):
+ in_types.append(outputs[name].type_annotation.dtype)
+ else:
+ # For quantized modules with parameters, here we would get
+ # "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to
+ # _packed_params is not needed by quantized ops, we return an arbitrary type.
+ in_types.append(default_dtype)
+ else:
+ in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype))
-def _get_output_types(op_node, default_dtype="float32"):
- """Returns a TVM dtype for each input nodes derived from the torch type"""
- return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
- for i in op_node.outputs()]
+ return in_types
def _get_constant(node):
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = convert_map[operator]
- relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
+ relay_out = relay_op(inputs, _get_input_types(op_node, outputs,
+ default_dtype=default_dtype))
if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
tensor2 = torch.randn(3, 4).to(dtype=dt)
verify_model(fn, input_data=[tensor1, tensor2])
+ class ModuleWithIntParameters(Module):
+ def __init__(self, arr):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False)
+
+ def forward(self, x):
+ return x.long() + self.param
+
+ shape = (10, 10)
+ param = torch.ones(shape, dtype=torch.long)
+ inp = torch.ones(shape, dtype=torch.int)
+ verify_model(ModuleWithIntParameters(param), input_data=inp)
+
def test_weight_names():
tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])