[XLA] Allow omitting operands shapes and program shapes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 15 Dec 2017 03:07:06 +0000 (19:07 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 15 Dec 2017 03:10:31 +0000 (19:10 -0800)
PiperOrigin-RevId: 179132435

tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/tools/parser/README.md
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index 4f6feefb4359ee9fbc97f20e5c9125c0db39b31b..4202c083367792d595d8c22edb7dd154881030ec 100644 (file)
@@ -369,8 +369,11 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
   for (int i = 0; i < options.indent_amount(); i++) {
     s << "    ";
   }
-  s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
-    << " {\n";
+  s << "%" << name();
+  if (options.print_program_shape()) {
+    s << " " << ShapeUtil::HumanString(ComputeProgramShape());
+  }
+  s << " {\n";
   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
     for (int i = 0; i < options.indent_amount(); i++) {
       s << "    ";
index 9e37ab64a06d557bd5abc0da7bfb6d172110d934..58883101a5a91575b570afc9399aa6a9a5734a4b 100644 (file)
@@ -1964,10 +1964,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
       slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
     }
     operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
-      *out += ShapeUtil::HumanStringWithLayout(operand->shape());
+      std::vector<string> str;
+      if (options.print_operand_shape()) {
+        str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
+      }
       if (!options.compact_operands()) {
-        StrAppend(out, " %", operand->name());
+        str.push_back(StrCat("%", operand->name()));
       }
+      StrAppend(out, Join(str, " "));
     });
     const int64 remaining = operands_.size() - slice.size();
     if (slice.size() != operands_.size()) {
index 753b7dc0bf0332984d6436aaf54fdfb8758fb301..6d6068c66adc997ab2e33ee747dc3175212d4fc8 100644 (file)
@@ -65,8 +65,18 @@ class HloPrintOptions {
       : print_large_constants_(false),
         print_metadata_(true),
         compact_operands_(false),
+        print_operand_shape_(true),
+        print_program_shape_(true),
         indent_amount_(0) {}
 
+  static HloPrintOptions ShortParsable() {
+    return HloPrintOptions()
+        .set_print_large_constants(true)
+        .set_print_metadata(false)
+        .set_print_operand_shape(false)
+        .set_print_program_shape(false);
+  }
+
   // If true, large constants will be printed out.
   HloPrintOptions& set_print_large_constants(bool value) {
     print_large_constants_ = value;
@@ -79,6 +89,18 @@ class HloPrintOptions {
     return *this;
   }
 
+  // If true, operands' shapes will be printed.
+  HloPrintOptions& set_print_operand_shape(bool value) {
+    print_operand_shape_ = value;
+    return *this;
+  }
+
+  // If true, program shape of hlo computations will be printed.
+  HloPrintOptions& set_print_program_shape(bool value) {
+    print_program_shape_ = value;
+    return *this;
+  }
+
   // If true, only a part of operands will be printed out, and their names will
   // be omitted (note that in this case the text will not be parsable).
   HloPrintOptions& set_compact_operands(bool value) {
@@ -95,12 +117,16 @@ class HloPrintOptions {
   bool print_large_constants() const { return print_large_constants_; }
   bool print_metadata() const { return print_metadata_; }
   bool compact_operands() const { return compact_operands_; }
+  bool print_operand_shape() const { return print_operand_shape_; }
+  bool print_program_shape() const { return print_program_shape_; }
   int indent_amount() const { return indent_amount_; }
 
  private:
   bool print_large_constants_;
   bool print_metadata_;
   bool compact_operands_;
+  bool print_operand_shape_;
+  bool print_program_shape_;
   int indent_amount_;
 };
 
index 6232967f5f04cbf316d985357ae84c28335531e2..45e005581e2b51ce6d33f487b34b7d4336bb9832 100644 (file)
@@ -15,8 +15,10 @@ computations
   ;
 
 computation
-  : 'ENTRY' name param_list '->' shape instruction_list
-  | name param_list '->' shape instruction_list
+  : 'ENTRY' name param_list_to_shape instruction_list
+  | name param_list_to_shape instruction_list
+  | 'ENTRY' name instruction_list
+  | name instruction_list
   ;
 
 instruction_list
@@ -41,6 +43,7 @@ operands1
   ;
 operand
   : shape name
+  | name
   ;
 
 attributes
@@ -60,6 +63,10 @@ attribute_value
   | '{' sub_attributes '}'
   ;
 
+param_list_to_shape
+  : param_list '->' shape
+  ;
+
 param_list
   : '(' param_list1 ')'
   ;
index 710e76f53dc61066fe979a7b24726a83ee447eb4..e47c3b03ed7e872ce10bb37a5fa9bb08f36636ac 100644 (file)
@@ -171,6 +171,7 @@ class HloParser {
   bool ParseInt64List(const TokKind start, const TokKind end,
                       const TokKind delim, std::vector<int64>* result);
 
+  bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
   bool ParseParamList();
   bool ParseName(string* result);
   bool ParseAttributeName(string* result);
@@ -184,6 +185,12 @@ class HloParser {
   bool ParseBool(bool* result);
   bool ParseToken(TokKind kind, const string& msg);
 
+  // Returns true if the current token is the beginning of a shape.
+  bool CanBeShape();
+  // Returns true if the current token is the beginning of a
+  // param_list_to_shape.
+  bool CanBeParamListToShape();
+
   // Logs the current parsing line and the given message. Always returns false.
   bool TokenError(StringPiece msg);
   bool Error(LocTy loc, StringPiece msg);
@@ -267,7 +274,7 @@ bool HloParser::ParseComputations() {
   return true;
 }
 
-// computation ::= ('ENTRY')? name param_list '->' shape instruction_list
+// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
 bool HloParser::ParseComputation() {
   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
   string name;
@@ -277,14 +284,14 @@ bool HloParser::ParseComputation() {
   }
   auto builder = MakeUnique<HloComputation::Builder>(name);
 
+  LocTy shape_loc = nullptr;
   Shape shape;
-  string root_name;
-  if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+  if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
     return false;
   }
 
-  LocTy shape_ty = lexer_.GetLoc();
-  if (!ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) {
+  string root_name;
+  if (!ParseInstructionList(builder.get(), &root_name)) {
     return false;
   }
 
@@ -311,9 +318,10 @@ bool HloParser::ParseComputation() {
     CHECK_EQ(root, computation->root_instruction());
   }
 
-  if (!ShapeUtil::Compatible(root->shape(), shape)) {
+  // If param_list_to_shape was present, check compatibility.
+  if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
     return Error(
-        shape_ty,
+        shape_loc,
         StrCat("Shape of computation ", name, ", ",
                ShapeUtil::HumanString(shape),
                ", is not compatible with that of its root instruction ",
@@ -1438,7 +1446,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
 // operands1
 //   ::= /*empty*/
 //   ::= operand (, operand)*
-// operand ::= shape name
+// operand ::= (shape)? name
 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
   if (!ParseToken(TokKind::kLparen,
                   "expects '(' at the beginning of operands")) {
@@ -1449,9 +1457,14 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
   } else {
     do {
       LocTy loc = lexer_.GetLoc();
-      Shape shape;
       string name;
-      if (!ParseShape(&shape) || !ParseName(&name)) {
+      if (CanBeShape()) {
+        Shape shape;
+        if (!ParseShape(&shape)) {
+          return false;
+        }
+      }
+      if (!ParseName(&name)) {
         return false;
       }
       HloInstruction* instruction =
@@ -1976,6 +1989,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
       end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
 }
 
+// param_list_to_shape ::= param_list '->' shape
+bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
+  if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
+    return false;
+  }
+  *shape_loc = lexer_.GetLoc();
+  return ParseShape(shape);
+}
+
+bool HloParser::CanBeParamListToShape() {
+  return lexer_.GetKind() == TokKind::kLparen;
+}
+
 // param_list ::= '(' param_list1 ')'
 // param_list1
 //   ::= /*empty*/
@@ -2032,6 +2058,13 @@ bool HloParser::ParseShape(Shape* result) {
   return true;
 }
 
+bool HloParser::CanBeShape() {
+  // A non-tuple shape starts with a kShape token; a tuple shape starts with
+  // '('.
+  return lexer_.GetKind() == TokKind::kShape ||
+         lexer_.GetKind() == TokKind::kLparen;
+}
+
 bool HloParser::ParseName(string* result) {
   VLOG(1) << "ParseName";
   if (lexer_.GetKind() != TokKind::kName) {
index 5c12a991cca1644bcf4e4316e42963a5d84a1a8b..29b3cc83e7ff8276f3c1db7d2b5d02c70d2c199f 100644 (file)
@@ -405,44 +405,6 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
   ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
 }
 
-)"
-},
-// map
-{
-"Map",
-R"(HloModule MapBinaryAdder_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] {
-  %param0 = f32[4]{0} parameter(0)
-  %param1 = f32[4]{0} parameter(1)
-  ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3
-}
-
-)"
-},
-// reduce
-{
-"Reduce",
-R"(HloModule ReduceR3ToR2_module:
-
-%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
-  %lhs = f32[] parameter(0)
-  %rhs = f32[] parameter(1)
-  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
-}
-
-ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] {
-  %input = f32[8,16,256]{2,1,0} parameter(0)
-  %constant = f32[] constant(0)
-  ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3
-}
-
 )"
 },
 // select and scatter
@@ -664,6 +626,51 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
   ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
 }
 
+)"
+}
+  });
+  // clang-format on
+}
+
+std::vector<TestData> CreateShortTestCases() {
+  // clang-format off
+  return std::vector<TestData>({
+// map
+{
+"Map",
+R"(HloModule MapBinaryAdder_module:
+
+%add_F32.v3 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %MapBinaryAdder.v3 {
+  %param0 = f32[4]{0} parameter(0)
+  %param1 = f32[4]{0} parameter(1)
+  ROOT %map = f32[4]{0} map(%param0, %param1), to_apply=%add_F32.v3
+}
+
+)"
+},
+// reduce
+{
+"Reduce",
+R"(HloModule ReduceR3ToR2_module:
+
+%add_F32.v3 {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(%lhs, %rhs)
+}
+
+ENTRY %ReduceR3ToR2.v3 {
+  %input = f32[8,16,256]{2,1,0} parameter(0)
+  %constant = f32[] constant(0)
+  ROOT %reduce = f32[8,16]{1,0} reduce(%input, %constant), dimensions={2}, to_apply=%add_F32.v3
+}
+
 )"
 },
 // infeed/outfeed
@@ -671,11 +678,11 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] {
 "InfeedOutfeed",
 R"(HloModule outfeed_module:
 
-ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
+ENTRY %InfeedToOutfeed {
   %infeed = (u32[3]{0}, pred[]) infeed()
-  %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed)
+  %outfeed = () outfeed(%infeed)
   ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed()
-  %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1)
+  %outfeed.1 = () outfeed(%infeed.1)
 }
 
 )"
