Merging
authorAnna R <annarev@google.com>
Tue, 12 Jun 2018 00:21:06 +0000 (17:21 -0700)
committerAnna R <annarev@google.com>
Wed, 13 Jun 2018 00:16:57 +0000 (17:16 -0700)
tensorflow/tools/api/generator/BUILD
tensorflow/tools/api/generator/create_python_api.py
tensorflow/tools/api/generator/doc_srcs.py [new file with mode: 0644]
tensorflow/tools/api/generator/doc_srcs_test.py [new file with mode: 0644]

index f0c5877..3a28153 100644 (file)
@@ -5,12 +5,21 @@ licenses(["notice"])  # Apache 2.0
 
 exports_files(["LICENSE"])
 
+load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
+
+py_library(
+    name = "doc_srcs",
+    srcs = ["doc_srcs.py"],
+    srcs_version = "PY2AND3",
+)
+
 py_binary(
     name = "create_python_api",
     srcs = ["create_python_api.py"],
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
+        ":doc_srcs",
         "//tensorflow/python:no_contrib",
     ],
 )
@@ -24,3 +33,18 @@ py_test(
         "//tensorflow/python:client_testlib",
     ],
 )
+
+py_test(
+    name = "tensorflow_doc_srcs_test",
+    srcs = ["doc_srcs_test.py"],
+    args = [
+        "--package=tensorflow.python",
+    ] + TENSORFLOW_API_INIT_FILES,
+    main = "doc_srcs_test.py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":doc_srcs",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:no_contrib",
+    ],
+)
index 9f210ad..31f287b 100644 (file)
@@ -25,6 +25,8 @@ import os
 import sys
 
 from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+from tensorflow.tools.api.generator import doc_srcs
 
 
 _API_CONSTANTS_ATTR = '_tf_api_constants'
@@ -36,10 +38,9 @@ _SYMBOLS_TO_SKIP_EXPLICITLY = {
     # would have side effects.
     'tensorflow.python.platform.flags.FLAGS'
 }
-_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
-
-This file is MACHINE GENERATED! Do not edit.
-Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
+# Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+\"\"\"%s
 \"\"\"
 
 from __future__ import print_function
@@ -254,6 +255,44 @@ def get_module(dir_path, relative_to_dir):
   return dir_path.replace('/', '.').strip('.')
 
 
