Added del_hparam(), the counter part of add_hparam.
authorSherry Moore <sherrym@google.com>
Sun, 29 Apr 2018 16:56:16 +0000 (09:56 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 29 Apr 2018 17:01:27 +0000 (10:01 -0700)
PiperOrigin-RevId: 194711291

tensorflow/contrib/training/python/training/hparam.py
tensorflow/contrib/training/python/training/hparam_test.py

index 6c59b68..f0418f0 100644 (file)
@@ -502,6 +502,16 @@ class HParams(object):
             'Must pass a list for multi-valued parameter: %s.' % name)
       setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
 
+  def del_hparam(self, name):
+    """Removes the hyperparameter with key 'name'.
+
+    Args:
+      name: Name of the hyperparameter.
+    """
+    if hasattr(self, name):
+      delattr(self, name)
+      del self._hparam_types[name]
+
   def parse(self, values):
     """Override hyperparameter values, parsing new values from a string.
 
index 96eff86..11fd15b 100644 (file)
@@ -439,6 +439,22 @@ class HParamsTest(test.TestCase):
     self.assertEqual(123, hparams.get('unknown', 123))
     self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
 
+  def testDel(self):
+    hparams = hparam.HParams(aaa=1, b=2.0)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('aaa', 'will fail')
+
+    with self.assertRaises(ValueError):
+      hparams.add_hparam('aaa', 'will fail')
+
+    hparams.del_hparam('aaa')
+    hparams.add_hparam('aaa', 'will work')
+    self.assertEqual('will work', hparams.get('aaa'))
+
+    hparams.set_hparam('aaa', 'still works')
+    self.assertEqual('still works', hparams.get('aaa'))
+
 
 if __name__ == '__main__':
   test.main()