From fe9b2637cfe39cf11eb3d0494948a733b7fc1d7d Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 29 Mar 2018 05:28:16 +0800 Subject: [PATCH] Parse op definition and generate a Java Op class. --- tensorflow/java/BUILD | 4 + tensorflow/java/src/gen/cc/java_defs.h | 76 ++-- tensorflow/java/src/gen/cc/op_gen_main.cc | 22 +- tensorflow/java/src/gen/cc/op_generator.cc | 406 ++++++++++++++++++-- tensorflow/java/src/gen/cc/op_generator.h | 42 ++- tensorflow/java/src/gen/cc/op_parser.cc | 417 +++++++++++++++++++++ tensorflow/java/src/gen/cc/op_parser.h | 137 +++++++ tensorflow/java/src/gen/cc/source_writer.cc | 127 ++++--- tensorflow/java/src/gen/cc/source_writer.h | 55 ++- tensorflow/java/src/gen/cc/source_writer_test.cc | 82 ++-- tensorflow/java/src/gen/gen_ops.bzl | 29 +- .../java/src/gen/resources/license.snippet.java | 14 + 12 files changed, 1201 insertions(+), 210 deletions(-) create mode 100644 tensorflow/java/src/gen/cc/op_parser.cc create mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/resources/license.snippet.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index ab7d698..635a4e8 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -70,6 +70,7 @@ filegroup( tf_java_op_gen_srcjar( name = "java_op_gen_sources", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -111,11 +112,13 @@ cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", + "src/gen/cc/op_parser.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", + "src/gen/cc/op_parser.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), @@ -124,6 +127,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 59f8bea..2065477 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include namespace tensorflow { namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { + PACKAGE = 0, PUBLIC = (1 << 0), PROTECTED = (1 << 1), PRIVATE = (1 << 2), @@ -72,6 +75,12 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } + static Type Generic(const string& name) { + return Type(Type::GENERIC, name); + } + static Type Wildcard() { + return Type(Type::GENERIC, ""); + } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -81,9 +90,6 @@ class Type { static Type Enum(const string& name, const string& package = "") { return Type(Type::ENUM, name, package); } - static Type Generic(const string& name = "") { - return Type(Type::GENERIC, name); - } static Type ClassOf(const Type& type) { return Class("Class").add_parameter(type); } @@ -96,11 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string& description() const { return description_; } - Type& description(const string& description) { - description_ = description; - return *this; + const string full_name() const { + return package_.empty() ? name_ : package_ + "." + name_; } + bool unknown() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); @@ -120,14 +125,6 @@ class Type { } return *this; } - // Returns true if "type" is of a known collection type (only a few for now) - bool IsCollection() const { - return name_ == "List" || name_ == "Iterable"; - } - // Returns true if this instance is a wildcard () - bool IsWildcard() const { - return kind_ == GENERIC && name_.empty(); - } protected: Type(Kind kind, const string& name, const string& package = "") @@ -137,7 +134,6 @@ class Type { Kind kind_; string name_; string package_; - string description_; std::list parameters_; std::list annotations_; std::list supertypes_; @@ -180,16 +176,11 @@ class Variable { const string& name() const { return name_; } const Type& type() const { return type_; } bool variadic() const { return variadic_; } - const string& description() const { return description_; } - Variable& description(const string& description) { - description_ = description; - return *this; - } + private: string name_; Type type_; bool variadic_; - string description_; Variable(const string& name, const Type& type, bool variadic) : name_(name), type_(type), variadic_(variadic) {} @@ -210,16 +201,6 @@ class Method { bool constructor() const { return constructor_; } const string& name() const { return name_; } const Type& return_type() const { return return_type_; } - const string& description() const { return description_; } - Method& description(const string& description) { - description_ = description; - return *this; - } - const string& return_description() const { return return_description_; } - Method& return_description(const string& description) { - return_description_ = description; - return *this; - } const std::list& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); @@ -235,8 +216,6 @@ class Method { string name_; Type return_type_; bool constructor_; - string description_; - string return_description_; std::list arguments_; std::list annotations_; @@ -244,6 +223,35 @@ class Method { : name_(name), return_type_(return_type), constructor_(constructor) {} }; +// A definition of a documentation bloc for a Java element (JavaDoc) +class Javadoc { + public: + static Javadoc Create(const string& brief = "") { + return Javadoc(brief); + } + const string& brief() const { return brief_; } + const string& details() const { return description_; } + Javadoc& details(const string description) { + description_ = description; + return *this; + } + const std::list> tags() const { return tags_; } + Javadoc& add_tag(const string& tag, const string& text) { + tags_.push_back(std::make_pair(tag, text)); + return *this; + } + Javadoc& add_param_tag(const string& name, const string& text) { + return add_tag("param", name + " " + text); + } + + private: + string brief_; + string description_; + std::list> tags_; + + explicit Javadoc(const string& brief) : brief_(brief) {} +}; + } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index bea99f3..0152000 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -48,8 +48,11 @@ const char kUsageHeader[] = "through\n" "the 'org.tensorflow.op.Ops' API as a group until the generated classes " "are compiled using an appropriate annotation processor.\n\n" - "Finally, the '--base_package' overrides the default parent package " - "under which the generated subpackage and classes are to be located.\n\n"; + "The '--base_package' overrides the default parent package under which " + "the generated subpackage and classes are to be located.\n\n" + "Finally, a list of directories of API proto definitions can be provided " + "to override default values found in the ops definitions, ordered by\n" + "priority (the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow @@ -60,7 +63,7 @@ int main(int argc, char* argv[]) { tensorflow::string base_package = "org.tensorflow.op"; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, - "Root directory into which output files are generated"), + "Root directory into which output files are generated"), tensorflow::Flag( "lib_name", &lib_name, "A name, in snake_case, used to classify this set of operations"), @@ -72,12 +75,15 @@ int main(int argc, char* argv[]) { bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - - tensorflow::java::OpGenerator generator; + std::vector api_dirs; + if (argc > 1) { + api_dirs = tensorflow::str_util::Split(argv[1], ",", + tensorflow::str_util::SkipEmpty()); + } + tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; - tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = - generator.Run(ops, lib_name, base_package, output_dir); + tensorflow::OpRegistry::Global()->Export(false, &ops); + tensorflow::Status status = generator.Run(ops, lib_name); TF_QCHECK_OK(status); return 0; diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index def06ba..c9b57f5 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -14,53 +14,409 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" namespace tensorflow { namespace java { namespace { -string CamelCase(const string& str, char delimiter, bool upper) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == delimiter) { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; +const char* kLicenseSnippet = + "tensorflow/java/src/gen/resources/license.snippet.java"; + +const std::map kPrimitiveAttrTypes = { + { "Boolean", Type::Boolean() }, + { "Byte", Type::Byte() }, + { "Character", Type::Byte() }, + { "Float", Type::Float() }, + { "Integer", Type::Long() }, + { "Long", Type::Long() }, + { "Short", Type::Long() }, + { "Double", Type::Float() }, +}; + +enum RenderMode { + DEFAULT, + SINGLE_OUTPUT, + SINGLE_LIST_OUTPUT +}; + +void CollectOpDependencies(const OpSpec& op, RenderMode mode, + std::list* out) { + out->push_back(Type::Class("Operation", "org.tensorflow")); + out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); + out->push_back(Type::Class("Scope", "org.tensorflow.op")); + if (mode == SINGLE_OUTPUT) { + out->push_back(Type::Class("Output", "org.tensorflow")); + } else if (mode == SINGLE_LIST_OUTPUT) { + out->push_back(Type::Interface("Iterator", "java.util")); + } + // Don't pay attention to duplicate types in the dependency list, they will + // be filtered out by the SourceWriter. + for (const OpSpec::Operand& input : op.inputs()) { + out->push_back(input.var().type()); + if (input.iterable()) { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } + } + for (const OpSpec::Operand& output : op.outputs()) { + out->push_back(output.var().type()); + if (output.iterable()) { + out->push_back(Type::Class("Arrays", "java.util")); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + out->push_back(attribute.var().type()); + if (attribute.var().type().name() == "Class") { + out->push_back(Type::Enum("DataType", "org.tensorflow")); + } + } + for (const OpSpec::Operand& option : op.options()) { + out->push_back(option.var().type()); + } +} + +void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, + SourceWriter* writer) { + string var = optional ? "opts." + attr.var().name() : attr.var().name(); + if (attr.iterable()) { + const Type& type = attr.data_type(); + std::map::const_iterator it = + kPrimitiveAttrTypes.find(type.name()); + if (it != kPrimitiveAttrTypes.end()) { + string array = attr.var().name() + "Array"; + writer->AppendType(it->second) + .Append("[] " + array + " = new ") + .AppendType(it->second) + .Append("[" + var + ".size()];") + .EndLine(); + writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") + .Append(array + "[i] = " + var + ".get(i);") + .EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append(");") + .EndLine(); } else { - result += c; + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + .Append(".toArray(new ") + .AppendType(type) + .Append("[" + var + ".size()]));") + .EndLine(); } + } else { + Type type = attr.var().type(); + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + if (type.name() == "Class") { + writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + } else { + writer->Append(var + ");"); + } + writer->EndLine(); } - return result; } -} // namespace +void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method factory = Method::Create("create", op_class); + Javadoc factory_doc = Javadoc::Create( + "Factory method to create a class to wrap a new " + op_class.name() + + " operation to the graph."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + factory.add_argument(scope); + factory_doc.add_param_tag(scope.name(), "Current graph scope"); + for (const OpSpec::Operand& input : op.inputs()) { + factory.add_argument(input.var()); + factory_doc.add_param_tag(input.var().name(), input.description()); + } + for (const OpSpec::Operand& attribute : op.attributes()) { + factory.add_argument(attribute.var()); + factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + } + if (!op.options().empty()) { + factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); + factory_doc.add_param_tag("options", "carries optional attributes values"); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); + writer->EndLine(); -OpGenerator::OpGenerator() : env(Env::Default()) {} + for (const OpSpec::Operand& input : op.inputs()) { + if (input.iterable()) { + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); + writer->EndLine(); + } else { + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); + writer->EndLine(); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + WriteSetAttrDirective(attribute, false, writer); + } + if (!op.options().empty()) { + writer->BeginBlock("if (options != null)") + .BeginBlock("for (Options opts : options)"); + for (const OpSpec::Operand& option : op.options()) { + writer->BeginBlock("if (opts." + option.var().name() + " != null)"); + WriteSetAttrDirective(option, true, writer); + writer->EndBlock(); + } + writer->EndBlock().EndBlock(); + } + writer->Append("return new ") + .AppendType(op_class) + .Append("(opBuilder.build());") + .EndLine(); + writer->EndMethod(); +} -OpGenerator::~OpGenerator() {} +void RenderConstructor(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method constructor = Method::ConstructorFor(op_class) + .add_argument( + Variable::Create("operation", + Type::Class("Operation", "org.tensorflow"))); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable() && !output.data_type().unknown()) { + constructor.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + break; + } + } + writer->BeginMethod(constructor, PRIVATE) + .Append("super(operation);") + .EndLine(); + if (op.outputs().size() > 0) { + writer->Append("int outputIdx = 0;") + .EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable()) { + string var_length = output.var().name() + "Length"; + writer->Append("int " + var_length) + .Append(" = operation.outputListLength(\"" + output.graph_name() + + "\");") + .EndLine() + .Append(output.var().name() + " = Arrays.asList("); + if (!output.data_type().unknown()) { + writer->Append("(") + .AppendType(output.var().type().parameters().front()) + .Append("[])"); + } + writer->Append("operation.outputList(outputIdx, " + var_length + "));") + .EndLine() + .Append("outputIdx += " + var_length + ";") + .EndLine(); + } else { + writer->Append(output.var().name() + + " = operation.output(outputIdx++);") + .EndLine(); + } + } + } + writer->EndMethod(); +} -Status OpGenerator::Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir) { - const string package = - base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); - const string package_path = - output_dir + '/' + str_util::StringReplace(package, ".", "/", true); - const string group = CamelCase(lib_name, '_', false); +void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), Type::Class("Options")) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) + .Append("return new Options()." + option.var().name() + "(" + + option.var().name() + ");") + .EndLine() + .EndMethod(); + } + for (const OpSpec::Operand& output : op.outputs()) { + Method getter = Method::Create(output.var().name(), output.var().type()); + Javadoc getter_doc = Javadoc::Create(output.description()); + writer->BeginMethod(getter, PUBLIC, &getter_doc) + .Append("return " + output.var().name() + ";") + .EndLine() + .EndMethod(); + } +} + +void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, + SourceWriter* writer) { + OpSpec::Operand output = op.outputs().front(); + + if (mode == SINGLE_OUTPUT) { + bool cast2obj = output.data_type().unknown(); + Type return_type = Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + Method as_output = Method::Create("asOutput", return_type) + .add_annotation(Annotation::Create("Override")); + if (cast2obj) { + as_output.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + } + writer->BeginMethod(as_output, PUBLIC); + if (cast2obj) { + writer->Append("return (").AppendType(return_type).Append(") "); + } else { + writer->Append("return "); + } + writer->Append(output.var().name() + ";") + .EndLine() + .EndMethod(); + + } else if (mode == SINGLE_LIST_OUTPUT) { + Type operand = Type::Interface("Operand", "org.tensorflow"); + if (output.data_type().unknown()) { + operand.add_parameter(Type::Class("Object")); + } else { + operand.add_parameter(output.data_type()); + } + Type return_type = Type::Interface("Iterator", "java.util") + .add_parameter(operand); + Method iterator = Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); + // cast the output list using a raw List + writer->BeginMethod(iterator, PUBLIC) + .Append("return (" + return_type.name() + ") ") + .Append(output.var().name() + ".iterator();") + .EndLine() + .EndMethod(); + } +} + +void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { + Type options_class = Type::Class("Options"); + Javadoc options_doc = Javadoc::Create( + "Class holding optional attributes of this operation"); + writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), options_class) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC, &setter_doc) + .Append("this." + option.var().name() + " = " + option.var().name() + + ";") + .EndLine() + .Append("return this;") + .EndLine() + .EndMethod(); + } + writer->EndLine(); + for (const OpSpec::Operand& option : op.options()) { + writer->WriteField(option.var(), PRIVATE); + } + Method constructor = Method::ConstructorFor(options_class); + writer->BeginMethod(constructor, PRIVATE).EndMethod(); + writer->EndType(); +} - if (!env->FileExists(package_path).ok()) { - TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); +void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, + SourceWriter* writer) { + RenderMode mode = DEFAULT; + if (op.outputs().size() == 1) { + mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + } + std::list dependencies; + CollectOpDependencies(op, mode, &dependencies); + const Type& op_class = endpoint.type(); + writer->WriteFromFile(kLicenseSnippet) + .EndLine() + .Append("// This file is machine generated, DO NOT EDIT!") + .EndLine() + .EndLine() + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); + if (!op.options().empty()) { + RenderOptionsClass(op, writer); } + RenderFactoryMethod(op, op_class, writer); + RenderGettersAndSetters(op, writer); + if (mode != DEFAULT) { + RenderInterfaceImpl(op, mode, writer); + } + writer->EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + writer->WriteField(output.var(), PRIVATE); + } + RenderConstructor(op, op_class, writer); + writer->EndType(); +} + +} // namespace + +OpGenerator::OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env) + : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), + env_(env) { +} +Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; - // TODO(karllessard) generate wrappers from list of ops + ApiDefMap api_map(op_list); + if (!api_dirs_.empty()) { + // Only load api files that correspond to the requested "op_list" + for (const auto& op : op_list.op()) { + for (const auto& api_def_dir : api_dirs_) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env_->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + } + } + } + } + api_map.UpdateDocs(); + for (const auto& op_def : op_list.op()) { + const ApiDef* api_def = api_map.GetApiDef(op_def.name()); + if (api_def->visibility() != ApiDef::SKIP) { + Status status = GenerateOp(op_def, *api_def, lib_name); + if (status != Status::OK()) { + LOG(ERROR) << "Fail to generate Java wrapper for operation \"" + << op_def.name() << "\""; + } + } + } + return Status::OK(); +} + +Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name) { + std::unique_ptr op; + OpParser op_parser(op_def, api_def, lib_name, base_package_); + op_parser.Parse(&op); + for (const OpSpec::Endpoint& endpoint : op->endpoints()) { + string package_path = io::JoinPath(output_dir_, + str_util::StringReplace(endpoint.type().package(), ".", "/", true)); + if (!env_->FileExists(package_path).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); + } + string file_path = + io::JoinPath(package_path, endpoint.type().name() + ".java"); + std::unique_ptr file; + TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); + SourceFileWriter writer(file.get()); + RenderEndpoint(*op, endpoint, &writer); + } return Status::OK(); } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 4b55ed3..19d8db9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -17,34 +17,42 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ #include +#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace java { -/// \brief A generator of Java operation wrappers. -/// -/// Such generator is normally ran only once per executable, outputting -/// wrappers for the all registered operations it has been compiled with. -/// Nonetheless, it is designed to support multiple runs, giving a different -/// list of operations on each cycle. +// A generator of Java operation wrappers. +// +// Such generator is normally ran only once per executable, outputting +// wrappers for the all registered operations it has been compiled with. +// Nonetheless, it is designed to support multiple runs, giving a different +// list of operations on each cycle. class OpGenerator { public: - OpGenerator(); - virtual ~OpGenerator(); + OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env = Env::Default()); + virtual ~OpGenerator() = default; - /// \brief Generates wrappers for the given list of 'ops'. - /// - /// Output files are generated in //, - /// where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir); + // Generates wrappers for the given list of 'ops'. + // + // Output files are generated in //, + // where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& op_list, const string& lib_name); private: - Env* env; + string base_package_; + string output_dir_; + std::vector api_dirs_; + Env* env_; + + Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc new file mode 100644 index 0000000..0541e34 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_parser.cc @@ -0,0 +1,417 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" + +namespace tensorflow { +namespace java { +namespace { + +string SnakeToCamelCase(const string& str, bool upper = false) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == '_') { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +bool IsRealNumber(DataType type) { + for (DataType dt : RealNumberTypes()) { + if (type == dt) { + return true; + } + } + return false; +} + +bool IsRealNumbers(const AttrValue& values) { + if (values.has_list()) { + for (int i = 0; i < values.list().type_size(); ++i) { + if (!IsRealNumber(values.list().type(i))) { + return false; + } + } + return true; + } + return IsRealNumber(values.type()); +} + +string ParseDocumentation(const string& text) { + std::stringstream javadoc_text; + string::const_iterator c_iter = text.cbegin(); + bool code = false; + bool emphasis = false; + bool list = false; + while (c_iter != text.cend()) { + char c = *c_iter++; + int count = 1; + switch (c) { + case '\n': + if (!code) { + // consumes all subsequent newlines, if there are more than one, + // then there are two choices: + // - if the next line starts with an asterisk, we are enumerating + // a list of items + // - otherwise, we are starting a new paragraph + for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} + if (c_iter != text.cend()) { + if (count > 1) { + if (*c_iter != '*' && list) { + javadoc_text << "\n\n"; + list = false; + } else if (*c_iter == '*' && !list) { + javadoc_text << "\n
    \n
  • "; + list = true; + c_iter++; + } else { + javadoc_text << "\n

    \n"; + } + } else if (list && *c_iter == '*') { + javadoc_text << "

  • \n
  • "; + c_iter++; + } else { + javadoc_text << '\n'; + } + } + } + break; + case '`': + // consumes all subsequent backquotes, those are use enclose code. + // if there are more than 3, we are dealing with a pre-formatted block, + // otherwise it is a single-line code snippet + for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} + if (count >= 3) { + javadoc_text << (code ? "\n}" : "
    {@code\n");
    +      } else {
    +        javadoc_text << (code ? "}" : "{@code ");
    +      }
    +      code = !code;
    +      break;
    +    case '*':
    +      if (!code) {
    +        // consumes all subsequent asterisks, if there are more than one, then
    +        // we put the text in bold, otherwise in italic
    +        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
    +        if (count > 1) {
    +          javadoc_text << (emphasis ? "" : "");
    +        } else {
    +          javadoc_text << (emphasis ? "" : "");
    +        }
    +        emphasis = !emphasis;
    +      } else {
    +        javadoc_text << '*';
    +      }
    +      break;
    +    default:
    +      javadoc_text << c;
    +      break;
    +    }
    +  }
    +  return javadoc_text.str();
    +}
    +
    +}  // namespace
    +
    +OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
    +    const string& lib_name, const string& base_package)
    +  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
    +    base_package_(base_package) {
    +}
    +
    +void OpParser::Parse(std::unique_ptr* op_ptr) {
    +  visited_attrs_.clear();
    +  next_generic_ = 'T';
    +  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
    +  for (const string& next_input_name : op_api_.arg_order()) {
    +    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
    +      if (op_def_.input_arg(i).name() == next_input_name) {
    +        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
    +        break;
    +      }
    +    }
    +  }
    +  for (int i = 0; i < op_def_.attr().size(); ++i) {
    +    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
    +  }
    +  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
    +    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
    +  }
    +  BuildEndpoints(op_ptr->get());
    +}
    +
    +void OpParser::BuildEndpoints(OpSpec* op) {
    +  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
    +    .details(ParseDocumentation(op_api_.description()));
    +  std::vector op_supertypes;
    +  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
    +  std::map op_generics;
    +  for (const OpSpec::Operand& output : op->outputs()) {
    +    // declare generic output parameters at the Op class level
    +    const Type& data_type = output.data_type();
    +    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
    +        && op_generics.find(data_type.name()) == op_generics.end()) {
    +      op_generics.insert(std::make_pair(data_type.name(), &data_type));
    +      op_doc.add_param_tag("<" + data_type.name() + ">",
    +          "data type of output '" + output.var().name() + "'");
    +    }
    +    // implement the Op as an (iteration of) Operand if it has only one output
    +    if (op->outputs().size() == 1) {
    +      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
    +      operand_inf.add_parameter(data_type.unknown() ?
    +          Type::Class("Object") : data_type);
    +      op_supertypes.push_back(output.iterable() ?
    +          Type::IterableOf(operand_inf) : operand_inf);
    +    }
    +  }
    +  for (const auto& endpoint_def : op_api_.endpoint()) {
    +    std::vector name_tokens = str_util::Split(endpoint_def.name(), ".");
    +    // if the endpoint specifies a package, use it, otherwise derive it from the
    +    // op library name.
    +    string name;
    +    string package;
    +    if (name_tokens.size() > 1) {
    +      package = str_util::Lowercase(name_tokens.at(0));
    +      name = name_tokens.at(1);
    +    } else {
    +      package = str_util::StringReplace(lib_name_, "_", "", true);
    +      name = name_tokens.at(0);
    +    }
    +    Type endpoint(Type::Class(name, base_package_ + "." + package));
    +    Javadoc endpoint_doc(op_doc);
    +    for (const auto& parameter : op_generics) {
    +      endpoint.add_parameter(*parameter.second);
    +    }
    +    for (const Type& supertype : op_supertypes) {
    +      endpoint.add_supertype(supertype);
    +    }
    +    if (endpoint_def.deprecation_version() > 0) {
    +      string explanation;
    +      if (op_api_.endpoint(0).deprecation_version() == 0) {
    +        explanation = ", use {@link "
    +            + op->endpoints().at(0).type().full_name()
    +            + "} instead";
    +      } else {
    +        explanation = op_def_.deprecation().explanation();
    +      }
    +      endpoint_doc.add_tag("deprecated", explanation);
    +      endpoint.add_annotation(Annotation::Create("Deprecated"));
    +    }
    +    // only visible ops should be annotated for exposure in the Ops Graph API
    +    if (op_api_.visibility() != ApiDef::HIDDEN) {
    +      string group_name = SnakeToCamelCase(lib_name_);
    +      endpoint.add_annotation(
    +          Annotation::Create("Operator", "org.tensorflow.op.annotation")
    +            .attributes("group = \"" + group_name + "\""));
    +    }
    +    op->add_endpoint(endpoint, endpoint_doc);
    +  }
    +}
    +
    +void OpParser::ParseInput(const OpDef_ArgDef& input_def,
    +    const ApiDef::Arg& input_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(input_def, &iterable);
    +  Type type = Type::Interface("Operand", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::IterableOf(type);
    +  }
    +  op->add_input(OpSpec::Operand(input_api.name(),
    +      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(input_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
    +    const ApiDef::Arg& output_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(output_def, &iterable);
    +  Type type = Type::Class("Output", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::ListOf(type);
    +  }
    +  op->add_output(OpSpec::Operand(output_api.name(),
    +      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(output_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
    +    const ApiDef::Attr& attr_api, OpSpec* op) {
    +  // do not parse attributes already visited, they have probably been inferred
    +  // before as an input argument type
    +  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
    +    return;
    +  }
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(attr_def, &iterable);
    +  // generic attributes should be passed as an explicit type
    +  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
    +  Type type = explicit_type ?
    +      Type::Class("Class").add_parameter(data_type) : data_type;
    +  if (iterable) {
    +    type = Type::ListOf(data_type);
    +  }
    +  OpSpec::Operand attr(attr_api.name(),
    +      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(attr_api.description()),
    +      iterable);
    +  // attributes with a default value are optional
    +  if (attr_api.has_default_value() && !explicit_type) {
    +    op->add_option(attr);
    +  } else {
    +    op->add_attribute(attr);
    +  }
    +  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
    +  if (!arg.number_attr().empty()) {
    +    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
    +    *iterable_out = true;
    +  }
    +  if (arg.type() != DataType::DT_INVALID) {
    +    // resolve type from DataType
    +    switch (arg.type()) {
    +      case DataType::DT_BOOL:
    +        return Type::Class("Boolean");
    +
    +      case DataType::DT_STRING:
    +        return Type::Class("String");
    +
    +      case DataType::DT_FLOAT:
    +        return Type::Class("Float");
    +
    +      case DataType::DT_DOUBLE:
    +        return Type::Class("Double");
    +
    +      case DataType::DT_UINT8:
    +        return Type::Class("UInt8", "org.tensorflow.types");
    +
    +      case DataType::DT_INT32:
    +        return Type::Class("Integer");
    +
    +      case DataType::DT_INT64:
    +        return Type::Class("Long");
    +
    +      case DataType::DT_RESOURCE:
    +        // TODO(karllessard) create a Resource utility class that could be
    +        // used to store a resource and its type (passed in a second argument).
    +        // For now, we need to force a wildcard and we will unfortunately lose
    +        // track of the resource type.
    +        return Type::Wildcard();
    +
    +      default:
    +        break;
    +    }
    +  } else {
    +    // resolve type from type attribute
    +    string attr_name = arg.type_attr();
    +    if (attr_name.empty()) {
    +      attr_name = arg.type_list_attr();
    +      if (!attr_name.empty()) {
    +        *iterable_out = true;
    +        Type type = Type::Wildcard();
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +    for (const auto& attr : op_def_.attr()) {
    +      if (attr.name() == attr_name) {
    +        Type type = DataTypeOf(attr, iterable_out);
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +  }
    +  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
    +  return Type::Wildcard();
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
    +  std::map::const_iterator it = visited_attrs_.find(attr.name());
    +  if (it != visited_attrs_.cend()) {
    +    return it->second;
    +  }
    +  string attr_type = attr.type();
    +  if (attr.type().compare(0, 5, "list(") == 0) {
    +    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
    +    *iterable_out = true;
    +  }
    +  if (attr_type == "type") {
    +    if (*iterable_out) {
    +      return Type::Enum("DataType", "org.tensorflow");
    +    }
    +    return GetNextGenericTensorType(attr.allowed_values());
    +  }
    +  if (attr_type == "string") {
    +    return Type::Class("String");
    +  }
    +  if (attr_type == "int") {
    +    return Type::Class("Integer");
    +  }
    +  if (attr_type == "float") {
    +    return Type::Class("Float");
    +  }
    +  if (attr_type == "bool") {
    +    return Type::Class("Boolean");
    +  }
    +  if (attr_type == "shape") {
    +    return Type::Class("Shape", "org.tensorflow");
    +  }
    +  if (attr_type == "tensor") {
    +    return Type::Class("Tensor", "org.tensorflow")
    +      .add_parameter(Type::Wildcard());
    +  }
    +  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
    +  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
    +}
    +
    +Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
    +  Type generic = Type::Generic(string(1, next_generic_));
    +  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    +
    +  // when only real numbers are allowed, enforce that restriction in the Java by
    +  // extending the generic from java.lang.Number
    +  if (IsRealNumbers(allowed_values)) {
    +    generic.add_supertype(Type::Class("Number"));
    +  }
    +  return generic;
    +}
    +
    +}  // namespace java
    +}  // namespace tensorflow
    diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
    new file mode 100644
    index 0000000..4285512
    --- /dev/null
    +++ b/tensorflow/java/src/gen/cc/op_parser.h
    @@ -0,0 +1,137 @@
    +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    +
    +Licensed under the Apache License, Version 2.0 (the "License");
    +you may not use this file except in compliance with the License.
    +You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +Unless required by applicable law or agreed to in writing, software
    +distributed under the License is distributed on an "AS IS" BASIS,
    +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +See the License for the specific language governing permissions and
    +limitations under the License.
    +==============================================================================*/
    +
    +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +
    +#include 
    +#include 
    +#include 
    +#include 
    +
    +#include "tensorflow/core/framework/op_def.pb.h"
    +#include "tensorflow/core/framework/api_def.pb.h"
    +#include "tensorflow/java/src/gen/cc/java_defs.h"
    +
    +namespace tensorflow {
    +namespace java {
    +
    +// Specification of a TensorFlow operation to generate.
    +//
    +// This is the result of an operation definition parsing, see OpParser::Parse().
    +class OpSpec {
    + public:
    +  class Endpoint {
    +   public:
    +    Endpoint(const Type& type, const Javadoc& javadoc)
    +      : type_(type), javadoc_(javadoc) {}
    +    const Type& type() const { return type_; }
    +    const Javadoc& javadoc() const { return javadoc_; }
    +
    +   private:
    +    Type type_;
    +    Javadoc javadoc_;
    +  };
    +
    +  class Operand {
    +   public:
    +    Operand(const string& graph_name, const Variable& var,
    +        const Type& data_type, const string& description, bool iterable)
    +     : graph_name_(graph_name), var_(var), data_type_(data_type),
    +       description_(description), iterable_(iterable) {}
    +    const string& graph_name() const { return graph_name_; }
    +    const Variable& var() const { return var_; }
    +    Variable* var_ptr() { return &var_; }
    +    const Type& data_type() const { return data_type_; }
    +    const string& description() const { return description_; }
    +    bool iterable() const { return iterable_; }
    +
    +   private:
    +    string graph_name_;
    +    Variable var_;
    +    Type data_type_;
    +    string description_;
    +    bool iterable_;
    +  };
    +
    +  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
    +  const string& graph_name() const { return graph_name_; }
    +  const std::vector endpoints() const { return endpoints_; }
    +  void add_endpoint(const Type& type, const Javadoc& javadoc) {
    +    endpoints_.push_back(Endpoint(type, javadoc));
    +  }
    +  const std::vector& inputs() const { return inputs_; }
    +  void add_input(const Operand& input) {
    +    inputs_.push_back(input);
    +  }
    +  const std::vector& outputs() const { return outputs_; }
    +  void add_output(const Operand& output) {
    +    outputs_.push_back(output);
    +  }
    +  const std::vector& attributes() const { return attributes_; }
    +  void add_attribute(const Operand& attribute) {
    +    attributes_.push_back(attribute);
    +  }
    +  const std::vector& options() const { return options_; }
    +  void add_option(const Operand& option) {
    +    options_.push_back(option);
    +  }
    +
    + private:
    +  string graph_name_;
    +  std::vector endpoints_;
    +  std::vector inputs_;
    +  std::vector outputs_;
    +  std::vector attributes_;
    +  std::vector options_;
    +};
    +
    +// A parser of ops proto definitions.
    +//
    +// This object parses the definition and the api of an TensorFlow operation to
    +// produce a specification that can be used for Java source code rendering.
    +class OpParser {
    + public:
    +  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
    +      const string& base_package);
    +  virtual ~OpParser() = default;
    +
    +  // Produces an operation specification from its proto definitions.
    +  void Parse(std::unique_ptr* op_ptr);
    +
    + private:
    +  OpDef op_def_;
    +  ApiDef op_api_;
    +  string lib_name_;
    +  string base_package_;
    +  std::map visited_attrs_;
    +  char next_generic_ = 0;
    +
    +  void BuildEndpoints(OpSpec* op);
    +  void ParseInput(const OpDef_ArgDef& input_def,
    +      const ApiDef::Arg& input_api, OpSpec* op);
    +  void ParseOutput(const OpDef_ArgDef& output_def,
    +      const ApiDef::Arg& output_api, OpSpec* op);
    +  void ParseAttribute(const OpDef_AttrDef& attr_def,
    +      const ApiDef::Attr& attr_api, OpSpec* op);
    +  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    +  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    +  Type GetNextGenericTensorType(const AttrValue& allowed_values);
    +};
    +
    +}  // namespace java
    +}  // namespace tensorflow
    +
    +#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
    index a02f75a..b1de5af 100644
    --- a/tensorflow/java/src/gen/cc/source_writer.cc
    +++ b/tensorflow/java/src/gen/cc/source_writer.cc
    @@ -15,7 +15,7 @@ limitations under the License.
     
     #include 
     #include 
    -#include 
    +#include 
     
     #include "tensorflow/java/src/gen/cc/source_writer.h"
     
    @@ -83,20 +83,20 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
     }
     
     SourceWriter& SourceWriter::AppendType(const Type& type) {
    -  if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
    +  if (type.unknown()) {
         Append("?");
       } else {
         Append(type.name());
    -  }
    -  if (!type.parameters().empty()) {
    -    Append("<");
    -    for (const Type& t : type.parameters()) {
    -      if (&t != &type.parameters().front()) {
    -        Append(", ");
    +    if (!type.parameters().empty()) {
    +      Append("<");
    +      for (const Type& t : type.parameters()) {
    +        if (&t != &type.parameters().front()) {
    +          Append(", ");
    +        }
    +        AppendType(t);
           }
    -      AppendType(t);
    +      Append(">");
         }
    -    Append(">");
       }
       return *this;
     }
    @@ -107,7 +107,21 @@ SourceWriter& SourceWriter::EndLine() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
    +SourceWriter& SourceWriter::BeginBlock(const string& expression) {
    +  if (!expression.empty()) {
    +    Append(expression + " {");
    +  } else {
    +    Append(newline_ ? "{" : " {");
    +  }
    +  return EndLine().Indent(2);
    +}
    +
    +SourceWriter& SourceWriter::EndBlock() {
    +  return Indent(-2).Append("}").EndLine();
    +}
    +
    +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       if (!method.constructor()) {
         generic_namespace->Visit(method.return_type());
    @@ -116,8 +130,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
         generic_namespace->Visit(v.type());
       }
       EndLine();
    -  WriteDoc(method.description(), method.return_description(),
    -      &method.arguments());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!method.annotations().empty()) {
         WriteAnnotations(method.annotations());
       }
    @@ -145,29 +160,35 @@ SourceWriter& SourceWriter::EndMethod() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginType(const Type& type,
    -    const std::list* dependencies, int modifiers) {
    +SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
    +    const std::list* extra_dependencies, const Javadoc* javadoc) {
       if (!type.package().empty()) {
         Append("package ").Append(type.package()).Append(";").EndLine();
       }
    -  if (dependencies != nullptr && !dependencies->empty()) {
    -    TypeImporter type_importer(type.package());
    -    for (const Type& t : *dependencies) {
    +  TypeImporter type_importer(type.package());
    +  type_importer.Visit(type);
    +  if (extra_dependencies != nullptr) {
    +    for (const Type& t : *extra_dependencies) {
           type_importer.Visit(t);
         }
    +  }
    +  if (!type_importer.imports().empty()) {
         EndLine();
         for (const string& s : type_importer.imports()) {
           Append("import ").Append(s).Append(";").EndLine();
         }
       }
    -  return BeginInnerType(type, modifiers);
    +  return BeginInnerType(type, modifiers, javadoc);
     }
     
    -SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
    +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       generic_namespace->Visit(type);
       EndLine();
    -  WriteDoc(type.description());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!type.annotations().empty()) {
         WriteAnnotations(type.annotations());
       }
    @@ -200,14 +221,15 @@ SourceWriter& SourceWriter::EndType() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteFields(const std::list& fields,
    -    int modifiers) {
    -  EndLine();
    -  for (const Variable& v : fields) {
    -    WriteModifiers(modifiers);
    -    AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
    -    EndLine();
    +SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
    +    const Javadoc* javadoc) {
    +  // If present, write field javadoc only as one brief line
    +  if (javadoc != nullptr && !javadoc->brief().empty()) {
    +    Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
       }
    +  WriteModifiers(modifiers);
    +  AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
    +  EndLine();
       return *this;
     }
     
    @@ -228,39 +250,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteDoc(const string& description,
    -    const string& return_description, const std::list* parameters) {
    -  if (description.empty() && return_description.empty()
    -      && (parameters == nullptr || parameters->empty())) {
    -    return *this;  // no doc to write
    -  }
    +SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
    +  Append("/**").Prefix(" * ").EndLine();
       bool do_line_break = false;
    -  Append("/**").EndLine().Prefix(" * ");
    -  if (!description.empty()) {
    -    Write(description).EndLine();
    +  if (!javadoc.brief().empty()) {
    +    Write(javadoc.brief()).EndLine();
         do_line_break = true;
       }
    -  if (parameters != nullptr && !parameters->empty()) {
    +  if (!javadoc.details().empty()) {
         if (do_line_break) {
    -      EndLine();
    -      do_line_break = false;
    -    }
    -    for (const Variable& v : *parameters) {
    -      Append("@param ").Append(v.name());
    -      if (!v.description().empty()) {
    -        Append(" ").Write(v.description());
    -      }
    -      EndLine();
    +      Append("

    ").EndLine(); } + Write(javadoc.details()).EndLine(); + do_line_break = true; } - if (!return_description.empty()) { + if (!javadoc.tags().empty()) { if (do_line_break) { EndLine(); - do_line_break = false; } - Append("@return ").Write(return_description).EndLine(); + for (const auto& p : javadoc.tags()) { + Append("@" + p.first); + if (!p.second.empty()) { + Append(" ").Write(p.second); + } + EndLine(); + } } - return Prefix("").Append(" **/").EndLine(); + return Prefix("").Append(" */").EndLine(); } SourceWriter& SourceWriter::WriteAnnotations( @@ -311,20 +327,19 @@ void SourceWriter::PopGenericNamespace() { void SourceWriter::TypeVisitor::Visit(const Type& type) { DoVisit(type); for (const Type& t : type.parameters()) { - DoVisit(t); + Visit(t); } for (const Annotation& t : type.annotations()) { DoVisit(t); } for (const Type& t : type.supertypes()) { - DoVisit(t); + Visit(t); } } void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC - && !type.IsWildcard() + if (type.kind() == Type::GENERIC && !type.unknown() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -333,7 +348,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.package() + '.' + type.name()); + imports_.insert(type.full_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index f011acd..1f0febe 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,25 +93,22 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock() { - return Append(newline_ ? "{" : " {").EndLine().Indent(2); - } + SourceWriter& BeginBlock(const string& expr = ""); // Ends the current block of source code. // // This method appends a new closing brace to the current data and outdent the // next lines back to the margin used before BeginBlock() was invoked. - SourceWriter& EndBlock() { - return Indent(-2).Append("}").EndLine(); - } + SourceWriter& EndBlock(); // Begins to write a method. // // This method outputs the signature of the Java method from the data passed - // in the 'method' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // method. - SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + // in the 'method' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this method and, optionally, + // a Javadoc. + SourceWriter& BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -122,22 +119,24 @@ class SourceWriter { // Begins to write the main type of a source file. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this type and, optionally, + // a Javadoc. // - // If not null, all types found in the 'dependencies' list will be imported - // before declaring the new type. - SourceWriter& BeginType(const Type& clazz, - const std::list* dependencies, int modifiers = 0); + // If not null, all types found in the 'extra_dependencies' list will be + // imported before declaring the new type. + SourceWriter& BeginType(const Type& clazz, int modifiers, + const std::list* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. - SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the accesses and the scope of this type and, + // optionally, a Javadoc. + SourceWriter& BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -145,13 +144,13 @@ class SourceWriter { // BeginType() or BeginInnerType() prior to this. SourceWriter& EndType(); - // Writes a list of variables as fields of a type. + // Writes a variable as fields of a type. // // This method must be called within the definition of a type (see BeginType() - // or BeginInnerType()). Additional modifiers can also be passed in parameter - // to define the accesses and the scope of those fields. - SourceWriter& WriteFields(const std::list& fields, - int modifiers = 0); + // or BeginInnerType()). Modifiers are also be passed in parameter to define + // the accesses and the scope of this field and, optionally, a Javadoc. + SourceWriter& WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; @@ -207,9 +206,7 @@ class SourceWriter { std::stack generic_namespaces_; SourceWriter& WriteModifiers(int modifiers); - SourceWriter& WriteDoc(const string& description, - const string& return_description = "", - const std::list* parameters = nullptr); + SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); GenericNamespace* PushGenericNamespace(int modifiers); diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 4bce2fe..8bd42d9 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -250,7 +250,7 @@ TEST(StreamTest, Types) { .AppendType(generic).Append(", ") .AppendType(Type::ListOf(generic)).Append(", ") .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Generic())); + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List, List>, List"; @@ -282,7 +282,7 @@ TEST(WriteType, SimpleClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -300,7 +300,7 @@ TEST(WriteType, SimpleClassWithDependencies) { deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); deps.push_back(Type::Class("NoPackageType")); - writer.BeginType(clazz, &deps, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, &deps).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -313,18 +313,21 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - clazz.description("This class has a\n

    \nmultiline description."); + Javadoc clazz_doc; + clazz_doc.brief("Javadoc test") + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType(); const char* expected = "package org.tensorflow;\n\n" "/**\n" - " * This class has a\n" + " * Javadoc test\n" " *

    \n" + " * This is a\n" " * multiline description.\n" " **/\n" "@Bean\n" @@ -339,7 +342,7 @@ TEST(WriteType, ParameterizedClass) { clazz.add_parameter(Type::Generic("T")); clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -358,7 +361,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) { clazz.add_supertype(Type::Interface("Runnable")); clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -372,24 +375,24 @@ TEST(WriteType, ParameterizedClassFields) { Type clazz = Type::Class("Test", "org.tensorflow"); Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); clazz.add_parameter(type_t); - std::list static_fields; - static_fields.push_back(Variable::Create("field1", Type::Class("String"))); - std::list member_fields; - member_fields.push_back(Variable::Create("field2", Type::Class("String"))); - member_fields.push_back(Variable::Create("field3", type_t)); - - writer.BeginType(clazz, nullptr, PUBLIC) - .WriteFields(static_fields, STATIC | PUBLIC | FINAL) - .WriteFields(member_fields, PRIVATE) + Variable field1 = Variable::Create("field1", Type::Class("String")); + Variable field2 = Variable::Create("field2", Type::Class("String")); + Variable field3 = Variable::Create("field3", type_t); + Javadoc field3_doc; + field3_doc.brief("This variable is documented"); + + writer.BeginType(clazz, PUBLIC) + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test {\n" - " \n" " public static final String field1;\n" - " \n" " private String field2;\n" + " /** This variable is documented */\n" " private T field3;\n" "}\n"; ASSERT_STREQ(expected, writer.str().data()); @@ -400,7 +403,7 @@ TEST(WriteType, SimpleInnerClass) { Type clazz = Type::Class("Test", "org.tensorflow"); Type inner_class = Type::Class("InnerTest"); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC) .EndType() .EndType(); @@ -423,7 +426,7 @@ TEST(WriteType, StaticParameterizedInnerClass) { Type inner_class = Type::Class("InnerTest"); inner_class.add_parameter(type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC | STATIC) .EndType() .EndType(); @@ -443,7 +446,7 @@ TEST(WriteMethod, SimpleMethod) { Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC).EndMethod() .EndType(); @@ -461,13 +464,15 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - method.description("This method has a\n

    \nmultiline description."); + Javadoc method_doc; + method_doc.brief("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc).EndMethod() .EndType(); const char* expected = @@ -475,8 +480,9 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { "public class Test {\n" " \n" " /**\n" - " * This method has a\n" + " * Javadoc test\n" " *

    \n" + " * This method has a\n" " * multiline description.\n" " **/\n" " @Override\n" @@ -490,16 +496,18 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { TEST(WriteMethod, DocumentedMethodWithArguments) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); + Variable reverse = Variable::Create("reverse", Type::Boolean()); Method method = Method::Create("boolToInt", Type::Int()); - method.description("Converts a boolean to an int"); - method.return_description("int value for this boolean"); method.add_argument(Variable::Create("b", Type::Boolean())); - Variable reverse = Variable::Create("reverse", Type::Boolean()); - reverse.description("if true, value is reversed"); method.add_argument(reverse); - - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) + Javadoc method_doc; + method_doc.brief("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); + + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) .Append("if (b && !reverse)") .BeginBlock() .Append("return 1;").EndLine() @@ -514,8 +522,10 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " \n" " /**\n" " * Converts a boolean to an int\n" + " *

    \n" + " * This method will convert\n" + " * a boolean to an int\n" " * \n" - " * @param b\n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" " **/\n" @@ -536,7 +546,7 @@ TEST(WriteMethod, ParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC) .Append("return null;").EndLine() .EndMethod() @@ -560,7 +570,7 @@ TEST(WriteMethod, StaticParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC | STATIC) .Append("return null;").EndLine() .EndMethod() diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index a6650fc..1e7899c 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -1,9 +1,11 @@ # -*- Python -*- -load("//tensorflow:tensorflow.bzl", - "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", + "tf_cc_binary", + "tf_copts", +) # Given a list of "ops_libs" (a list of files in the core/ops directory # without their .cc extensions), generate Java wrapper code for all operations @@ -27,16 +29,31 @@ def tf_java_op_gen_srcjar(name, ops_libs_pkg="//tensorflow/core", out_dir="ops/", out_src_dir="src/main/java/", + api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] # Construct an op generator binary for each ops library. for ops_lib in ops_libs: gen_lib = ops_lib[:ops_lib.rfind("_")] out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + " $$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) + tf_cc_binary( name=out_gen_tool, copts=tf_copts(), @@ -48,7 +65,8 @@ def tf_java_op_gen_srcjar(name, gen_cmds += ["$(location :" + out_gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package] + " --base_package=" + gen_base_package + + " " + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" @@ -57,6 +75,7 @@ def tf_java_op_gen_srcjar(name, gen_tools += tf_binary_additional_srcs() native.genrule( name=name, + srcs=srcs, outs=[gen_srcjar], tools=gen_tools, cmd="&&".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java new file mode 100644 index 0000000..90285ec --- /dev/null +++ b/tensorflow/java/src/gen/resources/license.snippet.java @@ -0,0 +1,14 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ -- 2.7.4