d3501ad45c1f83810adea88f26ab17b41ff435fb
[platform/core/api/mediavision.git] / mv_machine_learning / image_classification / src / mv_image_classification_open.cpp
1 /**
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "mv_private.h"
18 #include "itask.h"
19 #include "mv_image_classification_open.h"
20 #include "image_classification_adapter.h"
21 #include "machine_learning_exception.h"
22 #include "image_classification_type.h"
23 #include "context.h"
24
25 #include <new>
26 #include <unistd.h>
27 #include <string>
28 #include <algorithm>
29
30 using namespace std;
31 using namespace mediavision::inference;
32 using namespace mediavision::common;
33 using namespace mediavision::machine_learning;
34 using namespace MediaVision::Common;
35 using namespace mediavision::machine_learning::exception;
36 using ImageClassificationTask = ITask<image_classification_input_s, image_classification_result_s>;
37
38 int mv_image_classification_set_model_open(mv_image_classification_h handle, const char *model_file,
39                                                                                    const char *meta_file, const char *label_file)
40 {
41         if (!handle) {
42                 LOGE("Handle is NULL.");
43                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
44         }
45
46         try {
47                 auto context = static_cast<Context *>(handle);
48                 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
49
50                 image_classification_input_s input;
51
52                 input.model_file = string(model_file);
53                 input.meta_file = string(meta_file);
54                 input.label_file = string(label_file);
55
56                 task->setInput(input);
57         } catch (const BaseException &e) {
58                 LOGE("%s", e.what());
59                 return e.getError();
60         }
61
62         LOGD("LEAVE");
63
64         return MEDIA_VISION_ERROR_NONE;
65 }
66
67 int mv_image_classification_create_open(mv_image_classification_h *out_handle)
68 {
69         if (!out_handle) {
70                 LOGE("Handle can't be created because handle pointer is NULL");
71                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
72         }
73
74         Context *context = nullptr;
75         ImageClassificationTask *task = nullptr;
76
77         try {
78                 context = new Context();
79                 task = new ImageClassificationAdapter<image_classification_input_s, image_classification_result_s>();
80
81                 context->__tasks.insert(make_pair("image_classification", task));
82                 *out_handle = static_cast<mv_image_classification_h>(context);
83         } catch (const BaseException &e) {
84                 LOGE("%s", e.what());
85                 delete task;
86                 delete context;
87                 return e.getError();
88         }
89
90         LOGD("object detection 3d handle [%p] has been created", *out_handle);
91
92         return MEDIA_VISION_ERROR_NONE;
93 }
94
95 int mv_image_classification_destroy_open(mv_image_classification_h handle)
96 {
97         if (!handle) {
98                 LOGE("Handle is NULL.");
99                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
100         }
101
102         auto context = static_cast<Context *>(handle);
103
104         for (auto &m : context->__tasks)
105                 delete static_cast<ImageClassificationTask *>(m.second);
106
107         delete context;
108
109         LOGD("Object detection 3d handle has been destroyed.");
110
111         return MEDIA_VISION_ERROR_NONE;
112 }
113
114 int mv_image_classification_configure_open(mv_image_classification_h handle)
115 {
116         LOGD("ENTER");
117
118         if (!handle) {
119                 LOGE("Handle is NULL.");
120                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
121         }
122
123         try {
124                 auto context = static_cast<Context *>(handle);
125                 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
126                 if (!task) {
127                         return MEDIA_VISION_ERROR_INVALID_OPERATION;
128                 }
129
130                 task->configure();
131         } catch (const BaseException &e) {
132                 LOGE("%s", e.what());
133                 return e.getError();
134         }
135
136         LOGD("LEAVE");
137
138         return MEDIA_VISION_ERROR_NONE;
139 }
140
141 int mv_image_classification_prepare_open(mv_image_classification_h handle)
142 {
143         LOGD("ENTER");
144
145         if (!handle) {
146                 LOGE("Handle is NULL.");
147                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
148         }
149
150         try {
151                 auto context = static_cast<Context *>(handle);
152                 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
153
154                 task->prepare();
155         } catch (const BaseException &e) {
156                 LOGE("%s", e.what());
157                 return e.getError();
158         }
159
160         LOGD("LEAVE");
161
162         return MEDIA_VISION_ERROR_NONE;
163 }
164
165 int mv_image_classification_inference_open(mv_image_classification_h handle, mv_source_h source)
166 {
167         LOGD("ENTER");
168
169         if (!handle) {
170                 LOGE("Handle is NULL.");
171                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
172         }
173
174         try {
175                 auto context = static_cast<Context *>(handle);
176                 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
177
178                 image_classification_input_s input = { source };
179
180                 task->setInput(input);
181                 task->perform();
182         } catch (const BaseException &e) {
183                 LOGE("%s", e.what());
184                 return e.getError();
185         }
186
187         LOGD("LEAVE");
188
189         return MEDIA_VISION_ERROR_NONE;
190 }
191
192 int mv_image_classification_get_label_open(mv_image_classification_h handle, const char **out_label)
193 {
194         LOGD("ENTER");
195
196         if (!handle) {
197                 LOGE("Handle is NULL.");
198                 return MEDIA_VISION_ERROR_INVALID_PARAMETER;
199         }
200
201         try {
202                 auto context = static_cast<Context *>(handle);
203                 auto task = static_cast<ImageClassificationTask *>(context->__tasks.at("image_classification"));
204
205                 image_classification_result_s &result = task->getOutput();
206
207                 *out_label = result.label.c_str();
208         } catch (const BaseException &e) {
209                 LOGE("%s", e.what());
210                 return e.getError();
211         }
212
213         LOGD("LEAVE");
214
215         return MEDIA_VISION_ERROR_NONE;
216 }