Implement additional options to control the string output of HloInstruction and HloCo...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 12 May 2018 00:53:06 +0000 (17:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 12 May 2018 00:55:58 +0000 (17:55 -0700)
PiperOrigin-RevId: 196334340

tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation_test.cc
tensorflow/compiler/xla/service/hlo_graph_dumper.cc
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/hlo_instruction_test.cc

index 05dceb1..63c3dc4 100644 (file)
@@ -365,25 +365,38 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
 string HloComputation::ToString(const HloPrintOptions& options) const {
   std::ostringstream s;
   for (int i = 0; i < options.indent_amount(); i++) {
-    s << "    ";
+    s << "  ";
   }
-  if (options.print_percent()) {
-    s << "%";
+
+  if (!options.is_in_nested_computation()) {
+    if (options.print_percent()) {
+      s << "%";
+    }
+    s << name() << " ";
   }
-  s << name();
+
   if (options.print_program_shape()) {
-    s << " " << ShapeUtil::HumanString(ComputeProgramShape());
-  }
-  s << " {\n";
-  for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
-    for (int i = 0; i < options.indent_amount(); i++) {
-      s << "    ";
+    s << ShapeUtil::HumanString(ComputeProgramShape()) << " ";
+  }
+  s << "{\n";
+  {
+    // Print the instructions in this computation.
+    HloPrintOptions new_options = options;
+    new_options.set_indent_amount(options.indent_amount() + 1)
+        .set_is_in_nested_computation(true);
+    CanonicalNameMap name_map;
+    for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+      for (int i = 0; i < new_options.indent_amount(); i++) {
+        s << "  ";
+      }
+      s << (instruction == root_instruction_ ? "ROOT " : "")
+        << instruction->ToStringWithCanonicalNameMap(new_options, &name_map)
+        << "\n";
     }
-    s << "  " << (instruction == root_instruction_ ? "ROOT " : "")
-      << instruction->ToString(options) << "\n";
   }
+
   for (int i = 0; i < options.indent_amount(); i++) {
-    s << "    ";
+    s << "  ";
   }
   s << "}";
   return s.str();
index 7b7588f..25469a5 100644 (file)
@@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) {
   EXPECT_FALSE(reachability->IsReachable(constant2, copy));
 }
 
+TEST_F(HloComputationTest, Stringification) {
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+
+  auto options = HloPrintOptions().set_print_metadata(false);
+  EXPECT_EQ(computation->ToString(options),
+            R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+  %x = f32[5,10]{1,0} parameter(0)
+  %y = f32[20,10]{1,0} parameter(1)
+  %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+  ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
+TEST_F(HloComputationTest, StringificationIndent) {
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+
+  auto options =
+      HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
+  EXPECT_EQ(computation->ToString(options),
+            R"(    %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+      %x = f32[5,10]{1,0} parameter(0)
+      %y = f32[20,10]{1,0} parameter(1)
+      %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+      ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+    })");
+}
+
+TEST_F(HloComputationTest, StringificationCanonical) {
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+
+  auto options = HloPrintOptions().set_print_metadata(false);
+  EXPECT_EQ(computation->ToString(options),
+            R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+  %x = f32[5,10]{1,0} parameter(0)
+  %y = f32[20,10]{1,0} parameter(1)
+  %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
+  ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+
+  options = HloPrintOptions().Canonical();
+  EXPECT_EQ(computation->ToString(options), R"(TransposeDot {
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+  ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
 }  // namespace
 
 }  // namespace xla
index 8dc3b83..17e3c40 100644 (file)
@@ -1104,7 +1104,8 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
   // Get the instruction's extra attributes excluding the names of its
   // subcomputations, since those are drawn explicitly in the graph.
   for (const auto& line : instr->ExtraAttributesToString(
-           HloPrintOptions().set_print_subcomputation_references(false))) {
+           HloPrintOptions().set_print_subcomputation_mode(
+               HloPrintOptions::PrintSubcomputationMode::kOff))) {
     lines.push_back(HtmlLikeStringSanitize(line));
   }
 
index 8d0fd65..a269034 100644 (file)
@@ -2106,13 +2106,40 @@ string PrintName(const string& name, const HloPrintOptions& options) {
 }  // namespace
 
 string HloInstruction::ToString(const HloPrintOptions& options) const {
-  string result =
-      StrCat(PrintName(name(), options), " = ",
-             ShapeUtil::HumanStringWithLayout(shape()), " ",
-             HloOpcodeString(opcode()), "(", OperandsToString(options), ")");
+  CanonicalNameMap new_map;
+  return ToStringWithCanonicalNameMap(options, &new_map);
+}
+
+string HloInstruction::ToStringWithCanonicalNameMap(
+    const HloPrintOptions& options,
+    CanonicalNameMap* canonical_name_map) const {
+  string result = "";
+
+  // Logic to print the instruction name (e.g. "%foo = ").
+  if (options.canonicalize_instruction_names()) {
+    if (options.is_in_nested_computation()) {
+      // If we are canonicalizing instruction names and this is a top-level
+      // HloInstruction::ToString() call, don't print an instruction name.
+      StrAppend(&result,
+                PrintName(canonical_name_map->LookupOrInsert(name()), options),
+                " = ");
+    }
+  } else {
+    StrAppend(&result, PrintName(name(), options), " = ");
+  }
+
+  // Print opcode, operand(s) and shape.
+  StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ",
+            HloOpcodeString(opcode()), "(",
+            OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
+            ")");
+
+  // Print additional attributes. If an instruction contains a subcomputation,
+  // the subcomputation is also printed here.
   for (const string& extra : ExtraAttributesToString(options)) {
     StrAppend(&result, ", ", extra);
   }
+
   if (options.print_metadata() &&
       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
        !metadata_.source_file().empty())) {
@@ -2125,6 +2152,13 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
 }
 
 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
+  CanonicalNameMap new_map;
+  return OperandsToStringWithCanonicalNameMap(options, &new_map);
+}
+
+string HloInstruction::OperandsToStringWithCanonicalNameMap(
+    const HloPrintOptions& options,
+    CanonicalNameMap* canonical_name_map) const {
   string operands;
   if (opcode() == HloOpcode::kConstant) {
     // For constants, show the actual value in place of an empty operand list.
@@ -2164,7 +2198,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
       if (options.print_operand_shape()) {
         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
       }
-      if (!options.compact_operands()) {
+
+      // In a top-level HloInstruction::ToString() call, the operand name is not
+      // part of the canonical string.
+      if (options.canonicalize_instruction_names() &&
+          options.is_in_nested_computation()) {
+        str.push_back(PrintName(
+            canonical_name_map->LookupOrInsert(operand->name()), options));
+      } else if (!options.compact_operands()) {
         str.push_back(PrintName(operand->name(), options));
       }
       StrAppend(out, Join(str, " "));
@@ -2233,7 +2274,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
     extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
   }
 
-  if (options.print_subcomputation_references()) {
+  if (options.print_subcomputation_mode() ==
+      HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
     if (opcode() == HloOpcode::kWhile) {
       extra.push_back(
           StrCat("condition=", PrintName(while_condition()->name(), options)));
@@ -2261,8 +2303,45 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
                                      PrintName(computation->name(), options));
                          })));
     }
