From a82e0e7922d6dc657b42ef2b3a7a1a52194454c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 May 2018 09:07:57 -0700 Subject: [PATCH] Fix crash in HloGraphDumper where it crashes on tuple shaped constants The problem is that it tries to use a special logic for 0 element constants but the logic used to check the number of elements only supports array shapes. PiperOrigin-RevId: 194945246 --- tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 2 +- .../compiler/xla/service/hlo_graph_dumper_test.cc | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 516e14b..bb4db89 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -804,7 +804,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (ShapeUtil::HasZeroElements(shape)) { + if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 4843963..8e52d92 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -131,5 +131,23 @@ TEST(HloGraphDumperTest, Constant) { EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } +TEST(HloGraphDumperTest, TupleConstant) { + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); + HloComputation::Builder b("b"); + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape))); + auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {3, 2}), constant, 0)); + + HloModuleConfig config; + HloModule m(TestName(), config); + HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"tuple_constant", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("tuple_constant")); + EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); +} + } // anonymous namespace } // namespace xla -- 2.7.4