From: Sherry Moore Date: Sun, 29 Apr 2018 16:56:16 +0000 (-0700) Subject: Added del_hparam(), the counter part of add_hparam. X-Git-Tag: upstream/v1.9.0_rc1~190^2^2~3 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=45529aaac3f5c1d290c285a4e86c434600ec2d92;p=platform%2Fupstream%2Ftensorflow.git Added del_hparam(), the counter part of add_hparam. PiperOrigin-RevId: 194711291 --- diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 6c59b68..f0418f0 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -502,6 +502,16 @@ class HParams(object): 'Must pass a list for multi-valued parameter: %s.' % name) setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + def parse(self, values): """Override hyperparameter values, parsing new values from a string. diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 96eff86..11fd15b 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -439,6 +439,22 @@ class HParamsTest(test.TestCase): self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3])) + def testDel(self): + hparams = hparam.HParams(aaa=1, b=2.0) + + with self.assertRaises(ValueError): + hparams.set_hparam('aaa', 'will fail') + + with self.assertRaises(ValueError): + hparams.add_hparam('aaa', 'will fail') + + hparams.del_hparam('aaa') + hparams.add_hparam('aaa', 'will work') + self.assertEqual('will work', hparams.get('aaa')) + + hparams.set_hparam('aaa', 'still works') + self.assertEqual('still works', hparams.get('aaa')) + if __name__ == '__main__': test.main()