refine unit test case coding style and move _should_add_regularizer function into...
authorwangsiyu <siyu.wsy@gmail.com>
Tue, 8 May 2018 02:54:04 +0000 (10:54 +0800)
committerwangsiyu <siyu.wsy@gmail.com>
Tue, 8 May 2018 02:54:04 +0000 (10:54 +0800)
tensorflow/python/layers/base.py
tensorflow/python/layers/base_test.py

index f7b2e47..78db476 100644 (file)
@@ -191,6 +191,18 @@ class Layer(base_layer.Layer):
       RuntimeError: If called with partioned variable regularization and
         eager execution is enabled.
     """
+    
+    def _should_add_regularizer(variable, existing_variable_set):
+      result = True
+      if isinstance(variable, tf_variables.PartitionedVariable):
+        for var in variable:
+          if var in existing_variable_set:
+            result = False
+            break
+      else:
+        result = variable not in existing_variable_set
+      return result
+
     init_graph = None
     if not context.executing_eagerly():
       default_graph = ops.get_default_graph()
@@ -354,14 +366,3 @@ def _add_elements_to_collection(elements, collection_list):
     for element in elements:
       if element not in collection_set:
         collection.append(element)
-
-def _should_add_regularizer(variable, existing_variable_set):
-  result = True
-  if isinstance(variable, tf_variables.PartitionedVariable):
-    for var in variable:
-      if var in existing_variable_set:
-        result = False
-        break
-  else:
-    result = variable not in existing_variable_set
-  return result
index 361e3de..7158fd4 100644 (file)
@@ -99,10 +99,10 @@ class BaseLayerTest(test.TestCase):
   def testReusePartitionedVaraiblesAndRegularizers(self):
     regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
     partitioner = partitioned_variables.fixed_size_partitioner(3)
-    for i in xrange(2):
+    for reuse in [False, True]:
       with variable_scope.variable_scope(variable_scope.get_variable_scope(),
                                          partitioner=partitioner,
-                                         reuse=False if i == 0 else True):
+                                         reuse=reuse):
         layer = base_layers.Layer(name='my_layer')
         variable = layer.add_variable(
             'reg_part_var', [4, 4],