[MXNET]MaxPool3d and AvgPool3d Ops support added (#5614)
authorSamuel <siju.samuel@huawei.com>
Wed, 20 May 2020 01:09:44 +0000 (06:39 +0530)
committerGitHub <noreply@github.com>
Wed, 20 May 2020 01:09:44 +0000 (10:09 +0900)
python/tvm/relay/frontend/mxnet.py
tests/python/frontend/mxnet/test_forward.py

index e6384f7..edf6680 100644 (file)
@@ -318,6 +318,34 @@ def _mx_pooling(inputs, attrs):
             new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
         return new_op(inputs[0], **new_attrs)
 
+    def _pool3d(new_op, is_avg):
+        kernel_size = attrs.get_int_tuple("kernel")
+        if len(kernel_size) != 3:
+            raise tvm.error.OpAttributeInvalid(
+                'Only 3D kernels are supported for operator Pool3D.')
+        new_attrs = {}
+        new_attrs["pool_size"] = kernel_size
+        new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1))
+        new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0))
+        new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full")
+        if is_avg:
+            new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True)
+        return new_op(inputs[0], **new_attrs)
+
+    #3D pooling
+    if len(_infer_shape(inputs[0])) == 5:
+        if pool_type == "max":
+            if global_pool:
+                return _op.nn.global_max_pool3d(inputs[0])
+            return _pool3d(_op.nn.max_pool3d, False)
+        if pool_type == "avg":
+            if global_pool:
+                return _op.nn.global_avg_pool3d(inputs[0])
+            return _pool3d(_op.nn.avg_pool3d, True)
+        raise tvm.error.OpNotImplemented(
+            'Operator {} Pooling is not supported for frontend MXNet.' \
+                .format(pool_type.capitalize()))
+    #2D Pooling
     if pool_type == "max":
         if global_pool:
             return _op.nn.global_max_pool2d(inputs[0])
@@ -327,7 +355,8 @@ def _mx_pooling(inputs, attrs):
             return _op.nn.global_avg_pool2d(inputs[0])
         return _pool2d(_op.nn.avg_pool2d, True)
     raise tvm.error.OpNotImplemented(
-        'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
+        'Operator {} Pooling is not supported for frontend MXNet.' \
+            .format(pool_type.capitalize()))
 
 
 def _mx_adaptive_avg_pooling(inputs, attrs):
index 8113271..6e8acde 100644 (file)
@@ -179,6 +179,14 @@ def test_forward_pooling():
     mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max')
     verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8))
 
+def test_forward_pooling3d():
+    data = mx.sym.var('data')
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg')
+    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))
+
+    mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max')
+    verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8))
+
 def test_forward_adaptive_pooling():
     data = mx.sym.var('data')
     mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,))
@@ -1123,6 +1131,7 @@ if __name__ == '__main__':
     test_forward_pad()
     test_forward_slice()
     test_forward_pooling()
+    test_forward_pooling3d()
     test_forward_adaptive_pooling()
     test_forward_lrn()
     test_forward_ones()