Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / rawformatter.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "rawformatter.h"
18 #include "nnfw.h"
19 #include "nnfw_util.h"
20
21 #include <iostream>
22 #include <fstream>
23 #include <stdexcept>
24
25 namespace onert_train
26 {
27 void RawFormatter::loadInputs(const std::string &filename, std::vector<Allocation> &inputs)
28 {
29   uint32_t num_inputs;
30   NNPR_ENSURE_STATUS(nnfw_input_size(session_, &num_inputs));
31
32   // Support multiple inputs
33   // Option 1: Get comman-separated input file list like --load:raw a,b,c
34   // Option 2: Get prefix --load:raw in
35   //           Internally access in.0, in.1, in.2, ... in.{N-1} where N is determined by nnfw info
36   //           query api.
37   //
38   // Currently Option 2 is implemented.
39   try
40   {
41     for (uint32_t i = 0; i < num_inputs; ++i)
42     {
43       nnfw_tensorinfo ti;
44       NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
45
46       // allocate memory for data
47       auto bufsz = bufsize_for(&ti);
48       inputs[i].alloc(bufsz);
49
50       std::ifstream file(filename + "." + std::to_string(i), std::ios::ate | std::ios::binary);
51       auto filesz = file.tellg();
52       if (bufsz != filesz)
53       {
54         throw std::runtime_error("Input " + std::to_string(i) +
55                                  " size does not match: " + std::to_string(bufsz) +
56                                  " expected, but " + std::to_string(filesz) + " provided.");
57       }
58       file.seekg(0, std::ios::beg);
59       file.read(reinterpret_cast<char *>(inputs[i].data()), filesz);
60       file.close();
61
62       NNPR_ENSURE_STATUS(nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), bufsz));
63       NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
64     }
65   }
66   catch (const std::exception &e)
67   {
68     std::cerr << e.what() << std::endl;
69     std::exit(-1);
70   }
71 };
72
73 void RawFormatter::dumpOutputs(const std::string &filename, std::vector<Allocation> &outputs)
74 {
75   uint32_t num_outputs;
76   NNPR_ENSURE_STATUS(nnfw_output_size(session_, &num_outputs));
77   try
78   {
79     for (uint32_t i = 0; i < num_outputs; i++)
80     {
81       nnfw_tensorinfo ti;
82       NNPR_ENSURE_STATUS(nnfw_output_tensorinfo(session_, i, &ti));
83       auto bufsz = bufsize_for(&ti);
84
85       std::ofstream file(filename + "." + std::to_string(i), std::ios::out | std::ios::binary);
86       file.write(reinterpret_cast<const char *>(outputs[i].data()), bufsz);
87       file.close();
88       std::cerr << filename + "." + std::to_string(i) + " is generated.\n";
89     }
90   }
91   catch (const std::runtime_error &e)
92   {
93     std::cerr << "Error during dumpOutputs on onert_run : " << e.what() << std::endl;
94     std::exit(-1);
95   }
96 }
97 } // end of namespace onert_train