[COREML]multiple output support, reshape, split ops added (#6296)
authorSiju Samuel <siju.samuel@huawei.com>
Fri, 21 Aug 2020 04:07:04 +0000 (09:37 +0530)
committerGitHub <noreply@github.com>
Fri, 21 Aug 2020 04:07:04 +0000 (21:07 -0700)
* [COREML]multiple output support, reshape, split ops added

* Review comments addressed

python/tvm/relay/frontend/coreml.py
tests/python/frontend/coreml/test_forward.py

index b8ef1f2..65f1c2a 100644 (file)
@@ -411,6 +411,14 @@ def _ReduceLayerParams(op, inexpr, etab):
         raise tvm.error.OpAttributeUnImplemented(msg.format(mode))
 
 
+def _ReshapeLayerParams(op, inexpr, etab):
+    return _op.reshape(inexpr, op.targetShape)
+
+
+def _SplitLayerParams(op, inexpr, etab):
+    return _op.split(inexpr, op.nOutputs, axis=-3)
+
+
 _convert_map = {
     'NeuralNetworkMeanImage': _NeuralNetworkMeanImage,
     'NeuralNetworkImageScaler': _NeuralNetworkImageScaler,
@@ -435,6 +443,8 @@ _convert_map = {
     'MinLayerParams': _MinLayerParams,
     'UnaryFunctionLayerParams': _UnaryFunctionLayerParams,
     'ReduceLayerParams': _ReduceLayerParams,
+    'ReshapeLayerParams': _ReshapeLayerParams,
+    'SplitLayerParams': _SplitLayerParams,
 }
 
 # SAME padding: https://www.tensorflow.org/api_guides/python/nn
@@ -464,7 +474,7 @@ def get_pad_value(data, kernel, stride):
     return pad_before, pad_after
 
 
-def coreml_op_to_relay(op, inname, outname, etab):
+def coreml_op_to_relay(op, inname, outnames, etab):
     """Convert coreml layer to a Relay expression and update the expression table.
 
     Parameters
@@ -474,7 +484,7 @@ def coreml_op_to_relay(op, inname, outname, etab):
     inname : str or list of str
         Name of the input Relay expression.
 
-    outname : str
+    outnames : str or list of str
         Name of the output Relay expression.
 
     etab : relay.frontend.common.ExprTable
@@ -488,9 +498,17 @@ def coreml_op_to_relay(op, inname, outname, etab):
         insym = etab.get_expr(inname)
     else:
         insym = [etab.get_expr(i) for i in inname]
-    ret = _convert_map[classname](op, insym, etab)
-    if outname:
-        etab.set_expr(outname, ret, force_override=True)
+    outs = _convert_map[classname](op, insym, etab)
+
+    if outnames:
+        if isinstance(outnames, _base.string_types) or len(outnames) == 1:
+            outname = outnames if isinstance(outnames, _base.string_types) else outnames[0]
+            etab.set_expr(outname, outs, force_override=True)
+        else:
+            # the number of ouputs from model op and tvm relay must be same
+            assert len(outnames) == len(outs)
+            for outname, out in zip(outnames, outs):
+                etab.set_expr(outname, out, force_override=True)
 
 
 def from_coreml(model, shape=None):
@@ -550,16 +568,18 @@ def from_coreml(model, shape=None):
     for l in cc.layers:
         layertype = l.WhichOneof('layer')
         layerop = getattr(l, layertype)
-        assert len(l.output) == 1
         if len(l.input) == 1:
-            coreml_op_to_relay(layerop, l.input[0], l.output[0], etab)
+            coreml_op_to_relay(layerop, l.input[0], l.output, etab)
         else:
-            coreml_op_to_relay(layerop, list(l.input), l.output[0], etab)
+            coreml_op_to_relay(layerop, list(l.input), l.output, etab)
 
     outexpr = [etab.get_expr(o.name) if o.name in etab.exprs else _expr.var(o.name)
                for o in spec.description.output]
-    # for now return first output
-    outexpr = outexpr[0]
+
+    # check there are multiple outputs in the model and all are there in etab
+    multi_out = all([bool(o.name in etab.exprs) for o in spec.description.output])
+    outexpr = _expr.Tuple(outexpr) if multi_out else outexpr[0]
+
     func = _function.Function(analysis.free_vars(outexpr), outexpr)
     params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
     return IRModule.from_expr(func), params
index cbae5f3..5ae7a6c 100644 (file)
@@ -586,6 +586,68 @@ def test_forward_reduce():
                 _verify_reduce(dshape, "argmax", axis, np.argmax, dtype='int32')
 
 
+def verify_reshape(input_dim, target_shape, mode):
+    dtype = 'float32'
+
+    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
+    ref_val = np.reshape(a_np, target_shape)
+
+    inputs = [('input', datatypes.Array(*input_dim))]
+    output = [('output', datatypes.Array(*ref_val.shape))]
+    builder = NeuralNetworkBuilder(inputs, output)
+    builder.add_reshape(name="reshape",
+                       input_name='input',
+                       output_name='output',
+                       target_shape=target_shape,
+                       mode=mode)
+
+    model = cm.models.MLModel(builder.spec)
+    for target, ctx in ctx_list():
+        out = run_tvm_graph(model, target, ctx, [a_np],
+                            ['input'], ref_val.shape, dtype)
+        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
+
+
+def test_forward_reshape():
+    for mode in [0, 1]:
+        verify_reshape((20,), (1, 2, 2, 5), mode)
+        verify_reshape((1, 3, 20, 20), (1, 12, 10, 10), mode)
+
+
+def verify_split(input_dim, nOutputs):
+    dtype = 'float32'
+
+    a_np = np.random.uniform(-100.0, 100.0, size=input_dim).astype(dtype)
+    ref_val = np.split(a_np, nOutputs, axis=-3)
+
+    inputs = [('input', datatypes.Array(*input_dim))]
+
+    output_names = []
+    outputs = []
+    output_shapes = []
+    for i, out in enumerate(ref_val):
+        output_name = "output" + str(i)
+        output_names = output_names + [output_name]
+        outputs = outputs + [(output_name, datatypes.Array(*out.shape))]
+        output_shapes = output_shapes + [out.shape]
+
+    builder = NeuralNetworkBuilder(inputs, outputs)
+    builder.add_split(name="split",
+                      input_name='input',
+                      output_names=output_names)
+
+    model = cm.models.MLModel(builder.spec)
+    for target, ctx in ctx_list():
+        out = run_tvm_graph(model, target, ctx, [a_np],
+                            ['input'], output_shapes, [dtype] * len(output_shapes))
+        tvm.testing.assert_allclose(out, ref_val, rtol=1e-5)
+
+
+def test_forward_split():
+    verify_split((1, 4, 4, 4,), 2)
+    verify_split((1, 3, 30, 20,), 3)
+
+
 def verify_image_scaler(input_dim, blue_bias=0.0, green_bias=0.0, red_bias=0.0, image_scale=1.0):
     dtype = 'float32'
     a_np = np.random.uniform(size=input_dim).astype(dtype)
@@ -664,6 +726,8 @@ if __name__ == '__main__':
     test_forward_min()
     test_forward_unary()
     test_forward_reduce()
+    test_forward_reshape()
+    test_forward_split()
     test_mobilenet_checkonly()
     test_resnet50_checkonly()
     test_forward_image_scaler()