[GNA] Fix a global buffer overflow in GNAModelSerial::Import (#3290) (#3327)
[platform/upstream/dldt.git] / inference-engine / src / gna_plugin / gna_model_serial.cpp
index beff0af..4ed6511 100644 (file)
@@ -20,6 +20,9 @@
 
 #include "gna_plugin.hpp"
 #include "gna_model_serial.hpp"
+#include "serial/headers/latest/gna_model_header.hpp"
+
+using namespace GNAPluginNS;
 
 inline void writeNBytes(const void *ptr, uint32_t size, std::ostream & os) {
     os.write(static_cast<const char*>(ptr), size);
@@ -69,11 +72,25 @@ bool is_little_endian() {
 
 const int gna_header_magic = is_little_endian() ?  0x4d414e47 : 0x474e414d;
 
-ModelHeader GNAModelSerial::ReadHeader(std::istream &is) {
+GNAPluginNS::HeaderLatest::ModelHeader GNAModelSerial::ReadHeader(std::istream &is) {
     is.exceptions(std::istream::failbit);
-
-    ModelHeader header;
-    readBits(header, is);
+    is.seekg(0, is.end);
+    auto stream_len = is.tellg();
+    if (stream_len == -1) {
+        THROW_GNA_EXCEPTION << "Can't open file to import";
+    }
+    is.seekg(0, is.beg);
+
+    HeaderLatest::ModelHeader header;
+    header.version.major = 0u;
+    header.version.minor = 0u;
+    auto size_of_headers_header = sizeof(HeaderLatest::ModelHeader::gnam) + sizeof(HeaderLatest::ModelHeader::headerSize)
+                                + sizeof(HeaderLatest::ModelHeader::Version);
+    if (stream_len > size_of_headers_header) {
+        readNBytes(&header, size_of_headers_header, is);
+    } else {
+        readNBytes(&header, stream_len, is);
+    }
     if (*reinterpret_cast<int*>(header.gnam) != gna_header_magic) {
         THROW_GNA_EXCEPTION << "Imported file unsupported: magic number should be GNAM(0x474e414d), but was 0x"
                            << std::setfill('0') <<
@@ -82,12 +99,28 @@ ModelHeader GNAModelSerial::ReadHeader(std::istream &is) {
                            std::hex << std::setw(2) << static_cast<short>(header.gnam[2]) <<
                            std::hex << std::setw(2) << static_cast<short>(header.gnam[3]);
     }
-    if (header.version.major != HEADER_MAJOR) {
-        THROW_GNA_EXCEPTION << "Imported file unsupported: major version should be == " << HEADER_MAJOR;
-    }
-    if (header.headerSize < sizeof(header)) {
-        THROW_GNA_EXCEPTION << "Unsupported header size minimal value is : " << sizeof (header) << ", but read: " << header.headerSize;
+
+    is.seekg(0, is.beg);
+    Header2dot1::ModelHeader tempHeader2dot1;
+    switch (header.version.major) {
+        case 2:
+            switch (header.version.minor) {
+                case 1:
+                    readBits(tempHeader2dot1, is);
+                    header = Header2dot3::ModelHeader(tempHeader2dot1);
+                    break;
+                case 2:
+                case 3:
+                    readBits(header, is);
+                    break;
+                default:
+                    THROW_GNA_EXCEPTION << "Imported file unsupported. minor version should be equal to 1 or 2 and is: " << header.version.minor;
+            }
+            break;
+        default:
+            THROW_GNA_EXCEPTION << "Imported file unsupported. Import for files with major version equal to: " << header.version.major << " is not implemented";
     }
+
     /*
      * extra data need to be added into new header and modify check as appropriate
      */
@@ -134,7 +167,30 @@ void GNAModelSerial::Import(void *basePointer,
         InferenceEngine::OutputsDataMap& outputsDataMap) {
     is.exceptions(std::istream::failbit);
 
+    if (modelHeader.version.major == 2) {
+        if (modelHeader.version.minor >= 3) {
+            for (auto inputIndex = 0; inputIndex < modelHeader.nInputs; inputIndex++) {
+                uint32_t nameSize = 0;
+                readNBits<32>(nameSize, is);
+                std::string inName(nameSize, '\0');
+                readNBytes(&inName[0], nameSize, is);
+                inputNames.push_back(inName.substr(0, nameSize - 1));
+            }
+        }
+    }
     ImportInputs(is, basePointer, inputsDesc, inputsDataMap);
+
+    if (modelHeader.version.major == 2) {
+        if (modelHeader.version.minor >= 3) {
+            for (auto inputIndex = 0; inputIndex < modelHeader.nOutputs; inputIndex++) {
+                uint32_t nameSize = 0;
+                readNBits<32>(nameSize, is);
+                std::string outName(nameSize, '\0');
+                readNBytes(&outName[0], nameSize, is);
+                outputNames.push_back(outName.substr(0, nameSize - 1));
+            }
+        }
+    }
     ImportOutputs(is, basePointer, desc, outputsDataMap);
 
     for (auto operation = gna2Model->Operations; operation != gna2Model->Operations + gna2Model->NumberOfOperations; ++operation) {
@@ -173,6 +229,7 @@ void GNAModelSerial::Import(void *basePointer,
         for (uint32_t i = 0; i < operation->NumberOfParameters; i++) {
             uint32_t paramSize = 0;
             readBits(paramSize, is);
+            IE_ASSERT(operation->Parameters != nullptr);
             if (paramSize == 0) {
                 IE_ASSERT(operation->Parameters != nullptr);
                 operation->Parameters[i] = nullptr;
@@ -248,8 +305,8 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
         return out;
     };
 
-    auto convert_to_serial = [getOffsetFromBase](const GNAModelSerial::RuntimeEndPoint& ep) {
-        RuntimeEndPoint out;
+    auto convert_to_serial = [getOffsetFromBase](const HeaderLatest::RuntimeEndPoint& ep) {
+        HeaderLatest::RuntimeEndPoint out;
         out.elements_count = ep.elements_count;
         out.descriptor_offset = offsetFromBase(ep.descriptor_ptr);
         out.scaleFactor = ep.scaleFactor;
@@ -260,14 +317,12 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
     /**
      * writing header
      */
-    ModelHeader header;
+    HeaderLatest::ModelHeader header;
     header.gnam[0] = 'G';
     header.gnam[1] = 'N';
     header.gnam[2] = 'A';
     header.gnam[3] = 'M';
-    header.headerSize = sizeof(ModelHeader);
-    header.version.major = HEADER_MAJOR;
-    header.version.minor = HEADER_MINOR;
+    header.headerSize = sizeof(HeaderLatest::ModelHeader);
     header.gnaMemSize = gnaGraphSize;
     header.layersCount = layers.size();
     header.nGroup = guessGrouping(*gna2Model);
@@ -280,9 +335,19 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
 
     writeBits(header, os);
 
+    for (auto &name : inputNames) {
+        const auto nameSize = strlen(name.c_str()) + 1;
+        writeBits(static_cast<uint32_t>(nameSize), os);
+        writeNBytes(name.c_str(), nameSize , os);
+    }
     for (const auto &input : inputs) {
         writeBits(convert_to_serial(input), os);
     }
+    for (auto &name : outputNames) {
+        const auto nameSize = strlen(name.c_str()) + 1;
+        writeBits(static_cast<uint32_t>(nameSize), os);
+        writeNBytes(name.c_str(), nameSize, os);
+    }
     for (const auto &output : outputs) {
         writeBits(convert_to_serial(output), os);
     }
@@ -495,8 +560,8 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
         }
     };
 
-    auto convert_to_serial = [getOffsetFromBase](const GNAModelSerial::RuntimeEndPoint& ep){
-        RuntimeEndPoint out;
+    auto convert_to_serial = [getOffsetFromBase](const HeaderLatest::RuntimeEndPoint& ep){
+        HeaderLatest::RuntimeEndPoint out;
         out.elements_count = ep.elements_count;
         out.element_size = ep.element_size;
         out.descriptor_offset = offsetFromBase(ep.descriptor_ptr);
@@ -507,19 +572,19 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
     /**
      * writing header
      */
-    ModelHeader header;
+    HeaderLatest::ModelHeader header;
     header.gnam[0] = 'G';
     header.gnam[1] = 'N';
     header.gnam[2] = 'A';
     header.gnam[3] = 'M';
-    header.version.major = HEADER_MAJOR;
-    header.version.minor = HEADER_MINOR;
+    header.version.major = 1u;
+    header.version.minor = 1u;
     header.gnaMemSize = gnaGraphSize;
     header.layersCount = layers.size();
     header.nGroup = ptr_nnet->nGroup;
     header.nInputs = 1;
     header.nOutputs = 1;
-    header.headerSize = sizeof(ModelHeader);
+    header.headerSize = sizeof(HeaderLatest::ModelHeader);
     header.nRotateRows = nRotateRows;
     header.nRotateColumns = nRotateColumns;
 
@@ -606,16 +671,16 @@ void GNAModelSerial::Export(void * basePointer, size_t gnaGraphSize, std::ostrea
 
 #endif
 
-std::vector<GNAModelSerial::RuntimeEndPoint> GNAModelSerial::serializeOutputs(const InferenceEngine::OutputsDataMap& outputsDataMap,
+std::vector<HeaderLatest::RuntimeEndPoint> GNAModelSerial::serializeOutputs(const InferenceEngine::OutputsDataMap& outputsDataMap,
         const std::vector<GNAPluginNS::OutputDesc>& outputsDesc) {
-    std::vector<GNAModelSerial::RuntimeEndPoint> endPoints;
+    std::vector<HeaderLatest::RuntimeEndPoint> endPoints;
     std::size_t outputIndex = 0;
     for (auto const &output : outputsDataMap) {
         auto outputName = output.first;
         auto inputDims = output.second->getTensorDesc().getDims();
         uint32_t elementsCount = static_cast<uint32_t>(InferenceEngine::details::product(inputDims.begin(), inputDims.end()));
 
-        GNAModelSerial::RuntimeEndPoint endPoint(outputsDesc[outputIndex].scale_factor,
+        HeaderLatest::RuntimeEndPoint endPoint(outputsDesc[outputIndex].scale_factor,
                                                  outputsDesc[outputIndex].ptrs[0],
                                                  outputsDesc[outputIndex].num_bytes_per_element,
                                                  elementsCount,
@@ -626,9 +691,9 @@ std::vector<GNAModelSerial::RuntimeEndPoint> GNAModelSerial::serializeOutputs(co
     return endPoints;
 }
 
-std::vector<GNAModelSerial::RuntimeEndPoint> GNAModelSerial::serializeInputs(const InferenceEngine::InputsDataMap& inputsDataMap,
+std::vector<HeaderLatest::RuntimeEndPoint> GNAModelSerial::serializeInputs(const InferenceEngine::InputsDataMap& inputsDataMap,
                                                                              std::shared_ptr<GNAPluginNS::InputDesc> inputDesc) {
-    std::vector<GNAModelSerial::RuntimeEndPoint> endPoints;
+    std::vector<HeaderLatest::RuntimeEndPoint> endPoints;
 
     std::size_t inputIndex = 0;
     for (auto const& input : inputsDataMap) {
@@ -642,7 +707,7 @@ std::vector<GNAModelSerial::RuntimeEndPoint> GNAModelSerial::serializeInputs(con
         uint32_t elementsCount = static_cast<uint32_t>(InferenceEngine::details::product(inputDims.begin(), inputDims.end()));
         intel_dnn_orientation_t orientation = inputDesc->getOrientation(inputName);
 
-        GNAModelSerial::RuntimeEndPoint endPoint(scaleFactor,
+        HeaderLatest::RuntimeEndPoint endPoint(scaleFactor,
                                                  descriptor_ptr[0],
                                                  element_size,
                                                  elementsCount,
@@ -660,11 +725,13 @@ void GNAModelSerial::ImportInputs(std::istream &is,
     dataMap.clear();
 
     for (auto inputIndex = 0; inputIndex < modelHeader.nInputs; inputIndex++) {
-        std::string name = "input" + std::to_string(inputIndex);
-        RuntimeEndPoint input;
+        const std::string& name = (modelHeader.version.major == 2 && modelHeader.version.minor >= 3)
+                ? inputNames.at(inputIndex) : std::string("input" + std::to_string(inputIndex));
+        HeaderLatest::RuntimeEndPoint input;
         is.read(reinterpret_cast<char *>(&input), sizeof(input));
         inputsDesc->getPtrInputsGlobal(name).push_back(reinterpret_cast<float*>(reinterpret_cast<uint8_t *> (basePtr) + input.descriptor_offset));
         inputsDesc->orientation_in[name] = input.orientation;
+        inputsDesc->bytes_allocated_for_input[name] = input.element_size * input.elements_count;
 
         auto inputDims = InferenceEngine::SizeVector({modelHeader.nGroup, input.elements_count / modelHeader.nGroup});
 
@@ -687,10 +754,11 @@ void GNAModelSerial::ImportOutputs(std::istream &is,
     desc.resize(modelHeader.nOutputs);
 
     for (auto outputIndex = 0; outputIndex < modelHeader.nOutputs; outputIndex++) {
-        std::string name = "output" + std::to_string(outputIndex);
-        RuntimeEndPoint output;
+        const std::string& name = (modelHeader.version.major == 2 && modelHeader.version.minor >= 3)
+                                  ? outputNames.at(outputIndex) : std::string("input" + std::to_string(outputIndex));
+        HeaderLatest::RuntimeEndPoint output;
         is.read(reinterpret_cast<char *>(&output), sizeof(output));
-        GNAPluginNS::OutputDesc description;
+        OutputDesc description;
         description.ptrs.push_back(reinterpret_cast<float*>(reinterpret_cast<uint8_t *> (basePtr) + output.descriptor_offset));
         description.orientation = kDnnInterleavedOrientation;
         description.orientation = output.orientation;
@@ -707,6 +775,6 @@ void GNAModelSerial::ImportOutputs(std::istream &is,
     }
 }
 
-void GNAModelSerial::setHeader(ModelHeader header) {
+void GNAModelSerial::setHeader(HeaderLatest::ModelHeader header) {
     modelHeader = header;
 }