instructions_by_name.at("e")));
}
+// The current scheduler is suboptimal, in that it does not account for the
+// memory used by subcomputations when choosing a schedule.
+// This test demonstrates the current behavior.
+// We are working on improving it (b/65409243).
+TEST_F(HloSchedulingTest, SubcomputationsNotAccounted) {
+ // %WhileCond (cond_param: f32[4]) -> pred[] {
+ // %cond_param = f32[4]{0} parameter(0)
+ // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
+ // ROOT %not-equal-to = pred[] not-equal-to(
+ // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
+ // }
+ // %WhileBody (body_param: f32[4]) -> f32[4] {
+ // %body_param = f32[4]{0} parameter(0)
+ // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+ // ROOT %subtract = f32[4]{0} subtract(
+ // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
+ // }
+ // %SubcomputationsNotAccounted () -> f32[2,4] {
+ // %constant.3 = f32[2,4]{1,0} constant(
+ // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
+ // %transpose = f32[2,4]{1,0} transpose(
+ // f32[2,4]{1,0} %constant.3), dimensions={0,1}
+ // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+ // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
+ // condition=%WhileCond,
+ // body=%WhileBody
+ // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
+ // ROOT %add = f32[2,4]{1,0} add(
+ // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
+ // }
+
+ auto module = CreateNewModule();
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // param != 0
+ // Needs 17 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* zero_vector = cond_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector = body_builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ // transpose(matrix) + bcast(while)
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ // Creates 16 bytes, ignoring subcomputations
+ HloInstruction* while_loop =
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ // Creates 32 bytes and frees 16
+ HloInstruction* bcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
+
+ HloInstruction* matrix = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR2<float>(
+ {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
+ // Creates 32 bytes
+ HloInstruction* transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
+
+ // Creates 32 bytes and frees 64
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
+
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ }));
+ // Verify that all instructions are in the sequence.
+ EXPECT_EQ(module->entry_computation()->instruction_count(),
+ sequence.at(module->entry_computation()).size());
+ SequentialHloOrdering ordering(module.get(), sequence);
+ // TODO(b/65409243): while_loop is scheduled first by List; it's thought to be
+ // cheaper than transpose because the temporary memory needed for
+ // subcomputations is ignored. If we count the temporary memory as part of
+ // bytes_defined, then transpose would be scheduled first. Incidentally,
+ // ignoring subcomputations results in a better schedule here.
+ EXPECT_TRUE(ordering.ExecutesBefore(while_loop, transpose));
+ EXPECT_TRUE(ordering.ExecutesBefore(bcast, transpose));
+ EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
+ EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+}
+
} // namespace
} // namespace xla