fix typo from clear_session to clean_session
[platform/core/ml/aitt.git] / src / AITTImpl.cc
1 /*
2  * Copyright (c) 2021-2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "AITTImpl.h"
17
18 #include <algorithm>
19 #include <cerrno>
20 #include <cstring>
21 #include <functional>
22 #include <memory>
23 #include <stdexcept>
24
25 #include "MosquittoMQ.h"
26 #include "aitt_internal.h"
27
28 namespace aitt {
29
30 AITT::Impl::Impl(AITT &parent, const std::string &id, const std::string &my_ip,
31       const AittOption &option)
32       : public_api(parent),
33         discovery(id),
34         modules(my_ip, discovery),
35         id_(id),
36         mqtt_broker_port_(0),
37         reply_id(0)
38 {
39     if (option.GetUseCustomMqttBroker()) {
40         mq = modules.NewCustomMQ(id, option);
41         AittOption discovery_option = option;
42         discovery_option.SetCleanSession(false);
43         discovery.SetMQ(modules.NewCustomMQ(id + 'd', option));
44     } else {
45         mq = std::unique_ptr<MQ>(new MosquittoMQ(id, option.GetCleanSession()));
46         discovery.SetMQ(std::unique_ptr<MQ>(new MosquittoMQ(id + 'd', false)));
47     }
48     aittThread = std::thread(&AITT::Impl::ThreadMain, this);
49 }
50
51 AITT::Impl::~Impl(void)
52 {
53     if (mqtt_broker_ip_.empty() == false) {
54         try {
55             Disconnect();
56         } catch (std::exception &e) {
57             ERR("Disconnect() Fail(%s)", e.what());
58         }
59     }
60     while (main_loop.Quit() == false) {
61         // wait when called before the thread has completely created.
62         usleep(1000);  // 1millisecond
63     }
64
65     if (aittThread.joinable())
66         aittThread.join();
67
68     discovery.SetMQ(nullptr);
69     mq = nullptr;
70 }
71
72 void AITT::Impl::ThreadMain(void)
73 {
74     pthread_setname_np(pthread_self(), "AITTWorkerLoop");
75     main_loop.Run();
76 }
77
78 void AITT::Impl::SetWillInfo(const std::string &topic, const void *data, const int datalen,
79       AittQoS qos, bool retain)
80 {
81     mq->SetWillInfo(topic, data, datalen, qos, retain);
82 }
83
84 void AITT::Impl::SetConnectionCallback(ConnectionCallback cb, void *user_data)
85 {
86     if (cb) {
87         mq->SetConnectionCallback([&, cb, user_data](int status) {
88             auto idler_cb = std::bind(&Impl::ConnectionCB, this, cb, user_data, status,
89                   std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
90             MainLoopHandler::AddIdle(&main_loop, idler_cb, nullptr);
91         });
92     } else {
93         mq->SetConnectionCallback(nullptr);
94     }
95 }
96
97 int AITT::Impl::ConnectionCB(ConnectionCallback cb, void *user_data, int status,
98       MainLoopHandler::MainLoopResult result, int fd, MainLoopHandler::MainLoopData *loop_data)
99 {
100     RETV_IF(cb == nullptr, AITT_LOOP_EVENT_REMOVE);
101
102     cb(public_api, status, user_data);
103
104     return AITT_LOOP_EVENT_REMOVE;
105 }
106
107 void AITT::Impl::Connect(const std::string &host, int port, const std::string &username,
108       const std::string &password)
109 {
110     discovery.Start(host, port, username, password);
111     mq->Connect(host, port, username, password);
112
113     mqtt_broker_ip_ = host;
114     mqtt_broker_port_ = port;
115 }
116
117 void AITT::Impl::Disconnect(void)
118 {
119     UnsubscribeAll();
120
121     for (auto stream : in_use_streams)
122         delete stream;
123     in_use_streams.clear();
124
125     mqtt_broker_ip_.clear();
126     mqtt_broker_port_ = -1;
127
128     discovery.Stop();
129     mq->Disconnect();
130 }
131
132 void AITT::Impl::UnsubscribeAll()
133 {
134     std::unique_lock<std::mutex> lock(subscribed_list_mutex_);
135
136     DBG("Subscribed list %zu", subscribed_list.size());
137
138     for (auto subscribe_info : subscribed_list) {
139         switch (subscribe_info->first) {
140         case AITT_TYPE_MQTT:
141             mq->Unsubscribe(subscribe_info->second);
142             break;
143         case AITT_TYPE_TCP:
144         case AITT_TYPE_TCP_SECURE:
145             modules.Get(subscribe_info->first).Unsubscribe(subscribe_info->second);
146             break;
147
148         default:
149             ERR("Unknown AittProtocol(%d)", subscribe_info->first);
150             break;
151         }
152
153         delete subscribe_info;
154     }
155     subscribed_list.clear();
156 }
157
158 void AITT::Impl::ConfigureTransportModule(const std::string &key, const std::string &value,
159       AittProtocol protocols)
160 {
161 }
162
163 void AITT::Impl::Publish(const std::string &topic, const void *data, const int datalen,
164       AittProtocol protocols, AittQoS qos, bool retain)
165 {
166     if ((protocols & AITT_TYPE_MQTT) == AITT_TYPE_MQTT)
167         mq->Publish(topic, data, datalen, qos, retain);
168
169     if ((protocols & AITT_TYPE_TCP) == AITT_TYPE_TCP)
170         modules.Get(AITT_TYPE_TCP).Publish(topic, data, datalen, qos, retain);
171
172     if ((protocols & AITT_TYPE_TCP_SECURE) == AITT_TYPE_TCP_SECURE)
173         modules.Get(AITT_TYPE_TCP_SECURE).Publish(topic, data, datalen, qos, retain);
174 }
175
176 AittSubscribeID AITT::Impl::Subscribe(const std::string &topic, const AITT::SubscribeCallback &cb,
177       void *user_data, AittProtocol protocol, AittQoS qos)
178 {
179     SubscribeInfo *info = new SubscribeInfo();
180     info->first = protocol;
181
182     void *subscribe_handle;
183     switch (protocol) {
184     case AITT_TYPE_MQTT:
185         subscribe_handle = SubscribeMQ(info, &main_loop, topic, cb, user_data, qos);
186         break;
187     case AITT_TYPE_TCP:
188     case AITT_TYPE_TCP_SECURE:
189         subscribe_handle = SubscribeTCP(info, topic, cb, user_data, qos);
190         break;
191     default:
192         ERR("Unknown AittProtocol(%d)", protocol);
193         delete info;
194         throw std::runtime_error("Unknown AittProtocol");
195     }
196     info->second = subscribe_handle;
197     {
198         std::unique_lock<std::mutex> lock(subscribed_list_mutex_);
199         subscribed_list.push_back(info);
200     }
201
202     INFO("Subscribe topic(%s) : %p", topic.c_str(), info);
203     return reinterpret_cast<AittSubscribeID>(info);
204 }
205
206 AittSubscribeID AITT::Impl::SubscribeMQ(SubscribeInfo *handle, MainLoopHandler *loop_handle,
207       const std::string &topic, const SubscribeCallback &cb, void *user_data, AittQoS qos)
208 {
209     return mq->Subscribe(
210           topic,
211           [this, handle, loop_handle, cb](MSG *msg, const std::string &topic, const void *data,
212                 const int datalen, void *mq_user_data) {
213               void *delivery = malloc(datalen);
214               if (delivery)
215                   memcpy(delivery, data, datalen);
216
217               msg->SetID(handle);
218               auto idler_cb =
219                     std::bind(&Impl::DetachedCB, this, cb, *msg, delivery, datalen, mq_user_data,
220                           std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);
221               MainLoopHandler::AddIdle(loop_handle, idler_cb, nullptr);
222           },
223           user_data, qos);
224 }
225
226 int AITT::Impl::DetachedCB(SubscribeCallback cb, MSG msg, void *data, const int datalen,
227       void *user_data, MainLoopHandler::MainLoopResult result, int fd,
228       MainLoopHandler::MainLoopData *loop_data)
229 {
230     RETV_IF(cb == nullptr, AITT_LOOP_EVENT_REMOVE);
231
232     cb(&msg, data, datalen, user_data);
233
234     free(data);
235     return AITT_LOOP_EVENT_REMOVE;
236 }
237
238 void *AITT::Impl::Unsubscribe(AittSubscribeID subscribe_id)
239 {
240     INFO("subscribe_id : %p", subscribe_id);
241     SubscribeInfo *info = reinterpret_cast<SubscribeInfo *>(subscribe_id);
242
243     std::unique_lock<std::mutex> lock(subscribed_list_mutex_);
244
245     auto it = std::find(subscribed_list.begin(), subscribed_list.end(), info);
246     if (it == subscribed_list.end()) {
247         ERR("Unknown subscribe_id(%p)", subscribe_id);
248         throw AittException(AittException::NO_DATA_ERR);
249     }
250
251     void *user_data = nullptr;
252     SubscribeInfo *found_info = *it;
253     switch (found_info->first) {
254     case AITT_TYPE_MQTT:
255         user_data = mq->Unsubscribe(found_info->second);
256         break;
257     case AITT_TYPE_TCP:
258     case AITT_TYPE_TCP_SECURE:
259         user_data = modules.Get(found_info->first).Unsubscribe(found_info->second);
260         break;
261
262     default:
263         ERR("Unknown AittProtocol(%d)", found_info->first);
264         break;
265     }
266
267     subscribed_list.erase(it);
268     delete info;
269
270     return user_data;
271 }
272
273 int AITT::Impl::PublishWithReply(const std::string &topic, const void *data, const int datalen,
274       AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *user_data,
275       const std::string &correlation)
276 {
277     std::string replyTopic = topic + RESPONSE_POSTFIX + std::to_string(reply_id++);
278
279     if (protocol != AITT_TYPE_MQTT)
280         return -1;  // not yet support
281
282     Subscribe(
283           replyTopic,
284           [this, cb](MSG *sub_msg, const void *sub_data, const int sub_datalen, void *sub_cbdata) {
285               if (sub_msg->IsEndSequence()) {
286                   try {
287                       Unsubscribe(sub_msg->GetID());
288                   } catch (AittException &e) {
289                       ERR("Unsubscribe() Fail(%s)", e.what());
290                   }
291               }
292               cb(sub_msg, sub_data, sub_datalen, sub_cbdata);
293           },
294           user_data, protocol, qos);
295
296     mq->PublishWithReply(topic, data, datalen, qos, false, replyTopic, correlation);
297     return 0;
298 }
299
300 int AITT::Impl::PublishWithReplySync(const std::string &topic, const void *data, const int datalen,
301       AittProtocol protocol, AittQoS qos, bool retain, const SubscribeCallback &cb, void *user_data,
302       const std::string &correlation, int timeout_ms)
303 {
304     std::string replyTopic = topic + RESPONSE_POSTFIX + std::to_string(reply_id++);
305
306     if (protocol != AITT_TYPE_MQTT)
307         return -1;  // not yet support
308
309     SubscribeInfo *info = new SubscribeInfo();
310     info->first = protocol;
311
312     void *subscribe_handle;
313     MainLoopHandler sync_loop;
314     unsigned int timeout_id = 0;
315     bool is_timeout = false;
316
317     subscribe_handle = SubscribeMQ(
318           info, &sync_loop, replyTopic,
319           [&](MSG *sub_msg, const void *sub_data, const int sub_datalen, void *sub_cbdata) {
320               if (sub_msg->IsEndSequence()) {
321                   try {
322                       Unsubscribe(sub_msg->GetID());
323                   } catch (AittException &e) {
324                       ERR("Unsubscribe() Fail(%s)", e.what());
325                   }
326                   sync_loop.Quit();
327               } else {
328                   if (timeout_id) {
329                       sync_loop.RemoveTimeout(timeout_id);
330                       HandleTimeout(timeout_ms, timeout_id, sync_loop, is_timeout);
331                   }
332               }
333               cb(sub_msg, sub_data, sub_datalen, sub_cbdata);
334           },
335           user_data, qos);
336     info->second = subscribe_handle;
337     {
338         std::unique_lock<std::mutex> lock(subscribed_list_mutex_);
339         subscribed_list.push_back(info);
340     }
341
342     mq->PublishWithReply(topic, data, datalen, qos, false, replyTopic, correlation);
343     if (timeout_ms)
344         HandleTimeout(timeout_ms, timeout_id, sync_loop, is_timeout);
345
346     sync_loop.Run();
347
348     if (is_timeout)
349         return AITT_ERROR_TIMED_OUT;
350     return 0;
351 }
352
353 void AITT::Impl::HandleTimeout(int timeout_ms, unsigned int &timeout_id,
354       aitt::MainLoopHandler &sync_loop, bool &is_timeout)
355 {
356     timeout_id = sync_loop.AddTimeout(
357           timeout_ms,
358           [&, timeout_ms](MainLoopHandler::MainLoopResult result, int fd,
359                 MainLoopHandler::MainLoopData *data) -> int {
360               ERR("PublishWithReplySync() timeout(%d)", timeout_ms);
361               sync_loop.Quit();
362               is_timeout = true;
363               return AITT_LOOP_EVENT_REMOVE;
364           },
365           nullptr);
366 }
367
368 void AITT::Impl::SendReply(MSG *msg, const void *data, const int datalen, bool end)
369 {
370     RET_IF(msg == nullptr);
371
372     if ((msg->GetProtocols() & AITT_TYPE_MQTT) != AITT_TYPE_MQTT)
373         return;  // not yet support
374
375     if (end == false || msg->GetSequence())
376         msg->IncreaseSequence();
377     msg->SetEndSequence(end);
378
379     mq->SendReply(msg, data, datalen, AITT_QOS_AT_MOST_ONCE, false);
380 }
381
382 void *AITT::Impl::SubscribeTCP(SubscribeInfo *handle, const std::string &topic,
383       const SubscribeCallback &cb, void *user_data, AittQoS qos)
384 {
385     return modules.Get(handle->first)
386           .Subscribe(
387                 topic,
388                 [handle, cb](const std::string &topic, const void *data, const int datalen,
389                       void *user_data, const std::string &correlation) -> void {
390                     MSG msg;
391                     msg.SetID(handle);
392                     msg.SetTopic(topic);
393                     msg.SetCorrelation(correlation);
394                     msg.SetProtocols(handle->first);
395
396                     return cb(&msg, data, datalen, user_data);
397                 },
398                 user_data, qos);
399 }
400
401 AittStream *AITT::Impl::CreateStream(AittStreamProtocol type, const std::string &topic,
402       AittStreamRole role)
403 {
404     AittStreamModule *stream = nullptr;
405     try {
406         stream = modules.NewStreamModule(type, topic, role);
407         in_use_streams.push_back(stream);
408     } catch (std::exception &e) {
409         ERR("StreamHandler() Fail(%s)", e.what());
410     }
411     discovery.Restart();
412
413     return stream;
414 }
415
416 void AITT::Impl::DestroyStream(AittStream *aitt_stream)
417 {
418     auto it = std::find(in_use_streams.begin(), in_use_streams.end(), aitt_stream);
419     if (it == in_use_streams.end()) {
420         ERR("Unknown Stream(%p)", aitt_stream);
421         return;
422     }
423     in_use_streams.erase(it);
424     delete aitt_stream;
425 }
426
427 }  // namespace aitt