import json
import tvm
+from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
from .. import op as _op
from .. import module as _module
+from .. import scope_builder as _scope_builder
from ... import nd as _nd
from .common import StrAttrsDict
new_attrs['axis'] = attrs.get_int('axis')
return _op.nn.fifo_buffer(*inputs, **new_attrs)
+def _mx_cond(inputs, attrs, subgraphs):
+ assert len(subgraphs) == 3
+ cond_input_locs = json.loads(attrs.get_str("cond_input_locs"))
+ then_input_locs = json.loads(attrs.get_str("then_input_locs"))
+ else_input_locs = json.loads(attrs.get_str("else_input_locs"))
+ num_outputs = attrs.get_int("num_outputs")
+
+ input_args = []
+ for i, arg in enumerate(inputs):
+ var = _expr.var("arg%s" % i, _infer_type(arg).checked_type)
+ input_args.append(var)
+ cond_args = [input_args[i] for i in cond_input_locs]
+ then_args = [input_args[i] for i in then_input_locs]
+ else_args = [input_args[i] for i in else_input_locs]
+
+ cond_arg_shapes = [arg.type_annotation.shape for arg in cond_args]
+ cond_arg_dtype_info = [arg.type_annotation.dtype for arg in cond_args]
+ cond_func = _from_mxnet_impl(subgraphs[0], cond_arg_shapes, cond_arg_dtype_info)
+ cond = _expr.Call(cond_func, cond_args).astype("bool")
+ cond_shape = get_const_tuple(_infer_type(cond).checked_type.shape)
+ if len(cond_shape) > 0:
+ assert len(cond_shape) == 1 and cond_shape[0] == 1, "Condition is not scalar"
+ cond = _op.take(cond, _expr.const(1, "int"))
+
+ sb = _scope_builder.ScopeBuilder()
+ with sb.if_scope(cond):
+ then_arg_shapes = [arg.type_annotation.shape for arg in then_args]
+ then_arg_dtype_info = [arg.type_annotation.dtype for arg in then_args]
+ then_func = _from_mxnet_impl(subgraphs[1], then_arg_shapes, then_arg_dtype_info)
+ sb.ret(_expr.Call(then_func, then_args))
+ with sb.else_scope():
+ else_arg_shapes = [arg.type_annotation.shape for arg in else_args]
+ else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
+ else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
+ sb.ret(_expr.Call(else_func, else_args))
+ func = _expr.Function(input_args, sb.get())
+ ret = _expr.Call(func, inputs)
+ if num_outputs > 1:
+ ret = _expr.TupleWrapper(ret, num_outputs)
+ return ret
+
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
+ # control flow
+ "_cond" : _mx_cond,
# Depricated:
"Crop" : _mx_crop_like,
# List of missing operators that are present in NNVMv1
Converted relay Function
"""
assert symbol is not None
- jgraph = json.loads(symbol.tojson())
+ if isinstance(symbol, dict):
+ jgraph = symbol
+ else:
+ jgraph = json.loads(symbol.tojson())
jnodes = jgraph["nodes"]
node_map = {}
+ shape_idx = 0
for nid, node in enumerate(jnodes):
children = [node_map[e[0]][e[1]] for e in node["inputs"]]
node_name = node["name"]
op_name = node["op"]
if op_name == "null":
- shape = shape_dict[node_name] if node_name in shape_dict else None
+ if isinstance(shape_dict, dict):
+ shape = shape_dict[node_name] if node_name in shape_dict else None
+ elif isinstance(shape_dict, (list, tuple)):
+ shape = shape_dict[shape_idx]
+ else:
+ raise ValueError("Unknown type of shape_dict: %s" + type(shape_dict))
if isinstance(dtype_info, dict):
dtype = dtype_info[node_name] if node_name in dtype_info else "float32"
+ elif isinstance(dtype_info, (list, tuple)):
+ dtype = dtype_info[shape_idx]
else:
dtype = dtype_info
+ if isinstance(shape_dict, (list, tuple)):
+ shape_idx += 1
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
- res = _convert_map[op_name](children, attrs)
+ if op_name in ['_cond', '_foreach', '_while_loop']:
+ subgraphs = node['subgraphs']
+ res = _convert_map[op_name](children, attrs, subgraphs)
+ else:
+ res = _convert_map[op_name](children, attrs)
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2)
+def test_forward_cond():
+ def verify(a_np, b_np):
+ a_nd, b_nd = mx.nd.array(a_np), mx.nd.array(b_np)
+ pred = a_nd * b_nd < 5
+ then_func = lambda: (a_nd + 5) * (b_nd + 5)
+ else_func = lambda: (a_nd - 5) * (b_nd - 5)
+ ref_res = mx.nd.contrib.cond(pred, then_func, else_func)
+
+ a_sym, b_sym = mx.sym.var("a"), mx.sym.var("b")
+ pred = a_sym * b_sym < 5
+ then_func = lambda: (a_sym + 5) * (b_sym + 5)
+ else_func = lambda: (a_sym - 5) * (b_sym - 5)
+ mx_sym = mx.sym.contrib.cond(pred, then_func, else_func)
+
+ shape_dict = {"a": a_np.shape, "b": b_np.shape}
+ mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+ for target, ctx in ctx_list():
+ for kind in ["debug", "vm"]:
+ intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
+ op_res = intrp.evaluate()(a_np, b_np)
+ tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3)
+
+ verify(np.asarray([1.0], 'float32'), np.asarray([2.0],'float32'))
+ verify(np.asarray([4.0], 'float32'), np.asarray([3.0],'float32'))
+
if __name__ == '__main__':
test_forward_mlp()
test_forward_one_hot()
test_forward_convolution()
test_forward_deconvolution()
+ test_forward_cond()