2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "randomgen.h"
19 #include "nnfw_util.h"
20 #include "misc/RandomGenerator.h"
27 template <class T> void randomData(nnfw::misc::RandomGenerator &randgen, void *data, uint64_t size)
29 for (uint64_t i = 0; i < size; i++)
30 reinterpret_cast<T *>(data)[i] = randgen.generate<T>();
33 void RandomGenerator::generate(std::vector<Allocation> &inputs)
35 // generate random data
37 nnfw::misc::RandomGenerator randgen{seed, 0.0f, 2.0f};
38 for (uint32_t i = 0; i < inputs.size(); ++i)
41 NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session_, i, &ti));
42 auto input_size_in_bytes = bufsize_for(&ti);
43 inputs[i].alloc(input_size_in_bytes);
46 case NNFW_TYPE_TENSOR_FLOAT32:
47 randomData<float>(randgen, inputs[i].data(), num_elems(&ti));
49 case NNFW_TYPE_TENSOR_QUANT8_ASYMM:
50 randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
52 case NNFW_TYPE_TENSOR_BOOL:
53 randomData<bool>(randgen, inputs[i].data(), num_elems(&ti));
55 case NNFW_TYPE_TENSOR_UINT8:
56 randomData<uint8_t>(randgen, inputs[i].data(), num_elems(&ti));
58 case NNFW_TYPE_TENSOR_INT32:
59 randomData<int32_t>(randgen, inputs[i].data(), num_elems(&ti));
61 case NNFW_TYPE_TENSOR_INT64:
62 randomData<int64_t>(randgen, inputs[i].data(), num_elems(&ti));
64 case NNFW_TYPE_TENSOR_QUANT16_SYMM_SIGNED:
65 randomData<int16_t>(randgen, inputs[i].data(), num_elems(&ti));
68 std::cerr << "Not supported input type" << std::endl;
72 nnfw_set_input(session_, i, ti.dtype, inputs[i].data(), input_size_in_bytes));
73 NNPR_ENSURE_STATUS(nnfw_set_input_layout(session_, i, NNFW_LAYOUT_CHANNELS_LAST));
77 } // end of namespace onert_train