+  } else if (options.print_subcomputation_mode() ==
+             HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
+    HloPrintOptions new_options = options;
+    new_options.set_is_in_nested_computation(true);
+    switch (opcode()) {
+      case HloOpcode::kWhile:
+        extra.push_back(
+            StrCat("condition=\n", while_condition()->ToString(new_options)));
+        extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
+        break;
+      case HloOpcode::kSelectAndScatter:
+        extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
+        extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
+        break;
+      case HloOpcode::kConditional:
+        extra.push_back(StrCat("true_computation=\n",
+                               true_computation()->ToString(new_options)));
+        extra.push_back(StrCat("false_computation=\n",
+                               false_computation()->ToString(new_options)));
+        break;
+      case HloOpcode::kCall:
+      case HloOpcode::kMap:
+      case HloOpcode::kReduceWindow:
+      case HloOpcode::kReduce:
+        extra.push_back(
+            StrCat("to_apply=\n", to_apply()->ToString(new_options)));
+        break;
+      default:
+        if (!called_computations().empty()) {
+          extra.push_back(
+              StrCat("calls=\n",
+                     Join(called_computations(), ", ",
+                          [&](string* out, const HloComputation* computation) {
+                            StrAppend(out, computation->ToString(new_options));
+                          })));
+        }
+        break;
+    }
   }
-
   if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
       opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
     extra.push_back(StrCat("channel_id=", channel_id_));
@@ -2300,7 +2379,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
   }
 
   // By contract, we print the custom call target even if
