Improve Javadoc and include first code review
authorkarl@kubx.ca <karl@kubx.ca>
Fri, 6 Apr 2018 12:56:54 +0000 (08:56 -0400)
committerkarl@kubx.ca <karl@kubx.ca>
Fri, 4 May 2018 02:40:42 +0000 (22:40 -0400)
14 files changed:
tensorflow/java/BUILD
tensorflow/java/src/gen/cc/java_defs.h
tensorflow/java/src/gen/cc/op_gen_main.cc
tensorflow/java/src/gen/cc/op_generator.cc
tensorflow/java/src/gen/cc/op_generator.h
tensorflow/java/src/gen/cc/op_parser.cc [deleted file]
tensorflow/java/src/gen/cc/op_parser.h [deleted file]
tensorflow/java/src/gen/cc/op_specs.cc [new file with mode: 0644]
tensorflow/java/src/gen/cc/op_specs.h [new file with mode: 0644]
tensorflow/java/src/gen/cc/source_writer.cc
tensorflow/java/src/gen/cc/source_writer.h
tensorflow/java/src/gen/cc/source_writer_test.cc
tensorflow/java/src/gen/gen_ops.bzl
tensorflow/java/src/gen/resources/license.java.snippet [moved from tensorflow/java/src/gen/resources/license.snippet.java with 100% similarity]

index 635a4e8..17566e1 100644 (file)
@@ -68,9 +68,13 @@ filegroup(
     ],
 )
 
