_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
+_DEFAULT_MODULE_FILTER = 'tensorflow.'
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
return module_text_map
-def get_api_init_text():
+def get_api_init_text(module_filter):
"""Get a map from destination module to __init__.py code for that module.
+ Args:
+ module_filter: Substring used to filter module names to process.
+
Returns:
A dictionary where
key: (string) destination module (for e.g. tf or tf.consts).
for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
- 'tensorflow.' not in module.__name__):
+ module_filter not in module.__name__):
continue
# Do not generate __init__.py files for contrib modules for now.
if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
return module_code_builder.build()
-def create_api_files(output_files):
+def create_api_files(output_files, module_filter):
"""Creates __init__.py files for the Python API.
Args:
output_files: List of __init__.py file paths to create.
Each file must be under api/ directory.
+ module_filter: Substring used to filter module names to process.
Raises:
ValueError: if an output file is not under api/ directory,
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text()
+ module_text_map = get_api_init_text(module_filter)
# Add imports to output files.
missing_output_files = []
',\n'.join(sorted(missing_output_files)))
-def main(output_files):
- create_api_files(output_files)
-
-if __name__ == '__main__':
+def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'outputs', metavar='O', type=str, nargs='+',
'semicolon-separated list of Python files that we expect this script to '
'output. If multiple files are passed in, then we assume output files '
'are listed directly as arguments.')
+ parser.add_argument(
+ '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str,
+ help='Only processes modules with names containing this substring.'
+ )
args = parser.parse_args()
+
if len(args.outputs) == 1:
# If we only get a single argument, then it must be a file containing
# list of outputs.
outputs = [line.strip() for line in output_list_file.read().split(';')]
else:
outputs = args.outputs
- main(outputs)
+ create_api_files(outputs, args.module_filter)
+
+
+if __name__ == '__main__':
+ main()
del sys.modules[_MODULE_NAME]
def testFunctionImportIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected_import = (
'from test.tensorflow.test_module import test_op as test_op1')
self.assertTrue(
msg='%s not in %s' % (expected_import, str(imports)))
def testClassImportIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected_import = 'from test.tensorflow.test_module import TestClass'
self.assertTrue(
'TestClass' in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
def testConstantIsAdded(self):
- imports = create_python_api.get_api_init_text()
+ imports = create_python_api.get_api_init_text(
+ module_filter=create_python_api._DEFAULT_MODULE_FILTER)
expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
self.assertTrue(expected in str(imports),
msg='%s not in %s' % (expected, str(imports)))