*/
MessageType receive();
+ /**
+ * @return is the queue empty
+ */
+ bool isEmpty();
+
private:
typedef std::lock_guard<std::mutex> Lock;
return mess;
}
+template<typename MessageType>
+bool EventQueue<MessageType>::isEmpty()
+{
+ Lock lock(mCommunicationMutex);
+ return mMessages.empty();
+}
+
} // namespace ipc
} // namespace security_containers
EventFD::EventFD()
{
- mFD = ::eventfd(0, EFD_SEMAPHORE);
+ mFD = ::eventfd(0, EFD_SEMAPHORE | EFD_NONBLOCK);
if (mFD == -1) {
LOGE("Error in eventfd: " << std::string(strerror(errno)));
throw IPCException("Error in eventfd: " + std::string(strerror(errno)));
LOGE("Callback threw an error: " << e.what()); \
}
-
-
-
const Processor::MethodID Processor::RETURN_METHOD_ID = std::numeric_limits<MethodID>::max();
Processor::Processor(const PeerCallback& newPeerCallback,
return peerID;
}
-void Processor::removePeer(const PeerID peerID, Status status)
+void Processor::removePeer(const PeerID peerID)
{
- LOGW("Removing naughty peer. ID: " << peerID);
+ std::shared_ptr<std::condition_variable> conditionPtr(new std::condition_variable());
+
+ {
+ Lock lock(mSocketsMutex);
+ RemovePeerRequest request(peerID, conditionPtr);
+ mPeersToDelete.push(std::move(request));
+ }
+
+ mEventQueue.send(Event::DELETE_PEER);
+
+ auto isPeerDeleted = [&peerID, this] {
+ Lock lock(mSocketsMutex);
+ return mSockets.count(peerID) == 0;
+ };
+
+ std::mutex mutex;
+ std::unique_lock<std::mutex> lock(mutex);
+ conditionPtr->wait(lock, isPeerDeleted);
+}
+
+void Processor::removePeerInternal(const PeerID peerID, Status status)
+{
+ LOGW("Removing peer. ID: " << peerID);
{
Lock lock(mSocketsMutex);
mSockets.erase(peerID);
}
}
+ if (mRemovedPeerCallback) {
+ // Notify about the deletion
+ mRemovedPeerCallback(peerID);
+ }
+
resetPolling();
}
+void Processor::cleanCommunication()
+{
+ while (!mEventQueue.isEmpty()) {
+ switch (mEventQueue.receive()) {
+ case Event::FINISH: {
+ LOGD("Event FINISH after FINISH");
+ break;
+ }
+ case Event::CALL: {
+ LOGD("Event CALL after FINISH");
+ Call call = getCall();
+ IGNORE_EXCEPTIONS(call.process(Status::CLOSING, call.data));
+ break;
+ }
+
+ case Event::NEW_PEER: {
+ LOGD("Event NEW_PEER after FINISH");
+ break;
+ }
+
+ case Event::DELETE_PEER: {
+ LOGD("Event DELETE_PEER after FINISH");
+ RemovePeerRequest request;
+ {
+ Lock lock(mSocketsMutex);
+ request = std::move(mPeersToDelete.front());
+ mPeersToDelete.pop();
+ }
+ request.conditionPtr->notify_all();
+ break;
+ }
+ }
+ }
+}
+
void Processor::resetPolling()
{
LOGI("Resetting polling");
continue;
}
}
+
+ cleanCommunication();
}
peersToRemove.push_back(socketIt->first);
}
}
-
}
for (const PeerID peerID : peersToRemove) {
- removePeer(peerID, Status::PEER_DISCONNECTED);
+ removePeerInternal(peerID, Status::PEER_DISCONNECTED);
}
return !peersToRemove.empty();
mReturnCallbacks.erase(messageID);
} catch (const std::out_of_range&) {
LOGW("No return callback for messageID: " << messageID);
- removePeer(peerID, Status::NAUGHTY_PEER);
+ removePeerInternal(peerID, Status::NAUGHTY_PEER);
return true;
}
} catch (const std::exception& e) {
LOGE("Exception during parsing: " << e.what());
IGNORE_EXCEPTIONS(returnCallbacks.process(Status::PARSING_ERROR, data));
- removePeer(peerID, Status::PARSING_ERROR);
+ removePeerInternal(peerID, Status::PARSING_ERROR);
return true;
}
methodCallbacks = mMethodsCallbacks.at(methodID);
} catch (const std::out_of_range&) {
LOGW("No method callback for methodID: " << methodID);
- removePeer(peerID, Status::NAUGHTY_PEER);
+ removePeerInternal(peerID, Status::NAUGHTY_PEER);
return true;
}
data = methodCallbacks->parse(socket.getFD());
} catch (const std::exception& e) {
LOGE("Exception during parsing: " << e.what());
- removePeer(peerID, Status::PARSING_ERROR);
+ removePeerInternal(peerID, Status::PARSING_ERROR);
return true;
}
returnData = methodCallbacks->method(data);
} catch (const std::exception& e) {
LOGE("Exception in method handler: " << e.what());
- removePeer(peerID, Status::NAUGHTY_PEER);
+ removePeerInternal(peerID, Status::NAUGHTY_PEER);
return true;
}
methodCallbacks->serialize(socket.getFD(), returnData);
} catch (const std::exception& e) {
LOGE("Exception during serialization: " << e.what());
- removePeer(peerID, Status::SERIALIZATION_ERROR);
+ removePeerInternal(peerID, Status::SERIALIZATION_ERROR);
return true;
}
}
return true;
}
+
+ case Event::DELETE_PEER: {
+ LOGD("Event DELETE_PEER");
+ RemovePeerRequest request;
+ {
+ Lock lock(mSocketsMutex);
+ request = std::move(mPeersToDelete.front());
+ mPeersToDelete.pop();
+ }
+
+ removePeerInternal(request.peerID, Status::REMOVED_PEER);
+ request.conditionPtr->notify_all();
+ return true;
+ }
}
return false;
return false;
}
- MessageID messageID = getNextMessageID();
-
{
// Set what to do with the return message
Lock lock(mReturnCallbacksMutex);
- if (mReturnCallbacks.count(messageID) != 0) {
- LOGE("There already was a return callback for messageID: " << messageID);
+ if (mReturnCallbacks.count(call.messageID) != 0) {
+ LOGE("There already was a return callback for messageID: " << call.messageID);
}
// move insertion
- mReturnCallbacks[messageID] = std::move(ReturnCallbacks(call.peerID,
- std::move(call.parse),
- std::move(call.process)));
+ mReturnCallbacks[call.messageID] = std::move(ReturnCallbacks(call.peerID,
+ std::move(call.parse),
+ std::move(call.process)));
}
try {
// Send the call with the socket
Socket::Guard guard = socketPtr->getGuard();
socketPtr->write(&call.methodID, sizeof(call.methodID));
- socketPtr->write(&messageID, sizeof(messageID));
+ socketPtr->write(&call.messageID, sizeof(call.messageID));
call.serialize(socketPtr->getFD(), call.data);
} catch (const std::exception& e) {
LOGE("Error during sending a message: " << e.what());
// Inform about the error
- IGNORE_EXCEPTIONS(mReturnCallbacks[messageID].process(Status::SERIALIZATION_ERROR, call.data));
+ IGNORE_EXCEPTIONS(mReturnCallbacks[call.messageID].process(Status::SERIALIZATION_ERROR, call.data));
{
Lock lock(mReturnCallbacksMutex);
- mReturnCallbacks.erase(messageID);
+ mReturnCallbacks.erase(call.messageID);
}
- removePeer(call.peerID, Status::SERIALIZATION_ERROR);
+ removePeerInternal(call.peerID, Status::SERIALIZATION_ERROR);
return true;
}
* - Rest: The data written in a callback. One type per method.ReturnCallbacks
*
* TODO:
-* - error codes passed to async callbacks
* - remove ReturnCallbacks on peer disconnect
* - on sync timeout erase the return callback
* - don't throw timeout if the message is already processed
* - removePeer API function
* - error handling - special message type
* - some mutexes may not be needed
+* - make addPeer synchronous like removePeer
*/
class Processor {
public:
typedef std::function<void(int)> PeerCallback;
typedef unsigned int PeerID;
typedef unsigned int MethodID;
+ typedef unsigned int MessageID;
+
/**
* Method ID. Used to indicate a message with the return value.
PeerID addPeer(const std::shared_ptr<Socket>& socketPtr);
/**
+ * Request removing peer and wait
+ *
+ * @param peerID id of the peer
+ */
+ void removePeer(const PeerID peerID);
+
+ /**
* Saves the callbacks connected to the method id.
* When a message with the given method id is received,
* the data will be passed to the serialization callback through file descriptor.
* @tparam ReceivedDataType data type to receive
*/
template<typename SentDataType, typename ReceivedDataType>
- void callAsync(const MethodID methodID,
- const PeerID peerID,
- const std::shared_ptr<SentDataType>& data,
- const typename ResultHandler<ReceivedDataType>::type& process);
+ MessageID callAsync(const MethodID methodID,
+ const PeerID peerID,
+ const std::shared_ptr<SentDataType>& data,
+ const typename ResultHandler<ReceivedDataType>::type& process);
private:
typedef std::function<void(int fd, std::shared_ptr<void>& data)> SerializeCallback;
typedef std::function<std::shared_ptr<void>(int fd)> ParseCallback;
typedef std::lock_guard<std::mutex> Lock;
- typedef unsigned int MessageID;
struct Call {
Call(const Call& other) = delete;
SerializeCallback serialize;
ParseCallback parse;
ResultHandler<void>::type process;
+ MessageID messageID;
};
struct MethodHandlers {
std::shared_ptr<Socket> socketPtr;
};
+ struct RemovePeerRequest {
+ RemovePeerRequest(const RemovePeerRequest& other) = delete;
+ RemovePeerRequest& operator=(const RemovePeerRequest&) = delete;
+ RemovePeerRequest() = default;
+ RemovePeerRequest(RemovePeerRequest&&) = default;
+ RemovePeerRequest& operator=(RemovePeerRequest &&) = default;
+
+ RemovePeerRequest(const PeerID peerID,
+ const std::shared_ptr<std::condition_variable>& conditionPtr)
+ : peerID(peerID), conditionPtr(conditionPtr) {}
+
+ PeerID peerID;
+ std::shared_ptr<std::condition_variable> conditionPtr;
+ };
+
enum class Event : int {
FINISH, // Shutdown request
CALL, // New method call in the queue
- NEW_PEER // New peer in the queue
+ NEW_PEER, // New peer in the queue
+ DELETE_PEER // Delete peer
};
EventQueue<Event> mEventQueue;
std::mutex mSocketsMutex;
std::unordered_map<PeerID, std::shared_ptr<Socket> > mSockets;
std::queue<SocketInfo> mNewSockets;
+ std::queue<RemovePeerRequest> mPeersToDelete;
// Mutex for modifying the map with return callbacks
std::mutex mReturnCallbacksMutex;
MessageID getNextMessageID();
PeerID getNextPeerID();
Call getCall();
- void removePeer(const PeerID peerID, Status status);
-
+ void removePeerInternal(const PeerID peerID, Status status);
+ void cleanCommunication();
};
template<typename SentDataType, typename ReceivedDataType>
}
template<typename SentDataType, typename ReceivedDataType>
-void Processor::callAsync(const MethodID methodID,
- const PeerID peerID,
- const std::shared_ptr<SentDataType>& data,
- const typename ResultHandler<ReceivedDataType>::type& process)
+Processor::MessageID Processor::callAsync(const MethodID methodID,
+ const PeerID peerID,
+ const std::shared_ptr<SentDataType>& data,
+ const typename ResultHandler<ReceivedDataType>::type& process)
{
static_assert(config::isVisitable<SentDataType>::value,
"Use the libConfig library");
call.peerID = peerID;
call.methodID = methodID;
call.data = data;
+ call.messageID = getNextMessageID();
call.parse = [](const int fd)->std::shared_ptr<void> {
std::shared_ptr<ReceivedDataType> data(new ReceivedDataType());
}
mEventQueue.send(Event::CALL);
+
+ return call.messageID;
}
std::shared_ptr<ReceivedDataType> result;
- std::mutex mtx;
- std::unique_lock<std::mutex> lck(mtx);
+ std::mutex mutex;
std::condition_variable cv;
Status returnStatus = ipc::Status::UNDEFINED;
- auto process = [&result, &cv, &returnStatus](Status status, std::shared_ptr<ReceivedDataType> returnedData) {
+ auto process = [&result, &mutex, &cv, &returnStatus](Status status, std::shared_ptr<ReceivedDataType> returnedData) {
+ std::unique_lock<std::mutex> lock(mutex);
returnStatus = status;
result = returnedData;
- cv.notify_one();
+ cv.notify_all();
};
- callAsync<SentDataType,
- ReceivedDataType>(methodID,
- peerID,
- data,
- process);
+ MessageID messageID = callAsync<SentDataType, ReceivedDataType>(methodID,
+ peerID,
+ data,
+ process);
auto isResultInitialized = [&returnStatus]() {
return returnStatus != ipc::Status::UNDEFINED;
};
- if (!cv.wait_for(lck, std::chrono::milliseconds(timeoutMS), isResultInitialized)) {
- LOGE("Function call timeout; methodID: " << methodID);
- throw IPCTimeoutException("Function call timeout; methodID: " + std::to_string(methodID));
+ std::unique_lock<std::mutex> lock(mutex);
+ if (!cv.wait_for(lock, std::chrono::milliseconds(timeoutMS), isResultInitialized)) {
+ bool isTimeout = false;
+ {
+ Lock lock(mReturnCallbacksMutex);
+ if (1 == mReturnCallbacks.erase(messageID)) {
+ isTimeout = true;
+ }
+ }
+ if (isTimeout) {
+ removePeer(peerID);
+ LOGE("Function call timeout; methodID: " << methodID);
+ throw IPCTimeoutException("Function call timeout; methodID: " + std::to_string(methodID));
+ } else {
+ // Timeout started during the return value processing, so wait for it to finish
+ cv.wait(lock, isResultInitialized);
+ }
}
throwOnError(returnStatus);
{
int n = ::sd_listen_fds(-1 /*Block further calls to sd_listen_fds*/);
if (n < 0) {
- LOGE("sd_listen_fds fails with errno: " + n);
- throw IPCException("sd_listen_fds fails with errno: " + n);
+ LOGE("sd_listen_fds fails with errno: " << n);
+ throw IPCException("sd_listen_fds fails with errno: " + std::to_string(n));
}
for (int fd = SD_LISTEN_FDS_START;
throw IPCException("Error in connect: " + std::string(strerror(errno)));
}
+ // Nonblock socket
+ int flags = fcntl(fd, F_GETFL, 0);
+ if (-1 == fcntl(fd, F_SETFL, flags | O_NONBLOCK)) {
+ ::close(fd);
+ LOGE("Error in fcntl: " + std::string(strerror(errno)));
+ throw IPCException("Error in fcntl: " + std::string(strerror(errno)));
+ }
+
return Socket(fd);
}
#include <cerrno>
#include <cstring>
+#include <chrono>
#include <unistd.h>
-
+#include <poll.h>
#include <sys/resource.h>
#include <boost/filesystem.hpp>
namespace fs = boost::filesystem;
+namespace chr = std::chrono;
namespace security_containers {
namespace ipc {
+namespace {
+
+void waitForEvent(int fd,
+ short event,
+ const chr::high_resolution_clock::time_point deadline)
+{
+ // Wait for the rest of the data
+ struct pollfd fds[1];
+ fds[0].fd = fd;
+ fds[0].events = event | POLLHUP;
+
+ for (;;) {
+ chr::milliseconds timeoutMS = chr::duration_cast<chr::milliseconds>(deadline - chr::high_resolution_clock::now());
+ if (timeoutMS.count() < 0) {
+ LOGE("Timeout in read");
+ throw IPCException("Timeout in read");
+ }
+
+ int ret = ::poll(fds, 1 /*fds size*/, timeoutMS.count());
+
+ if (ret == -1) {
+ if (errno == EINTR) {
+ continue;
+ }
+ LOGE("Error in poll: " + std::string(strerror(errno)));
+ throw IPCException("Error in poll: " + std::string(strerror(errno)));
+ }
+
+ if (ret == 0) {
+ LOGE("Timeout in read");
+ throw IPCException("Timeout in read");
+ }
+
+ if (fds[0].revents & POLLHUP) {
+ LOGE("Peer disconnected");
+ throw IPCException("Peer disconnected");
+ }
+
+ // Here Comes the Sun
+ break;
+ }
+}
+
+} // namespace
+
void close(int fd)
{
if (fd < 0) {
}
}
-void write(int fd, const void* bufferPtr, const size_t size)
+void write(int fd, const void* bufferPtr, const size_t size, int timeoutMS)
{
+ chr::high_resolution_clock::time_point deadline = chr::high_resolution_clock::now() +
+ chr::milliseconds(timeoutMS);
+
size_t nTotal = 0;
- int n;
+ for (;;) {
+ int n = ::write(fd,
+ reinterpret_cast<const char*>(bufferPtr) + nTotal,
+ size - nTotal);
+ if (n > 0) {
+ nTotal += n;
+ } else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ // Neglected errors
+ LOGD("Retrying write");
+ } else {
+ LOGE("Error during reading: " + std::string(strerror(errno)));
+ throw IPCException("Error during reading: " + std::string(strerror(errno)));
+ }
- do {
- n = ::write(fd,
- reinterpret_cast<const char*>(bufferPtr) + nTotal,
- size - nTotal);
- if (n < 0) {
- if (errno == EINTR) {
- LOGD("Write interrupted by a signal, retrying");
- continue;
- }
- LOGE("Error during writing: " + std::string(strerror(errno)));
- throw IPCException("Error during witting: " + std::string(strerror(errno)));
+ if (nTotal >= size) {
+ // All data is written, break loop
+ break;
+ } else {
+ waitForEvent(fd, POLLOUT, deadline);
}
- nTotal += n;
- } while (nTotal < size);
+ }
}
-void read(int fd, void* bufferPtr, const size_t size)
+void read(int fd, void* bufferPtr, const size_t size, int timeoutMS)
{
- size_t nTotal = 0;
- int n;
+ chr::high_resolution_clock::time_point deadline = chr::high_resolution_clock::now() +
+ chr::milliseconds(timeoutMS);
- do {
- n = ::read(fd,
- reinterpret_cast<char*>(bufferPtr) + nTotal,
- size - nTotal);
- if (n < 0) {
- if (errno == EINTR) {
- LOGD("Read interrupted by a signal, retrying");
- continue;
- }
+ size_t nTotal = 0;
+ for (;;) {
+ int n = ::read(fd,
+ reinterpret_cast<char*>(bufferPtr) + nTotal,
+ size - nTotal);
+ if (n > 0) {
+ nTotal += n;
+ } else if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
+ // Neglected errors
+ LOGD("Retrying read");
+ } else {
LOGE("Error during reading: " + std::string(strerror(errno)));
throw IPCException("Error during reading: " + std::string(strerror(errno)));
}
- nTotal += n;
- } while (nTotal < size);
+
+ if (nTotal >= size) {
+ // All data is read, break loop
+ break;
+ } else {
+ waitForEvent(fd, POLLIN, deadline);
+ }
+ }
}
unsigned int getMaxFDNumber()
* @param fd file descriptor
* @param bufferPtr pointer to the data buffer
* @param size size of data to write
+ * @param timeoutMS timeout in milliseconds
*/
-void write(int fd, const void* bufferPtr, const size_t size);
+void write(int fd, const void* bufferPtr, const size_t size, int timeoutMS = 500);
/**
* Read from a file descriptor, throw on error.
* @param fd file descriptor
* @param bufferPtr pointer to the data buffer
* @param size size of the data to read
+ * @param timeoutMS timeout in milliseconds
*/
-void read(int fd, void* bufferPtr, const size_t size);
+void read(int fd, void* bufferPtr, const size_t size, int timeoutMS = 500);
/**
* @return the max number of file descriptors for this process.
case Status::SERIALIZATION_ERROR: return "Exception during writing/serializing data to the socket";
case Status::PEER_DISCONNECTED: return "No such peer. Might got disconnected.";
case Status::NAUGHTY_PEER: return "Peer performed a forbidden action.";
+ case Status::REMOVED_PEER: return "Removing peer";
+ case Status::CLOSING: return "Closing IPC";
case Status::UNDEFINED: return "Undefined state";
default: return "Unknown status";
}
case Status::SERIALIZATION_ERROR: throw IPCSerializationException(message);
case Status::PEER_DISCONNECTED: throw IPCPeerDisconnectedException(message);
case Status::NAUGHTY_PEER: throw IPCNaughtyPeerException(message);
+ case Status::REMOVED_PEER: throw IPCException(message);
+ case Status::CLOSING: throw IPCException(message);
case Status::UNDEFINED: throw IPCException(message);
default: return throw IPCException(message);
}
SERIALIZATION_ERROR,
PEER_DISCONNECTED,
NAUGHTY_PEER,
+ REMOVED_PEER,
+ CLOSING,
UNDEFINED
};
)
};
+struct LongSendData {
+ LongSendData(int i = 0, int waitTime = 1000): mSendData(i), mWaitTime(waitTime), intVal(i) {}
+
+ template<typename Visitor>
+ void accept(Visitor visitor)
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(mWaitTime));
+ mSendData.accept(visitor);
+ }
+ template<typename Visitor>
+ void accept(Visitor visitor) const
+ {
+ std::this_thread::sleep_for(std::chrono::milliseconds(mWaitTime));
+ mSendData.accept(visitor);
+ }
+
+ SendData mSendData;
+ int mWaitTime;
+ int intVal;
+};
+
struct EmptyData {
CONFIG_REGISTER_EMPTY
};
}
+BOOST_AUTO_TEST_CASE(ReadTimeoutTest)
+{
+ Service s(socketPath);
+ auto longEchoCallback = [](std::shared_ptr<SendData>& data) {
+ return std::shared_ptr<LongSendData>(new LongSendData(data->intVal));
+ };
+ s.addMethodHandler<LongSendData, SendData>(1, longEchoCallback);
+ s.start();
+
+ Client c(socketPath);
+ c.start();
+
+ // Test timeout on read
+ std::shared_ptr<SendData> sentData(new SendData(334));
+ BOOST_CHECK_THROW((c.callSync<SendData, SendData>(1, sentData, 100)), IPCException);
+}
+
+
+BOOST_AUTO_TEST_CASE(WriteTimeoutTest)
+{
+ Service s(socketPath);
+ s.addMethodHandler<SendData, SendData>(1, echoCallback);
+ s.start();
+
+ Client c(socketPath);
+ c.start();
+
+ // Test echo with a minimal timeout
+ std::shared_ptr<LongSendData> sentDataA(new LongSendData(34, 10 /*ms*/));
+ std::shared_ptr<SendData> recvData = c.callSync<LongSendData, SendData>(1, sentDataA, 100);
+ BOOST_CHECK_EQUAL(recvData->intVal, sentDataA->intVal);
+
+ // Test timeout on write
+ std::shared_ptr<LongSendData> sentDataB(new LongSendData(34, 1000 /*ms*/));
+ BOOST_CHECK_THROW((c.callSync<LongSendData, SendData>(1, sentDataB, 100)), IPCTimeoutException);
+}
+
+
+
// BOOST_AUTO_TEST_CASE(ConnectionLimitTest)
// {
// unsigned oldLimit = ipc::getMaxFDNumber();