+# Build the gen tool as a library, as it will be linked to a core/ops binary
+# files before making it an executable.
 tf_java_op_gen_srcjar(
     name = "java_op_gen_sources",
-    api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
+    api_def_srcs = [
+        "//tensorflow/core/api_def:base_api_def",
+    ],
     gen_base_package = "org.tensorflow.op",
     gen_tool = "java_op_gen_tool",
     ops_libs = [
@@ -95,30 +99,17 @@ tf_java_op_gen_srcjar(
     ],
 )
 
-# Build the gen tool as a library, as it will be linked to a core/ops binary
-# file before making it an executable. See tf_java_op_gen_srcjar().
-cc_library(
-    name = "java_op_gen_tool",
-    srcs = [
-        "src/gen/cc/op_gen_main.cc",
-    ],
-    copts = tf_copts(),
-    deps = [
-        ":java_op_gen_lib",
-    ],
-)
-
 cc_library(
     name = "java_op_gen_lib",
     srcs = [
         "src/gen/cc/op_generator.cc",
-        "src/gen/cc/op_parser.cc",
+        "src/gen/cc/op_specs.cc",
         "src/gen/cc/source_writer.cc",
     ],
     hdrs = [
         "src/gen/cc/java_defs.h",
         "src/gen/cc/op_generator.h",
-        "src/gen/cc/op_parser.h",
+        "src/gen/cc/op_specs.h",
         "src/gen/cc/source_writer.h",
     ],
     copts = tf_copts(),
index 2065477..81ac67e 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -230,12 +230,12 @@ class Javadoc {
     return Javadoc(brief);
   }
   const string& brief() const { return brief_; }
-  const string& details() const { return description_; }
-  Javadoc& details(const string description) {
-    description_ = description;
+  const string& details() const { return details_; }
+  Javadoc& details(const string& details) {
+    details_ = details;
     return *this;
   }
-  const std::list<std::pair<string, string>> tags() const { return tags_; }
+  const std::list<std::pair<string, string>>& tags() const { return tags_; }
   Javadoc& add_tag(const string& tag, const string& text) {
     tags_.push_back(std::make_pair(tag, text));
     return *this;
@@ -246,7 +246,7 @@ class Javadoc {
 
  private:
   string brief_;
-  string description_;
+  string details_;
   std::list<std::pair<string, string>> tags_;
 
   explicit Javadoc(const string& brief) : brief_(brief) {}
index 0152000..458141b 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
@@ -36,55 +36,41 @@ const char kUsageHeader[] =
     "Operation wrappers are generated under the path specified by the "
     "'--output_dir' argument. This path can be absolute or relative to the\n"
     "current working directory and will be created if it does not exists.\n\n"
-    "The '--lib_name' argument is used to classify the set of operations. If "
-    "the chosen name contains more than one word, it must be provided in \n"
-    "snake_case. This value is declined into other meaningful names, such as "
-    "the group and package of the generated operations. For example,\n"
-    "'--lib_name=my_lib' generates the operations under the "
-    "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
-    "group.\n\n"
-    "Note that the operator group assigned to the generated wrappers is just "
-    "an annotation tag at this stage. Operations will not be available "
-    "through\n"
-    "the 'org.tensorflow.op.Ops' API as a group until the generated classes "
-    "are compiled using an appropriate annotation processor.\n\n"
+    "Note that the operations will not be available through the "
+    "'org.tensorflow.op.Ops' API until the generated classes are compiled\n"
+    "using an appropriate annotation processor.\n\n"
     "The '--base_package' overrides the default parent package under which "
     "the generated subpackage and classes are to be located.\n\n"
-    "Finally, a list of directories of API proto definitions can be provided "
-    "to override default values found in the ops definitions, ordered by\n"
-    "priority (the last having precedence over the first).\n\n";
+    "Finally, the `--api_dirs` argument takes a list of comma-seperated "
+    "directories of API definitions can be provided to override default\n"
+    "values found in the ops definitions. Directories are ordered by priority "
+    "(the last having precedence over the first).\n\n";
 
 }  // namespace java
 }  // namespace tensorflow
 
 int main(int argc, char* argv[]) {
-  tensorflow::string lib_name;
   tensorflow::string output_dir;
   tensorflow::string base_package = "org.tensorflow.op";
+  tensorflow::string api_dirs_str;
   std::vector<tensorflow::Flag> flag_list = {
       tensorflow::Flag("output_dir", &output_dir,
           "Root directory into which output files are generated"),
-      tensorflow::Flag(
-          "lib_name", &lib_name,
-          "A name, in snake_case, used to classify this set of operations"),
-      tensorflow::Flag(
-          "base_package", &base_package,
-          "Package parent to the generated subpackage and classes")};
+      tensorflow::Flag("base_package", &base_package,
+          "Package parent to the generated subpackage and classes"),
+      tensorflow::Flag("api_dirs", &api_dirs_str,
+          "List of directories that contains the ops api definitions")};
   tensorflow::string usage = tensorflow::java::kUsageHeader;
   usage += tensorflow::Flags::Usage(argv[0], flag_list);
   bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
   tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
-  QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
-  std::vector<tensorflow::string> api_dirs;
-  if (argc > 1) {
-    api_dirs = tensorflow::str_util::Split(argv[1], ",",
-        tensorflow::str_util::SkipEmpty());
-  }
+  QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
+  std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
+      api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
   tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs);
   tensorflow::OpList ops;
   tensorflow::OpRegistry::Global()->Export(false, &ops);
-  tensorflow::Status status = generator.Run(ops, lib_name);
-  TF_QCHECK_OK(status);
+  TF_CHECK_OK(generator.Run(ops));
 
   return 0;
 }
index c9b57f5..c32ad3b 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ limitations under the License.
 #include <vector>
 #include <list>
 #include <memory>
+#include <set>
 
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -27,15 +28,15 @@ limitations under the License.
 #include "tensorflow/core/framework/op_gen_lib.h"
 #include "tensorflow/java/src/gen/cc/java_defs.h"
 #include "tensorflow/java/src/gen/cc/source_writer.h"
-#include "tensorflow/java/src/gen/cc/op_parser.h"
 #include "tensorflow/java/src/gen/cc/op_generator.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
 
 namespace tensorflow {
 namespace java {
 namespace {
 
 const char* kLicenseSnippet =
-    "tensorflow/java/src/gen/resources/license.snippet.java";
+    "tensorflow/java/src/gen/resources/license.java.snippet";
 
 const std::map<string, Type> kPrimitiveAttrTypes = {
   { "Boolean", Type::Boolean() },
@@ -66,34 +67,34 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
   }
   // Don't pay attention to duplicate types in the dependency list, they will
   // be filtered out by the SourceWriter.
-  for (const OpSpec::Operand& input : op.inputs()) {
+  for (const ArgumentSpec& input : op.inputs()) {
     out->push_back(input.var().type());
     if (input.iterable()) {
       out->push_back(Type::Class("Operands", "org.tensorflow.op"));
     }
   }
-  for (const OpSpec::Operand& output : op.outputs()) {
+  for (const ArgumentSpec& output : op.outputs()) {
     out->push_back(output.var().type());
     if (output.iterable()) {
       out->push_back(Type::Class("Arrays", "java.util"));
     }
   }
-  for (const OpSpec::Operand& attribute : op.attributes()) {
+  for (const AttributeSpec& attribute : op.attributes()) {
     out->push_back(attribute.var().type());
     if (attribute.var().type().name() == "Class") {
       out->push_back(Type::Enum("DataType", "org.tensorflow"));
     }
   }
-  for (const OpSpec::Operand& option : op.options()) {
-    out->push_back(option.var().type());
+  for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+    out->push_back(optional_attribute.var().type());
   }
 }
 
-void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
+void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
     SourceWriter* writer) {
   string var = optional ? "opts." + attr.var().name() : attr.var().name();
   if (attr.iterable()) {
-    const Type& type = attr.data_type();
+    const Type& type = attr.type();
     std::map<string, Type>::const_iterator it =
       kPrimitiveAttrTypes.find(type.name());
     if (it != kPrimitiveAttrTypes.end()) {
@@ -107,11 +108,11 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
           .Append(array + "[i] = " + var + ".get(i);")
           .EndLine()
           .EndBlock()
-          .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array)
+          .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array)
           .Append(");")
           .EndLine();
     } else {
-      writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var)
+      writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var)
           .Append(".toArray(new ")
           .AppendType(type)
           .Append("[" + var + ".size()]));")
@@ -119,7 +120,7 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
     }
   } else {
     Type type = attr.var().type();
-    writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", ");
+    writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
     if (type.name() == "Class") {
       writer->Append("DataType.fromClass(" + attr.var().name() + "));");
     } else {
@@ -139,26 +140,26 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
       Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
   factory.add_argument(scope);
   factory_doc.add_param_tag(scope.name(), "Current graph scope");
-  for (const OpSpec::Operand& input : op.inputs()) {
+  for (const ArgumentSpec& input : op.inputs()) {
     factory.add_argument(input.var());
     factory_doc.add_param_tag(input.var().name(), input.description());
   }
-  for (const OpSpec::Operand& attribute : op.attributes()) {
+  for (const AttributeSpec& attribute : op.attributes()) {
     factory.add_argument(attribute.var());
     factory_doc.add_param_tag(attribute.var().name(), attribute.description());
   }
-  if (!op.options().empty()) {
+  if (!op.optional_attributes().empty()) {
     factory.add_argument(Variable::Varargs("options", Type::Class("Options")));
     factory_doc.add_param_tag("options", "carries optional attributes values");
   }
   factory_doc.add_tag("return", "a new instance of " + op_class.name());
   writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
   writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
-      + op.graph_name() + "\", scope.makeOpName(\""
+      + op.graph_op_name() + "\", scope.makeOpName(\""
       + op_class.name() + "\"));");
   writer->EndLine();
 
-  for (const OpSpec::Operand& input : op.inputs()) {
+  for (const ArgumentSpec& input : op.inputs()) {
     if (input.iterable()) {
       writer->Append("opBuilder.addInputList(Operands.asOutputs("
           + input.var().name() + "));");
@@ -169,15 +170,15 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
       writer->EndLine();
     }
   }
-  for (const OpSpec::Operand& attribute : op.attributes()) {
+  for (const AttributeSpec& attribute : op.attributes()) {
     WriteSetAttrDirective(attribute, false, writer);
   }
-  if (!op.options().empty()) {
+  if (!op.optional_attributes().empty()) {
     writer->BeginBlock("if (options != null)")
         .BeginBlock("for (Options opts : options)");
-    for (const OpSpec::Operand& option : op.options()) {
-      writer->BeginBlock("if (opts." + option.var().name() + " != null)");
-      WriteSetAttrDirective(option, true, writer);
+    for (const AttributeSpec& attribute : op.optional_attributes()) {
+      writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
+      WriteSetAttrDirective(attribute, true, writer);
       writer->EndBlock();
     }
     writer->EndBlock().EndBlock();
@@ -195,8 +196,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
     .add_argument(
         Variable::Create("operation",
             Type::Class("Operation", "org.tensorflow")));
-  for (const OpSpec::Operand& output : op.outputs()) {
-    if (output.iterable() && !output.data_type().unknown()) {
+  for (const ArgumentSpec& output : op.outputs()) {
+    if (output.iterable() && !output.type().unknown()) {
       constructor.add_annotation(
           Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
       break;
@@ -208,15 +209,15 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
   if (op.outputs().size() > 0) {
     writer->Append("int outputIdx = 0;")
         .EndLine();
-    for (const OpSpec::Operand& output : op.outputs()) {
+    for (const ArgumentSpec& output : op.outputs()) {
       if (output.iterable()) {
         string var_length = output.var().name() + "Length";
         writer->Append("int " + var_length)
-            .Append(" = operation.outputListLength(\"" + output.graph_name()
+            .Append(" = operation.outputListLength(\"" + output.op_def_name()
                 + "\");")
             .EndLine()
             .Append(output.var().name() + " = Arrays.asList(");
-        if (!output.data_type().unknown()) {
+        if (!output.type().unknown()) {
           writer->Append("(")
               .AppendType(output.var().type().parameters().front())
               .Append("[])");
@@ -236,18 +237,19 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
 }
 
 void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
-  for (const OpSpec::Operand& option : op.options()) {
-    Method setter = Method::Create(option.var().name(), Type::Class("Options"))
-        .add_argument(option.var());
+  for (const AttributeSpec& attribute : op.optional_attributes()) {
+    Method setter =
+        Method::Create(attribute.var().name(), Type::Class("Options"))
+            .add_argument(attribute.var());
     Javadoc setter_doc = Javadoc::Create()
-        .add_param_tag(option.var().name(), option.description());
+        .add_param_tag(attribute.var().name(), attribute.description());
     writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
-        .Append("return new Options()." + option.var().name() + "("
-            + option.var().name() + ");")
+        .Append("return new Options()." + attribute.var().name() + "("
+            + attribute.var().name() + ");")
         .EndLine()
         .EndMethod();
   }
-  for (const OpSpec::Operand& output : op.outputs()) {
+  for (const ArgumentSpec& output : op.outputs()) {
     Method getter = Method::Create(output.var().name(), output.var().type());
     Javadoc getter_doc = Javadoc::Create(output.description());
     writer->BeginMethod(getter, PUBLIC, &getter_doc)
@@ -259,12 +261,12 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
 
 void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
     SourceWriter* writer) {
-  OpSpec::Operand output = op.outputs().front();
+  ArgumentSpec output = op.outputs().front();
 
   if (mode == SINGLE_OUTPUT) {
-    bool cast2obj = output.data_type().unknown();
+    bool cast2obj = output.type().unknown();
     Type return_type = Type::Class("Output", "org.tensorflow")
-        .add_parameter(cast2obj ? Type::Class("Object") : output.data_type());
+        .add_parameter(cast2obj ? Type::Class("Object") : output.type());
     Method as_output = Method::Create("asOutput", return_type)
         .add_annotation(Annotation::Create("Override"));
     if (cast2obj) {
@@ -283,10 +285,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
 
   } else if (mode == SINGLE_LIST_OUTPUT) {
     Type operand = Type::Interface("Operand", "org.tensorflow");
-    if (output.data_type().unknown()) {
+    if (output.type().unknown()) {
       operand.add_parameter(Type::Class("Object"));
     } else {
-      operand.add_parameter(output.data_type());
+      operand.add_parameter(output.type());
     }
     Type return_type = Type::Interface("Iterator", "java.util")
         .add_parameter(operand);
@@ -308,57 +310,119 @@ void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) {
   Javadoc options_doc = Javadoc::Create(
       "Class holding optional attributes of this operation");
   writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
-  for (const OpSpec::Operand& option : op.options()) {
-    Method setter = Method::Create(option.var().name(), options_class)
-        .add_argument(option.var());
+  for (const AttributeSpec& attribute : op.optional_attributes()) {
+    Method setter = Method::Create(attribute.var().name(), options_class)
+        .add_argument(attribute.var());
     Javadoc setter_doc = Javadoc::Create()
-        .add_param_tag(option.var().name(), option.description());
+        .add_param_tag(attribute.var().name(), attribute.description());
     writer->BeginMethod(setter, PUBLIC, &setter_doc)
-        .Append("this." + option.var().name() + " = " + option.var().name()
-            + ";")
+        .Append("this." + attribute.var().name() + " = "
+            + attribute.var().name() + ";")
         .EndLine()
         .Append("return this;")
         .EndLine()
         .EndMethod();
   }
   writer->EndLine();
-  for (const OpSpec::Operand& option : op.options()) {
-    writer->WriteField(option.var(), PRIVATE);
+  for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+    writer->WriteField(optional_attribute.var(), PRIVATE);
   }
   Method constructor = Method::ConstructorFor(options_class);
   writer->BeginMethod(constructor, PRIVATE).EndMethod();
   writer->EndType();
 }
 
-void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint,
-    SourceWriter* writer) {
+inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
+  return Type::Class(endpoint.name(),
+      base_package + "." + str_util::Lowercase(endpoint.package()));
+}
+
+void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
+    const string& base_package, const string& output_dir, Env* env) {
+  Type op_class(ClassOf(endpoint, base_package)
+      .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
+  Javadoc op_javadoc(endpoint.javadoc());
+
+  // implement Operand (or Iterable<Operand>) if the op has only one output
   RenderMode mode = DEFAULT;
   if (op.outputs().size() == 1) {
-    mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT;
+    const ArgumentSpec& output = op.outputs().front();
+    Type operand_type(output.type().unknown() ?
+        Type::Class("Object") : output.type());
+    Type operand_inf(Type::Interface("Operand", "org.tensorflow")
+        .add_parameter(operand_type));
+    if (output.iterable()) {
+      mode = SINGLE_LIST_OUTPUT;
+      op_class.add_supertype(Type::IterableOf(operand_inf));
+    } else {
+      mode = SINGLE_OUTPUT;
+      op_class.add_supertype(operand_inf);
+    }
+  }
+  // declare all outputs generics at the op class level
+  std::set<string> generics;
+  for (const ArgumentSpec& output : op.outputs()) {
+    if (output.type().kind() == Type::GENERIC && !output.type().unknown()
+        && generics.find(output.type().name()) == generics.end()) {
+      op_class.add_parameter(output.type());
+      op_javadoc.add_param_tag("<" + output.type().name() + ">",
+          "data type of output {@code " + output.var().name() + "}");
+      generics.insert(output.type().name());
+    }
+  }
+  // handle endpoint deprecation
+  if (endpoint.deprecated()) {
+    op_class.add_annotation(Annotation::Create("Deprecated"));
+    string explanation;
+    if (!op.endpoints().front().deprecated()) {
+      explanation = "use {@link " +
+          ClassOf(op.endpoints().front(), base_package).full_name()
+          + "} instead";
+    } else {
+      explanation = op.deprecation_explanation();
+    }
+    op_javadoc.add_tag("deprecated", explanation);
   }
+  // expose the op in the Ops Graph API only if it is visible
+  if (!op.hidden()) {
+    op_class.add_annotation(
+        Annotation::Create("Operator", "org.tensorflow.op.annotation")
+          .attributes("group = \"" + endpoint.package() + "\""));
+  }
+  // create op class file
+  string op_dir = io::JoinPath(output_dir,
+      str_util::StringReplace(op_class.package(), ".", "/", true));
+  if (!env->FileExists(op_dir).ok()) {
+    TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir));
+  }
+  std::unique_ptr<tensorflow::WritableFile> op_file;
+  TF_CHECK_OK(env->NewWritableFile(
+      io::JoinPath(op_dir, op_class.name() + ".java"), &op_file));
+
+  // render endpoint source code
+  SourceFileWriter writer(op_file.get());
   std::list<Type> dependencies;
   CollectOpDependencies(op, mode, &dependencies);
-  const Type& op_class = endpoint.type();
-  writer->WriteFromFile(kLicenseSnippet)
+  writer.WriteFromFile(kLicenseSnippet)
       .EndLine()
       .Append("// This file is machine generated, DO NOT EDIT!")
       .EndLine()
       .EndLine()
-      .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc());
-  if (!op.options().empty()) {
-    RenderOptionsClass(op, writer);
+      .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc);
+  if (!op.optional_attributes().empty()) {
+    RenderOptionsClass(op, &writer);
   }
-  RenderFactoryMethod(op, op_class, writer);
-  RenderGettersAndSetters(op, writer);
+  RenderFactoryMethod(op, op_class, &writer);
+  RenderGettersAndSetters(op, &writer);
   if (mode != DEFAULT) {
-    RenderInterfaceImpl(op, mode, writer);
+    RenderInterfaceImpl(op, mode, &writer);
   }
-  writer->EndLine();
-  for (const OpSpec::Operand& output : op.outputs()) {
-    writer->WriteField(output.var(), PRIVATE);
+  writer.EndLine();
+  for (const ArgumentSpec& output : op.outputs()) {
+    writer.WriteField(output.var(), PRIVATE);
   }
-  RenderConstructor(op, op_class, writer);
-  writer->EndType();
+  RenderConstructor(op, op_class, &writer);
+  writer.EndType();
 }
 
 }  // namespace
@@ -369,8 +433,7 @@ OpGenerator::OpGenerator(const string& base_package, const string& output_dir,
     env_(env) {
 }
 
-Status OpGenerator::Run(const OpList& op_list, const string& lib_name) {
-  LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
+Status OpGenerator::Run(const OpList& op_list) {
   ApiDefMap api_map(op_list);
   if (!api_dirs_.empty()) {
     // Only load api files that correspond to the requested "op_list"
@@ -388,37 +451,14 @@ Status OpGenerator::Run(const OpList& op_list, const string& lib_name) {
   for (const auto& op_def : op_list.op()) {
     const ApiDef* api_def = api_map.GetApiDef(op_def.name());
     if (api_def->visibility() != ApiDef::SKIP) {
-      Status status = GenerateOp(op_def, *api_def, lib_name);
-      if (status != Status::OK()) {
-        LOG(ERROR) << "Fail to generate Java wrapper for operation \""
-            << op_def.name() << "\"";
+      OpSpec op(OpSpec::Create(op_def, *api_def));
+      for (const EndpointSpec& endpoint : op.endpoints()) {
+        GenerateOp(op, endpoint, base_package_, output_dir_, env_);
       }
     }
   }
   return Status::OK();
 }
 
-Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def,
-    const string& lib_name) {
-  std::unique_ptr<OpSpec> op;
-  OpParser op_parser(op_def, api_def, lib_name, base_package_);
-  op_parser.Parse(&op);
-  for (const OpSpec::Endpoint& endpoint : op->endpoints()) {
-    string package_path = io::JoinPath(output_dir_,
-        str_util::StringReplace(endpoint.type().package(), ".", "/", true));
-    if (!env_->FileExists(package_path).ok()) {
-      TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path));
-    }
-    string file_path =
-        io::JoinPath(package_path, endpoint.type().name() + ".java");
-    std::unique_ptr<tensorflow::WritableFile> file;
-    TF_CHECK_OK(env_->NewWritableFile(file_path, &file));
-
-    SourceFileWriter writer(file.get());
-    RenderEndpoint(*op, endpoint, &writer);
-  }
-  return Status::OK();
-}
-
 }  // namespace java
 }  // namespace tensorflow
index 19d8db9..06b08e8 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -23,36 +23,33 @@ limitations under the License.
 #include "tensorflow/core/framework/api_def.pb.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
 
 namespace tensorflow {
 namespace java {
 
 // A generator of Java operation wrappers.
 //
-// Such generator is normally ran only once per executable, outputting
-// wrappers for the all registered operations it has been compiled with.
-// Nonetheless, it is designed to support multiple runs, giving a different
-// list of operations on each cycle.
+// This generator takes a list of ops definitions in input and outputs
+// a Java Op wrapper for each of them in the provided directory. The same
+// generator instance can be invoked multiple times with a different list of
+// ops definitions.
 class OpGenerator {
  public:
   OpGenerator(const string& base_package, const string& output_dir,
       const std::vector<string>& api_dirs, Env* env = Env::Default());
-  virtual ~OpGenerator() = default;
 
   // Generates wrappers for the given list of 'ops'.
   //
   // Output files are generated in <output_dir>/<base_package>/<lib_package>,
-  // where 'lib_package' is derived from 'lib_name'.
-  Status Run(const OpList& op_list, const string& lib_name);
+  // where 'lib_package' is derived from ops endpoints.
+  Status Run(const OpList& op_list);
 
  private:
-  string base_package_;
-  string output_dir_;
-  std::vector<string> api_dirs_;
+  const string base_package_;
+  const string output_dir_;
+  const std::vector<string> api_dirs_;
   Env* env_;
-
-  Status GenerateOp(const OpDef& op_def, const ApiDef& api_def,
-    const string& lib_name);
 };
 
 }  // namespace java
diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc
deleted file mode 100644 (file)
index 0541e34..0000000
+++ /dev/null
@@ -1,417 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <map>
-#include <vector>
-#include <utility>
-#include <string>
-
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/java/src/gen/cc/op_parser.h"
-
-namespace tensorflow {
-namespace java {
-namespace {
-
-string SnakeToCamelCase(const string& str, bool upper = false) {
-  string result;
-  bool cap = upper;
-  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
-    const char c = *it;
-    if (c == '_') {
-      cap = true;
-    } else if (cap) {
-      result += toupper(c);
-      cap = false;
-    } else {
-      result += c;
-    }
-  }
-  return result;
-}
-
-bool IsRealNumber(DataType type) {
-  for (DataType dt : RealNumberTypes()) {
-    if (type == dt) {
-      return true;
-    }
-  }
-  return false;
-}
-
-bool IsRealNumbers(const AttrValue& values) {
-  if (values.has_list()) {
-    for (int i = 0; i < values.list().type_size(); ++i) {
-      if (!IsRealNumber(values.list().type(i))) {
-        return false;
-      }
-    }
-    return true;
-  }
-  return IsRealNumber(values.type());
-}
-
-string ParseDocumentation(const string& text) {
-  std::stringstream javadoc_text;
-  string::const_iterator c_iter = text.cbegin();
-  bool code = false;
-  bool emphasis = false;
-  bool list = false;
-  while (c_iter != text.cend()) {
-    char c = *c_iter++;
-    int count = 1;
-    switch (c) {
-    case '\n':
-      if (!code) {
-        // consumes all subsequent newlines, if there are more than one,
-        // then there are two choices:
-        // - if the next line starts with an asterisk, we are enumerating
-        //   a list of items
-        // - otherwise, we are starting a new paragraph
-        for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {}
-        if (c_iter != text.cend()) {
-          if (count > 1) {
-            if (*c_iter != '*' && list) {
-              javadoc_text << "</li>\n</ul>\n";
-              list = false;
-            } else if (*c_iter == '*' && !list) {
-              javadoc_text << "\n<ul>\n<li>";
-              list = true;
-              c_iter++;
-            } else {
-              javadoc_text << "\n<p>\n";
-            }
-          } else if (list && *c_iter == '*') {
-            javadoc_text << "</li>\n<li>";
-            c_iter++;
-          } else {
-            javadoc_text << '\n';
-          }
-        }
-      }
-      break;
-    case '`':
-      // consumes all subsequent backquotes, those are use enclose code.
-      // if there are more than 3, we are dealing with a pre-formatted block,
-      // otherwise it is a single-line code snippet
-      for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {}
-      if (count >= 3) {
-        javadoc_text << (code ? "\n}</pre>" : "<pre>{@code\n");
-      } else {
-        javadoc_text << (code ? "}" : "{@code ");
-      }
-      code = !code;
-      break;
-    case '*':
-      if (!code) {
-        // consumes all subsequent asterisks, if there are more than one, then
-        // we put the text in bold, otherwise in italic
-        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
-        if (count > 1) {
-          javadoc_text << (emphasis ? "</b>" : "<b>");
-        } else {
-          javadoc_text << (emphasis ? "</i>" : "<i>");
-        }
-        emphasis = !emphasis;
-      } else {
-        javadoc_text << '*';
-      }
-      break;
-    default:
-      javadoc_text << c;
-      break;
-    }
-  }
-  return javadoc_text.str();
-}
-
-}  // namespace
-
-OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
-    const string& lib_name, const string& base_package)
-  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
-    base_package_(base_package) {
-}
-
-void OpParser::Parse(std::unique_ptr<OpSpec>* op_ptr) {
-  visited_attrs_.clear();
-  next_generic_ = 'T';
-  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
-  for (const string& next_input_name : op_api_.arg_order()) {
-    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
-      if (op_def_.input_arg(i).name() == next_input_name) {
-        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
-        break;
-      }
-    }
-  }
-  for (int i = 0; i < op_def_.attr().size(); ++i) {
-    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
-  }
-  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
-    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
-  }
-  BuildEndpoints(op_ptr->get());
-}
-
-void OpParser::BuildEndpoints(OpSpec* op) {
-  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
-    .details(ParseDocumentation(op_api_.description()));
-  std::vector<Type> op_supertypes;
-  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
-  std::map<string, const Type*> op_generics;
-  for (const OpSpec::Operand& output : op->outputs()) {
-    // declare generic output parameters at the Op class level
-    const Type& data_type = output.data_type();
-    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
-        && op_generics.find(data_type.name()) == op_generics.end()) {
-      op_generics.insert(std::make_pair(data_type.name(), &data_type));
-      op_doc.add_param_tag("<" + data_type.name() + ">",
-          "data type of output '" + output.var().name() + "'");
-    }
-    // implement the Op as an (iteration of) Operand if it has only one output
-    if (op->outputs().size() == 1) {
-      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
-      operand_inf.add_parameter(data_type.unknown() ?
-          Type::Class("Object") : data_type);
-      op_supertypes.push_back(output.iterable() ?
-          Type::IterableOf(operand_inf) : operand_inf);
-    }
-  }
-  for (const auto& endpoint_def : op_api_.endpoint()) {
-    std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
-    // if the endpoint specifies a package, use it, otherwise derive it from the
-    // op library name.
-    string name;
-    string package;
-    if (name_tokens.size() > 1) {
-      package = str_util::Lowercase(name_tokens.at(0));
-      name = name_tokens.at(1);
-    } else {
-      package = str_util::StringReplace(lib_name_, "_", "", true);
-      name = name_tokens.at(0);
-    }
-    Type endpoint(Type::Class(name, base_package_ + "." + package));
-    Javadoc endpoint_doc(op_doc);
-    for (const auto& parameter : op_generics) {
-      endpoint.add_parameter(*parameter.second);
-    }
-    for (const Type& supertype : op_supertypes) {
-      endpoint.add_supertype(supertype);
-    }
-    if (endpoint_def.deprecation_version() > 0) {
-      string explanation;
-      if (op_api_.endpoint(0).deprecation_version() == 0) {
-        explanation = ", use {@link "
-            + op->endpoints().at(0).type().full_name()
-            + "} instead";
-      } else {
-        explanation = op_def_.deprecation().explanation();
-      }
-      endpoint_doc.add_tag("deprecated", explanation);
-      endpoint.add_annotation(Annotation::Create("Deprecated"));
-    }
-    // only visible ops should be annotated for exposure in the Ops Graph API
-    if (op_api_.visibility() != ApiDef::HIDDEN) {
-      string group_name = SnakeToCamelCase(lib_name_);
-      endpoint.add_annotation(
-          Annotation::Create("Operator", "org.tensorflow.op.annotation")
-            .attributes("group = \"" + group_name + "\""));
-    }
-    op->add_endpoint(endpoint, endpoint_doc);
-  }
-}
-
-void OpParser::ParseInput(const OpDef_ArgDef& input_def,
-    const ApiDef::Arg& input_api, OpSpec* op) {
-  bool iterable = false;
-  Type data_type = DataTypeOf(input_def, &iterable);
-  Type type = Type::Interface("Operand", "org.tensorflow")
-    .add_parameter(data_type);
-  if (iterable) {
-    type = Type::IterableOf(type);
-  }
-  op->add_input(OpSpec::Operand(input_api.name(),
-      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
-      data_type,
-      ParseDocumentation(input_api.description()),
-      iterable));
-}
-
-void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
-    const ApiDef::Arg& output_api, OpSpec* op) {
-  bool iterable = false;
-  Type data_type = DataTypeOf(output_def, &iterable);
-  Type type = Type::Class("Output", "org.tensorflow")
-    .add_parameter(data_type);
-  if (iterable) {
-    type = Type::ListOf(type);
-  }
-  op->add_output(OpSpec::Operand(output_api.name(),
-      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
-      data_type,
-      ParseDocumentation(output_api.description()),
-      iterable));
-}
-
-void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
-    const ApiDef::Attr& attr_api, OpSpec* op) {
-  // do not parse attributes already visited, they have probably been inferred
-  // before as an input argument type
-  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
-    return;
-  }
-  bool iterable = false;
-  Type data_type = DataTypeOf(attr_def, &iterable);
-  // generic attributes should be passed as an explicit type
-  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
-  Type type = explicit_type ?
-      Type::Class("Class").add_parameter(data_type) : data_type;
-  if (iterable) {
-    type = Type::ListOf(data_type);
-  }
-  OpSpec::Operand attr(attr_api.name(),
-      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
-      data_type,
-      ParseDocumentation(attr_api.description()),
-      iterable);
-  // attributes with a default value are optional
-  if (attr_api.has_default_value() && !explicit_type) {
-    op->add_option(attr);
-  } else {
-    op->add_attribute(attr);
-  }
-  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
-}
-
-Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
-  if (!arg.number_attr().empty()) {
-    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
-    *iterable_out = true;
-  }
-  if (arg.type() != DataType::DT_INVALID) {
-    // resolve type from DataType
-    switch (arg.type()) {
-      case DataType::DT_BOOL:
-        return Type::Class("Boolean");
-
-      case DataType::DT_STRING:
-        return Type::Class("String");
-
-      case DataType::DT_FLOAT:
-        return Type::Class("Float");
-
-      case DataType::DT_DOUBLE:
-        return Type::Class("Double");
-
-      case DataType::DT_UINT8:
-        return Type::Class("UInt8", "org.tensorflow.types");
-
-      case DataType::DT_INT32:
-        return Type::Class("Integer");
-
-      case DataType::DT_INT64:
-        return Type::Class("Long");
-
-      case DataType::DT_RESOURCE:
-        // TODO(karllessard) create a Resource utility class that could be
-        // used to store a resource and its type (passed in a second argument).
-        // For now, we need to force a wildcard and we will unfortunately lose
-        // track of the resource type.
-        return Type::Wildcard();
-
-      default:
-        break;
-    }
-  } else {
-    // resolve type from type attribute
-    string attr_name = arg.type_attr();
-    if (attr_name.empty()) {
-      attr_name = arg.type_list_attr();
-      if (!attr_name.empty()) {
-        *iterable_out = true;
-        Type type = Type::Wildcard();
-        visited_attrs_.insert(std::make_pair(attr_name, type));
-        return type;
-      }
-    }
-    for (const auto& attr : op_def_.attr()) {
-      if (attr.name() == attr_name) {
-        Type type = DataTypeOf(attr, iterable_out);
-        visited_attrs_.insert(std::make_pair(attr_name, type));
-        return type;
-      }
-    }
-  }
-  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
-  return Type::Wildcard();
-}
-
-Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
-  std::map<string, Type>::const_iterator it = visited_attrs_.find(attr.name());
-  if (it != visited_attrs_.cend()) {
-    return it->second;
-  }
-  string attr_type = attr.type();
-  if (attr.type().compare(0, 5, "list(") == 0) {
-    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
-    *iterable_out = true;
-  }
-  if (attr_type == "type") {
-    if (*iterable_out) {
-      return Type::Enum("DataType", "org.tensorflow");
-    }
-    return GetNextGenericTensorType(attr.allowed_values());
-  }
-  if (attr_type == "string") {
-    return Type::Class("String");
-  }
-  if (attr_type == "int") {
-    return Type::Class("Integer");
-  }
-  if (attr_type == "float") {
-    return Type::Class("Float");
-  }
-  if (attr_type == "bool") {
-    return Type::Class("Boolean");
-  }
-  if (attr_type == "shape") {
-    return Type::Class("Shape", "org.tensorflow");
-  }
-  if (attr_type == "tensor") {
-    return Type::Class("Tensor", "org.tensorflow")
-      .add_parameter(Type::Wildcard());
-  }
-  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
-  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
-}
-
-Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
-  Type generic = Type::Generic(string(1, next_generic_));
-  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
-
-  // when only real numbers are allowed, enforce that restriction in the Java by
-  // extending the generic from java.lang.Number
-  if (IsRealNumbers(allowed_values)) {
-    generic.add_supertype(Type::Class("Number"));
-  }
-  return generic;
-}
-
-}  // namespace java
-}  // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
deleted file mode 100644 (file)
index 4285512..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
-#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
-
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/api_def.pb.h"
-#include "tensorflow/java/src/gen/cc/java_defs.h"
-
-namespace tensorflow {
-namespace java {
-
-// Specification of a TensorFlow operation to generate.
-//
-// This is the result of an operation definition parsing, see OpParser::Parse().
-class OpSpec {
- public:
-  class Endpoint {
-   public:
-    Endpoint(const Type& type, const Javadoc& javadoc)
-      : type_(type), javadoc_(javadoc) {}
-    const Type& type() const { return type_; }
-    const Javadoc& javadoc() const { return javadoc_; }
-
-   private:
-    Type type_;
-    Javadoc javadoc_;
-  };
-
-  class Operand {
-   public:
-    Operand(const string& graph_name, const Variable& var,
-        const Type& data_type, const string& description, bool iterable)
-     : graph_name_(graph_name), var_(var), data_type_(data_type),
-       description_(description), iterable_(iterable) {}
-    const string& graph_name() const { return graph_name_; }
-    const Variable& var() const { return var_; }
-    Variable* var_ptr() { return &var_; }
-    const Type& data_type() const { return data_type_; }
-    const string& description() const { return description_; }
-    bool iterable() const { return iterable_; }
-
-   private:
-    string graph_name_;
-    Variable var_;
-    Type data_type_;
-    string description_;
-    bool iterable_;
-  };
-
-  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
-  const string& graph_name() const { return graph_name_; }
-  const std::vector<Endpoint> endpoints() const { return endpoints_; }
-  void add_endpoint(const Type& type, const Javadoc& javadoc) {
-    endpoints_.push_back(Endpoint(type, javadoc));
-  }
-  const std::vector<Operand>& inputs() const { return inputs_; }
-  void add_input(const Operand& input) {
-    inputs_.push_back(input);
-  }
-  const std::vector<Operand>& outputs() const { return outputs_; }
-  void add_output(const Operand& output) {
-    outputs_.push_back(output);
-  }
-  const std::vector<Operand>& attributes() const { return attributes_; }
-  void add_attribute(const Operand& attribute) {
-    attributes_.push_back(attribute);
-  }
-  const std::vector<Operand>& options() const { return options_; }
-  void add_option(const Operand& option) {
-    options_.push_back(option);
-  }
-
- private:
-  string graph_name_;
-  std::vector<Endpoint> endpoints_;
-  std::vector<Operand> inputs_;
-  std::vector<Operand> outputs_;
-  std::vector<Operand> attributes_;
-  std::vector<Operand> options_;
-};
-
-// A parser of ops proto definitions.
-//
-// This object parses the definition and the api of an TensorFlow operation to
-// produce a specification that can be used for Java source code rendering.
-class OpParser {
- public:
-  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
-      const string& base_package);
-  virtual ~OpParser() = default;
-
-  // Produces an operation specification from its proto definitions.
-  void Parse(std::unique_ptr<OpSpec>* op_ptr);
-
- private:
-  OpDef op_def_;
-  ApiDef op_api_;
-  string lib_name_;
-  string base_package_;
-  std::map<string, Type> visited_attrs_;
-  char next_generic_ = 0;
-
-  void BuildEndpoints(OpSpec* op);
-  void ParseInput(const OpDef_ArgDef& input_def,
-      const ApiDef::Arg& input_api, OpSpec* op);
-  void ParseOutput(const OpDef_ArgDef& output_def,
-      const ApiDef::Arg& output_api, OpSpec* op);
-  void ParseAttribute(const OpDef_AttrDef& attr_def,
-      const ApiDef::Attr& attr_api, OpSpec* op);
-  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
-  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
-  Type GetNextGenericTensorType(const AttrValue& allowed_values);
-};
-
-}  // namespace java
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
new file mode 100644 (file)
index 0000000..a727f7a
--- /dev/null
@@ -0,0 +1,390 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+inline bool IsRealNumbers(const AttrValue& values) {
+  if (!values.has_list()) {
+    return RealNumberTypes().Contains(values.type());
+  }
+  for (int i = 0; i < values.list().type_size(); ++i) {
+    if (!RealNumberTypes().Contains(values.list().type(i))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+class TypeResolver {
+ public:
+  explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
+
+  Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+  Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
+  bool IsAttributeVisited(const string& attr_name) {
+    return visited_attrs_.find(attr_name) != visited_attrs_.cend();
+  }
+ private:
+  const OpDef op_def_;
+  std::map<std::string, Type> visited_attrs_;
+  char next_generic_ = 'T';
+};
+
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
+    bool* iterable_out) {
+  *iterable_out = false;
+  if (!arg_def.number_attr().empty()) {
+    // when number_attr is set, argument has to be a list of tensors
+    *iterable_out = true;
+    visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
+  }
+  Type type = Type::Wildcard();
+  if (arg_def.type() != DataType::DT_INVALID) {
+    // resolve type from DataType
+    switch (arg_def.type()) {
+      case DataType::DT_BOOL:
+        type = Type::Class("Boolean");
+        break;
+      case DataType::DT_STRING:
+        type = Type::Class("String");
+        break;
+      case DataType::DT_FLOAT:
+        type = Type::Class("Float");
+        break;
+      case DataType::DT_DOUBLE:
+        type = Type::Class("Double");
+        break;
+      case DataType::DT_UINT8:
+        type = Type::Class("UInt8", "org.tensorflow.types");
+        break;
+      case DataType::DT_INT32:
+        type = Type::Class("Integer");
+        break;
+      case DataType::DT_INT64:
+        type = Type::Class("Long");
+        break;
+      case DataType::DT_RESOURCE:
+        // TODO(karllessard) create a Resource utility class that could be
+        // used to store a resource and its type (passed in a second argument).
+        // For now, we need to force a wildcard and we will unfortunately lose
+        // track of the resource type.
+        break;
+      default:
+        // Any other datatypes does not have a equivalent in Java and must
+        // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+        break;
+    }
+  } else if (!arg_def.type_attr().empty()) {
+    // resolve type from attribute (if already visited, retrieve its type)
+    if (IsAttributeVisited(arg_def.type_attr())) {
+      type = visited_attrs_.at(arg_def.type_attr());
+    } else {
+      for (const auto& attr_def : op_def_.attr()) {
+        if (attr_def.name() == arg_def.type_attr()) {
+          type = TypeOf(attr_def, iterable_out);
+          break;
+        }
+      }
+    }
+  } else if (!arg_def.type_list_attr().empty()) {
+    // type is a list of tensors that can be of different data types, so leave
+    // it as a list of wildcards
+    *iterable_out = true;
+    visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
+
+  } else {
+    LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
+        << "\" in operation \"" << op_def_.name() << "\"";
+  }
+  return type;
+}
+
+Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
+    bool* iterable_out) {
+  *iterable_out = false;
+  StringPiece attr_type = attr_def.type();
+  if (str_util::ConsumePrefix(&attr_type, "list(")) {
+    attr_type.remove_suffix(1);  // remove closing brace
+    *iterable_out = true;
+  }
+  Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object");
+  if (attr_type == "type") {
+    if (*iterable_out) {
+      type = Type::Enum("DataType", "org.tensorflow");
+    } else {
+      type = Type::Generic(string(1, next_generic_));
+      next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
+      if (IsRealNumbers(attr_def.allowed_values())) {
+        // enforce real numbers datasets by extending java.lang.Number
+        type.add_supertype(Type::Class("Number"));
+      }
+    }
+  } else if (attr_type == "string") {
+    type = Type::Class("String");
+
+  } else if (attr_type == "int") {
+    type = Type::Class("Integer");
+
+  } else if (attr_type == "float") {
+    type = Type::Class("Float");
+
+  } else if (attr_type == "bool") {
+    type = Type::Class("Boolean");
+
+  } else if (attr_type == "shape") {
+    type = Type::Class("Shape", "org.tensorflow");
+
+  } else if (attr_type == "tensor") {
+    type = Type::Class("Tensor", "org.tensorflow")
+        .add_parameter(Type::Wildcard());
+
+  } else {
+    LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
+        << "\" in operation \"" << op_def_.name() << "\"";
+  }
+  visited_attrs_.insert(std::make_pair(attr_def.name(), type));
+  return type;
+}
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+  string result;
+  bool cap = upper;
+  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+    const char c = *it;
+    if (c == '_') {
+      cap = true;
+    } else if (cap) {
+      result += toupper(c);
+      cap = false;
+    } else {
+      result += c;
+    }
+  }
+  return result;
+}
+
+bool FindAndCut(re2::StringPiece* input, const RE2& expr,
+    re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
+  re2::StringPiece match;
+  bool matches =
+      expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1);
+  if (matches) {
+    before_match->set(input->data(), match.begin() - input->begin());
+    input->remove_prefix(match.end() - before_match->begin());
+    if (ret_match != nullptr) {
+      *ret_match = match;
+    }
+  } else {
+    *before_match = *input;
+    if (ret_match != nullptr) {
+      ret_match->set(nullptr, 0);
+    }
+  }
+  return matches;
+}
+
+string ParseDocumentation(const string& mdtext) {
+  std::stringstream javadoc_text;
+  re2::StringPiece input(mdtext);
+  re2::StringPiece text;
+  bool in_list = false;
+  do {
+    re2::StringPiece markup;
+    FindAndCut(&input,
+        "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\[",
+        &text, &markup);
+    javadoc_text << text;
+    if (markup.empty()) {
+      break;  // we are done parsing
+    }
+    if (markup.starts_with("\n")) {
+      javadoc_text << "\n";
+      if (markup.contains("* ")) {
+        javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
+        in_list = true;
+      } else if (markup.starts_with("\n\n")) {
+        if (in_list) {
+          javadoc_text << "</li>\n</ul>\n";
+          in_list = false;
+        } else if (!input.starts_with("```")) {
+          javadoc_text << "<p>\n";
+        }
+      }
+    } else if (markup.starts_with("```") && text.empty()) {
+      re2::StringPiece language;
+      RE2::Consume(&input, "[\\w\\+]+", &language);
+      if (FindAndCut(&input, markup.ToString() + "\n*", &text)) {
+        javadoc_text << "<pre>\n{@code" << text << "}\n</pre>\n";
+      } else {
+        javadoc_text << markup << language;
+      }
+    } else if (markup.starts_with("`")) {
+      if (FindAndCut(&input, markup, &text)) {
+        javadoc_text << "{@code " << text << "}";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (markup == "**") {
+      if (FindAndCut(&input, "\\b\\*{2}", &text)) {
+        javadoc_text << "<b>" << text << "</b>";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (markup == "*") {
+      if (FindAndCut(&input, "\\b\\*{1}", &text)) {
+        javadoc_text << "<i>" << text << "</i>";
+      } else {
+        javadoc_text << markup;
+      }
+    } else if (markup == "[") {
+      string label;
+      string link;
+      if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) {
+        javadoc_text << "<a href=\"" << link << "\">" << label << "</a>";
+      } else {
+        javadoc_text << markup;
+      }
+    }
+  } while (!input.empty());
+
+  return javadoc_text.str();
+}
+
+ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
+    const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) {
+  bool iterable = false;
+  Type type = type_resolver->TypeOf(input_def, &iterable);
+  Type var_type = Type::Interface("Operand", "org.tensorflow")
+    .add_parameter(type);
+  if (iterable) {
+    var_type = Type::IterableOf(var_type);
+  }
+  return ArgumentSpec(input_api_def.name(),
+      Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
+      type,
+      ParseDocumentation(input_api_def.description()),
+      iterable);
+}
+
+AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
+    const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
+  bool iterable = false;
+  Type type = type_resolver->TypeOf(attr_def, &iterable);
+  // type attributes must be passed explicitly in methods as a Class<> parameter
+  bool is_explicit = type.kind() == Type::GENERIC && !iterable;
+  Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type;
+  if (iterable) {
+    var_type = Type::ListOf(type);
+  }
+  return AttributeSpec(attr_api_def.name(),
+      Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
+      type,
+      ParseDocumentation(attr_api_def.description()),
+      iterable,
+      attr_api_def.has_default_value() && !is_explicit);
+}
+
+ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
+    const ApiDef::Arg& output_api, TypeResolver* type_resolver) {
+  bool iterable = false;
+  Type type = type_resolver->TypeOf(output_def, &iterable);
+  Type var_type = Type::Class("Output", "org.tensorflow")
+    .add_parameter(type);
+  if (iterable) {
+    var_type = Type::ListOf(var_type);
+  }
+  return ArgumentSpec(output_api.name(),
+      Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
+      type,
+      ParseDocumentation(output_api.description()),
+      iterable);
+}
+
+EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
+    const ApiDef_Endpoint& endpoint_def) {
+
+  std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+  string package;
+  string name;
+  if (name_tokens.size() > 1) {
+    package = name_tokens.at(0);
+    name = name_tokens.at(1);
+  } else {
+    package = "core";  // generate unclassified ops in the 'core' package
+    name = name_tokens.at(0);
+  }
+  return EndpointSpec(package,
+      name,
+      Javadoc::Create(ParseDocumentation(api_def.summary()))
+          .details(ParseDocumentation(api_def.description())),
+      endpoint_def.deprecation_version() > 0);
+}
+
+}  // namespace
+
+OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
+  OpSpec op(api_def.graph_op_name(),
+      api_def.visibility() == ApiDef::HIDDEN,
+      op_def.deprecation().explanation());
+  TypeResolver type_resolver(op_def);
+  for (const string& next_input_name : api_def.arg_order()) {
+    for (int i = 0; i < op_def.input_arg().size(); ++i) {
+      if (op_def.input_arg(i).name() == next_input_name) {
+        op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
+            &type_resolver));
+        break;
+      }
+    }
+  }
+  for (int i = 0; i < op_def.attr().size(); ++i) {
+    // do not parse attributes already visited, they have probably been inferred
+    // before as an input argument type
+    if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
+      AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
+          &type_resolver);
+      // attributes with a default value are optional
+      if (attr.optional()) {
+        op.optional_attributes_.push_back(attr);
+      } else {
+        op.attributes_.push_back(attr);
+      }
+    }
+  }
+  for (int i = 0; i < op_def.output_arg().size(); ++i) {
+    op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i),
+        &type_resolver));
+  }
+  for (const auto& endpoint_def : api_def.endpoint()) {
+    op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
+  }
+  return op;
+}
+
+}  // namespace java
+}  // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
new file mode 100644 (file)
index 0000000..55c2c3f
--- /dev/null
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+
+namespace tensorflow {
+namespace java {
+
+class EndpointSpec {
+ public:
+  // A specification for an operation endpoint
+  //
+  // package: package of this endpoint (from which also derives its package)
+  // name: name of this endpoint class
+  // javadoc: the endpoint class documentation
+  // deprecated: true if this endpoint is now deprecated
+  EndpointSpec(const string& package, const string& name,
+      const Javadoc& javadoc, bool deprecated)
+    : package_(package), name_(name), javadoc_(javadoc),
+      deprecated_(deprecated) {}
+
+  const string& package() const { return package_; }
+  const string& name() const { return name_; }
+  const Javadoc& javadoc() const { return javadoc_; }
+  bool deprecated() const { return deprecated_; }
+
+ private:
+  const string package_;
+  const string name_;
+  const Javadoc javadoc_;
+  const bool deprecated_;
+};
+
+class ArgumentSpec {
+ public:
+  // A specification for an operation argument
+  //
+  // op_def_name: argument name, as known by TensorFlow core
+  // var: a variable to represent this argument in Java
+  // type: the tensor type of this argument
+  // description: a description of this argument, in javadoc
+  // iterable: true if this argument is a list
+  ArgumentSpec(const string& op_def_name, const Variable& var,
+      const Type& type, const string& description, bool iterable)
+    : op_def_name_(op_def_name), var_(var), type_(type),
+      description_(description), iterable_(iterable) {}
+  virtual ~ArgumentSpec() = default;
+
+  const string& op_def_name() const { return op_def_name_; }
+  const Variable& var() const { return var_; }
+  const Type& type() const { return type_; }
+  const string& description() const { return description_; }
+  bool iterable() const { return iterable_; }
+
+ private:
+  const string op_def_name_;
+  const Variable var_;
+  const Type type_;
+  const string description_;
+  const bool iterable_;
+};
+
+class AttributeSpec : public ArgumentSpec {
+ public:
+  // A specification for an operation attribute
+  //
+  // op_def_name: attribute name, as known by TensorFlow core
+  // var: a variable to represent this attribute in Java
+  // type: the type of this attribute
+  // description: a description of this attribute, in javadoc
+  // iterable: true if this attribute is a list
+  // optional: true if this attribute does not require to be set explicitly
+  AttributeSpec(const string& op_def_name, const Variable& var,
+      const Type& type, const string& description, bool iterable,
+      bool optional)
+    : ArgumentSpec(op_def_name, var, type, description, iterable),
+      optional_(optional) {}
+  virtual ~AttributeSpec() = default;
+
+  bool optional() const { return optional_; }
+
+ private:
+  const bool optional_;
+};
+
+class OpSpec {
+ public:
+  // Parses an op definition and its API to produce a specification used for
+  // rendering its Java wrapper
+  //
+  // op_def: Op definition
+  // api_def: Op API definition
+  static OpSpec Create(const OpDef& op_def, const ApiDef& api_def);
+
+  const string& graph_op_name() const { return graph_op_name_; }
+  bool hidden() const { return hidden_; }
+  const string& deprecation_explanation() const {
+    return deprecation_explanation_;
+  }
+  const std::vector<EndpointSpec> endpoints() const { return endpoints_; }
+  const std::vector<ArgumentSpec>& inputs() const { return inputs_; }
+  const std::vector<ArgumentSpec>& outputs() const { return outputs_; }
+  const std::vector<AttributeSpec>& attributes() const { return attributes_; }
+  const std::vector<AttributeSpec>& optional_attributes() const {
+    return optional_attributes_;
+  }
+
+ private:
+  // A specification for an operation
+  //
+  // graph_op_name: name of this op, as known by TensorFlow core engine
+  // hidden: true if this op should not be visible through the Graph Ops API
+  // deprecation_explanation: message to show if all endpoints are deprecated
+  explicit OpSpec(const string& graph_op_name, bool hidden,
+      const string& deprecation_explanation)
+    : graph_op_name_(graph_op_name), hidden_(hidden),
+      deprecation_explanation_(deprecation_explanation) {}
+
+  const string graph_op_name_;
+  const bool hidden_;
+  const string deprecation_explanation_;
+  std::vector<EndpointSpec> endpoints_;
+  std::vector<ArgumentSpec> inputs_;
+  std::vector<ArgumentSpec> outputs_;
+  std::vector<AttributeSpec> attributes_;
+  std::vector<AttributeSpec> optional_attributes_;
+};
+
+}  // namespace java
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
index b1de5af..7e42778 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
index 1f0febe..bcae33c 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
index 8bd42d9..875ad99 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -313,8 +313,7 @@ TEST(WriteType, SimpleClassWithDependencies) {
 TEST(WriteType, AnnotatedAndDocumentedClass) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
-  Javadoc clazz_doc;
-  clazz_doc.brief("Javadoc test")
+  Javadoc clazz_doc = Javadoc::Create("Javadoc test")
       .details("This is a\nmultiline description.");
   clazz.add_annotation(Annotation::Create("Bean"));
   clazz.add_annotation(Annotation::Create("SuppressWarnings")
@@ -329,7 +328,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) {
       " * <p>\n"
       " * This is a\n"
       " * multiline description.\n"
-      " **/\n"
+      " */\n"
       "@Bean\n"
       "@SuppressWarnings(\"rawtypes\")\n"
       "public class Test {\n}\n";
@@ -378,8 +377,7 @@ TEST(WriteType, ParameterizedClassFields) {
   Variable field1 = Variable::Create("field1", Type::Class("String"));
   Variable field2 = Variable::Create("field2", Type::Class("String"));
   Variable field3 = Variable::Create("field3", type_t);
-  Javadoc field3_doc;
-  field3_doc.brief("This variable is documented");
+  Javadoc field3_doc = Javadoc::Create("This variable is documented");
 
   writer.BeginType(clazz, PUBLIC)
           .WriteField(field1, STATIC | PUBLIC | FINAL)
@@ -464,8 +462,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
   SourceBufferWriter writer;
   Type clazz = Type::Class("Test", "org.tensorflow");
   Method method = Method::Create("doNothing", Type::Void());
-  Javadoc method_doc;
-  method_doc.brief("Javadoc test")
+  Javadoc method_doc = Javadoc::Create("Javadoc test")
       .details("This method has a\nmultiline description.");
   method.add_annotation(Annotation::Create("Override"));
   method.add_annotation(Annotation::Create("SuppressWarnings")
@@ -484,7 +481,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
       "   * <p>\n"
       "   * This method has a\n"
       "   * multiline description.\n"
-      "   **/\n"
+      "   */\n"
       "  @Override\n"
       "  @SuppressWarnings(\"rawtypes\")\n"
       "  public void doNothing() {\n"
@@ -500,8 +497,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
   Method method = Method::Create("boolToInt", Type::Int());
   method.add_argument(Variable::Create("b", Type::Boolean()));
   method.add_argument(reverse);
-  Javadoc method_doc;
-  method_doc.brief("Converts a boolean to an int")
+  Javadoc method_doc = Javadoc::Create("Converts a boolean to an int")
       .details("This method will convert\na boolean to an int")
       .add_param_tag(reverse.name(), "if true, value is reversed")
       .add_tag("return", "int value for this boolean");
@@ -528,7 +524,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
       "   * \n"
       "   * @param reverse if true, value is reversed\n"
       "   * @return int value for this boolean\n"
-      "   **/\n"
+      "   */\n"
       "  public int boolToInt(boolean b, boolean reverse) {\n"
       "    if (b && !reverse) {\n"
       "      return 1;\n"
index 1e7899c..7017b52 100644 (file)
@@ -32,50 +32,52 @@ def tf_java_op_gen_srcjar(name,
                           api_def_srcs=[],
                           visibility=["//tensorflow/java:__pkg__"]):
 
-  gen_tools = []
   gen_cmds = ["rm -rf $(@D)"]  # Always start from fresh when generating source files
   srcs = api_def_srcs[:]
 
-  # Construct an op generator binary for each ops library.
-  for ops_lib in ops_libs:
-    gen_lib = ops_lib[:ops_lib.rfind("_")]
-    out_gen_tool = out_dir + ops_lib + "_gen_tool"
+  if not api_def_srcs:
+    api_def_args_str = ","
+  else:
+    api_def_args = []
+    for api_def_src in api_def_srcs:
+      # Add directory of the first ApiDef source to args.
+      # We are assuming all ApiDefs in a single api_def_src are in the
+      # same directory.
+      api_def_args.append(
+          "$$(dirname $$(echo $(locations " + api_def_src +
+          ") | cut -d\" \" -f1))")
+    api_def_args_str = ",".join(api_def_args)
 
-    if not api_def_srcs:
-      api_def_args_str = ","
-    else:
-      api_def_args = []
-      for api_def_src in api_def_srcs:
-        # Add directory of the first ApiDef source to args.
-        # We are assuming all ApiDefs in a single api_def_src are in the
-        # same directory.
-        api_def_args.append(
-            " $$(dirname $$(echo $(locations " + api_def_src +
-            ") | cut -d\" \" -f1))")
-      api_def_args_str = ",".join(api_def_args)
+  gen_tool_deps = [":java_op_gen_lib"]
+  for ops_lib in ops_libs:
+    gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib")
 
-    tf_cc_binary(
-        name=out_gen_tool,
-        copts=tf_copts(),
-        linkopts=["-lm"],
-        linkstatic=1,  # Faster to link this one-time-use binary dynamically
-        deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
+  tf_cc_binary(
+      name=gen_tool,
+      srcs=[
+          "src/gen/cc/op_gen_main.cc",
+      ],
+      copts=tf_copts(),
+      linkopts=["-lm"],
+      linkstatic=1,  # Faster to link this one-time-use binary dynamically
+      deps = gen_tool_deps)
 
-    gen_tools += [":" + out_gen_tool]
-    gen_cmds += ["$(location :" + out_gen_tool + ")" +
-                 " --output_dir=$(@D)/" + out_src_dir +
-                 " --lib_name=" + gen_lib +
-                 " --base_package=" + gen_base_package +
-                 " " + api_def_args_str]
+  gen_cmds += ["$(location :" + gen_tool + ")" +
+               " --output_dir=$(@D)/" + out_src_dir +
+               " --base_package=" + gen_base_package +
+               " --api_dirs=" + api_def_args_str]
 
   # Generate a source archive containing generated code for these ops.
   gen_srcjar = out_dir + name + ".srcjar"
   gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
-  gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
-  gen_tools += tf_binary_additional_srcs()
+
   native.genrule(
       name=name,
       srcs=srcs,
       outs=[gen_srcjar],
-      tools=gen_tools,
-      cmd="&&".join(gen_cmds))
+      tools=[
+          "@local_jdk//:jar",
+          "@local_jdk//:jdk",
+          gen_tool
+      ] + tf_binary_additional_srcs(),
+      cmd=" && ".join(gen_cmds))