[XLA] Make HloInstruction::backend_config() a JSON-encoded protobuf.
authorJustin Lebar <jlebar@google.com>
Thu, 31 May 2018 18:43:37 +0000 (11:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 18:46:25 +0000 (11:46 -0700)
PiperOrigin-RevId: 198754463

16 files changed:
tensorflow/compiler/xla/BUILD
tensorflow/compiler/xla/scanner.cc [deleted file]
tensorflow/compiler/xla/scanner.h [deleted file]
tensorflow/compiler/xla/scanner_test.cc [deleted file]
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/compiler.cc
tensorflow/compiler/xla/service/compiler.h
tensorflow/compiler/xla/service/hlo_graph_dumper.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
tensorflow/core/BUILD
tensorflow/core/platform/default/build_config.bzl
tensorflow/core/platform/default/human_readable_json.cc [new file with mode: 0644]
tensorflow/core/platform/human_readable_json.h [new file with mode: 0644]

index c08db7e..c6deb95 100644 (file)
@@ -500,37 +500,6 @@ cc_library(
 )
 
 cc_library(
-    name = "scanner",
-    srcs = ["scanner.cc"],
-    hdrs = ["scanner.h"],
-    visibility = [":internal"],
-    deps = [
-        ":status",
-        ":status_macros",
-        ":types",
-        ":util",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-    ],
-)
-
-tf_cc_test(
-    name = "scanner_test",
-    srcs = ["scanner_test.cc"],
-    deps = [
-        ":scanner",
-        ":status",
-        ":status_macros",
-        ":test",
-        ":types",
-        ":util",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:lib_internal",
-        "//tensorflow/core:test_main",
-    ],
-)
-
-cc_library(
     name = "text_literal_reader",
     srcs = ["text_literal_reader.cc"],
     hdrs = ["text_literal_reader.h"],
diff --git a/tensorflow/compiler/xla/scanner.cc b/tensorflow/compiler/xla/scanner.cc
deleted file mode 100644 (file)
index f23a141..0000000
+++ /dev/null
@@ -1,197 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/scanner.h"
-
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace xla {
-namespace {
-
-// Returns true if c can be the first character in an identifier.
-bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; }
-
-// Returns true if c can be the non-first character in an identifier.
-bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; }
-
-// Returns true if str is an identifier.
-bool IsIdentifier(tensorflow::StringPiece str) {
-  if (str.empty() || !IsIdentifierFirst(str[0])) {
-    return false;
-  }
-  for (int64 i = 1; i < str.size(); ++i) {
-    if (!IsIdentifierLater(str[i])) {
-      return false;
-    }
-  }
-  return true;
-}
-
-}  // namespace
-
-Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {}
-
-bool Scanner::ok() const { return status().ok(); }
-
-const Status& Scanner::status() const { return status_; }
-
-bool Scanner::Match(tensorflow::StringPiece match) {
-  SkipWhitespace();
-  if (ok() && position_ + match.size() <= input_.size() &&
-      std::equal(match.begin(), match.end(), input_.begin() + position_)) {
-    SkipChars(match.size());
-
-    VLOG(10) << "Matched \"" << match << "\"";
-    return true;
-  } else {
-    return false;
-  }
-}
-
-void Scanner::Expect(tensorflow::StringPiece expect) {
-  if (!Match(expect)) {
-    SetError(tensorflow::strings::StrCat("Expected \"", expect, "\"."));
-  }
-}
-
-bool Scanner::MatchReadIdentifier(string* identifier) {
-  SkipWhitespace();
-  if (!IsIdentifierFirst(PeekChar())) {
-    return false;
-  }
-  identifier->clear();
-  do {
-    *identifier += ReadChar();
-  } while (IsIdentifierLater(PeekChar()));
-
-  VLOG(10) << "Read identifier " << identifier;
-  CHECK(IsIdentifier(*identifier));
-  return true;
-}
-
-string Scanner::ReadIdentifier() {
-  string identifier;
-  if (!MatchReadIdentifier(&identifier)) {
-    SetError("Expected identifier.");
-  }
-  return identifier;
-}
-
-void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) {
-  CHECK(IsIdentifier(expect));
-
-  string identifier;
-  if (!MatchReadIdentifier(&identifier)) {
-    SetError(tensorflow::strings::StrCat("Expected identifier ", expect, "."));
-  }
-  if (identifier != expect) {
-    SetError(tensorflow::strings::StrCat("Expected identifier ", expect,
-                                         ", but got ", identifier, "."));
-  }
-}
-
-// Matches the end of the input, also known as End Of File (EOF).
-bool Scanner::MatchEof() {
-  SkipWhitespace();
-  return PeekChar() == EOF;
-}
-
-void Scanner::ExpectEof() {
-  if (!MatchEof()) {
-    SetError("Expected end of input.");
-  }
-}
-
-// Reads a vector of the format "(1, 2, 3)".
-std::vector<int64> Scanner::ReadIntVector() {
-  std::vector<int64> ints;
-  Expect("(");
-  if (!Match(")") && ok()) {
-    ints.push_back(ReadInt());
-    while (Match(",")) {
-      ints.push_back(ReadInt());
-    }
-    Expect(")");
-  }
-
-  VLOG(10) << "Read int vector with " << ints.size() << " elements.";
-  return ints;
-}
-
-int64 Scanner::ReadInt() {
-  bool negative = Match("-");
-  if (!PeekDigit()) {
-    SetError("Expected integer.");
-    return 0;
-  }
-
-  int64 integer = 0;
-  do {
-    integer = (ReadChar() - '0') + integer * 10;
-  } while (PeekDigit());
-  integer = negative ? -integer : integer;
-
-  VLOG(10) << "Read integer " << integer;
-  return integer;
-}
-
-void Scanner::SkipWhitespace() {
-  while (PeekWhitespace()) {
-    SkipChars(1);
-  }
-}
-
-int Scanner::ReadChar() {
-  int c = PeekChar();
-  SkipChars(1);
-
-  VLOG(20) << "Read char " << c;
-  return c;
-}
-
-int Scanner::PeekChar() const {
-  return ok() && position_ < input_.size() ? input_[position_] : EOF;
-}
-
-bool Scanner::PeekDigit() const {
-  // Do not use std::isdigit since it depends on the locale and we do not
-  // handle any digits beyond 0-9.
-  const char c = PeekChar();
-  return '0' <= c && c <= '9';
-}
-
-bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); }
-
-bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); }
-
-void Scanner::SkipChars(int64 count) {
-  CHECK_GE(count, 0);
-  position_ += count;
-}
-
-void Scanner::SetError(string error_message) {
-  // Only the first error is recorded since any later errors will likely be a
-  // consequence of the first error.
-  if (ok()) {
-    status_ = InvalidArgumentStrCat(std::move(error_message));
-    position_ = input_.size();
-    VLOG(10) << "Failed scanner with error " << status_.ToString();
-  } else {
-    VLOG(10) << "Error on already failed scanner is " << error_message;
-  }
-}
-
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/scanner.h b/tensorflow/compiler/xla/scanner.h
deleted file mode 100644 (file)
index 86b04ae..0000000
+++ /dev/null
@@ -1,102 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_
-#define TENSORFLOW_COMPILER_XLA_SCANNER_H_
-
-#include "tensorflow/compiler/xla/status.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-namespace xla {
-
-// Simple class for parsing data. The concepts for the interface are:
-//
-//   Match(x): Returns true if x is next in the input and in that case skips
-//     past it. Otherwise returns false.
-//
-//   Expect(x): As Match(x), but requires x to be next in the input.
-//
-//   MatchReadX(x): Returns true if an X is next in the input and in that case
-//     skips past it and assigns it to x. Otherwise returns false.
-//
-//   ReadX(): As ReadMatchX(), but requires an X to be next in the input and
-//     returns it.
-//
-//   PeekX(): Returns true if an X is next in the input and does not skip
-//     past it either way.
-//
-// All of these, except those that work on individual characters, skip
-// whitespace.
-//
-// If a requirement is not met, the error is available in status(). A Scanner
-// with a failed status() will behave as though the rest of the input is EOF and
-// will not record further errors after that point.
-class Scanner {
- public:
-  Scanner(tensorflow::StringPiece input);
-
-  bool ok() const;
-  const Status& status() const;
-
-  bool Match(tensorflow::StringPiece match);
-  void Expect(tensorflow::StringPiece expect);
-
-  // Match-reads an identifier. An identifier starts with an alphabetic
-  // character or an underscore followed by any number of characters that are
-  // each alphanumeric or underscore.
-  bool MatchReadIdentifier(string* identifier);
-
-  string ReadIdentifier();
-
-  void ExpectIdentifier(tensorflow::StringPiece expect);
-
-  // Matches the end of the input, also known as End Of File (EOF).
-  bool MatchEof();
-  void ExpectEof();
-
-  // Reads a vector of the format "(1, 4, 5)".
-  std::vector<int64> ReadIntVector();
-
-  // Reads an integer. Can start with a minus but not a plus.
-  int64 ReadInt();
-
-  // Keeps skipping until encountering a non-whitespace character.
-  void SkipWhitespace();
-
-  // *** Below here are character-level methods that do not skip whitespace.
-
-  int ReadChar();
-  int PeekChar() const;
-  bool PeekDigit() const;
-  bool PeekAlnum() const;
-  bool PeekWhitespace() const;
-
-  // Skip past the next count characters.
-  void SkipChars(int64 count);
-
- private:
-  // Sets a failed status. The input is in effect replaced with EOF after
-  // this. Only the first error is recorded.
-  void SetError(string error_message);
-
-  const tensorflow::StringPiece input_;
-  int64 position_;
-  Status status_;
-};
-
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SCANNER_H_
diff --git a/tensorflow/compiler/xla/scanner_test.cc b/tensorflow/compiler/xla/scanner_test.cc
deleted file mode 100644 (file)
index 10cd0c6..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// TODO(b/80179519): Fix open source build for real.
-#if 0
-#include "tensorflow/compiler/xla/scanner.h"
-
-#include <string>
-
-#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/env.h"
-
-namespace xla {
-namespace {
-
-TEST(Scanner, Empty) {
-  Scanner scanner("");
-
-  EXPECT_EQ(scanner.PeekChar(), EOF);
-  EXPECT_TRUE(scanner.MatchEof());
-  EXPECT_TRUE(scanner.Match(""));
-  EXPECT_FALSE(scanner.Match("1"));
-  EXPECT_TRUE(scanner.ok());
-}
-
-TEST(Scanner, Prefix) {
-  Scanner scanner("1234 5");
-  EXPECT_FALSE(scanner.MatchEof());
-  EXPECT_TRUE(scanner.Match("12"));
-  EXPECT_TRUE(scanner.Match("34 "));
-  EXPECT_FALSE(scanner.MatchEof());
-  EXPECT_FALSE(scanner.Match("5 "));
-  EXPECT_TRUE(scanner.Match("5"));
-  EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, Whitespace) {
-  Scanner scanner(" \t\n\r 1\t2\n\n");
-
-  EXPECT_FALSE(scanner.Match(" "));
-  EXPECT_TRUE(scanner.Match("1"));
-  EXPECT_TRUE(scanner.Match("2"));
-  EXPECT_TRUE(scanner.MatchEof());
-  EXPECT_TRUE(scanner.ok());
-}
-
-TEST(Scanner, Fail) {
-  Scanner scanner("153 4q");
-
-  scanner.Expect("5");
-  EXPECT_FALSE(scanner.ok());
-  EXPECT_FALSE(scanner.status().ok());
-
-  EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, Identifier) {
-  Scanner scanner("1 q1  _1_ _1a= qqb");
-
-  string identifier = "foo";
-  EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier));
-  EXPECT_EQ(identifier, "foo");
-  scanner.Match("1");
-
-  EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier));
-  EXPECT_EQ(identifier, "q1");
-
-  scanner.ExpectIdentifier("_1_");
-  EXPECT_TRUE(scanner.ok());
-
-  scanner.ExpectIdentifier("_1a");
-  EXPECT_TRUE(scanner.ok());
-
-  // The = after _1a is not included in the identifier.
-  scanner.Expect("=");
-
-  // The expected identifier matches a prefix but is not the full identifier in
-  // the input.
-  EXPECT_TRUE(scanner.ok());
-  scanner.ExpectIdentifier("qq");
-  EXPECT_FALSE(scanner.ok());
-}
-
-TEST(Scanner, Int) {
-  Scanner scanner("1_2 3% -1 124345 -363 0 -0");
-  EXPECT_EQ(1, scanner.ReadInt());
-  EXPECT_TRUE(scanner.Match("_"));
-  EXPECT_EQ(2, scanner.ReadInt());
-  EXPECT_EQ(3, scanner.ReadInt());
-  EXPECT_TRUE(scanner.Match("%"));
-  EXPECT_EQ(-1, scanner.ReadInt());
-  EXPECT_EQ(124345, scanner.ReadInt());
-  EXPECT_EQ(-363, scanner.ReadInt());
-  EXPECT_EQ(0, scanner.ReadInt());
-  EXPECT_EQ(0, scanner.ReadInt());
-  EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, IntVector) {
-  Scanner scanner("()(0) (-1,2) ( 3 , 4 )");
-  EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty());
-  EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0));
-  EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2));
-  EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4));
-  EXPECT_TRUE(scanner.MatchEof());
-  EXPECT_TRUE(scanner.ok());
-}
-
-}  // namespace
-}  // namespace xla
-#endif
index b954bbd..aa41631 100644 (file)
@@ -309,6 +309,7 @@ cc_library(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:window_util",
         "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/core:human_readable_json",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
     ],
