Add more cases to keras _convert_reshape (#3846)
authorAlexander Pivovarov <pivovaa@amazon.com>
Sat, 31 Aug 2019 04:30:59 +0000 (21:30 -0700)
committerZhi <5145158+zhiics@users.noreply.github.com>
Sat, 31 Aug 2019 04:30:59 +0000 (21:30 -0700)
python/tvm/relay/frontend/keras.py
tests/python/frontend/keras/test_forward.py

index 635a600..e5b70a4 100644 (file)
@@ -490,11 +490,26 @@ def _convert_concat(inexpr, keras_layer, _):
 
 def _convert_reshape(inexpr, keras_layer, _):
     _check_data_format(keras_layer)
-    ch = keras_layer.input_shape[-1]
-    assert ch == keras_layer.target_shape[-1], \
-        "Only supports last dimension in target shape being equal to " \
-        "the channel number of input tensor."
-    shape = (-1, ch) + keras_layer.target_shape[:-1]
+    inshape = keras_layer.input_shape # includes batch
+    tshape = keras_layer.target_shape # no batch
+    if len(inshape) == 3 and len(tshape) == 1:
+        # (?, a, b) -> (-1, ab)
+        shape = (-1, tshape[0])
+    elif len(inshape) in [2, 3] and len(tshape) == 2:
+        # (?, cc) -> (-1, c, c)
+        # (?, a, b) -> (-1, c, c)
+        assert tshape[0] == tshape[1], \
+            "Only supports square target shapes, but got {}".format(tshape)
+        shape = (-1, ) + tshape
+    else:
+        # (?, h, w, c) -> (-1, c, H, W)
+        # (?, h, w, c) -> (-1, c, hw)
+        # (?, hw, c) -> (-1, c, h, w)
+        ch = inshape[-1]
+        assert ch == tshape[-1], \
+            "Only supports last dimension in target shape being equal to " \
+            "the channel number of input tensor."
+        shape = (-1, ch) + tshape[:-1]
     return _op.reshape(inexpr, newshape=shape)
 
 
index 4b71cb6..3cc4ac5 100644 (file)
@@ -193,10 +193,36 @@ def test_forward_upsample(interpolation='nearest'):
 
 
 def test_forward_reshape():
+    # input_shape len is 3, target_shape len is 3
     data = keras.layers.Input(shape=(32, 32, 3))
-    x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
+    x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
     keras_model = keras.models.Model(data, x)
     verify_keras_frontend(keras_model)
+    # input_shape len is 3, target_shape len is 2
+    data = keras.layers.Input(shape=(32, 8, 3))
+    x = keras.layers.Reshape(target_shape=(256, 3))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+    # input_shape len is 2, target_shape len is 3
+    data = keras.layers.Input(shape=(256, 3))
+    x = keras.layers.Reshape(target_shape=(8, 32, 3))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
+    # input_shape len is 2, target_shape len is 1
+    data = keras.layers.Input(shape=(2, 8))
+    x = keras.layers.Reshape(target_shape=(16,))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+    # input_shape len is 1, target_shape len is 2
+    data = keras.layers.Input(shape=(16,))
+    x = keras.layers.Reshape(target_shape=(4, 4))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
+    # input_shape len is 2, target_shape len is 2
+    data = keras.layers.Input(shape=(2, 8))
+    x = keras.layers.Reshape(target_shape=(4, 4))(data)
+    keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model, need_transpose=False)
 
 
 def test_forward_crop():