self._optimizer = opt
self._ema = moving_averages.ExponentialMovingAverage(
average_decay, num_updates=num_updates)
- self._variable_map = None
+ self._swapped_variable_name_map = None
self._sequential_update = sequential_update
def compute_gradients(self, *args, **kwargs):
train_op = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
var_list = [x[1] for x in grads_and_vars if x[0] is not None]
- self._variable_map = {}
+ self._swapped_variable_name_map = {}
if self._sequential_update:
with ops.control_dependencies([train_op]):
ma_op = self._ema.apply(var_list)
for v in var_list:
v_avg = self._ema.average(v)
- self._variable_map[v.op.name] = v_avg
- self._variable_map[v_avg.op.name] = v
- return control_flow_ops.group(train_op, ma_op, name="train_with_avg")
+ self._swapped_variable_name_map[v.op.name] = v_avg.op.name
+ self._swapped_variable_name_map[v_avg.op.name] = v.op.name
+ return control_flow_ops.group(train_op, ma_op, name='train_with_avg')
def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
"""Create a saver swapping moving averages and variables.
Raises:
RuntimeError: If apply_gradients or minimize has not been called before.
+ ValueError: If var_list is provided and contains some variables but not
+ their moving average counterpart.
"""
- if self._variable_map is None:
+ if self._swapped_variable_name_map is None:
raise RuntimeError('Must call apply_gradients or minimize before '
'creating the swapping_saver')
if var_list is None:
var_list = variables.global_variables()
if not isinstance(var_list, dict):
var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
+
+ # OpListToDict converts variables to tensors. We make sure we can get
+ # the unique variable name for normal and resource vaiables.
+ def get_v_name(tensor):
+ if tensor.op.type == 'ReadVariableOp':
+ return tensor.op.inputs[0].op.name
+ else:
+ return tensor.op.name
+
+ v_name_to_tensor = {}
+ for tensor in six.itervalues(var_list):
+ v_name = get_v_name(tensor)
+ v_name_to_tensor[v_name] = tensor
+
# Now swap variables and moving averages
swapped_var_list = {}
- for k, v in six.iteritems(var_list):
- v_swap = self._variable_map.get(v.op.name, None)
- if v_swap:
- swapped_var_list[k] = v_swap
- else:
- swapped_var_list[k] = v
+ for k, tensor in six.iteritems(var_list):
+ v_name = get_v_name(tensor)
+ swapped_v_name = self._swapped_variable_name_map.get(v_name, None)
+ tensor_to_save = tensor
+ if swapped_v_name is not None:
+ if swapped_v_name in v_name_to_tensor:
+ tensor_to_save = v_name_to_tensor[swapped_v_name]
+ else:
+ raise ValueError(
+ ('Variable to swap %s is not part of variables to save. '
+ 'This breaks MovingAverageOptimizer.') % swapped_v_name)
+ swapped_var_list[k] = tensor_to_save
+
# Build the swapping saver.
return saver.Saver(swapped_var_list, name=name, **kwargs)
from tensorflow.contrib.opt.python.training import moving_average_optimizer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
class MovingAverageOptimizerTest(test.TestCase):
def testRun(self):
+ self._helpTestRun(use_resource=False)
+
+ def testRunUseResource(self):
+ # Test that MovingAverageOptimizer works with resource variables.
+ self._helpTestRun(use_resource=True)
+
+ def _helpTestRun(self, use_resource=False):
for sequential_update in [True, False]:
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.test_session(graph=ops.Graph()) as sess:
orig_val0 = [1.0, 2.0]
orig_val1 = [3.0, 4.0]
- var0 = variables.Variable(orig_val0, name='var0', dtype=dtype)
- var1 = variables.Variable(orig_val1, name='var1', dtype=dtype)
+ var0 = variable_scope.get_variable(
+ 'var0',
+ initializer=constant_op.constant(orig_val0, dtype=dtype),
+ use_resource=use_resource)
+ var1 = variable_scope.get_variable(
+ 'var1',
+ initializer=constant_op.constant(orig_val1, dtype=dtype),
+ use_resource=use_resource)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
save_path = os.path.join(save_dir, 'model')
update = opt.apply_gradients(
list(six.moves.zip([grads0, grads1], [var0, var1])))
+ global_vars = variables.global_variables()
+ ema_var0 = [
+ v for v in global_vars
+ if v.op.name == 'var0/ExponentialMovingAverage'
+ ][0]
+ ema_var1 = [
+ v for v in global_vars
+ if v.op.name == 'var1/ExponentialMovingAverage'
+ ][0]
+ perturb = control_flow_ops.group([
+ state_ops.assign_add(var0, [1.0, 1.0]),
+ state_ops.assign_add(var1, [2.0, 2.0]),
+ state_ops.assign_add(ema_var0, [3.0, 3.0]),
+ state_ops.assign_add(ema_var1, [4.0, 4.0])
+ ])
+
+ # Test taht saver with missing ema variables will fail.
+ with self.assertRaisesRegexp(ValueError, r'Variable to swap'):
+ opt.swapping_saver(var_list=[var0])
+
train_saver = opt.swapping_saver()
+ train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0])
inference_saver = saver.Saver()
variables.global_variables_initializer().run()
# Step 1.
update.run()
- val0 = var0.eval()
- val1 = var1.eval()
self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval())
+ if sequential_update:
+ self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+ self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval())
# Test that the swapping saver save/restore operation is identity.
train_saver.save(sess, save_path)
train_saver.restore(sess, save_path)
- val0 = var0.eval()
- val1 = var1.eval()
self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval())
+ if sequential_update:
+ self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+ self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval())
+ # Test that the subset saver saves the EMA variable as well.
+ if sequential_update:
+ subset_save_path = save_path + '_subset'
+ train_saver_subset.save(sess, subset_save_path)
+ perturb.run()
+ self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval())
+ self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval())
+ self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval())
+ self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval())
+ # Restoring should only restore var0 and ema_var0.
+ train_saver_subset.restore(sess, subset_save_path)
+ self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval())
+ self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval())
+ self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval())
+ self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval())
+ # Restore back to previou state.
+ train_saver.restore(sess, save_path)
+
# If updates are parallel, this is not always true after the 1st step.
if sequential_update:
# Test that the normal saver will have the averaged variables.