bool IsOpWithUnderscorePrefix(const string& s) {
static const std::set<string>* const kUnderscoreOps = new std::set<string>(
{// Lowercase built-in functions and types in Python, from:
- // [x for x in dir(__builtins__) if x[0].islower()]
+ // [x for x in dir(__builtins__) if x[0].islower()] except "round".
// These need to be excluded so they don't conflict with actual built-in
// functions since we use '*' imports.
"abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
"iter", "len", "license", "list", "locals", "long", "map", "max",
"memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
"print", "property", "quit", "range", "raw_input", "reduce", "reload",
- "repr", "reversed", "round", "set", "setattr", "slice", "sorted",
- "staticmethod", "str", "sum", "super", "tuple", "type", "unichr",
- "unicode", "vars", "xrange", "zip",
+ "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
+ "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
+ "xrange", "zip",
// These have the same name as ops defined in Python and might be used
// incorrectly depending on order of '*' imports.
// TODO(annarev): reduce usage of '*' imports and remove these from the
import os
import sys
-from tensorflow import python as tf
from tensorflow.python.util import tf_decorator
"""
+class SymbolExposedTwiceError(Exception):
+ """Raised when different symbols are exported with the same name."""
+ pass
+
+
def format_import(source_module_name, source_name, dest_name):
"""Formats import statement.
return 'import %s as %s' % (source_name, dest_name)
+class _ModuleImportsBuilder(object):
+ """Builds a map from module name to imports included in that module."""
+
+ def __init__(self):
+ self.module_imports = collections.defaultdict(list)
+ self._seen_api_names = set()
+
+ def add_import(
+ self, dest_module_name, source_module_name, source_name, dest_name):
+ """Adds this import to module_imports.
+
+ Args:
+ dest_module_name: (string) Module name to add import to.
+ source_module_name: (string) Module to import from.
+ source_name: (string) Name of the symbol to import.
+ dest_name: (string) Import the symbol using this name.
+
+ Raises:
+ SymbolExposedTwiceError: Raised when an import with the same
+ dest_name has already been added to dest_module_name.
+ """
+ import_str = format_import(source_module_name, source_name, dest_name)
+ if import_str in self.module_imports[dest_module_name]:
+ return
+
+ # Check if we are trying to expose two different symbols with same name.
+ full_api_name = dest_name
+ if dest_module_name:
+ full_api_name = dest_module_name + '.' + full_api_name
+ if full_api_name in self._seen_api_names:
+ raise SymbolExposedTwiceError(
+ 'Trying to export multiple symbols with same name: %s.' %
+ full_api_name)
+ self._seen_api_names.add(full_api_name)
+
+ self.module_imports[dest_module_name].append(import_str)
+
+
def get_api_imports():
"""Get a map from destination module to formatted imports.
(for e.g. 'from foo import bar') and constant
assignments (for e.g. 'FOO = 123').
"""
- module_imports = collections.defaultdict(list)
+ module_imports_builder = _ModuleImportsBuilder()
+ visited_symbols = set()
+
# Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules.
for module in sys.modules.values():
for module_contents_name in dir(module):
attr = getattr(module, module_contents_name)
+ if id(attr) in visited_symbols:
+ continue
# If attr is _tf_api_constants attribute, then add the constants.
if module_contents_name == _API_CONSTANTS_ATTR:
for export in exports:
names = export.split('.')
dest_module = '.'.join(names[:-1])
- import_str = format_import(module.__name__, value, names[-1])
- module_imports[dest_module].append(import_str)
+ module_imports_builder.add_import(
+ dest_module, module.__name__, value, names[-1])
continue
_, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
- # The same op might be accessible from multiple modules.
- # We only want to consider location where function was defined.
- # Here we check if the op is defined in another TensorFlow module in
- # sys.modules.
- if (hasattr(attr, '__module__') and
- attr.__module__.startswith(tf.__name__) and
- attr.__module__ != module.__name__ and
- attr.__module__ in sys.modules and
- module_contents_name in dir(sys.modules[attr.__module__])):
+ # If the same symbol is available using multiple names, only create
+ # imports for it once.
+ if id(attr) in visited_symbols:
continue
+ visited_symbols.add(id(attr))
for export in attr._tf_api_names: # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
- import_str = format_import(
- module.__name__, module_contents_name, names[-1])
- module_imports[dest_module].append(import_str)
+ module_imports_builder.add_import(
+ dest_module, module.__name__, module_contents_name, names[-1])
# Import all required modules in their parent modules.
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
- imported_modules = set(module_imports.keys())
+ imported_modules = set(module_imports_builder.module_imports.keys())
for module in imported_modules:
if not module:
continue
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
import_from += '.' + parent_module
- submodule_import = format_import(
- import_from, module_split[submodule_index],
+ module_imports_builder.add_import(
+ parent_module, import_from, module_split[submodule_index],
module_split[submodule_index])
- if submodule_import not in module_imports[parent_module]:
- module_imports[parent_module].append(submodule_import)
- return module_imports
+ return module_imports_builder.module_imports
def create_api_files(output_files):