[Test] enable NHWC of `relay.testing.mobilenet` (#3886)
author黎明灰烬 <sakltian@yeah.net>
Thu, 5 Sep 2019 18:32:21 +0000 (02:32 +0800)
committerThierry Moreau <moreau@uw.edu>
Thu, 5 Sep 2019 18:32:21 +0000 (11:32 -0700)
* [Relay] enable NHWC of `relay.testing.mobilenet`

In this way, we can play around NHWC inside TVM regardless of
the frontends.

* [Test] test for NHWC of relay.testing.mobilenet

python/tvm/relay/testing/layers.py
python/tvm/relay/testing/mobilenet.py
tests/python/relay/benchmarking/benchmark_vm.py

index 4c263a1..e153a4e 100644 (file)
@@ -152,3 +152,33 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
     data = relay.nn.dense(data, weight, units, **kwargs)
     data = relay.nn.bias_add(data, bias, axis=-1)
     return data
+
+def conv_kernel_layout(data_layout, is_depthwise=False):
+    """Map the data layout to corresponding kernel layout.
+
+    Arbitrary layout is not fully supported in TOPI yet.
+
+    Parameters
+    ----------
+    data_layout : str
+        The data_layout, can be 'NCHW', 'NHWC'.
+
+    is_depthwise : bool, optional
+        Whether the conv is a depthwise convolution.
+
+    Returns
+    -------
+    result : str
+        The corresponding kernel layout.
+    """
+    conv_layout_map = {
+        'NCHW': 'OIHW',
+        'NHWC': 'HWIO',
+    }
+    depthwise_conv_layout_map = {
+        'NCHW': 'OIHW',
+        'NHWC': 'HWOI',
+    }
+    mapping = depthwise_conv_layout_map if is_depthwise else conv_layout_map
+    assert data_layout in mapping, "Unknown data layout %s" % data_layout
+    return mapping[data_layout]
index 3b068c0..f76b0c2 100644 (file)
@@ -23,8 +23,9 @@ from tvm import relay
 from . import layers
 from .init import create_workload
 
+
 def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
-               padding=(1, 1), epsilon=1e-5):
+               padding=(1, 1), epsilon=1e-5, layout='NCHW'):
     """Helper function to construct conv_bn-relu"""
     # convolution + bn + relu
     conv = layers.conv2d(
@@ -33,7 +34,8 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
         kernel_size=kernel_size,
         strides=strides,
         padding=padding,
-        data_layout='NCHW',
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout),
         name=name+'_conv')
     bn = layers.batch_norm_infer(data=conv, epsilon=epsilon, name=name + '_bn')
     act = relay.nn.relu(data=bn)
@@ -42,7 +44,7 @@ def conv_block(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
 
 def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
                          kernel_size=(3, 3), downsample=False, padding=(1, 1),
-                         epsilon=1e-5):
+                         epsilon=1e-5, layout='NCHW'):
     """Helper function to get a separable conv block"""
     if downsample:
         strides = (2, 2)
@@ -56,6 +58,8 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
         kernel_size=kernel_size,
         strides=strides,
         padding=padding,
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout, True),
         name=name+'_depthwise_conv1')
     bn1 = layers.batch_norm_infer(data=conv1, epsilon=epsilon, name=name+'_bn1')
     act1 = relay.nn.relu(data=bn1)
@@ -66,7 +70,8 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
         kernel_size=(1, 1),
         strides=(1, 1),
         padding=(0, 0),
-        data_layout='NCHW',
+        data_layout=layout,
+        kernel_layout=layers.conv_kernel_layout(layout),
         name=name + '_conv2')
     bn2 = layers.batch_norm_infer(data=conv2, epsilon=epsilon, name=name+'_bn2')
     act2 = relay.nn.relu(data=bn2)
@@ -74,36 +79,45 @@ def separable_conv_block(data, name, depthwise_channels, pointwise_channels,
 
 
 def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
-               dtype='float32', alpha=1.0, is_shallow=False):
+               dtype='float32', alpha=1.0, is_shallow=False, layout='NCHW'):
     """Function to construct a MobileNet"""
     data = relay.var("data", shape=data_shape, dtype=dtype)
