}
if (convolution_dimension_numbers_ != nullptr) {
- extra.push_back(ConvolutionDimensionNumbersToString());
+ extra.push_back(StrCat(
+ "dim_labels=",
+ ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
}
-StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
- static std::unordered_map<string, RandomDistribution>* map = [] {
- static auto* map = new std::unordered_map<string, RandomDistribution>;
- for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
- if (RandomDistribution_IsValid(i)) {
- auto value = static_cast<RandomDistribution>(i);
- (*map)[RandomDistributionToString(value)] = value;
- }
- }
- return map;
- }();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
- if (found == map->end()) {
- return InvalidArgument("Unknown distribution");
- }
- return found->second;
-}
-
-std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
- return os << ToString(kind);
-}
-
-string HloInstruction::ConvolutionDimensionNumbersToString() const {
- string result;
- if (convolution_dimension_numbers_ == nullptr) {
- return result;
- }
- const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
- // Show the given dimension labels in order of major to minor based on the
- // shape's layout.
- const auto append_dims = [&](const std::vector<string>& dims,
- const Shape& shape) {
- CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
- StrAppend(&result, Join(dims, ""));
- };
-
+string ConvolutionDimensionNumbersToString(
+ const ConvolutionDimensionNumbers& dnums) {
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
- result += "dim_labels=";
- append_dims(lhs_dims, operand(0)->shape());
- result += "_";
- append_dims(rhs_dims, operand(1)->shape());
- result += "->";
-
- // A convolution can be represented as a kConvolution HLO or as a CustomCall
- // that returns a tuple, the first element of which is the result of the
- // convolution.
- Shape this_shape =
- ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
- append_dims(output_dims, this_shape);
- return result;
+ return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
+ Join(output_dims, ""));
}
string HloInstruction::DotDimensionNumbersToString() const {
return Join(result, ", ");
}
+StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
+ static std::unordered_map<string, RandomDistribution>* map = [] {
+ static auto* map = new std::unordered_map<string, RandomDistribution>;
+ for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
+ if (RandomDistribution_IsValid(i)) {
+ auto value = static_cast<RandomDistribution>(i);
+ (*map)[RandomDistributionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(tensorflow::str_util::Lowercase(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
+std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
+ return os << ToString(kind);
+}
+
string HloInstruction::GatherDimensionNumbersToString() const {
CHECK_NE(gather_dimension_numbers_.get(), nullptr);
string output_window_dims =
return fft_length_;
}
- // Returns the dump string of the convolution dimension numbers.
- string ConvolutionDimensionNumbersToString() const;
-
// Returns data on the dimension numbers used for a dot operation.
const DotDimensionNumbers& dot_dimension_numbers() const {
CHECK(dot_dimension_numbers_ != nullptr);
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string ConvolutionDimensionNumbersToString(
+ const ConvolutionDimensionNumbers& dnums);
+
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
// Returns the error information.
string GetError() const { return Join(error_, "\n"); }
- // Stand alone parsing for sharding. The parser string is supposed to
- // contain the body of the sharding, i.e. just the rhs of the "sharding={...}"
- // attribute string.
+ // Stand alone parsing utils for various aggregate data types.
StatusOr<HloSharding> ParseShardingOnly();
+ StatusOr<Window> ParseWindowOnly();
+ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
private:
// ParseXXX returns false if an error occurred.
bool ParseComputationName(HloComputation** value);
// Parses a list of names and finds the corresponding hlo instructions.
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
- bool ParseWindow(Window* window);
+ // Pass expect_outer_curlies == true when parsing a Window in the context of a
+ // larger computation. Pass false when parsing a stand-alone Window string.
+ bool ParseWindow(Window* window, bool expect_outer_curlies);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata);
}
case AttrTy::kWindow: {
Window result;
- if (!ParseWindow(&result)) {
+ if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
// The subattributes can appear in any order. 'size=' is required, others are
// optional.
-bool HloParser::ParseWindow(Window* window) {
+bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
LocTy loc = lexer_.GetLoc();
- if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
+ if (expect_outer_curlies &&
+ !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
return false;
}
std::vector<int64> lhs_dilate;
std::vector<int64> rhs_dilate;
std::vector<int64> rhs_reversal;
- while (lexer_.GetKind() != TokKind::kRbrace) {
+ const auto end_token =
+ expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
+ while (lexer_.GetKind() != end_token) {
LocTy attr_loc = lexer_.GetLoc();
string field_name;
if (!ParseAttributeName(&field_name)) {
window->mutable_dimensions(i)->set_window_reversal(
rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
}
- return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
+ return !expect_outer_curlies ||
+ ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
return HloSharding::FromProto(op_sharding);
}
+StatusOr<Window> HloParser::ParseWindowOnly() {
+ lexer_.Lex();
+ Window window;
+ if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument("Syntax error:\nExtra content after window");
+ }
+ return window;
+}
+
+StatusOr<ConvolutionDimensionNumbers>
+HloParser::ParseConvolutionDimensionNumbersOnly() {
+ lexer_.Lex();
+ ConvolutionDimensionNumbers dnums;
+ if (!ParseConvolutionDimensionNumbers(&dnums)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument(
+ "Syntax error:\nExtra content after convolution dnums");
+ }
+ return dnums;
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
return parser.ParseShardingOnly();
}
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParseWindowOnly();
+}
+
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+ tensorflow::StringPiece str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParseConvolutionDimensionNumbersOnly();
+}
+
} // namespace tools
} // namespace xla
// format, parses the string and creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
-// Parse sharding from str. str is supposed to contain the body of the
-// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+// Parses the result of window_util::ToString(const Window&).
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+
+// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
+// "b0f_0io->b0f".
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+ tensorflow::StringPiece str);
+
} // namespace tools
} // namespace xla
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
"was parsing 8:39: error: instruction does not exist: aparam");
}
+TEST_F(HloParserTest, ParseSharding) {
+ const string original = "{maximal device=42}";
+ TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
+ EXPECT_EQ(sharding.ToString(), original);
+}
+
+TEST_F(HloParserTest, ParseWindow) {
+ Window original = window_util::MakeWindow({1, 2, 3});
+ TF_ASSERT_OK_AND_ASSIGN(Window parsed,
+ ParseWindow(window_util::ToString(original)))
+ EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
+}
+
+TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
+ const string original = "b0f_0io->b0f";
+ TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
+ ParseConvolutionDimensionNumbers(original));
+ EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
+}
+
} // namespace
} // namespace tools
} // namespace xla