Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / passes / adjust_data_layout.cpp
index b280994..d9bfd2e 100644 (file)
@@ -25,20 +25,21 @@ private:
 
     void propagateScaleFactorsImpl(
             const SmallVector<float>&,
-            ScalePropagationStep) override {
+            ScalePropagationStep,
+            StageDataInfo<float>&) override {
         VPU_THROW_EXCEPTION << "Must never be called";
     }
 
-    void propagateDataOrderImpl() const override {
+    void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
     }
 
-    void getDataStridesRequirementsImpl() const override {
+    void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
     }
 
     void finalizeDataLayoutImpl() override {
     }
 
-    void getBatchSupportInfoImpl() const override {
+    void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
     }
 
     StageSHAVEsRequirements getSHAVEsRequirementsImpl() const override {
@@ -46,11 +47,8 @@ private:
     }
 
     void finalCheckImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto output = outputEdge(0)->output();
 
         auto inDimsOrder = input->desc().dimsOrder();
         auto outDimsOrder = output->desc().dimsOrder();
@@ -63,12 +61,8 @@ private:
     }
 
     void serializeParamsImpl(BlobSerializer& serializer) const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-        IE_ASSERT(_tempBufferEdges.empty());
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto output = outputEdge(0)->output();
 
         auto inDimsOrder = input->desc().dimsOrder();
         auto outDimsOrder = output->desc().dimsOrder();
@@ -94,12 +88,8 @@ private:
     }
 
     void serializeDataImpl(BlobSerializer& serializer) const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-        IE_ASSERT(_tempBufferEdges.empty());
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto output = outputEdge(0)->output();
 
         input->serializeNewBuffer(serializer);
         output->serializeNewBuffer(serializer);
@@ -145,6 +135,12 @@ void PassImpl::run(const Model::Ptr& model) {
         if (data->usage() == DataUsage::Intermediate)
             continue;
 
+        if (data->usage() == DataUsage::Input || data->usage() == DataUsage::Output) {
+            if (!data->requiredStrides().fixedStrides().empty()) {
+                continue;
+            }
+        }
+
         data->updateRequiredStrides(StridesRequirement::compact());
     }
 
@@ -201,6 +197,10 @@ void PassImpl::run(const Model::Ptr& model) {
                 auto output = outEdge->output();
                 auto portInd = outEdge->portInd();
 
+                if (output->usage() == DataUsage::Fake) {
+                    continue;
+                }
+
                 auto requiredOrder = output->desc().dimsOrder();
 
                 if (curStageInfo.hasOutput(outEdge)) {
@@ -310,6 +310,10 @@ void PassImpl::run(const Model::Ptr& model) {
                 auto output = outEdge->output();
                 auto portInd = outEdge->portInd();
 
+                if (output->usage() == DataUsage::Fake) {
+                    continue;
+                }
+
                 auto requiredStrides = StridesRequirement();
 
                 if (curStageInfo.hasOutput(outEdge)) {