Fixes MovingAverageOptimizer when dealing with resource variables.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Feb 2018 21:51:23 +0000 (13:51 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 21:59:56 +0000 (13:59 -0800)
Also make it an error if user tries to use swapping_saver with some variables but do not include their EMA counterparts.

PiperOrigin-RevId: 185739214

tensorflow/contrib/opt/python/training/moving_average_optimizer.py
tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py

index d68ad23..9ce50bf 100644 (file)
@@ -83,7 +83,7 @@ class MovingAverageOptimizer(optimizer.Optimizer):
     self._optimizer = opt
     self._ema = moving_averages.ExponentialMovingAverage(
         average_decay, num_updates=num_updates)
-    self._variable_map = None
+    self._swapped_variable_name_map = None
     self._sequential_update = sequential_update
 
   def compute_gradients(self, *args, **kwargs):
@@ -93,7 +93,7 @@ class MovingAverageOptimizer(optimizer.Optimizer):
     train_op = self._optimizer.apply_gradients(
         grads_and_vars, global_step=global_step, name=name)
     var_list = [x[1] for x in grads_and_vars if x[0] is not None]
-    self._variable_map = {}
+    self._swapped_variable_name_map = {}
     if self._sequential_update:
       with ops.control_dependencies([train_op]):
         ma_op = self._ema.apply(var_list)
@@ -102,9 +102,9 @@ class MovingAverageOptimizer(optimizer.Optimizer):
 
     for v in var_list:
       v_avg = self._ema.average(v)
-      self._variable_map[v.op.name] = v_avg
-      self._variable_map[v_avg.op.name] = v
-    return control_flow_ops.group(train_op, ma_op, name="train_with_avg")
+      self._swapped_variable_name_map[v.op.name] = v_avg.op.name
+      self._swapped_variable_name_map[v_avg.op.name] = v.op.name
+    return control_flow_ops.group(train_op, ma_op, name='train_with_avg')
 
   def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
     """Create a saver swapping moving averages and variables.
@@ -129,22 +129,45 @@ class MovingAverageOptimizer(optimizer.Optimizer):
 
     Raises:
       RuntimeError: If apply_gradients or minimize has not been called before.
+      ValueError: If var_list is provided and contains some variables but not
+        their moving average counterpart.
     """
 
-    if self._variable_map is None:
+    if self._swapped_variable_name_map is None:
       raise RuntimeError('Must call apply_gradients or minimize before '
                          'creating the swapping_saver')
     if var_list is None:
       var_list = variables.global_variables()
     if not isinstance(var_list, dict):
       var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+    # OpListToDict converts variables to tensors. We make sure we can get
+    # the unique variable name for normal and resource vaiables.
+    def get_v_name(tensor):
+      if tensor.op.type == 'ReadVariableOp':
+        return tensor.op.inputs[0].op.name
+      else:
+        return tensor.op.name
+
+    v_name_to_tensor = {}
+    for tensor in six.itervalues(var_list):
+      v_name = get_v_name(tensor)
+      v_name_to_tensor[v_name] = tensor
+
     # Now swap variables and moving averages
     swapped_var_list = {}
-    for k, v in six.iteritems(var_list):
-      v_swap = self._variable_map.get(v.op.name, None)
-      if v_swap:
-        swapped_var_list[k] = v_swap
-      else:
-        swapped_var_list[k] = v
+    for k, tensor in six.iteritems(var_list):
+      v_name = get_v_name(tensor)
+      swapped_v_name = self._swapped_variable_name_map.get(v_name, None)
+      tensor_to_save = tensor
+      if swapped_v_name is not None:
+        if swapped_v_name in v_name_to_tensor:
+          tensor_to_save = v_name_to_tensor[swapped_v_name]
+        else:
+          raise ValueError(
+              ('Variable to swap %s is not part of variables to save. '
+               'This breaks MovingAverageOptimizer.') % swapped_v_name)
+      swapped_var_list[k] = tensor_to_save
+
     # Build the swapping saver.
     return saver.Saver(swapped_var_list, name=name, **kwargs)
index 60929ad..85e3e8d 100644 (file)
@@ -24,6 +24,10 @@ import six
 from tensorflow.contrib.opt.python.training import moving_average_optimizer
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 from tensorflow.python.training import gradient_descent
@@ -33,13 +37,26 @@ from tensorflow.python.training import saver
 class MovingAverageOptimizerTest(test.TestCase):
 
   def testRun(self):
+    self._helpTestRun(use_resource=False)
+
+  def testRunUseResource(self):
+    # Test that MovingAverageOptimizer works with resource variables.
+    self._helpTestRun(use_resource=True)
+
+  def _helpTestRun(self, use_resource=False):
     for sequential_update in [True, False]:
       for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-        with self.test_session() as sess:
+        with self.test_session(graph=ops.Graph()) as sess:
           orig_val0 = [1.0, 2.0]
           orig_val1 = [3.0, 4.0]
-          var0 = variables.Variable(orig_val0, name='var0', dtype=dtype)
-          var1 = variables.Variable(orig_val1, name='var1', dtype=dtype)
+          var0 = variable_scope.get_variable(
+              'var0',
+              initializer=constant_op.constant(orig_val0, dtype=dtype),
+              use_resource=use_resource)
+          var1 = variable_scope.get_variable(
+              'var1',
+              initializer=constant_op.constant(orig_val1, dtype=dtype),
+              use_resource=use_resource)
           grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
           grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
 
@@ -52,22 +69,63 @@ class MovingAverageOptimizerTest(test.TestCase):
           save_path = os.path.join(save_dir, 'model')
           update = opt.apply_gradients(
               list(six.moves.zip([grads0, grads1], [var0, var1])))
+          global_vars = variables.global_variables()
+          ema_var0 = [
+              v for v in global_vars
+              if v.op.name == 'var0/ExponentialMovingAverage'
+          ][0]
+          ema_var1 = [
+              v for v in global_vars
+              if v.op.name == 'var1/ExponentialMovingAverage'
+          ][0]
+          perturb = control_flow_ops.group([
+              state_ops.assign_add(var0, [1.0, 1.0]),
+              state_ops.assign_add(var1, [2.0, 2.0]),
+              state_ops.assign_add(ema_var0, [3.0, 3.0]),
+              state_ops.assign_add(ema_var1, [4.0, 4.0])
+          ])
+
+          # Test taht saver with missing ema variables will fail.
+          with self.assertRaisesRegexp(ValueError, r'Variable to swap'):
+            opt.swapping_saver(var_list=[var0])
+
           train_saver = opt.swapping_saver()
+          train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0])
           inference_saver = saver.Saver()
           variables.global_variables_initializer().run()
           # Step 1.
           update.run()
-          val0 = var0.eval()
-          val1 = var1.eval()
           self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
           self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval())
+          if sequential_update:
+            self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+            self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval())
           # Test that the swapping saver save/restore operation is identity.
           train_saver.save(sess, save_path)
           train_saver.restore(sess, save_path)
-          val0 = var0.eval()
-          val1 = var1.eval()
           self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
           self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval())
+          if sequential_update:
+            self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+            self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval())
+          # Test that the subset saver saves the EMA variable as well.
+          if sequential_update:
+            subset_save_path = save_path + '_subset'
+            train_saver_subset.save(sess, subset_save_path)
+            perturb.run()
+            self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval())
+            self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval())
+            self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval())
+            self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval())
+            # Restoring should only restore var0 and ema_var0.
+            train_saver_subset.restore(sess, subset_save_path)
+            self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
+            self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+            self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval())
+            self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval())
+            # Restore back to previou state.
+            train_saver.restore(sess, save_path)
+
           # If updates are parallel, this is not always true after the 1st step.
           if sequential_update:
             # Test that the normal saver will have the averaged variables.