Internal Change.
authorMichael Case <mikecase@google.com>
Wed, 16 May 2018 18:51:48 +0000 (11:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 18:54:27 +0000 (11:54 -0700)
PiperOrigin-RevId: 196864489

tensorflow/tools/api/generator/create_python_api.py
tensorflow/tools/api/generator/create_python_api_test.py

index d72cb3b..9cb137d 100644 (file)
@@ -20,21 +20,17 @@ from __future__ import print_function
 
 import argparse
 import collections
+import importlib
 import os
 import sys
 
-# Populate `sys.modules` which will be traversed to find TensorFlow modules.
-# Make sure your module gets imported in tensorflow/python/__init__.py for it
-# to be seen by this script.
-import tensorflow.python  # pylint: disable=unused-import
-
 from tensorflow.python.util import tf_decorator
 
 
 _API_CONSTANTS_ATTR = '_tf_api_constants'
 _API_NAMES_ATTR = '_tf_api_names'
 _API_DIR = '/api/'
-_DEFAULT_MODULE_FILTER = 'tensorflow.'
+_DEFAULT_PACKAGE = 'tensorflow.python'
 _OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
 _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
 
@@ -142,7 +138,10 @@ class _ModuleInitCodeBuilder(object):
     # since we import from it using * import.
     underscore_names_str = ', '.join(
         '\'%s\'' % name for name in self._underscore_names_in_root)
-    module_text_map[''] += '''
+    # We will always generate a root __init__.py file to let us handle *
+    # imports consistently. Be sure to have a root __init__.py file listed in
+    # the script outputs.
+    module_text_map[''] = module_text_map.get('', '') + '''
 _names_with_underscore = [%s]
 __all__ = [s for s in dir() if not s.startswith('_')]
 __all__.extend([s for s in _names_with_underscore])
@@ -151,11 +150,12 @@ __all__.extend([s for s in _names_with_underscore])
     return module_text_map
 
 
-def get_api_init_text(module_filter):
+def get_api_init_text(package):
   """Get a map from destination module to __init__.py code for that module.
 
   Args:
-    module_filter: Substring used to filter module names to process.
+    package: Base python package containing python with target tf_export
+      decorators.
 
   Returns:
     A dictionary where
@@ -170,7 +170,7 @@ def get_api_init_text(module_filter):
   for module in list(sys.modules.values()):
     # Only look at tensorflow modules.
     if (not module or not hasattr(module, '__name__') or
-        module_filter not in module.__name__):
+        package not in module.__name__):
       continue
     # Do not generate __init__.py files for contrib modules for now.
     if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
@@ -223,13 +223,14 @@ def get_api_init_text(module_filter):
   return module_code_builder.build()
 
 
-def create_api_files(output_files, module_filter):
+def create_api_files(output_files, package):
   """Creates __init__.py files for the Python API.
 
   Args:
     output_files: List of __init__.py file paths to create.
       Each file must be under api/ directory.
-    module_filter: Substring used to filter module names to process.
+    package: Base python package containing python with target tf_export
+      decorators.
 
   Raises:
     ValueError: if an output file is not under api/ directory,
@@ -257,7 +258,7 @@ def create_api_files(output_files, module_filter):
       os.makedirs(os.path.dirname(file_path))
     open(file_path, 'a').close()
 
-  module_text_map = get_api_init_text(module_filter)
+  module_text_map = get_api_init_text(package)
 
   # Add imports to output files.
   missing_output_files = []
@@ -288,9 +289,9 @@ def main():
       'output. If multiple files are passed in, then we assume output files '
       'are listed directly as arguments.')
   parser.add_argument(
-      '--module_filter', default=_DEFAULT_MODULE_FILTER, type=str,
-      help='Only processes modules with names containing this substring.'
-  )
+      '--package', default=_DEFAULT_PACKAGE, type=str,
+      help='Base package that imports modules containing the target tf_export '
+           'decorators.')
   args = parser.parse_args()
 
   if len(args.outputs) == 1:
@@ -300,7 +301,10 @@ def main():
       outputs = [line.strip() for line in output_list_file.read().split(';')]
   else:
     outputs = args.outputs
-  create_api_files(outputs, args.module_filter)
+
+  # Populate `sys.modules` with modules containing tf_export().
+  importlib.import_module(args.package)
+  create_api_files(outputs, args.package)
 
 
 if __name__ == '__main__':
index 5f10522..986340c 100644 (file)
@@ -37,7 +37,7 @@ class TestClass(object):
 
 
 _TEST_CONSTANT = 5
-_MODULE_NAME = 'test.tensorflow.test_module'
+_MODULE_NAME = 'tensorflow.python.test_module'
 
 
 class CreatePythonApiTest(test.TestCase):
@@ -57,30 +57,34 @@ class CreatePythonApiTest(test.TestCase):
 
   def testFunctionImportIsAdded(self):
     imports = create_python_api.get_api_init_text(
-        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
+        package=create_python_api._DEFAULT_PACKAGE)
     expected_import = (
-        'from test.tensorflow.test_module import test_op as test_op1')
+        'from tensorflow.python.test_module '
+        'import test_op as test_op1')
     self.assertTrue(
         expected_import in str(imports),
         msg='%s not in %s' % (expected_import, str(imports)))
 
-    expected_import = 'from test.tensorflow.test_module import test_op'
+    expected_import = ('from tensorflow.python.test_module '
+                       'import test_op')
     self.assertTrue(
         expected_import in str(imports),
         msg='%s not in %s' % (expected_import, str(imports)))
 
   def testClassImportIsAdded(self):
     imports = create_python_api.get_api_init_text(
-        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
-    expected_import = 'from test.tensorflow.test_module import TestClass'
+        package=create_python_api._DEFAULT_PACKAGE)
+    expected_import = ('from tensorflow.python.test_module '
+                       'import TestClass')
     self.assertTrue(
         'TestClass' in str(imports),
         msg='%s not in %s' % (expected_import, str(imports)))
 
   def testConstantIsAdded(self):
     imports = create_python_api.get_api_init_text(
-        module_filter=create_python_api._DEFAULT_MODULE_FILTER)
-    expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
+        package=create_python_api._DEFAULT_PACKAGE)
+    expected = ('from tensorflow.python.test_module '
+                'import _TEST_CONSTANT')
     self.assertTrue(expected in str(imports),
                     msg='%s not in %s' % (expected, str(imports)))