Fix cases where we export incorrect symbol with tf_export. This can happen when
authorAnna R <annarev@google.com>
Thu, 22 Mar 2018 20:55:48 +0000 (13:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 20:58:12 +0000 (13:58 -0700)
both generated op and its python wrapper have tf_export decorator.
create_python_api.py now checks that we don't export different symbols with same name. Also, simplified some logic.

PiperOrigin-RevId: 190120505

18 files changed:
tensorflow/core/api_def/python_api/api_def_ArgMax.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ArgMin.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_CountUpTo.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Div.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Erf.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Identity.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Mod.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Rank.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Round.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ScatterNdUpdate.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ScatterUpdate.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ShapeN.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Sign.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Sqrt.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Square.pbtxt [new file with mode: 0644]
tensorflow/python/framework/python_op_gen.cc
tensorflow/python/ops/math_ops.py
tensorflow/tools/api/generator/create_python_api.py

diff --git a/tensorflow/core/api_def/python_api/api_def_ArgMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ArgMax.pbtxt
new file mode 100644 (file)
index 0000000..4c23a43
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ArgMax"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ArgMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ArgMin.pbtxt
new file mode 100644 (file)
index 0000000..daa14f6
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ArgMin"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CountUpTo.pbtxt b/tensorflow/core/api_def/python_api/api_def_CountUpTo.pbtxt
new file mode 100644 (file)
index 0000000..f41be2f
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CountUpTo"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Div.pbtxt b/tensorflow/core/api_def/python_api/api_def_Div.pbtxt
new file mode 100644 (file)
index 0000000..8e5537c
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Div"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Erf.pbtxt b/tensorflow/core/api_def/python_api/api_def_Erf.pbtxt
new file mode 100644 (file)
index 0000000..3911672
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Erf"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Identity.pbtxt b/tensorflow/core/api_def/python_api/api_def_Identity.pbtxt
new file mode 100644 (file)
index 0000000..00f2afd
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Identity"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Mod.pbtxt b/tensorflow/core/api_def/python_api/api_def_Mod.pbtxt
new file mode 100644 (file)
index 0000000..48d828c
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Mod"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rank.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rank.pbtxt
new file mode 100644 (file)
index 0000000..05aa12f
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Rank"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Round.pbtxt b/tensorflow/core/api_def/python_api/api_def_Round.pbtxt
new file mode 100644 (file)
index 0000000..74428e2
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Round"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNdUpdate.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNdUpdate.pbtxt
new file mode 100644 (file)
index 0000000..ccf4a9c
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ScatterNdUpdate"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterUpdate.pbtxt
new file mode 100644 (file)
index 0000000..e4c41c1
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ScatterUpdate"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ShapeN.pbtxt b/tensorflow/core/api_def/python_api/api_def_ShapeN.pbtxt
new file mode 100644 (file)
index 0000000..b2dbe74
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ShapeN"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sign.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sign.pbtxt
new file mode 100644 (file)
index 0000000..c2ee91d
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Sign"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sqrt.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sqrt.pbtxt
new file mode 100644 (file)
index 0000000..59e2dfe
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Sqrt"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Square.pbtxt b/tensorflow/core/api_def/python_api/api_def_Square.pbtxt
new file mode 100644 (file)
index 0000000..7b39ae2
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Square"
+  visibility: HIDDEN
+}
index 03721c9..9850f0b 100644 (file)
@@ -78,7 +78,7 @@ bool IsPythonReserved(const string& s) {
 bool IsOpWithUnderscorePrefix(const string& s) {
   static const std::set<string>* const kUnderscoreOps = new std::set<string>(
       {// Lowercase built-in functions and types in Python, from:
-       // [x for x in dir(__builtins__) if x[0].islower()]
+       // [x for x in dir(__builtins__) if x[0].islower()] except "round".
        // These need to be excluded so they don't conflict with actual built-in
        // functions since we use '*' imports.
        "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
@@ -90,9 +90,9 @@ bool IsOpWithUnderscorePrefix(const string& s) {
        "iter", "len", "license", "list", "locals", "long", "map", "max",
        "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
        "print", "property", "quit", "range", "raw_input", "reduce", "reload",
-       "repr", "reversed", "round", "set", "setattr", "slice", "sorted",
-       "staticmethod", "str", "sum", "super", "tuple", "type", "unichr",
-       "unicode", "vars", "xrange", "zip",
+       "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
+       "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
+       "xrange", "zip",
        // These have the same name as ops defined in Python and might be used
        // incorrectly depending on order of '*' imports.
        // TODO(annarev): reduce usage of '*' imports and remove these from the
index c893bf9..4699e05 100644 (file)
@@ -180,6 +180,8 @@ linspace = gen_math_ops.lin_space
 
 arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max)  # pylint: disable=used-before-assignment
 arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
+tf_export("arg_max")(arg_max)
+tf_export("arg_min")(arg_min)
 
 
 # This is set by resource_variable_ops.py. It is included in this way since
@@ -1196,7 +1198,7 @@ tf_export("floor_div")(floor_div)
 truncatemod = gen_math_ops.truncate_mod
 tf_export("truncatemod")(truncatemod)
 floormod = gen_math_ops.floor_mod
-tf_export("floormod")(floormod)
+tf_export("floormod", "mod")(floormod)
 
 
 def _mul_dispatch(x, y, name=None):
index bb7c3e7..183c473 100644 (file)
@@ -23,7 +23,6 @@ import collections
 import os
 import sys
 
-from tensorflow import python as tf
 from tensorflow.python.util import tf_decorator
 
 
@@ -39,6 +38,11 @@ Generated by: tensorflow/tools/api/generator/create_python_api.py script.
 """
 
 
+class SymbolExposedTwiceError(Exception):
+  """Raised when different symbols are exported with the same name."""
+  pass
+
+
 def format_import(source_module_name, source_name, dest_name):
   """Formats import statement.
 
@@ -63,6 +67,44 @@ def format_import(source_module_name, source_name, dest_name):
       return 'import %s as %s' % (source_name, dest_name)
 
 
+class _ModuleImportsBuilder(object):
+  """Builds a map from module name to imports included in that module."""
+
+  def __init__(self):
+    self.module_imports = collections.defaultdict(list)
+    self._seen_api_names = set()
+
+  def add_import(
+      self, dest_module_name, source_module_name, source_name, dest_name):
+    """Adds this import to module_imports.
+
+    Args:
+      dest_module_name: (string) Module name to add import to.
+      source_module_name: (string) Module to import from.
+      source_name: (string) Name of the symbol to import.
+      dest_name: (string) Import the symbol using this name.
+
+    Raises:
+      SymbolExposedTwiceError: Raised when an import with the same
+        dest_name has already been added to dest_module_name.
+    """
+    import_str = format_import(source_module_name, source_name, dest_name)
+    if import_str in self.module_imports[dest_module_name]:
+      return
+
+    # Check if we are trying to expose two different symbols with same name.
+    full_api_name = dest_name
+    if dest_module_name:
+      full_api_name = dest_module_name + '.' + full_api_name
+    if full_api_name in self._seen_api_names:
+      raise SymbolExposedTwiceError(
+          'Trying to export multiple symbols with same name: %s.' %
+          full_api_name)
+    self._seen_api_names.add(full_api_name)
+
+    self.module_imports[dest_module_name].append(import_str)
+
+
 def get_api_imports():
   """Get a map from destination module to formatted imports.
 
@@ -73,7 +115,9 @@ def get_api_imports():
           (for e.g. 'from foo import bar') and constant
           assignments (for e.g. 'FOO = 123').
   """
-  module_imports = collections.defaultdict(list)
+  module_imports_builder = _ModuleImportsBuilder()
+  visited_symbols = set()
+
   # Traverse over everything imported above. Specifically,
   # we want to traverse over TensorFlow Python modules.
   for module in sys.modules.values():
@@ -86,6 +130,8 @@ def get_api_imports():
 
     for module_contents_name in dir(module):
       attr = getattr(module, module_contents_name)
+      if id(attr) in visited_symbols:
+        continue
 
       # If attr is _tf_api_constants attribute, then add the constants.
       if module_contents_name == _API_CONSTANTS_ATTR:
@@ -93,36 +139,30 @@ def get_api_imports():
           for export in exports:
             names = export.split('.')
             dest_module = '.'.join(names[:-1])
-            import_str = format_import(module.__name__, value, names[-1])
-            module_imports[dest_module].append(import_str)
+            module_imports_builder.add_import(
+                dest_module, module.__name__, value, names[-1])
         continue
 
       _, attr = tf_decorator.unwrap(attr)
       # If attr is a symbol with _tf_api_names attribute, then
       # add import for it.
       if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
-        # The same op might be accessible from multiple modules.
-        # We only want to consider location where function was defined.
-        # Here we check if the op is defined in another TensorFlow module in
-        # sys.modules.
-        if (hasattr(attr, '__module__') and
-            attr.__module__.startswith(tf.__name__) and
-            attr.__module__ != module.__name__ and
-            attr.__module__ in sys.modules and
-            module_contents_name in dir(sys.modules[attr.__module__])):
+        # If the same symbol is available using multiple names, only create
+        # imports for it once.
+        if id(attr) in visited_symbols:
           continue
+        visited_symbols.add(id(attr))
 
         for export in attr._tf_api_names:  # pylint: disable=protected-access
           names = export.split('.')
           dest_module = '.'.join(names[:-1])
-          import_str = format_import(
-              module.__name__, module_contents_name, names[-1])
-          module_imports[dest_module].append(import_str)
+          module_imports_builder.add_import(
+              dest_module, module.__name__, module_contents_name, names[-1])
 
   # Import all required modules in their parent modules.
   # For e.g. if we import 'foo.bar.Value'. Then, we also
   # import 'bar' in 'foo'.
-  imported_modules = set(module_imports.keys())
+  imported_modules = set(module_imports_builder.module_imports.keys())
   for module in imported_modules:
     if not module:
       continue
@@ -135,13 +175,11 @@ def get_api_imports():
         parent_module += ('.' + module_split[submodule_index-1] if parent_module
                           else module_split[submodule_index-1])
         import_from += '.' + parent_module
-      submodule_import = format_import(
-          import_from, module_split[submodule_index],
+      module_imports_builder.add_import(
+          parent_module, import_from, module_split[submodule_index],
           module_split[submodule_index])
-      if submodule_import not in module_imports[parent_module]:
-        module_imports[parent_module].append(submodule_import)
 
-  return module_imports
+  return module_imports_builder.module_imports
 
 
 def create_api_files(output_files):