IVGCVSW-5483 'Implement Loading and Saving to File'
[platform/upstream/armnn.git] / src / backends / cl / ClContextDeserializer.cpp
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClContextDeserializer.hpp"
7 #include "ClContextSchema_generated.h"
8
9 #include <armnn/Exceptions.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11
12 #include <flatbuffers/flexbuffers.h>
13
14 #include <fmt/format.h>
15
16 #include <cstdlib>
17 #include <fstream>
18 #include <iostream>
19 #include <vector>
20
21 namespace armnn
22 {
23
24 void ClContextDeserializer::Deserialize(arm_compute::CLCompileContext& clCompileContext,
25                                         cl::Context& context,
26                                         cl::Device& device,
27                                         const std::string& filePath)
28 {
29     std::ifstream inputFileStream(filePath, std::ios::binary);
30     std::vector<std::uint8_t> binaryContent;
31     while (inputFileStream)
32     {
33         char input;
34         inputFileStream.get(input);
35         if (inputFileStream)
36         {
37             binaryContent.push_back(static_cast<std::uint8_t>(input));
38         }
39     }
40     inputFileStream.close();
41     DeserializeFromBinary(clCompileContext, context, device, binaryContent);
42 }
43
44 void ClContextDeserializer::DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext,
45                                                   cl::Context& context,
46                                                   cl::Device& device,
47                                                   const std::vector<uint8_t>& binaryContent)
48 {
49     if (binaryContent.data() == nullptr)
50     {
51         throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
52                                                    CHECK_LOCATION().AsString()));
53     }
54
55     size_t binaryContentSize = binaryContent.size();
56     flatbuffers::Verifier verifier(binaryContent.data(), binaryContentSize);
57     if (verifier.VerifyBuffer<ClContext>() == false)
58     {
59         throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
60                                          "flatbuffers format. size:{0} {1}",
61                                          binaryContentSize,
62                                          CHECK_LOCATION().AsString()));
63     }
64     auto clContext = GetClContext(binaryContent.data());
65
66     for (Program const* program : *clContext->programs())
67     {
68         auto programName = program->name()->c_str();
69         auto programBinary = program->binary();
70         std::vector<uint8_t> binary(programBinary->begin(), programBinary->begin() + programBinary->size());
71
72         cl::Program::Binaries   binaries{ binary };
73         std::vector<cl::Device> devices {device};
74         cl::Program             theProgram(context, devices, binaries);
75         theProgram.build();
76         clCompileContext.add_built_program(programName, theProgram);
77     }
78 }
79
80 } // namespace armnn