Improve type handling in PyTorch frontend (#5834)
authorThomas Viehmann <tv.code@beamnet.de>
Mon, 22 Jun 2020 13:33:04 +0000 (15:33 +0200)
committerGitHub <noreply@github.com>
Mon, 22 Jun 2020 13:33:04 +0000 (19:03 +0530)
* Improve type handling in PyTorch frontend

- Use type information from graph for inputs if available. Check
  against shape information from graph if available.
- Allow user to set default dtype (default to float32 for sanity and
  compatibility).
- Implement type promotion to follow PyTorch mechanism. This includes
  fixing the handling of many "Scalar" overloads in PyTorch binary ops.
- Fix arange/linspace type semantics.
- Added support for traced functions. (Because it really is about the
  "self" input handling.)

Aside from adding an optional default_dtype keyword argument, this does not
change the signature/requirements of from_pytorch.

* Fix scalar detection using numpy.isscalar

and address other review comments. Thank you @siju-samuel

* refine test criteron on qnn_test::test_serialized_modules, fix bool conversion of const

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/qnn_test.py
tests/python/frontend/pytorch/test_forward.py

index d3b6510..f70a64a 100644 (file)
@@ -126,18 +126,7 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        # TODO: Figure out a better way to get typing to work for tensor + scalar
-        type0 = input_types[0]
-        if isinstance(inputs[1], _expr.Expr):
-            type0 = input_types[1]
-
-        type1 = input_types[1]
-        if isinstance(inputs[0], _expr.Expr):
-            type1 = input_types[0]
-
-        data0 = _convert_elemwise_input(inputs[0], type0)
-        data1 = _convert_elemwise_input(inputs[1], type1)
-
+        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
         return get_relay_op(name)(data0, data1)
     return _impl
 
@@ -145,8 +134,8 @@ def _elemwise(name):
 def _unary(name):
     def _impl(inputs, input_types):
         input_type = input_types[0]
-        data = _convert_elemwise_input(inputs[0], input_type)
-
+        # this is just to ensure tensor input
+        data, = _pytorch_promote_types(inputs[:1], input_types[:1])
         return get_relay_op(name)(data)
     return _impl
 
@@ -154,7 +143,8 @@ def _unary(name):
 def _log1p():
     def _impl(inputs, input_types):
         # 1_plus_log x = log(x + 1)
-        one = _expr.const(1, dtype="float32")
+        dtype, = input_types
+        one = _expr.const(1, dtype=dtype)
         return _op.log(inputs[0] + one)
     return _impl
 
@@ -162,25 +152,40 @@ def _log1p():
 def _arange():
     def _impl(inputs, input_types):
         def _get_value(val, dtype):
+            # dtype is a tvm dtype
             if isinstance(val, _expr.Expr):
-                return _op.cast(val, _convert_data_type(dtype))
+                return _op.cast(val, dtype)
             return _create_typed_const(val, dtype)
 
         def _get_type(val, inp_type):
             if isinstance(val, _expr.Expr):
                 dtype = str(_infer_type(val).checked_type)
-                return dtype if dtype != "float32" else "float"
+                return dtype
             return inp_type
 
+        # PyTorch arange uses the following type semantics:
+        # - if a dtype is given, start, stop, step are converted to that dtype
+        # - if no dtype is given and all args are integral, dtype is int64
+        # - if no dtype is given and there is a float arg, dtype is float32
         if len(inputs) == 5:
             dtype0 = _get_type(inputs[0], input_types[0])
-            dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1])
+            if inputs[1] is not None:
+                dtype = _convert_dtype_value(inputs[1])
+            elif dtype0.startswith("float"):
+                dtype = "float32"
+            else:
+                dtype = "int64"
             start = _get_value(0, dtype)
             stop = _get_value(inputs[0], dtype)
             step = _get_value(1, dtype)
         elif len(inputs) == 7:
             types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
-            dtype = "float" if "float" in types else _convert_dtype_value(inputs[3])
+            if inputs[3] is not None:
+                dtype = _convert_dtype_value(inputs[3])
+            elif any([t.startswith("float") for t in types]):
+                dtype = "float32"
+            else:
+                dtype = "int64"
             start = _get_value(inputs[0], dtype)
             stop = _get_value(inputs[1], dtype)
             step = _get_value(inputs[2], dtype)
@@ -191,7 +196,7 @@ def _arange():
         return _op.transform.arange(start=start,
                                     stop=stop,
                                     step=step,
-                                    dtype=_convert_data_type(dtype))
+                                    dtype=dtype)
     return _impl
 
 def _squeeze():
