From 0a06a20fcc09b6af8bdd69a52f7015c926113c47 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 23 May 2018 07:16:01 +0800 Subject: [PATCH] BUG: keras.callbacks.TensorBoard raises an exception for non_trainale_weights (#19148) * TST: write_grads for non_trainable_weights * BUG: bypass non_trainable_weights for write_grad * CLN: factor out write_grad loop --- tensorflow/python/keras/callbacks.py | 21 ++++++++++++--------- tensorflow/python/keras/callbacks_test.py | 2 ++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 3678272..a6dbe2b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -720,15 +720,6 @@ class TensorBoard(Callback): for weight in layer.weights: mapped_weight_name = weight.name.replace(':', '_') tf_summary.histogram(mapped_weight_name, weight) - if self.write_grads: - grads = model.optimizer.get_gradients(model.total_loss, weight) - - def is_indexed_slices(grad): - return type(grad).__name__ == 'IndexedSlices' - - grads = [grad.values if is_indexed_slices(grad) else grad - for grad in grads] - tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) if self.write_images: w_img = array_ops.squeeze(weight) shape = K.int_shape(w_img) @@ -755,6 +746,18 @@ class TensorBoard(Callback): assert len(shape) == 4 and shape[-1] in [1, 3, 4] tf_summary.image(mapped_weight_name, w_img) + if self.write_grads: + for weight in layer.trainable_weights: + mapped_weight_name = weight.name.replace(':', '_') + grads = model.optimizer.get_gradients(model.total_loss, weight) + + def is_indexed_slices(grad): + return type(grad).__name__ == 'IndexedSlices' + + grads = [grad.values if is_indexed_slices(grad) else grad + for grad in grads] + tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) + if hasattr(layer, 'output'): tf_summary.histogram('{}_out'.format(layer.name), layer.output) self.merged = tf_summary.merge_all() diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index ad5f416..eb40fb4 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -635,6 +635,8 @@ class KerasCallbacksTest(test.TestCase): model.add( keras.layers.Dense( NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu')) + # non_trainable_weights: moving_variance, moving_mean + model.add(keras.layers.BatchNormalization()) model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax')) model.compile( loss='categorical_crossentropy', -- 2.7.4