[KERAS]RepeatVector, Conv3DTranspose op support added (#5833)
authorSiju Samuel <siju.samuel@huawei.com>
Thu, 18 Jun 2020 00:50:33 +0000 (06:20 +0530)
committerGitHub <noreply@github.com>
Thu, 18 Jun 2020 00:50:33 +0000 (09:50 +0900)
python/tvm/relay/frontend/keras.py
tests/python/frontend/keras/test_forward.py

index ef76eb6..6972fb7 100644 (file)
@@ -336,25 +336,28 @@ def _convert_convolution3d(inexpr, keras_layer, etab):
               'in frontend Keras.'
         raise tvm.error.OpAttributeUnImplemented(msg.format(etab.data_layout))
 
+    is_deconv = type(keras_layer).__name__ == 'Conv3DTranspose'
+
+    if is_deconv:
+        kernel_d, kernel_h, kernel_w, n_filters, _ = weight.shape
+        if kernel_layout == 'OIDHW':
+            weight = weight.transpose([4, 3, 2, 0, 1])
+    else:
+        kernel_d, kernel_h, kernel_w, _, n_filters = weight.shape
+
     dilation_rate = keras_layer.dilation_rate
     if isinstance(dilation_rate, (list, tuple)):
         dilation = [dilation_rate[0], dilation_rate[1], dilation_rate[2]]
     else:
         dilation = [dilation_rate, dilation_rate, dilation_rate]
 
-    kernel_d1 = weight.shape[0]
-    kernel_d2 = weight.shape[1]
-    kernel_d3 = weight.shape[2]
-    # in_channels = weight.shape[3]
-    n_filters = weight.shape[4]
-
-    dilated_kernel_d1 = (kernel_d1 - 1) * dilation[0] + 1
-    dilated_kernel_d2 = (kernel_d2 - 1) * dilation[1] + 1
-    dilated_kernel_d3 = (kernel_d3 - 1) * dilation[2] + 1
-    stride_d1, stride_d2, stride_d3 = keras_layer.strides
+    dilated_kernel_d = (kernel_d - 1) * dilation[0] + 1
+    dilated_kernel_h = (kernel_h - 1) * dilation[1] + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation[2] + 1
+    stride_d, stride_h, stride_w = keras_layer.strides
     params = {'weight': etab.new_const(weight),
-              'kernel_size': [kernel_d1, kernel_d2, kernel_d3],
-              'strides': [stride_d1, stride_d2, stride_d3],
+              'kernel_size': [kernel_d, kernel_h, kernel_w],
+              'strides': [stride_d, stride_h, stride_w],
               'dilation': dilation,
               'padding': [0, 0, 0],
               'data_layout': etab.data_layout,
@@ -365,18 +368,21 @@ def _convert_convolution3d(inexpr, keras_layer, etab):
         pass
     # calculate the padding values
     elif keras_layer.padding == 'same':
-        in_d1 = keras_layer.input_shape[1]
-        in_d2 = keras_layer.input_shape[2]
-        in_d3 = keras_layer.input_shape[3]
-        pad_d1 = _get_pad_pair(in_d1, dilated_kernel_d1, stride_d1)
-        pad_d2 = _get_pad_pair(in_d2, dilated_kernel_d2, stride_d2)
-        pad_d3 = _get_pad_pair(in_d3, dilated_kernel_d3, stride_d3)
-        params['padding'] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]]
+        in_d = keras_layer.input_shape[1]
+        in_h = keras_layer.input_shape[2]
+        in_w = keras_layer.input_shape[3]
+        pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
+        pad_h = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
+        pad_w = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
+        params['padding'] = [pad_d[0], pad_h[0], pad_w[0], pad_d[1], pad_h[1], pad_w[1]]
     else:
         msg = 'Padding with {} is not supported for operator Convolution3D ' \
               'in frontend Keras.'
         raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
-    out = _op.nn.conv3d(data=inexpr, **params)
+    if is_deconv:
+        out = _op.nn.conv3d_transpose(data=inexpr, **params)
+    else:
+        out = _op.nn.conv3d(data=inexpr, **params)
 
     channel_axis = -1 if etab.data_layout == "NDHWC" else 1
     if keras_layer.use_bias:
@@ -849,6 +855,16 @@ def _convert_gru(inexpr, keras_layer, etab):
     return [output, output]
 
 
+def _convert_repeat_vector(inexpr, keras_layer, _):
+    input_shape = list(keras_layer.input_shape)
+    repeats = keras_layer.n
+    out_shape = [-1, repeats] + input_shape[1:]
+    out = _op.repeat(inexpr, repeats=repeats, axis=0)
+    out = _op.reshape(out, out_shape)
+
+    return out
+
+
 def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
     """Layers that can be skipped because they are train time only."""
     return inexpr
@@ -898,7 +914,7 @@ _convert_map = {
     # 'Conv1D'                 : _convert_convolution1d,
 
     'Conv3D'                   : _convert_convolution3d,
-    # 'Conv3DTranspose'        : _convert_convolution3d,
+    'Conv3DTranspose'          : _convert_convolution3d,
     # 'SeparableConv3D'        : _convert_convolution3d,
     'MaxPooling3D'             : _convert_pooling3d,
     'AveragePooling3D'         : _convert_pooling3d,
@@ -919,7 +935,7 @@ _convert_map = {
     'Dot'                      : _convert_merge,
     'Permute'                  : _convert_permute,
     'Embedding'                : _convert_embedding,
-    # 'RepeatVector'           : _convert_repeat_vector,
+    'RepeatVector'             : _convert_repeat_vector,
 
     'InputLayer'               : _default_skip,
     'Dropout'                  : _default_skip,
index 9b963c3..8ddae96 100644 (file)
@@ -422,6 +422,31 @@ class TestKeras:
             keras_model = keras.models.Model(data, x)
             verify_keras_frontend(keras_model, layout='NDHWC')
 
+
+    def test_forward_conv3d_transpose(self, keras):
+        data = keras.layers.Input(shape=(32, 32, 32, 3))
+        conv_funcs = [keras.layers.Conv3DTranspose(filters=10,
+                                          kernel_size=(3, 3, 3),
+                                          strides=(2, 2, 2),
+                                          padding='same'),
+                      keras.layers.Conv3DTranspose(filters=10,
+                                          kernel_size=(1, 1, 1),
+                                          dilation_rate=(1, 1, 1),
+                                          padding='same'),
+                      keras.layers.Conv3DTranspose(filters=1,
+                                          kernel_size=(3, 3, 3),
+                                          padding='valid',
+                                          use_bias=False),
+                      keras.layers.Conv3DTranspose(filters=10,
+                                          kernel_size=(2, 2, 2),
+                                          padding='valid'),
+                    ]
+        for conv_func in conv_funcs:
+            x = conv_func(data)
+            keras_model = keras.models.Model(data, x)
+            verify_keras_frontend(keras_model, layout='NDHWC')
+
+
     def test_forward_pool3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 1))
         pool_funcs = [# maxpool
@@ -483,6 +508,26 @@ class TestKeras:
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
 
+
+    def test_forward_repeat_vector(self, keras):
+        data = keras.layers.Input(shape=(5,), dtype="float32")
+        x = keras.layers.Dense(6)(data)
+        x = keras.layers.RepeatVector(2)(x)
+
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+        data = keras.layers.Input(shape=(10,), dtype="float32")
+        x = keras.layers.RepeatVector(3)(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+        data = keras.layers.Input(shape=(4,), dtype="float32")
+        x = keras.layers.RepeatVector(1)(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+
     def test_forward_global_pool3d(self, keras):
         data = keras.layers.Input(shape=(32, 32, 32, 1))
         pool_funcs = [# global maxpool
@@ -523,8 +568,10 @@ if __name__ == '__main__':
         sut.test_forward_mobilenet(keras=k)
         sut.test_forward_mobilenet(keras=k, layout='NHWC')
         sut.test_forward_conv3d(keras=k)
+        sut.test_forward_conv3d_transpose(keras=k)
         sut.test_forward_pool3d(keras=k)
         sut.test_forward_global_pool3d(keras=k)
         sut.test_forward_upsample3d(keras=k)
         sut.test_forward_zero_padding3d(keras=k)
         sut.test_forward_embedding(keras=k)
+        sut.test_forward_repeat_vector(keras=k)