@@ -200,6 +205,7 @@ def _squeeze():
         if len(inputs) == 1:
             axis = None
         else:
+            # TODO (t-vi): why is the cast to int needed? similarly elsewhere
             axis = [int(inputs[1])]
 
         return _op.transform.squeeze(data, axis)
@@ -295,7 +301,7 @@ def _split():
     return _impl
 
 def _split_with_sizes():
-    def _impl(inputs, inputs_types):
+    def _impl(inputs, input_types):
         data = inputs[0]
         dim = int(inputs[2])
 
@@ -345,7 +351,7 @@ def _topk():
 def _reciprocal():
     def _impl(inputs, input_types):
         data = inputs[0]
-        return _expr.const(1.0) / data
+        return _expr.const(1.0, dtype=input_types[0]) / data
     return _impl
 
 def _repeat():
@@ -373,22 +379,14 @@ def _repeat_interleave():
 
 def _addcdiv():
     def _impl(inputs, input_types):
-        data = inputs[0]
-        c = _expr.const(inputs[3])
-        t1 = inputs[1]
-        t2 = inputs[2]
-
+        data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
         return data + (c * (t1 / t2))
     return _impl
 
 
 def _addcmul():
     def _impl(inputs, input_types):
-        data = inputs[0]
-        c = _expr.const(inputs[3])
-        t1 = inputs[1]
-        t2 = inputs[2]
-
+        data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
         return data + (c * (t1 * t2))
     return _impl
 
@@ -396,9 +394,7 @@ def _addcmul():
 def _where():
     def _impl(inputs, input_types):
         cond = inputs[0]
-        x = inputs[1]
-        y = inputs[2]
-
+        x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
         return _op.where(cond, x, y)
 
     return _impl
@@ -419,7 +415,7 @@ def _ones():
             msg = "Data type %s could not be parsed in ones op" % (type(data))
             raise AssertionError(msg)
 
-        dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+        dtype = _convert_dtype_value(inputs[1])
 
         return _op.full(_expr.const(1), shape, dtype=dtype)
     return _impl
@@ -430,8 +426,8 @@ def _ones_like():
         out = _op.ones_like(data)
 
         # If the input and the output datatype is different, do a cast
-        dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
-        if input_types[0] not in dtype:
+        dtype = _convert_dtype_value(inputs[1])
+        if input_types[0] != dtype:
             out = _op.cast(out, dtype)
 
         return out
@@ -453,7 +449,7 @@ def _zeros():
             msg = "Data type %s could not be parsed in zeros op" % (type(data))
             raise AssertionError(msg)
 
-        dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+        dtype = _convert_dtype_value(inputs[1])
 
         return _op.full(_expr.const(0), shape, dtype=dtype)
     return _impl
@@ -465,7 +461,7 @@ def _zeros_like():
         out = _op.zeros_like(data)
 
         # If the input and the output datatype is different, do a cast
-        dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+        dtype = _convert_dtype_value(inputs[1])
         if input_types[0] not in dtype:
             out = _op.cast(out, dtype)
 
@@ -490,7 +486,7 @@ def _full():
             raise AssertionError(msg)
 
         if inputs[2] is not None: # dtype given
-            dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+            dtype = _convert_dtype_value(inputs[2])
         else:
             dtype = data.type_annotation.dtype
 
@@ -505,7 +501,7 @@ def _full_like():
         out = _op.full_like(data, _expr.const(fill_value))
 
         # If the input and the output datatype is different, do a cast
-        dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+        dtype = _convert_dtype_value(inputs[2])
         if input_types[0] not in dtype:
             out = _op.cast(out, dtype)
 
@@ -526,7 +522,8 @@ def _linspace():
         else:
             stop = start + step
 
-        dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
+        dtype = ("float32" if inputs[3] is not None
+                 else _convert_dtype_value(inputs[3]))
         start = _create_typed_const(start, dtype)
         stop = _create_typed_const(stop, dtype)
         step = _create_typed_const(step, dtype)
@@ -534,7 +531,7 @@ def _linspace():
         return _op.transform.arange(start=start,
                                     stop=stop,
                                     step=step,
-                                    dtype=_convert_data_type(dtype))
+                                    dtype=dtype)
     return _impl
 
 
@@ -565,35 +562,41 @@ def _leaky_relu():
 def _elu():
     def _impl(inputs, input_types):
         data = inputs[0]
