[TF:XLA] Add tests to show that the List scheduler handles tuples correctly (in and...
authorDimitris Vardoulakis <dimvar@google.com>
Wed, 23 May 2018 22:17:03 +0000 (15:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 22:19:10 +0000 (15:19 -0700)
PiperOrigin-RevId: 197798787

tensorflow/compiler/xla/service/hlo_scheduling.cc
tensorflow/compiler/xla/service/hlo_scheduling_test.cc

index 29c3377..68b2cde 100644 (file)
@@ -299,6 +299,8 @@ class ListScheduler {
       auto best_it = ready_queue.end();
       --best_it;
       const HloInstruction* best = best_it->second.instruction;
+      VLOG(2) << "Schedule instruction: " << best->ToShortString()
+              << " Bytes freed: " << best_it->first.first;
       ready_queue.erase(best_it);
       ready_instructions.erase(best);
       schedule.push_back(best);
index c018ba2..0bc930f 100644 (file)
@@ -289,5 +289,100 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
   EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
 }
 
+TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
+  auto builder = HloComputation::Builder(TestName());
+  const auto TUPLE_SIZE = 1;
+  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
+
+  // Wrap lit in abs because constants are considered free by
+  // IgnoreInstruction, and it skews the accounting.
+  auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
+      Literal::CreateR1<float>({1, 1, 1, 1, 1, 1})));
+  auto abs_const = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
+
+  auto abs_abs1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
+      tensorflow::gtl::ArraySlice<HloInstruction*>({abs_abs1})));
+  auto tuple_elm = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+  auto abs_abs2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+
+  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
+                                                      tuple_elm, abs_abs2));
+
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+  TF_ASSERT_OK_AND_ASSIGN(
+      SequentialHloOrdering::HloModuleSequence sequence,
+      CreateMemoryMinimizingSequence(*module,
+                                     [&TUPLE_SIZE](const BufferValue& buffer) {
+                                       return ShapeUtil::ByteSizeOf(
+                                           buffer.shape(), TUPLE_SIZE);
+                                     },
+                                     ListMemoryScheduler));
+
+  // Verify that all instructions are in the sequence.
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            sequence.at(module->entry_computation()).size());
+  SequentialHloOrdering ordering(module.get(), sequence);
+  // tuple allocates the tuple buffer and doesn't free anything.
+  // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
+  // abs_abs2 should be scheduled before tuple by List.
+  EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
+}
+
+TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
+  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
+  HloComputation::Builder builder(TestName());
+
+  auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
+      Literal::CreateR1<float>({1, 1, 1, 1, 1})));
+  auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
+      Literal::CreateR1<float>({1, 2, 3, 4, 5})));
+  auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
+      Literal::CreateR1<float>({0, 2, 4, 6, 8})));
+
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
+  auto mul = builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
+
+  auto tuple_elm = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+  auto exp = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
+
+  builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
+
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+
+  auto fusion = computation->CreateFusionInstruction(
+      {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
+
+  TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
+                          CreateMemoryMinimizingSequence(
+                              *module,
+                              [](const BufferValue& buffer) {
+                                return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
+                              },
+                              ListMemoryScheduler));
+
+  // Verify that all instructions are in the sequence.
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            sequence.at(module->entry_computation()).size());
+  SequentialHloOrdering ordering(module.get(), sequence);
+  // fusion allocates memory for the tuple elements and doesn't free anything,
+  // so it's more expensive than exp.
+  EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
+}
+
 }  // namespace
 }  // namespace xla