2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "auto_tuner.h"
22 #include "istreamwrapper.h"
23 #include "stringbuffer.h"
24 #include "prettywriter.h"
27 namespace kernel_selector
29 std::tuple<std::string, int> AutoTuner::LoadKernelOnline(const TuningMode tuningMode, const std::string& cacheFilePath, const uint32_t computeUnitsCount, const std::string& hash)
31 std::lock_guard<std::mutex> lock(mutex);
32 rapidjson::Document cacheData;
33 std::ifstream tuningFile(cacheFilePath);
34 if (tuningFile && tuningFile.good())
36 rapidjson::IStreamWrapper isw{ tuningFile };
37 cacheData.ParseStream(isw);
39 else // Tuning file doesn't exist
41 if (tuningMode == TuningMode::TUNING_USE_CACHE)
43 throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
46 // Create a new tuning file and write the versions
47 std::ofstream newTuningFile(cacheFilePath, std::ofstream::out);
52 onlineCache = std::make_shared<rapidjson::Document>(std::move(cacheData));
54 // Tuning file is loaded
55 auto computeUnitsStr = std::to_string(computeUnitsCount);
56 if (!onlineCache->IsNull())
58 auto cacheObject = onlineCache->GetObject();
59 if (onlineCache->HasMember(computeUnitsStr.c_str()))
61 if (cacheObject[computeUnitsStr.c_str()].HasMember(hash.c_str()))
63 const rapidjson::Value& prog = cacheObject[computeUnitsStr.c_str()][hash.c_str()];
64 return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
68 return std::make_pair("", 0);
72 void AutoTuner::StoreKernel(const std::string& cacheFilePath, const std::string& hash, std::string implementationName, const int tuneIndex, const uint32_t computeUnitsCount)
74 std::lock_guard<std::mutex> lock(mutex);
75 auto computeUnitsStr = std::to_string(computeUnitsCount);
76 rapidjson::Document::AllocatorType& allocator = onlineCache->GetAllocator();
77 rapidjson::Value dataArray(rapidjson::kArrayType);
78 rapidjson::Value hashStr(rapidjson::kStringType);
79 hashStr.Set(hash.c_str(), allocator);
80 dataArray.PushBack(rapidjson::Value().Set(implementationName.c_str(),allocator) , allocator);
81 dataArray.PushBack(rapidjson::Value().SetInt(tuneIndex), allocator);
83 rapidjson::Value newVal(rapidjson::kObjectType);
85 if (onlineCache->IsNull())
87 onlineCache->Parse("{}");
89 if (!onlineCache->HasMember(computeUnitsStr.c_str()))
91 onlineCache->AddMember(rapidjson::Value(computeUnitsStr.c_str(), allocator), newVal, allocator);
94 auto cache = onlineCache->GetObject();
95 cache[computeUnitsStr.c_str()].AddMember(hashStr, dataArray, allocator);
97 std::ofstream cachedKernelsFile(cacheFilePath);
98 rapidjson::StringBuffer buffer(0, 1024);
99 rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
100 onlineCache->Accept(writer);
101 auto temp = buffer.GetString();
102 cachedKernelsFile << temp;
103 cachedKernelsFile.close();
107 std::tuple<std::string, int> AutoTuner::LoadKernelOffline(std::shared_ptr<rapidjson::Document> deviceCache, const std::string& hash)
109 if (!deviceCache->IsNull())
111 auto cache = deviceCache->GetObject();
112 if (deviceCache->HasMember(hash.c_str()))
114 const rapidjson::Value& prog = cache[hash.c_str()];
115 return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
118 return std::make_tuple("", 0);