-        alpha = _expr.const(float(inputs[1]))
-        return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data)
+        dtype = input_types[0]
+        alpha = _expr.const(float(inputs[1]), dtype=dtype)
+        return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
     return _impl
 
 def _celu():
     def _impl(inputs, input_types):
         data = inputs[0]
-        alpha = _expr.const(float(inputs[1]))
-        return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data)
+        dtype = input_types[0]
+        alpha = _expr.const(float(inputs[1]), dtype=dtype)
+        return alpha * _op.nn.relu(_expr.const(1, dtype=dtype)
+                                   - _op.exp(data / alpha)) + _op.nn.relu(data)
     return _impl
 
 def _gelu():
     def _impl(inputs, input_types):
         data = inputs[0]
+        dtype = input_types[0]
         # gelu is data  * normcdf(data)
         # normcdf expressed as erf because we don't currently have that intrinsic
         # note that there is also a fastgelu variant approximating normcdf
         # with tanh and third order polynomials, but this is "true" gelu
-        return data * (_expr.const(0.5) +
-                       _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5))
+        return data * (_expr.const(0.5, dtype=dtype) +
+                       _op.erf(data * _expr.const(0.5**0.5, dtype=dtype))
+                       * _expr.const(0.5, dtype=dtype))
     return _impl
 
 def _selu():
     def _impl(inputs, input_types):
         data = inputs[0]
         # https://pytorch.org/docs/stable/nn.html#selu
-        alpha = _expr.const(-1.6732632423543772848170429916717)
-        gamma = _expr.const(1.0507009873554804934193349852946)
-        return gamma * (alpha * _op.nn.relu(_expr.const(1.0)
+        dtype = input_types[0]
+        alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype)
+        gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype)
+        return gamma * (alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype)
                                             - _op.exp(data)) + _op.nn.relu(data))
     return _impl
 
@@ -1112,8 +1115,9 @@ def _sigmoid():
 def _softplus():
     def _impl(inputs, input_types):
         data = inputs[0]
-        beta = _expr.const(float(inputs[1]))
-        return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
+        dtype = input_types[0]
+        beta = _expr.const(float(inputs[1]), dtype=dtype)
+        return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1., dtype=dtype)) / beta
     return _impl
 
 def _avg_pool2d(prelude):
@@ -1195,6 +1199,7 @@ def _reduce(name):
 def _norm():
     def _impl(inputs, input_types):
         data = inputs[0]
+        dtype = input_types[0]
         axis = None
         keepdims = False
         if len(inputs) > 3:
@@ -1207,7 +1212,7 @@ def _norm():
         elif order == np.NINF:
             return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims)
         else:
-            reci_order = _expr.const(1.0 / order)
+            reci_order = _expr.const(1.0 / order, dtype=dtype)
             order = _expr.const(order)
             return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order),
                                             axis=axis,
@@ -1239,7 +1244,7 @@ def _std():
 
         if unbiased:
             msg = "Currently only supports standard-deviation calculated via the biased "\
-                  "estimator. Pytorch's Bessel's correction is not supported."
+                  "estimator. PyTorch's Bessel's correction is not supported."
             raise NotImplementedError(msg)
 
         return _op.reduce.std(data, axis=axis, keepdims=keepdims)
@@ -1255,7 +1260,7 @@ def _variance():
 
         if unbiased:
             msg = "Currently only supports standard-deviation calculated via the biased "\
-                  "estimator. Pytorch's Bessel's correction is not supported."
+                  "estimator. PyTorch's Bessel's correction is not supported."
             raise NotImplementedError(msg)
 
         return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
@@ -1657,7 +1662,7 @@ def _type_as():
     def _impl(inputs, input_types):
         assert len(inputs) == 2
         assert len(input_types) == 2
-        return _op.cast(inputs[0], _convert_data_type(input_types[1]))
+        return _op.cast(inputs[0], input_types[1])
     return _impl
 
 
@@ -1687,20 +1692,13 @@ def _tensor_array_stack(prelude):
 
 def _rsub():
     def _impl(inputs, input_types):
-        # TODO: Figure out a better way to get typing to work for tensor + scalar
-        type0 = input_types[0]
-        if isinstance(inputs[1], _expr.Expr):
-            type0 = input_types[1]
+        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
 
