Add option to set more generic module name filter for API generation.
authorMichael Case <mikecase@google.com>
Wed, 9 May 2018 22:07:40 +0000 (15:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 9 May 2018 22:10:15 +0000 (15:10 -0700)
PiperOrigin-RevId: 196036164

tensorflow/tools/api/generator/create_python_api.py
tensorflow/tools/api/generator/create_python_api_test.py

index 65baa6e..b6171ce 100644 (file)
@@ -29,6 +29,7 @@ from tensorflow.python.util import tf_decorator
 _API_CONSTANTS_ATTR = '_tf_api_constants'
 _API_NAMES_ATTR = '_tf_api_names'
 _API_DIR = '/api/'
+_DEFAULT_MODULE_FILTER = 'tensorflow.'
 _OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
 _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
 
@@ -145,9 +146,12 @@ __all__.extend([s for s in _names_with_underscore])
     return module_text_map
 
 
-def get_api_init_text():
+def get_api_init_text(module_filter):
   """Get a map from destination module to __init__.py code for that module.
 
+  Args:
+    module_filter: Substring used to filter module names to process.
+
   Returns:
     A dictionary where
       key: (string) destination module (for e.g. tf or tf.consts).
@@ -161,7 +165,7 @@ def get_api_init_text():
   for module in list(sys.modules.values()):
     # Only look at tensorflow modules.
     if (not module or not hasattr(module, '__name__') or
-        'tensorflow.' not in module.__name__):
+        module_filter not in module.__name__):
       continue
     # Do not generate __init__.py files for contrib modules for now.
     if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
@@ -214,12 +218,13 @@ def get_api_init_text():
   return module_code_builder.build()
 
 
-def create_api_files(output_files):
+def create_api_files(output_files, module_filter):
   """Creates __init__.py files for the Python API.
 
   Args:
     output_files: List of __init__.py file paths to create.
       Each file must be under api/ directory.
+    module_filter: Substring used to filter module names to process.
 
   Raises:
     ValueError: if an output file is not under api/ directory,
@@ -247,7 +252,7 @@ def create_api_files(output_files):
       os.makedirs(os.path.dirname(file_path))
     open(file_path, 'a').close()
 
-  module_text_map = get_api_init_text()
+  module_text_map = get_api_init_text(module_filter)
 
   # Add imports to output files.
   missing_output_files = []
@@ -269,10 +274,7 @@ def create_api_files(output_files):
         ',\n'.join(sorted(missing_output_files)))
 
 
-def main(output_files):
-  create_api_files(output_files)
-
-if __name__ == '__main__':
+def main():
   parser = argparse.ArgumentParser()
   parser.add_argument(
       'outputs', metavar='O', type=str, nargs='+',
@@ -280,7 +282,12 @@ if __name__ == '__main__':
       'semicolon-separated list of Python files that we expect this script to '
       'output. If multiple files are passed in, then we assume output files '
       'are listed directly as arguments.')
+  parser.add_argument(
+      '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str,
+      help='Only processes modules with names containing this substring.'
+  )
   args = parser.parse_args()
+
   if len(args.outputs) == 1:
     # If we only get a single argument, then it must be a file containing
     # list of outputs.
@@ -288,4 +295,8 @@ if __name__ == '__main__':
       outputs = [line.strip() for line in output_list_file.read().split(';')]
   else:
     outputs = args.outputs
-  main(outputs)
+  create_api_files(outputs, args.module_filter)
+
+
+if __name__ == '__main__':
+  main()
index 218c812..5f10522 100644 (file)
@@ -56,7 +56,8 @@ class CreatePythonApiTest(test.TestCase):
     del sys.modules[_MODULE_NAME]
 
   def testFunctionImportIsAdded(self):
-    imports = create_python_api.get_api_init_text()
+    imports = create_python_api.get_api_init_text(
+        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
     expected_import = (
         'from test.tensorflow.test_module import test_op as test_op1')
     self.assertTrue(
@@ -69,14 +70,16 @@ class CreatePythonApiTest(test.TestCase):
         msg='%s not in %s' % (expected_import, str(imports)))
 
   def testClassImportIsAdded(self):
-    imports = create_python_api.get_api_init_text()
+    imports = create_python_api.get_api_init_text(
+        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
     expected_import = 'from test.tensorflow.test_module import TestClass'
     self.assertTrue(
         'TestClass' in str(imports),
         msg='%s not in %s' % (expected_import, str(imports)))
 
   def testConstantIsAdded(self):
-    imports = create_python_api.get_api_init_text()
+    imports = create_python_api.get_api_init_text(
+        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
     expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
     self.assertTrue(expected in str(imports),
                     msg='%s not in %s' % (expected, str(imports)))