fix bug of declaring regularization loss mutiple times when reusing partitioned varia...
authorwangsiyu <siyu.wsy@gmail.com>
Thu, 3 May 2018 10:31:29 +0000 (18:31 +0800)
committerwangsiyu <siyu.wsy@gmail.com>
Thu, 3 May 2018 10:31:29 +0000 (18:31 +0800)
tensorflow/python/layers/base.py
tensorflow/python/layers/base_test.py

index 64db49c..c050e6b 100644 (file)
@@ -233,7 +233,8 @@ class Layer(base_layer.Layer):
             getter=vs.get_variable)
 
         if regularizer:
-          if context.executing_eagerly() or variable not in existing_variables:
+          if context.executing_eagerly() or _should_add_regularizer(
+              variable, existing_variables):
             self._handle_weight_regularization(name, variable, regularizer)
 
         if init_graph is not None:
@@ -354,3 +355,13 @@ def _add_elements_to_collection(elements, collection_list):
       if element not in collection_set:
         collection.append(element)
 
+def _should_add_regularizer(variable, existing_variable_set):
+  result = True
+  if isinstance(variable, tf_variables.PartitionedVariable):
+    for var in variable._get_variable_list():
+      if var in existing_variable_set:
+        result = False
+        break
+  else:
+    result = variable not in existing_variable_set
+  return result
index f08b552..361e3de 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
@@ -95,6 +96,20 @@ class BaseLayerTest(test.TestCase):
           regularizer=regularizer)
       self.assertEqual(len(layer.losses), 1)
 
+  def testReusePartitionedVaraiblesAndRegularizers(self):
+    regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
+    partitioner = partitioned_variables.fixed_size_partitioner(3)
+    for i in xrange(2):
+      with variable_scope.variable_scope(variable_scope.get_variable_scope(),
+                                         partitioner=partitioner,
+                                         reuse=False if i == 0 else True):
+        layer = base_layers.Layer(name='my_layer')
+        variable = layer.add_variable(
+            'reg_part_var', [4, 4],
+            initializer=init_ops.zeros_initializer(),
+            regularizer=regularizer)
+    self.assertEqual(len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3)
+
   def testNoEagerActivityRegularizer(self):
     with context.eager_mode():
       with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):