protected:
typedef std::vector<int64> Vec;
- void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) {
+ void RunR2Test(const Shape& shape, int64 max_target_partition_count,
+ const std::vector<int64>* expected_partitions) {
ShapePartitionAssigner assigner(shape);
- // Check all partitions of outer dimension.
- for (int64 i = 1; i <= expected_max_partition_count; ++i) {
- EXPECT_TRUE(ContainersEqual(Vec({i}),
- assigner.Run(/*target_partition_count=*/i)));
+ // Iterate through 1..max_target_partition_count.
+ for (int64 i = 1; i <= max_target_partition_count; ++i) {
+ std::vector<int64> actual_partitions =
+ assigner.Run(/*target_partition_count=*/i);
+ EXPECT_THAT(actual_partitions, expected_partitions[i - 1]);
}
- // Check target_partition_count > outer dimension size.
- EXPECT_TRUE(ContainersEqual(
- Vec({expected_max_partition_count}),
- assigner.Run(
- /*target_partition_count=*/expected_max_partition_count + 1)));
}
};
TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) {
- RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1);
+ std::vector<int64> expected_partitions[] = {{1} /* 1 */, {1, 2} /* 2 */};
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 2,
+ expected_partitions);
}
TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) {
- RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1);
+ std::vector<int64> expected_partitions[] = {
+ {1} /* 1 */, {1, 2} /* 2 */
+ };
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 2,
+ expected_partitions);
}
TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) {
- RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5);
+ std::vector<int64> expected_partitions[] = {{1} /* 1 */, {2} /* 2 */,
+ {3} /* 3 */, {4} /* 4 */,
+ {5} /* 5 */, {3, 2} /* 6 */};
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 6,
+ expected_partitions);
}
TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) {
- RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3);
+ std::vector<int64> expected_partitions[] = {
+ {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */};
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 4,
+ expected_partitions);
}
TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) {
- Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0});
- ShapePartitionAssigner assigner(shape);
-
- for (int64 i = 1; i <= 5; ++i) {
- EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run(
- /*target_partition_count=*/i)));
- }
-
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7)));
- EXPECT_TRUE(
- ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 3}),
- assigner.Run(/*target_partition_count=*/10)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 3}),
- assigner.Run(/*target_partition_count=*/11)));
- EXPECT_TRUE(ContainersEqual(Vec({4, 3}),
- assigner.Run(/*target_partition_count=*/12)));
- EXPECT_TRUE(ContainersEqual(Vec({4, 3}),
- assigner.Run(/*target_partition_count=*/13)));
- EXPECT_TRUE(ContainersEqual(Vec({4, 3}),
- assigner.Run(/*target_partition_count=*/14)));
- EXPECT_TRUE(ContainersEqual(Vec({5, 3}),
- assigner.Run(/*target_partition_count=*/15)));
- EXPECT_TRUE(ContainersEqual(Vec({5, 3}),
- assigner.Run(/*target_partition_count=*/16)));
+ std::vector<int64> expected_partitions[] = {
+ {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {4} /* 4 */,
+ {5} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {4, 2} /* 8 */,
+ {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {4, 3} /* 12 */,
+ {4, 3} /* 13 */, {4, 3} /* 14 */, {5, 3} /* 15 */, {4, 2, 2} /* 16 */};
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}), 16,
+ expected_partitions);
}
TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) {
- Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1});
- ShapePartitionAssigner assigner(shape);
-
- for (int64 i = 1; i <= 3; ++i) {
- EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run(
- /*target_partition_count=*/i)));
- }
-
- EXPECT_TRUE(
- ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4)));
- EXPECT_TRUE(
- ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8)));
- EXPECT_TRUE(
- ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 3}),
- assigner.Run(/*target_partition_count=*/10)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 3}),
- assigner.Run(/*target_partition_count=*/11)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 4}),
- assigner.Run(/*target_partition_count=*/12)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 4}),
- assigner.Run(/*target_partition_count=*/13)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 4}),
- assigner.Run(/*target_partition_count=*/14)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 5}),
- assigner.Run(/*target_partition_count=*/15)));
- EXPECT_TRUE(ContainersEqual(Vec({3, 5}),
- assigner.Run(/*target_partition_count=*/16)));
+ std::vector<int64> expected_partitions[] = {
+ {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */,
+ {2, 2} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {3, 2} /* 8 */,
+ {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {3, 4} /* 12 */,
+ {3, 4} /* 13 */, {3, 4} /* 14 */, {3, 5} /* 15 */, {3, 2, 2} /* 16 */};
+ RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}), 16,
+ expected_partitions);
}
class ShapePartitionIteratorTest : public HloTestBase {