Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / tests / tools / onert_train / src / randomgen.cc
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "randomgen.h"
18 #include "nnfw.h"
19 #include "nnfw_util.h"
20 #include "misc/RandomGenerator.h"
21
22 #include <iostream>
23
24 namespace onert_train
25 {
26
27 template <class T> void randomData(nnfw::misc::RandomGenerator &randgen, void *data, uint64_t size)
28 {
29   for (uint64_t i = 0; i < size; i++)
30     reinterpret_cast<T *>(data)[i] = randgen.generate<T>();
31 }
32
33 void RandomGenerator::generate(std::vector<Allocation> &inputs)
34 {
35   // generate random data
36   const int seed = 1;
37   nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
38   for (uint32_t i = 0; i < inputs.size(); ++i)
39   {
40     nnfw_tensorinfo ti;
41     NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
42     auto input_size_in_bytes = bufsize_for(&ti);
43     inputs[i].alloc(input_size_in_bytes);
44     switch (ti.dtype)
45     {
46       case NNFW_TYPE_TENSOR_FLOAT32:
47         randomData<float>(randgen, inputs[i].data(), num_elems(&ti));
48         break;
49       case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
50         randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
51         break;
52       case NNFW_TYPE_TENSOR_BOOL:
53         randomData<bool>(randgen, inputs[i].data(), num_elems(&ti));
54         break;
55       case NNFW_TYPE_TENSOR_UINT8:
56         randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
57         break;
58       case NNFW_TYPE_TENSOR_INT32:
59         randomData<int32_t>(randgen, inputs[i].data(), num_elems(&ti));
60         break;
61       case NNFW_TYPE_TENSOR_INT64:
62         randomData<int64_t>(randgen, inputs[i].data(), num_elems(&ti));
63         break;
64       case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
65         randomData<int16_t>(randgen, inputs[i].data(), num_elems(&ti));
66         break;
67       default:
68         std::cerr << "Not supported input type" << std::endl;
69         std::exit(-1);
70     }
71     NNPR_ENSURE_STATUS(
72       nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), input_size_in_bytes));
73     NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
74   }
75 };
76
77 } // end of namespace onert_train