Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / reshape.cpp
index a9a3c45..7e35592 100644 (file)
@@ -22,53 +22,41 @@ private:
 
     void propagateScaleFactorsImpl(
             const SmallVector<float>& inputScales,
-            ScalePropagationStep step) override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
+            ScalePropagationStep step,
+            StageDataInfo<float>& scaleInfo) override {
         if (step == ScalePropagationStep::Propagate) {
-            _scaleInfo.setOutput(_outputEdges[0], inputScales[0]);
+            scaleInfo.setOutput(outputEdge(0), inputScales[0]);
         } else {
             // Reshape can only propagate scaling.
-            _scaleInfo.setInput(_inputEdges[0], 1.0f);
-            _scaleInfo.setOutput(_outputEdges[0], 1.0f);
+            scaleInfo.setInput(inputEdge(0), 1.0f);
+            scaleInfo.setOutput(outputEdge(0), 1.0f);
         }
     }
 
-    void propagateDataOrderImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
+    void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
+        auto input = inputEdge(0)->input();
+        auto output = outputEdge(0)->output();
 
         // Only default order is supported
-        _orderInfo.setInput(_inputEdges[0], DimsOrder::fromNumDims(input->desc().numDims()));
-        _orderInfo.setOutput(_outputEdges[0], DimsOrder::fromNumDims(output->desc().numDims()));
+        orderInfo.setInput(inputEdge(0), DimsOrder::fromNumDims(input->desc().numDims()));
+        orderInfo.setOutput(outputEdge(0), DimsOrder::fromNumDims(output->desc().numDims()));
     }
 
-    void getDataStridesRequirementsImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        _stridesInfo.setInput(_inputEdges[0], StridesRequirement::compact());
-        _stridesInfo.setOutput(_outputEdges[0], StridesRequirement::compact());
+    void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
+        stridesInfo.setInput(inputEdge(0), StridesRequirement::compact());
+        stridesInfo.setOutput(outputEdge(0), StridesRequirement::compact());
     }
 
     void finalizeDataLayoutImpl() override {
     }
 
-    void getBatchSupportInfoImpl() const override {
+    void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
     }
 
-    void finalCheckImpl() const override {
-        IE_ASSERT(_inputEdges.size() == 1);
-        IE_ASSERT(_outputEdges.size() == 1);
-
-        auto input = _inputEdges[0]->input();
-        auto output = _outputEdges[0]->output();
-
-        IE_ASSERT(input->desc().totalDimSize() == output->desc().totalDimSize());
+    void initialCheckImpl() const override {
+        const auto& firstInputPrecision = input(0)->desc().type();
+        assertInputsOutputsTypes(this, {{firstInputPrecision}}, {{firstInputPrecision}});
+        IE_ASSERT(input(0)->desc().totalDimSize() == output(0)->desc().totalDimSize());
     }
 
     void serializeParamsImpl(BlobSerializer&) const override {