[Torch] Fix dtype handling for modules with integer parameters (#6311)
authormasahi <masahi129@gmail.com>
Fri, 21 Aug 2020 02:29:45 +0000 (11:29 +0900)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 02:29:45 +0000 (07:59 +0530)
* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 85dd5f4..8725a64 100644 (file)
@@ -2130,6 +2130,7 @@ def _report_missing_conversion(op_names, convert_map):
         msg = "The following operators are not implemented: {}".format(missing)
         raise NotImplementedError(msg)
 
+
 def _getattr_attr_name(node):
     attribute_names = node.attributeNames()
     assert len(attribute_names) == 1
@@ -2140,6 +2141,7 @@ def _getattr_attr_name(node):
 def _getattr_full_name(getattrs):
     return ".".join([_getattr_attr_name(node) for node in getattrs])
 
+
 def _get_pytorch_value_type(typ, default_dtype="float32"):
     kind = typ.kind()
     if kind == 'TensorType':
@@ -2162,16 +2164,25 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):
         return 'UnsupportedType'
 
 
-def _get_input_types(op_node, default_dtype="float32"):
+def _get_input_types(op_node, outputs, default_dtype="float32"):
     """Returns a TVM dtype for each input nodes derived from the torch type"""
-    return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
-            for i in op_node.inputs()]
-
+    in_types = []
+    for inp in op_node.inputs():
+        if inp.node().kind() == "prim::GetAttr":
+            # GetAttr nodes always return None when we call scalarType() on it
+            name = inp.debugName()
+            assert name in outputs
+            if isinstance(outputs[name], _expr.Var):
+                in_types.append(outputs[name].type_annotation.dtype)
+            else:
+                # For quantized modules with parameters, here we would get
+                # "prim::GetAttr[name="_packed_params"]". Since the dtype corresponding to
+                # _packed_params is not needed by quantized ops, we return an arbitrary type.
+                in_types.append(default_dtype)
+        else:
+            in_types.append(_get_pytorch_value_type(inp.type(), default_dtype=default_dtype))
 
-def _get_output_types(op_node, default_dtype="float32"):
-    """Returns a TVM dtype for each input nodes derived from the torch type"""
-    return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
-            for i in op_node.outputs()]
+    return in_types
 
 
 def _get_constant(node):
@@ -2575,7 +2586,8 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude, defau
             outputs.update(zip(unpacked_names, loop_out))
         else:
             relay_op = convert_map[operator]
-            relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
+            relay_out = relay_op(inputs, _get_input_types(op_node, outputs,
+                                                          default_dtype=default_dtype))
 
             if isinstance(relay_out, tuple):
                 # This is for torch operators that return multiple outputs
index d5b4ed2..e5c9634 100644 (file)
@@ -2550,6 +2550,19 @@ def test_forward_dtypes():
         tensor2 = torch.randn(3, 4).to(dtype=dt)
         verify_model(fn, input_data=[tensor1, tensor2])
 
+    class ModuleWithIntParameters(Module):
+        def __init__(self, arr):
+            super().__init__()
+            self.param = torch.nn.Parameter(torch.LongTensor(arr), requires_grad=False)
+
+        def forward(self, x):
+            return x.long() + self.param
+
+    shape = (10, 10)
+    param = torch.ones(shape, dtype=torch.long)
+    inp = torch.ones(shape, dtype=torch.int)
+    verify_model(ModuleWithIntParameters(param), input_data=inp)
+
 
 def test_weight_names():
     tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)])