2 * Copyright (c) 2021 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ml_single_manager.h"
18 #include "common/tools.h"
20 using common::PlatformResult;
21 using common::ErrorCode;
26 SingleManager::SingleManager(TensorsInfoManager* tim) : nextId_{0}, tim_{tim} {
30 SingleManager::~SingleManager() {
34 // MachineLearningSingle::openModel()
35 PlatformResult SingleManager::OpenModel(const std::string& modelPath, TensorsInfo* inTensorsInfo,
36 TensorsInfo* outTensorsInfo, ml_nnfw_type_e nnfw_e,
37 ml_nnfw_hw_e hw_e, bool isDynamicMode, int* res_id) {
40 ml_single_h handle = nullptr;
42 ml_tensors_info_h in_info = inTensorsInfo ? inTensorsInfo->Handle() : nullptr;
43 ml_tensors_info_h out_info = outTensorsInfo ? outTensorsInfo->Handle() : nullptr;
45 int ret = ml_single_open(&handle, modelPath.c_str(), in_info, out_info, nnfw_e, hw_e);
46 if (ML_ERROR_NONE != ret) {
47 LoggerE("ml_single_open failed: %d (%s)", ret, get_error_message(ret));
48 return util::ToPlatformResult(ret, "Failed to open model");
51 std::lock_guard<std::mutex> singles_lock(singles_mutex_);
53 singles_[id] = std::make_unique<SingleShot>(id, handle, isDynamicMode);
55 return PlatformResult{};
59 SingleShot* SingleManager::GetSingleShot(int id) {
60 ScopeLogger("id: %d", id);
62 std::lock_guard<std::mutex> singles_lock(singles_mutex_);
63 if (singles_.end() != singles_.find(id)) {
64 return singles_[id].get();
70 PlatformResult SingleManager::GetNativeTensorsInfo(int id, bool get_input_mode, int* res_id) {
73 SingleShot* single = GetSingleShot(id);
75 LoggerE("Could not find singleShot handle");
76 return PlatformResult(ErrorCode::ABORT_ERR);
79 ml_tensors_info_h in_info = nullptr;
80 PlatformResult ret = single->GetTensorsInfo(get_input_mode, &in_info);
85 auto tensor_info = tim_->CreateTensorsInfo(in_info);
86 *res_id = tensor_info->Id();
87 return PlatformResult{};
90 PlatformResult SingleManager::SetNativeInputInfo(int id, TensorsInfo* inTensorsInfo) {
93 SingleShot* single = GetSingleShot(id);
95 LoggerE("Could not find singleShot handle");
96 return PlatformResult(ErrorCode::ABORT_ERR);
99 ml_tensors_info_h in_info = inTensorsInfo ? inTensorsInfo->Handle() : nullptr;
101 PlatformResult ret = single->SetInputInfo(in_info);
106 return PlatformResult{};
109 PlatformResult SingleManager::Invoke(int id, TensorsData* in_tensors_data,
110 TensorsData** out_tensors_data) {
113 SingleShot* single = GetSingleShot(id);
115 LoggerE("Could not find SingleShot handle");
116 return PlatformResult(ErrorCode::ABORT_ERR, "Internal SingleShot error");
119 ml_tensors_info_h out_tensors_info_h = nullptr;
120 ml_tensors_data_h out_tensors_data_h = nullptr;
121 bool should_copy_data = false;
122 PlatformResult result =
123 single->Invoke(in_tensors_data->Handle(), in_tensors_data->GetTensorsInfo()->Handle(),
124 &out_tensors_data_h, &out_tensors_info_h, &should_copy_data);
128 if (should_copy_data) {
129 *out_tensors_data = tim_->CloneNativeTensorWithData(out_tensors_info_h, out_tensors_data_h);
131 *out_tensors_data = tim_->CreateTensorsData(out_tensors_info_h, out_tensors_data_h);
134 if (*out_tensors_data == nullptr) {
135 LoggerE("out_tensors_data creation failed");
136 result = single->CleanUpAfterInvoke();
138 LoggerE("CleanUpAfterInvoke failed");
140 return PlatformResult(ErrorCode::ABORT_ERR, "Internal SingleShot error");
143 return PlatformResult{};
146 PlatformResult SingleManager::GetValue(int id, const std::string& name, std::string& value) {
149 SingleShot* single = GetSingleShot(id);
151 LoggerE("Could not find SingleShot handle");
152 return PlatformResult(ErrorCode::ABORT_ERR, "Internal SingleShot error");
155 return single->GetValue(name, value);
158 PlatformResult SingleManager::SetValue(int id, const std::string& name, const std::string& value) {
161 SingleShot* single = GetSingleShot(id);
163 LoggerE("Could not find SingleShot handle");
164 return PlatformResult(ErrorCode::ABORT_ERR, "Internal SingleShot error");
167 return single->SetValue(name, value);
170 // SingleShot::setTimeout()
171 // SingleShot::close()
174 } // namespace extension