-        type1 = input_types[1]
-        if isinstance(inputs[0], _expr.Expr):
-            type1 = input_types[0]
-
-        data1 = _convert_elemwise_input(inputs[0], type0)
-        data0 = _convert_elemwise_input(inputs[1], type1)
+        # TODO (t-vi): should this also be part of the type promotion?
         alpha = _expr.const(float(inputs[2]))
 
-        return get_relay_op("subtract")(data0, alpha * data1)
+        # note: rsub means data0 and data1 swap places
+        return get_relay_op("subtract")(data1, alpha * data0)
     return _impl
 
 
@@ -1729,8 +1727,55 @@ def _one_hot():
     return _impl
 
 
+def _pytorch_result_type(dtypes, non_tensor_inputs):
+    """This promotes TVM dtypes like PyTorch would"""
+    import torch
+    dtype_map = {
+        "float64": torch.float64,
+        "float32": torch.float32,
+        "float16": torch.float16,
+        "bfloat16": torch.bfloat16,
+        "int64": torch.int64,
+        "int32": torch.int32,
+        "int16": torch.int16,
+        "int8": torch.int8,
+        "uint8": torch.uint8,
+        "bool": torch.bool
+        }
+    if len(dtypes) > 0:
+        result_type = dtypes[0]
+        for dt in dtypes[1:]:
+            if dt != result_type: # we don't want to work with same types as we
+                                  # don't do quantized here (which cannot be promoted?)
+                result_type = _convert_data_type(str(torch.result_type(
+                    torch.zeros((), dtype=dtype_map[result_type]),
+                    torch.zeros((), dtype=dtype_map[dt]))))
+    else:
+        result_type = "bool"  # this is the smallest type...
+    for inp in non_tensor_inputs:
+        result_type = _convert_data_type(
+            str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]),
+                                  inp)))
+    return result_type
+
+def _pytorch_promote_types(inputs, dtypes):
+    """This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
+    tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
+    non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
+    result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
+    results = []
+    for inp, dt in zip(inputs, dtypes):
+        if np.isscalar(inp):
+            results.append(_expr.const(inp, dtype=result_type))
+        elif dt == result_type:
+            results.append(inp)
+        else:
+            results.append(_op.cast(inp, result_type))
+    return results
+
 # Helper functions for operator implementation
 def _convert_dtype_value(val):
+    """converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
     convert_torch_dtype_map = {7:"torch.float64",
                                6:"torch.float32",
                                5:"torch.float16",
@@ -1741,12 +1786,19 @@ def _convert_dtype_value(val):
                                0:"torch.unit8",
                                None:"torch.int64"} # Default is torch.int64
     if val in convert_torch_dtype_map:
-        return convert_torch_dtype_map[val]
+        return _convert_data_type(convert_torch_dtype_map[val])
     else:
         msg = "Torch data type value %d is not handled yet." % (val)
         raise NotImplementedError(msg)
 
-def _convert_data_type(input_type):
+def _convert_data_type(input_type, default_dtype=None):
+    """converts the PyTorch scalar type input_type to a TVM dtype.
+       optionally, default_dtype can be a TVM dtype that is used
+       if input_type is None (but not when it is unknown)"""
+    if input_type is None and default_dtype is not None:
+        return default_dtype
+
+    input_type = input_type.lower()
     if input_type in ["double", "torch.float64"]:
         return "float64"
     elif input_type in ["float", "torch.float32"]:
@@ -1763,12 +1815,21 @@ def _convert_data_type(input_type):
         return "int8"
     elif input_type in ["byte", "torch.uint8"]:
         return "uint8"
+    elif input_type in ["quint8", "torch.quint8"]:
+        return "quint8"
+    elif input_type in ["qint8", "torch.qint8"]:
+        return "qint8"
+    elif input_type in ["qint32", "torch.qint32"]:
+        return "qint32"
+    elif input_type in ["bool", "torch.bool"]:
+        return "bool"
     else:
-        raise NotImplementedError("input_type {} is not handled yet" % (input_type))
-    return "float32"
+        raise NotImplementedError("input_type {} is not handled yet".format(input_type))
+    return "float32"  # Never reached
 
-def _create_typed_const(data, data_type):
-    dtype = _convert_data_type(data_type)
+def _create_typed_const(data, dtype):
+    """create a (scalar) constant of given value and dtype.
+       dtype should be a TVM dtype"""
 
     if dtype == "float64":
         typed_data = _expr.const(np.float64(data), dtype=dtype)
