Re-apply #18192.
authorAsim Shankar <ashankar@google.com>
Tue, 29 May 2018 07:50:34 +0000 (00:50 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 07:53:00 +0000 (00:53 -0700)
PiperOrigin-RevId: 198358055

12 files changed:
tensorflow/java/BUILD
tensorflow/java/build_defs.bzl
tensorflow/java/src/gen/cc/java_defs.h
tensorflow/java/src/gen/cc/op_gen_main.cc
tensorflow/java/src/gen/cc/op_generator.cc
tensorflow/java/src/gen/cc/op_generator.h
tensorflow/java/src/gen/cc/op_specs.cc [new file with mode: 0644]
tensorflow/java/src/gen/cc/op_specs.h [new file with mode: 0644]
tensorflow/java/src/gen/cc/source_writer.cc
tensorflow/java/src/gen/cc/source_writer.h
tensorflow/java/src/gen/cc/source_writer_test.cc
tensorflow/java/src/gen/gen_ops.bzl

index 78596d1..19d2133 100644 (file)
@@ -60,9 +60,7 @@ java_library(
 
 filegroup(
     name = "java_op_sources",
-    srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
-        ":java_op_gen_sources",
-    ],
+    srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
     visibility = [
         "//tensorflow/java:__pkg__",
     ],
@@ -70,43 +68,27 @@ filegroup(
 
 tf_java_op_gen_srcjar(
     name = "java_op_gen_sources",
-    gen_base_package = "org.tensorflow.op",
-    gen_tool = "java_op_gen_tool",
-    ops_libs = [
-        "array_ops",
-        "candidate_sampling_ops",
-        "control_flow_ops",
-        "data_flow_ops",
-        "image_ops",
-        "io_ops",
-        "linalg_ops",
-        "logging_ops",
-        "math_ops",
-        "nn_ops",
-        "no_op",
-        "parsing_ops",
-        "random_ops",
-        "sparse_ops",
-        "state_ops",
-        "string_ops",
-        "training_ops",
-        "user_ops",
+    api_def_srcs = [
+        "//tensorflow/core/api_def:base_api_def",
     ],
+    base_package = "org.tensorflow.op",
+    gen_tool = ":java_op_gen_tool",
 )
 
-# Build the gen tool as a library, as it will be linked to a core/ops binary
-# file before making it an executable. See tf_java_op_gen_srcjar().
-cc_library(
+tf_cc_binary(
     name = "java_op_gen_tool",
     srcs = [
         "src/gen/cc/op_gen_main.cc",
     ],
     copts = tf_copts(),
+    linkopts = ["-lm"],
+    linkstatic = 1,
     deps = [
         ":java_op_gen_lib",
         "//tensorflow/core:framework",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
+        "//tensorflow/core:ops",
     ],
 )
 
@@ -114,11 +96,13 @@ cc_library(
     name = "java_op_gen_lib",
     srcs = [
         "src/gen/cc/op_generator.cc",
+        "src/gen/cc/op_specs.cc",
         "src/gen/cc/source_writer.cc",
     ],
     hdrs = [
         "src/gen/cc/java_defs.h",
         "src/gen/cc/op_generator.h",
+        "src/gen/cc/op_specs.h",
         "src/gen/cc/source_writer.h",
     ],
     copts = tf_copts(),
@@ -127,6 +111,9 @@ cc_library(
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core:op_gen_lib",
+        "//tensorflow/core:protos_all_cc",
+        "@com_googlesource_code_re2//:re2",
     ],
 )
 
index ab7f60d..e1916ca 100644 (file)
@@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [
 XLINT_OPTS = [
     "-Werror",
     "-Xlint:all",
+    "-Xlint:-processing",
     "-Xlint:-serial",
     "-Xlint:-try",
     "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
index 59f8bea..d9d6f8a 100644 (file)
@@ -16,19 +16,22 @@ limitations under the License.
 #ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
 #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
 
-#include <string>
 #include <list>
+#include <map>
+#include <string>
+#include <utility>
 
 namespace tensorflow {
 namespace java {
 
 // An enumeration of different modifiers commonly used in Java
 enum Modifier {
-  PUBLIC    = (1 << 0),
+  PACKAGE = 0,
+  PUBLIC = (1 << 0),
   PROTECTED = (1 << 1),
-  PRIVATE   = (1 << 2),
-  STATIC    = (1 << 3),
-  FINAL     = (1 << 4),
+  PRIVATE = (1 << 2),
+  STATIC = (1 << 3),
+  FINAL = (1 << 4),
 };
 
 class Annotation;
@@ -72,6 +75,8 @@ class Type {
     // Reflection API does
     return Type(Type::PRIMITIVE, "void");
   }
+  static Type Generic(const string& name) { return Type(Type::GENERIC, name); }
+  static Type Wildcard() { return Type(Type::GENERIC, ""); }
   static Type Class(const string& name, const string& package = "") {
     return Type(Type::CLASS, name, package);
   }
@@ -81,9 +86,6 @@ class Type {
   static Type Enum(const string& name, const string& package = "") {
     return Type(Type::ENUM, name, package);
   }
-  static Type Generic(const string& name = "") {
-    return Type(Type::GENERIC, name);
-  }
   static Type ClassOf(const Type& type) {
     return Class("Class").add_parameter(type);
   }
@@ -96,11 +98,10 @@ class Type {
   const Kind& kind() const { return kind_; }
   const string& name() const { return name_; }
   const string& package() const { return package_; }
-  const string& description() const { return description_; }
-  Type& description(const string& description) {
-    description_ = description;
-    return *this;
+  const string canonical_name() const {
+    return package_.empty() ? name_ : package_ + "." + name_;
   }
+  bool wildcard() const { return name_.empty(); }  // only wildcards has no name
   const std::list<Type>& parameters() const { return parameters_; }
   Type& add_parameter(const Type& parameter) {
     parameters_.push_back(parameter);
@@ -120,14 +121,6 @@ class Type {
     }
     return *this;
   }
-  // Returns true if "type" is of a known collection type (only a few for now)
-  bool IsCollection() const {
-    return name_ == "List" || name_ == "Iterable";
-  }
-  // Returns true if this instance is a wildcard (<?>)
-  bool IsWildcard() const {
-    return kind_ == GENERIC && name_.empty();
-  }
 
  protected:
   Type(Kind kind, const string& name, const string& package = "")
@@ -137,7 +130,6 @@ class Type {
   Kind kind_;
   string name_;
   string package_;
-  string description_;
   std::list<Type> parameters_;
   std::list<Annotation> annotations_;
   std::list<Type> supertypes_;
@@ -180,16 +172,11 @@ class Variable {
   const string& name() const { return name_; }
   const Type& type() const { return type_; }
   bool variadic() const { return variadic_; }
-  const string& description() const { return description_; }
-  Variable& description(const string& description) {
-    description_ = description;
-    return *this;
-  }
+
  private:
   string name_;
   Type type_;
   bool variadic_;
-  string description_;
 
   Variable(const string& name, const Type& type, bool variadic)
     : name_(name), type_(type), variadic_(variadic) {}
@@ -210,16 +197,6 @@ class Method {
   bool constructor() const { return constructor_; }
   const string& name() const { return name_; }
   const Type& return_type() const { return return_type_; }
-  const string& description() const { return description_; }
-  Method& description(const string& description) {
-    description_ = description;
-    return *this;
-  }
-  const string& return_description() const { return return_description_; }
-  Method& return_description(const string& description) {
-    return_description_ = description;
-    return *this;
-  }
   const std::list<Variable>& arguments() const { return arguments_; }
   Method& add_argument(const Variable& var) {
     arguments_.push_back(var);
@@ -235,8 +212,6 @@ class Method {
   string name_;
   Type return_type_;
   bool constructor_;
-  string description_;
-  string return_description_;
   std::list<Variable> arguments_;
   std::list<Annotation> annotations_;
 
@@ -244,6 +219,33 @@ class Method {
     : name_(name), return_type_(return_type), constructor_(constructor) {}
 };
 
+// A definition of a documentation bloc for a Java element (JavaDoc)
+class Javadoc {
+ public:
+  static Javadoc Create(const string& brief = "") { return Javadoc(brief); }
+  const string& brief() const { return brief_; }
+  const string& details() const { return details_; }
+  Javadoc& details(const string& details) {
+    details_ = details;
+    return *this;
+  }
+  const std::list<std::pair<string, string>>& tags() const { return tags_; }
+  Javadoc& add_tag(const string& tag, const string& text) {
+    tags_.push_back(std::make_pair(tag, text));
+    return *this;
+  }
+  Javadoc& add_param_tag(const string& name, const string& text) {
+    return add_tag("param", name + " " + text);
+  }
+
+ private:
+  string brief_;
+  string details_;
+  std::list<std::pair<string, string>> tags_;
+
+  explicit Javadoc(const string& brief) : brief_(brief) {}
+};
+
 }  // namespace java
 }  // namespace tensorflow
 
index bea99f3..0d9e088 100644 (file)
@@ -36,49 +36,43 @@ const char kUsageHeader[] =
     "Operation wrappers are generated under the path specified by the "
     "'--output_dir' argument. This path can be absolute or relative to the\n"
     "current working directory and will be created if it does not exists.\n\n"
-    "The '--lib_name' argument is used to classify the set of operations. If "
-    "the chosen name contains more than one word, it must be provided in \n"
-    "snake_case. This value is declined into other meaningful names, such as "
-    "the group and package of the generated operations. For example,\n"
-    "'--lib_name=my_lib' generates the operations under the "
-    "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
-    "group.\n\n"
-    "Note that the operator group assigned to the generated wrappers is just "
-    "an annotation tag at this stage. Operations will not be available "
-    "through\n"
-    "the 'org.tensorflow.op.Ops' API as a group until the generated classes "
-    "are compiled using an appropriate annotation processor.\n\n"
-    "Finally, the '--base_package' overrides the default parent package "
-    "under which the generated subpackage and classes are to be located.\n\n";
+    "Note that the operations will not be available through the "
+    "'org.tensorflow.op.Ops' API until the generated classes are compiled\n"
+    "using an appropriate annotation processor.\n\n"
+    "The '--base_package' overrides the default parent package under which "
+    "the generated subpackage and classes are to be located.\n\n"
+    "Finally, the `--api_dirs` argument takes a list of comma-separated "
+    "directories of API definitions can be provided to override default\n"
+    "values found in the ops definitions. Directories are ordered by priority "
+    "(the last having precedence over the first).\n\n";
 
 }  // namespace java
 }  // namespace tensorflow
 
 int main(int argc, char* argv[]) {
-  tensorflow::string lib_name;
   tensorflow::string output_dir;
   tensorflow::string base_package = "org.tensorflow.op";
+  tensorflow::string api_dirs_str;
   std::vector<tensorflow::Flag> flag_list = {
       tensorflow::Flag("output_dir", &output_dir,
                        "Root directory into which output files are generated"),
       tensorflow::Flag(
-          "lib_name", &lib_name,
-          "A name, in snake_case, used to classify this set of operations"),
-      tensorflow::Flag(
           "base_package", &base_package,
-          "Package parent to the generated subpackage and classes")};
+          "Package parent to the generated subpackage and classes"),
+      tensorflow::Flag(
+          "api_dirs", &api_dirs_str,
+          "List of directories that contains the ops api definitions")};
   tensorflow::string usage = tensorflow::java::kUsageHeader;
   usage += tensorflow::Flags::Usage(argv[0], flag_list);
   bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
   tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
-  QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
-
-  tensorflow::java::OpGenerator generator;
+  QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
+  std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
+      api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
+  tensorflow::java::OpGenerator generator(api_dirs);
   tensorflow::OpList ops;
-  tensorflow::OpRegistry::Global()->Export(true, &ops);
-  tensorflow::Status status =
-      generator.Run(ops, lib_name, base_package, output_dir);
-  TF_QCHECK_OK(status);
+  tensorflow::OpRegistry::Global()->Export(false, &ops);
+  TF_CHECK_OK(generator.Run(ops, base_package, output_dir));
 
   return 0;
 }
index def06ba..debd95f 100644 (file)
@@ -13,54 +13,466 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
 #include <string>
+#include <vector>
 
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
 #include "tensorflow/java/src/gen/cc/op_generator.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+#include "tensorflow/java/src/gen/cc/source_writer.h"
 
 namespace tensorflow {
 namespace java {
 namespace {
 
-string CamelCase(const string& str, char delimiter, bool upper) {
-  string result;
-  bool cap = upper;
-  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
-    const char c = *it;
-    if (c == delimiter) {
-      cap = true;
-    } else if (cap) {
-      result += toupper(c);
-      cap = false;
+const char* kLicense =
+    "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+    "\n"
+    "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+    "you may not use this file except in compliance with the License.\n"
+    "You may obtain a copy of the License at\n"
+    "\n"
+    "    http://www.apache.org/licenses/LICENSE-2.0\n"
+    "\n"
+    "Unless required by applicable law or agreed to in writing, software\n"
+    "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+    "See the License for the specific language governing permissions and\n"
+    "limitations under the License.\n"
+    "=======================================================================*/"
+    "\n";
+
+// There is three different modes to render an op class, depending on the
+// number and type of outputs it has:
+//
+// DEFAULT: This mode does not provide any specialization for the op class, it
+//          is applied when the operation does not comply with any other mode
+//
+// OPERAND: The op class implements the Operand<T> interface, allowing an
+//          instance to be passed directly in input to another operation
+//
+// LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
+//          allowing an instance to be passed directly as a list input to
+//          another operation
+//
+enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
+
+void AddArgument(const Variable& var, const string& description,
+                 Method* method_out, Javadoc* javadoc_out) {
+  method_out->add_argument(var);
+  javadoc_out->add_param_tag(var.name(), description);
+}
+
+void CollectOpDependencies(const OpSpec& op, RenderMode mode,
+                           std::list<Type>* out) {
+  out->push_back(Type::Class("Operation", "org.tensorflow"));
+  out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
+  out->push_back(Type::Class("Scope", "org.tensorflow.op"));
+  if (mode == OPERAND) {
+    out->push_back(Type::Class("Output", "org.tensorflow"));
+  } else if (mode == LIST_OPERAND) {
+    out->push_back(Type::Interface("Iterator", "java.util"));
+  }
+  // Don't pay attention to duplicate types in the dependency list, they will
+  // be filtered out by the SourceWriter.
+  for (const ArgumentSpec& input : op.inputs()) {
+    out->push_back(input.var().type());
+    if (input.iterable()) {
+      out->push_back(Type::Class("Operands", "org.tensorflow.op"));
+    }
+  }
+  for (const ArgumentSpec& output : op.outputs()) {
+    out->push_back(output.var().type());
+    if (output.iterable()) {
+      out->push_back(Type::Class("Arrays", "java.util"));
+    }
+  }
+  for (const AttributeSpec& attribute : op.attributes()) {
+    out->push_back(attribute.var().type());
+    out->push_back(attribute.jni_type());
+  }
+  for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+    out->push_back(optional_attribute.var().type());
+  }
+}
+
+void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
+                           SourceWriter* writer) {
+  string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
+  if (attr.iterable()) {
+    string array_name = attr.var().name() + "Array";
+    writer->AppendType(attr.jni_type())
+        .Append("[] " + array_name + " = new ")
+        .AppendType(attr.jni_type())
+        .Append("[" + var_name + ".size()];")
+        .EndLine()
+        .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
+        .Append(array_name + "[i] = ");
+    if (attr.type().kind() == Type::GENERIC) {
+      writer->Append("DataType.fromClass(" + var_name + ".get(i));");
+    } else {
+      writer->Append(var_name + ".get(i);");
+    }
+    writer->EndLine()
+        .EndBlock()
+        .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
+        .Append(array_name + ");")
+        .EndLine();
+  } else {
+    writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
+    if (attr.var().type().name() == "Class") {
+      writer->Append("DataType.fromClass(" + var_name + "));");
     } else {
-      result += c;
+      writer->Append(var_name + ");");
     }
+    writer->EndLine();
   }
-  return result;
 }
 
-}  // namespace
+void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
+                          SourceWriter* writer) {
+  Method factory = Method::Create("create", op_class);
+  Javadoc factory_doc =
+      Javadoc::Create("Factory method to create a class to wrap a new " +
+                      op_class.name() + " operation to the graph.");
+  Variable scope =
+      Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
+  AddArgument(scope, "current graph scope", &factory, &factory_doc);
+  for (const ArgumentSpec& input : op.inputs()) {
+    AddArgument(input.var(), input.description(), &factory, &factory_doc);
+  }
+  for (const AttributeSpec& attr : op.attributes()) {
+    AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+  }
+  if (!op.optional_attributes().empty()) {
+    AddArgument(Variable::Varargs("options", Type::Class("Options")),
+                "carries optional attributes values", &factory, &factory_doc);
+  }
+  factory_doc.add_tag("return", "a new instance of " + op_class.name());
 
-OpGenerator::OpGenerator() : env(Env::Default()) {}
+  writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
+  writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
+                 op.graph_op_name() + "\", scope.makeOpName(\"" +
+                 op_class.name() + "\"));");
+  writer->EndLine();
+  for (const ArgumentSpec& input : op.inputs()) {
+    if (input.iterable()) {
+      writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
+                     input.var().name() + "));");
+      writer->EndLine();
+    } else {
+      writer->Append("opBuilder.addInput(" + input.var().name() +
+                     ".asOutput());");
+      writer->EndLine();
+    }
+  }
+  for (const AttributeSpec& attribute : op.attributes()) {
+    WriteSetAttrDirective(attribute, false, writer);
+  }
+  if (!op.optional_attributes().empty()) {
+    writer->BeginBlock("if (options != null)")
+        .BeginBlock("for (Options opts : options)");
+    for (const AttributeSpec& attribute : op.optional_attributes()) {
+      writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
+      WriteSetAttrDirective(attribute, true, writer);
+      writer->EndBlock();
+    }
+    writer->EndBlock().EndBlock();
+  }
+  writer->Append("return new ")
+      .AppendType(op_class)
+      .Append("(opBuilder.build());")
+      .EndLine();
+  writer->EndMethod();
+}
 
-OpGenerator::~OpGenerator() {}
+void RenderConstructor(const OpSpec& op, const Type& op_class,
+                       SourceWriter* writer) {
+  Variable operation =
+      Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
+  Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
+  for (const ArgumentSpec& output : op.outputs()) {
+    if (output.iterable() && !output.type().wildcard()) {
+      constructor.add_annotation(
+          Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+      break;
+    }
+  }
+  writer->BeginMethod(constructor, PRIVATE)
+      .Append("super(operation);")
+      .EndLine();
+  if (!op.outputs().empty()) {
+    writer->Append("int outputIdx = 0;").EndLine();
+    for (const ArgumentSpec& output : op.outputs()) {
+      if (output.iterable()) {
+        string var_length = output.var().name() + "Length";
+        writer->Append("int " + var_length)
+            .Append(" = operation.outputListLength(\"" + output.op_def_name() +
+                    "\");")
+            .EndLine()
+            .Append(output.var().name() + " = Arrays.asList(");
+        if (!output.type().wildcard()) {
+          writer->Append("(")
+              .AppendType(output.var().type().parameters().front())
+              .Append("[])");
+        }
+        writer->Append("operation.outputList(outputIdx, " + var_length + "));")
+            .EndLine()
+            .Append("outputIdx += " + var_length + ";")
+            .EndLine();
+      } else {
+        writer
+            ->Append(output.var().name() + " = operation.output(outputIdx++);")
+            .EndLine();
+      }
+    }
+  }
+  writer->EndMethod();
+}
 
-Status OpGenerator::Run(const OpList& ops, const string& lib_name,
-                        const string& base_package, const string& output_dir) {
-  const string package =
-      base_package + '.' + str_util::StringReplace(lib_name, "_", "", true);
-  const string package_path =
-      output_dir + '/' + str_util::StringReplace(package, ".", "/", true);
-  const string group = CamelCase(lib_name, '_', false);
+void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
+  for (const AttributeSpec& attr : op.optional_attributes()) {
+    Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
+    Javadoc setter_doc = Javadoc::Create();
+    AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
+    writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
+        .Append("return new Options()." + attr.var().name() + "(" +
+                attr.var().name() + ");")
+        .EndLine()
+        .EndMethod();
+  }
+  for (const ArgumentSpec& output : op.outputs()) {
+    Method getter = Method::Create(output.var().name(), output.var().type());
+    Javadoc getter_doc = Javadoc::Create(output.description());
+    writer->BeginMethod(getter, PUBLIC, &getter_doc)
+        .Append("return " + output.var().name() + ";")
+        .EndLine()
+        .EndMethod();
+  }
+}
 
-  if (!env->FileExists(package_path).ok()) {
-    TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
+void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
+                         SourceWriter* writer) {
+  ArgumentSpec output = op.outputs().front();
+
+  if (mode == OPERAND) {
+    bool cast2obj = output.type().wildcard();
+    Type return_type =
+        Type::Class("Output", "org.tensorflow")
+            .add_parameter(cast2obj ? Type::Class("Object") : output.type());
+    Method as_output = Method::Create("asOutput", return_type)
+                           .add_annotation(Annotation::Create("Override"));
+    if (cast2obj) {
+      as_output.add_annotation(
+          Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+    }
+    writer->BeginMethod(as_output, PUBLIC);
+    if (cast2obj) {
+      writer->Append("return (").AppendType(return_type).Append(") ");
+    } else {
+      writer->Append("return ");
+    }
+    writer->Append(output.var().name() + ";").EndLine().EndMethod();
+
+  } else if (mode == LIST_OPERAND) {
+    Type operand = Type::Interface("Operand", "org.tensorflow");
+    if (output.type().wildcard()) {
+      operand.add_parameter(Type::Class("Object"));
+    } else {
+      operand.add_parameter(output.type());
+    }
+    Type return_type =
+        Type::Interface("Iterator", "java.util").add_parameter(operand);
+    Method iterator =
+        Method::Create("iterator", return_type)
+            .add_annotation(Annotation::Create("Override"))
+            .add_annotation(Annotation::Create("SuppressWarnings")
+                                .attributes("{\"rawtypes\", \"unchecked\"}"));
+    // cast the output list using a raw List
+    writer->BeginMethod(iterator, PUBLIC)
+        .Append("return (" + return_type.name() + ") ")
+        .Append(output.var().name() + ".iterator();")
+        .EndLine()
+        .EndMethod();
+  }
+}
+
+void RenderOptionsClass(const OpSpec& op, const Type& op_class,
+                        SourceWriter* writer) {
+  Type options_class = Type::Class("Options");
+  Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
+                                        op_class.canonical_name() + "}");
+  writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
+  for (const AttributeSpec& attr : op.optional_attributes()) {
+    Method setter = Method::Create(attr.var().name(), options_class);
+    Javadoc setter_doc = Javadoc::Create();
+    AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
+    writer->BeginMethod(setter, PUBLIC, &setter_doc)
+        .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
+        .EndLine()
+        .Append("return this;")
+        .EndLine()
+        .EndMethod();
+  }
+  writer->EndLine();
+  for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+    writer->WriteField(optional_attribute.var(), PRIVATE);
   }
+  Method constructor = Method::ConstructorFor(options_class);
+  writer->BeginMethod(constructor, PRIVATE).EndMethod();
+  writer->EndType();
+}
+
+inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
+  return Type::Class(
+      endpoint.name(),
+      base_package + "." + str_util::Lowercase(endpoint.package()));
+}
 
-  LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
-  // TODO(karllessard) generate wrappers from list of ops
+void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
+                const string& base_package, const string& output_dir,
+                Env* env) {
+  Type op_class(
+      ClassOf(endpoint, base_package)
+          .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
+  Javadoc op_javadoc(endpoint.javadoc());
 
+  // op interfaces
+  RenderMode mode = DEFAULT;
+  if (op.outputs().size() == 1) {
+    const ArgumentSpec& output = op.outputs().front();
+    Type operand_type(output.type().wildcard() ? Type::Class("Object")
+                                               : output.type());
+    Type operand_inf(Type::Interface("Operand", "org.tensorflow")
+                         .add_parameter(operand_type));
+    if (output.iterable()) {
+      mode = LIST_OPERAND;
+      op_class.add_supertype(Type::IterableOf(operand_inf));
+    } else {
+      mode = OPERAND;
+      op_class.add_supertype(operand_inf);
+    }
+  }
+  // op generic parameters
+  std::set<string> generics;
+  for (const ArgumentSpec& output : op.outputs()) {
+    if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
+        generics.find(output.type().name()) == generics.end()) {
+      op_class.add_parameter(output.type());
+      op_javadoc.add_param_tag(
+          "<" + output.type().name() + ">",
+          "data type for {@code " + output.var().name() + "()} output");
+      generics.insert(output.type().name());
+    }
+  }
+  // op annotations
+  op_class.add_annotation(
+      Annotation::Create("Generated", "javax.annotation")
+          .attributes("value = \"TensorFlow Java Op Generator\""));
+  if (endpoint.deprecated()) {
+    op_class.add_annotation(Annotation::Create("Deprecated"));
+    string explanation;
+    if (!op.endpoints().front().deprecated()) {
+      explanation =
+          "use {@link " +
+          ClassOf(op.endpoints().front(), base_package).canonical_name() +
+          "} instead";
+    } else {
+      explanation = op.deprecation_explanation();
+    }
+    op_javadoc.add_tag("deprecated", explanation);
+  }
+  if (!op.hidden()) {
+    // expose the op in the Ops Graph API only if it is visible
+    op_class.add_annotation(
+        Annotation::Create("Operator", "org.tensorflow.op.annotation")
+            .attributes("group = \"" + endpoint.package() + "\""));
+  }
+  // create op class file
+  const string op_dir_name = io::JoinPath(
+      output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
+  if (!env->FileExists(op_dir_name).ok()) {
+    TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
+        << op_dir_name;
+  }
+  const string op_file_name = op_class.name() + ".java";
+  std::unique_ptr<tensorflow::WritableFile> op_file;
+  TF_CHECK_OK(
+      env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
+      << op_file_name;
+
+  // render endpoint source code
+  SourceFileWriter writer(op_file.get());
+  std::list<Type> dependencies;
+  CollectOpDependencies(op, mode, &dependencies);
+  writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL,
+                                             &dependencies, &op_javadoc);
+  if (!op.optional_attributes().empty()) {
+    RenderOptionsClass(op, op_class, &writer);
+  }
+  RenderFactoryMethods(op, op_class, &writer);
+  RenderGettersAndSetters(op, &writer);
+  if (mode != DEFAULT) {
+    RenderInterfaceImpl(op, mode, &writer);
+  }
+  writer.EndLine();
+  for (const ArgumentSpec& output : op.outputs()) {
+    writer.WriteField(output.var(), PRIVATE);
+  }
+  RenderConstructor(op, op_class, &writer);
+  writer.EndType();
+}
+
+bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
+  if (api_def.visibility() == ApiDef::SKIP) {
+    return false;
+  }
+  for (const auto& attr : op_def.attr()) {
+    if (attr.type() == "func") {
+      return false;  // TODO(karllessard) add support for function attributes
+    }
+  }
+  return true;
+}
+
+}  // namespace
+
+Status OpGenerator::Run(const OpList& op_list, const string& base_package,
+                        const string& output_dir) {
+  ApiDefMap api_map(op_list);
+  if (!api_dirs_.empty()) {
+    // Only load api files that correspond to the requested "op_list"
+    for (const auto& op : op_list.op()) {
+      for (const auto& api_def_dir : api_dirs_) {
+        const std::string api_def_file_pattern =
+            io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
+        if (env_->FileExists(api_def_file_pattern).ok()) {
+          TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
+              << api_def_file_pattern;
+        }
+      }
+    }
+  }
+  api_map.UpdateDocs();
+  for (const auto& op_def : op_list.op()) {
+    const ApiDef* api_def = api_map.GetApiDef(op_def.name());
+    if (CanGenerateOp(op_def, *api_def)) {
+      OpSpec op(OpSpec::Create(op_def, *api_def));
+      for (const EndpointSpec& endpoint : op.endpoints()) {
+        GenerateOp(op, endpoint, base_package, output_dir, env_);
+      }
+    }
+  }
   return Status::OK();
 }
 
index 4b55ed3..05decd6 100644 (file)
@@ -17,34 +17,39 @@ limitations under the License.
 #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
 
 #include <string>
+#include <vector>
 
-#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/env.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
 
 namespace tensorflow {
 namespace java {
 
-/// \brief A generator of Java operation wrappers.
-///
-/// Such generator is normally ran only once per executable, outputting
-/// wrappers for the all registered operations it has been compiled with.
-/// Nonetheless, it is designed to support multiple runs, giving a different
-/// list of operations on each cycle.
+// A generator of Java operation wrappers.
+//
+// This generator takes a list of ops definitions in input and outputs
+// a Java Op wrapper for each of them in the provided directory. The same
+// generator instance can be invoked multiple times with a different list of
+// ops definitions.
 class OpGenerator {
  public:
-  OpGenerator();
-  virtual ~OpGenerator();
+  explicit OpGenerator(const std::vector<string>& api_dirs,
+                       Env* env = Env::Default())
+      : api_dirs_(api_dirs), env_(env) {}
 
-  /// \brief Generates wrappers for the given list of 'ops'.
-  ///
-  /// Output files are generated in <output_dir>/<base_package>/<lib_package>,
-  /// where 'lib_package' is derived from 'lib_name'.
-  Status Run(const OpList& ops, const string& lib_name,
-             const string& base_package, const string& output_dir);
+  // Generates wrappers for the given list of 'ops'.
+  //
+  // Output files are generated in <output_dir>/<base_package>/<op_package>,
+  // where 'op_package' is derived from ops endpoints.
+  Status Run(const OpList& op_list, const string& base_package,
+             const string& output_dir);
 
  private:
-  Env* env;
+  const std::vector<string> api_dirs_;
+  Env* env_;
 };
 
 }  // namespace java
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
new file mode 100644 (file)
index 0000000..181fd4c
--- /dev/null
@@ -0,0 +1,423 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+inline bool IsRealNumbers(const AttrValue& values) {
+  if (!values.has_list()) {
+    return RealNumberTypes().Contains(values.type());
+  }
+  for (int i = 0; i < values.list().type_size(); ++i) {
+    if (!RealNumberTypes().Contains(values.list().type(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+class TypeResolver {
+ public:
+  explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
+
+  // Returns the class type of an input/output argument
+  //
+  // For example, if the argument's datatype is DT_STRING, this method will
+  // return "java.lang.String", so the argument can become "Operand<String>"
+  // in the Ops API
+  Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
+
+  // Returns types of an input attribute
+  //
+  // The first element of the pair is the class type of this attribute while
+  // the second is its JNI/primitive type equivalent, required for explicit
+  // unboxing.
+  //
+  // For example, if the attribute is of type "float", this method will return
+  // <java.lang.Float, float>, so the attribute can be used as a "Float" object
+  // in the Ops API and casted to a "float" when passing through the JNI layer.
+  std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
+                                bool* iterable_out);
+
+  // Returns true if the type of this attribute has already been resolved
+  bool IsAttributeVisited(const string& attr_name) {
+    return visited_attrs_.find(attr_name) != visited_attrs_.cend();
+  }
+
+ private:
+  const OpDef op_def_;
+  std::map<std::string, Type> visited_attrs_;
+  char next_generic_letter_ = 'T';
+
+  std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
+    return std::make_pair(type, jni_type);
+  }
+  std::pair<Type, Type> MakeTypePair(const Type& type) {
+    return std::make_pair(type, type);
+  }
+  Type NextGeneric() {
+    char generic_letter = next_generic_letter_++;
+    if (next_generic_letter_ > 'Z') {
+      next_generic_letter_ = 'A';
+    }
+    return Type::Generic(string(1, generic_letter));
+  }
+};
+
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
+  *iterable_out = false;
+  if (!arg_def.number_attr().empty()) {
+    // when number_attr is set, argument has to be a list of tensors
+    *iterable_out = true;
+    visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
+  }
+  Type type = Type::Wildcard();
+  if (arg_def.type() != DataType::DT_INVALID) {
+    // resolve type from DataType
+    switch (arg_def.type()) {
+      case DataType::DT_BOOL:
+        type = Type::Class("Boolean");
+        break;
+      case DataType::DT_STRING:
+        type = Type::Class("String");
+        break;
+      case DataType::DT_FLOAT:
+        type = Type::Class("Float");
+        break;
+      case DataType::DT_DOUBLE:
+        type = Type::Class("Double");
+        break;
+      case DataType::DT_UINT8:
+        type = Type::Class("UInt8", "org.tensorflow.types");
+        break;
+      case DataType::DT_INT32:
+        type = Type::Class("Integer");
+        break;
+      case DataType::DT_INT64:
+        type = Type::Class("Long");
+        break;
+      case DataType::DT_RESOURCE:
+        // TODO(karllessard) create a Resource utility class that could be
+        // used to store a resource and its type (passed in a second argument).
+        // For now, we need to force a wildcard and we will unfortunately lose
+        // track of the resource type.
+        break;
+      default:
+        // Any other datatypes does not have a equivalent in Java and must
+        // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+        break;
+    }
+  } else if (!arg_def.type_attr().empty()) {
+    // resolve type from attribute (if already visited, retrieve its type)
+    if (IsAttributeVisited(arg_def.type_attr())) {
+      type = visited_attrs_.at(arg_def.type_attr());
+    } else {
+      for (const auto& attr_def : op_def_.attr()) {
+        if (attr_def.name() == arg_def.type_attr()) {
+          type = TypesOf(attr_def, iterable_out).first;
+          break;
+        }
+      }
+    }
+  } else if (!arg_def.type_list_attr().empty()) {
+    // type is a list of tensors that can be of different data types, so leave
+    // it as a list of wildcards
+    *iterable_out = true;
+    visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
+
+  } else {
+    LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
+               << "\" in operation \"" << op_def_.name() << "\"";
+  }
+  return type;
+}
+
+std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
+                                            bool* iterable_out) {
+  std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
+  *iterable_out = false;
+  StringPiece attr_type = attr_def.type();
+  if (str_util::ConsumePrefix(&attr_type, "list(")) {
+    attr_type.remove_suffix(1);  // remove closing brace
+    *iterable_out = true;
+  }
+  if (attr_type == "string") {
+    types = MakeTypePair(Type::Class("String"));
+
+  } else if (attr_type == "int") {
+    types = MakeTypePair(Type::Class("Long"), Type::Long());
+
+  } else if (attr_type == "float") {
+    types = MakeTypePair(Type::Class("Float"), Type::Float());
+
+  } else if (attr_type == "bool") {
+    types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
+
+  } else if (attr_type == "shape") {
+    types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
+
+  } else if (attr_type == "tensor") {
+    types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
+                             .add_parameter(Type::Wildcard()));
+
+  } else if (attr_type == "type") {
+    Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
+    if (IsRealNumbers(attr_def.allowed_values())) {
+      type.add_supertype(Type::Class("Number"));
+    }
+    types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
+
+  } else {
+    LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
+               << "\" in operation \"" << op_def_.name() << "\"";
+  }
+  visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
+  return types;
+}
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+  string result;
+  bool cap = upper;
+  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+    const char c = *it;
+    if (c == '_') {
+      cap = true;
+    } else if (cap) {
+      result += toupper(c);
+      cap = false;
+    } else {
+      result += c;
+    }
+  }
+  return result;
+}
+
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+                string* ret_match = nullptr) {
+  string match;
+  if (!RE2::PartialMatch(*input, expr, &match)) return false;
+  *before_match = input->substr(0, input->find(match));
+  *input = input->substr(before_match->size() + match.size());
+  if (ret_match != nullptr) *ret_match = match;
+  return true;
+}
+
+string ParseDocumentation(const string& inp) {
+  std::stringstream javadoc_text;
+
+  // TODO(karllessard) This is a very minimalist utility method for converting
+  // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
+  // for alternatives to increase the level of support for markups.
+  std::vector<string> markups_subexpr;
+  markups_subexpr.push_back("\n+\\*\\s+");                // lists
+  markups_subexpr.push_back("\n{2,}");                    // paragraphs
+  markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n");  // code blocks
+  markups_subexpr.push_back("`+");           // inlined code and code blocks
+  markups_subexpr.push_back("\\*{1,2}\\b");  // text emphasis
+  markups_subexpr.push_back("\\[");          // hyperlinks
+  const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
+
+  bool in_list = false;
+  string input = inp;
+  while (true) {
+    string text, markup;
+    if (!FindAndCut(&input, markup_expr, &text, &markup)) {
+      javadoc_text << input;
+      break;  // end of loop
+    }
+    javadoc_text << text;
+    if (str_util::StartsWith(markup, "\n")) {
+      javadoc_text << "\n";
+      if (str_util::StrContains(markup, "*")) {
+        // new list item
+        javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
+        in_list = true;
+      } else if (in_list) {
+        // end of list
+        javadoc_text << "</li>\n</ul>\n";
+        in_list = false;
+      } else if (!str_util::StartsWith(input, "```")) {
+        // new paragraph (not required if a <pre> block follows)
+        javadoc_text << "<p>\n";
+      }
+    } else if (str_util::StartsWith(markup, "```")) {
+      // code blocks
+      if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
+        javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (str_util::StartsWith("(" + markup + ")", "`")) {
+      // inlined code
+      if (FindAndCut(&input, markup, &text)) {
+        javadoc_text << "{@code " << text << "}";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (markup == "**") {
+      // text emphasis (strong)
+      if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
+        javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (markup == "*") {
+      // text emphasis (normal)
+      if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
+        javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (str_util::StartsWith(markup, "[")) {
+      // hyperlinks
+      string label;
+      string link;
+      if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+                            &link) &&
+          str_util::StartsWith(input, label + link)) {
+        input = input.substr(label.size() + link.size());
+        javadoc_text << "<a href=\"" << link << "\">"
+                     << ParseDocumentation(label) << "</a>";
+      } else {
+        javadoc_text << markup;
+      }
+    } else {
+      // safe fallback
+      javadoc_text << markup;
+    }
+  }
+  return javadoc_text.str();
+}
+
+ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
+                         const ApiDef::Arg& input_api_def,
+                         TypeResolver* type_resolver) {
+  bool iterable = false;
+  Type type = type_resolver->TypeOf(input_def, &iterable);
+  Type var_type =
+      Type::Interface("Operand", "org.tensorflow").add_parameter(type);
+  if (iterable) {
+    var_type = Type::IterableOf(var_type);
+  }
+  return ArgumentSpec(
+      input_api_def.name(),
+      Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
+      type, ParseDocumentation(input_api_def.description()), iterable);
+}
+
+AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
+                              const ApiDef::Attr& attr_api_def,
+                              TypeResolver* type_resolver) {
+  bool iterable = false;
+  std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
+  Type var_type = types.first.kind() == Type::GENERIC
+                      ? Type::Class("Class").add_parameter(types.first)
+                      : types.first;
+  if (iterable) {
+    var_type = Type::ListOf(var_type);
+  }
+  return AttributeSpec(
+      attr_api_def.name(),
+      Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
+      types.first, types.second, ParseDocumentation(attr_api_def.description()),
+      iterable, attr_api_def.has_default_value());
+}
+
+ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
+                          const ApiDef::Arg& output_api,
+                          TypeResolver* type_resolver) {
+  bool iterable = false;
+  Type type = type_resolver->TypeOf(output_def, &iterable);
+  Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
+  if (iterable) {
+    var_type = Type::ListOf(var_type);
+  }
+  return ArgumentSpec(
+      output_api.name(),
+      Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
+      type, ParseDocumentation(output_api.description()), iterable);
+}
+
+EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
+                            const ApiDef_Endpoint& endpoint_def) {
+  std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+  string package;
+  string name;
+  if (name_tokens.size() > 1) {
+    package = name_tokens.at(0);
+    name = name_tokens.at(1);
+  } else {
+    package = "core";  // generate unclassified ops in the 'core' package
+    name = name_tokens.at(0);
+  }
+  return EndpointSpec(package, name,
+                      Javadoc::Create(ParseDocumentation(api_def.summary()))
+                          .details(ParseDocumentation(api_def.description())));
+}
+
+}  // namespace
+
+OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
+  OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+            op_def.deprecation().explanation());
+  TypeResolver type_resolver(op_def);
+  for (const string& next_input_name : api_def.arg_order()) {
+    for (int i = 0; i < op_def.input_arg().size(); ++i) {
+      if (op_def.input_arg(i).name() == next_input_name) {
+        op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
+                                         &type_resolver));
+        break;
+      }
+    }
+  }
+  for (int i = 0; i < op_def.attr().size(); ++i) {
+    // do not parse attributes already visited, they have probably been inferred
+    // before as an input argument type
+    if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
+      AttributeSpec attr =
+          CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
+      // attributes with a default value are optional
+      if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
+        op.optional_attributes_.push_back(attr);
+      } else {
+        op.attributes_.push_back(attr);
+      }
+    }
+  }
+  for (int i = 0; i < op_def.output_arg().size(); ++i) {
+    op.outputs_.push_back(
+        CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
+  }
+  for (const auto& endpoint_def : api_def.endpoint()) {
+    op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
+  }
+  return op;
+}
+
+}  // namespace java
+}  // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
new file mode 100644 (file)
index 0000000..ca0ba16
--- /dev/null
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+
+namespace tensorflow {
+namespace java {
+
+class EndpointSpec {
+ public:
+  // A specification for an operation endpoint
+  //
+  // package: package of this endpoint (from which also derives its package)
+  // name: name of this endpoint class
+  // javadoc: the endpoint class documentation
+  // TODO(annarev): hardcode depcreated to false until deprecated is possible
+  EndpointSpec(const string& package, const string& name,
+               const Javadoc& javadoc)
+      : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {}
+
+  const string& package() const { return package_; }
+  const string& name() const { return name_; }
+  const Javadoc& javadoc() const { return javadoc_; }
+  bool deprecated() const { return deprecated_; }
+
+ private:
+  const string package_;
+  const string name_;
+  const Javadoc javadoc_;
+  const bool deprecated_;
+};
+
+class ArgumentSpec {
+ public:
+  // A specification for an operation argument
+  //
+  // op_def_name: argument name, as known by TensorFlow core
+  // var: a variable to represent this argument in Java
+  // type: the tensor type of this argument
+  // description: a description of this argument, in javadoc
+  // iterable: true if this argument is a list
+  ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type,
+               const string& description, bool iterable)
+      : op_def_name_(op_def_name),
+        var_(var),
+        type_(type),
+        description_(description),
+        iterable_(iterable) {}
+
+  const string& op_def_name() const { return op_def_name_; }
+  const Variable& var() const { return var_; }
+  const Type& type() const { return type_; }
+  const string& description() const { return description_; }
+  bool iterable() const { return iterable_; }
+
+ private:
+  const string op_def_name_;
+  const Variable var_;
+  const Type type_;
+  const string description_;
+  const bool iterable_;
+};
+
+class AttributeSpec {
+ public:
+  // A specification for an operation attribute
+  //
+  // op_def_name: attribute name, as known by TensorFlow core
+  // var: a variable to represent this attribute in Java
+  // type: the type of this attribute
+  // jni_type: the type of this attribute in JNI layer (see OperationBuilder)
+  // description: a description of this attribute, in javadoc
+  // iterable: true if this attribute is a list
+  // has_default_value: true if this attribute has a default value if not set
+  AttributeSpec(const string& op_def_name, const Variable& var,
+                const Type& type, const Type& jni_type,
+                const string& description, bool iterable,
+                bool has_default_value)
+      : op_def_name_(op_def_name),
+        var_(var),
+        type_(type),
+        description_(description),
+        iterable_(iterable),
+        jni_type_(jni_type),
+        has_default_value_(has_default_value) {}
+
+  const string& op_def_name() const { return op_def_name_; }
+  const Variable& var() const { return var_; }
+  const Type& type() const { return type_; }
+  const string& description() const { return description_; }
+  bool iterable() const { return iterable_; }
+  const Type& jni_type() const { return jni_type_; }
+  bool has_default_value() const { return has_default_value_; }
+
+ private:
+  const string op_def_name_;
+  const Variable var_;
+  const Type type_;
+  const string description_;
+  const bool iterable_;
+  const Type jni_type_;
+  const bool has_default_value_;
+};
+
+class OpSpec {
+ public:
+  // Parses an op definition and its API to produce a specification used for
+  // rendering its Java wrapper
+  //
+  // op_def: Op definition
+  // api_def: Op API definition
+  static OpSpec Create(const OpDef& op_def, const ApiDef& api_def);
+
+  const string& graph_op_name() const { return graph_op_name_; }
+  bool hidden() const { return hidden_; }
+  const string& deprecation_explanation() const {
+    return deprecation_explanation_;
+  }
+  const std::vector<EndpointSpec> endpoints() const { return endpoints_; }
+  const std::vector<ArgumentSpec>& inputs() const { return inputs_; }
+  const std::vector<ArgumentSpec>& outputs() const { return outputs_; }
+  const std::vector<AttributeSpec>& attributes() const { return attributes_; }
+  const std::vector<AttributeSpec>& optional_attributes() const {
+    return optional_attributes_;
+  }
+
+ private:
+  // A specification for an operation
+  //
+  // graph_op_name: name of this op, as known by TensorFlow core engine
+  // hidden: true if this op should not be visible through the Graph Ops API
+  // deprecation_explanation: message to show if all endpoints are deprecated
+  explicit OpSpec(const string& graph_op_name, bool hidden,
+                  const string& deprecation_explanation)
+      : graph_op_name_(graph_op_name),
+        hidden_(hidden),
+        deprecation_explanation_(deprecation_explanation) {}
+
+  const string graph_op_name_;
+  const bool hidden_;
+  const string deprecation_explanation_;
+  std::vector<EndpointSpec> endpoints_;
+  std::vector<ArgumentSpec> inputs_;
+  std::vector<ArgumentSpec> outputs_;
+  std::vector<AttributeSpec> attributes_;
+  std::vector<AttributeSpec> optional_attributes_;
+};
+
+}  // namespace java
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
index a02f75a..66401bd 100644 (file)
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include <string>
 #include <algorithm>
