def _convert_reshape(inexpr, keras_layer, _):
_check_data_format(keras_layer)
- ch = keras_layer.input_shape[-1]
- assert ch == keras_layer.target_shape[-1], \
- "Only supports last dimension in target shape being equal to " \
- "the channel number of input tensor."
- shape = (-1, ch) + keras_layer.target_shape[:-1]
+ inshape = keras_layer.input_shape # includes batch
+ tshape = keras_layer.target_shape # no batch
+ if len(inshape) == 3 and len(tshape) == 1:
+ # (?, a, b) -> (-1, ab)
+ shape = (-1, tshape[0])
+ elif len(inshape) in [2, 3] and len(tshape) == 2:
+ # (?, cc) -> (-1, c, c)
+ # (?, a, b) -> (-1, c, c)
+ assert tshape[0] == tshape[1], \
+ "Only supports square target shapes, but got {}".format(tshape)
+ shape = (-1, ) + tshape
+ else:
+ # (?, h, w, c) -> (-1, c, H, W)
+ # (?, h, w, c) -> (-1, c, hw)
+ # (?, hw, c) -> (-1, c, h, w)
+ ch = inshape[-1]
+ assert ch == tshape[-1], \
+ "Only supports last dimension in target shape being equal to " \
+ "the channel number of input tensor."
+ shape = (-1, ch) + tshape[:-1]
return _op.reshape(inexpr, newshape=shape)
def test_forward_reshape():
+ # input_shape len is 3, target_shape len is 3
data = keras.layers.Input(shape=(32, 32, 3))
- x = keras.layers.Reshape(target_shape=(32, 32, 3))(data)
+ x = keras.layers.Reshape(target_shape=(16, 64, 3))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
+ # input_shape len is 3, target_shape len is 2
+ data = keras.layers.Input(shape=(32, 8, 3))
+ x = keras.layers.Reshape(target_shape=(256, 3))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model)
+ # input_shape len is 2, target_shape len is 3
+ data = keras.layers.Input(shape=(256, 3))
+ x = keras.layers.Reshape(target_shape=(8, 32, 3))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model)
+ # input_shape len is 2, target_shape len is 1
+ data = keras.layers.Input(shape=(2, 8))
+ x = keras.layers.Reshape(target_shape=(16,))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model, need_transpose=False)
+ # input_shape len is 1, target_shape len is 2
+ data = keras.layers.Input(shape=(16,))
+ x = keras.layers.Reshape(target_shape=(4, 4))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model, need_transpose=False)
+ # input_shape len is 2, target_shape len is 2
+ data = keras.layers.Input(shape=(2, 8))
+ x = keras.layers.Reshape(target_shape=(4, 4))(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model, need_transpose=False)
def test_forward_crop():