[RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (#5302)
authorHuacong Yang <will.yang@rock-chips.com>
Fri, 10 Apr 2020 21:46:03 +0000 (05:46 +0800)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 21:46:03 +0000 (14:46 -0700)
python/tvm/relay/frontend/caffe2.py

index f4fcd92..8a5803f 100644 (file)
@@ -172,6 +172,12 @@ class Add(Elemwise):
     name = 'add'
 
 
+class Mul(Elemwise):
+    """ Operator converter for Mul.
+    """
+    name = 'multiply'
+
+
 class Pool(Caffe2OpConverter):
     """ A helper class for pool op converters.
     """
@@ -233,6 +239,33 @@ class Conv(Caffe2OpConverter):
         return out
 
 
+class ConvTranspose(Caffe2OpConverter):
+    """ Operator converter for ConvTranspose.
+    """
+
+    @classmethod
+    def _impl(cls, inputs, args, params):
+        # get number of channels
+        channels = infer_channels(inputs[1], True)
+        args['channels'] = channels
+        _clean_up_pool_args(args)
+        out = AttrCvt(
+            op_name=dimension_picker('conv', '_transpose'),
+            transforms={
+                'kernel_shape': 'kernel_size',
+                'pads': ('padding', (0, 0), revert_caffe2_pad),
+                'dilations': ('dilation', (1, 1)),
+                'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
+            },
+            excludes=[],
+            ignores=_caffe2_internal_args,
+            custom_check=dimension_constraint())(inputs[:2], args, params)
+        use_bias = len(inputs) == 3
+        if use_bias:
+            out = _op.nn.bias_add(out, inputs[2])
+        return out
+
+
 class Concat(Caffe2OpConverter):
     """ Operator converter for Concat.
     """
@@ -353,12 +386,14 @@ def _get_convert_map():
         # caffe2 common operators
         'Add': Add.get_converter(),
         'Sum': Sum.get_converter(),
+        'Mul': Mul.get_converter(),
         'Softmax': Softmax.get_converter(),
 
         # nn
         'AveragePool': AveragePool.get_converter(),
         'MaxPool': MaxPool.get_converter(),
         'Conv': Conv.get_converter(),
+        'ConvTranspose': ConvTranspose.get_converter(),
         'Concat': Concat.get_converter(),
         'FC': FC.get_converter(),
         'SpatialBN': SpatialBN.get_converter(),