2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #ifndef __SOUSCHEF_TENSOR_FILLER_H__
18 #define __SOUSCHEF_TENSOR_FILLER_H__
29 virtual ~TensorFiller() = default;
32 * @brief This will record the tensor by index, if it needs filler option,
33 * such as kernel, bias.
35 void set_tensor_filler(uint32_t tensor_index) { _tensor_filler[tensor_index] = true; }
38 * @brief This will store int32 filler values such as reshape information for the tensor
40 void set_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues)
42 _tensor_filler_vint32[tensor_index] = expvalues;
45 void set_tensor_filler(uint32_t tensor_index, std::vector<float> &expvalues)
47 _tensor_filler_vfloat[tensor_index] = expvalues;
51 * @brief This will return true if the tensor by index, needs a filler option.
53 bool get_tensor_filler(uint32_t tensor_index)
55 auto it = _tensor_filler.find(tensor_index);
56 if (it != _tensor_filler.end())
64 * @brief This will return true if the tensor by index, needs a int array filler option.
66 bool get_tensor_filler(uint32_t tensor_index, std::vector<int32_t> &expvalues)
68 auto it = _tensor_filler_vint32.find(tensor_index);
69 if (it != _tensor_filler_vint32.end())
71 expvalues = it->second;
77 bool get_tensor_filler(uint32_t tensor_index, std::vector<float> &expvalues)
79 auto it = _tensor_filler_vfloat.find(tensor_index);
80 if (it != _tensor_filler_vfloat.end())
82 expvalues = it->second;
89 std::map<uint32_t, bool> _tensor_filler{};
90 std::map<uint32_t, std::vector<int32_t>> _tensor_filler_vint32{};
91 std::map<uint32_t, std::vector<float>> _tensor_filler_vfloat{};
94 } // namespace souschef
96 #endif // __SOUSCHEF_TENSOR_FILLER_H__