From c86a47448534b135cdba106b59aee2492889ff75 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 30 May 2018 17:59:50 -0700 Subject: [PATCH] [XLA] Add parsers for Window and ConvolutionDimensionNumbers. Also modify relevant ToString functions so we can have the property Parse(ToString(x)) == x. PiperOrigin-RevId: 198650340 --- tensorflow/compiler/xla/service/hlo_instruction.cc | 79 ++++++++-------------- tensorflow/compiler/xla/service/hlo_instruction.h | 6 +- tensorflow/compiler/xla/tools/parser/BUILD | 1 + tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 63 ++++++++++++++--- tensorflow/compiler/xla/tools/parser/hlo_parser.h | 11 ++- .../compiler/xla/tools/parser/hlo_parser_test.cc | 21 ++++++ 6 files changed, 117 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index dc351e9..c55e5cf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2299,7 +2299,9 @@ std::vector HloInstruction::ExtraAttributesToString( } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -3419,42 +3421,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3478,19 +3446,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3516,6 +3473,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 6df97c4..8119c35 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1313,9 +1313,6 @@ class HloInstruction { return fft_length_; } - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1749,6 +1746,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index 0fa4b98..76f35af 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -65,6 +65,7 @@ tf_cc_test( srcs = ["hlo_parser_test.cc"], deps = [ ":hlo_parser", + "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 134978d..3c1d63a 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -56,10 +56,10 @@ class HloParser { // Returns the error information. string GetError() const { return Join(error_, "\n"); } - // Stand alone parsing for sharding. The parser string is supposed to - // contain the body of the sharding, i.e. just the rhs of the "sharding={...}" - // attribute string. + // Stand alone parsing utils for various aggregate data types. StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); private: // ParseXXX returns false if an error occurred. @@ -169,7 +169,9 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); @@ -1933,7 +1935,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -2051,9 +2053,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2063,7 +2066,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2127,7 +2132,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2692,6 +2698,32 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace StatusOr> Parse(StringPiece str, @@ -2714,5 +2746,18 @@ StatusOr ParseSharding(tensorflow::StringPiece str) { return parser.ParseShardingOnly(); } +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); +} + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h index f7854f4..902c45c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -36,10 +36,17 @@ StatusOr> Parse(tensorflow::StringPiece str, // format, parses the string and creates a HloModule with default config. StatusOr> Parse(tensorflow::StringPiece str); -// Parse sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". StatusOr ParseSharding(tensorflow::StringPiece str); +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 183b112..f7a27cf 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -1349,6 +1350,26 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace } // namespace tools } // namespace xla -- 2.7.4