+def get_module_docstring(module_name, package):
+  """Get docstring for the given module.
+
+  This method looks for docstring in the following order:
+  1. Checks if module has a docstring specified in doc_srcs.
+  2. Checks if module has a docstring source module specified
+     in doc_srcs. If it does, gets docstring from that module.
+  3. Checks if module with module_name exists under base package.
+     If it does, gets docstring from that module.
+  4. Returns a default docstring.
+
+  Args:
+    module_name: module name relative to tensorflow
+      (excluding 'tensorflow.' prefix) to get a docstring for.
+    package: Base python package containing python with target tf_export
+      decorators.
+
+  Returns:
+    One-line docstring to describe the module.
+  """
+  # Module under base package to get a docstring from.
+  docstring_module_name = module_name
+
+  if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+    docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
+    if docsrc.docstring:
+      return docsrc.docstring
+    if docsrc.docstring_module_name:
+      docstring_module_name = docsrc.docstring_module_name
+
+  docstring_module_name = package + '.' + docstring_module_name
+  if (docstring_module_name in sys.modules and
+      sys.modules[docstring_module_name].__doc__):
+    return sys.modules[docstring_module_name].__doc__
+
+  return 'Public API for tf.%s namespace.' % module_name
+
+
 def create_api_files(
     output_files, package, root_init_template, output_dir):
   """Creates __init__.py files for the Python API.
@@ -296,7 +335,10 @@ def create_api_files(
       continue
     contents = ''
     if module or not root_init_template:
-      contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER
+      contents = (
+          _GENERATED_FILE_HEADER %
+          get_module_docstring(module, package) + text +
+          _GENERATED_FILE_FOOTER)
     else:
       # Read base init file
       with open(root_init_template, 'r') as root_init_template_file:
@@ -309,7 +351,7 @@ def create_api_files(
     raise ValueError(
         'Missing outputs for python_api_gen genrule:\n%s.'
         'Make sure all required outputs are in the '
-        'tensorflow/tools/api/generator/BUILD file.' %
+        'tensorflow/tools/api/generator/api_gen.bzl file.' %
         ',\n'.join(sorted(missing_output_files)))
 
 
diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py
new file mode 100644 (file)
index 0000000..74f6db9
--- /dev/null
@@ -0,0 +1,65 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Specifies sources of doc strings for API modules."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+# Specifies docstring source for a module.
+# Only one of docstring or docstring_module_name should be set.
+# * If docstring is set, then we will use this docstring when
+#   for the module.
+# * If docstring_module_name is set, then we will copy the docstring
+#   from docstring source module.
+DocSource = collections.namedtuple(
+    'DocSource', ['docstring', 'docstring_module_name'])
+# Each attribute of DocSource is optional.
+DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
+
+TENSORFLOW_DOC_SOURCES = {
+    'app': DocSource(docstring_module_name='platform.app'),
+    'compat': DocSource(docstring_module_name='util.compat'),
+    'distributions': DocSource(
+        docstring_module_name='ops.distributions.distributions'),
+    'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
+    'errors': DocSource(docstring_module_name='framework.errors'),
+    'gfile': DocSource(docstring_module_name='platform.gfile'),
+    'graph_util': DocSource(docstring_module_name='framework.graph_util'),
+    'image': DocSource(docstring_module_name='ops.image_ops'),
+    'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
+    'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
+    'logging': DocSource(docstring_module_name='ops.logging_ops'),
+    'losses': DocSource(docstring_module_name='ops.losses.losses'),
+    'manip': DocSource(docstring_module_name='ops.manip_ops'),
+    'math': DocSource(docstring_module_name='ops.math_ops'),
+    'metrics': DocSource(docstring_module_name='ops.metrics'),
+    'nn': DocSource(docstring_module_name='ops.nn_ops'),
+    'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'),
+    'python_io': DocSource(docstring_module_name='lib.io.python_io'),
+    'resource_loader': DocSource(
+        docstring_module_name='platform.resource_loader'),
+    'sets': DocSource(docstring_module_name='ops.sets'),
+    'sparse': DocSource(docstring_module_name='ops.sparse_ops'),
+    'spectral': DocSource(docstring_module_name='ops.spectral_ops'),
+    'strings': DocSource(docstring_module_name='ops.string_ops'),
+    'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
+    'test': DocSource(docstring_module_name='platform.test'),
+    'train': DocSource(docstring_module_name='training.training'),
+    'train.queue_runner': DocSource(
+        docstring_module_name='training.queue_runner'),
+}
diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py
new file mode 100644 (file)
index 0000000..9ba95a3
--- /dev/null
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.tools.api.generator.doc_srcs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import importlib
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.tools.api.generator import doc_srcs
+
+
+FLAGS = None
+
+
+class DocSrcsTest(test.TestCase):
+
+  def testModulesAreValidAPIModules(self):
+    for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+      # Convert module_name to corresponding __init__.py file path.
+      file_path = module_name.replace('.', '/')
+      if file_path:
+        file_path += '/'
+      file_path += '__init__.py'
+
+      if file_path not in FLAGS.outputs:
+        self.assertFalse('%s is not a valid API module' % module_name)
+
+  def testHaveDocstringOrDocstringModule(self):
+    for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+      if docsrc.docstring and docsrc.docstring_module_name:
+        self.assertFalse(
+            '%s contains DocSource has both a docstring and a '
+            'docstring_module_name. '
+            'Only one of "docstring" or "docstring_module_name" should be set.'
+            % (module_name))
+
+  def testDocstringModulesAreValidModules(self):
+    for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+      if docsrc.docstring_module_name:
+        doc_module_name = '.'.join([
+            FLAGS.package, docsrc.docstring_module_name])
+        if doc_module_name not in sys.modules:
+          sys.assertFalse(
+              'docsources_module %s is not a valid module under %s.' %
+              (docsrc.docstring_module_name, FLAGS.package))
+
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      'outputs', metavar='O', type=str, nargs='+',
+      help='create_python_api output files.')
+  parser.add_argument(
+      '--package', type=str,
+      help='Base package that imports modules containing the target tf_export '
+           'decorators.')
+  FLAGS, unparsed = parser.parse_known_args()
+
+  importlib.import_module(FLAGS.package)
+
+  # Now update argv, so that unittest library does not get confused.
+  sys.argv = [sys.argv[0]] + unparsed
+  test.main()