# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
- # TODO: Figure out a better way to get typing to work for tensor + scalar
- type0 = input_types[0]
- if isinstance(inputs[1], _expr.Expr):
- type0 = input_types[1]
-
- type1 = input_types[1]
- if isinstance(inputs[0], _expr.Expr):
- type1 = input_types[0]
-
- data0 = _convert_elemwise_input(inputs[0], type0)
- data1 = _convert_elemwise_input(inputs[1], type1)
-
+ data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
return get_relay_op(name)(data0, data1)
return _impl
def _unary(name):
def _impl(inputs, input_types):
input_type = input_types[0]
- data = _convert_elemwise_input(inputs[0], input_type)
-
+ # this is just to ensure tensor input
+ data, = _pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name)(data)
return _impl
def _log1p():
def _impl(inputs, input_types):
# 1_plus_log x = log(x + 1)
- one = _expr.const(1, dtype="float32")
+ dtype, = input_types
+ one = _expr.const(1, dtype=dtype)
return _op.log(inputs[0] + one)
return _impl
def _arange():
def _impl(inputs, input_types):
def _get_value(val, dtype):
+ # dtype is a tvm dtype
if isinstance(val, _expr.Expr):
- return _op.cast(val, _convert_data_type(dtype))
+ return _op.cast(val, dtype)
return _create_typed_const(val, dtype)
def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(_infer_type(val).checked_type)
- return dtype if dtype != "float32" else "float"
+ return dtype
return inp_type
+ # PyTorch arange uses the following type semantics:
+ # - if a dtype is given, start, stop, step are converted to that dtype
+ # - if no dtype is given and all args are integral, dtype is int64
+ # - if no dtype is given and there is a float arg, dtype is float32
if len(inputs) == 5:
dtype0 = _get_type(inputs[0], input_types[0])
- dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1])
+ if inputs[1] is not None:
+ dtype = _convert_dtype_value(inputs[1])
+ elif dtype0.startswith("float"):
+ dtype = "float32"
+ else:
+ dtype = "int64"
start = _get_value(0, dtype)
stop = _get_value(inputs[0], dtype)
step = _get_value(1, dtype)
elif len(inputs) == 7:
types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
- dtype = "float" if "float" in types else _convert_dtype_value(inputs[3])
+ if inputs[3] is not None:
+ dtype = _convert_dtype_value(inputs[3])
+ elif any([t.startswith("float") for t in types]):
+ dtype = "float32"
+ else:
+ dtype = "int64"
start = _get_value(inputs[0], dtype)
stop = _get_value(inputs[1], dtype)
step = _get_value(inputs[2], dtype)
return _op.transform.arange(start=start,
stop=stop,
step=step,
- dtype=_convert_data_type(dtype))
+ dtype=dtype)
return _impl
def _squeeze():
if len(inputs) == 1:
axis = None
else:
+ # TODO (t-vi): why is the cast to int needed? similarly elsewhere
axis = [int(inputs[1])]
return _op.transform.squeeze(data, axis)
return _impl
def _split_with_sizes():
- def _impl(inputs, inputs_types):
+ def _impl(inputs, input_types):
data = inputs[0]
dim = int(inputs[2])
def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
- return _expr.const(1.0) / data
+ return _expr.const(1.0, dtype=input_types[0]) / data
return _impl
def _repeat():
def _addcdiv():
def _impl(inputs, input_types):
- data = inputs[0]
- c = _expr.const(inputs[3])
- t1 = inputs[1]
- t2 = inputs[2]
-
+ data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 / t2))
return _impl
def _addcmul():
def _impl(inputs, input_types):
- data = inputs[0]
- c = _expr.const(inputs[3])
- t1 = inputs[1]
- t2 = inputs[2]
-
+ data, t1, t2, c = _pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 * t2))
return _impl
def _where():
def _impl(inputs, input_types):
cond = inputs[0]
- x = inputs[1]
- y = inputs[2]
-
+ x, y = _pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
return _impl
msg = "Data type %s could not be parsed in ones op" % (type(data))
raise AssertionError(msg)
- dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+ dtype = _convert_dtype_value(inputs[1])
return _op.full(_expr.const(1), shape, dtype=dtype)
return _impl
out = _op.ones_like(data)
# If the input and the output datatype is different, do a cast
- dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
- if input_types[0] not in dtype:
+ dtype = _convert_dtype_value(inputs[1])
+ if input_types[0] != dtype:
out = _op.cast(out, dtype)
return out
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)
- dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+ dtype = _convert_dtype_value(inputs[1])
return _op.full(_expr.const(0), shape, dtype=dtype)
return _impl
out = _op.zeros_like(data)
# If the input and the output datatype is different, do a cast
- dtype = _convert_data_type(_convert_dtype_value(inputs[1]))
+ dtype = _convert_dtype_value(inputs[1])
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
raise AssertionError(msg)
if inputs[2] is not None: # dtype given
- dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+ dtype = _convert_dtype_value(inputs[2])
else:
dtype = data.type_annotation.dtype
out = _op.full_like(data, _expr.const(fill_value))
# If the input and the output datatype is different, do a cast
- dtype = _convert_data_type(_convert_dtype_value(inputs[2]))
+ dtype = _convert_dtype_value(inputs[2])
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
else:
stop = start + step
- dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
+ dtype = ("float32" if inputs[3] is not None
+ else _convert_dtype_value(inputs[3]))
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)
return _op.transform.arange(start=start,
stop=stop,
step=step,
- dtype=_convert_data_type(dtype))
+ dtype=dtype)
return _impl
def _elu():
def _impl(inputs, input_types):
data = inputs[0]
- alpha = _expr.const(float(inputs[1]))
- return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data)) + _op.nn.relu(data)
+ dtype = input_types[0]
+ alpha = _expr.const(float(inputs[1]), dtype=dtype)
+ return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
return _impl
def _celu():
def _impl(inputs, input_types):
data = inputs[0]
- alpha = _expr.const(float(inputs[1]))
- return alpha * _op.nn.relu(_expr.const(1.0) - _op.exp(data / alpha)) + _op.nn.relu(data)
+ dtype = input_types[0]
+ alpha = _expr.const(float(inputs[1]), dtype=dtype)
+ return alpha * _op.nn.relu(_expr.const(1, dtype=dtype)
+ - _op.exp(data / alpha)) + _op.nn.relu(data)
return _impl
def _gelu():
def _impl(inputs, input_types):
data = inputs[0]
+ dtype = input_types[0]
# gelu is data * normcdf(data)
# normcdf expressed as erf because we don't currently have that intrinsic
# note that there is also a fastgelu variant approximating normcdf
# with tanh and third order polynomials, but this is "true" gelu
- return data * (_expr.const(0.5) +
- _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5))
+ return data * (_expr.const(0.5, dtype=dtype) +
+ _op.erf(data * _expr.const(0.5**0.5, dtype=dtype))
+ * _expr.const(0.5, dtype=dtype))
return _impl
def _selu():
def _impl(inputs, input_types):
data = inputs[0]
# https://pytorch.org/docs/stable/nn.html#selu
- alpha = _expr.const(-1.6732632423543772848170429916717)
- gamma = _expr.const(1.0507009873554804934193349852946)
- return gamma * (alpha * _op.nn.relu(_expr.const(1.0)
+ dtype = input_types[0]
+ alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype)
+ gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype)
+ return gamma * (alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype)
- _op.exp(data)) + _op.nn.relu(data))
return _impl
def _softplus():
def _impl(inputs, input_types):
data = inputs[0]
- beta = _expr.const(float(inputs[1]))
- return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
+ dtype = input_types[0]
+ beta = _expr.const(float(inputs[1]), dtype=dtype)
+ return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1., dtype=dtype)) / beta
return _impl
def _avg_pool2d(prelude):
def _norm():
def _impl(inputs, input_types):
data = inputs[0]
+ dtype = input_types[0]
axis = None
keepdims = False
if len(inputs) > 3:
elif order == np.NINF:
return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims)
else:
- reci_order = _expr.const(1.0 / order)
+ reci_order = _expr.const(1.0 / order, dtype=dtype)
order = _expr.const(order)
return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order),
axis=axis,
if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
- "estimator. Pytorch's Bessel's correction is not supported."
+ "estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
return _op.reduce.std(data, axis=axis, keepdims=keepdims)
if unbiased:
msg = "Currently only supports standard-deviation calculated via the biased "\
- "estimator. Pytorch's Bessel's correction is not supported."
+ "estimator. PyTorch's Bessel's correction is not supported."
raise NotImplementedError(msg)
return _op.reduce.variance(data, axis=axis, keepdims=keepdims)
def _impl(inputs, input_types):
assert len(inputs) == 2
assert len(input_types) == 2
- return _op.cast(inputs[0], _convert_data_type(input_types[1]))
+ return _op.cast(inputs[0], input_types[1])
return _impl
def _rsub():
def _impl(inputs, input_types):
- # TODO: Figure out a better way to get typing to work for tensor + scalar
- type0 = input_types[0]
- if isinstance(inputs[1], _expr.Expr):
- type0 = input_types[1]
+ data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
- type1 = input_types[1]
- if isinstance(inputs[0], _expr.Expr):
- type1 = input_types[0]
-
- data1 = _convert_elemwise_input(inputs[0], type0)
- data0 = _convert_elemwise_input(inputs[1], type1)
+ # TODO (t-vi): should this also be part of the type promotion?
alpha = _expr.const(float(inputs[2]))
- return get_relay_op("subtract")(data0, alpha * data1)
+ # note: rsub means data0 and data1 swap places
+ return get_relay_op("subtract")(data1, alpha * data0)
return _impl
return _impl
+def _pytorch_result_type(dtypes, non_tensor_inputs):
+ """This promotes TVM dtypes like PyTorch would"""
+ import torch
+ dtype_map = {
+ "float64": torch.float64,
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ "int64": torch.int64,
+ "int32": torch.int32,
+ "int16": torch.int16,
+ "int8": torch.int8,
+ "uint8": torch.uint8,
+ "bool": torch.bool
+ }
+ if len(dtypes) > 0:
+ result_type = dtypes[0]
+ for dt in dtypes[1:]:
+ if dt != result_type: # we don't want to work with same types as we
+ # don't do quantized here (which cannot be promoted?)
+ result_type = _convert_data_type(str(torch.result_type(
+ torch.zeros((), dtype=dtype_map[result_type]),
+ torch.zeros((), dtype=dtype_map[dt]))))
+ else:
+ result_type = "bool" # this is the smallest type...
+ for inp in non_tensor_inputs:
+ result_type = _convert_data_type(
+ str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]),
+ inp)))
+ return result_type
+
+def _pytorch_promote_types(inputs, dtypes):
+ """This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
+ tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
+ non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
+ result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
+ results = []
+ for inp, dt in zip(inputs, dtypes):
+ if np.isscalar(inp):
+ results.append(_expr.const(inp, dtype=result_type))
+ elif dt == result_type:
+ results.append(inp)
+ else:
+ results.append(_op.cast(inp, result_type))
+ return results
+
# Helper functions for operator implementation
def _convert_dtype_value(val):
+ """converts a PyTorch the PyTorch numeric type id to a torch scalar type."""
convert_torch_dtype_map = {7:"torch.float64",
6:"torch.float32",
5:"torch.float16",
0:"torch.unit8",
None:"torch.int64"} # Default is torch.int64
if val in convert_torch_dtype_map:
- return convert_torch_dtype_map[val]
+ return _convert_data_type(convert_torch_dtype_map[val])
else:
msg = "Torch data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)
-def _convert_data_type(input_type):
+def _convert_data_type(input_type, default_dtype=None):
+ """converts the PyTorch scalar type input_type to a TVM dtype.
+ optionally, default_dtype can be a TVM dtype that is used
+ if input_type is None (but not when it is unknown)"""
+ if input_type is None and default_dtype is not None:
+ return default_dtype
+
+ input_type = input_type.lower()
if input_type in ["double", "torch.float64"]:
return "float64"
elif input_type in ["float", "torch.float32"]:
return "int8"
elif input_type in ["byte", "torch.uint8"]:
return "uint8"
+ elif input_type in ["quint8", "torch.quint8"]:
+ return "quint8"
+ elif input_type in ["qint8", "torch.qint8"]:
+ return "qint8"
+ elif input_type in ["qint32", "torch.qint32"]:
+ return "qint32"
+ elif input_type in ["bool", "torch.bool"]:
+ return "bool"
else:
- raise NotImplementedError("input_type {} is not handled yet" % (input_type))
- return "float32"
+ raise NotImplementedError("input_type {} is not handled yet".format(input_type))
+ return "float32" # Never reached
-def _create_typed_const(data, data_type):
- dtype = _convert_data_type(data_type)
+def _create_typed_const(data, dtype):
+ """create a (scalar) constant of given value and dtype.
+ dtype should be a TVM dtype"""
if dtype == "float64":
typed_data = _expr.const(np.float64(data), dtype=dtype)
elif dtype == "uint8":
typed_data = _expr.const(np.uint8(data), dtype=dtype)
else:
- raise NotImplementedError("input_type {} is not handled yet" % (data_type))
+ raise NotImplementedError("input_type {} is not handled yet".format(dtype))
return typed_data
-def _convert_elemwise_input(data, input_type):
- import torch
- if isinstance(data, torch.Tensor):
- return _expr.const(data.item(), dtype=_convert_data_type(input_type))
- elif not isinstance(data, _expr.Expr):
- return _expr.const(data, dtype=_convert_data_type(input_type))
- else:
- return data
-
def _wrap_const(c):
if not isinstance(c, (_expr.Expr, list, tvm.tir.expr.Any)):
return _expr.const(c)
"aten::mean" : _mean(prelude),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(prelude),
+ "aten::bmm" : _matmul(prelude),
"aten::expand" : _expand(),
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
def _is_int_seq(seq):
+ # TODO (t-vi): handle non-int constants? (like numpy.intXX)
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
def _get_tensor_and_var(torch_tensor, name):
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
- var = _expr.var(name, shape=tensor.shape)
+ var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
return tensor, var
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
-
-def _check_inputs(graph, input_shapes):
- """
- Check the graph inputs match the expected number of inputs
- and are in the correct format
- """
- ir_inputs = _get_graph_input_names(graph)
-
- if not isinstance(input_shapes, list):
- msg = "Graph inputs input_shapes should be list"
- raise RuntimeError(msg)
- missing_inputs = len(ir_inputs) - len(input_shapes)
- if missing_inputs > 0:
- msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
- raise RuntimeError(msg)
-
- for num, inp in enumerate(input_shapes):
- if num < len(ir_inputs):
- if not isinstance(inp, tuple):
- msg = "Graph input {} is not a tuple".format(num)
- raise RuntimeError(msg)
- if (len(inp) != 2 or not isinstance(inp[0], str)):
- msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
- raise RuntimeError(msg)
- else:
- msg = "Unused graph input {} in input_shapes".format(inp)
- logging.warning(msg)
-
-
def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])
+def _get_pytorch_value_type(typ, default_dtype="float32"):
+ kind = typ.kind()
+ if kind == 'TensorType':
+ if typ.scalarType() is None:
+ # Tensor's type can be unknown if we use torch.jit.script(...)
+ # Defaults can be passed in, if not it is float32
+ logging.warning("Untyped Tensor found, assume it is %s", default_dtype)
+ return default_dtype
+ else:
+ return _convert_data_type(typ.scalarType())
+
+ elif kind == 'ListType':
+ return "ListType"
+ elif kind in ['IntType', 'FloatType', 'BoolType',
+ 'StringType', 'OptionalType']:
+ pt_dtype = str(typ).lower()
+ dtype = pt_dtype if pt_dtype == 'OptionalType' else _convert_data_type(pt_dtype)
+ return dtype
+ else:
+ return 'UnsupportedType'
-def _get_input_types(op_node):
- """ Returns a torch type for each input nodes """
- input_list_types = []
- for input_node in op_node.inputs():
- in_ty = input_node.type()
- input_node_kind = in_ty.kind()
- if input_node_kind == 'TensorType':
- if in_ty.scalarType() is None:
- # Tensor's type can be unknown if we use torch.jit.script(...)
- # Defaults to float for now
- logging.warning("Untyped Tensor found, assume it is float")
- input_list_types.append("float")
- else:
- input_list_types.append(in_ty.scalarType().lower())
- elif input_node_kind == 'ListType':
- input_list_types.append("ListType")
- elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
- 'StringType', 'OptionalType']:
- input_list_types.append(str(in_ty).lower())
- else:
- input_list_types.append('UnsupportedType')
+def _get_input_types(op_node, default_dtype="float32"):
+ """Returns a TVM dtype for each input nodes derived from the torch type"""
+ return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
+ for i in op_node.inputs()]
- if op_node.kind() in ['aten::ones', 'aten::zeros']:
- node_type = op_node.output().type()
- scalar_type = node_type.scalarType()
- if scalar_type:
- input_list_types[0] = scalar_type.lower()
- return input_list_types
+def _get_output_types(op_node, default_dtype="float32"):
+ """Returns a TVM dtype for each input nodes derived from the torch type"""
+ return [_get_pytorch_value_type(i.type(), default_dtype=default_dtype)
+ for i in op_node.outputs()]
def _get_constant(node):
attr_name = attribute_names[0]
ty = node.output().type().kind()
- if ty in ["IntType", "BoolType"]:
+ if ty == "IntType":
return node.i(attr_name)
+ elif ty == "BoolType":
+ return bool(node.i(attr_name))
elif ty in ["FloatType", "LongType"]:
return node.f(attr_name)
elif ty in ["TensorType", "CompleteTensorType"]:
tensor = node.t(attr_name)
if len(tensor.shape) == 0: # tensor(0.1)
- return float(tensor)
+ # TODO(t-vi): When is this needed?
+ return tensor.item()
return _wrap_const(tensor.numpy())
elif ty == "DeviceObjType":
return node.s(attr_name)
return ops
-def _get_graph_input_names(graph):
- """ Get the graph input names (use after graph copy and run jit passes) """
- # Variable names could change the first time a copy is made and after
- # _run_jit_passes is called, expected that those functions already invoked
- ir_inputs = _get_input_names(graph)
- return ir_inputs[1:] # remove self at the 0th arg
-
-
-def _get_relay_input_vars(graph, input_shapes, prelude):
+def _get_relay_input_vars(graph, input_shapes, prelude, is_module=True, default_dtype="float32"):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
- def get_relay_ty(ishape):
- if _is_int_seq(ishape) or len(ishape) == 0:
- return TensorType(ishape)
- elif isinstance(ishape, tuple):
- return TupleType([get_relay_ty(elem) for elem in ishape])
- elif isinstance(ishape, list):
- assert len(ishape) > 0
- elem_tys = [get_relay_ty(s) for s in ishape]
- msg = "List elements should have identical types"
- assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
+
+ graph_inputs = list(graph.inputs())
+ if is_module:
+ # a module has "self" as first input, which we do not need/want
+ graph_inputs = graph_inputs[1:]
+
+ if not isinstance(input_shapes, list):
+ msg = "Graph inputs input_shapes should be a list"
+ raise RuntimeError(msg)
+
+ if len(graph_inputs) != len(input_shapes):
+ msg = "PyTorch has {} inputs and input_shapes lists {}.".format(
+ len(graph_inputs), len(input_shapes))
+ raise RuntimeError(msg)
+
+ def get_relay_ty(ishape, pt_type):
+ if pt_type.kind() == 'TensorType':
+ if not (_is_int_seq(ishape) or len(ishape) == 0):
+ msg = "Shape for Tensors must be lists of ints"
+ raise RuntimeError(msg)
+ if ((pt_type.dim() is not None and pt_type.dim() != len(ishape)) or
+ (pt_type.sizes() is not None
+ and any([s1 != s2 for s1, s2 in zip(pt_type.sizes(), ishape)]))):
+ msg = "Shapes of input list and information in the graph do not match"
+ raise RuntimeError(msg)
+ pt_dtype = pt_type.scalarType()
+ dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
+ return TensorType(ishape, dtype)
+ elif pt_type.kind() == 'TupleType':
+ if not isinstance(ishape, tuple):
+ msg = "Shapes for tuples must be tuples"
+ raise RuntimeError(msg)
+ return TupleType([get_relay_ty(elem, pt_t)
+ for elem, pt_t in zip(ishape, pt_type.elements())])
+ elif pt_type.kind() == 'ListType':
+ if not isinstance(ishape, list):
+ msg = "Shapes for lists must be lists"
+ raise RuntimeError(msg)
+ pt_elemtype = pt_type.getElementType()
+ elem_tys = [get_relay_ty(s, pt_elemtype) for s in ishape]
+ if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0], elem_tys)):
+ msg = "List elements need have identical types"
+ raise RuntimeError(msg)
return prelude.l(elem_tys[0])
+ elif pt_type.kind() == 'OptionalType':
+ # we do not support None yet, so we fill in the type
+ return get_relay_ty(ishape, pt_type.getElementType())
+ # TODO: scalar inputs
raise NotImplementedError("unsupported input type")
- input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
input_vars = {}
- ir_inputs = _get_graph_input_names(graph)
+
+ for num, inp in enumerate(input_shapes):
+ if not isinstance(inp, tuple):
+ msg = "Graph input {} is not a tuple".format(num)
+ raise RuntimeError(msg)
+ if (len(inp) != 2 or not isinstance(inp[0], str)):
+ msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
+ raise RuntimeError(msg)
+
+ input_types = [(name, get_relay_ty(shape, gi.type()))
+ for (name, shape), gi in zip(input_shapes, graph_inputs)]
+
+ ir_inputs = [i.debugName() for i in graph_inputs]
for ir_input, (name, itype) in zip(ir_inputs, input_types):
inp = _expr.var(name, type_annotation=itype)
# Translate from graph input to user input name
return params, param_tensors, packed_param_map
-def convert_block(block, outputs, convert_map, prelude):
+def convert_block(block, outputs, convert_map, prelude, default_dtype="float32"):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
- return convert_operators(ops, outputs, ret_names, convert_map, prelude)
+ return convert_operators(ops, outputs, ret_names, convert_map, prelude,
+ default_dtype=default_dtype)
-def convert_if(if_node, outputs, convert_map, prelude):
+def convert_if(if_node, outputs, convert_map, prelude, default_dtype="float32"):
""" Translate Torch prim::If to Relay If """
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
- true_branch = convert_block(blocks[0], outputs, convert_map, prelude)
- false_branch = convert_block(blocks[1], outputs, convert_map, prelude)
+ true_branch = convert_block(blocks[0], outputs, convert_map, prelude,
+ default_dtype=default_dtype)
+ false_branch = convert_block(blocks[1], outputs, convert_map, prelude,
+ default_dtype=default_dtype)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
-def convert_operators(operators, outputs, ret_names, convert_map, prelude):
+def convert_operators(operators, outputs, ret_names, convert_map, prelude, default_dtype="float32"):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
unpacked = _unpack_tuple(inputs[0])
outputs.update(zip(_get_output_names(op_node), unpacked))
elif operator == "prim::If":
- if_out = convert_if(op_node, outputs, convert_map, prelude)
+ if_out = convert_if(op_node, outputs, convert_map, prelude, default_dtype=default_dtype)
outputs[node_name] = if_out
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, convert_map, prelude)
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = convert_map[operator]
- relay_out = relay_op(inputs, _get_input_types(op_node))
+ relay_out = relay_op(inputs, _get_input_types(op_node, default_dtype=default_dtype))
if isinstance(relay_out, tuple):
# This is for torch operators that return multiple outputs
return set(node.kind() for node in nodes)
-def from_pytorch(script_module, input_shapes, custom_convert_map=None):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None, default_dtype="float32"):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
+ import torch
+
mod = tvm.IRModule()
prelude = Prelude(mod)
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names, convert_map)
- _check_inputs(graph, input_shapes)
- params = script_module.state_dict()
- outputs = _get_relay_input_vars(graph, input_shapes, prelude)
+ is_module = isinstance(script_module, torch.jit.ScriptModule)
+ params = script_module.state_dict() if is_module else {}
+ outputs = _get_relay_input_vars(graph, input_shapes, prelude,
+ default_dtype=default_dtype,
+ is_module=is_module)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
convert_map.update(qnn_torch.convert_map)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
- outputs, ret_name, convert_map, prelude)
+ outputs, ret_name, convert_map, prelude,
+ default_dtype=default_dtype)
mod["main"] = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])