Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / passes / sw_conv_adaptation.cpp
index 3588814..ba2a0b8 100644 (file)
@@ -66,17 +66,15 @@ private:
 
     void propagateScaleFactorsImpl(
             const SmallVector<float>&,
-            ScalePropagationStep) override {
+            ScalePropagationStep,
+            StageDataInfo<float>&) override {
         VPU_THROW_EXCEPTION << "Must never be called";
     }
 
-    void propagateDataOrderImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 3);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto weights = _inputEdges[1]->input();
-        auto output = _outputEdges[0]->output();
+    void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
+        auto input = inputEdge(0)->input();
+        auto weights = inputEdge(1)->input();
+        auto output = outputEdge(0)->output();
 
         auto finalOrder = input->desc().dimsOrder();
         if (finalOrder.dimInd(Dim::C) == 1) {
@@ -87,37 +85,31 @@ private:
         if (_type == StageType::Conv ||
             _type == StageType::Im2ColConvolution) {
             if (finalOrder != input->desc().dimsOrder()) {
-                _orderInfo.setInput(_inputEdges[0], finalOrder);
+                orderInfo.setInput(inputEdge(0), finalOrder);
             }
-            _orderInfo.setOutput(_outputEdges[0], finalOrder);
+            orderInfo.setOutput(outputEdge(0), finalOrder);
         } else if (_type == StageType::DepthConv) {
             if (finalOrder != input->desc().dimsOrder()) {
-                _orderInfo.setInput(_inputEdges[0], finalOrder);
+                orderInfo.setInput(inputEdge(0), finalOrder);
             }
-            _orderInfo.setOutput(_outputEdges[0], finalOrder);
+            orderInfo.setOutput(outputEdge(0), finalOrder);
         } else {
-            _orderInfo.setInput(_inputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
-            _orderInfo.setOutput(_outputEdges[0], finalOrder.createMovedDim(Dim::C, 0));
+            orderInfo.setInput(inputEdge(0), finalOrder.createMovedDim(Dim::C, 0));
+            orderInfo.setOutput(outputEdge(0), finalOrder.createMovedDim(Dim::C, 0));
         }
     }
 
-    void getDataStridesRequirementsImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 3);
-        IE_ASSERT(_outputEdges.size() == 1);
-
+    void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
         if (_type != StageType::DepthConv) {
-            _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
-            _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
+            stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
+            stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
         }
     }
 
     void finalizeDataLayoutImpl() override {
-        IE_ASSERT(_inputEdges.size() == 3);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto weights = _inputEdges[1]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto weights = inputEdge(1)->input();
+        auto output = outputEdge(0)->output();
 
         auto kernelSizeX = attrs().get<int>("kernelSizeX");
         auto kernelSizeY = attrs().get<int>("kernelSizeY");
@@ -223,18 +215,18 @@ private:
 
         IE_ASSERT(swWeights != nullptr);
 
-        _model->replaceStageInput(_inputEdges[1], swWeights);
+        _model->replaceStageInput(inputEdge(1), swWeights);
     }
 
-    void getBatchSupportInfoImpl() const  override {
-        IE_ASSERT(_inputEdges.size() == 3);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
-        _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
+    void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
+        batchInfo.setInput(inputEdge(0), BatchSupport::Split);
+        batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
     }
 
     void finalCheckImpl() const override {
+        assertInputsOutputsTypes(this,
+              {{DataType::FP16}, {DataType::FP16}, {DataType::FP16}},
+              {{DataType::FP16}});
     }
 
     void serializeParamsImpl(BlobSerializer& serializer) const override {
@@ -258,20 +250,17 @@ private:
     }
 
     void serializeDataImpl(BlobSerializer& serializer) const override {
-        IE_ASSERT(_inputEdges.size() == 3);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto weights = _inputEdges[1]->input();
-        auto biases = _inputEdges[2]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto weights = inputEdge(1)->input();
+        auto biases = inputEdge(2)->input();
+        auto output = outputEdge(0)->output();
 
         input->serializeOldBuffer(handle_from_this(), serializer);
         output->serializeOldBuffer(handle_from_this(), serializer);
         weights->serializeOldBuffer(handle_from_this(), serializer);
 
-        if (!_tempBufferEdges.empty()) {
-            _tempBufferEdges[0]->tempBuffer()->serializeOldBuffer(handle_from_this(), serializer);
+        if (numTempBuffers() == 1) {
+            tempBuffer(0)->serializeOldBuffer(handle_from_this(), serializer);
         }
 
         // TODO: remove this