@@ -1787,18 +1848,9 @@ def _create_typed_const(data, data_type):
     elif dtype == "uint8":
         typed_data = _expr.const(np.uint8(data), dtype=dtype)
     else:
-        raise NotImplementedError("input_type {} is not handled yet" % (data_type))
+        raise NotImplementedError("input_type {} is not handled yet".format(dtype))
     return typed_data
 
-def _convert_elemwise_input(data, input_type):
-    import torch
-    if isinstance(data, torch.Tensor):
-        return _expr.const(data.item(), dtype=_convert_data_type(input_type))
-    elif not isinstance(data, _expr.Expr):
-        return _expr.const(data, dtype=_convert_data_type(input_type))
-    else:
-        return data
-
 def _wrap_const(c):
     if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
         return _expr.const(c)
@@ -1891,6 +1943,7 @@ def _get_convert_map(prelude):
         "aten::mean"                            : _mean(prelude),
         "aten::chunk"                           : _chunk(prelude),
         "aten::matmul"                          : _matmul(prelude),
+        "aten::bmm"                             : _matmul(prelude),
         "aten::expand"                          : _expand(),
         "aten::Int"                             : _int(),
         "prim::NumToTensor"                     : _numtotensor(),
@@ -1981,12 +2034,13 @@ def _run_jit_passes(graph):
 
 
 def _is_int_seq(seq):
+    # TODO (t-vi): handle non-int constants? (like numpy.intXX)
     return len(seq) > 0 and all([isinstance(i, int) for i in seq])
 
 
 def _get_tensor_and_var(torch_tensor, name):
     tensor = tvm.nd.array(torch_tensor.cpu().numpy())
-    var = _expr.var(name, shape=tensor.shape)
+    var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
     return tensor, var
 
 
@@ -2039,35 +2093,6 @@ def _report_missing_conversion(op_names, convert_map):
         msg = "The following operators are not implemented: {}".format(missing)
         raise NotImplementedError(msg)
 
-
-def _check_inputs(graph, input_shapes):
-    """
-    Check the graph inputs match the expected number of inputs
-    and are in the correct format
-    """
-    ir_inputs = _get_graph_input_names(graph)
-
-    if not isinstance(input_shapes, list):
-        msg = "Graph inputs input_shapes should be list"
-        raise RuntimeError(msg)
-    missing_inputs = len(ir_inputs) - len(input_shapes)
-    if missing_inputs > 0:
-        msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
-        raise RuntimeError(msg)
-
-    for num, inp in enumerate(input_shapes):
-        if num < len(ir_inputs):
-            if not isinstance(inp, tuple):
-                msg = "Graph input {} is not a tuple".format(num)
-                raise RuntimeError(msg)
-            if (len(inp) != 2 or not isinstance(inp[0], str)):
-                msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
-                raise RuntimeError(msg)
-        else:
-            msg = "Unused graph input {} in input_shapes".format(inp)
-            logging.warning(msg)
-
-
 def _getattr_attr_name(node):
     attribute_names = node.attributeNames()
     assert len(attribute_names) == 1
@@ -2078,37 +2103,38 @@ def _getattr_attr_name(node):
 def _getattr_full_name(getattrs):
     return ".".join([_getattr_attr_name(node) for node in getattrs])
 
+def _get_pytorch_value_type(typ, default_dtype="float32"):
+    kind = typ.kind()
+    if kind == 'TensorType':
+        if typ.scalarType() is None:
+            # Tensor's type can be unknown if we use torch.jit.script(...)
+            # Defaults can be passed in, if not it is float32
+            logging.warning("Untyped Tensor found, assume it is %s", default_dtype)
+            return default_dtype
+        else:
+            return _convert_data_type(typ.scalarType())
+
+    elif kind == 'ListType':
+        return "ListType"
+    elif kind in ['IntType', 'FloatType', 'BoolType',
+                  'StringType', 'OptionalType']:
+        pt_dtype = str(typ).lower()
+        dtype = pt_dtype if pt_dtype == 'OptionalType' else _convert_data_type(pt_dtype)
+        return dtype
+    else:
+        return 'UnsupportedType'
 
-def _get_input_types(op_node):
-    """ Returns a torch type for each input nodes """
-    input_list_types = []
-    for input_node in op_node.inputs():
-        in_ty = input_node.type()
-        input_node_kind = in_ty.kind()
-        if input_node_kind == 'TensorType':
-            if in_ty.scalarType() is None:
-                # Tensor's type can be unknown if we use torch.jit.script(...)
-                # Defaults to float for now
-                logging.warning("Untyped Tensor found, assume it is float")
-                input_list_types.append("float")
-            else:
-                input_list_types.append(in_ty.scalarType().lower())
 
