From 10fa513e15691681903a472d251fa8eadca1f239 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 31 May 2018 11:43:37 -0700 Subject: [PATCH] [XLA] Make HloInstruction::backend_config() a JSON-encoded protobuf. PiperOrigin-RevId: 198754463 --- tensorflow/compiler/xla/BUILD | 31 ---- tensorflow/compiler/xla/scanner.cc | 197 --------------------- tensorflow/compiler/xla/scanner.h | 102 ----------- tensorflow/compiler/xla/scanner_test.cc | 124 ------------- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/compiler.cc | 5 +- tensorflow/compiler/xla/service/compiler.h | 6 +- .../compiler/xla/service/hlo_graph_dumper.cc | 4 +- tensorflow/compiler/xla/service/hlo_instruction.cc | 36 +++- tensorflow/compiler/xla/service/hlo_instruction.h | 36 +++- tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 2 +- .../compiler/xla/tools/parser/hlo_parser_test.cc | 2 +- tensorflow/core/BUILD | 52 ++++-- tensorflow/core/platform/default/build_config.bzl | 3 + .../core/platform/default/human_readable_json.cc | 54 ++++++ tensorflow/core/platform/human_readable_json.h | 37 ++++ 16 files changed, 202 insertions(+), 490 deletions(-) delete mode 100644 tensorflow/compiler/xla/scanner.cc delete mode 100644 tensorflow/compiler/xla/scanner.h delete mode 100644 tensorflow/compiler/xla/scanner_test.cc create mode 100644 tensorflow/core/platform/default/human_readable_json.cc create mode 100644 tensorflow/core/platform/human_readable_json.h diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c08db7e..c6deb95 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -500,37 +500,6 @@ cc_library( ) cc_library( - name = "scanner", - srcs = ["scanner.cc"], - hdrs = ["scanner.h"], - visibility = [":internal"], - deps = [ - ":status", - ":status_macros", - ":types", - ":util", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "scanner_test", - srcs = ["scanner_test.cc"], - deps = [ - ":scanner", - ":status", - ":status_macros", - ":test", - ":types", - ":util", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test_main", - ], -) - -cc_library( name = "text_literal_reader", srcs = ["text_literal_reader.cc"], hdrs = ["text_literal_reader.h"], diff --git a/tensorflow/compiler/xla/scanner.cc b/tensorflow/compiler/xla/scanner.cc deleted file mode 100644 index f23a141..0000000 --- a/tensorflow/compiler/xla/scanner.cc +++ /dev/null @@ -1,197 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/scanner.h" - -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace xla { -namespace { - -// Returns true if c can be the first character in an identifier. -bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; } - -// Returns true if c can be the non-first character in an identifier. -bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; } - -// Returns true if str is an identifier. -bool IsIdentifier(tensorflow::StringPiece str) { - if (str.empty() || !IsIdentifierFirst(str[0])) { - return false; - } - for (int64 i = 1; i < str.size(); ++i) { - if (!IsIdentifierLater(str[i])) { - return false; - } - } - return true; -} - -} // namespace - -Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {} - -bool Scanner::ok() const { return status().ok(); } - -const Status& Scanner::status() const { return status_; } - -bool Scanner::Match(tensorflow::StringPiece match) { - SkipWhitespace(); - if (ok() && position_ + match.size() <= input_.size() && - std::equal(match.begin(), match.end(), input_.begin() + position_)) { - SkipChars(match.size()); - - VLOG(10) << "Matched \"" << match << "\""; - return true; - } else { - return false; - } -} - -void Scanner::Expect(tensorflow::StringPiece expect) { - if (!Match(expect)) { - SetError(tensorflow::strings::StrCat("Expected \"", expect, "\".")); - } -} - -bool Scanner::MatchReadIdentifier(string* identifier) { - SkipWhitespace(); - if (!IsIdentifierFirst(PeekChar())) { - return false; - } - identifier->clear(); - do { - *identifier += ReadChar(); - } while (IsIdentifierLater(PeekChar())); - - VLOG(10) << "Read identifier " << identifier; - CHECK(IsIdentifier(*identifier)); - return true; -} - -string Scanner::ReadIdentifier() { - string identifier; - if (!MatchReadIdentifier(&identifier)) { - SetError("Expected identifier."); - } - return identifier; -} - -void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) { - CHECK(IsIdentifier(expect)); - - string identifier; - if (!MatchReadIdentifier(&identifier)) { - SetError(tensorflow::strings::StrCat("Expected identifier ", expect, ".")); - } - if (identifier != expect) { - SetError(tensorflow::strings::StrCat("Expected identifier ", expect, - ", but got ", identifier, ".")); - } -} - -// Matches the end of the input, also known as End Of File (EOF). -bool Scanner::MatchEof() { - SkipWhitespace(); - return PeekChar() == EOF; -} - -void Scanner::ExpectEof() { - if (!MatchEof()) { - SetError("Expected end of input."); - } -} - -// Reads a vector of the format "(1, 2, 3)". -std::vector Scanner::ReadIntVector() { - std::vector ints; - Expect("("); - if (!Match(")") && ok()) { - ints.push_back(ReadInt()); - while (Match(",")) { - ints.push_back(ReadInt()); - } - Expect(")"); - } - - VLOG(10) << "Read int vector with " << ints.size() << " elements."; - return ints; -} - -int64 Scanner::ReadInt() { - bool negative = Match("-"); - if (!PeekDigit()) { - SetError("Expected integer."); - return 0; - } - - int64 integer = 0; - do { - integer = (ReadChar() - '0') + integer * 10; - } while (PeekDigit()); - integer = negative ? -integer : integer; - - VLOG(10) << "Read integer " << integer; - return integer; -} - -void Scanner::SkipWhitespace() { - while (PeekWhitespace()) { - SkipChars(1); - } -} - -int Scanner::ReadChar() { - int c = PeekChar(); - SkipChars(1); - - VLOG(20) << "Read char " << c; - return c; -} - -int Scanner::PeekChar() const { - return ok() && position_ < input_.size() ? input_[position_] : EOF; -} - -bool Scanner::PeekDigit() const { - // Do not use std::isdigit since it depends on the locale and we do not - // handle any digits beyond 0-9. - const char c = PeekChar(); - return '0' <= c && c <= '9'; -} - -bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); } - -bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); } - -void Scanner::SkipChars(int64 count) { - CHECK_GE(count, 0); - position_ += count; -} - -void Scanner::SetError(string error_message) { - // Only the first error is recorded since any later errors will likely be a - // consequence of the first error. - if (ok()) { - status_ = InvalidArgumentStrCat(std::move(error_message)); - position_ = input_.size(); - VLOG(10) << "Failed scanner with error " << status_.ToString(); - } else { - VLOG(10) << "Error on already failed scanner is " << error_message; - } -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/scanner.h b/tensorflow/compiler/xla/scanner.h deleted file mode 100644 index 86b04ae..0000000 --- a/tensorflow/compiler/xla/scanner.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_ -#define TENSORFLOW_COMPILER_XLA_SCANNER_H_ - -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace xla { - -// Simple class for parsing data. The concepts for the interface are: -// -// Match(x): Returns true if x is next in the input and in that case skips -// past it. Otherwise returns false. -// -// Expect(x): As Match(x), but requires x to be next in the input. -// -// MatchReadX(x): Returns true if an X is next in the input and in that case -// skips past it and assigns it to x. Otherwise returns false. -// -// ReadX(): As ReadMatchX(), but requires an X to be next in the input and -// returns it. -// -// PeekX(): Returns true if an X is next in the input and does not skip -// past it either way. -// -// All of these, except those that work on individual characters, skip -// whitespace. -// -// If a requirement is not met, the error is available in status(). A Scanner -// with a failed status() will behave as though the rest of the input is EOF and -// will not record further errors after that point. -class Scanner { - public: - Scanner(tensorflow::StringPiece input); - - bool ok() const; - const Status& status() const; - - bool Match(tensorflow::StringPiece match); - void Expect(tensorflow::StringPiece expect); - - // Match-reads an identifier. An identifier starts with an alphabetic - // character or an underscore followed by any number of characters that are - // each alphanumeric or underscore. - bool MatchReadIdentifier(string* identifier); - - string ReadIdentifier(); - - void ExpectIdentifier(tensorflow::StringPiece expect); - - // Matches the end of the input, also known as End Of File (EOF). - bool MatchEof(); - void ExpectEof(); - - // Reads a vector of the format "(1, 4, 5)". - std::vector ReadIntVector(); - - // Reads an integer. Can start with a minus but not a plus. - int64 ReadInt(); - - // Keeps skipping until encountering a non-whitespace character. - void SkipWhitespace(); - - // *** Below here are character-level methods that do not skip whitespace. - - int ReadChar(); - int PeekChar() const; - bool PeekDigit() const; - bool PeekAlnum() const; - bool PeekWhitespace() const; - - // Skip past the next count characters. - void SkipChars(int64 count); - - private: - // Sets a failed status. The input is in effect replaced with EOF after - // this. Only the first error is recorded. - void SetError(string error_message); - - const tensorflow::StringPiece input_; - int64 position_; - Status status_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_ diff --git a/tensorflow/compiler/xla/scanner_test.cc b/tensorflow/compiler/xla/scanner_test.cc deleted file mode 100644 index 10cd0c6..0000000 --- a/tensorflow/compiler/xla/scanner_test.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// TODO(b/80179519): Fix open source build for real. -#if 0 -#include "tensorflow/compiler/xla/scanner.h" - -#include - -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" - -namespace xla { -namespace { - -TEST(Scanner, Empty) { - Scanner scanner(""); - - EXPECT_EQ(scanner.PeekChar(), EOF); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.Match("")); - EXPECT_FALSE(scanner.Match("1")); - EXPECT_TRUE(scanner.ok()); -} - -TEST(Scanner, Prefix) { - Scanner scanner("1234 5"); - EXPECT_FALSE(scanner.MatchEof()); - EXPECT_TRUE(scanner.Match("12")); - EXPECT_TRUE(scanner.Match("34 ")); - EXPECT_FALSE(scanner.MatchEof()); - EXPECT_FALSE(scanner.Match("5 ")); - EXPECT_TRUE(scanner.Match("5")); - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, Whitespace) { - Scanner scanner(" \t\n\r 1\t2\n\n"); - - EXPECT_FALSE(scanner.Match(" ")); - EXPECT_TRUE(scanner.Match("1")); - EXPECT_TRUE(scanner.Match("2")); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.ok()); -} - -TEST(Scanner, Fail) { - Scanner scanner("153 4q"); - - scanner.Expect("5"); - EXPECT_FALSE(scanner.ok()); - EXPECT_FALSE(scanner.status().ok()); - - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, Identifier) { - Scanner scanner("1 q1 _1_ _1a= qqb"); - - string identifier = "foo"; - EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier)); - EXPECT_EQ(identifier, "foo"); - scanner.Match("1"); - - EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier)); - EXPECT_EQ(identifier, "q1"); - - scanner.ExpectIdentifier("_1_"); - EXPECT_TRUE(scanner.ok()); - - scanner.ExpectIdentifier("_1a"); - EXPECT_TRUE(scanner.ok()); - - // The = after _1a is not included in the identifier. - scanner.Expect("="); - - // The expected identifier matches a prefix but is not the full identifier in - // the input. - EXPECT_TRUE(scanner.ok()); - scanner.ExpectIdentifier("qq"); - EXPECT_FALSE(scanner.ok()); -} - -TEST(Scanner, Int) { - Scanner scanner("1_2 3% -1 124345 -363 0 -0"); - EXPECT_EQ(1, scanner.ReadInt()); - EXPECT_TRUE(scanner.Match("_")); - EXPECT_EQ(2, scanner.ReadInt()); - EXPECT_EQ(3, scanner.ReadInt()); - EXPECT_TRUE(scanner.Match("%")); - EXPECT_EQ(-1, scanner.ReadInt()); - EXPECT_EQ(124345, scanner.ReadInt()); - EXPECT_EQ(-363, scanner.ReadInt()); - EXPECT_EQ(0, scanner.ReadInt()); - EXPECT_EQ(0, scanner.ReadInt()); - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, IntVector) { - Scanner scanner("()(0) (-1,2) ( 3 , 4 )"); - EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty()); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0)); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2)); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4)); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.ok()); -} - -} // namespace -} // namespace xla -#endif diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b954bbd..aa41631 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -309,6 +309,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 31f84e8..6f06bba 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,8 +28,9 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -std::vector Compiler::ComputeBackendConfigs( - const HloInstruction& hlo, se::StreamExecutor* executor) const { +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { CHECK(executor != nullptr); return {}; } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index c39db58..6c52ffd 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -161,8 +162,9 @@ class Compiler { // // The stream executor is passed in to provide information about the hardware // that the backend configurations would be targeting. - virtual std::vector ComputeBackendConfigs( - const HloInstruction& hlo, se::StreamExecutor* executor) const; + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 672b1c0..05adb45 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1085,11 +1085,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->backend_config().empty()) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { return ""; } - return StrCat("backend_config=\"", instr->backend_config(), "\""); + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c55e5cf..a68075e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -110,7 +111,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); - instruction->set_backend_config(proto.backend_config()); + instruction->backend_config_ = proto.backend_config(); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -1521,7 +1522,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); - clone->set_backend_config(backend_config()); + clone->set_raw_backend_config_string(backend_config_); if (context != nullptr) { context->MapInstruction(this, clone.get()); clone->ReplaceCalledComputations([&](HloComputation* callee) { @@ -2182,8 +2183,8 @@ string HloInstruction::ToStringWithCanonicalNameMap( !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } - if (options.print_backend_config() && !backend_config().empty()) { - StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); } return result; } @@ -2463,7 +2464,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; - proto.set_backend_config(backend_config()); + proto.set_backend_config(backend_config_); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -3526,6 +3527,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8119c35..72b9d54 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1446,12 +1447,33 @@ class HloInstruction { // this field and they cannot interpret it due to its meaning being backend // specific. // - // TODO(b/78194644): Introduce structured configuration format as per - // go/xla-heuristics. - const string& backend_config() const { return backend_config_; } - void set_backend_config(string backend_config) { - backend_config_ = std::move(backend_config); + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); + } + + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1573,6 +1595,10 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3c1d63a..ef10ca4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1127,7 +1127,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } if (backend_config) { - instruction->set_backend_config(std::move(*backend_config)); + instruction->set_raw_backend_config_string(std::move(*backend_config)); } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index f7a27cf..3c5957b 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1025,7 +1025,7 @@ ENTRY %configuration_test() -> s32[] { EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() ->root_instruction() - ->backend_config()); + ->raw_backend_config_string()); } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3286f85..74f74af 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -101,42 +101,43 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", - "tf_platform_hdrs", - "tf_platform_srcs", - "tf_proto_library", - "tf_proto_library_cc", "tf_additional_all_protos", + "tf_additional_cloud_kernel_deps", + "tf_additional_cloud_op_deps", "tf_additional_core_deps", + "tf_additional_cupti_wrapper_deps", + "tf_additional_device_tracer_cuda_deps", + "tf_additional_device_tracer_deps", + "tf_additional_device_tracer_srcs", + "tf_additional_gdr_lib_defines", + "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", + "tf_additional_libdevice_data", + "tf_additional_libdevice_deps", + "tf_additional_libdevice_srcs", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", "tf_additional_minimal_lib_srcs", + "tf_additional_mpi_lib_defines", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", - "tf_additional_cupti_wrapper_deps", - "tf_additional_libdevice_data", - "tf_additional_libdevice_deps", - "tf_additional_libdevice_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", - "tf_kernel_tests_linkstatic", - "tf_additional_cloud_op_deps", - "tf_additional_cloud_kernel_deps", - "tf_lib_proto_parsing_deps", "tf_additional_verbs_lib_defines", - "tf_additional_mpi_lib_defines", - "tf_additional_gdr_lib_defines", - "tf_additional_device_tracer_srcs", - "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_kernel_tests_linkstatic", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", + "tf_platform_hdrs", + "tf_platform_srcs", + "tf_proto_library", + "tf_proto_library_cc", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", "tf_protos_grappler_impl", + "tf_pyclif_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -400,6 +401,7 @@ cc_library( "protobuf.cc", ]) + [ "platform/protobuf_util.cc", + "lib/core/status.h", ], hdrs = [ ":platform_protobuf_hdrs", @@ -416,6 +418,18 @@ cc_library( ], ) +cc_library( + name = "human_readable_json", + srcs = tf_platform_srcs(["human_readable_json.cc"]), + hdrs = ["platform/human_readable_json.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_human_readable_json_deps(), +) + filegroup( name = "platform_env_hdrs", srcs = [ @@ -2013,6 +2027,7 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", ], ) + tf_additional_lib_srcs( @@ -2025,6 +2040,7 @@ cc_library( "platform/**/env_time.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 23c594d..43fe82c 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -515,6 +515,9 @@ def tf_additional_proto_srcs(): "platform/default/protobuf.cc", ] +def tf_additional_human_readable_json_deps(): + return [] + def tf_additional_all_protos(): return ["//tensorflow/core:protos_all"] diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc new file mode 100644 index 0000000..6bf2106 --- /dev/null +++ b/tensorflow/core/platform/default/human_readable_json.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/human_readable_json.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +Status ProtoToHumanReadableJson(const ::google::protobuf::Message& proto, + string* result) { + result->clear(); + + auto status = google::protobuf::util::MessageToJsonString(proto, result); + if (!status.ok()) { + // Convert error_msg google::protobuf::StringPiece to + // tensorflow::StringPiece. + auto error_msg = status.error_message(); + return errors::Internal( + strings::StrCat("Could not convert proto to JSON string: ", + StringPiece(error_msg.data(), error_msg.length()))); + } + return Status::OK(); +} + +Status HumanReadableJsonToProto(const string& str, + ::google::protobuf::Message* proto) { + proto->Clear(); + auto status = google::protobuf::util::JsonStringToMessage(str, proto); + if (!status.ok()) { + // Convert error_msg google::protobuf::StringPiece to + // tensorflow::StringPiece. + auto error_msg = status.error_message(); + return errors::Internal( + strings::StrCat("Could not convert JSON string to proto: ", + StringPiece(error_msg.data(), error_msg.length()))); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h new file mode 100644 index 0000000..c759e80 --- /dev/null +++ b/tensorflow/core/platform/human_readable_json.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ +#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Converts a proto to a JSON-like string that's meant to be human-readable +// but still machine-parseable. +// +// This string may not be strictly JSON-compliant, but it must be parseable by +// HumanReadableJSONToProto. +Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result); + +// Converts a string produced by ProtoToHumanReadableJSON to a protobuf. Not +// guaranteed to work for general JSON. +Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ -- 2.7.4