Fix crash in HloGraphDumper where it crashes on tuple shaped constants
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 1 May 2018 16:07:57 +0000 (09:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 1 May 2018 16:10:07 +0000 (09:10 -0700)
The problem is that it tries to use a special logic for 0 element constants
but the logic used to check the number of elements only supports array shapes.

PiperOrigin-RevId: 194945246

tensorflow/compiler/xla/service/hlo_graph_dumper.cc
tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc

index 516e14b..bb4db89 100644 (file)
@@ -804,7 +804,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
     // is just noise.
-    if (ShapeUtil::HasZeroElements(shape)) {
+    if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) {
       return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
     }
 
index 4843963..8e52d92 100644 (file)
@@ -131,5 +131,23 @@ TEST(HloGraphDumperTest, Constant) {
   EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
 }
 
+TEST(HloGraphDumperTest, TupleConstant) {
+  Shape tuple_shape = ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})});
+  HloComputation::Builder b("b");
+  auto constant = b.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape)));
+  auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement(
+      ShapeUtil::MakeShape(F32, {3, 2}), constant, 0));
+
+  HloModuleConfig config;
+  HloModule m(TestName(), config);
+  HloComputation* root_computation = m.AddEntryComputation(b.Build(gte));
+  string graph = hlo_graph_dumper::DumpGraph(
+      *root_computation, /*label=*/"tuple_constant", DebugOptions());
+  EXPECT_THAT(graph, HasSubstr("tuple_constant"));
+  EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])"));
+}
+
 }  // anonymous namespace
 }  // namespace xla