return name
-def infer_type(node):
+def infer_type(node, mod=None):
"""A method to infer the type of an intermediate node in the relay graph."""
- mod = node if isinstance(node, _module.Module) else _module.Module.from_expr(node)
- mod = _transform.InferType()(mod)
- entry = mod["main"]
+ new_mod = _module.Module.from_expr(node)
+ if mod is not None:
+ new_mod.update(mod)
+ new_mod = _transform.InferType()(new_mod)
+ entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
-
-def infer_shape(inputs):
- """A method to get the output shape of an intermediate node in the graph."""
- out_type = infer_type(inputs)
- out_shapes = get_const_tuple(out_type.checked_type.shape)
- return out_shapes
-
+def infer_shape(inputs, mod=None):
+ """A method to get the output type of an intermediate node in the graph."""
+ out_type = infer_type(inputs, mod=mod)
+ checked_type = out_type.checked_type
+ if hasattr(checked_type, 'shape'):
+ # Regular operator that outputs tensors
+ return get_const_tuple(out_type.checked_type.shape)
+ # The return type is not a tensor, for example List
+ return checked_type
def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
def _get_tuple_param(params, input_node):
return tuple(_get_param(params, input_node))
+def _need_module_for_shape_inference(op):
+ return op in ['StridedSlice']
+
+def _need_prelude_for_shape_inference(op):
+ return "TensorArray" in op
+
def _rsqrt():
def _impl(inputs, attr, params):
inputs.append(tvm.relay.const(-0.5, attr['T'].name))
return _impl
def _stridedSlice():
- def _impl(inputs, attr, params):
+ def _impl(inputs, attr, params, mod):
"""Strided Slice.
Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride)
- out_shape = _infer_shape(out)
+ out_shape = _infer_shape(out, mod=mod)
if not fshape_indices:
fshape_indices = range(len(out_shape))
# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
- out_shapes = [_infer_shape(node_item) for node_item in self._nodes[node.name]]
+ out_shapes = [_infer_shape(node_item, self._mod)
+ for node_item in self._nodes[node.name]]
self._output_shapes[node.name] = out_shapes
if self._output_shapes[node.name] and shape and node.name in shape:
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
- out_shapes = [_infer_shape(node_item) for node_item in node_output]
+ out_shapes = [_infer_shape(node_item, self._mod) for node_item in node_output]
self._output_shapes[node.name] = out_shapes
out = []
if op_name in identity_list:
sym = get_relay_op(op_name)(*inputs, **attrs)
elif op_name in convert_map:
- if 'TensorArray' in op_name:
+ if _need_prelude_for_shape_inference(op_name):
sym = convert_map[op_name](inputs, attrs, self._params, self._prelude)
+ elif _need_module_for_shape_inference(op_name):
+ sym = convert_map[op_name](inputs, attrs, self._params, self._mod)
else:
sym = convert_map[op_name](inputs, attrs, self._params)