import argparse
import collections
+import importlib
import os
import sys
-# Populate `sys.modules` which will be traversed to find TensorFlow modules.
-# Make sure your module gets imported in tensorflow/python/__init__.py for it
-# to be seen by this script.
-import tensorflow.python # pylint: disable=unused-import
-
from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
-_DEFAULT_MODULE_FILTER = 'tensorflow.'
+_DEFAULT_PACKAGE = 'tensorflow.python'
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
# since we import from it using * import.
underscore_names_str = ', '.join(
'\'%s\'' % name for name in self._underscore_names_in_root)
- module_text_map[''] += '''
+ # We will always generate a root __init__.py file to let us handle *
+ # imports consistently. Be sure to have a root __init__.py file listed in
+ # the script outputs.
+ module_text_map[''] = module_text_map.get('', '') + '''
_names_with_underscore = [%s]
__all__ = [s for s in dir() if not s.startswith('_')]
__all__.extend([s for s in _names_with_underscore])
return module_text_map
-def get_api_init_text(module_filter):
+def get_api_init_text(package):
"""Get a map from destination module to __init__.py code for that module.
Args:
- module_filter: Substring used to filter module names to process.
+ package: Base python package containing python with target tf_export
+ decorators.
Returns:
A dictionary where
for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
- module_filter not in module.__name__):
+ package not in module.__name__):
continue
# Do not generate __init__.py files for contrib modules for now.
if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
return module_code_builder.build()
-def create_api_files(output_files, module_filter):
+def create_api_files(output_files, package):
"""Creates __init__.py files for the Python API.
Args:
output_files: List of __init__.py file paths to create.
Each file must be under api/ directory.
- module_filter: Substring used to filter module names to process.
+ package: Base python package containing python with target tf_export
+ decorators.
Raises:
ValueError: if an output file is not under api/ directory,
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(module_filter)
+ module_text_map = get_api_init_text(package)
# Add imports to output files.
missing_output_files = []
'output. If multiple files are passed in, then we assume output files '
'are listed directly as arguments.')
parser.add_argument(
- '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str,
- help='Only processes modules with names containing this substring.'
- )
+ '--package', default=_DEFAULT_PACKAGE, type=str,
+ help='Base package that imports modules containing the target tf_export '
+ 'decorators.')
args = parser.parse_args()
if len(args.outputs) == 1:
outputs = [line.strip() for line in output_list_file.read().split(';')]
else:
outputs = args.outputs
- create_api_files(outputs, args.module_filter)
+
+ # Populate `sys.modules` with modules containing tf_export().
+ importlib.import_module(args.package)
+ create_api_files(outputs, args.package)
if __name__ == '__main__':
_TEST_CONSTANT = 5
-_MODULE_NAME = 'test.tensorflow.test_module'
+_MODULE_NAME = 'tensorflow.python.test_module'
class CreatePythonApiTest(test.TestCase):
def testFunctionImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- module_filter=create_python_api._DEFAULT_MODULE_FILTER)
+ package=create_python_api._DEFAULT_PACKAGE)
expected_import = (
- 'from test.tensorflow.test_module import test_op as test_op1')
+ 'from tensorflow.python.test_module '
+ 'import test_op as test_op1')
self.assertTrue(
expected_import in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
- expected_import = 'from test.tensorflow.test_module import test_op'
+ expected_import = ('from tensorflow.python.test_module '
+ 'import test_op')
self.assertTrue(
expected_import in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
def testClassImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- module_filter=create_python_api._DEFAULT_MODULE_FILTER)
- expected_import = 'from test.tensorflow.test_module import TestClass'
+ package=create_python_api._DEFAULT_PACKAGE)
+ expected_import = ('from tensorflow.python.test_module '
+ 'import TestClass')
self.assertTrue(
'TestClass' in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
def testConstantIsAdded(self):
imports = create_python_api.get_api_init_text(
- module_filter=create_python_api._DEFAULT_MODULE_FILTER)
- expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
+ package=create_python_api._DEFAULT_PACKAGE)
+ expected = ('from tensorflow.python.test_module '
+ 'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),
msg='%s not in %s' % (expected, str(imports)))