From 45529aaac3f5c1d290c285a4e86c434600ec2d92 Mon Sep 17 00:00:00 2001 From: Sherry Moore Date: Sun, 29 Apr 2018 09:56:16 -0700 Subject: [PATCH] Added del_hparam(), the counter part of add_hparam. PiperOrigin-RevId: 194711291 --- tensorflow/contrib/training/python/training/hparam.py | 10 ++++++++++ .../contrib/training/python/training/hparam_test.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 6c59b68..f0418f0 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -502,6 +502,16 @@ class HParams(object): 'Must pass a list for multi-valued parameter: %s.' % name) setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + def parse(self, values): """Override hyperparameter values, parsing new values from a string. diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 96eff86..11fd15b 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -439,6 +439,22 @@ class HParamsTest(test.TestCase): self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3])) + def testDel(self): + hparams = hparam.HParams(aaa=1, b=2.0) + + with self.assertRaises(ValueError): + hparams.set_hparam('aaa', 'will fail') + + with self.assertRaises(ValueError): + hparams.add_hparam('aaa', 'will fail') + + hparams.del_hparam('aaa') + hparams.add_hparam('aaa', 'will work') + self.assertEqual('will work', hparams.get('aaa')) + + hparams.set_hparam('aaa', 'still works') + self.assertEqual('still works', hparams.get('aaa')) + if __name__ == '__main__': test.main() -- 2.7.4