From 2ebf1bd14ea8eec2d69bffdaf2805d8539026550 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 30 Aug 2019 21:30:59 -0700 Subject: [PATCH] Add more cases to keras _convert_reshape (#3846) --- python/tvm/relay/frontend/keras.py | 25 ++++++++++++++++++++----- tests/python/frontend/keras/test_forward.py | 28 +++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 635a600..e5b70a4 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -490,11 +490,26 @@ def _convert_concat(inexpr, keras_layer, _): def _convert_reshape(inexpr, keras_layer, _): _check_data_format(keras_layer) - ch = keras_layer.input_shape[-1] - assert ch == keras_layer.target_shape[-1], \ - "Only supports last dimension in target shape being equal to " \ - "the channel number of input tensor." - shape = (-1, ch) + keras_layer.target_shape[:-1] + inshape = keras_layer.input_shape # includes batch + tshape = keras_layer.target_shape # no batch + if len(inshape) == 3 and len(tshape) == 1: + # (?, a, b) -> (-1, ab) + shape = (-1, tshape[0]) + elif len(inshape) in [2, 3] and len(tshape) == 2: + # (?, cc) -> (-1, c, c) + # (?, a, b) -> (-1, c, c) + assert tshape[0] == tshape[1], \ + "Only supports square target shapes, but got {}".format(tshape) + shape = (-1, ) + tshape + else: + # (?, h, w, c) -> (-1, c, H, W) + # (?, h, w, c) -> (-1, c, hw) + # (?, hw, c) -> (-1, c, h, w) + ch = inshape[-1] + assert ch == tshape[-1], \ + "Only supports last dimension in target shape being equal to " \ + "the channel number of input tensor." + shape = (-1, ch) + tshape[:-1] return _op.reshape(inexpr, newshape=shape) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 4b71cb6..3cc4ac5 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -193,10 +193,36 @@ def test_forward_upsample(interpolation='nearest'): def test_forward_reshape(): + # input_shape len is 3, target_shape len is 3 data = keras.layers.Input(shape=(32, 32, 3)) - x = keras.layers.Reshape(target_shape=(32, 32, 3))(data) + x = keras.layers.Reshape(target_shape=(16, 64, 3))(data) keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model) + # input_shape len is 3, target_shape len is 2 + data = keras.layers.Input(shape=(32, 8, 3)) + x = keras.layers.Reshape(target_shape=(256, 3))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + # input_shape len is 2, target_shape len is 3 + data = keras.layers.Input(shape=(256, 3)) + x = keras.layers.Reshape(target_shape=(8, 32, 3))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) + # input_shape len is 2, target_shape len is 1 + data = keras.layers.Input(shape=(2, 8)) + x = keras.layers.Reshape(target_shape=(16,))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + # input_shape len is 1, target_shape len is 2 + data = keras.layers.Input(shape=(16,)) + x = keras.layers.Reshape(target_shape=(4, 4))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + # input_shape len is 2, target_shape len is 2 + data = keras.layers.Input(shape=(2, 8)) + x = keras.layers.Reshape(target_shape=(4, 4))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) def test_forward_crop(): -- 2.7.4