getter=vs.get_variable)
if regularizer:
- if context.executing_eagerly() or variable not in existing_variables:
+ if context.executing_eagerly() or _should_add_regularizer(
+ variable, existing_variables):
self._handle_weight_regularization(name, variable, regularizer)
if init_graph is not None:
if element not in collection_set:
collection.append(element)
+def _should_add_regularizer(variable, existing_variable_set):
+ result = True
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for var in variable._get_variable_list():
+ if var in existing_variable_set:
+ result = False
+ break
+ else:
+ result = variable not in existing_variable_set
+ return result
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ def testReusePartitionedVaraiblesAndRegularizers(self):
+ regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
+ partitioner = partitioned_variables.fixed_size_partitioner(3)
+ for i in xrange(2):
+ with variable_scope.variable_scope(variable_scope.get_variable_scope(),
+ partitioner=partitioner,
+ reuse=False if i == 0 else True):
+ layer = base_layers.Layer(name='my_layer')
+ variable = layer.add_variable(
+ 'reg_part_var', [4, 4],
+ initializer=init_ops.zeros_initializer(),
+ regularizer=regularizer)
+ self.assertEqual(len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3)
+
def testNoEagerActivityRegularizer(self):
with context.eager_mode():
with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):