-        elif input_node_kind == 'ListType':
-            input_list_types.append("ListType")
-        elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
-                                 'StringType', 'OptionalType']:
-            input_list_types.append(str(in_ty).lower())
-        else:
-            input_list_types.append('UnsupportedType')
+def _get_input_types(op_node, default_dtype="float32"):
+    """Returns a TVM dtype for each input nodes derived from the torch type"""
+    return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
+            for i in op_node.inputs()]
 
-    if op_node.kind() in ['aten::ones', 'aten::zeros']:
-        node_type = op_node.output().type()
-        scalar_type = node_type.scalarType()
-        if scalar_type:
-            input_list_types[0] = scalar_type.lower()
 
-    return input_list_types
+def _get_output_types(op_node, default_dtype="float32"):
+    """Returns a TVM dtype for each input nodes derived from the torch type"""
+    return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
+            for i in op_node.outputs()]
 
 
 def _get_constant(node):
@@ -2120,14 +2146,17 @@ def _get_constant(node):
         attr_name = attribute_names[0]
         ty = node.output().type().kind()
 
-        if ty in ["IntType", "BoolType"]:
+        if ty == "IntType":
             return node.i(attr_name)
+        elif ty == "BoolType":
+            return bool(node.i(attr_name))
         elif ty in ["FloatType", "LongType"]:
             return node.f(attr_name)
         elif ty in ["TensorType", "CompleteTensorType"]:
             tensor = node.t(attr_name)
             if len(tensor.shape) == 0:  # tensor(0.1)
-                return float(tensor)
+                # TODO(t-vi): When is this needed?
+                return tensor.item()
             return _wrap_const(tensor.numpy())
         elif ty == "DeviceObjType":
             return node.s(attr_name)
@@ -2156,35 +2185,75 @@ def _get_operator_nodes(nodes):
     return ops
 
 
-def _get_graph_input_names(graph):
-    """ Get the graph input names (use after graph copy and run jit passes) """
-    # Variable names could change the first time a copy is made and after
-    # _run_jit_passes is called, expected that those functions already invoked
-    ir_inputs = _get_input_names(graph)
-    return ir_inputs[1:]  # remove self at the 0th arg
-
-
-def _get_relay_input_vars(graph, input_shapes, prelude):
+def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_dtype="float32"):
     """
     Return Relay vars from input shapes and create entries based on
     expected graph inputs - to allow translation
     """
-    def get_relay_ty(ishape):
-        if _is_int_seq(ishape) or len(ishape) == 0:
-            return TensorType(ishape)
-        elif isinstance(ishape, tuple):
-            return TupleType([get_relay_ty(elem) for elem in ishape])
-        elif isinstance(ishape, list):
-            assert len(ishape) > 0
-            elem_tys = [get_relay_ty(s) for s in ishape]
-            msg = "List elements should have identical types"
-            assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+
+    graph_inputs = list(graph.inputs())
+    if is_module:
+        # a module has "self" as first input, which we do not need/want
+        graph_inputs = graph_inputs[1:]
+
+    if not isinstance(input_shapes, list):
+        msg = "Graph inputs input_shapes should be a list"
+        raise RuntimeError(msg)
+
+    if len(graph_inputs) != len(input_shapes):
+        msg = "PyTorch has {} inputs and input_shapes lists {}.".format(
+            len(graph_inputs), len(input_shapes))
+        raise RuntimeError(msg)
+
+    def get_relay_ty(ishape, pt_type):
+        if pt_type.kind() == 'TensorType':
+            if not (_is_int_seq(ishape) or len(ishape) == 0):
+                msg = "Shape for Tensors must be lists of ints"
+                raise RuntimeError(msg)
+            if ((pt_type.dim() is not None and pt_type.dim() != len(ishape)) or
+                    (pt_type.sizes() is not None
+                     and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)]))):
+                msg = "Shapes of input list and information in the graph do not match"
+                raise RuntimeError(msg)
+            pt_dtype = pt_type.scalarType()
+            dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
+            return TensorType(ishape, dtype)
+        elif pt_type.kind() == 'TupleType':
+            if not isinstance(ishape, tuple):
+                msg = "Shapes for tuples must be tuples"
+                raise RuntimeError(msg)
+            return TupleType([get_relay_ty(elem, pt_t)
+                              for elem, pt_t in zip(ishape, pt_type.elements())])
+        elif pt_type.kind() == 'ListType':
+            if not isinstance(ishape, list):
+                msg = "Shapes for lists must be lists"
+                raise RuntimeError(msg)
+            pt_elemtype = pt_type.getElementType()
+            elem_tys = [get_relay_ty(s, pt_elemtype) for s in ishape]
+            if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0], elem_tys)):
+                msg = "List elements need have identical types"
+                raise RuntimeError(msg)
             return prelude.l(elem_tys[0])
