2 * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "mv_private.h"
19 #include "mv_image_classification_open.h"
20 #include "image_classification_adapter.h"
21 #include "machine_learning_exception.h"
22 #include "image_classification_type.h"
31 using namespace mediavision::inference;
32 using namespace mediavision::common;
33 using namespace mediavision::machine_learning;
34 using namespace MediaVision::Common;
35 using namespace mediavision::machine_learning::exception;
36 using ImageClassificationTask = ITask<image_classification_input_s, image_classification_result_s>;
38 int mv_image_classification_set_model_open(mv_image_classification_h handle, const char *model_file,
39 const char *meta_file, const char *label_file)
42 LOGE("Handle is NULL.");
43 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
47 auto context = static_cast<Context *>(handle);
48 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
50 image_classification_input_s input;
52 input.model_file = string(model_file);
53 input.meta_file = string(meta_file);
54 input.label_file = string(label_file);
56 task->setInput(input);
57 } catch (const BaseException &e) {
64 return MEDIA_VISION_ERROR_NONE;
67 int mv_image_classification_create_open(mv_image_classification_h *out_handle)
70 LOGE("Handle can't be created because handle pointer is NULL");
71 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
74 Context *context = nullptr;
75 ImageClassificationTask *task = nullptr;
78 context = new Context();
79 task = new ImageClassificationAdapter<image_classification_input_s, image_classification_result_s>();
81 context->__tasks.insert(make_pair("image_classification", task));
82 *out_handle = static_cast<mv_image_classification_h>(context);
83 } catch (const BaseException &e) {
90 LOGD("object detection 3d handle [%p] has been created", *out_handle);
92 return MEDIA_VISION_ERROR_NONE;
95 int mv_image_classification_destroy_open(mv_image_classification_h handle)
98 LOGE("Handle is NULL.");
99 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
102 auto context = static_cast<Context *>(handle);
104 for (auto &m : context->__tasks)
105 delete static_cast<ImageClassificationTask *>(m.second);
109 LOGD("Object detection 3d handle has been destroyed.");
111 return MEDIA_VISION_ERROR_NONE;
114 int mv_image_classification_configure_open(mv_image_classification_h handle)
119 LOGE("Handle is NULL.");
120 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
124 auto context = static_cast<Context *>(handle);
125 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
127 return MEDIA_VISION_ERROR_INVALID_OPERATION;
131 } catch (const BaseException &e) {
132 LOGE("%s", e.what());
138 return MEDIA_VISION_ERROR_NONE;
141 int mv_image_classification_prepare_open(mv_image_classification_h handle)
146 LOGE("Handle is NULL.");
147 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
151 auto context = static_cast<Context *>(handle);
152 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
155 } catch (const BaseException &e) {
156 LOGE("%s", e.what());
162 return MEDIA_VISION_ERROR_NONE;
165 int mv_image_classification_inference_open(mv_image_classification_h handle, mv_source_h source)
170 LOGE("Handle is NULL.");
171 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
175 auto context = static_cast<Context *>(handle);
176 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
178 image_classification_input_s input = { source };
180 task->setInput(input);
182 } catch (const BaseException &e) {
183 LOGE("%s", e.what());
189 return MEDIA_VISION_ERROR_NONE;
192 int mv_image_classification_get_label_open(mv_image_classification_h handle, const char **out_label)
197 LOGE("Handle is NULL.");
198 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
202 auto context = static_cast<Context *>(handle);
203 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
205 image_classification_result_s &result = task->getOutput();
207 *out_label = result.label.c_str();
208 } catch (const BaseException &e) {
209 LOGE("%s", e.what());
215 return MEDIA_VISION_ERROR_NONE;