-#include <deque>
+#include <list>
+#include <string>
 
 #include "tensorflow/java/src/gen/cc/source_writer.h"
 
@@ -83,20 +83,22 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
 }
 
 SourceWriter& SourceWriter::AppendType(const Type& type) {
-  if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
+  if (type.wildcard()) {
     Append("?");
   } else {
     Append(type.name());
-  }
-  if (!type.parameters().empty()) {
-    Append("<");
-    for (const Type& t : type.parameters()) {
-      if (&t != &type.parameters().front()) {
-        Append(", ");
+    if (!type.parameters().empty()) {
+      Append("<");
+      bool first = true;
+      for (const Type& t : type.parameters()) {
+        if (!first) {
+          Append(", ");
+        }
+        AppendType(t);
+        first = false;
       }
-      AppendType(t);
+      Append(">");
     }
-    Append(">");
   }
   return *this;
 }
@@ -107,7 +109,21 @@ SourceWriter& SourceWriter::EndLine() {
   return *this;
 }
 
-SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
+SourceWriter& SourceWriter::BeginBlock(const string& expression) {
+  if (!expression.empty()) {
+    Append(expression + " {");
+  } else {
+    Append(newline_ ? "{" : " {");
+  }
+  return EndLine().Indent(2);
+}
+
+SourceWriter& SourceWriter::EndBlock() {
+  return Indent(-2).Append("}").EndLine();
+}
+
+SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
+                                        const Javadoc* javadoc) {
   GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
   if (!method.constructor()) {
     generic_namespace->Visit(method.return_type());
@@ -116,8 +132,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
     generic_namespace->Visit(v.type());
   }
   EndLine();
-  WriteDoc(method.description(), method.return_description(),
-      &method.arguments());
+  if (javadoc != nullptr) {
+    WriteJavadoc(*javadoc);
+  }
   if (!method.annotations().empty()) {
     WriteAnnotations(method.annotations());
   }
@@ -130,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
     AppendType(method.return_type()).Append(" ");
   }
   Append(method.name()).Append("(");
+  bool first = true;
   for (const Variable& v : method.arguments()) {
-    if (&v != &method.arguments().front()) {
+    if (!first) {
       Append(", ");
     }
     AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name());
+    first = false;
   }
   return Append(")").BeginBlock();
 }