@@ -685,10 +692,10 @@ ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
 "Rng",
 R"(HloModule rng_module:
 
-ENTRY %Rng () -> f32[8] {
+ENTRY %Rng {
   %constant = f32[] constant(0)
   %constant.1 = f32[] constant(1)
-  ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform
+  ROOT %rng = f32[8]{0} rng(%constant, %constant.1), distribution=rng_uniform
 }
 
 )"
@@ -698,9 +705,9 @@ ENTRY %Rng () -> f32[8] {
 "ReducePrevison",
 R"(HloModule reduce_precision:
 
-ENTRY %ReducePrecision () -> f32[1] {
+ENTRY %ReducePrecision {
   %constant = f32[1]{0} constant({3.14159})
-  ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10
+  ROOT %reduce-precision = f32[1]{0} reduce-precision(%constant), exponent_bits=8, mantissa_bits=10
 }
 
 )"
@@ -710,34 +717,33 @@ ENTRY %ReducePrecision () -> f32[1] {
 "Conditional",
 R"(HloModule conditional:
 
-%Negate (x: f32[]) -> f32[] {
+%Negate {
   %x = f32[] parameter(0)
-  ROOT %negate = f32[] negate(f32[] %x)
+  ROOT %negate = f32[] negate(%x)
 }
 
-%Identity (y: f32[]) -> f32[] {
+%Identity {
   %y = f32[] parameter(0)
-  ROOT %copy = f32[] copy(f32[] %y)
+  ROOT %copy = f32[] copy(%y)
 }
 
-ENTRY %Parameters1.v4 () -> f32[] {
+ENTRY %Parameters1.v4 {
   %constant = pred[] constant(true)
   %constant.1 = f32[] constant(56)
   %constant.2 = f32[] constant(12)
-  ROOT %conditional = f32[] conditional(pred[] %constant, f32[] %constant.1, f32[] %constant.2), true_computation=%Negate, false_computation=%Identity
+  ROOT %conditional = f32[] conditional(%constant, %constant.1, %constant.2), true_computation=%Negate, false_computation=%Identity
 }
 
 )"
 },
-
 // CustomCall
 {
 "CustomCall",
 R"(HloModule custom_call:
 
-ENTRY %CustomCall () -> f32[1,2,3] {
+ENTRY %CustomCall {
   %constant = f32[1]{0} constant({12345})
-  ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
+  ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(%constant), custom_call_target="foo\"bar"
 }
 
 )"
@@ -747,9 +753,9 @@ ENTRY %CustomCall () -> f32[1,2,3] {
 "NonDefaultNames",
 R"(HloModule add_constants_module:
 
-ENTRY %add_constants () -> f32[] {
+ENTRY %add_constants {
   %foo = f32[] constant(3.14)
-  ROOT %bar = f32[] add(f32[] %foo, f32[] %foo)
+  ROOT %bar = f32[] add(%foo, %foo)
 }
 
 )"
@@ -778,12 +784,29 @@ class HloParserTest : public ::testing::Test,
   }
 };
 
+class HloParserShortTest : public HloParserTest {
+ protected:
+  void ExpectEqualShort() {
+    const string& original = GetParam().module_string;
+    auto result = Parse(original);
+    TF_ASSERT_OK(result.status());
+    EXPECT_EQ(original,
+              result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
+  }
+};
+
 TEST_P(HloParserTest, Run) { ExpectEqual(); }
 
+TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
+
 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
                         ::testing::ValuesIn(CreateTestCases()),
                         TestDataToString);
 
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+                        ::testing::ValuesIn(CreateShortTestCases()),
+                        TestDataToString);
+
 TEST_F(HloParserTest, Empty) {
   const string original = "";
   auto result = Parse(original);