index 31f84e8..6f06bba 100644 (file)
@@ -28,8 +28,9 @@ namespace xla {
 /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
     tensorflow::LINKER_INITIALIZED);
 
-std::vector<string> Compiler::ComputeBackendConfigs(
-    const HloInstruction& hlo, se::StreamExecutor* executor) const {
+std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
+Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
+                                se::StreamExecutor* executor) const {
   CHECK(executor != nullptr);
   return {};
 }
index c39db58..6c52ffd 100644 (file)
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/thread_annotations.h"
 
@@ -161,8 +162,9 @@ class Compiler {
   //
   // The stream executor is passed in to provide information about the hardware
   // that the backend configurations would be targeting.
-  virtual std::vector<string> ComputeBackendConfigs(
-      const HloInstruction& hlo, se::StreamExecutor* executor) const;
+  virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
+  ComputeBackendConfigs(const HloInstruction& hlo,
+                        se::StreamExecutor* executor) const;
 
   // Compiles the HLO module for ahead-of-time execution.  This is intended for
   // use in static compilation.
index 672b1c0..05adb45 100644 (file)
@@ -1085,11 +1085,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
 
 string HloDotDumper::GetInstructionNodeBackendConfig(
     const HloInstruction* instr) {
-  if (!show_backend_config_ || instr->backend_config().empty()) {
+  if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
     return "";
   }
 
-  return StrCat("backend_config=\"", instr->backend_config(), "\"");
+  return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
 }
 
 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
index c55e5cf..a68075e 100644 (file)
@@ -41,6 +41,7 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/human_readable_json.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace xla {
@@ -110,7 +111,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
   instruction->name_ = proto.name();
 
   instruction->metadata_ = proto.metadata();
-  instruction->set_backend_config(proto.backend_config());
+  instruction->backend_config_ = proto.backend_config();
   if (proto.has_literal()) {
     TF_ASSIGN_OR_RETURN(instruction->literal_,
                         Literal::CreateFromProto(proto.literal()));
@@ -1521,7 +1522,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
   }
   SetupDerivedInstruction(clone.get());
   clone->set_parent(parent_);
-  clone->set_backend_config(backend_config());
+  clone->set_raw_backend_config_string(backend_config_);
   if (context != nullptr) {
     context->MapInstruction(this, clone.get());
     clone->ReplaceCalledComputations([&](HloComputation* callee) {
@@ -2182,8 +2183,8 @@ string HloInstruction::ToStringWithCanonicalNameMap(
        !metadata_.source_file().empty())) {
     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
   }
-  if (options.print_backend_config() && !backend_config().empty()) {
-    StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
+  if (options.print_backend_config() && !backend_config_.empty()) {
+    StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
   }
   return result;
 }
@@ -2463,7 +2464,7 @@ HloInstructionProto HloInstruction::ToProto() const {
   }
 
   *proto.mutable_metadata() = metadata_;
-  proto.set_backend_config(backend_config());
+  proto.set_backend_config(backend_config_);
   if (literal_ != nullptr) {
     *proto.mutable_literal() = literal_->ToProto();
   }
@@ -3526,6 +3527,31 @@ bool HloInstruction::CouldBeBitcast() const {
   }
 }
 
+Status HloInstruction::GetBackendConfigInternal(
+    tensorflow::protobuf::Message* proto) const {
+  proto->Clear();
+
+  // Empty string does not parse as valid JSON, but it's a valid backend config,
+  // corresponding to the empty proto.
+  if (backend_config_.empty()) {
+    return Status::OK();
+  }
+  return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
+}
+
+Status HloInstruction::set_backend_config(
+    const tensorflow::protobuf::Message& proto) {
+  TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
+  return Status::OK();
+}
+
+/* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
+    const tensorflow::protobuf::Message& proto) {
+  string ret;
+  TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
+  return ret;
+}
+
 HloModule* HloInstruction::GetModule() const {
   if (parent_) {
     return parent_->parent();
index 8119c35..72b9d54 100644 (file)
@@ -52,6 +52,7 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/iterator_range.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
@@ -1446,12 +1447,33 @@ class HloInstruction {
   // this field and they cannot interpret it due to its meaning being backend
   // specific.
   //
-  // TODO(b/78194644): Introduce structured configuration format as per
-  // go/xla-heuristics.
-  const string& backend_config() const { return backend_config_; }
-  void set_backend_config(string backend_config) {
-    backend_config_ = std::move(backend_config);
+  // ConfigProto should be a protobuf Message type.
+  template <typename ConfigProto>
+  StatusOr<ConfigProto> backend_config() const {
+    ConfigProto proto;
+    TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
+    return std::move(proto);
   }
+  Status set_backend_config(const tensorflow::protobuf::Message& proto);
+
+  // Getter/setter for raw JSON-encoded backend config.  Prefer the
+  // functions above that deal in proto Messages where possible.
+  const string& raw_backend_config_string() const { return backend_config_; }
+  void set_raw_backend_config_string(string config_str) {
+    backend_config_ = std::move(config_str);
+  }
+
+  // Returns a string representation of a proto in the format used by
+  // raw_backend_config_string.
+  //
+  // This is morally equivalent to:
+  //
+  //   HloInstruction instr;
+  //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
+  //   return instr.raw_backend_config_string();
+  //
+  static StatusOr<string> BackendConfigToRawString(
+      const tensorflow::protobuf::Message& proto);
 
   // Sets the debug metadata for this instruction.
   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1573,6 +1595,10 @@ class HloInstruction {
   // Returns how this instruction uses elements of its `i`th operand.
   UseKind OperandElementUse(int64 i) const;
 
+  // Helper for implementing backend_config().  Parses backend_config_ into the
+  // given proto.
+  Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
+
   int unique_id_;  // Unique to this HloInstruction within a HloModule
 
   // Opcode for this instruction.
index 3c1d63a..ef10ca4 100644 (file)
@@ -1127,7 +1127,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
     instruction->set_metadata(*metadata);
   }
   if (backend_config) {
-    instruction->set_backend_config(std::move(*backend_config));
+    instruction->set_raw_backend_config_string(std::move(*backend_config));
   }
   return AddInstruction(name, instruction, name_loc);
 }  // NOLINT(readability/fn_size)
index f7a27cf..3c5957b 100644 (file)
@@ -1025,7 +1025,7 @@ ENTRY %configuration_test() -> s32[] {
   EXPECT_EQ("foo bar", result.ValueOrDie()
                            ->entry_computation()
                            ->root_instruction()
-                           ->backend_config());
+                           ->raw_backend_config_string());
 }
 
 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
index 3286f85..74f74af 100644 (file)
@@ -101,42 +101,43 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
 # For platform specific build config
 load(
     "//tensorflow/core:platform/default/build_config.bzl",
-    "tf_platform_hdrs",
-    "tf_platform_srcs",
-    "tf_proto_library",
-    "tf_proto_library_cc",
     "tf_additional_all_protos",
+    "tf_additional_cloud_kernel_deps",
+    "tf_additional_cloud_op_deps",
     "tf_additional_core_deps",
+    "tf_additional_cupti_wrapper_deps",
+    "tf_additional_device_tracer_cuda_deps",
+    "tf_additional_device_tracer_deps",
+    "tf_additional_device_tracer_srcs",
+    "tf_additional_gdr_lib_defines",
+    "tf_additional_human_readable_json_deps",
     "tf_additional_lib_defines",
     "tf_additional_lib_deps",
+    "tf_additional_libdevice_data",
+    "tf_additional_libdevice_deps",
+    "tf_additional_libdevice_srcs",
     "tf_additional_lib_hdrs",
     "tf_additional_lib_srcs",
     "tf_additional_minimal_lib_srcs",
+    "tf_additional_mpi_lib_defines",
     "tf_additional_proto_hdrs",
     "tf_additional_proto_srcs",
-    "tf_additional_cupti_wrapper_deps",
-    "tf_additional_libdevice_data",
-    "tf_additional_libdevice_deps",
-    "tf_additional_libdevice_srcs",
     "tf_additional_test_deps",
     "tf_additional_test_srcs",
-    "tf_kernel_tests_linkstatic",
-    "tf_additional_cloud_op_deps",
-    "tf_additional_cloud_kernel_deps",
-    "tf_lib_proto_parsing_deps",
     "tf_additional_verbs_lib_defines",
-    "tf_additional_mpi_lib_defines",
-    "tf_additional_gdr_lib_defines",
-    "tf_additional_device_tracer_srcs",
-    "tf_additional_device_tracer_deps",
-    "tf_additional_device_tracer_cuda_deps",
-    "tf_pyclif_proto_library",
     "tf_jspb_proto_library",
+    "tf_kernel_tests_linkstatic",
+    "tf_lib_proto_parsing_deps",
     "tf_nano_proto_library",
+    "tf_platform_hdrs",
+    "tf_platform_srcs",
+    "tf_proto_library",
+    "tf_proto_library_cc",
     "tf_protos_all",
     "tf_protos_all_impl",
     "tf_protos_grappler",
     "tf_protos_grappler_impl",
+    "tf_pyclif_proto_library",
 )
 load(
     "//tensorflow/core:platform/default/build_config_root.bzl",
@@ -400,6 +401,7 @@ cc_library(
         "protobuf.cc",
     ]) + [
         "platform/protobuf_util.cc",
+        "lib/core/status.h",
     ],
     hdrs = [
         ":platform_protobuf_hdrs",
@@ -416,6 +418,18 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "human_readable_json",
+    srcs = tf_platform_srcs(["human_readable_json.cc"]),
+    hdrs = ["platform/human_readable_json.h"],
+    copts = tf_copts(),
+    visibility = ["//visibility:public"],
+    deps = [
+        ":lib",
+        ":lib_internal",
+    ] + tf_additional_human_readable_json_deps(),
+)
+
 filegroup(
     name = "platform_env_hdrs",
     srcs = [
@@ -2013,6 +2027,7 @@ cc_library(
             "platform/**/cuda_libdevice_path.cc",
             "platform/**/device_tracer.cc",
             "platform/**/logging.cc",
+            "platform/**/human_readable_json.cc",
             "platform/abi.cc",
         ],
     ) + tf_additional_lib_srcs(
@@ -2025,6 +2040,7 @@ cc_library(
             "platform/**/env_time.cc",
             "platform/**/device_tracer.cc",
             "platform/**/logging.cc",
+            "platform/**/human_readable_json.cc",
             "platform/abi.cc",
         ] +
         # Protobuf deps already included through the ":lib_proto_parsing"
index 23c594d..43fe82c 100644 (file)
@@ -515,6 +515,9 @@ def tf_additional_proto_srcs():
       "platform/default/protobuf.cc",
   ]
 
+def tf_additional_human_readable_json_deps():
+  return []
+
 def tf_additional_all_protos():
   return ["//tensorflow/core:protos_all"]
 
diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc
new file mode 100644 (file)
index 0000000..6bf2106
--- /dev/null
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/human_readable_json.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+Status ProtoToHumanReadableJson(const ::google::protobuf::Message& proto,
+                                string* result) {
+  result->clear();
+
+  auto status = google::protobuf::util::MessageToJsonString(proto, result);
+  if (!status.ok()) {
+    // Convert error_msg google::protobuf::StringPiece to
+    // tensorflow::StringPiece.
+    auto error_msg = status.error_message();
+    return errors::Internal(
+        strings::StrCat("Could not convert proto to JSON string: ",
+                        StringPiece(error_msg.data(), error_msg.length())));
+  }
+  return Status::OK();
+}
+
+Status HumanReadableJsonToProto(const string& str,
+                                ::google::protobuf::Message* proto) {
+  proto->Clear();
+  auto status = google::protobuf::util::JsonStringToMessage(str, proto);
+  if (!status.ok()) {
+    // Convert error_msg google::protobuf::StringPiece to
+    // tensorflow::StringPiece.
+    auto error_msg = status.error_message();
+    return errors::Internal(
+        strings::StrCat("Could not convert JSON string to proto: ",
+                        StringPiece(error_msg.data(), error_msg.length())));
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h
new file mode 100644 (file)
index 0000000..c759e80
--- /dev/null
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
+#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Converts a proto to a JSON-like string that's meant to be human-readable
+// but still machine-parseable.
+//
+// This string may not be strictly JSON-compliant, but it must be parseable by
+// HumanReadableJSONToProto.
+Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result);
+
+// Converts a string produced by ProtoToHumanReadableJSON to a protobuf.  Not
+// guaranteed to work for general JSON.
+Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_