[Relay][Frontend][Keras] batch_norm op params not handling well (#4310)
authorXingyu Zhou <zhoxingy@amazon.com>
Tue, 12 Nov 2019 18:30:04 +0000 (10:30 -0800)
committerYuwei Hu <huyuwei1995@gmail.com>
Tue, 12 Nov 2019 18:30:04 +0000 (10:30 -0800)
* Relay Keras frontent batch_norm op params not handeling well

* add unit test for Relay Frontend Keras batch_norm

python/tvm/relay/frontend/keras.py
tests/python/frontend/keras/test_forward.py

index 15f7440..57ee227 100644 (file)
@@ -460,6 +460,11 @@ def _convert_batchnorm(inexpr, keras_layer, etab):
     moving_var = keras_layer.get_weights()[idx + 1]
     params['moving_mean'] = etab.new_const(moving_mean)
     params['moving_var'] = etab.new_const(moving_var)
+    # in case beta or gamma is not defined
+    params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
+                     'beta' not in params else params['beta']
+    params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
+                      'gamma' not in params else params['gamma']
     result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
     return result
 
index 1b253fd..2af1615 100644 (file)
@@ -190,6 +190,36 @@ def test_forward_conv():
         keras_model = keras.models.Model(data, x)
         verify_keras_frontend(keras_model)
 
+def test_forward_batch_norm():
+    data = keras.layers.Input(shape=(32, 32, 3))
+    batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+                                                        center=True, scale=False,
+                                                        beta_initializer='zeros',
+                                                        gamma_initializer='ones',
+                                                        moving_mean_initializer='zeros',
+                                                        moving_variance_initializer='ones'),
+                       keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+                                                        center=True, scale=True,
+                                                        beta_initializer='zeros',
+                                                        gamma_initializer='ones',
+                                                        moving_mean_initializer='zeros',
+                                                        moving_variance_initializer='ones'),
+                       keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+                                                        center=False, scale=True,
+                                                        beta_initializer='zeros',
+                                                        gamma_initializer='ones',
+                                                        moving_mean_initializer='zeros',
+                                                        moving_variance_initializer='ones'),
+                       keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+                                                        center=False, scale=False,
+                                                        beta_initializer='zeros',
+                                                        gamma_initializer='ones',
+                                                        moving_mean_initializer='zeros',
+                                                        moving_variance_initializer='ones')]
+    for batch_norm_func in batch_norm_funcs:
+        x = batch_norm_func(data)
+        keras_model = keras.models.Model(data, x)
+    verify_keras_frontend(keras_model)
 
 def test_forward_upsample(interpolation='nearest'):
     data = keras.layers.Input(shape=(32, 32, 3))
@@ -333,6 +363,7 @@ if __name__ == '__main__':
     test_forward_sequential()
     test_forward_pool()
     test_forward_conv()
+    test_forward_batch_norm()
     test_forward_upsample(interpolation='nearest')
     test_forward_upsample(interpolation='bilinear')
     test_forward_reshape()