-  // !options.print_subcomputation_references(), because the call target is not
+  // options.print_subcomputation_mode() == kOff, because the call target is not
   // an HloComputation.
   if (opcode() == HloOpcode::kCustomCall) {
     extra.push_back(
index 2e5895e..0089cae 100644 (file)
@@ -60,23 +60,31 @@ class HloModule;
 // A bunch of switches that control how the hlo text should be printed.
 class HloPrintOptions {
  public:
+  enum class PrintSubcomputationMode {
+    kOff,         // Do not print anything about subcomputations.
+    kNameOnly,    // Only print the name of subcomputations.
+    kFullBodies,  // Print the full bodies of subcomputations.
+  };
+
   // Constructs the default print options: don't print large constants, don't
   // compact operands, no indentation.
   HloPrintOptions()
       : print_large_constants_(false),
-        print_subcomputation_references_(true),
+        print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
         print_metadata_(true),
         print_backend_config_(true),
         compact_operands_(false),
         print_operand_shape_(true),
         print_program_shape_(true),
         print_percent_(true),
-        indent_amount_(0) {}
+        canonicalize_instruction_names_(false),
+        indent_amount_(0),
+        is_in_nested_computation_(false) {}
 
   static HloPrintOptions ShortParsable() {
     return HloPrintOptions()
         .set_print_large_constants(true)
-        .set_print_subcomputation_references(true)
+        .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
         .set_print_metadata(false)
         .set_print_backend_config(false)
         .set_print_operand_shape(false)
@@ -84,20 +92,28 @@ class HloPrintOptions {
         .set_print_percent(false);
   }
 
+  // Options to produce the canonical string representing an isomorphic
+  // computation graph.
+  static HloPrintOptions Canonical() {
+    return HloPrintOptions()
+        .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
+        .set_print_metadata(false)
+        .set_compact_operands(true)
+        .set_print_operand_shape(true)
+        .set_print_program_shape(false)
+        .set_print_percent(false)
+        .set_canonicalize_instruction_names(true);
+  }
+
   // If true, large constants will be printed out.
   HloPrintOptions& set_print_large_constants(bool value) {
     print_large_constants_ = value;
     return *this;
   }
 
-  // If true, the names of subcomputations (e.g. a fusion node's fused
-  // computation) won't be printed.  This makes the resulting text not parsable.
-  //
-  // A CustomCall's call target is printed even if
-  // print_subcomputation_references is false, because the call target isn't an
-  // HloComputation.
-  HloPrintOptions& set_print_subcomputation_references(bool value) {
-    print_subcomputation_references_ = value;
+  HloPrintOptions& set_print_subcomputation_mode(
+      PrintSubcomputationMode value) {
+    print_subcomputation_mode_ = value;
     return *this;
   }
 
@@ -138,15 +154,29 @@ class HloPrintOptions {
     return *this;
   }
 
+  // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
+  // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
+  HloPrintOptions& set_canonicalize_instruction_names(bool value) {
+    canonicalize_instruction_names_ = value;
+    return *this;
+  }
+
   // The indent of the hlo text block.
   HloPrintOptions& set_indent_amount(int value) {
     indent_amount_ = value;
     return *this;
   }
 
+  // If true, indicates the instruction being printed is inside a nested
+  // computation.
+  HloPrintOptions& set_is_in_nested_computation(bool value) {
+    is_in_nested_computation_ = value;
+    return *this;
+  }
+
   bool print_large_constants() const { return print_large_constants_; }
-  bool print_subcomputation_references() const {
-    return print_subcomputation_references_;
+  PrintSubcomputationMode print_subcomputation_mode() const {
+    return print_subcomputation_mode_;
   }
   bool print_metadata() const { return print_metadata_; }
   bool print_backend_config() const { return print_metadata_; }
@@ -154,18 +184,51 @@ class HloPrintOptions {
   bool print_operand_shape() const { return print_operand_shape_; }
   bool print_program_shape() const { return print_program_shape_; }
   bool print_percent() const { return print_percent_; }
+  bool canonicalize_instruction_names() const {
+    return canonicalize_instruction_names_;
+  }
   int indent_amount() const { return indent_amount_; }
+  int is_in_nested_computation() const { return is_in_nested_computation_; }
 
  private:
   bool print_large_constants_;
-  bool print_subcomputation_references_;
+  PrintSubcomputationMode print_subcomputation_mode_;
   bool print_metadata_;
   bool print_backend_config_;
   bool compact_operands_;
   bool print_operand_shape_;
   bool print_program_shape_;
   bool print_percent_;
+  bool canonicalize_instruction_names_;
   int indent_amount_;
+  bool is_in_nested_computation_;
+};
+
+// For canonical string output, we need to have a canonical way to rename
+// each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
+// where <xxx> is an index starting from 0.
+class CanonicalNameMap {
+ public:
+  CanonicalNameMap() : index(0) {}
+
+  string LookupOrInsert(const string& old_name) {
+    auto iter = canonical_name_map.find(old_name);
+    if (iter != canonical_name_map.end()) {
+      return iter->second;
+    }
+
+    string new_name = tensorflow::strings::StrCat("tmp_", index++);
+    canonical_name_map[old_name] = new_name;
+    return new_name;
+  }
+  void Clear() {
+    canonical_name_map.clear();
+    index = 0;
+  }
+
+ private:
+  int64 index;
+  tensorflow::gtl::FlatMap<string, string> canonical_name_map;
 };
 
 // HLO instructions are the IR used by the high-level compiler.
@@ -1331,6 +1394,24 @@ class HloInstruction {
                         const ShapeIndex& shape_index = {});
 
  private:
+  // Prints an instruction to a string.
+  //
+  // The canonical string representation needs to name operands and instruction
+  // names in a consistent way. This is implemented through the
+  // canonical_name_map.
+  string ToStringWithCanonicalNameMap(
+      const HloPrintOptions& options,
+      CanonicalNameMap* canonical_name_map) const;
+
+  // Prints an operand to a string.
+  string OperandsToStringWithCanonicalNameMap(
+      const HloPrintOptions& options,
+      CanonicalNameMap* canonical_name_map) const;
+
+  // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and
+  // OperandsToStringWithCanonicalNameMap() functions.
+  friend class HloComputation;
+
   enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
 
   // Helper class for computing OperandElementUse for kFusion.
index 909cdc0..a61c472 100644 (file)
@@ -1336,5 +1336,163 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
             "index_vector_dim=2, window_bounds={30,29,28,27,26}");
 }
 
+TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
+  // Tests stringification of a simple op, fusion, while, and conditional.
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  HloInstruction* dot = builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+  auto options = HloPrintOptions().Canonical();
+
+  EXPECT_EQ(dot->ToString(options),
+            "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
+            "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
+
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+  HloInstruction* fusion = computation->CreateFusionInstruction(
+      {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+  EXPECT_EQ(
+      fusion->ToString(options),
+      R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
+{
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+  ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+})");
+}
+
+TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
+  // Tests stringification of a simple op, fusion, while, and conditional.
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  HloInstruction* dot = builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+  computation->CreateFusionInstruction({dot, reshape},
+                                       HloInstruction::FusionKind::kLoop);
+
+  HloInstruction* loop = builder.AddInstruction(
+      HloInstruction::CreateWhile(sout, computation, computation, x));
+
+  auto options = HloPrintOptions().Canonical();
+  EXPECT_EQ(loop->ToString(options),
+            R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+{
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+  {
+    tmp_0 = f32[5,10]{1,0} parameter(0)
+    tmp_1 = f32[20,10]{1,0} parameter(1)
+    tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+    ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+}, body=
+{
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+  {
+    tmp_0 = f32[5,10]{1,0} parameter(0)
+    tmp_1 = f32[20,10]{1,0} parameter(1)
+    tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+    ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+})");
+}
+
+TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
+  // Tests stringification of a simple op, fusion, while, and conditional.
+  const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+  const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+  const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+  const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+  HloComputation::Builder builder("TransposeDot");
+  HloInstruction* x =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+  HloInstruction* y =
+      builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+  HloInstruction* reshape =
+      builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+  DotDimensionNumbers dot_dnums;
+  dot_dnums.add_lhs_contracting_dimensions(1);
+  dot_dnums.add_rhs_contracting_dimensions(0);
+  HloInstruction* dot = builder.AddInstruction(
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+  computation->CreateFusionInstruction({dot, reshape},
+                                       HloInstruction::FusionKind::kLoop);
+
+  builder.AddInstruction(
+      HloInstruction::CreateWhile(sout, computation, computation, x));
+
+  auto pred = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+  HloInstruction* conditional =
+      builder.AddInstruction(HloInstruction::CreateConditional(
+          sout, pred, x, computation, x, computation));
+  auto options = HloPrintOptions().Canonical();
+  EXPECT_EQ(
+      conditional->ToString(options),
+      R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
+{
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+  {
+    tmp_0 = f32[5,10]{1,0} parameter(0)
+    tmp_1 = f32[20,10]{1,0} parameter(1)
+    tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+    ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+}, false_computation=
+{
+  tmp_0 = f32[5,10]{1,0} parameter(0)
+  tmp_1 = f32[20,10]{1,0} parameter(1)
+  ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+  {
+    tmp_0 = f32[5,10]{1,0} parameter(0)
+    tmp_1 = f32[20,10]{1,0} parameter(1)
+    tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+    ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+  }
+})");
+}
+
 }  // namespace
 }  // namespace xla