[XLA] Add parsers for Window and ConvolutionDimensionNumbers.
authorJustin Lebar <jlebar@google.com>
Thu, 31 May 2018 00:59:50 +0000 (17:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 01:01:59 +0000 (18:01 -0700)
Also modify relevant ToString functions so we can have the property
Parse(ToString(x)) == x.

PiperOrigin-RevId: 198650340

tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/tools/parser/BUILD
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.h
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index dc351e9..c55e5cf 100644 (file)
@@ -2299,7 +2299,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
   }
 
   if (convolution_dimension_numbers_ != nullptr) {
-    extra.push_back(ConvolutionDimensionNumbersToString());
+    extra.push_back(StrCat(
+        "dim_labels=",
+        ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
   }
   if (dot_dimension_numbers_ != nullptr) {
     extra.push_back(DotDimensionNumbersToString());
@@ -3419,42 +3421,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) {
   return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
 }
 
-StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
-  static std::unordered_map<string, RandomDistribution>* map = [] {
-    static auto* map = new std::unordered_map<string, RandomDistribution>;
-    for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
-      if (RandomDistribution_IsValid(i)) {
-        auto value = static_cast<RandomDistribution>(i);
-        (*map)[RandomDistributionToString(value)] = value;
-      }
-    }
-    return map;
-  }();
-  auto found = map->find(tensorflow::str_util::Lowercase(name));
-  if (found == map->end()) {
-    return InvalidArgument("Unknown distribution");
-  }
-  return found->second;
-}
-
-std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
-  return os << ToString(kind);
-}
-
-string HloInstruction::ConvolutionDimensionNumbersToString() const {
-  string result;
-  if (convolution_dimension_numbers_ == nullptr) {
-    return result;
-  }
-  const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
-  // Show the given dimension labels in order of major to minor based on the
-  // shape's layout.
-  const auto append_dims = [&](const std::vector<string>& dims,
-                               const Shape& shape) {
-    CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
-    StrAppend(&result, Join(dims, ""));
-  };
-
+string ConvolutionDimensionNumbersToString(
+    const ConvolutionDimensionNumbers& dnums) {
   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
   std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
@@ -3478,19 +3446,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
   }
 
-  result += "dim_labels=";
-  append_dims(lhs_dims, operand(0)->shape());
-  result += "_";
-  append_dims(rhs_dims, operand(1)->shape());
-  result += "->";
-
-  // A convolution can be represented as a kConvolution HLO or as a CustomCall
-  // that returns a tuple, the first element of which is the result of the
-  // convolution.
-  Shape this_shape =
-      ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
-  append_dims(output_dims, this_shape);
-  return result;
+  return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
+                Join(output_dims, ""));
 }
 
 string HloInstruction::DotDimensionNumbersToString() const {
@@ -3516,6 +3473,28 @@ string HloInstruction::DotDimensionNumbersToString() const {
   return Join(result, ", ");
 }
 
+StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
+  static std::unordered_map<string, RandomDistribution>* map = [] {
+    static auto* map = new std::unordered_map<string, RandomDistribution>;
+    for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
+      if (RandomDistribution_IsValid(i)) {
+        auto value = static_cast<RandomDistribution>(i);
+        (*map)[RandomDistributionToString(value)] = value;
+      }
+    }
+    return map;
+  }();
+  auto found = map->find(tensorflow::str_util::Lowercase(name));
+  if (found == map->end()) {
+    return InvalidArgument("Unknown distribution");
+  }
+  return found->second;
+}
+
+std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
+  return os << ToString(kind);
+}
+
 string HloInstruction::GatherDimensionNumbersToString() const {
   CHECK_NE(gather_dimension_numbers_.get(), nullptr);
   string output_window_dims =
index 6df97c4..8119c35 100644 (file)
@@ -1313,9 +1313,6 @@ class HloInstruction {
     return fft_length_;
   }
 
-  // Returns the dump string of the convolution dimension numbers.
-  string ConvolutionDimensionNumbersToString() const;
-
   // Returns data on the dimension numbers used for a dot operation.
   const DotDimensionNumbers& dot_dimension_numbers() const {
     CHECK(dot_dimension_numbers_ != nullptr);
@@ -1749,6 +1746,9 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
 string PaddingConfigToString(const PaddingConfig& padding);
 string OpMetadataToString(const OpMetadata& metadata);
 string RandomDistributionToString(const RandomDistribution& distribution);
+string ConvolutionDimensionNumbersToString(
+    const ConvolutionDimensionNumbers& dnums);
+
 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
 
 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
index 0fa4b98..76f35af 100644 (file)
@@ -65,6 +65,7 @@ tf_cc_test(
     srcs = ["hlo_parser_test.cc"],
     deps = [
         ":hlo_parser",
+        "//tensorflow/compiler/xla:window_util",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
index 134978d..3c1d63a 100644 (file)
@@ -56,10 +56,10 @@ class HloParser {
   // Returns the error information.
   string GetError() const { return Join(error_, "\n"); }
 
-  // Stand alone parsing for sharding. The parser string is supposed to
-  // contain the body of the sharding, i.e. just the rhs of the "sharding={...}"
-  // attribute string.
+  // Stand alone parsing utils for various aggregate data types.
   StatusOr<HloSharding> ParseShardingOnly();
+  StatusOr<Window> ParseWindowOnly();
+  StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
 
  private:
   // ParseXXX returns false if an error occurred.
@@ -169,7 +169,9 @@ class HloParser {
   bool ParseComputationName(HloComputation** value);
   // Parses a list of names and finds the corresponding hlo instructions.
   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
-  bool ParseWindow(Window* window);
+  // Pass expect_outer_curlies == true when parsing a Window in the context of a
+  // larger computation.  Pass false when parsing a stand-alone Window string.
+  bool ParseWindow(Window* window, bool expect_outer_curlies);
   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
   bool ParsePaddingConfig(PaddingConfig* padding);
   bool ParseMetadata(OpMetadata* metadata);
@@ -1933,7 +1935,7 @@ bool HloParser::ParseAttributeHelper(
       }
       case AttrTy::kWindow: {
         Window result;
-        if (!ParseWindow(&result)) {
+        if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
           return false;
         }
         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
@@ -2051,9 +2053,10 @@ bool HloParser::ParseComputationName(HloComputation** value) {
 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
 // The subattributes can appear in any order. 'size=' is required, others are
 // optional.
-bool HloParser::ParseWindow(Window* window) {
+bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
   LocTy loc = lexer_.GetLoc();
-  if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
+  if (expect_outer_curlies &&
+      !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
     return false;
   }
 
@@ -2063,7 +2066,9 @@ bool HloParser::ParseWindow(Window* window) {
   std::vector<int64> lhs_dilate;
   std::vector<int64> rhs_dilate;
   std::vector<int64> rhs_reversal;
-  while (lexer_.GetKind() != TokKind::kRbrace) {
+  const auto end_token =
+      expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
+  while (lexer_.GetKind() != end_token) {
     LocTy attr_loc = lexer_.GetLoc();
     string field_name;
     if (!ParseAttributeName(&field_name)) {
@@ -2127,7 +2132,8 @@ bool HloParser::ParseWindow(Window* window) {
     window->mutable_dimensions(i)->set_window_reversal(
         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
   }
-  return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
+  return !expect_outer_curlies ||
+         ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
 }
 
 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
@@ -2692,6 +2698,32 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
   return HloSharding::FromProto(op_sharding);
 }
 
+StatusOr<Window> HloParser::ParseWindowOnly() {
+  lexer_.Lex();
+  Window window;
+  if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
+    return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+  }
+  if (lexer_.GetKind() != TokKind::kEof) {
+    return InvalidArgument("Syntax error:\nExtra content after window");
+  }
+  return window;
+}
+
+StatusOr<ConvolutionDimensionNumbers>
+HloParser::ParseConvolutionDimensionNumbersOnly() {
+  lexer_.Lex();
+  ConvolutionDimensionNumbers dnums;
+  if (!ParseConvolutionDimensionNumbers(&dnums)) {
+    return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+  }
+  if (lexer_.GetKind() != TokKind::kEof) {
+    return InvalidArgument(
+        "Syntax error:\nExtra content after convolution dnums");
+  }
+  return dnums;
+}
+
 }  // namespace
 
 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
@@ -2714,5 +2746,18 @@ StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
   return parser.ParseShardingOnly();
 }
 
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+  HloModuleConfig config;
+  HloParser parser(str, config);
+  return parser.ParseWindowOnly();
+}
+
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+    tensorflow::StringPiece str) {
+  HloModuleConfig config;
+  HloParser parser(str, config);
+  return parser.ParseConvolutionDimensionNumbersOnly();
+}
+
 }  // namespace tools
 }  // namespace xla
index f7854f4..902c45c 100644 (file)
@@ -36,10 +36,17 @@ StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
 // format, parses the string and creates a HloModule with default config.
 StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
 
-// Parse sharding from str. str is supposed to contain the body of the
-// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
 StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
 
+// Parses the result of window_util::ToString(const Window&).
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+
+// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
+// "b0f_0io->b0f".
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+    tensorflow::StringPiece str);
+
 }  // namespace tools
 }  // namespace xla
 
index 183b112..f7a27cf 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
 
 #include <string>
+#include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -1349,6 +1350,26 @@ ENTRY entry {
       "was parsing 8:39: error: instruction does not exist: aparam");
 }
 
+TEST_F(HloParserTest, ParseSharding) {
+  const string original = "{maximal device=42}";
+  TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
+  EXPECT_EQ(sharding.ToString(), original);
+}
+
+TEST_F(HloParserTest, ParseWindow) {
+  Window original = window_util::MakeWindow({1, 2, 3});
+  TF_ASSERT_OK_AND_ASSIGN(Window parsed,
+                          ParseWindow(window_util::ToString(original)))
+  EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
+}
+
+TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
+  const string original = "b0f_0io->b0f";
+  TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
+                          ParseConvolutionDimensionNumbers(original));
+  EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
+}
+
 }  // namespace
 }  // namespace tools
 }  // namespace xla