+        elif pt_type.kind() == 'OptionalType':
+            # we do not support None yet, so we fill in the type
+            return get_relay_ty(ishape, pt_type.getElementType())
+        # TODO: scalar inputs
         raise NotImplementedError("unsupported input type")
 
-    input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
     input_vars = {}
-    ir_inputs = _get_graph_input_names(graph)
+
+    for num, inp in enumerate(input_shapes):
+        if not isinstance(inp, tuple):
+            msg = "Graph input {} is not a tuple".format(num)
+            raise RuntimeError(msg)
+        if (len(inp) != 2 or not isinstance(inp[0], str)):
+            msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
+            raise RuntimeError(msg)
+
+    input_types = [(name, get_relay_ty(shape, gi.type()))
+                   for (name, shape), gi in zip(input_shapes, graph_inputs)]
+
+    ir_inputs = [i.debugName() for i in graph_inputs]
     for ir_input, (name, itype) in zip(ir_inputs, input_types):
         inp = _expr.var(name, type_annotation=itype)
         # Translate from graph input to user input name
@@ -2292,19 +2361,22 @@ def convert_params(graph, state_dict):
     return params, param_tensors, packed_param_map
 
 
-def convert_block(block, outputs, convert_map, prelude):
+def convert_block(block, outputs, convert_map, prelude, default_dtype="float32"):
     """ Translate Torch "Block", used for prim::If and prim::Loop """
     ops = _get_operator_nodes(block.nodes())
     ret_names = _get_input_names(block.returnNode())
-    return convert_operators(ops, outputs, ret_names, convert_map, prelude)
+    return convert_operators(ops, outputs, ret_names, convert_map, prelude,
+                             default_dtype=default_dtype)
 
 
-def convert_if(if_node, outputs, convert_map, prelude):
+def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"):
     """ Translate Torch prim::If to Relay If """
     cond = outputs[if_node.inputsAt(0).debugName()]
     blocks = list(if_node.blocks())
-    true_branch = convert_block(blocks[0], outputs, convert_map, prelude)
-    false_branch = convert_block(blocks[1], outputs, convert_map, prelude)
+    true_branch = convert_block(blocks[0], outputs, convert_map, prelude,
+                                default_dtype=default_dtype)
+    false_branch = convert_block(blocks[1], outputs, convert_map, prelude,
+                                 default_dtype=default_dtype)
     assert len(true_branch) == 1 and len(false_branch) == 1
     return _expr.If(cond, true_branch[0], false_branch[0])
 
@@ -2424,7 +2496,7 @@ def convert_loop(loop_node, outputs, convert_map, prelude):
     return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
 
 
-def convert_operators(operators, outputs, ret_names, convert_map, prelude):
+def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"):
     """ Convert each Torch IR operators to Relay equivalent """
     for node_name, op_node in operators:
         operator = op_node.kind()
@@ -2450,7 +2522,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude):
                 unpacked = _unpack_tuple(inputs[0])
             outputs.update(zip(_get_output_names(op_node), unpacked))
         elif operator == "prim::If":
-            if_out = convert_if(op_node, outputs, convert_map, prelude)
+            if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype)
             outputs[node_name] = if_out
         elif operator == "prim::Loop":
             loop_out = convert_loop(op_node, outputs, convert_map, prelude)
@@ -2459,7 +2531,7 @@ def convert_operators(operators, outputs, ret_names, convert_map, prelude):
             outputs.update(zip(unpacked_names, loop_out))
         else:
             relay_op = convert_map[operator]
-            relay_out = relay_op(inputs, _get_input_types(op_node))
+            relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
 
             if isinstance(relay_out, tuple):
                 # This is for torch operators that return multiple outputs
@@ -2486,7 +2558,7 @@ def get_all_op_names(graph):
     return set(node.kind() for node in nodes)
 
 
