Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / auto_tuner.cpp
index d9ccd15..307390d 100644 (file)
 */
 
 #include "auto_tuner.h"
-#include "auto_tuner_offline.h"
 #include <iostream>
 #include <sstream>
 #include <fstream>
+#include <iomanip>
+#include "istreamwrapper.h"
+#include "stringbuffer.h"
+#include "prettywriter.h"
 
-namespace kernel_selector 
+
+namespace kernel_selector
 {
-    std::tuple<std::string, int> AutoTuner::LoadKernelOnline(const TuningMode tuningMode, const std::string& tuningFilePath, const std::string& deviceID, const std::string& driverVersion, const std::string& hostVersion, const std::string& hash)
+    std::tuple<std::string, int> AutoTuner::LoadKernelOnline(const TuningMode tuningMode, const std::string& cacheFilePath, const uint32_t computeUnitsCount,  const std::string& hash)
     {
         std::lock_guard<std::mutex> lock(mutex);
-
-        //First, check if the tuning file has been already loaded to cache
-        auto const& tuningFileCache = onlineCache.find(tuningFilePath);
-        if (tuningFileCache == onlineCache.end())
+        rapidjson::Document cacheData;
+        std::ifstream tuningFile(cacheFilePath);
+        if (tuningFile && tuningFile.good())
         {
-            // Load tuning file to cache
-            onlineCache[tuningFilePath] = {};
-
-            std::ifstream tuningFile(tuningFilePath);
-            std::string cachedDeviceId;
-            std::string cachedDriverVersion;
-            std::string cachedHostVersion;
-            std::string cachedhash;
-            std::string cachedkernelName;
-            int cachedIndex;
-            std::string line;
-
-            if (tuningFile) // Tuning file exists
+            rapidjson::IStreamWrapper isw{ tuningFile };
+            cacheData.ParseStream(isw);
+        }
+        else // Tuning file doesn't exist
+        {
+            if (tuningMode == TuningMode::TUNING_USE_CACHE)
             {
-                // Read device ID
-                tuningFile >> cachedDeviceId;
-                if (!tuningFile.good() || (cachedDeviceId.compare(deviceID) != 0))
-                {
-                    throw std::runtime_error("Tuning file bad structure or wrong device ID. Re-generate cache in TUNE_AND_CACHE mode.");
-                }
-
-                // Read driver version
-                tuningFile >> cachedDriverVersion;
-                if (!tuningFile.good() || (cachedDriverVersion.compare(driverVersion) != 0))
-                {
-                    throw std::runtime_error("Tuning file bad structure or wrong driver version. Re-generate cache in TUNE_AND_CACHE mode.");
-                }
+                throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
+            }
 
-                // Read host version
-                tuningFile >> cachedHostVersion;
-                if (!tuningFile.good() || (cachedHostVersion.compare(hostVersion) != 0))
-                {
-                    throw std::runtime_error("Tuning file bad structure or wrong host version. Re-generate cache in TUNE_AND_CACHE mode.");
-                }
+            // Create a new tuning file and write the versions
+            std::ofstream newTuningFile(cacheFilePath, std::ofstream::out);
 
-                // Read optimal kernel/config data 
-                while (std::getline(tuningFile, line))
-                {
-                    if (line.empty())
-                    {
-                        continue;
-                    }
-                    std::istringstream iss(line);
-                    iss >> cachedhash >> cachedkernelName >> cachedIndex;
-                    if (iss.fail())
-                    {
-                        throw std::runtime_error("Tuning file bad structure. Re-generate cache in TUNE_AND_CACHE mode.");
-                    }
+        }
+        tuningFile.close();
 
-                    // Update tuning cache 
-                    onlineCache[tuningFilePath].td[cachedhash] = std::make_tuple(cachedkernelName, cachedIndex);
-                }
+        onlineCache = std::make_shared<rapidjson::Document>(std::move(cacheData));
 
-                tuningFile.close();
-            }
-            else // Tuning file doesn't exist
+        // Tuning file is loaded
+        auto computeUnitsStr = std::to_string(computeUnitsCount);
+        if (!onlineCache->IsNull())
+        {
+            auto cacheObject = onlineCache->GetObject();
+            if (onlineCache->HasMember(computeUnitsStr.c_str()))
             {
-                if (tuningMode == TuningMode::TUNING_USE_CACHE)
+                if (cacheObject[computeUnitsStr.c_str()].HasMember(hash.c_str()))
                 {
-                    throw std::runtime_error("Tuning file: " + tuningFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
+                    const rapidjson::Value& prog = cacheObject[computeUnitsStr.c_str()][hash.c_str()];
+                    return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
                 }
-
-                // Create a new tuning file and write the versions
-                std::ofstream newTuningFile(tuningFilePath, std::ofstream::out);
-
-                newTuningFile << deviceID << "\n";
-                newTuningFile << driverVersion << "\n";
-                newTuningFile << hostVersion << "\n";
             }
         }
+        return std::make_pair("", 0);
+        
+    }
 
-        // Tuning file is loaded
-        auto const& tuningFileData = onlineCache[tuningFilePath];
-        auto const& hashData = tuningFileData.td.find(hash);
-        if (hashData != tuningFileData.td.end())
+    void AutoTuner::StoreKernel(const std::string& cacheFilePath, const std::string& hash, std::string implementationName, const int tuneIndex, const uint32_t computeUnitsCount)
+    {
+        std::lock_guard<std::mutex> lock(mutex);
+        auto computeUnitsStr = std::to_string(computeUnitsCount);
+        rapidjson::Document::AllocatorType& allocator = onlineCache->GetAllocator();
+        rapidjson::Value dataArray(rapidjson::kArrayType);
+        rapidjson::Value hashStr(rapidjson::kStringType);
+        hashStr.Set(hash.c_str(), allocator);
+        dataArray.PushBack(rapidjson::Value().Set(implementationName.c_str(),allocator) , allocator);
+        dataArray.PushBack(rapidjson::Value().SetInt(tuneIndex), allocator);
+
+        rapidjson::Value newVal(rapidjson::kObjectType);
+        newVal.SetObject();
+        if (onlineCache->IsNull())
         {
-            // Tuning data exists for this hash.
-            return hashData->second;
+            onlineCache->Parse("{}");
         }
-        else
+        if (!onlineCache->HasMember(computeUnitsStr.c_str()))
         {
-            // Tuning data doesn't exists for this hash - on-line tuning is needed.
-            return std::make_pair("", 0);
+            onlineCache->AddMember(rapidjson::Value(computeUnitsStr.c_str(), allocator), newVal, allocator);
         }
-    }
-
-    void AutoTuner::StoreKernel(const std::string& tuningFilePath, const std::string& hash, const std::string& implementationName, const int tuneIndex)
-    {
-        std::lock_guard<std::mutex> lock(mutex);
 
-        // Add the new tuning data to cache
-        onlineCache[tuningFilePath].td[hash] = std::make_tuple(implementationName, tuneIndex);
+        auto cache = onlineCache->GetObject();
+        cache[computeUnitsStr.c_str()].AddMember(hashStr, dataArray, allocator);
 
-        // Add the new tuning data to tuning file
-        std::ofstream cachedKernelsFile(tuningFilePath, std::ofstream::out | std::ofstream::app);
-        if (!cachedKernelsFile.good())
-        {
-            throw std::runtime_error("Tuning file: " + tuningFilePath + " could not be written!");
-        }
-        cachedKernelsFile << hash << " ";
-        cachedKernelsFile << implementationName << " ";
-        cachedKernelsFile << tuneIndex << "\n";
+        std::ofstream cachedKernelsFile(cacheFilePath);
+        rapidjson::StringBuffer buffer(0, 1024);
+        rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
+        onlineCache->Accept(writer);
+        auto temp = buffer.GetString();
+        cachedKernelsFile << temp;
         cachedKernelsFile.close();
     }
 
-    std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const std::string& deviceID, const std::string& hash)
+
+    std::tuple<std::string, int> AutoTuner::LoadKernelOffline(std::shared_ptr<rapidjson::Document> deviceCache, const std::string& hash)
     {
-        auto const& deviceCache = auto_tuner_offline::get_instance(deviceID)->get_tuning_data();
-        if (deviceCache.td.empty())
-        {
-            return std::make_pair("", 0);
-        }
-        auto const& deviceCacheData = deviceCache.td;
-        auto const& hashData = deviceCacheData.find(hash);
-        if (hashData == deviceCacheData.end())
+        if (!deviceCache->IsNull())
         {
-            return std::make_pair("", 0);
-        }
-        else
-        {
-            return hashData->second;
+            auto cache = deviceCache->GetObject();
+            if (deviceCache->HasMember(hash.c_str()))
+            {
+                const rapidjson::Value& prog = cache[hash.c_str()];
+                return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
+            }
         }
+        return std::make_tuple("", 0);
     }
 }