moving_var = keras_layer.get_weights()[idx + 1]
params['moving_mean'] = etab.new_const(moving_mean)
params['moving_var'] = etab.new_const(moving_var)
+ # in case beta or gamma is not defined
+ params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
+ 'beta' not in params else params['beta']
+ params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
+ 'gamma' not in params else params['gamma']
result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
return result
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
+def test_forward_batch_norm():
+ data = keras.layers.Input(shape=(32, 32, 3))
+ batch_norm_funcs = [keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+ center=True, scale=False,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ moving_mean_initializer='zeros',
+ moving_variance_initializer='ones'),
+ keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+ center=True, scale=True,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ moving_mean_initializer='zeros',
+ moving_variance_initializer='ones'),
+ keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+ center=False, scale=True,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ moving_mean_initializer='zeros',
+ moving_variance_initializer='ones'),
+ keras.layers.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001,
+ center=False, scale=False,
+ beta_initializer='zeros',
+ gamma_initializer='ones',
+ moving_mean_initializer='zeros',
+ moving_variance_initializer='ones')]
+ for batch_norm_func in batch_norm_funcs:
+ x = batch_norm_func(data)
+ keras_model = keras.models.Model(data, x)
+ verify_keras_frontend(keras_model)
def test_forward_upsample(interpolation='nearest'):
data = keras.layers.Input(shape=(32, 32, 3))
test_forward_sequential()
test_forward_pool()
test_forward_conv()
+ test_forward_batch_norm()
test_forward_upsample(interpolation='nearest')
test_forward_upsample(interpolation='bilinear')
test_forward_reshape()