From bbc8fe70603c21f2a2a7086530035364b6f5b207 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Mon, 21 May 2018 19:45:21 -0700 Subject: [PATCH] Internal Change PiperOrigin-RevId: 197501805 --- tensorflow/python/estimator/estimator.py | 1 + tensorflow/python/training/warm_starting_util.py | 2 +- tensorflow/python/util/tf_export.py | 35 ++++++++++++++---------- tensorflow/python/util/tf_export_test.py | 7 +++++ 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 10f4de3..a2e84c8 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1616,6 +1616,7 @@ class _DatasetInitializerHook(training.SessionRunHook): session.run(self._initializer) VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name +tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo) @tf_export('estimator.WarmStartSettings') diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index 4d4fb39..b0f37f8 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -33,7 +33,7 @@ from tensorflow.python.training import saver from tensorflow.python.util.tf_export import tf_export -@tf_export("train.VocabInfo", "estimator.VocabInfo") +@tf_export("train.VocabInfo", allow_multiple_exports=True) class VocabInfo( collections.namedtuple("VocabInfo", [ "new_vocab", diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py index a30b8b1..bf3961c 100644 --- a/tensorflow/python/util/tf_export.py +++ b/tensorflow/python/util/tf_export.py @@ -59,13 +59,19 @@ class tf_export(object): # pylint: disable=invalid-name Args: *args: API names in dot delimited format. - **kwargs: Optional keyed arguments. Currently only supports 'overrides' - argument. overrides: List of symbols that this is overriding - (those overrided api exports will be removed). Note: passing overrides - has no effect on exporting a constant. + **kwargs: Optional keyed arguments. + overrides: List of symbols that this is overriding + (those overrided api exports will be removed). Note: passing overrides + has no effect on exporting a constant. + allow_multiple_exports: Allows exporting the same symbol multiple + times with multiple `tf_export` usages. Prefer however, to list all + of the exported names in a single `tf_export` usage when possible. + """ self._names = args self._overrides = kwargs.get('overrides', []) + self._allow_multiple_exports = kwargs.get( + 'allow_multiple_exports', False) def __call__(self, func): """Calls this decorator. @@ -77,7 +83,8 @@ class tf_export(object): # pylint: disable=invalid-name The input function with _tf_api_names attribute set. Raises: - SymbolAlreadyExposedError: Raised when a symbol already has API names. + SymbolAlreadyExposedError: Raised when a symbol already has API names + and kwarg `allow_multiple_exports` not set. """ # Undecorate overridden names for f in self._overrides: @@ -90,16 +97,14 @@ class tf_export(object): # pylint: disable=invalid-name # __dict__ instead of using hasattr to verify that subclasses have # their own _tf_api_names as opposed to just inheriting it. if '_tf_api_names' in undecorated_func.__dict__: - # pylint: disable=protected-access - raise SymbolAlreadyExposedError( - 'Symbol %s is already exposed as %s.' % - (undecorated_func.__name__, undecorated_func._tf_api_names)) - # pylint: enable=protected-access - - # Complete the export by creating/overriding attribute - # pylint: disable=protected-access - undecorated_func._tf_api_names = self._names - # pylint: enable=protected-access + if self._allow_multiple_exports: + undecorated_func._tf_api_names += self._names # pylint: disable=protected-access + else: + raise SymbolAlreadyExposedError( + 'Symbol %s is already exposed as %s.' % + (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access + else: + undecorated_func._tf_api_names = self._names # pylint: disable=protected-access return func def export_constant(self, module_name, name): diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py index b9e26ec..ace3f05 100644 --- a/tensorflow/python/util/tf_export_test.py +++ b/tensorflow/python/util/tf_export_test.py @@ -128,6 +128,13 @@ class ValidateExportTest(test.TestCase): with self.assertRaises(tf_export.SymbolAlreadyExposedError): export_decorator(_test_function) + def testEAllowMultipleExports(self): + _test_function._tf_api_names = ['name1', 'name2'] + tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)( + _test_function) + self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'], + _test_function._tf_api_names) + def testOverridesFunction(self): _test_function2._tf_api_names = ['abc'] -- 2.7.4