[IE][VPU]: Enable new tests for adjust_data_batch pass (#1219)
authorDaria Mityagina <daria.mityagina@intel.com>
Thu, 30 Jul 2020 10:19:33 +0000 (13:19 +0300)
committerGitHub <noreply@github.com>
Thu, 30 Jul 2020 10:19:33 +0000 (13:19 +0300)
* New tests for adjust_data_batch pass

inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp [new file with mode: 0644]

diff --git a/inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp b/inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp
new file mode 100644 (file)
index 0000000..67e6217
--- /dev/null
@@ -0,0 +1,352 @@
+// Copyright (C) 2020 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <vpu/stages/mx_stage.hpp>
+#include <vpu/utils/numeric.hpp>
+
+#include "graph_transformer_tests.hpp"
+
+using namespace vpu;
+
+class VPU_AdjustDataBatchTest : public GraphTransformerTest {
+protected:
+    const int batchSize = 4;
+    TestModel testModel;
+
+public:
+    void SetUp() override {
+        ASSERT_NO_FATAL_FAILURE(GraphTransformerTest::SetUp());
+
+        ASSERT_NO_FATAL_FAILURE(InitCompileEnv());
+
+        testModel = CreateTestModel();
+    }
+
+    void RunPass() {
+        PassSet pipeline;
+        pipeline.addPass(passManager->dumpModel("initial"));
+        pipeline.addPass(passManager->adjustDataBatch());
+        pipeline.addPass(passManager->dumpModel("adjustDataBatch"));
+        pipeline.run(testModel.getBaseModel());
+    }
+
+    DataVector checkSingleLoopStart(const Data& data) {
+        EXPECT_EQ(data->desc().dim(Dim::N), 4);
+        EXPECT_EQ(data->numConsumers(), 2);
+
+        DataVector outputs;
+        for (const auto& consumer : data->consumers()) {
+            EXPECT_TRUE(consumer->type() == StageType::LoopStart || consumer->type() == StageType::LoopEnd);
+            if (consumer->type() == StageType::LoopStart) {
+                for (const auto& output : consumer->outputs()) {
+                    EXPECT_EQ(output->desc().dim(Dim::N), 1);
+                    outputs.push_back(output);
+                }
+            }
+        }
+
+        return outputs;
+    }
+
+    DataVector checkBranches(const Data& root, const std::vector<StageType>& consumersTypes) {
+        auto successors = DataVector{};
+
+        const auto& consumers = root->consumers() | asVector();
+        EXPECT_EQ(consumers.size(), consumersTypes.size());
+        for (std::size_t i = 0; i < consumers.size(); ++i) {
+            const auto& consumer = consumers[i];
+            const auto& expected = consumersTypes[i];
+            EXPECT_EQ(consumer->type(), expected);
+
+            EXPECT_EQ(consumer->numOutputs(), 1);
+            const auto& output = consumer->output(0);
+            successors.push_back(output);
+
+            if (expected == StageType::LoopStart) {
+                EXPECT_EQ(consumer->numOutputs(), 1);
+                EXPECT_EQ(output->desc().dim(Dim::N), 1);
+            } else if (expected == StageType::LoopEnd) {
+                EXPECT_EQ(output->desc().dim(Dim::N), 4);
+            }
+        }
+
+        return successors;
+    }
+
+    DataVector checkSingleLoopEnd(const Data& data) {
+        EXPECT_EQ(data->numConsumers(), 1);
+
+        const auto& consumer = data->singleConsumer();
+        EXPECT_EQ(consumer->type(), StageType::LoopEnd);
+        DataVector outputs;
+        for (const auto& output : consumer->outputs()) {
+            EXPECT_EQ(output->desc().dim(Dim::N), 4);
+            outputs.push_back(output);
+        }
+
+        return outputs;
+    }
+
+    static Data CheckSingleConnection(const Data& data, int testInd, int batch = 1) {
+        EXPECT_EQ(data->numConsumers(), 1);
+
+        const auto& consumer = data->singleConsumer();
+        EXPECT_EQ(consumer->type(), StageType::None);
+        EXPECT_EQ(consumer->attrs().get<int>("test_ind"), testInd);
+        EXPECT_EQ(consumer->numOutputs(), 1);
+        const auto& output = consumer->output(0);
+        EXPECT_EQ(output->desc().dim(Dim::N), batch);
+        return output;
+    }
+
+    static Data singleElement(const DataVector& dataObjects) {
+        EXPECT_EQ(dataObjects.size(), 1);
+        return dataObjects.front();
+    }
+};
+
+TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheEnd) {
+    //
+    // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batched) -> [Output]
+    //
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc});
+
+    for (int i = 0; i < 6; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+    testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork(0)});
+
+    RunPass();
+
+    const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
+    const auto& data1 = CheckSingleConnection(data0, 0);
+    const auto& data2 = CheckSingleConnection(data1, 1);
+    const auto& data3 = CheckSingleConnection(data2, 2);
+    const auto& data4 = CheckSingleConnection(data3, 3);
+    const auto& data5 = CheckSingleConnection(data4, 4);
+    const auto& data6 = CheckSingleConnection(data5, 5);
+    const auto& data7 = singleElement(checkSingleLoopEnd(data6));
+
+    const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
+
+    ASSERT_EQ(data8, testModel.getOutputs().at(0));
+}
+
+TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchSplitItems) {
+    //                                                                                      -> (Batched) -> [Output]
+    // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split)
+    //                                                                                      -> (Batched) -> [Output]
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc, desc});
+
+    for (int i = 0; i < 7; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+
+    testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
+    testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
+
+    RunPass();
+
+    const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
+    const auto& data1 = CheckSingleConnection(data0, 0);
+    const auto& data2 = CheckSingleConnection(data1, 1);
+    const auto& data3 = CheckSingleConnection(data2, 2);
+    const auto& data4 = CheckSingleConnection(data3, 3);
+    const auto& data5 = CheckSingleConnection(data4, 4);
+    const auto& data6 = CheckSingleConnection(data5, 5);
+    const auto& data7 = CheckSingleConnection(data6, 6);
+    const auto& data8 = singleElement(checkSingleLoopEnd(data7));
+
+    const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
+    const auto& withBatch = branches[0];
+    const auto& withBatch_1 = branches[1];
+
+    ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
+    ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
+    ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
+
+    ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
+    ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
+    ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
+}
+
+TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheBeginning) {
+    //
+    // [Input] -> (Batched) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> [Output]
+    //
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc});
+
+    for (int i = 0; i < 6; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        if (i > 0)
+            testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+
+    testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork()});
+    testModel.setStageBatchInfo(6, {{0, BatchSupport::Split}});
+
+    RunPass();
+
+    const auto& data0 = CheckSingleConnection(testModel.getInputs().at(0), 0, batchSize);
+    const auto& data7 = singleElement(checkSingleLoopStart(data0));
+    const auto& data3 = CheckSingleConnection(data7, 1);
+    const auto& data4 = CheckSingleConnection(data3, 2);
+    const auto& data5 = CheckSingleConnection(data4, 3);
+    const auto& data6 = CheckSingleConnection(data5, 4);
+    const auto& data8 = CheckSingleConnection(data6, 5);
+    const auto& data10 = CheckSingleConnection(data8, 6);
+    const auto& data11 = checkSingleLoopEnd(data10);
+
+    ASSERT_EQ(data11, testModel.getOutputs());
+}
+
+TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchItemsInTheEnd) {
+    //                                                                                      -> (Batched) -> [Output]
+    // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batch)
+    //                                                                                      -> (Batched) -> [Output]
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc, desc});
+
+    for (int i = 0; i < 6; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+
+    testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::intermediate(desc)});
+    testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
+    testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
+
+    RunPass();
+
+    const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
+    const auto& data1 = CheckSingleConnection(data0, 0);
+    const auto& data2 = CheckSingleConnection(data1, 1);
+    const auto& data3 = CheckSingleConnection(data2, 2);
+    const auto& data4 = CheckSingleConnection(data3, 3);
+    const auto& data5 = CheckSingleConnection(data4, 4);
+    const auto& data6 = CheckSingleConnection(data5, 5);
+    const auto& data7 = singleElement(checkSingleLoopEnd(data6));
+
+    const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
+
+    const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
+    const auto& withBatch = branches[0];
+    const auto& withBatch_1 = branches[1];
+
+    ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
+    ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
+    ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
+
+    ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
+    ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
+    ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
+}
+
+TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithSplitAndBatchItemsInTheEnd) {
+    //
+    //                                         -> (Split) -> (Batched) -> [Output]
+    // [Input] -> (Split) -> (Split) -> (Split)
+    //                                         -> (Split) -> [Output]
+    //
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc, desc});
+
+    for (int i = 0; i < 5; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        if (i != 3)
+            testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+
+    testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
+    testModel.setStageBatchInfo(5, {{0, BatchSupport::Split}});
+
+    RunPass();
+
+    const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
+    const auto& data1 = CheckSingleConnection(data0, 0);
+    const auto& data2 = CheckSingleConnection(data1, 1);
+    const auto& data3 = CheckSingleConnection(data2, 2);
+
+    const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
+    const auto& branch1 = branches[0];
+    const auto& branch2 = branches[1];
+
+    const auto& data4 = CheckSingleConnection(branch1, 3);
+    const auto& data7 = singleElement(checkSingleLoopEnd(data4));
+    const auto& data5 = CheckSingleConnection(data7, 4, batchSize);
+    ASSERT_EQ(data5, testModel.getOutputs().at(0));
+    const auto& data6 = CheckSingleConnection(branch2, 5);
+    ASSERT_EQ(data6, testModel.getOutputs().at(1));
+}
+
+TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithBatchAndSplitItemsInTheEnd) {
+    //
+    //                                         -> (Split) -> [Output]
+    // [Input] -> (Split) -> (Split) -> (Split)
+    //                                         -> (Split) -> [Output]
+    //
+    const DataDesc desc{16, 16, 3, batchSize};
+
+    testModel.createInputs({desc});
+    testModel.createOutputs({desc, desc});
+
+    for (int i = 0; i < 3; i++) {
+        if (i > 0)
+            testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
+        else
+            testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
+        testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
+    }
+    for (int i = 0; i < 2; i++) {
+        testModel.addStage({InputInfo::fromNetwork(2)}, {OutputInfo::intermediate(desc)});
+        testModel.setStageBatchInfo(3 + i, {{0, BatchSupport::Split}});
+    }
+
+    testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(0)});
+    testModel.setStageBatchInfo(3, {{0, BatchSupport::Split}});
+    testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
+    testModel.setStageBatchInfo(4, {{0, BatchSupport::Split}});
+
+    RunPass();
+
+    const auto& data1 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
+    const auto& data2 = CheckSingleConnection(data1, 1);
+    const auto& data3 = CheckSingleConnection(data2, 2);
+    const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
+    const auto& branch1 = branches[0];
+    const auto& branch2 = branches[1];
+    const auto& data4 = CheckSingleConnection(branch1, 3);
+    const auto& data5 = CheckSingleConnection(branch2, 4);
+    const auto& data6 = checkSingleLoopEnd(data5);
+}