-    body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2))
+    body = conv_block(data, 'conv_block_1', int(32*alpha), strides=(2, 2),
+                      layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_1',
-                                int(32*alpha), int(64*alpha))
+                                int(32*alpha), int(64*alpha), layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_2',
-                                int(64*alpha), int(128*alpha), downsample=True)
+                                int(64*alpha), int(128*alpha), downsample=True,
+                                layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_3',
-                                int(128*alpha), int(128*alpha))
+                                int(128*alpha), int(128*alpha), layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_4',
-                                int(128*alpha), int(256*alpha), downsample=True)
+                                int(128*alpha), int(256*alpha), downsample=True,
+                                layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_5',
-                                int(256*alpha), int(256*alpha))
+                                int(256*alpha), int(256*alpha), layout=layout)
     body = separable_conv_block(body, 'separable_conv_block_6',
-                                int(256*alpha), int(512*alpha), downsample=True)
+                                int(256*alpha), int(512*alpha), downsample=True,
+                                layout=layout)
     if is_shallow:
         body = separable_conv_block(body, 'separable_conv_block_7',
-                                    int(512*alpha), int(1024*alpha), downsample=True)
+                                    int(512*alpha), int(1024*alpha),
+                                    downsample=True, layout=layout)
         body = separable_conv_block(body, 'separable_conv_block_8',
-                                    int(1024*alpha), int(1024*alpha), downsample=True)
+                                    int(1024*alpha), int(1024*alpha),
+                                    downsample=True, layout=layout)
     else:
         for i in range(7, 12):
             body = separable_conv_block(body, 'separable_conv_block_%d' % i,
-                                        int(512*alpha), int(512*alpha))
+                                        int(512*alpha), int(512*alpha),
+                                        layout=layout)
         body = separable_conv_block(body, 'separable_conv_block_12',
-                                    int(512*alpha), int(1024*alpha), downsample=True)
+                                    int(512*alpha), int(1024*alpha),
+                                    downsample=True, layout=layout)
         body = separable_conv_block(body, 'separable_conv_block_13',
-                                    int(1024*alpha), int(1024*alpha))
-    pool = relay.nn.global_avg_pool2d(data=body)
+                                    int(1024*alpha), int(1024*alpha),
+                                    layout=layout)
+    pool = relay.nn.global_avg_pool2d(data=body, layout=layout)
     flatten = relay.nn.batch_flatten(data=pool)
     weight = relay.var('fc_weight')
     fc = relay.nn.dense(data=flatten, weight=weight, units=num_classes)
@@ -111,7 +125,8 @@ def mobile_net(num_classes=1000, data_shape=(1, 3, 224, 224),
     return relay.Function(relay.analysis.free_vars(softmax), softmax)
 
 
-def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtype='float32'):
+def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224),
+                 dtype='float32', layout='NCHW'):
     """Get benchmark workload for mobilenet
 
     Parameters
@@ -123,11 +138,15 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp
         Number of classes
 
     image_shape : tuple, optional
-        The input image shape
+        The input image shape, cooperate with layout
 
     dtype : str, optional
         The data type
 
+    layout : str, optional
+        The data layout of image_shape and the operators
+        cooperate with image_shape
+
     Returns
     -------
     mod : tvm.relay.Module
@@ -138,5 +157,6 @@ def get_workload(batch_size=1, num_classes=1000, image_shape=(3, 224, 224), dtyp
     """
     data_shape = tuple([batch_size] + list(image_shape))
     net = mobile_net(num_classes=num_classes, data_shape=data_shape,
-                     dtype=dtype, alpha=1.0, is_shallow=False)
+                     dtype=dtype, alpha=1.0, is_shallow=False,
+                     layout=layout)
     return create_workload(net)
index 0b1a5f5..b1d8b9c 100644 (file)
@@ -114,6 +114,13 @@ def test_mobilenet():
     mod, params = testing.mobilenet.get_workload(batch_size=1)
     benchmark_execution(mod, params)
 
+# TODO: enable when the low building performance (several minutes) fixed.
+def test_mobilenet_nhwc():
+    image_shape = (1, 224, 224, 3)
+    mod, params = testing.mobilenet.get_workload(batch_size=1,
+                                                 image_shape=image_shape[1:],
+                                                 layout='NHWC')
+    benchmark_execution(mod, params, measure=False, data_shape=image_shape)
 
 def test_densenet():
     mod, params = testing.densenet.get_workload(batch_size=1)