Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / vpu / graph_transformer / src / stages / custom.cpp
index 89ac460..dd4490f 100644 (file)
@@ -41,66 +41,67 @@ private:
         return std::make_shared<CustomStage>(*this);
     }
 
-    void propagateDataOrderImpl() const override {
+    void propagateDataOrderImpl(StageDataInfo<DimsOrder>& orderInfo) override {
         const auto& inputOrders = attrs().get<std::map<int, DimsOrder>>("inputOrders");
         const auto& outputOrders = attrs().get<std::map<int, DimsOrder>>("outputOrders");
 
-        for (const auto& inEdge : _inputEdges) {
+        for (const auto& inEdge : inputEdges()) {
             // last input is always OpenCL binary, so use it as is.
-            if (inEdge->portInd() == _inputEdges.size() - 1) {
+            if (inEdge->portInd() == numInputs() - 1) {
                 break;
             }
 
             auto it = inputOrders.find(inEdge->portInd());
             if (it != inputOrders.end()) {
                 auto requiredOrder = it->second;
-                _orderInfo.setInput(inEdge, requiredOrder);
+                orderInfo.setInput(inEdge, requiredOrder);
             }
         }
 
-        for (const auto& outEdge : _outputEdges) {
+        for (const auto& outEdge : outputEdges()) {
             auto it = outputOrders.find(outEdge->portInd());
             if (it != outputOrders.end()) {
                 auto requiredOrder = it->second;
-                _orderInfo.setOutput(outEdge, requiredOrder);
+                orderInfo.setOutput(outEdge, requiredOrder);
             }
         }
     }
 
-    void getDataStridesRequirementsImpl() const override {
-        for (const auto& inEdge : _inputEdges) {
+    void getDataStridesRequirementsImpl(StageDataInfo<StridesRequirement>& stridesInfo) override {
+        for (const auto& inEdge : inputEdges()) {
             // last input is always OpenCL binary, so use it as is.
-            if (inEdge->portInd() == _inputEdges.size() - 1) {
+            if (inEdge->portInd() == numInputs() - 1) {
                 break;
             }
 
-            _stridesInfo.setInput(inEdge, StridesRequirement::compact());
+            stridesInfo.setInput(inEdge, StridesRequirement::compact());
         }
-        for (const auto& outEdge : _outputEdges) {
-            _stridesInfo.setOutput(outEdge, StridesRequirement::compact());
+        for (const auto& outEdge : outputEdges()) {
+            stridesInfo.setOutput(outEdge, StridesRequirement::compact());
         }
     }
 
     void finalizeDataLayoutImpl() override {
     }
 
-    void getBatchSupportInfoImpl() const override {
-        for (const auto& inEdge : _inputEdges) {
+    void getBatchSupportInfoImpl(StageDataInfo<BatchSupport>& batchInfo) override {
+        std::vector<CustomDataFormat> formats = attrs().get<std::vector<CustomDataFormat>>("formats");
+
+        for (const auto& inEdge : inputEdges()) {
+            IE_ASSERT(inEdge->portInd() < formats.size());
+
             // last input is always OpenCL binary, so use it as is.
-            if (inEdge->portInd() == _inputEdges.size() - 1) {
+            if ((inEdge->portInd() == numInputs() - 1) || (formats[inEdge->portInd()] == CustomDataFormat::Any)) {
                 break;
             }
 
-            _batchInfo.setInput(inEdge, BatchSupport::Split);
+            batchInfo.setInput(inEdge, BatchSupport::Split);
         }
-        for (const auto& outEdge : _outputEdges) {
-            _batchInfo.setOutput(outEdge, BatchSupport::Split);
+        for (const auto& outEdge : outputEdges()) {
+            batchInfo.setOutput(outEdge, BatchSupport::Split);
         }
     }
 
-    void finalCheckImpl() const override {
-    }
-
     void serializeParamsImpl(BlobSerializer& serializer) const override {
         const auto& customLayer = attrs().get<CustomLayer::Ptr>("customLayer");
         const auto& gws = attrs().get<SmallVector<int, 3>>("gws");
@@ -136,7 +137,7 @@ private:
         // Total number of blobs
         //
 
-        serializer.append(static_cast<int32_t>(_inputEdges.size() + _outputEdges.size()));
+        serializer.append(static_cast<int32_t>(numInputs() + numOutputs()));
 
         //
         // Number of kernel parameters
@@ -200,15 +201,26 @@ private:
                             auto blob = parameter.irSource.substr(0, pos);
                             auto dim = parameter.irSource.substr(pos + 1, std::string::npos);
 
+                            IE_ASSERT(dim.length() == 1)
+                                    << "Unable to deduce parameter " << parameter.argName << " for "
+                                    << _origLayer->type <<" layer. Name is: " << _origLayer->name;
+                            char dimLetter = dim[0];
+
                             ie::DataPtr origData;
                             if (blob == "I") {
                                 origData = _origLayer->insData[parameter.portIndex].lock();
                             } else {
-                                origData = _origLayer->outData[0];
+                                origData = _origLayer->outData[parameter.portIndex];
                             }
                             IE_ASSERT(origData != nullptr);
 
                             auto dims = origData->getDims();
+                            int ndims = dims.size();
+
+                            if (ndims > 4)
+                                VPU_THROW_EXCEPTION
+                                    << "Unable to deduce parameter " << parameter.argName << " for "
+                                    << _origLayer->type <<" layer. Name is: " << _origLayer->name;
 
                             const std::map<char, int> vars = {
                                 { 'b', 0 }, { 'B', 0 },
@@ -217,8 +229,9 @@ private:
                                 { 'x', 3 }, { 'X', 3 },
                             };
 
-                            if (vars.find(dim[0]) != vars.end()) {
-                                auto res = dims.at(vars.at(dim[0]));
+                            auto var = vars.find(dimLetter);
+                            if (var != vars.end()) {
+                                auto res = dims.at(var->second-4+ndims);
 
                                 serializer.append(static_cast<uint32_t>(res));
                                 serializer.append(static_cast<int32_t>(-1));
@@ -258,15 +271,19 @@ private:
     }
 
     void serializeDataImpl(BlobSerializer& serializer) const override {
-        IE_ASSERT(_tempBufferEdges.empty());
+        IE_ASSERT(numTempBuffers() == 1);
 
-        for (const auto& inEdge : _inputEdges) {
+        for (const auto& inEdge : inputEdges()) {
             inEdge->input()->serializeOldBuffer(handle_from_this(), serializer);
         }
 
-        for (const auto& outEdge : _outputEdges) {
+        for (const auto& outEdge : outputEdges()) {
             outEdge->output()->serializeOldBuffer(handle_from_this(), serializer);
         }
+
+        for (const auto& tempEdge : tempBufferEdges()) {
+            tempEdge->tempBuffer()->serializeOldBuffer(handle_from_this(), serializer);
+        }
     }
 };
 
@@ -362,15 +379,18 @@ void FrontEnd::parseCustom(
         auto customLayer = customLayersForType[stage_num];
 
         std::map<std::string, int> ports;
+        std::vector<CustomDataFormat> formats;
 
         // Gather inputs
         DataVector stageInputs;
         for (auto& param : customLayer->bindings()) {
             if (param.type == CustomParamType::Input) {
                 ports[param.argName] = stageInputs.size();
+                formats.emplace_back(param.format);
                 stageInputs.emplace_back(inputs[param.portIndex]);
             } else if (param.type == CustomParamType::InputBuffer) {
                 ports[param.argName] = stageInputs.size();
+                formats.emplace_back(CustomDataFormat::BFYX);
                 stageInputs.emplace_back(tempBuffsMap[param.portIndex]);
             }
         }
@@ -386,12 +406,14 @@ void FrontEnd::parseCustom(
                         DataDesc({origBlob->size()}),
                         ieBlobContent(origBlob));
                     ports[param.argName] = stageInputs.size();
+                    formats.emplace_back(param.format);
                     stageInputs.emplace_back(std::move(customBlob));
                 }
             }
         }
 
         customLayer->setStageNumInputs(stageInputs.size());
+        formats.emplace_back(CustomDataFormat::Any);
 
         // Get kernel binary
         auto kernelNode = kernelNodes.find(customLayer->kernelBinary());
@@ -429,6 +451,7 @@ void FrontEnd::parseCustom(
 
         stage->attrs().set("customLayer", customLayer);
         stage->attrs().set("ports", ports);
+        stage->attrs().set("formats", formats);
 
         SmallVector<int, 3> gws;
         SmallVector<int, 3> lws;
@@ -447,25 +470,27 @@ void FrontEnd::parseCustom(
             b2b[kp.argName] = kp;
         }
 
-        const std::map<CustomDataFormat, DimsOrder> formats = {
+        const std::map<CustomDataFormat, DimsOrder> formatsMap = {
             { CustomDataFormat::BYXF, DimsOrder::NHWC },
-            { CustomDataFormat::BFYX, DimsOrder::NCHW }
+            { CustomDataFormat::BFYX, DimsOrder::NCHW },
+            { CustomDataFormat::YXF, DimsOrder::HWC },
+            { CustomDataFormat::FYX, DimsOrder::CHW }
         };
 
         for (const auto& kp : customLayer->parameters()) {
             const auto& parameter = b2b[kp];
 
             if (parameter.type == CustomParamType::Input) {
-                auto it = formats.find(parameter.format);
-                if (it != formats.end()) {
+                auto it = formatsMap.find(parameter.format);
+                if (it != formatsMap.end()) {
                     auto requiredOrder = it->second;
                     inputOrders[parameter.portIndex] = requiredOrder;
                 }
             }
 
             if (parameter.type == CustomParamType::Output) {
-                auto it = formats.find(parameter.format);
-                if (it != formats.end()) {
+                auto it = formatsMap.find(parameter.format);
+                if (it != formatsMap.end()) {
                     auto requiredOrder = it->second;
                     outputOrders[parameter.portIndex] = requiredOrder;
                 }
@@ -474,6 +499,11 @@ void FrontEnd::parseCustom(
 
         stage->attrs().set("inputOrders", std::move(inputOrders));
         stage->attrs().set("outputOrders", std::move(outputOrders));
+
+        int buffer_size = customLayer->kernelBinary().length() + 1024;
+        model->addTempBuffer(
+            stage,
+            DataDesc({buffer_size}));
     }
 }