[TF:XLA] Scheduling test which demonstrates that we are ignoring the memory needed...
authorDimitris Vardoulakis <dimvar@google.com>
Tue, 15 May 2018 05:28:06 +0000 (22:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 15 May 2018 05:30:27 +0000 (22:30 -0700)
PiperOrigin-RevId: 196618347

tensorflow/compiler/xla/service/hlo_scheduling_test.cc

index 92df7c1..4e956af 100644 (file)
@@ -190,5 +190,108 @@ ENTRY root {
                                       instructions_by_name.at("e")));
 }
 
+// The current scheduler is suboptimal, in that it does not account for the
+// memory used by subcomputations when choosing a schedule.
+// This test demonstrates the current behavior.
+// We are working on improving it (b/65409243).
+TEST_F(HloSchedulingTest, SubcomputationsNotAccounted) {
+  // %WhileCond (cond_param: f32[4]) -> pred[] {
+  //   %cond_param = f32[4]{0} parameter(0)
+  //   %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
+  //   ROOT %not-equal-to = pred[] not-equal-to(
+  //     f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
+  // }
+  // %WhileBody (body_param: f32[4]) -> f32[4] {
+  //   %body_param = f32[4]{0} parameter(0)
+  //   %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+  //   ROOT %subtract = f32[4]{0} subtract(
+  //     f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
+  // }
+  // %SubcomputationsNotAccounted () -> f32[2,4] {
+  //   %constant.3 = f32[2,4]{1,0} constant(
+  //     f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
+  //   %transpose = f32[2,4]{1,0} transpose(
+  //     f32[2,4]{1,0} %constant.3), dimensions={0,1}
+  //   %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+  //   %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
+  //      condition=%WhileCond,
+  //      body=%WhileBody
+  //   %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
+  //   ROOT %add = f32[2,4]{1,0} add(
+  //     f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
+  // }
+
+  auto module = CreateNewModule();
+  const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+  const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+  // param != 0
+  // Needs 17 bytes
+  auto cond_builder = HloComputation::Builder("WhileCond");
+  HloInstruction* cond_param = cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+  HloInstruction* zero_vector = cond_builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+  cond_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+  auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+  // param - 1
+  // Needs 16 bytes
+  auto body_builder = HloComputation::Builder("WhileBody");
+  HloInstruction* body_param = body_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "body_param"));
+  HloInstruction* one_vector = body_builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+  body_builder.AddInstruction(HloInstruction::CreateBinary(
+      r1f32, HloOpcode::kSubtract, body_param, one_vector));
+  auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+  // transpose(matrix) + bcast(while)
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* while_init = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+  // Creates 16 bytes, ignoring subcomputations
+  HloInstruction* while_loop =
+      builder.AddInstruction(HloInstruction::CreateWhile(
+          r1f32, cond_computation, body_computation, while_init));
+
+  // Creates 32 bytes and frees 16
+  HloInstruction* bcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
+
+  HloInstruction* matrix = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2<float>(
+          {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
+  // Creates 32 bytes
+  HloInstruction* transpose = builder.AddInstruction(
+      HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
+
+  // Creates 32 bytes and frees 64
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
+
+  module->AddEntryComputation(builder.Build());
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      SequentialHloOrdering::HloModuleSequence sequence,
+      CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape());
+      }));
+  // Verify that all instructions are in the sequence.
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            sequence.at(module->entry_computation()).size());
+  SequentialHloOrdering ordering(module.get(), sequence);
+  // TODO(b/65409243): while_loop is scheduled first by List; it's thought to be
+  // cheaper than transpose because the temporary memory needed for
+  // subcomputations is ignored. If we count the temporary memory as part of
+  // bytes_defined, then transpose would be scheduled first. Incidentally,
+  // ignoring subcomputations results in a better schedule here.
+  EXPECT_TRUE(ordering.ExecutesBefore(while_loop, transpose));
+  EXPECT_TRUE(ordering.ExecutesBefore(bcast, transpose));
+  EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
+  EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+}
+
 }  // namespace
 }  // namespace xla