Expose the ExponentialMovingAverage name as a public property.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 31 May 2018 21:42:07 +0000 (14:42 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 21:45:11 +0000 (14:45 -0700)
PiperOrigin-RevId: 198782348

tensorflow/python/training/moving_averages.py
tensorflow/python/training/moving_averages_test.py
tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt

index 61fc828..60cc54c 100644 (file)
@@ -344,6 +344,11 @@ class ExponentialMovingAverage(object):
     self._name = name
     self._averages = {}
 
+  @property
+  def name(self):
+    """The name of this ExponentialMovingAverage object."""
+    return self._name
+
   def apply(self, var_list=None):
     """Maintains moving averages of variables.
 
@@ -394,7 +399,7 @@ class ExponentialMovingAverage(object):
         if isinstance(var, variables.Variable):
           avg = slot_creator.create_slot(var,
                                          var.initialized_value(),
-                                         self._name,
+                                         self.name,
                                          colocate_with_primary=True)
           # NOTE(mrry): We only add `tf.Variable` objects to the
           # `MOVING_AVERAGE_VARIABLES` collection.
@@ -402,7 +407,7 @@ class ExponentialMovingAverage(object):
         else:
           avg = slot_creator.create_zeros_slot(
               var,
-              self._name,
+              self.name,
               colocate_with_primary=(var.op.type in ["Variable",
                                                      "VariableV2",
                                                      "VarHandleOp"]))
@@ -410,7 +415,7 @@ class ExponentialMovingAverage(object):
             zero_debias_true.add(avg)
       self._averages[var] = avg
 
-    with ops.name_scope(self._name) as scope:
+    with ops.name_scope(self.name) as scope:
       decay = ops.convert_to_tensor(self._decay, name="decay")
       if self._num_updates is not None:
         num_updates = math_ops.cast(self._num_updates,
@@ -462,7 +467,7 @@ class ExponentialMovingAverage(object):
     if var in self._averages:
       return self._averages[var].op.name
     return ops.get_default_graph().unique_name(
-        var.op.name + "/" + self._name, mark_as_used=False)
+        var.op.name + "/" + self.name, mark_as_used=False)
 
   def variables_to_restore(self, moving_avg_variables=None):
     """Returns a map of names to `Variables` to restore.
index 6717811..3e85e6b 100644 (file)
@@ -263,6 +263,7 @@ class ExponentialMovingAverageTest(test.TestCase):
       tensor2 = v0 + v1
       ema = moving_averages.ExponentialMovingAverage(
           0.25, zero_debias=zero_debias, name="foo")
+      self.assertEqual("foo", ema.name)
       self.assertEqual("v0/foo", ema.average_name(v0))
       self.assertEqual("v1/foo", ema.average_name(v1))
       self.assertEqual("add/foo", ema.average_name(tensor2))
index 737acbe..c9fe136 100644 (file)
@@ -2,6 +2,10 @@ path: "tensorflow.train.ExponentialMovingAverage"
 tf_class {
   is_instance: "<class \'tensorflow.python.training.moving_averages.ExponentialMovingAverage\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
   member_method {
     name: "__init__"
     argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], "