From: Michael Case Date: Wed, 16 May 2018 18:51:48 +0000 (-0700) Subject: Internal Change. X-Git-Tag: upstream/v1.9.0_rc1~106^2^2~43 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=01ed446a17f4b6cb3a9e4ffd2c39641b9e96ed96;p=platform%2Fupstream%2Ftensorflow.git Internal Change. PiperOrigin-RevId: 196864489 --- diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index d72cb3b..9cb137d 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -20,21 +20,17 @@ from __future__ import print_function import argparse import collections +import importlib import os import sys -# Populate `sys.modules` which will be traversed to find TensorFlow modules. -# Make sure your module gets imported in tensorflow/python/__init__.py for it -# to be seen by this script. -import tensorflow.python # pylint: disable=unused-import - from tensorflow.python.util import tf_decorator _API_CONSTANTS_ATTR = '_tf_api_constants' _API_NAMES_ATTR = '_tf_api_names' _API_DIR = '/api/' -_DEFAULT_MODULE_FILTER = 'tensorflow.' +_DEFAULT_PACKAGE = 'tensorflow.python' _OUTPUT_MODULE = 'tensorflow.tools.api.generator.api' _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. @@ -142,7 +138,10 @@ class _ModuleInitCodeBuilder(object): # since we import from it using * import. underscore_names_str = ', '.join( '\'%s\'' % name for name in self._underscore_names_in_root) - module_text_map[''] += ''' + # We will always generate a root __init__.py file to let us handle * + # imports consistently. Be sure to have a root __init__.py file listed in + # the script outputs. + module_text_map[''] = module_text_map.get('', '') + ''' _names_with_underscore = [%s] __all__ = [s for s in dir() if not s.startswith('_')] __all__.extend([s for s in _names_with_underscore]) @@ -151,11 +150,12 @@ __all__.extend([s for s in _names_with_underscore]) return module_text_map -def get_api_init_text(module_filter): +def get_api_init_text(package): """Get a map from destination module to __init__.py code for that module. Args: - module_filter: Substring used to filter module names to process. + package: Base python package containing python with target tf_export + decorators. Returns: A dictionary where @@ -170,7 +170,7 @@ def get_api_init_text(module_filter): for module in list(sys.modules.values()): # Only look at tensorflow modules. if (not module or not hasattr(module, '__name__') or - module_filter not in module.__name__): + package not in module.__name__): continue # Do not generate __init__.py files for contrib modules for now. if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): @@ -223,13 +223,14 @@ def get_api_init_text(module_filter): return module_code_builder.build() -def create_api_files(output_files, module_filter): +def create_api_files(output_files, package): """Creates __init__.py files for the Python API. Args: output_files: List of __init__.py file paths to create. Each file must be under api/ directory. - module_filter: Substring used to filter module names to process. + package: Base python package containing python with target tf_export + decorators. Raises: ValueError: if an output file is not under api/ directory, @@ -257,7 +258,7 @@ def create_api_files(output_files, module_filter): os.makedirs(os.path.dirname(file_path)) open(file_path, 'a').close() - module_text_map = get_api_init_text(module_filter) + module_text_map = get_api_init_text(package) # Add imports to output files. missing_output_files = [] @@ -288,9 +289,9 @@ def main(): 'output. If multiple files are passed in, then we assume output files ' 'are listed directly as arguments.') parser.add_argument( - '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str, - help='Only processes modules with names containing this substring.' - ) + '--package', default=_DEFAULT_PACKAGE, type=str, + help='Base package that imports modules containing the target tf_export ' + 'decorators.') args = parser.parse_args() if len(args.outputs) == 1: @@ -300,7 +301,10 @@ def main(): outputs = [line.strip() for line in output_list_file.read().split(';')] else: outputs = args.outputs - create_api_files(outputs, args.module_filter) + + # Populate `sys.modules` with modules containing tf_export(). + importlib.import_module(args.package) + create_api_files(outputs, args.package) if __name__ == '__main__': diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py index 5f10522..986340c 100644 --- a/tensorflow/tools/api/generator/create_python_api_test.py +++ b/tensorflow/tools/api/generator/create_python_api_test.py @@ -37,7 +37,7 @@ class TestClass(object): _TEST_CONSTANT = 5 -_MODULE_NAME = 'test.tensorflow.test_module' +_MODULE_NAME = 'tensorflow.python.test_module' class CreatePythonApiTest(test.TestCase): @@ -57,30 +57,34 @@ class CreatePythonApiTest(test.TestCase): def testFunctionImportIsAdded(self): imports = create_python_api.get_api_init_text( - module_filter=create_python_api._DEFAULT_MODULE_FILTER) + package=create_python_api._DEFAULT_PACKAGE) expected_import = ( - 'from test.tensorflow.test_module import test_op as test_op1') + 'from tensorflow.python.test_module ' + 'import test_op as test_op1') self.assertTrue( expected_import in str(imports), msg='%s not in %s' % (expected_import, str(imports))) - expected_import = 'from test.tensorflow.test_module import test_op' + expected_import = ('from tensorflow.python.test_module ' + 'import test_op') self.assertTrue( expected_import in str(imports), msg='%s not in %s' % (expected_import, str(imports))) def testClassImportIsAdded(self): imports = create_python_api.get_api_init_text( - module_filter=create_python_api._DEFAULT_MODULE_FILTER) - expected_import = 'from test.tensorflow.test_module import TestClass' + package=create_python_api._DEFAULT_PACKAGE) + expected_import = ('from tensorflow.python.test_module ' + 'import TestClass') self.assertTrue( 'TestClass' in str(imports), msg='%s not in %s' % (expected_import, str(imports))) def testConstantIsAdded(self): imports = create_python_api.get_api_init_text( - module_filter=create_python_api._DEFAULT_MODULE_FILTER) - expected = 'from test.tensorflow.test_module import _TEST_CONSTANT' + package=create_python_api._DEFAULT_PACKAGE) + expected = ('from tensorflow.python.test_module ' + 'import _TEST_CONSTANT') self.assertTrue(expected in str(imports), msg='%s not in %s' % (expected, str(imports)))