: print_large_constants_(false),
print_metadata_(true),
compact_operands_(false),
+ print_operand_shape_(true),
+ print_program_shape_(true),
indent_amount_(0) {}
+ static HloPrintOptions ShortParsable() {
+ return HloPrintOptions()
+ .set_print_large_constants(true)
+ .set_print_metadata(false)
+ .set_print_operand_shape(false)
+ .set_print_program_shape(false);
+ }
+
// If true, large constants will be printed out.
HloPrintOptions& set_print_large_constants(bool value) {
print_large_constants_ = value;
return *this;
}
+ // If true, operands' shapes will be printed.
+ HloPrintOptions& set_print_operand_shape(bool value) {
+ print_operand_shape_ = value;
+ return *this;
+ }
+
+ // If true, program shape of hlo computations will be printed.
+ HloPrintOptions& set_print_program_shape(bool value) {
+ print_program_shape_ = value;
+ return *this;
+ }
+
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
bool print_large_constants() const { return print_large_constants_; }
bool print_metadata() const { return print_metadata_; }
bool compact_operands() const { return compact_operands_; }
+ bool print_operand_shape() const { return print_operand_shape_; }
+ bool print_program_shape() const { return print_program_shape_; }
int indent_amount() const { return indent_amount_; }
private:
bool print_large_constants_;
bool print_metadata_;
bool compact_operands_;
+ bool print_operand_shape_;
+ bool print_program_shape_;
int indent_amount_;
};
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim, std::vector<int64>* result);
+ bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
bool ParseName(string* result);
bool ParseAttributeName(string* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
+ // Returns true if the current token is the beginning of a shape.
+ bool CanBeShape();
+ // Returns true if the current token is the beginning of a
+ // param_list_to_shape.
+ bool CanBeParamListToShape();
+
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(StringPiece msg);
bool Error(LocTy loc, StringPiece msg);
return true;
}
-// computation ::= ('ENTRY')? name param_list '->' shape instruction_list
+// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
bool HloParser::ParseComputation() {
const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
string name;
}
auto builder = MakeUnique<HloComputation::Builder>(name);
+ LocTy shape_loc = nullptr;
Shape shape;
- string root_name;
- if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+ if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
return false;
}
- LocTy shape_ty = lexer_.GetLoc();
- if (!ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) {
+ string root_name;
+ if (!ParseInstructionList(builder.get(), &root_name)) {
return false;
}
CHECK_EQ(root, computation->root_instruction());
}
- if (!ShapeUtil::Compatible(root->shape(), shape)) {
+ // If param_list_to_shape was present, check compatibility.
+ if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
return Error(
- shape_ty,
+ shape_loc,
StrCat("Shape of computation ", name, ", ",
ShapeUtil::HumanString(shape),
", is not compatible with that of its root instruction ",
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
-// operand ::= shape name
+// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
} else {
do {
LocTy loc = lexer_.GetLoc();
- Shape shape;
string name;
- if (!ParseShape(&shape) || !ParseName(&name)) {
+ if (CanBeShape()) {
+ Shape shape;
+ if (!ParseShape(&shape)) {
+ return false;
+ }
+ }
+ if (!ParseName(&name)) {
return false;
}
HloInstruction* instruction =
end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
}
+// param_list_to_shape ::= param_list '->' shape
+bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
+ if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+ return false;
+ }
+ *shape_loc = lexer_.GetLoc();
+ return ParseShape(shape);
+}
+
+bool HloParser::CanBeParamListToShape() {
+ return lexer_.GetKind() == TokKind::kLparen;
+}
+
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
return true;
}
+bool HloParser::CanBeShape() {
+ // A non-tuple shape starts with a kShape token; a tuple shape starts with
+ // '('.
+ return lexer_.GetKind() == TokKind::kShape ||
+ lexer_.GetKind() == TokKind::kLparen;
+}
+
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
if (lexer_.GetKind() != TokKind::kName) {
ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
}
-)"
-},
-// map
-{
-"Map",
-R"(HloModule MapBinaryAdder_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] {
- %param0 = f32[4]{0} parameter(0)
- %param1 = f32[4]{0} parameter(1)
- ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3
-}
-
-)"
-},
-// reduce
-{
-"Reduce",
-R"(HloModule ReduceR3ToR2_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
- %lhs = f32[] parameter(0)
- %rhs = f32[] parameter(1)
- ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] {
- %input = f32[8,16,256]{2,1,0} parameter(0)
- %constant = f32[] constant(0)
- ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3
-}
-
)"
},
// select and scatter
ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
}
+)"
+}
+ });
+ // clang-format on
+}
+
+std::vector<TestData> CreateShortTestCases() {
+ // clang-format off
+ return std::vector<TestData>({
+// map
+{
+"Map",
+R"(HloModule MapBinaryAdder_module:
+
+%add_F32.v3 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %MapBinaryAdder.v3 {
+ %param0 = f32[4]{0} parameter(0)
+ %param1 = f32[4]{0} parameter(1)
+ ROOT %map = f32[4]{0} map(%param0, %param1), to_apply=%add_F32.v3
+}
+
+)"
+},
+// reduce
+{
+"Reduce",
+R"(HloModule ReduceR3ToR2_module:
+
+%add_F32.v3 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %ReduceR3ToR2.v3 {
+ %input = f32[8,16,256]{2,1,0} parameter(0)
+ %constant = f32[] constant(0)
+ ROOT %reduce = f32[8,16]{1,0} reduce(%input, %constant), dimensions={2}, to_apply=%add_F32.v3
+}
+
)"
},
// infeed/outfeed
"InfeedOutfeed",
R"(HloModule outfeed_module:
-ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
+ENTRY %InfeedToOutfeed {
%infeed = (u32[3]{0}, pred[]) infeed()
- %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed)
+ %outfeed = () outfeed(%infeed)
ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed()
- %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1)
+ %outfeed.1 = () outfeed(%infeed.1)
}
)"
"Rng",
R"(HloModule rng_module:
-ENTRY %Rng () -> f32[8] {
+ENTRY %Rng {
%constant = f32[] constant(0)
%constant.1 = f32[] constant(1)
- ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform
+ ROOT %rng = f32[8]{0} rng(%constant, %constant.1), distribution=rng_uniform
}
)"
"ReducePrevison",
R"(HloModule reduce_precision:
-ENTRY %ReducePrecision () -> f32[1] {
+ENTRY %ReducePrecision {
%constant = f32[1]{0} constant({3.14159})
- ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10
+ ROOT %reduce-precision = f32[1]{0} reduce-precision(%constant), exponent_bits=8, mantissa_bits=10
}
)"
"Conditional",
R"(HloModule conditional:
-%Negate (x: f32[]) -> f32[] {
+%Negate {
%x = f32[] parameter(0)
- ROOT %negate = f32[] negate(f32[] %x)
+ ROOT %negate = f32[] negate(%x)
}
-%Identity (y: f32[]) -> f32[] {
+%Identity {
%y = f32[] parameter(0)
- ROOT %copy = f32[] copy(f32[] %y)
+ ROOT %copy = f32[] copy(%y)
}
-ENTRY %Parameters1.v4 () -> f32[] {
+ENTRY %Parameters1.v4 {
%constant = pred[] constant(true)
%constant.1 = f32[] constant(56)
%constant.2 = f32[] constant(12)
- ROOT %conditional = f32[] conditional(pred[] %constant, f32[] %constant.1, f32[] %constant.2), true_computation=%Negate, false_computation=%Identity
+ ROOT %conditional = f32[] conditional(%constant, %constant.1, %constant.2), true_computation=%Negate, false_computation=%Identity
}
)"
},
-
// CustomCall
{
"CustomCall",
R"(HloModule custom_call:
-ENTRY %CustomCall () -> f32[1,2,3] {
+ENTRY %CustomCall {
%constant = f32[1]{0} constant({12345})
- ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(%constant), custom_call_target="foo\"bar"
}
)"
"NonDefaultNames",
R"(HloModule add_constants_module:
-ENTRY %add_constants () -> f32[] {
+ENTRY %add_constants {
%foo = f32[] constant(3.14)
- ROOT %bar = f32[] add(f32[] %foo, f32[] %foo)
+ ROOT %bar = f32[] add(%foo, %foo)
}
)"
}
};
+class HloParserShortTest : public HloParserTest {
+ protected:
+ void ExpectEqualShort() {
+ const string& original = GetParam().module_string;
+ auto result = Parse(original);
+ TF_ASSERT_OK(result.status());
+ EXPECT_EQ(original,
+ result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
+ }
+};
+
TEST_P(HloParserTest, Run) { ExpectEqual(); }
+TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
+
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+ ::testing::ValuesIn(CreateShortTestCases()),
+ TestDataToString);
+
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);