@@ -145,29 +164,36 @@ SourceWriter& SourceWriter::EndMethod() {
   return *this;
 }
 
-SourceWriter& SourceWriter::BeginType(const Type& type,
-    const std::list<Type>* dependencies, int modifiers) {
+SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
+                                      const std::list<Type>* extra_dependencies,
+                                      const Javadoc* javadoc) {
   if (!type.package().empty()) {
     Append("package ").Append(type.package()).Append(";").EndLine();
   }
-  if (dependencies != nullptr && !dependencies->empty()) {
-    TypeImporter type_importer(type.package());
-    for (const Type& t : *dependencies) {
+  TypeImporter type_importer(type.package());
+  type_importer.Visit(type);
+  if (extra_dependencies != nullptr) {
+    for (const Type& t : *extra_dependencies) {
       type_importer.Visit(t);
     }
+  }
+  if (!type_importer.imports().empty()) {
     EndLine();
     for (const string& s : type_importer.imports()) {
       Append("import ").Append(s).Append(";").EndLine();
     }
   }
-  return BeginInnerType(type, modifiers);
+  return BeginInnerType(type, modifiers, javadoc);
 }
 
-SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
+SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
+                                           const Javadoc* javadoc) {
   GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
   generic_namespace->Visit(type);
   EndLine();
-  WriteDoc(type.description());
+  if (javadoc != nullptr) {
+    WriteJavadoc(*javadoc);
+  }
   if (!type.annotations().empty()) {
     WriteAnnotations(type.annotations());
   }
@@ -200,14 +226,15 @@ SourceWriter& SourceWriter::EndType() {
   return *this;
 }
 
-SourceWriter& SourceWriter::WriteFields(const std::list<Variable>& fields,
-    int modifiers) {
-  EndLine();
-  for (const Variable& v : fields) {
-    WriteModifiers(modifiers);
-    AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
-    EndLine();
+SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
+                                       const Javadoc* javadoc) {
+  // If present, write field javadoc only as one brief line
+  if (javadoc != nullptr && !javadoc->brief().empty()) {
+    Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
   }
+  WriteModifiers(modifiers);
+  AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
+  EndLine();
   return *this;
 }
 
@@ -228,39 +255,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
   return *this;
 }
 
-SourceWriter& SourceWriter::WriteDoc(const string& description,
-    const string& return_description, const std::list<Variable>* parameters) {
-  if (description.empty() && return_description.empty()
-      && (parameters == nullptr || parameters->empty())) {
-    return *this;  // no doc to write
-  }
+SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
+  Append("/**").Prefix(" * ").EndLine();
   bool do_line_break = false;
-  Append("/**").EndLine().Prefix(" * ");
-  if (!description.empty()) {
-    Write(description).EndLine();
+  if (!javadoc.brief().empty()) {
+    Write(javadoc.brief()).EndLine();
     do_line_break = true;
   }
-  if (parameters != nullptr && !parameters->empty()) {
+  if (!javadoc.details().empty()) {
     if (do_line_break) {
-      EndLine();
-      do_line_break = false;
-    }
-    for (const Variable& v : *parameters) {
-      Append("@param ").Append(v.name());
-      if (!v.description().empty()) {
-        Append(" ").Write(v.description());
-      }
-      EndLine();
+      Append("<p>").EndLine();
     }
+    Write(javadoc.details()).EndLine();
+    do_line_break = true;
   }
-  if (!return_description.empty()) {
+  if (!javadoc.tags().empty()) {
     if (do_line_break) {
       EndLine();
-      do_line_break = false;
     }
-    Append("@return ").Write(return_description).EndLine();
+    for (const auto& p : javadoc.tags()) {
+      Append("@" + p.first);
+      if (!p.second.empty()) {
+        Append(" ").Write(p.second);
+      }
+      EndLine();
+    }
   }
-  return Prefix("").Append(" **/").EndLine();
+  return Prefix("").Append(" */").EndLine();
 }
 
 SourceWriter& SourceWriter::WriteAnnotations(
@@ -278,14 +299,16 @@ SourceWriter& SourceWriter::WriteAnnotations(
 SourceWriter& SourceWriter::WriteGenerics(
     const std::list<const Type*>& generics) {
   Append("<");
+  bool first = true;
   for (const Type* pt : generics) {
-    if (pt != generics.front()) {
+    if (!first) {
       Append(", ");
     }
     Append(pt->name());
     if (!pt->supertypes().empty()) {
       Append(" extends ").AppendType(pt->supertypes().front());
     }
+    first = false;
   }
   return Append(">");
 }
@@ -311,21 +334,20 @@ void SourceWriter::PopGenericNamespace() {
 void SourceWriter::TypeVisitor::Visit(const Type& type) {
   DoVisit(type);
   for (const Type& t : type.parameters()) {
-    DoVisit(t);
+    Visit(t);
   }
   for (const Annotation& t : type.annotations()) {
     DoVisit(t);
   }
   for (const Type& t : type.supertypes()) {
-    DoVisit(t);
+    Visit(t);
   }
 }
 
 void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
   // ignore non-generic parameters, wildcards and generics already declared
-  if (type.kind() == Type::GENERIC
-      && !type.IsWildcard()
-      && generic_names_.find(type.name()) == generic_names_.end()) {
+  if (type.kind() == Type::GENERIC && !type.wildcard() &&
+      generic_names_.find(type.name()) == generic_names_.end()) {
     declared_types_.push_back(&type);
     generic_names_.insert(type.name());
   }
@@ -333,7 +355,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
 
 void SourceWriter::TypeImporter::DoVisit(const Type& type) {
   if (!type.package().empty() && type.package() != current_package_) {
-    imports_.insert(type.package() + '.' + type.name());
+    imports_.insert(type.canonical_name());
   }
 }
 
index f011acd..de0113b 100644 (file)
@@ -93,25 +93,22 @@ class SourceWriter {
   // This method appends a new opening brace to the current data and indent the
   // next lines according to Google Java Style Guide. The block can optionally
   // be preceded by an expression (e.g. Append("if(true)").BeginBlock();)
-  SourceWriter& BeginBlock() {
-    return Append(newline_ ? "{" : " {").EndLine().Indent(2);
-  }
+  SourceWriter& BeginBlock(const string& expression = "");
 
   // Ends the current block of source code.
   //
   // This method appends a new closing brace to the current data and outdent the
   // next lines back to the margin used before BeginBlock() was invoked.
-  SourceWriter& EndBlock() {
-    return Indent(-2).Append("}").EndLine();
-  }
+  SourceWriter& EndBlock();
 
   // Begins to write a method.
   //
   // This method outputs the signature of the Java method from the data passed
-  // in the 'method' parameter and starts a new block. Additionnal modifiers can
-  // also be passed in parameter to define the accesses and the scope of this
-  // method.
-  SourceWriter& BeginMethod(const Method& method, int modifiers = 0);
+  // in the 'method' parameter and starts a new block. Modifiers are also passed
+  // in parameter to define the access scope of this method and, optionally,
+  // a Javadoc.
+  SourceWriter& BeginMethod(const Method& method, int modifiers,
+                            const Javadoc* javadoc = nullptr);
 
   // Ends the current method.
   //
@@ -122,22 +119,24 @@ class SourceWriter {
   // Begins to write the main type of a source file.
   //
   // This method outputs the declaration of the Java type from the data passed
-  // in the 'type' parameter and starts a new block. Additionnal modifiers can
-  // also be passed in parameter to define the accesses and the scope of this
-  // type.
+  // in the 'type' parameter and starts a new block. Modifiers are also passed
+  // in parameter to define the access scope of this type and, optionally,
+  // a Javadoc.
   //
-  // If not null, all types found in the 'dependencies' list will be imported
-  // before declaring the new type.
-  SourceWriter& BeginType(const Type& clazz,
-      const std::list<Type>* dependencies, int modifiers = 0);
+  // If not null, all types found in the 'extra_dependencies' list will be
+  // imported before declaring the new type.
+  SourceWriter& BeginType(const Type& type, int modifiers,
+                          const std::list<Type>* extra_dependencies = nullptr,
+                          const Javadoc* javadoc = nullptr);
 
   // Begins to write a new inner type.
   //
   // This method outputs the declaration of the Java type from the data passed
-  // in the 'type' parameter and starts a new block. Additionnal modifiers can
-  // also be passed in parameter to define the accesses and the scope of this
-  // type.
-  SourceWriter& BeginInnerType(const Type& type, int modifiers = 0);
+  // in the 'type' parameter and starts a new block. Modifiers are also passed
+  // in parameter to define the accesses and the scope of this type and,
+  // optionally, a Javadoc.
+  SourceWriter& BeginInnerType(const Type& type, int modifiers,
+                               const Javadoc* javadoc = nullptr);
 
   // Ends the current type.
   //
@@ -145,13 +144,13 @@ class SourceWriter {
   // BeginType() or BeginInnerType() prior to this.
   SourceWriter& EndType();
 
-  // Writes a list of variables as fields of a type.
+  // Writes a variable as fields of a type.
   //
   // This method must be called within the definition of a type (see BeginType()
-  // or BeginInnerType()). Additional modifiers can also be passed in parameter
-  // to define the accesses and the scope of those fields.
-  SourceWriter& WriteFields(const std::list<Variable>& fields,
-      int modifiers = 0);
+  // or BeginInnerType()). Modifiers are also be passed in parameter to define
+  // the accesses and the scope of this field and, optionally, a Javadoc.
+  SourceWriter& WriteField(const Variable& field, int modifiers,
+                           const Javadoc* javadoc = nullptr);
 
  protected:
   virtual void DoAppend(const StringPiece& str) = 0;
@@ -207,9 +206,7 @@ class SourceWriter {
   std::stack<GenericNamespace*> generic_namespaces_;
 
   SourceWriter& WriteModifiers(int modifiers);
-  SourceWriter& WriteDoc(const string& description,
-    const string& return_description = "",
-    const std::list<Variable>* parameters = nullptr);
+  SourceWriter& WriteJavadoc(const Javadoc& javadoc);
   SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
   SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
   GenericNamespace* PushGenericNamespace(int modifiers);
index 4bce2fe..fb8fc64 100644 (file)
@@ -245,12 +245,17 @@ TEST(StreamTest, Types) {
   SourceBufferWriter writer;
   Type generic = Type::Generic("T").add_supertype(Type::Class("Number"));
 
-  writer.AppendType(Type::Int()).Append(", ")
-        .AppendType(Type::Class("String")).Append(", ")
-        .AppendType(generic).Append(", ")
-        .AppendType(Type::ListOf(generic)).Append(", ")
-        .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ")
-        .AppendType(Type::ListOf(Type::Generic()));
+  writer.AppendType(Type::Int())
+      .Append(", ")
+      .AppendType(Type::Class("String"))
+      .Append(", ")
+      .AppendType(generic)
+      .Append(", ")
+      .AppendType(Type::ListOf(generic))
+      .Append(", ")
+      .AppendType(Type::ListOf(Type::IterableOf(generic)))
+      .Append(", ")
+      .AppendType(Type::ListOf(Type::Wildcard()));
 
   const char* expected =
       "int, String, T, List<T>, List<Iterable<T>>, List<?>";
@@ -282,7 +287,7 @@ TEST(WriteType, SimpleClass) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
 
-  writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+  writer.BeginType(clazz, PUBLIC).EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -300,7 +305,7 @@ TEST(WriteType, SimpleClassWithDependencies) {
   deps.push_back(Type::Class("SamePackageType", "org.tensorflow"));
   deps.push_back(Type::Class("NoPackageType"));
 
-  writer.BeginType(clazz, &deps, PUBLIC).EndType();
+  writer.BeginType(clazz, PUBLIC, &deps).EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -313,20 +318,22 @@ TEST(WriteType, SimpleClassWithDependencies) {
 TEST(WriteType, AnnotatedAndDocumentedClass) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
-  clazz.description("This class has a\n<p>\nmultiline description.");
+  Javadoc clazz_doc = Javadoc::Create("Javadoc test")
+                          .details("This is a\nmultiline description.");
   clazz.add_annotation(Annotation::Create("Bean"));
   clazz.add_annotation(Annotation::Create("SuppressWarnings")
       .attributes("\"rawtypes\""));
 
-  writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+  writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
       "/**\n"
-      " * This class has a\n"
+      " * Javadoc test\n"
       " * <p>\n"
+      " * This is a\n"
       " * multiline description.\n"
-      " **/\n"
+      " */\n"
       "@Bean\n"
       "@SuppressWarnings(\"rawtypes\")\n"
       "public class Test {\n}\n";
@@ -339,7 +346,7 @@ TEST(WriteType, ParameterizedClass) {
   clazz.add_parameter(Type::Generic("T"));
   clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number")));
 
-  writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+  writer.BeginType(clazz, PUBLIC).EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -358,7 +365,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) {
   clazz.add_supertype(Type::Interface("Runnable"));
   clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t));
 
-  writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+  writer.BeginType(clazz, PUBLIC).EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -372,24 +379,23 @@ TEST(WriteType, ParameterizedClassFields) {
   Type clazz = Type::Class("Test", "org.tensorflow");
   Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
   clazz.add_parameter(type_t);
-  std::list<Variable> static_fields;
-  static_fields.push_back(Variable::Create("field1", Type::Class("String")));
-  std::list<Variable> member_fields;
-  member_fields.push_back(Variable::Create("field2", Type::Class("String")));
-  member_fields.push_back(Variable::Create("field3", type_t));
+  Variable field1 = Variable::Create("field1", Type::Class("String"));
+  Variable field2 = Variable::Create("field2", Type::Class("String"));
+  Variable field3 = Variable::Create("field3", type_t);
+  Javadoc field3_doc = Javadoc::Create("This variable is documented");
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .WriteFields(static_fields, STATIC | PUBLIC | FINAL)
-          .WriteFields(member_fields, PRIVATE)
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .WriteField(field1, STATIC | PUBLIC | FINAL)
+      .WriteField(field2, PRIVATE)
+      .WriteField(field3, PRIVATE, &field3_doc)
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
       "public class Test<T extends Number> {\n"
-      "  \n"
       "  public static final String field1;\n"
-      "  \n"
       "  private String field2;\n"
+      "  /** This variable is documented */\n"
       "  private T field3;\n"
       "}\n";
   ASSERT_STREQ(expected, writer.str().data());
@@ -400,10 +406,10 @@ TEST(WriteType, SimpleInnerClass) {
   Type clazz = Type::Class("Test", "org.tensorflow");
   Type inner_class = Type::Class("InnerTest");
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginInnerType(inner_class, PUBLIC)
-          .EndType()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginInnerType(inner_class, PUBLIC)
+      .EndType()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -423,10 +429,10 @@ TEST(WriteType, StaticParameterizedInnerClass) {
   Type inner_class = Type::Class("InnerTest");
   inner_class.add_parameter(type_t);
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginInnerType(inner_class, PUBLIC | STATIC)
-          .EndType()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginInnerType(inner_class, PUBLIC | STATIC)
+      .EndType()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -443,9 +449,10 @@ TEST(WriteMethod, SimpleMethod) {
   Type clazz = Type::Class("Test", "org.tensorflow");
   Method method = Method::Create("doNothing", Type::Void());
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginMethod(method, PUBLIC).EndMethod()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginMethod(method, PUBLIC)
+      .EndMethod()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -461,24 +468,28 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
   Method method = Method::Create("doNothing", Type::Void());
-  method.description("This method has a\n<p>\nmultiline description.");
+  Javadoc method_doc =
+      Javadoc::Create("Javadoc test")
+          .details("This method has a\nmultiline description.");
   method.add_annotation(Annotation::Create("Override"));
   method.add_annotation(Annotation::Create("SuppressWarnings")
       .attributes("\"rawtypes\""));
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginMethod(method, PUBLIC).EndMethod()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginMethod(method, PUBLIC, &method_doc)
+      .EndMethod()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
       "public class Test {\n"
       "  \n"
       "  /**\n"
-      "   * This method has a\n"
+      "   * Javadoc test\n"
       "   * <p>\n"
+      "   * This method has a\n"
       "   * multiline description.\n"
-      "   **/\n"
+      "   */\n"
       "  @Override\n"
       "  @SuppressWarnings(\"rawtypes\")\n"
       "  public void doNothing() {\n"
@@ -490,23 +501,27 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
 TEST(WriteMethod, DocumentedMethodWithArguments) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
+  Variable reverse = Variable::Create("reverse", Type::Boolean());
   Method method = Method::Create("boolToInt", Type::Int());
-  method.description("Converts a boolean to an int");
-  method.return_description("int value for this boolean");
   method.add_argument(Variable::Create("b", Type::Boolean()));
-  Variable reverse = Variable::Create("reverse", Type::Boolean());
-  reverse.description("if true, value is reversed");
   method.add_argument(reverse);
-
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginMethod(method, PUBLIC)
-            .Append("if (b && !reverse)")
-            .BeginBlock()
-              .Append("return 1;").EndLine()
-            .EndBlock()
-          .Append("return 0;").EndLine()
-          .EndMethod()
-        .EndType();
+  Javadoc method_doc =
+      Javadoc::Create("Converts a boolean to an int")
+          .details("This method will convert\na boolean to an int")
+          .add_param_tag(reverse.name(), "if true, value is reversed")
+          .add_tag("return", "int value for this boolean");
+
+  writer.BeginType(clazz, PUBLIC)
+      .BeginMethod(method, PUBLIC, &method_doc)
+      .Append("if (b && !reverse)")
+      .BeginBlock()
+      .Append("return 1;")
+      .EndLine()
+      .EndBlock()
+      .Append("return 0;")
+      .EndLine()
+      .EndMethod()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -514,11 +529,13 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
       "  \n"
       "  /**\n"
       "   * Converts a boolean to an int\n"
+      "   * <p>\n"
+      "   * This method will convert\n"
+      "   * a boolean to an int\n"
       "   * \n"
-      "   * @param b\n"
       "   * @param reverse if true, value is reversed\n"
       "   * @return int value for this boolean\n"
-      "   **/\n"
+      "   */\n"
       "  public int boolToInt(boolean b, boolean reverse) {\n"
       "    if (b && !reverse) {\n"
       "      return 1;\n"
@@ -536,11 +553,12 @@ TEST(WriteMethod, ParameterizedMethod) {
   clazz.add_parameter(type_t);
   Method method = Method::Create("doNothing", type_t);
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginMethod(method, PUBLIC)
-            .Append("return null;").EndLine()
-          .EndMethod()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginMethod(method, PUBLIC)
+      .Append("return null;")
+      .EndLine()
+      .EndMethod()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
@@ -560,11 +578,12 @@ TEST(WriteMethod, StaticParameterizedMethod) {
   clazz.add_parameter(type_t);
   Method method = Method::Create("doNothing", type_t);
 
-  writer.BeginType(clazz, nullptr, PUBLIC)
-          .BeginMethod(method, PUBLIC | STATIC)
-            .Append("return null;").EndLine()
-          .EndMethod()
-        .EndType();
+  writer.BeginType(clazz, PUBLIC)
+      .BeginMethod(method, PUBLIC | STATIC)
+      .Append("return null;")
+      .EndLine()
+      .EndMethod()
+      .EndType();
 
   const char* expected =
       "package org.tensorflow;\n\n"
index a6650fc..f4ff34e 100644 (file)
@@ -1,62 +1,62 @@
 # -*- Python -*-
 
-load("//tensorflow:tensorflow.bzl",
-     "tf_binary_additional_srcs",
-     "tf_cc_binary",
-     "tf_copts")
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_binary_additional_srcs",
+)
 
-# Given a list of "ops_libs" (a list of files in the core/ops directory
-# without their .cc extensions), generate Java wrapper code for all operations
-# found in the ops files.
-# Then, combine all those source files into a single archive (.srcjar).
+# Generate Java wrapper classes for all registered core operations and package
+# them into a single source archive (.srcjar).
 #
 # For example:
-#  tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ])
+#  tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package")
 #
-# will create a genrule named "gen_sources" that first generate source files:
-#     ops/src/main/java/my/package/array/*.java
-#     ops/src/main/java/my/package/math/*.java
+# will create a genrule named "gen_sources" that generates source files under
+#     ops/src/main/java/my/package/**/*.java
 #
-# and then archive those source files in:
+# and then archive those source files into
 #     ops/gen_sources.srcjar
 #
 def tf_java_op_gen_srcjar(name,
                           gen_tool,
-                          gen_base_package,
-                          ops_libs=[],
-                          ops_libs_pkg="//tensorflow/core",
+                          base_package,
+                          api_def_srcs=[],
                           out_dir="ops/",
                           out_src_dir="src/main/java/",
                           visibility=["//tensorflow/java:__pkg__"]):
 
-  gen_tools = []
   gen_cmds = ["rm -rf $(@D)"]  # Always start from fresh when generating source files
+  srcs = api_def_srcs[:]
 
-  # Construct an op generator binary for each ops library.
-  for ops_lib in ops_libs:
-    gen_lib = ops_lib[:ops_lib.rfind("_")]
-    out_gen_tool = out_dir + ops_lib + "_gen_tool"
+  if not api_def_srcs:
+    api_def_args_str = ","
+  else:
+    api_def_args = []
+    for api_def_src in api_def_srcs:
+      # Add directory of the first ApiDef source to args.
+      # We are assuming all ApiDefs in a single api_def_src are in the
+      # same directory.
+      api_def_args.append(
+          "$$(dirname $$(echo $(locations " + api_def_src +
+          ") | cut -d\" \" -f1))")
+    api_def_args_str = ",".join(api_def_args)
 
-    tf_cc_binary(
-        name=out_gen_tool,
-        copts=tf_copts(),
-        linkopts=["-lm"],
-        linkstatic=1,  # Faster to link this one-time-use binary dynamically
-        deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
-
-    gen_tools += [":" + out_gen_tool]
-    gen_cmds += ["$(location :" + out_gen_tool + ")" +
-                 " --output_dir=$(@D)/" + out_src_dir +
-                 " --lib_name=" + gen_lib +
-                 " --base_package=" + gen_base_package]
+  gen_cmds += ["$(location " + gen_tool + ")" +
+               " --output_dir=$(@D)/" + out_src_dir +
+               " --base_package=" + base_package +
+               " --api_dirs=" + api_def_args_str]
 
   # Generate a source archive containing generated code for these ops.
   gen_srcjar = out_dir + name + ".srcjar"
   gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
-  gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
-  gen_tools += tf_binary_additional_srcs()
+
   native.genrule(
       name=name,
+      srcs=srcs,
       outs=[gen_srcjar],
-      tools=gen_tools,
-      cmd="&&".join(gen_cmds))
+      tools=[
+          "@local_jdk//:jar",
+          "@local_jdk//:jdk",
+          gen_tool
+      ] + tf_binary_additional_srcs(),
+      cmd=" && ".join(gen_cmds))