*/
#include "auto_tuner.h"
-#include "auto_tuner_offline.h"
#include <iostream>
#include <sstream>
#include <fstream>
+#include <iomanip>
+#include "istreamwrapper.h"
+#include "stringbuffer.h"
+#include "prettywriter.h"
-
-namespace kernel_selector
+
+namespace kernel_selector
{
- std::tuple<std::string, int> AutoTuner::LoadKernelOnline(const TuningMode tuningMode, const std::string& tuningFilePath, const std::string& deviceID, const std::string& driverVersion, const std::string& hostVersion, const std::string& hash)
+ std::tuple<std::string, int> AutoTuner::LoadKernelOnline(const TuningMode tuningMode, const std::string& cacheFilePath, const uint32_t computeUnitsCount, const std::string& hash)
{
std::lock_guard<std::mutex> lock(mutex);
-
- //First, check if the tuning file has been already loaded to cache
- auto const& tuningFileCache = onlineCache.find(tuningFilePath);
- if (tuningFileCache == onlineCache.end())
+ rapidjson::Document cacheData;
+ std::ifstream tuningFile(cacheFilePath);
+ if (tuningFile && tuningFile.good())
{
- // Load tuning file to cache
- onlineCache[tuningFilePath] = {};
-
- std::ifstream tuningFile(tuningFilePath);
- std::string cachedDeviceId;
- std::string cachedDriverVersion;
- std::string cachedHostVersion;
- std::string cachedhash;
- std::string cachedkernelName;
- int cachedIndex;
- std::string line;
-
- if (tuningFile) // Tuning file exists
+ rapidjson::IStreamWrapper isw{ tuningFile };
+ cacheData.ParseStream(isw);
+ }
+ else // Tuning file doesn't exist
+ {
+ if (tuningMode == TuningMode::TUNING_USE_CACHE)
{
- // Read device ID
- tuningFile >> cachedDeviceId;
- if (!tuningFile.good() || (cachedDeviceId.compare(deviceID) != 0))
- {
- throw std::runtime_error("Tuning file bad structure or wrong device ID. Re-generate cache in TUNE_AND_CACHE mode.");
- }
-
- // Read driver version
- tuningFile >> cachedDriverVersion;
- if (!tuningFile.good() || (cachedDriverVersion.compare(driverVersion) != 0))
- {
- throw std::runtime_error("Tuning file bad structure or wrong driver version. Re-generate cache in TUNE_AND_CACHE mode.");
- }
+ throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
+ }
- // Read host version
- tuningFile >> cachedHostVersion;
- if (!tuningFile.good() || (cachedHostVersion.compare(hostVersion) != 0))
- {
- throw std::runtime_error("Tuning file bad structure or wrong host version. Re-generate cache in TUNE_AND_CACHE mode.");
- }
+ // Create a new tuning file and write the versions
+ std::ofstream newTuningFile(cacheFilePath, std::ofstream::out);
- // Read optimal kernel/config data
- while (std::getline(tuningFile, line))
- {
- if (line.empty())
- {
- continue;
- }
- std::istringstream iss(line);
- iss >> cachedhash >> cachedkernelName >> cachedIndex;
- if (iss.fail())
- {
- throw std::runtime_error("Tuning file bad structure. Re-generate cache in TUNE_AND_CACHE mode.");
- }
+ }
+ tuningFile.close();
- // Update tuning cache
- onlineCache[tuningFilePath].td[cachedhash] = std::make_tuple(cachedkernelName, cachedIndex);
- }
+ onlineCache = std::make_shared<rapidjson::Document>(std::move(cacheData));
- tuningFile.close();
- }
- else // Tuning file doesn't exist
+ // Tuning file is loaded
+ auto computeUnitsStr = std::to_string(computeUnitsCount);
+ if (!onlineCache->IsNull())
+ {
+ auto cacheObject = onlineCache->GetObject();
+ if (onlineCache->HasMember(computeUnitsStr.c_str()))
{
- if (tuningMode == TuningMode::TUNING_USE_CACHE)
+ if (cacheObject[computeUnitsStr.c_str()].HasMember(hash.c_str()))
{
- throw std::runtime_error("Tuning file: " + tuningFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
+ const rapidjson::Value& prog = cacheObject[computeUnitsStr.c_str()][hash.c_str()];
+ return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
}
-
- // Create a new tuning file and write the versions
- std::ofstream newTuningFile(tuningFilePath, std::ofstream::out);
-
- newTuningFile << deviceID << "\n";
- newTuningFile << driverVersion << "\n";
- newTuningFile << hostVersion << "\n";
}
}
+ return std::make_pair("", 0);
+
+ }
- // Tuning file is loaded
- auto const& tuningFileData = onlineCache[tuningFilePath];
- auto const& hashData = tuningFileData.td.find(hash);
- if (hashData != tuningFileData.td.end())
+ void AutoTuner::StoreKernel(const std::string& cacheFilePath, const std::string& hash, std::string implementationName, const int tuneIndex, const uint32_t computeUnitsCount)
+ {
+ std::lock_guard<std::mutex> lock(mutex);
+ auto computeUnitsStr = std::to_string(computeUnitsCount);
+ rapidjson::Document::AllocatorType& allocator = onlineCache->GetAllocator();
+ rapidjson::Value dataArray(rapidjson::kArrayType);
+ rapidjson::Value hashStr(rapidjson::kStringType);
+ hashStr.Set(hash.c_str(), allocator);
+ dataArray.PushBack(rapidjson::Value().Set(implementationName.c_str(),allocator) , allocator);
+ dataArray.PushBack(rapidjson::Value().SetInt(tuneIndex), allocator);
+
+ rapidjson::Value newVal(rapidjson::kObjectType);
+ newVal.SetObject();
+ if (onlineCache->IsNull())
{
- // Tuning data exists for this hash.
- return hashData->second;
+ onlineCache->Parse("{}");
}
- else
+ if (!onlineCache->HasMember(computeUnitsStr.c_str()))
{
- // Tuning data doesn't exists for this hash - on-line tuning is needed.
- return std::make_pair("", 0);
+ onlineCache->AddMember(rapidjson::Value(computeUnitsStr.c_str(), allocator), newVal, allocator);
}
- }
-
- void AutoTuner::StoreKernel(const std::string& tuningFilePath, const std::string& hash, const std::string& implementationName, const int tuneIndex)
- {
- std::lock_guard<std::mutex> lock(mutex);
- // Add the new tuning data to cache
- onlineCache[tuningFilePath].td[hash] = std::make_tuple(implementationName, tuneIndex);
+ auto cache = onlineCache->GetObject();
+ cache[computeUnitsStr.c_str()].AddMember(hashStr, dataArray, allocator);
- // Add the new tuning data to tuning file
- std::ofstream cachedKernelsFile(tuningFilePath, std::ofstream::out | std::ofstream::app);
- if (!cachedKernelsFile.good())
- {
- throw std::runtime_error("Tuning file: " + tuningFilePath + " could not be written!");
- }
- cachedKernelsFile << hash << " ";
- cachedKernelsFile << implementationName << " ";
- cachedKernelsFile << tuneIndex << "\n";
+ std::ofstream cachedKernelsFile(cacheFilePath);
+ rapidjson::StringBuffer buffer(0, 1024);
+ rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
+ onlineCache->Accept(writer);
+ auto temp = buffer.GetString();
+ cachedKernelsFile << temp;
cachedKernelsFile.close();
}
- std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const std::string& deviceID, const std::string& hash)
+
+ std::tuple<std::string, int> AutoTuner::LoadKernelOffline(std::shared_ptr<rapidjson::Document> deviceCache, const std::string& hash)
{
- auto const& deviceCache = auto_tuner_offline::get_instance(deviceID)->get_tuning_data();
- if (deviceCache.td.empty())
- {
- return std::make_pair("", 0);
- }
- auto const& deviceCacheData = deviceCache.td;
- auto const& hashData = deviceCacheData.find(hash);
- if (hashData == deviceCacheData.end())
+ if (!deviceCache->IsNull())
{
- return std::make_pair("", 0);
- }
- else
- {
- return hashData->second;
+ auto cache = deviceCache->GetObject();
+ if (deviceCache->HasMember(hash.c_str()))
+ {
+ const rapidjson::Value& prog = cache[hash.c_str()];
+ return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
+ }
}
+ return std::make_tuple("", 0);
}
}