exports_files(["LICENSE"])
+load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
+
+py_library(
+ name = "doc_srcs",
+ srcs = ["doc_srcs.py"],
+ srcs_version = "PY2AND3",
+)
+
py_binary(
name = "create_python_api",
srcs = ["create_python_api.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":doc_srcs",
"//tensorflow/python:no_contrib",
],
)
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "tensorflow_doc_srcs_test",
+ srcs = ["doc_srcs_test.py"],
+ args = [
+ "--package=tensorflow.python",
+ ] + TENSORFLOW_API_INIT_FILES,
+ main = "doc_srcs_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ ],
+)
import sys
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+from tensorflow.tools.api.generator import doc_srcs
_API_CONSTANTS_ATTR = '_tf_api_constants'
# would have side effects.
'tensorflow.python.platform.flags.FLAGS'
}
-_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
-
-This file is MACHINE GENERATED! Do not edit.
-Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
+# Generated by: tensorflow/tools/api/generator/create_python_api.py script.
+\"\"\"%s
\"\"\"
from __future__ import print_function
return dir_path.replace('/', '.').strip('.')
+def get_module_docstring(module_name, package):
+ """Get docstring for the given module.
+
+ This method looks for docstring in the following order:
+ 1. Checks if module has a docstring specified in doc_srcs.
+ 2. Checks if module has a docstring source module specified
+ in doc_srcs. If it does, gets docstring from that module.
+ 3. Checks if module with module_name exists under base package.
+ If it does, gets docstring from that module.
+ 4. Returns a default docstring.
+
+ Args:
+ module_name: module name relative to tensorflow
+ (excluding 'tensorflow.' prefix) to get a docstring for.
+ package: Base python package containing python with target tf_export
+ decorators.
+
+ Returns:
+ One-line docstring to describe the module.
+ """
+ # Module under base package to get a docstring from.
+ docstring_module_name = module_name
+
+ if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+ docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
+ if docsrc.docstring:
+ return docsrc.docstring
+ if docsrc.docstring_module_name:
+ docstring_module_name = docsrc.docstring_module_name
+
+ docstring_module_name = package + '.' + docstring_module_name
+ if (docstring_module_name in sys.modules and
+ sys.modules[docstring_module_name].__doc__):
+ return sys.modules[docstring_module_name].__doc__
+
+ return 'Public API for tf.%s namespace.' % module_name
+
+
def create_api_files(
output_files, package, root_init_template, output_dir):
"""Creates __init__.py files for the Python API.
continue
contents = ''
if module or not root_init_template:
- contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER
+ contents = (
+ _GENERATED_FILE_HEADER %
+ get_module_docstring(module, package) + text +
+ _GENERATED_FILE_FOOTER)
else:
# Read base init file
with open(root_init_template, 'r') as root_init_template_file:
raise ValueError(
'Missing outputs for python_api_gen genrule:\n%s.'
'Make sure all required outputs are in the '
- 'tensorflow/tools/api/generator/BUILD file.' %
+ 'tensorflow/tools/api/generator/api_gen.bzl file.' %
',\n'.join(sorted(missing_output_files)))
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Specifies sources of doc strings for API modules."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+# Specifies docstring source for a module.
+# Only one of docstring or docstring_module_name should be set.
+# * If docstring is set, then we will use this docstring when
+# for the module.
+# * If docstring_module_name is set, then we will copy the docstring
+# from docstring source module.
+DocSource = collections.namedtuple(
+ 'DocSource', ['docstring', 'docstring_module_name'])
+# Each attribute of DocSource is optional.
+DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
+
+TENSORFLOW_DOC_SOURCES = {
+ 'app': DocSource(docstring_module_name='platform.app'),
+ 'compat': DocSource(docstring_module_name='util.compat'),
+ 'distributions': DocSource(
+ docstring_module_name='ops.distributions.distributions'),
+ 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
+ 'errors': DocSource(docstring_module_name='framework.errors'),
+ 'gfile': DocSource(docstring_module_name='platform.gfile'),
+ 'graph_util': DocSource(docstring_module_name='framework.graph_util'),
+ 'image': DocSource(docstring_module_name='ops.image_ops'),
+ 'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
+ 'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
+ 'logging': DocSource(docstring_module_name='ops.logging_ops'),
+ 'losses': DocSource(docstring_module_name='ops.losses.losses'),
+ 'manip': DocSource(docstring_module_name='ops.manip_ops'),
+ 'math': DocSource(docstring_module_name='ops.math_ops'),
+ 'metrics': DocSource(docstring_module_name='ops.metrics'),
+ 'nn': DocSource(docstring_module_name='ops.nn_ops'),
+ 'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'),
+ 'python_io': DocSource(docstring_module_name='lib.io.python_io'),
+ 'resource_loader': DocSource(
+ docstring_module_name='platform.resource_loader'),
+ 'sets': DocSource(docstring_module_name='ops.sets'),
+ 'sparse': DocSource(docstring_module_name='ops.sparse_ops'),
+ 'spectral': DocSource(docstring_module_name='ops.spectral_ops'),
+ 'strings': DocSource(docstring_module_name='ops.string_ops'),
+ 'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
+ 'test': DocSource(docstring_module_name='platform.test'),
+ 'train': DocSource(docstring_module_name='training.training'),
+ 'train.queue_runner': DocSource(
+ docstring_module_name='training.queue_runner'),
+}
--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for tensorflow.tools.api.generator.doc_srcs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import importlib
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.tools.api.generator import doc_srcs
+
+
+FLAGS = None
+
+
+class DocSrcsTest(test.TestCase):
+
+ def testModulesAreValidAPIModules(self):
+ for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+ # Convert module_name to corresponding __init__.py file path.
+ file_path = module_name.replace('.', '/')
+ if file_path:
+ file_path += '/'
+ file_path += '__init__.py'
+
+ if file_path not in FLAGS.outputs:
+ self.assertFalse('%s is not a valid API module' % module_name)
+
+ def testHaveDocstringOrDocstringModule(self):
+ for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+ if docsrc.docstring and docsrc.docstring_module_name:
+ self.assertFalse(
+ '%s contains DocSource has both a docstring and a '
+ 'docstring_module_name. '
+ 'Only one of "docstring" or "docstring_module_name" should be set.'
+ % (module_name))
+
+ def testDocstringModulesAreValidModules(self):
+ for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+ if docsrc.docstring_module_name:
+ doc_module_name = '.'.join([
+ FLAGS.package, docsrc.docstring_module_name])
+ if doc_module_name not in sys.modules:
+ sys.assertFalse(
+ 'docsources_module %s is not a valid module under %s.' %
+ (docsrc.docstring_module_name, FLAGS.package))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ 'outputs', metavar='O', type=str, nargs='+',
+ help='create_python_api output files.')
+ parser.add_argument(
+ '--package', type=str,
+ help='Base package that imports modules containing the target tf_export '
+ 'decorators.')
+ FLAGS, unparsed = parser.parse_known_args()
+
+ importlib.import_module(FLAGS.package)
+
+ # Now update argv, so that unittest library does not get confused.
+ sys.argv = [sys.argv[0]] + unparsed
+ test.main()