[Relay][Keras] Dot (#3668)
authorYong Wu <ywu118@alumni.jh.edu>
Wed, 28 Aug 2019 06:16:48 +0000 (23:16 -0700)
committerMORITA Kazutaka <morita.kazutaka@gmail.com>
Wed, 28 Aug 2019 06:16:48 +0000 (15:16 +0900)
* [Relay][Keras] Dot

* fix reshape

* fix comments

python/tvm/relay/frontend/keras.py
tests/python/frontend/keras/test_forward.py

index 8be3d22..635a600 100644 (file)
@@ -156,7 +156,26 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
 def _convert_merge(inexpr, keras_layer, _):
     merge_type = type(keras_layer).__name__
     ret = inexpr[0]
-    if merge_type == 'Subtract':
+    if merge_type == 'Dot':
+        axes = keras_layer.axes
+        if isinstance(keras_layer.axes, int):
+            axes = [keras_layer.axes, keras_layer.axes]
+        if isinstance(axes, list):
+            if len(axes) != 2:
+                raise tvm.error.OpAttributeUnimplemented(
+                    'Dot with axes {} is not supported.'.format(keras_layer.axes))
+            for i, axis in enumerate(axes):
+                if axis not in [1, 2]:
+                    raise tvm.error.OpAttributeUnimplemented(
+                        'Dot with axes {} is not supported.'.format(keras_layer.axes))
+                if axes[i] == 2:
+                    inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
+        else:
+            raise tvm.error.OpAttributeUnImplemented(
+                'Dot with axes {} is not supported.'.format(keras_layer.axes))
+        ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
+        ret = _op.transpose(ret_dot, axes=[0, 2, 1])
+    elif merge_type == 'Subtract':
         assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
         ret = _op.subtract(ret, inexpr[1])
     elif merge_type in ['Add', 'Multiply', 'Maximum']:
@@ -635,7 +654,7 @@ _convert_map = {
 
     'Average'                : _convert_merge,
     'Maximum'                : _convert_merge,
-    'Dot'                    : _convert_merge,
+    'Dot'                    : _convert_merge,
     'Permute'                : _convert_permute,
     # 'Embedding'              : _convert_embedding,
     # 'RepeatVector'           : _convert_repeat_vector,
index 9996bb7..4b71cb6 100644 (file)
@@ -84,13 +84,26 @@ def test_forward_merge():
                    keras.layers.Average(),
                    keras.layers.Concatenate()]
     for merge_func in merge_funcs:
-        if isinstance(merge_func, keras.layers.merge.Subtract):
+        if isinstance(merge_func, (keras.layers.merge.Subtract, keras.layers.merge.Dot)):
             out = merge_func([x, y])
         else:
             out = merge_func([x, y, z])
         keras_model = keras.models.Model(data, out)
         verify_keras_frontend(keras_model)
 
+def test_forward_merge_dot():
+    data1 = keras.layers.Input(shape=(2, 2))
+    data2 = keras.layers.Input(shape=(2, 2))
+    merge_funcs = [keras.layers.Dot(axes=[1, 2]),
+                   keras.layers.Dot(axes=[2, 1]),
+                   keras.layers.Dot(axes=[1, 1]),
+                   keras.layers.Dot(axes=[2, 2]),
+                   keras.layers.Dot(axes=1),
+                   keras.layers.Dot(axes=2)]
+    for merge_func in merge_funcs:
+        out = merge_func([data1, data2])
+        keras_model = keras.models.Model([data1, data2], out)
+        verify_keras_frontend(keras_model)
 
 def test_forward_activations():
     data = keras.layers.Input(shape=(32, 32, 3))
@@ -281,6 +294,7 @@ def test_forward_mobilenet():
 
 if __name__ == '__main__':
     test_forward_merge()
+    test_forward_merge_dot()
     test_forward_activations()
     test_forward_dense()
     test_forward_permute()