Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / region_yolo.cpp
index cda718c..9563497 100644 (file)
@@ -19,39 +19,26 @@ private:
         return std::make_shared<RegionYoloStage>(*this);
     }
 
-    void propagateDataOrderImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto output = _outputEdges[0]->output();
-
-        if (!attrs().get<bool>("doSoftMax")) {
-            _orderInfo.setOutput(_outputEdges[0], output->desc().dimsOrder().createMovedDim(Dim::C, 2));  // CHW
-        }
+    void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
     }
 
-    void getDataStridesRequirementsImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
+    void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
         if (attrs().get<bool>("doSoftMax")) {
             // Major dimension must be compact.
-            _stridesInfo.setInput(_inputEdges[0], StridesRequirement().add(2, DimStride::Compact));
+            stridesInfo.setInput(inputEdge(0), StridesRequirement().add(2, DimStride::Compact));
         }
     }
 
     void finalizeDataLayoutImpl() override {
     }
 
-    void getBatchSupportInfoImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        _batchInfo.setInput(_inputEdges[0], BatchSupport::Split);
-        _batchInfo.setOutput(_outputEdges[0], BatchSupport::Split);
+    void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
+        batchInfo.setInput(inputEdge(0), BatchSupport::Split);
+        batchInfo.setOutput(outputEdge(0), BatchSupport::Split);
     }
 
-    void finalCheckImpl() const override {
+    void initialCheckImpl() const override {
+        assertInputsOutputsTypes(this, {{DataType::FP16}}, {{DataType::FP16}});
     }
 
     void serializeParamsImpl(BlobSerializer& serializer) const override {
@@ -69,12 +56,8 @@ private:
     }
 
     void serializeDataImpl(BlobSerializer& serializer) const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-        IE_ASSERT(_tempBufferEdges.empty());
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
+        auto input = inputEdge(0)->input();
+        auto output = outputEdge(0)->output();
 
         input->serializeOldBuffer(handle_from_this(), serializer);
         output->serializeOldBuffer(handle_from_this(), serializer);