-def from_pytorch(script_module, input_shapes, custom_convert_map=None):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_dtype="float32"):
     """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
     The companion parameters will be handled automatically.
 
@@ -2512,6 +2584,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
     params : dict of str to tvm.runtime.NDArray
         Dict of converted parameters stored in tvm.runtime.ndarray format
     """
+    import torch
+
     mod = tvm.IRModule()
     prelude = Prelude(mod)
 
@@ -2525,10 +2599,12 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
 
     op_names = get_all_op_names(graph)
     _report_missing_conversion(op_names, convert_map)
-    _check_inputs(graph, input_shapes)
 
-    params = script_module.state_dict()
-    outputs = _get_relay_input_vars(graph, input_shapes, prelude)
+    is_module = isinstance(script_module, torch.jit.ScriptModule)
+    params = script_module.state_dict() if is_module else {}
+    outputs = _get_relay_input_vars(graph, input_shapes, prelude,
+                                    default_dtype=default_dtype,
+                                    is_module=is_module)
     param_vars, tensors, packed_param_map = convert_params(graph, params)
     tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
 
@@ -2546,7 +2622,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
         convert_map.update(qnn_torch.convert_map)
 
     ret = convert_operators(_get_operator_nodes(graph.nodes()),
-                            outputs, ret_name, convert_map, prelude)
+                            outputs, ret_name, convert_map, prelude,
+                            default_dtype=default_dtype)
 
     mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
index 8c6c248..f6c7280 100644 (file)
@@ -501,6 +501,10 @@ def test_serialized_modules():
     runtime.run()
     tvm_result = runtime.get_output(0).asnumpy()
 
-    num_identical = np.sum(tvm_result == pt_result)
+    # with 0.5ish results, 1e-2 is relative accuracy close to 2**-6.
+    # for simple layers like here this should be achievable
+    # with 8 bit quantization
+    # we only require 90% match just to be sure
+    num_identical = np.sum(np.abs(tvm_result - pt_result) < 1e-2)
     match_ratio = num_identical / float(np.prod(tvm_result.shape))
-    assert match_ratio > 0.2
+    assert match_ratio > 0.90
index 96e9144..6ec3110 100644 (file)
@@ -152,7 +152,8 @@ def verify_model(model_name, input_data=[],
         assert False, "Unexpected input format"
 
     if torch.cuda.is_available():
-        baseline_model = baseline_model.cuda()
+        if isinstance(baseline_model, torch.nn.Module):
+            baseline_model = baseline_model.cuda()
         baseline_input = [inp.cuda() for inp in baseline_input]
 
     with torch.no_grad():
@@ -163,12 +164,14 @@ def verify_model(model_name, input_data=[],
     else:
         baseline_outputs = (baseline_outputs.cpu().numpy(),)
 
-    trace = torch.jit.trace(baseline_model, baseline_input).float().eval()
+    trace = torch.jit.trace(baseline_model, baseline_input)
+    if isinstance(baseline_model, torch.nn.Module):
+        trace = trace.float().eval()
 
-    if torch.cuda.is_available():
-        trace = trace.cuda()
-    else:
-        trace = trace.cpu()
+        if torch.cuda.is_available():
+            trace = trace.cuda()
+        else:
+            trace = trace.cpu()
 
     input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
     input_shapes = list(zip(input_names,
@@ -2363,6 +2366,23 @@ def test_forward_addcmul():
     t2 = torch.rand([1, 3]).float()
     verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
 
+def test_forward_traced_function():
+    def fn(t1, t2):
+        return t1 + t2
+
+    tensor1 = torch.randn(3, 4)
+    tensor2 = torch.randn(3, 4)
+    verify_model(fn, input_data=[tensor1, tensor2])
+
+def test_forward_dtypes():
+    def fn(t1, t2):
+        return 2.5 * t1 + t2
+
+    for dt in [torch.int32, torch.int64, torch.double]:
+        tensor1 = torch.randn(3, 4).to(dtype=dt)
+        tensor2 = torch.randn(3, 4).to(dtype=dt)
+        verify_model(fn, input_data=[tensor1, tensor2])
+
 
 def test_forward_matmul():
     torch.set_grad_enabled(False)
@@ -2526,6 +2546,8 @@ def test_forward_pretrained_bert_base_uncased():
 
 
 if __name__ == "__main__":
+    test_forward_traced_function()
+    test_forward_dtypes()
     # Single operator tests
     test_forward_add()
     test_forward_subtract()