From: Konrad Lipinski Date: Tue, 12 Jul 2022 09:01:37 +0000 (+0200) Subject: Refactor MessageBuffer and dependencies X-Git-Tag: submit/tizen/20220803.102654~10 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=refs%2Fchanges%2F03%2F277803%2F9;p=platform%2Fcore%2Fsecurity%2Fsecurity-manager.git Refactor MessageBuffer and dependencies Security manager's protocol assumes there's at most one message in flight per connection at any given time. The MessageBuffer class can hold one such message in various stages of completion, assembled via either input or serialization and disposed of via either output or deserialization. This conceptual interface can be satisfied in a much simpler way than what's currently present. All that is require for a MessageBuffer is a single contiguous memory block and a little management on the side (the block's size, the message size, offset into the block). Since the protocol has the payload size stored as a size_t header prior to a message's payload, there's no need to even store it separately - it can be stored before the payload, just as in the protocol. Implications: * less memory copying/shuffling * read the full message directly into a buffer in binary form * deserialize directly from that buffer (no Pop(), no copies) * reuse the buffer space for serialization of the return message * output the return message into the socket without copying * socket manager now assembles full messages before handing them to the service, at no performance hit * one MessageEvent per socket instead of Accept/Close/Read/Write events * no need for the service to maintain connection state - it now operates on a per-message basis Change-Id: I45f6009ce09ae2f852cfee86a32426389bcf7a30 --- diff --git a/src/client/include/client-request.h b/src/client/include/client-request.h index 738fa907..a83e9cd7 100644 --- a/src/client/include/client-request.h +++ b/src/client/include/client-request.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2016-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -42,7 +42,8 @@ class ClientRequest { public: ClientRequest(SecurityModuleCall action) { - Serialization::Serialize(m_send, static_cast(action)); + m_buffer.InitForStreaming(); + Serialization::Serialize(m_buffer, static_cast(action)); } int getStatus() @@ -62,9 +63,10 @@ public: "Only one call to ClientRequest::send() is allowed"); m_sent = true; - m_status = sendToServer(SERVICE_SOCKET, m_send.Pop(), m_recv); + + m_status = sendToServer(SERVICE_SOCKET, m_buffer); if (!failed()) - Deserialization::Deserialize(m_recv, m_status); + Deserialization::Deserialize(m_buffer, m_status); else LogError("Error in sendToServer. Error code: " << m_status); @@ -73,7 +75,7 @@ public: template ClientRequest &send(const T&... args) { - Serialization::Serialize(m_send, args...); + Serialization::Serialize(m_buffer, args...); return send(); } @@ -87,7 +89,7 @@ public: throw std::logic_error( "ClientRequest::recv() not allowed if the request failed"); - Deserialization::Deserialize(m_recv, args...); + Deserialization::Deserialize(m_buffer, args...); return *this; } @@ -95,7 +97,7 @@ public: private: bool m_sent = false; int m_status = SECURITY_MANAGER_SUCCESS; - MessageBuffer m_send, m_recv; + MessageBuffer m_buffer; }; } // namespace SecurityManager diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index ac1b2ff6..e7986c5d 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -63,7 +63,6 @@ SET(COMMON_SOURCES ${DPL_PATH}/log/src/log.cpp ${DPL_PATH}/log/src/old_style_log_provider.cpp ${DPL_PATH}/core/src/assert.cpp - ${DPL_PATH}/core/src/binary_queue.cpp ${DPL_PATH}/core/src/colors.cpp ${DPL_PATH}/core/src/exception.cpp ${DPL_PATH}/core/src/noncopyable.cpp @@ -82,7 +81,6 @@ SET(COMMON_SOURCES ${COMMON_PATH}/file-lock.cpp ${COMMON_PATH}/permissible-set.cpp ${COMMON_PATH}/protocols.cpp - ${COMMON_PATH}/message-buffer.cpp ${COMMON_PATH}/nsmount-logic.cpp ${COMMON_PATH}/privilege_db.cpp ${COMMON_PATH}/smack-labels.cpp diff --git a/src/common/channel.cpp b/src/common/channel.cpp index 48f76887..79664c83 100644 --- a/src/common/channel.cpp +++ b/src/common/channel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2017-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -76,18 +76,14 @@ bool Channel::write(MessageBuffer &buffer) if (m_out == -1) return false; - RawBuffer data = buffer.Pop(); - unsigned char *buff = data.data(); - std::size_t done = 0; - - while (done < data.size()) { - int s = TEMP_FAILURE_RETRY(::write(m_out, buff+done, data.size()-done)); - if (s == -1) + buffer.ModeOutput(); + for (;;) { + const auto s = TEMP_FAILURE_RETRY(::write(m_out, buffer.Ptr(), buffer.OutputSize())); + if (s < 0) return false; - done += s; + if (buffer.OutputDone(s)) + return true; } - - return true; } bool Channel::read(MessageBuffer &buffer) @@ -95,16 +91,21 @@ bool Channel::read(MessageBuffer &buffer) if (m_in == -1) return false; - char buff[BUFFER_SIZE]; - - do { - int s = TEMP_FAILURE_RETRY(::read(m_in, buff, BUFFER_SIZE)); - if (s == 0 || s == -1) { + buffer.ModeInput(); + for (;;) { + const auto s = TEMP_FAILURE_RETRY(::read(m_in, buffer.Ptr(), buffer.InputSize())); + if (s <= 0) return false; + switch (buffer.InputDone(s)) { + case MessageBuffer::InputResult::ProtocolBroken: + return false; + case MessageBuffer::InputResult::Pending: + break; + case MessageBuffer::InputResult::Done: + buffer.ModeStreaming(); + return true; } - buffer.Push(RawBuffer(buff, buff+s)); - } while (!buffer.Ready()); - return true; + } } void Channel::closeAll() diff --git a/src/common/connection.cpp b/src/common/connection.cpp index d0b05184..469f1a5c 100644 --- a/src/common/connection.cpp +++ b/src/common/connection.cpp @@ -43,7 +43,6 @@ #include #include - #include namespace { @@ -170,58 +169,59 @@ private: namespace SecurityManager { -int sendToServer(char const * const interface, const RawBuffer &send, MessageBuffer &recv) { +int sendToServer(char const * const interface, MessageBuffer &buffer) { int ret; SockRAII sock; - ssize_t done = 0; - char buffer[2048]; if (SECURITY_MANAGER_SUCCESS != (ret = sock.Connect(interface))) { LogError("Error in SockRAII"); return ret; } - while ((send.size() - done) > 0) { + buffer.ModeOutput(); + + for (;;) { if (0 >= waitForSocket(sock.Get(), POLLOUT, POLL_TIMEOUT)) { LogError("Error in poll(POLLOUT)"); return SECURITY_MANAGER_ERROR_SOCKET; } - ssize_t temp = TEMP_FAILURE_RETRY(::send(sock.Get(), - &send[done], - send.size() - done, - MSG_NOSIGNAL)); - if (-1 == temp) { + const auto temp = TEMP_FAILURE_RETRY(::send(sock.Get(), buffer.Ptr(), buffer.OutputSize(), MSG_NOSIGNAL)); + if (temp < 0) { int err = errno; LogError("Error in write: " << GetErrnoString(err)); return SECURITY_MANAGER_ERROR_SOCKET; } - done += temp; + if (buffer.OutputDone(temp)) + break; } - do { + buffer.ModeInput(); + + for (;;) { if (0 >= waitForSocket(sock.Get(), POLLIN, POLL_TIMEOUT)) { LogError("Error in poll(POLLIN)"); return SECURITY_MANAGER_ERROR_SOCKET; } - ssize_t temp = TEMP_FAILURE_RETRY(::recv(sock.Get(), - buffer, - 2048, - 0)); - if (-1 == temp) { - int err = errno; - LogError("Error in read: " << GetErrnoString(err)); + const auto temp = TEMP_FAILURE_RETRY(::recv(sock.Get(), buffer.Ptr(), buffer.InputSize(), 0)); + if (temp <= 0) { + if (!temp) + LogError("Read return 0/Connection closed by server(?)"); + else { + int err = errno; + LogError("Error in read: " << GetErrnoString(err)); + } return SECURITY_MANAGER_ERROR_SOCKET; } - - if (0 == temp) { - LogError("Read return 0/Connection closed by server(?)"); - return SECURITY_MANAGER_ERROR_SOCKET; + switch (buffer.InputDone(temp)) { + case MessageBuffer::InputResult::ProtocolBroken: + return SECURITY_MANAGER_ERROR_SOCKET; + case MessageBuffer::InputResult::Pending: + break; + case MessageBuffer::InputResult::Done: + buffer.ModeStreaming(); + return SECURITY_MANAGER_SUCCESS; } - - RawBuffer raw(buffer, buffer+temp); - recv.Push(raw); - } while(!recv.Ready()); - return SECURITY_MANAGER_SUCCESS; + } } } // namespace SecurityManager diff --git a/src/common/filesystem.cpp b/src/common/filesystem.cpp index 85e79e38..97775249 100644 --- a/src/common/filesystem.cpp +++ b/src/common/filesystem.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2016-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -33,7 +33,6 @@ #include #include -#include #include #include #include diff --git a/src/common/include/connection-info.h b/src/common/include/connection-info.h deleted file mode 100644 index d604c81d..00000000 --- a/src/common/include/connection-info.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2014-2022 Samsung Electronics Co., Ltd. All rights reserved. - * - * This file is licensed under the terms of MIT License or the Apache License - * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. - * See the LICENSE file or the notice below for Apache License Version 2.0 - * details. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * @file connection-info.h - * @author Lukasz Kostyra (l.kostyra@partner.samsung.com) - * @version 1.0 - * @brief Definition of ConnectionInfo structure and ConnectionInfoMap type. - */ - -#pragma once - -#include -#include -#include -#include - -namespace SecurityManager -{ - struct ConnectionInfo { - ConnectionInfo(Credentials crd) - : creds(std::move(crd)) - {} - - Credentials creds; - MessageBuffer buffer; - }; - - typedef std::map ConnectionInfoMap; -} //namespace SecurityManager diff --git a/src/common/include/connection.h b/src/common/include/connection.h index 759c1e3b..85b6f6f0 100644 --- a/src/common/include/connection.h +++ b/src/common/include/connection.h @@ -30,15 +30,10 @@ #pragma once -#include -#include - #include namespace SecurityManager { -typedef std::vector RawBuffer; - -int sendToServer(char const * const interface, const RawBuffer &send, MessageBuffer &recv); +int sendToServer(char const * const interface, MessageBuffer &buffer); } // namespace SecurityManager diff --git a/src/common/include/message-buffer.h b/src/common/include/message-buffer.h index 761c42a7..bfb0ead2 100644 --- a/src/common/include/message-buffer.h +++ b/src/common/include/message-buffer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014-2020 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2014-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -28,17 +28,67 @@ #pragma once -#include +#include -#include #include +#include #include namespace SecurityManager { -typedef std::vector RawBuffer; +/** + * Contiguous memory buffer for holding protocol messages (possibly partially + * constructed) and iterating over or assembling their contents. + * + * There are four modes of operation: + * Default (default-constructed, moved-out or cleared). + * Input (for assembling via ::read()). + * Output (for outputting via ::write()). + * Streaming (for de/serialization via IStream). + * + * The following state transitions are possible: + * Default -> Input: \ref InitForInput() + * Default -> Streaming: \ref InitForStreaming() + * * -> Default: std::move(), \ref Clear() + * Input | Output | Streaming + * -> Input: \ref ModeInput() + * -> Output: \ref ModeOutput() + * -> Streaming: \ref ModeStreaming() + * + * Mode-specific public functions: + * Input | Output: \ref Ptr() + * Input: \ref InputSize(), \ref InputDone(n) + * Output: \ref OutputSize(), \ref OutputDone(n) + * Streaming: \ref DeserializationDone(), \ref Read(), \ref Write() + */ +class MessageBuffer final : public SecurityManager::IStream { + /** + * Null in default mode. + * + * Otherwise size_t-aligned (as guaranteed by malloc()/realloc()) and big + * enough to hold a size_t. + * + * (size_t*)m_buffer holds the message header, aka payload size. Subsequent + * m_buffer bytes hold the unpadded payload. For instance, a message + * containing an unsigned char and an int is laid out in m_buffer like so: + * size_t payloadSize = sizeof(char) + sizeof(int); + * unsigned char; + * int; + * and spans sizeof(size_t) + payloadSize bytes. + */ + unsigned char *m_buffer = nullptr; -class MessageBuffer : public SecurityManager::IStream { + /** + * Current offset for ::read(), ::write() or de/serialization. + * + * Undefined in default mode, otherwise <= \ref m_bufferSize. + */ + size_t m_offset; + + /** + * Undefined if \ref m_buffer is null, otherwise its allocation size. + */ + size_t m_bufferSize; public: class Exception { @@ -47,36 +97,278 @@ public: DECLARE_EXCEPTION_TYPE(Base, OutOfData) }; - MessageBuffer() - : m_bytesLeft(0) - {} + MessageBuffer() = default; + + MessageBuffer(MessageBuffer&& other) { + m_buffer = other.m_buffer; + m_offset = other.m_offset; + m_bufferSize = other.m_bufferSize; + other.m_buffer = nullptr; + } + + MessageBuffer &operator=(MessageBuffer&& other) { + this->~MessageBuffer(); + return *new (this) MessageBuffer(std::move(other)); + } + + ~MessageBuffer() { + free(m_buffer); + } + + /** + * Default mode only. Allocate the buffer. + */ + void InitBuffer() { + assert(!m_buffer); + + // Common PAGE_SIZE lower bound, a conservative choice for IO. + constexpr size_t INITIAL_BUFFER_SIZE = 4096; + + if (!(m_buffer = static_cast(malloc(INITIAL_BUFFER_SIZE)))) + throw std::bad_alloc(); + m_bufferSize = INITIAL_BUFFER_SIZE; + } + + /** + * Switch from Default mode to Input mode. + */ + void InitForInput() { + InitBuffer(); + ModeInput(); + } - void Push(const RawBuffer &data); + /** + * Switch from Default mode to Streaming mode. + */ + void InitForStreaming() { + InitBuffer(); + ModeStreaming(); + } - size_t SerializedSize(); + /** + * Switch to Default mode. + */ + void Clear() { + free(m_buffer); + m_buffer = nullptr; + } - RawBuffer Pop(); + /** + * Switch from a non-Default mode to Streaming mode for de/serialization. + */ + void ModeStreaming() { + assert(m_buffer); + assert(m_bufferSize >= sizeof(size_t)); - bool Ready(); + m_offset = sizeof(size_t); + } - virtual void Read(size_t num, void *bytes); + /** + * Switch from a non-Default mode to Input mode for ::read()ing. + */ + void ModeInput() { + assert(m_buffer); + assert(m_bufferSize >= sizeof(size_t)); - virtual void Write(size_t num, const void *bytes); + m_offset = 0; + } -protected: + /** + * Switch from a non-Default mode to Output mode for ::write()ing. + * + * Assumes the buffer contains a fully formed payload that's just been + * assembled or otherwise iterated over (sets the payload size to + * \ref m_offset - sizeof(size_t)). Also see \ref m_buffer. + */ + void ModeOutput() { + assert(m_buffer); + assert(m_offset >= sizeof(size_t)); + assert(m_offset <= m_bufferSize); - inline void CountBytesLeft() { - if (m_bytesLeft > 0) - return; // we already counted m_bytesLeft nothing to do + *reinterpret_cast(m_buffer) = m_offset - sizeof(size_t); + m_offset = 0; + } - if (m_buffer.Size() < sizeof(size_t)) - return; // we cannot count m_bytesLeft because buffer is too small + /** + * Input/Output mode only. Pointer to a ::read()/::write() buffer. + */ + unsigned char *Ptr() { + assert(m_buffer); + assert(m_offset <= m_bufferSize); - m_buffer.FlattenConsume(&m_bytesLeft, sizeof(size_t)); + return &m_buffer[m_offset]; } - size_t m_bytesLeft; - SecurityManager::BinaryQueue m_buffer; + /** + * Input mode only. Return ::read() buffer size. + * + * Security manager's protocol assumes that, for a single connection (pipe, + * socket, whatever), at most one message is in flight at any given time. + * Thus, ::read()ing into a buffer larger than that message's size is + * assumed to result in a short read that never goes past the payload. + * As a consequence, no data loss ever occurs and the buffer only ever + * contains bytes from one given message at a time. + * + * This implementation leverages that by simply using all the remaining + * buffer space for ::read()ing. This is doubly important for the first + * ::read() when payload size is not yet known. + * + * See also \ref MessageBuffer and \ref m_buffer. + */ + size_t InputSize() const { + assert(m_buffer); + assert(m_offset <= m_bufferSize); + + return m_bufferSize - m_offset; + } + + enum class InputResult : unsigned char { + Done, + Pending, + ProtocolBroken, + }; + + /** + * Input mode only. Acknowledge bytes ::read(). + * + * @param n number of bytes read into \ref Ptr() + * + * @return \ref InputResult::Pending if message not yet complete + * @return \ref InputResult::ProtocolBroken if excess trailing bytes + * @return \ref InputResult::Done if message complete + */ + InputResult InputDone(size_t n) { + assert(n <= InputSize()); + + m_offset += n; + + // message size not yet available + if (m_offset < sizeof(size_t)) + return InputResult::Pending; + + const auto messageSize = MessageSize(); + if (m_offset == messageSize) + return InputResult::Done; + if (m_offset > messageSize) { + LogError("Protocol broken. Excess bytes up to offset " << m_offset << + " beyond " << messageSize); + return InputResult::ProtocolBroken; + } + + if (messageSize > m_bufferSize) + ReserveMessageSize(messageSize); + + return InputResult::Pending; + } + + /** + * Output mode only. Return ::write() buffer size. + */ + size_t OutputSize() const { + const auto messageSize = MessageSize(); + + assert(messageSize >= m_offset); + assert(messageSize <= m_bufferSize); + + return messageSize - m_offset; + } + + /** + * Output mode only. Acknowledge bytes ::write()n. + * + * @param n number of bytes written into \ref Ptr() + * + * @return true if the message has now been fully written + */ + bool OutputDone(size_t n) { + assert(n <= OutputSize()); + return (m_offset += n) == MessageSize(); + } + + /** + * Streaming mode, deserialization only. Return whether the message has + * been fully deserialized. + */ + bool DeserializationDone() const { + const auto messageSize = MessageSize(); + + assert(messageSize >= m_offset); + assert(messageSize <= m_bufferSize); + + return m_offset == messageSize; + } + + /** + * Streaming mode, deserialization only. Retrieve next payload bytes. + * + * @param num number of bytes to retrieve + * @param bytes output buffer + */ + void Read(size_t num, void *bytes) override { + const auto messageSize = MessageSize(); + assert(messageSize <= m_bufferSize); + + const auto newOffset = m_offset + num; + + if (newOffset > messageSize) { + LogError("Protocol broken. OutOfData. Offset " << newOffset << + " beyond " << messageSize); + Throw(Exception::OutOfData); + } + memcpy(bytes, &m_buffer[m_offset], num); + m_offset = newOffset; + } + + /** + * Streaming mode, serialization only. Receive next payload bytes. + * + * @param num number of bytes to receive + * @param bytes input buffer + */ + void Write(size_t num, const void *bytes) override { + assert(m_buffer); + assert(m_offset >= sizeof(size_t)); + assert(m_offset <= m_bufferSize); + + const auto newOffset = m_offset + num; + if (newOffset > m_bufferSize) + ReserveMessageSize(newOffset); + + memcpy(&m_buffer[m_offset], bytes, num); + m_offset = newOffset; + } + +private: + /** + * Non-Default mode only. Reallocate the buffer. + * + * @param newBufferSize new buffer size lower bound > \ref m_bufferSize + */ + void ReserveMessageSize(size_t newBufferSize) { + assert(m_buffer); + assert(m_bufferSize >= sizeof(size_t)); + assert(newBufferSize > m_bufferSize); + + newBufferSize = std::max(newBufferSize, 2 * m_bufferSize); + + const auto newBuffer = static_cast(realloc(m_buffer, newBufferSize)); + if (!newBuffer) + throw std::bad_alloc(); + + m_buffer = newBuffer; + m_bufferSize = newBufferSize; + } + + /** + * Non-Default mode only. Total message size (header and payload). + * Assumes payload size validity (see \ref m_buffer). + */ + size_t MessageSize() const { + assert(m_buffer); + assert(m_offset <= m_bufferSize); + + return sizeof(size_t) + *reinterpret_cast(m_buffer); + } }; } // namespace SecurityManager diff --git a/src/common/message-buffer.cpp b/src/common/message-buffer.cpp deleted file mode 100644 index 6b1a8b4c..00000000 --- a/src/common/message-buffer.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2014-2020 Samsung Electronics Co., Ltd. All rights reserved. - * - * This file is licensed under the terms of MIT License or the Apache License - * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. - * See the LICENSE file or the notice below for Apache License Version 2.0 - * details. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * @file message-buffer.cpp - * @author Bartlomiej Grzelewski (b.grzelewski@samsung.com) - * @version 1.0 - * @brief Implementation of MessageBuffer. - */ - -#include - -#include - -namespace SecurityManager { - -void MessageBuffer::Push(const RawBuffer &data) { - m_buffer.AppendCopy(&data[0], data.size()); -} - -size_t MessageBuffer::SerializedSize() { - return m_buffer.Size() + sizeof(size_t); -} - -RawBuffer MessageBuffer::Pop() { - size_t size = m_buffer.Size(); - RawBuffer buffer; - buffer.resize(size + sizeof(size_t)); - memcpy(&buffer[0], &size, sizeof(size_t)); - m_buffer.FlattenConsume(&buffer[sizeof(size_t)], size); - return buffer; -} - -bool MessageBuffer::Ready() { - CountBytesLeft(); - if (m_bytesLeft == 0) - return false; - if (m_bytesLeft > m_buffer.Size()) - return false; - return true; -} - -void MessageBuffer::Read(size_t num, void *bytes) { - CountBytesLeft(); - if (num > m_bytesLeft) { - LogError("Protocol broken. OutOfData. Asked for: " << num << " Ready: " << m_bytesLeft << " Buffer.size(): " << m_buffer.Size()); - Throw(Exception::OutOfData); - } - - m_buffer.FlattenConsume(bytes, num); - m_bytesLeft -= num; -} - -void MessageBuffer::Write(size_t num, const void *bytes) { - m_buffer.AppendCopy(bytes, num); -} - -} // namespace SecurityManager - diff --git a/src/common/nsmount-logic.cpp b/src/common/nsmount-logic.cpp index 0257120e..3a3d0ad8 100644 --- a/src/common/nsmount-logic.cpp +++ b/src/common/nsmount-logic.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2017-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -118,17 +118,18 @@ void NSMountLogic::cynaraCheck(EntryVector &entryVector) bool NSMountLogic::sendJobs(EntryVector &entryVector) { int status; - MessageBuffer send, recv; - Serialization::Serialize(send, entryVector); - if (!m_channel.write(send)) { + MessageBuffer buffer; + buffer.InitForStreaming(); + Serialization::Serialize(buffer, entryVector); + if (!m_channel.write(buffer)) { LogError("Could not send data to worker!"); return false; } - if (!m_channel.read(recv)) { + if (!m_channel.read(buffer)) { LogError("Could not recv data from worker!"); return false; } - Deserialization::Deserialize(recv, status); + Deserialization::Deserialize(buffer, status); return status == 0; } diff --git a/src/common/worker.cpp b/src/common/worker.cpp index a83ad25a..e5cdedba 100644 --- a/src/common/worker.cpp +++ b/src/common/worker.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Samsung Electronics Co., Ltd. All rights reserved. + * Copyright (c) 2017-2022 Samsung Electronics Co., Ltd. All rights reserved. * * This file is licensed under the terms of MIT License or the Apache License * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. @@ -102,21 +102,22 @@ void Worker::mainLoop() { NSMountLogic::EntryVector entryVector; + MessageBuffer buffer; + buffer.InitBuffer(); do { int status; NSMountLogic::EntryVector entryVector; - MessageBuffer recv; - MessageBuffer send; - if (!m_channel.read(recv)) { + if (!m_channel.read(buffer)) { LogError("Error reading command socket. The Security-manager worker will exit"); break; } - Deserialization::Deserialize(recv, entryVector); + Deserialization::Deserialize(buffer, entryVector); status = doWork(entryVector); - Serialization::Serialize(send, status); - m_channel.write(send); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, status); + m_channel.write(buffer); } while (true); } diff --git a/src/dpl/core/include/dpl/binary_queue.h b/src/dpl/core/include/dpl/binary_queue.h deleted file mode 100644 index 58d2dfe8..00000000 --- a/src/dpl/core/include/dpl/binary_queue.h +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Copyright (c) 2011-2020 Samsung Electronics Co., Ltd. All rights reserved. - * - * This file is licensed under the terms of MIT License or the Apache License - * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. - * See the LICENSE file or the notice below for Apache License Version 2.0 - * details. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * @file binary_queue.h - * @author Przemyslaw Dobrowolski (p.dobrowolsk@samsung.com) - * @version 1.0 - * @brief This file is the header file of binary queue - */ -#pragma once - -//#include -#include -#include -#include -#include - -namespace SecurityManager { -/** - * Binary queue auto pointer - */ -class BinaryQueue; -typedef std::unique_ptr BinaryQueueUniquePtr; - -/** - * Binary stream implemented as constant size bucket list - * - * @todo Add optimized implementation for FlattenConsume - */ -class BinaryQueue -// : public AbstractInputOutput -{ - public: - class Exception - { - public: - DECLARE_EXCEPTION_TYPE(SecurityManager::Exception, Base) - DECLARE_EXCEPTION_TYPE(Base, OutOfData) - }; - - typedef void (*BufferDeleter)(const void *buffer, size_t bufferSize, - void *userParam); - static void BufferDeleterFree(const void *buffer, - size_t bufferSize, - void *userParam); - - class BucketVisitor - { - public: - /** - * Destructor - */ - virtual ~BucketVisitor(); - - /** - * Visit bucket - * - * @return none - * @param[in] buffer Constant pointer to bucket data buffer - * @param[in] bufferSize Number of bytes in bucket - */ - virtual void OnVisitBucket(const void *buffer, size_t bufferSize) = 0; - }; - - private: - struct Bucket : - private Noncopyable - { - const void *buffer; - const void *ptr; - size_t size; - size_t left; - - BufferDeleter deleter; - void *param; - - Bucket(const void *buffer, - size_t bufferSize, - BufferDeleter deleter, - void *userParam); - virtual ~Bucket(); - }; - - typedef std::list BucketList; - BucketList m_buckets; - size_t m_size; - - static void DeleteBucket(Bucket *bucket); - - class BucketVisitorCall - { - private: - BucketVisitor *m_visitor; - - public: - BucketVisitorCall(BucketVisitor *visitor); - virtual ~BucketVisitorCall(); - - void operator()(Bucket *bucket) const; - }; - - public: - /** - * Construct empty binary queue - */ - BinaryQueue(); - - /** - * Construct binary queue via bare copy of other binary queue - * - * @param[in] other Other binary queue to copy from - * @warning One cannot assume that bucket structure is preserved during copy - */ - BinaryQueue(const BinaryQueue &other); - - /** - * Destructor - */ - virtual ~BinaryQueue(); - - /** - * Construct binary queue via bare copy of other binary queue - * - * @param[in] other Other binary queue to copy from - * @warning One cannot assume that bucket structure is preserved during copy - */ - const BinaryQueue &operator=(const BinaryQueue &other); - - /** - * Append copy of @a bufferSize bytes from memory pointed by @a buffer - * to the end of binary queue. Uses default deleter based on free. - * - * @return none - * @param[in] buffer Pointer to buffer to copy data from - * @param[in] bufferSize Number of bytes to copy - * @exception std::bad_alloc Cannot allocate memory to hold additional data - * @see BinaryQueue::BufferDeleterFree - */ - void AppendCopy(const void *buffer, size_t bufferSize); - - /** - * Append @a bufferSize bytes from memory pointed by @a buffer - * to the end of binary queue. Uses custom provided deleter. - * Responsibility for deleting provided buffer is transfered to BinaryQueue. - * - * @return none - * @param[in] buffer Pointer to data buffer - * @param[in] bufferSize Number of bytes available in buffer - * @param[in] deleter Pointer to deleter procedure used to free provided - * buffer - * @param[in] userParam User parameter passed to deleter routine - * @exception std::bad_alloc Cannot allocate memory to hold additional data - */ - void AppendUnmanaged( - const void *buffer, - size_t bufferSize, - BufferDeleter deleter = - &BinaryQueue::BufferDeleterFree, - void *userParam = NULL); - - /** - * Append copy of other binary queue to the end of this binary queue - * - * @return none - * @param[in] other Constant reference to other binary queue to copy data - * from - * @exception std::bad_alloc Cannot allocate memory to hold additional data - * @warning One cannot assume that bucket structure is preserved during copy - */ - void AppendCopyFrom(const BinaryQueue &other); - - /** - * Move bytes from other binary queue to the end of this binary queue. - * This also removes all bytes from other binary queue. - * This method is designed to be as fast as possible (only pointer swaps) - * and is suggested over making copies of binary queues. - * Bucket structure is preserved after operation. - * - * @return none - * @param[in] other Reference to other binary queue to move data from - * @exception std::bad_alloc Cannot allocate memory to hold additional data - */ - void AppendMoveFrom(BinaryQueue &other); - - /** - * Append copy of binary queue to the end of other binary queue - * - * @return none - * @param[in] other Constant reference to other binary queue to copy data to - * @exception std::bad_alloc Cannot allocate memory to hold additional data - * @warning One cannot assume that bucket structure is preserved during copy - */ - void AppendCopyTo(BinaryQueue &other) const; - - /** - * Move bytes from binary queue to the end of other binary queue. - * This also removes all bytes from binary queue. - * This method is designed to be as fast as possible (only pointer swaps) - * and is suggested over making copies of binary queues. - * Bucket structure is preserved after operation. - * - * @return none - * @param[in] other Reference to other binary queue to move data to - * @exception std::bad_alloc Cannot allocate memory to hold additional data - */ - void AppendMoveTo(BinaryQueue &other); - - /** - * Retrieve total size of all data contained in binary queue - * - * @return Number of bytes in binary queue - */ - size_t Size() const; - - /** - * Remove all data from binary queue - * - * @return none - */ - void Clear(); - - /** - * Check if binary queue is empty - * - * @return true if binary queue is empty, false otherwise - */ - bool Empty() const; - - /** - * Remove @a size bytes from beginning of binary queue - * - * @return none - * @param[in] size Number of bytes to remove - * @exception BinaryQueue::Exception::OutOfData Number of bytes is larger - * than available bytes in binary queue - */ - void Consume(size_t size); - - /** - * Retrieve @a bufferSize bytes from beginning of binary queue and copy them - * to user supplied buffer - * - * @return none - * @param[in] buffer Pointer to user buffer to receive bytes - * @param[in] bufferSize Size of user buffer pointed by @a buffer - * @exception BinaryQueue::Exception::OutOfData Number of bytes to flatten - * is larger than available bytes in binary queue - */ - void Flatten(void *buffer, size_t bufferSize) const; - - /** - * Retrieve @a bufferSize bytes from beginning of binary queue, copy them - * to user supplied buffer, and remove from binary queue - * - * @return none - * @param[in] buffer Pointer to user buffer to receive bytes - * @param[in] bufferSize Size of user buffer pointed by @a buffer - * @exception BinaryQueue::Exception::OutOfData Number of bytes to flatten - * is larger than available bytes in binary queue - */ - void FlattenConsume(void *buffer, size_t bufferSize); - - /** - * Visit each buffer with data using visitor object - * - * @return none - * @param[in] visitor Pointer to bucket visitor - * @see BinaryQueue::BucketVisitor - */ - void VisitBuckets(BucketVisitor *visitor) const; - - /** - * IAbstractInput interface - */ - virtual BinaryQueueUniquePtr Read(size_t size); - - /** - * IAbstractOutput interface - */ - virtual size_t Write(const BinaryQueue &buffer, size_t bufferSize); -}; - -} // namespace SecurityManager diff --git a/src/dpl/core/src/binary_queue.cpp b/src/dpl/core/src/binary_queue.cpp deleted file mode 100644 index 94c01b99..00000000 --- a/src/dpl/core/src/binary_queue.cpp +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Copyright (c) 2011-2020 Samsung Electronics Co., Ltd. All rights reserved. - * - * This file is licensed under the terms of MIT License or the Apache License - * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. - * See the LICENSE file or the notice below for Apache License Version 2.0 - * details. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * @file binary_queue.cpp - * @author Przemyslaw Dobrowolski (p.dobrowolsk@samsung.com) - * @version 1.0 - * @brief This file is the implementation file of binary queue - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace SecurityManager { -BinaryQueue::BinaryQueue() : - m_size(0) -{} - -BinaryQueue::BinaryQueue(const BinaryQueue &other) : - m_size(0) -{ - AppendCopyFrom(other); -} - -BinaryQueue::~BinaryQueue() -{ - // Remove all remainig buckets - Clear(); -} - -const BinaryQueue &BinaryQueue::operator=(const BinaryQueue &other) -{ - if (this != &other) { - Clear(); - AppendCopyFrom(other); - } - - return *this; -} - -void BinaryQueue::AppendCopyFrom(const BinaryQueue &other) -{ - // To speed things up, always copy as one bucket - void *bufferCopy = malloc(other.m_size); - - if (bufferCopy == NULL) { - throw std::bad_alloc(); - } - - try { - other.Flatten(bufferCopy, other.m_size); - AppendUnmanaged(bufferCopy, other.m_size, &BufferDeleterFree, NULL); - } catch (const std::bad_alloc &) { - // Free allocated memory - free(bufferCopy); - throw; - } -} - -void BinaryQueue::AppendMoveFrom(BinaryQueue &other) -{ - // Copy all buckets - std::copy(other.m_buckets.begin(), - other.m_buckets.end(), std::back_inserter(m_buckets)); - m_size += other.m_size; - - // Clear other, but do not free memory - other.m_buckets.clear(); - other.m_size = 0; -} - -void BinaryQueue::AppendCopyTo(BinaryQueue &other) const -{ - other.AppendCopyFrom(*this); -} - -void BinaryQueue::AppendMoveTo(BinaryQueue &other) -{ - other.AppendMoveFrom(*this); -} - -void BinaryQueue::Clear() -{ - std::for_each(m_buckets.begin(), m_buckets.end(), &DeleteBucket); - m_buckets.clear(); - m_size = 0; -} - -void BinaryQueue::AppendCopy(const void* buffer, size_t bufferSize) -{ - // Create data copy with malloc/free - void *bufferCopy = malloc(bufferSize); - - // Check if allocation succeded - if (bufferCopy == NULL) { - throw std::bad_alloc(); - } - - // Copy user data - memcpy(bufferCopy, buffer, bufferSize); - - try { - // Try to append new bucket - AppendUnmanaged(bufferCopy, bufferSize, &BufferDeleterFree, NULL); - } catch (const std::bad_alloc &) { - // Free allocated memory - free(bufferCopy); - throw; - } -} - -void BinaryQueue::AppendUnmanaged(const void* buffer, - size_t bufferSize, - BufferDeleter deleter, - void* userParam) -{ - // Do not attach empty buckets - if (bufferSize == 0) { - deleter(buffer, bufferSize, userParam); - return; - } - - // Just add new bucket with selected deleter - Bucket *bucket = new Bucket(buffer, bufferSize, deleter, userParam); - try { - m_buckets.push_back(bucket); - } catch (const std::bad_alloc &) { - delete bucket; - throw; - } - - // Increase total queue size - m_size += bufferSize; -} - -size_t BinaryQueue::Size() const -{ - return m_size; -} - -bool BinaryQueue::Empty() const -{ - return m_size == 0; -} - -void BinaryQueue::Consume(size_t size) -{ - // Check parameters - if (size > m_size) { - Throw(Exception::OutOfData); - } - - size_t bytesLeft = size; - - // Consume data and/or remove buckets - while (bytesLeft > 0) { - // Get consume size - size_t count = std::min(bytesLeft, m_buckets.front()->left); - - m_buckets.front()->ptr = - static_cast(m_buckets.front()->ptr) + count; - m_buckets.front()->left -= count; - bytesLeft -= count; - m_size -= count; - - if (m_buckets.front()->left == 0) { - DeleteBucket(m_buckets.front()); - m_buckets.pop_front(); - } - } -} - -void BinaryQueue::Flatten(void *buffer, size_t bufferSize) const -{ - // Check parameters - if (bufferSize == 0) { - return; - } - - if (bufferSize > m_size) { - Throw(Exception::OutOfData); - } - - size_t bytesLeft = bufferSize; - void *ptr = buffer; - BucketList::const_iterator bucketIterator = m_buckets.begin(); - Assert(m_buckets.end() != bucketIterator); - - // Flatten data - while (bytesLeft > 0) { - // Get consume size - size_t count = std::min(bytesLeft, (*bucketIterator)->left); - - // Copy data to user pointer - memcpy(ptr, (*bucketIterator)->ptr, count); - - // Update flattened bytes count - bytesLeft -= count; - ptr = static_cast(ptr) + count; - - // Take next bucket - ++bucketIterator; - } -} - -void BinaryQueue::FlattenConsume(void *buffer, size_t bufferSize) -{ - // FIXME: Optimize - Flatten(buffer, bufferSize); - Consume(bufferSize); -} - -void BinaryQueue::DeleteBucket(BinaryQueue::Bucket *bucket) -{ - delete bucket; -} - -void BinaryQueue::BufferDeleterFree(const void* data, - size_t dataSize, - void* userParam) -{ - (void)dataSize; - (void)userParam; - - // Default free deleter - free(const_cast(data)); -} - -BinaryQueue::Bucket::Bucket(const void* data, - size_t dataSize, - BufferDeleter dataDeleter, - void* userParam) : - buffer(data), - ptr(data), - size(dataSize), - left(dataSize), - deleter(dataDeleter), - param(userParam) -{ - Assert(data != NULL); - Assert(deleter != NULL); -} - -BinaryQueue::Bucket::~Bucket() -{ - // Invoke deleter on bucket data - deleter(buffer, size, param); -} - -BinaryQueue::BucketVisitor::~BucketVisitor() -{} - -BinaryQueue::BucketVisitorCall::BucketVisitorCall(BucketVisitor *visitor) : - m_visitor(visitor) -{} - -BinaryQueue::BucketVisitorCall::~BucketVisitorCall() -{} - -void BinaryQueue::BucketVisitorCall::operator()(Bucket *bucket) const -{ - m_visitor->OnVisitBucket(bucket->ptr, bucket->left); -} - -void BinaryQueue::VisitBuckets(BucketVisitor *visitor) const -{ - Assert(visitor != NULL); - - // Visit all buckets - std::for_each(m_buckets.begin(), m_buckets.end(), BucketVisitorCall(visitor)); -} - -BinaryQueueUniquePtr BinaryQueue::Read(size_t size) -{ - // Simulate input stream - size_t available = std::min(size, m_size); - - std::unique_ptr> - bufferCopy(malloc(available), free); - - if (!bufferCopy.get()) { - throw std::bad_alloc(); - } - - BinaryQueueUniquePtr result(new BinaryQueue()); - - Flatten(bufferCopy.get(), available); - result->AppendUnmanaged( - bufferCopy.release(), available, &BufferDeleterFree, NULL); - Consume(available); - - return result; -} - -size_t BinaryQueue::Write(const BinaryQueue &buffer, size_t bufferSize) -{ - // Simulate output stream - AppendCopyFrom(buffer); - return bufferSize; -} -} // namespace SecurityManager diff --git a/src/server/main/include/generic-socket-manager.h b/src/server/main/include/generic-socket-manager.h index c916c62d..c545894a 100644 --- a/src/server/main/include/generic-socket-manager.h +++ b/src/server/main/include/generic-socket-manager.h @@ -28,13 +28,13 @@ #pragma once -#include #include #include #include #include +#include namespace SecurityManager { @@ -46,8 +46,6 @@ struct ConnectionID { } }; -typedef std::vector RawBuffer; - struct GenericSocketManager; struct GenericSocketService { @@ -64,28 +62,15 @@ struct GenericSocketService { ServiceHandlerPath serviceHandlerPath; // Path to file }; - struct AcceptEvent : public GenericEvent { - AcceptEvent(ConnectionID connection, Credentials crd) + struct MessageEvent : public GenericEvent { + MessageEvent(ConnectionID connection, Credentials &&crd, MessageBuffer &&messageBuffer) : connectionID(connection) , creds(std::move(crd)) + , messageBuffer(std::move(messageBuffer)) {} ConnectionID connectionID; Credentials creds; - }; - - struct WriteEvent : public GenericEvent { - ConnectionID connectionID; - size_t size; - size_t left; - }; - - struct ReadEvent : public GenericEvent { - ConnectionID connectionID; - RawBuffer rawBuffer; - }; - - struct CloseEvent : public GenericEvent { - ConnectionID connectionID; + MessageBuffer messageBuffer; }; virtual void SetSocketManager(GenericSocketManager *manager) { @@ -93,10 +78,7 @@ struct GenericSocketService { } virtual const ServiceDescription &GetServiceDescription() const = 0; - virtual void Event(AcceptEvent &&event) = 0; - virtual void Event(WriteEvent &&event) = 0; - virtual void Event(ReadEvent &&event) = 0; - virtual void Event(CloseEvent &&event) = 0; + virtual void Event(MessageEvent &&event) = 0; virtual void Start() {}; virtual void Stop() {}; @@ -111,7 +93,7 @@ struct GenericSocketManager { virtual void MainLoop() = 0; virtual void RegisterSocketService(GenericSocketService *ptr) = 0; virtual void Close(ConnectionID connectionID) = 0; - virtual void Write(ConnectionID connectionID, const RawBuffer &rawBuffer) = 0; + virtual void Write(ConnectionID connectionID, MessageBuffer &&messageBuffer) = 0; virtual ~GenericSocketManager(){} }; diff --git a/src/server/main/include/socket-manager.h b/src/server/main/include/socket-manager.h index ffcbba0f..9e6ffa23 100644 --- a/src/server/main/include/socket-manager.h +++ b/src/server/main/include/socket-manager.h @@ -53,7 +53,7 @@ public: virtual void RegisterSocketService(GenericSocketService *service); virtual void Close(ConnectionID connectionID); - virtual void Write(ConnectionID connectionID, const RawBuffer &rawBuffer); + virtual void Write(ConnectionID connectionID, MessageBuffer &&messageBuffer); protected: void CreateDomainSocket( @@ -76,18 +76,18 @@ protected: struct SocketDescription { bool isOpen = false; bool isTimeout = false; - time_t timeout = 0; - RawBuffer rawBuffer; int counter = -1; + time_t timeout = 0; + MessageBuffer buffer; }; - SocketDescription& CreateDefaultReadSocketDescription(int sock); + void CreateDefaultReadSocketDescription(int sock); typedef std::vector SocketDescriptionVector; struct WriteBuffer { ConnectionID connectionID; - RawBuffer rawBuffer; + MessageBuffer buffer; }; struct Timeout { diff --git a/src/server/main/socket-manager.cpp b/src/server/main/socket-manager.cpp index 673ee33d..1284081e 100644 --- a/src/server/main/socket-manager.cpp +++ b/src/server/main/socket-manager.cpp @@ -55,7 +55,7 @@ namespace { -const time_t SOCKET_TIMEOUT = 300; +constexpr time_t SOCKET_TIMEOUT = 300; } // namespace anonymous @@ -66,16 +66,16 @@ void SocketManager::RegisterFdForReading(int fd) { m_maxDesc = std::max(m_maxDesc, fd); } -SocketManager::SocketDescription& -SocketManager::CreateDefaultReadSocketDescription(int sock) +void SocketManager::CreateDefaultReadSocketDescription(int sock) { if ((int)m_socketDescriptionVector.size() <= sock) m_socketDescriptionVector.resize(sock+20); auto &desc = m_socketDescriptionVector[sock]; desc.isOpen = true; - desc.timeout = monotonicNow() + SOCKET_TIMEOUT; desc.counter = ++m_counter; + desc.timeout = monotonicNow() + SOCKET_TIMEOUT; + desc.buffer.InitForInput(); if (false == desc.isTimeout) { desc.isTimeout = true; @@ -86,7 +86,6 @@ SocketManager::CreateDefaultReadSocketDescription(int sock) } RegisterFdForReading(sock); - return desc; } SocketManager::SocketManager() @@ -144,12 +143,7 @@ void SocketManager::ReadyForAccept() { return; } - auto &desc = CreateDefaultReadSocketDescription(client); - - GenericSocketService::AcceptEvent event( - ConnectionID{client, desc.counter}, - Credentials::getCredentialsFromSocket(client)); - m_service->Event(std::move(event)); + CreateDefaultReadSocketDescription(client); } // true if quit mainloop @@ -178,23 +172,16 @@ bool SocketManager::GotSigTerm() const { } void SocketManager::ReadyForRead(int sock) { - GenericSocketService::ReadEvent event; - event.connectionID.sock = sock; - event.connectionID.counter = m_socketDescriptionVector[sock].counter; - event.rawBuffer.resize(4096); - auto &desc = m_socketDescriptionVector[sock]; + auto &buffer = desc.buffer; desc.timeout = monotonicNow() + SOCKET_TIMEOUT; - ssize_t size = read(sock, &event.rawBuffer[0], 4096); + ssize_t size = read(sock, buffer.Ptr(), buffer.InputSize()); if (size == 0) { LogDebug("Reading returned 0 bytes, closing socket: " << sock); - CloseSocket(sock); - } else if (size >= 0) { - event.rawBuffer.resize(size); - m_service->Event(std::move(event)); - } else if (size == -1) { + goto close; + } else if (size < 0) { int err = errno; switch (err) { case EAGAIN: @@ -202,15 +189,34 @@ void SocketManager::ReadyForRead(int sock) { break; default: LogError("Reading sock error: " << GetErrnoString(err)); - CloseSocket(sock); + goto close; + } + } else { + switch (buffer.InputDone(size)) { + case MessageBuffer::InputResult::ProtocolBroken: + goto close; + case MessageBuffer::InputResult::Pending: + break; + case MessageBuffer::InputResult::Done: + buffer.ModeStreaming(); + FD_CLR(sock, &m_readSet); // the one and only call on this socket is complete + m_service->Event( + GenericSocketService::MessageEvent(ConnectionID{sock, desc.counter}, + Credentials::getCredentialsFromSocket(sock), + std::move(buffer))); + break; } } + + return; +close: + CloseSocket(sock); } void SocketManager::ReadyForWrite(int sock) { auto &desc = m_socketDescriptionVector[sock]; - size_t size = desc.rawBuffer.size(); - ssize_t result = write(sock, &desc.rawBuffer[0], size); + auto &buffer = desc.buffer; + ssize_t result = write(sock, buffer.Ptr(), buffer.OutputSize()); if (result == -1) { int err = errno; switch (err) { @@ -226,24 +232,13 @@ void SocketManager::ReadyForWrite(int sock) { return; // We do not want to propagate error to next layer } - desc.rawBuffer.erase(desc.rawBuffer.begin(), desc.rawBuffer.begin()+result); - desc.timeout = monotonicNow() + SOCKET_TIMEOUT; - - if (desc.rawBuffer.empty()) - FD_CLR(sock, &m_writeSet); - - GenericSocketService::WriteEvent event; - event.connectionID.sock = sock; - event.connectionID.counter = desc.counter; - event.size = result; - event.left = desc.rawBuffer.size(); - - m_service->Event(std::move(event)); + if (buffer.OutputDone(result)) + CloseSocket(sock); } void SocketManager::MainLoop() { - // remove evironment values passed by systemd + // remove environment values passed by systemd sd_listen_fds(1); // Daemon is ready to work. @@ -277,7 +272,7 @@ void SocketManager::MainLoop() { } if (m_timeoutQueue.empty()) { - LogDebug("No usaable timeout found."); + LogDebug("No usable timeout found."); ptrTimeout = NULL; // select will wait without timeout } else { time_t currentTime = monotonicNow(); @@ -343,25 +338,27 @@ void SocketManager::MainLoop() { if (GotSigTerm()) return; FD_CLR(m_signalFd, &readSet); + ret--; } if (FD_ISSET(m_listenSock, &readSet)) { ReadyForAccept(); FD_CLR(m_listenSock, &readSet); + ret--; } if (FD_ISSET(m_notifyMe, &readSet)) { eventfd_t dummyValue; TEMP_FAILURE_RETRY(eventfd_read(m_notifyMe, &dummyValue)); FD_CLR(m_notifyMe, &readSet); + ret--; } for (int i = 0; i < m_maxDesc+1 && ret; ++i) { if (FD_ISSET(i, &readSet)) { ReadyForRead(i); - --ret; - } - if (FD_ISSET(i, &writeSet)) { + ret--; + } else if (FD_ISSET(i, &writeSet)) { ReadyForWrite(i); - --ret; + ret--; } } ProcessQueue(); @@ -493,13 +490,10 @@ void SocketManager::Close(ConnectionID connectionID) { NotifyMe(); } -void SocketManager::Write(ConnectionID connectionID, const RawBuffer &rawBuffer) { - WriteBuffer buffer; - buffer.connectionID = connectionID; - buffer.rawBuffer = rawBuffer; +void SocketManager::Write(ConnectionID connectionID, MessageBuffer &&buffer) { { std::lock_guard ulock(m_eventQueueMutex); - m_writeBufferQueue.push(buffer); + m_writeBufferQueue.push(WriteBuffer { connectionID, std::move(buffer) }); } NotifyMe(); } @@ -509,11 +503,10 @@ void SocketManager::NotifyMe() { } void SocketManager::ProcessQueue() { - WriteBuffer buffer; { std::lock_guard ulock(m_eventQueueMutex); while (!m_writeBufferQueue.empty()) { - buffer = m_writeBufferQueue.front(); + auto buffer = std::move(m_writeBufferQueue.front()); m_writeBufferQueue.pop(); auto &desc = m_socketDescriptionVector[buffer.connectionID.sock]; @@ -531,11 +524,9 @@ void SocketManager::ProcessQueue() { continue; } - std::copy( - buffer.rawBuffer.begin(), - buffer.rawBuffer.end(), - std::back_inserter(desc.rawBuffer)); + desc.buffer = std::move(buffer.buffer); + desc.buffer.ModeOutput(); FD_SET(buffer.connectionID.sock, &m_writeSet); } } @@ -564,19 +555,15 @@ void SocketManager::CloseSocket(int sock) { LogDebug("Closing socket: " << sock); auto &desc = m_socketDescriptionVector[sock]; - if (!(desc.isOpen)) { - // This may happend when some information was waiting for write to the + if (!desc.isOpen) { + // This may happen when some information was waiting for write to the // socket and in the same time socket was closed by the client. - LogError("Socket " << sock << " is not open. Nothing to do!"); + LogDebug("Socket " << sock << " is not open. Nothing to do!"); return; } - GenericSocketService::CloseEvent event; - event.connectionID.sock = sock; - event.connectionID.counter = desc.counter; - desc.isOpen = false; - desc.rawBuffer.clear(); + desc.buffer.Clear(); close(sock); FD_CLR(sock, &m_readSet); diff --git a/src/server/service/base-service.cpp b/src/server/service/base-service.cpp index 8450dfd1..dc2c931b 100644 --- a/src/server/service/base-service.cpp +++ b/src/server/service/base-service.cpp @@ -44,42 +44,11 @@ BaseService::BaseService(Offline offline) { } -void BaseService::accept(AcceptEvent &&event) +void BaseService::process(MessageEvent &&event) { - LogDebug("Accept event. ConnectionID.sock: " << event.connectionID.sock << - " ConnectionID.counter: " << event.connectionID.counter); + LogDebug("Message event for counter: " << event.connectionID.counter); - m_connectionInfoMap.emplace( - std::make_pair( - event.connectionID.counter, - ConnectionInfo(std::move(event.creds)))); -} - -void BaseService::write(WriteEvent &&event) -{ - LogDebug("WriteEvent. ConnectionID: " << event.connectionID.sock << - " Size: " << event.size << - " Left: " << event.left); - - if (event.left == 0) - m_serviceManager->Close(event.connectionID); -} - -void BaseService::process(ReadEvent &&event) -{ - LogDebug("Read event for counter: " << event.connectionID.counter); - auto &info = m_connectionInfoMap.at(event.connectionID.counter); - info.buffer.Push(event.rawBuffer); - - // We can get several requests in one package. - // Extract and process them all - while (processOne(event.connectionID, info.buffer)); -} - -void BaseService::close(CloseEvent &&event) -{ - LogDebug("CloseEvent. ConnectionID: " << event.connectionID.sock); - m_connectionInfoMap.erase(event.connectionID.counter); + processOne(event.connectionID, event.creds, event.messageBuffer); } void BaseService::Start() diff --git a/src/server/service/include/base-service.h b/src/server/service/include/base-service.h index 2ca2d39f..f89e09a4 100644 --- a/src/server/service/include/base-service.h +++ b/src/server/service/include/base-service.h @@ -31,7 +31,6 @@ #include #include #include -#include #include namespace SecurityManager { @@ -52,15 +51,9 @@ public: explicit BaseService(Offline offline); virtual const ServiceDescription &GetServiceDescription() const = 0; - DECLARE_THREAD_EVENT(AcceptEvent, accept) - DECLARE_THREAD_EVENT(WriteEvent, write) - DECLARE_THREAD_EVENT(ReadEvent, process) - DECLARE_THREAD_EVENT(CloseEvent, close) + DECLARE_THREAD_EVENT(MessageEvent, process) - void accept(AcceptEvent &&event); - void write(WriteEvent &&event); - void process(ReadEvent &&event); - void close(CloseEvent &&event); + void process(MessageEvent &&event); void Start(); void Stop(); @@ -68,17 +61,14 @@ public: protected: ServiceImpl serviceImpl; - ConnectionInfoMap m_connectionInfoMap; - /** * Handle request from a client * * @param conn Socket connection information - * @param buffer Raw received data buffer + * @param buffer Received message buffer * @return true on success */ - virtual bool processOne(const ConnectionID &conn, - MessageBuffer &buffer) = 0; + virtual void processOne(const ConnectionID &conn, Credentials &creds, MessageBuffer &buffer) = 0; }; } // namespace SecurityManager diff --git a/src/server/service/include/service.h b/src/server/service/include/service.h index b5b66d51..7fe5dede 100644 --- a/src/server/service/include/service.h +++ b/src/server/service/include/service.h @@ -57,79 +57,71 @@ private: /** * Handle request from a client * - * @param conn Socket connection information - * @param buffer Raw received data buffer - * @return true on success + * @param conn Socket connection information + * @param buffer Input/output message buffer + * @return true on success */ - bool processOne(const ConnectionID &conn, MessageBuffer &buffer); - + void processOne(const ConnectionID &conn, Credentials &creds, MessageBuffer &buffer) override; /** * Process getting application manifest policy - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppGetManifestPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processAppGetManifestPolicy(MessageBuffer &buffer, const Credentials &creds); /** * Process application installation * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppInstall(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processAppInstall(MessageBuffer &buffer, const Credentials &creds); /** * Process application update (currently only for cases when hybrid setting is changed) * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppUpdate(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processAppUpdate(MessageBuffer &buffer, const Credentials &creds); /** * Process application uninstallation * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppUninstall(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processAppUninstall(MessageBuffer &buffer, const Credentials &creds); /** * Process getting package identifier from an app identifier * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processGetPkgName(MessageBuffer &buffer, MessageBuffer &send); + void processGetPkgName(MessageBuffer &buffer); - void processUserAdd(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processUserAdd(MessageBuffer &buffer, const Credentials &creds); - void processUserDelete(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processUserDelete(MessageBuffer &buffer, const Credentials &creds); /** * Process policy update request * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processPolicyUpdate(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processPolicyUpdate(MessageBuffer &buffer, const Credentials &creds); /** * List all privileges for specific user, placed in Cynara's PRIVACY_MANAGER * or ADMIN's bucket - choice based on forAdmin parameter * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process - * @param forAdmin determines internal type of request + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process + * @param forAdmin determines internal type of request */ - void processGetConfiguredPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds, bool forAdmin); + void processGetConfiguredPolicy(MessageBuffer &buffer, const Credentials &creds, bool forAdmin); /** * Get whole policy for specific user. Whole policy is a list of all apps, @@ -139,130 +131,116 @@ private: * will be listed. If caller is privileged, then apps for all the users will * be listed. * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processGetPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processGetPolicy(MessageBuffer &buffer, const Credentials &creds); /** * Process getting policies descriptions as strings from Cynara * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processPolicyGetDesc(MessageBuffer &send); + void processPolicyGetDesc(MessageBuffer &buffer); /** * Process getting groups bound with privileges and permitted group ids for app id * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processGetForbiddenAndAllowedGroups(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processGetForbiddenAndAllowedGroups(MessageBuffer &buffer, const Credentials &creds); /** * Process getting groups bound with privileges for given uid * - * @param send Raw data buffer to be sent + * @param buffer Output message buffer */ - void processGroupsForUid(MessageBuffer &recv, MessageBuffer &send); + void processGroupsForUid(MessageBuffer &buffer); /** * Process checking application's privilege access based on app_id * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppHasPrivilege(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds); + void processAppHasPrivilege(MessageBuffer &buffer, const Credentials &creds); /** * Process applying private path sharing between applications. * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processApplyPrivateSharing(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds); + void processApplyPrivateSharing(MessageBuffer &buffer, const Credentials &creds); /** * Process drop private path sharing between applications. * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processDropPrivateSharing(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds); + void processDropPrivateSharing(MessageBuffer &buffer, const Credentials &creds); /** * Process package paths registration request * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processPathsRegister(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds); + void processPathsRegister(MessageBuffer &buffer, const Credentials &creds); /** * Process shared memory access request * - * @param recv Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processShmAppName(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds); + void processShmAppName(MessageBuffer &buffer, const Credentials &creds); /** * Process getting provider(app_id, pkg_id) of privilege * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processGetAppDefinedPrivilegeProvider(MessageBuffer &buffer, MessageBuffer &send); + void processGetAppDefinedPrivilegeProvider(MessageBuffer &buffer); /** * Process getting license of privilege * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processGetAppDefinedPrivilegeLicense(MessageBuffer &buffer, MessageBuffer &send); + void processGetAppDefinedPrivilegeLicense(MessageBuffer &buffer); /** * Process getting license of privilege * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processGetClientPrivilegeLicense(MessageBuffer &buffer, MessageBuffer &send); + void processGetClientPrivilegeLicense(MessageBuffer &buffer); /** * Process clean app namespace * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void processAppCleanNamespace(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void processAppCleanNamespace(MessageBuffer &buffer, const Credentials &creds); /** * Get process label * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent + * @param buffer Input/output message buffer */ - void processGetProcessLabel(MessageBuffer &buffer, MessageBuffer &send); + void processGetProcessLabel(MessageBuffer &buffer); /** * Get app info (process label, package name, shared_ro flag) and groups, setup app namespace * - * @param buffer Raw received data buffer - * @param send Raw data buffer to be sent - * @param creds credentials of the requesting process + * @param buffer Input/output message buffer + * @param creds credentials of the requesting process */ - void prepareApp(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds); + void prepareApp(MessageBuffer &buffer, const Credentials &creds); }; } // namespace SecurityManager diff --git a/src/server/service/service.cpp b/src/server/service/service.cpp index 8eea5068..3a3c8e10 100644 --- a/src/server/service/service.cpp +++ b/src/server/service/service.cpp @@ -53,19 +53,10 @@ const GenericSocketService::ServiceDescription &Service::GetServiceDescription() return serviceDesc; } -bool Service::processOne(const ConnectionID &conn, MessageBuffer &buffer) +void Service::processOne(const ConnectionID &conn, Credentials &creds, MessageBuffer &buffer) { LogDebug("Iteration begin."); - - //waiting for all data - if (!buffer.Ready()) { - return false; - } - - MessageBuffer send; - Try { - Credentials &creds = m_connectionInfoMap.at(conn.counter).creds; // deserialize API call type int call_type_int; Deserialization::Deserialize(buffer, call_type_int); @@ -74,111 +65,110 @@ bool Service::processOne(const ConnectionID &conn, MessageBuffer &buffer) switch (call_type) { case SecurityModuleCall::NOOP: LogDebug("call_type: SecurityModuleCall::NOOP"); - Serialization::Serialize(send, static_cast(SECURITY_MANAGER_SUCCESS)); + Serialization::Serialize(buffer, static_cast(SECURITY_MANAGER_SUCCESS)); break; case SecurityModuleCall::APP_INSTALL: LogDebug("call_type: SecurityModuleCall::APP_INSTALL"); - processAppInstall(buffer, send, creds); + processAppInstall(buffer, creds); break; case SecurityModuleCall::APP_UPDATE: LogDebug("call_type: SecurityModuleCall::APP_UPDATE"); - processAppUpdate(buffer, send, creds); + processAppUpdate(buffer, creds); break; case SecurityModuleCall::APP_UNINSTALL: LogDebug("call_type: SecurityModuleCall::APP_UNINSTALL"); - processAppUninstall(buffer, send, creds); + processAppUninstall(buffer, creds); break; case SecurityModuleCall::APP_GET_PKG_NAME: LogDebug("call_type: SecurityModuleCall::APP_GET_PKG_NAME"); - processGetPkgName(buffer, send); + processGetPkgName(buffer); break; case SecurityModuleCall::USER_ADD: LogDebug("call_type: SecurityModuleCall::USER_ADD"); - processUserAdd(buffer, send, creds); + processUserAdd(buffer, creds); break; case SecurityModuleCall::USER_DELETE: LogDebug("call_type: SecurityModuleCall::USER_DELETE"); - processUserDelete(buffer, send, creds); + processUserDelete(buffer, creds); break; case SecurityModuleCall::POLICY_UPDATE: LogDebug("call_type: SecurityModuleCall::POLICY_UPDATE"); - processPolicyUpdate(buffer, send, creds); + processPolicyUpdate(buffer, creds); break; case SecurityModuleCall::GET_CONF_POLICY_ADMIN: LogDebug("call_type: SecurityModuleCall::GET_CONF_POLICY_ADMIN"); - processGetConfiguredPolicy(buffer, send, creds, true); + processGetConfiguredPolicy(buffer, creds, true); break; case SecurityModuleCall::GET_CONF_POLICY_SELF: LogDebug("call_type: SecurityModuleCall::GET_CONF_POLICY_SELF"); - processGetConfiguredPolicy(buffer, send, creds, false); + processGetConfiguredPolicy(buffer, creds, false); break; case SecurityModuleCall::GET_POLICY: LogDebug("call_type: SecurityModuleCall::GET_POLICY"); - processGetPolicy(buffer, send, creds); + processGetPolicy(buffer, creds); break; case SecurityModuleCall::POLICY_GET_DESCRIPTIONS: LogDebug("call_type: SecurityModuleCall::POLICY_GET_DESCRIPTIONS"); - processPolicyGetDesc(send); + processPolicyGetDesc(buffer); break; case SecurityModuleCall::GROUPS_GET: LogDebug("call_type: SecurityModuleCall::GROUPS_GET"); - processGetForbiddenAndAllowedGroups(buffer, send, creds); + processGetForbiddenAndAllowedGroups(buffer, creds); break; case SecurityModuleCall::GROUPS_FOR_UID: - processGroupsForUid(buffer, send); + processGroupsForUid(buffer); break; case SecurityModuleCall::APP_HAS_PRIVILEGE: LogDebug("call_type: SecurityModuleCall::APP_HAS_PRIVILEGE"); - processAppHasPrivilege(buffer, send, creds); + processAppHasPrivilege(buffer, creds); break; case SecurityModuleCall::APP_APPLY_PRIVATE_SHARING: LogDebug("call_type: SecurityModuleCall::APP_APPLY_PRIVATE_SHARING"); - processApplyPrivateSharing(buffer, send, creds); + processApplyPrivateSharing(buffer, creds); break; case SecurityModuleCall::APP_DROP_PRIVATE_SHARING: LogDebug("call_type: SecurityModuleCall::APP_DROP_PRIVATE_SHARING"); - processDropPrivateSharing(buffer, send, creds); + processDropPrivateSharing(buffer, creds); break; case SecurityModuleCall::PATHS_REGISTER: - processPathsRegister(buffer, send, creds); + processPathsRegister(buffer, creds); break; case SecurityModuleCall::SHM_APP_NAME: - processShmAppName(buffer, send, creds); + processShmAppName(buffer, creds); break; case SecurityModuleCall::GET_APP_DEFINED_PRIVILEGE_PROVIDER: LogDebug("call_type: SecurityModuleCall::GET_APP_DEFINED_PRIVILEGE_PROVIDER"); - processGetAppDefinedPrivilegeProvider(buffer, send); + processGetAppDefinedPrivilegeProvider(buffer); break; case SecurityModuleCall::GET_APP_DEFINED_PRIVILEGE_LICENSE: LogDebug("call_type: SecurityModuleCall::GET_APP_DEFINED_PRIVILEGE_LICENSE"); - processGetAppDefinedPrivilegeLicense(buffer, send); + processGetAppDefinedPrivilegeLicense(buffer); break; case SecurityModuleCall::GET_CLIENT_PRIVILEGE_LICENSE: LogDebug("call_type: SecurityModuleCall::GET_CLIENT_PRIVILEGE_PROVIDER"); - processGetClientPrivilegeLicense(buffer, send); + processGetClientPrivilegeLicense(buffer); break; case SecurityModuleCall::APP_CLEAN_NAMESPACE: LogDebug("call_type: SecurityModuleCall::APP_CLEAN_NAMESPACE"); - processAppCleanNamespace(buffer, send, creds); + processAppCleanNamespace(buffer, creds); break; case SecurityModuleCall::GET_APP_MANIFEST_POLICY: LogDebug("call_type: SecurityModuleCall::GET_APP_MANIFEST_POLICY"); - processAppGetManifestPolicy(buffer, send, creds); + processAppGetManifestPolicy(buffer, creds); break; case SecurityModuleCall::GET_PROCESS_LABEL: - processGetProcessLabel(buffer, send); + processGetProcessLabel(buffer); break; case SecurityModuleCall::PREPARE_APP: - prepareApp(buffer, send, creds); + prepareApp(buffer, creds); break; default: LogError("Invalid call: " << call_type_int); Throw(ServiceException::InvalidAction); } // if we reach this point, the protocol is OK - LogDebug("Writing response to client, size of serialized response: " << send.SerializedSize()); - m_serviceManager->Write(conn, send.Pop()); - return true; + LogDebug("Writing response to client."); + return m_serviceManager->Write(conn, std::move(buffer)); } Catch(MessageBuffer::Exception::Base) { LogError("Broken protocol."); } Catch(ServiceException::Base) { @@ -191,11 +181,9 @@ bool Service::processOne(const ConnectionID &conn, MessageBuffer &buffer) LogError("Closing socket because of error"); m_serviceManager->Close(conn); - - return false; } -void Service::processAppGetManifestPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processAppGetManifestPolicy(MessageBuffer &buffer, const Credentials &creds) { std::string appName; uid_t uid; @@ -207,39 +195,43 @@ void Service::processAppGetManifestPolicy(MessageBuffer &buffer, MessageBuffer & ret = serviceImpl.getAppManifestPolicy(creds, appName, uid, privileges); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) { - Serialization::Serialize(send, static_cast(privileges.size())); + Serialization::Serialize(buffer, static_cast(privileges.size())); for (const auto &privilege : privileges) - Serialization::Serialize(send, privilege); + Serialization::Serialize(buffer, privilege); } } -void Service::processAppInstall(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processAppInstall(MessageBuffer &buffer, const Credentials &creds) { app_inst_req req; Deserialization::Deserialize(buffer, req); - Serialization::Serialize(send, serviceImpl.appInstall(creds, req)); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, serviceImpl.appInstall(creds, req)); } -void Service::processAppUpdate(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processAppUpdate(MessageBuffer &buffer, const Credentials &creds) { app_inst_req req; Deserialization::Deserialize(buffer, req); - Serialization::Serialize(send, serviceImpl.appUpdate(creds, req)); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, serviceImpl.appUpdate(creds, req)); } -void Service::processAppUninstall(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processAppUninstall(MessageBuffer &buffer, const Credentials &creds) { app_inst_req req; Deserialization::Deserialize(buffer, req); - Serialization::Serialize(send, serviceImpl.appUninstall(creds, req)); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, serviceImpl.appUninstall(creds, req)); } -void Service::processGetPkgName(MessageBuffer &buffer, MessageBuffer &send) +void Service::processGetPkgName(MessageBuffer &buffer) { std::string appName; std::string pkgName; @@ -247,12 +239,13 @@ void Service::processGetPkgName(MessageBuffer &buffer, MessageBuffer &send) Deserialization::Deserialize(buffer, appName); ret = serviceImpl.getPkgName(appName, pkgName); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, pkgName); + Serialization::Serialize(buffer, pkgName); } -void Service::processUserAdd(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processUserAdd(MessageBuffer &buffer, const Credentials &creds) { int ret; uid_t uidAdded; @@ -262,10 +255,11 @@ void Service::processUserAdd(MessageBuffer &buffer, MessageBuffer &send, const C Deserialization::Deserialize(buffer, userType); ret = serviceImpl.userAdd(creds, uidAdded, userType); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processUserDelete(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processUserDelete(MessageBuffer &buffer, const Credentials &creds) { int ret; uid_t uidRemoved; @@ -273,10 +267,11 @@ void Service::processUserDelete(MessageBuffer &buffer, MessageBuffer &send, cons Deserialization::Deserialize(buffer, uidRemoved); ret = serviceImpl.userDelete(creds, uidRemoved); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processPolicyUpdate(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processPolicyUpdate(MessageBuffer &buffer, const Credentials &creds) { int ret; std::vector policyEntries; @@ -284,10 +279,11 @@ void Service::processPolicyUpdate(MessageBuffer &buffer, MessageBuffer &send, co Deserialization::Deserialize(buffer, policyEntries); ret = serviceImpl.policyUpdate(creds, policyEntries); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processGetConfiguredPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds, bool forAdmin) +void Service::processGetConfiguredPolicy(MessageBuffer &buffer, const Credentials &creds, bool forAdmin) { int ret; policy_entry filter; @@ -296,14 +292,15 @@ void Service::processGetConfiguredPolicy(MessageBuffer &buffer, MessageBuffer &s ret = serviceImpl.getConfiguredPolicy(creds, forAdmin, filter, policyEntries); - Serialization::Serialize(send, ret); - Serialization::Serialize(send, static_cast(policyEntries.size())); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); + Serialization::Serialize(buffer, static_cast(policyEntries.size())); for (const auto &policyEntry : policyEntries) { - Serialization::Serialize(send, policyEntry); + Serialization::Serialize(buffer, policyEntry); }; } -void Service::processGetPolicy(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processGetPolicy(MessageBuffer &buffer, const Credentials &creds) { int ret; policy_entry filter; @@ -312,31 +309,33 @@ void Service::processGetPolicy(MessageBuffer &buffer, MessageBuffer &send, const ret = serviceImpl.getPolicy(creds, filter, policyEntries); - Serialization::Serialize(send, ret); - Serialization::Serialize(send, static_cast(policyEntries.size())); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); + Serialization::Serialize(buffer, static_cast(policyEntries.size())); for (const auto &policyEntry : policyEntries) { - Serialization::Serialize(send, policyEntry); + Serialization::Serialize(buffer, policyEntry); }; } -void Service::processPolicyGetDesc(MessageBuffer &send) +void Service::processPolicyGetDesc(MessageBuffer &buffer) { int ret; std::vector descriptions; ret = serviceImpl.policyGetDesc(descriptions); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) { - Serialization::Serialize(send, static_cast(descriptions.size())); + Serialization::Serialize(buffer, static_cast(descriptions.size())); for (std::vector::size_type i = 0; i != descriptions.size(); i++) { - Serialization::Serialize(send, descriptions[i]); + Serialization::Serialize(buffer, descriptions[i]); } } } -void Service::processGetForbiddenAndAllowedGroups(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processGetForbiddenAndAllowedGroups(MessageBuffer &buffer, const Credentials &creds) { std::string appName; std::vector forbiddenGroups, allowedGroups; @@ -346,94 +345,100 @@ void Service::processGetForbiddenAndAllowedGroups(MessageBuffer &buffer, Message std::string label = serviceImpl.getProcessLabel(appName); std::vector allowedPrivileges; int ret = serviceImpl.getAppAllowedPrivileges(label, creds.uid, allowedPrivileges); - if (ret != SECURITY_MANAGER_SUCCESS) { + if (ret == SECURITY_MANAGER_SUCCESS) + ret = serviceImpl.getForbiddenAndAllowedGroups(label, allowedPrivileges, forbiddenGroups, + allowedGroups); + else LogError("Failed to fetch allowed privileges for " << label); - Serialization::Serialize(send, ret); - return; - } - ret = serviceImpl.getForbiddenAndAllowedGroups(label, allowedPrivileges, forbiddenGroups, - allowedGroups); - Serialization::Serialize(send, ret); + + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) { - Serialization::Serialize(send, forbiddenGroups, allowedGroups); + Serialization::Serialize(buffer, forbiddenGroups, allowedGroups); } } -void Service::processGroupsForUid(MessageBuffer &recv, MessageBuffer &send) +void Service::processGroupsForUid(MessageBuffer &buffer) { uid_t uid; std::vector groups; - Deserialization::Deserialize(recv, uid); + Deserialization::Deserialize(buffer, uid); int ret = serviceImpl.policyGroupsForUid(uid, groups); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) { - Serialization::Serialize(send, groups); + Serialization::Serialize(buffer, groups); } } -void Service::processAppHasPrivilege(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds) +void Service::processAppHasPrivilege(MessageBuffer &buffer, const Credentials &creds) { std::string appName; std::string privilege; uid_t uid; - Deserialization::Deserialize(recv, appName); - Deserialization::Deserialize(recv, privilege); - Deserialization::Deserialize(recv, uid); + Deserialization::Deserialize(buffer, appName); + Deserialization::Deserialize(buffer, privilege); + Deserialization::Deserialize(buffer, uid); bool result; int ret = serviceImpl.appHasPrivilege(creds, appName, privilege, uid, result); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, static_cast(result)); + Serialization::Serialize(buffer, static_cast(result)); } -void Service::processApplyPrivateSharing(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds) +void Service::processApplyPrivateSharing(MessageBuffer &buffer, const Credentials &creds) { std::string ownerAppName, targetAppName; std::vector paths; - Deserialization::Deserialize(recv, ownerAppName); - Deserialization::Deserialize(recv, targetAppName); - Deserialization::Deserialize(recv, paths); + Deserialization::Deserialize(buffer, ownerAppName); + Deserialization::Deserialize(buffer, targetAppName); + Deserialization::Deserialize(buffer, paths); int ret = serviceImpl.applyPrivatePathSharing(creds, ownerAppName, targetAppName, paths); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processDropPrivateSharing(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds) +void Service::processDropPrivateSharing(MessageBuffer &buffer, const Credentials &creds) { std::string ownerAppName, targetAppName; std::vector paths; - Deserialization::Deserialize(recv, ownerAppName); - Deserialization::Deserialize(recv, targetAppName); - Deserialization::Deserialize(recv, paths); + Deserialization::Deserialize(buffer, ownerAppName); + Deserialization::Deserialize(buffer, targetAppName); + Deserialization::Deserialize(buffer, paths); int ret = serviceImpl.dropPrivatePathSharing(creds, ownerAppName, targetAppName, paths); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processPathsRegister(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds) +void Service::processPathsRegister(MessageBuffer &buffer, const Credentials &creds) { path_req req; - Deserialization::Deserialize(recv, req.pkgName); - Deserialization::Deserialize(recv, req.uid); - Deserialization::Deserialize(recv, req.pkgPaths); - Deserialization::Deserialize(recv, req.installationType); + Deserialization::Deserialize(buffer, req.pkgName); + Deserialization::Deserialize(buffer, req.uid); + Deserialization::Deserialize(buffer, req.pkgPaths); + Deserialization::Deserialize(buffer, req.installationType); int ret = serviceImpl.pathsRegister(creds, std::move(req)); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processShmAppName(MessageBuffer &recv, MessageBuffer &send, const Credentials &creds) +void Service::processShmAppName(MessageBuffer &buffer, const Credentials &creds) { std::string shmName, appName; - Deserialization::Deserialize(recv, shmName, appName); + Deserialization::Deserialize(buffer, shmName, appName); int ret = serviceImpl.shmAppName(creds, shmName, appName); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processGetAppDefinedPrivilegeProvider(MessageBuffer &buffer, MessageBuffer &send) +void Service::processGetAppDefinedPrivilegeProvider(MessageBuffer &buffer) { int ret; std::string privilege, appName, pkgName; @@ -441,12 +446,13 @@ void Service::processGetAppDefinedPrivilegeProvider(MessageBuffer &buffer, Messa Deserialization::Deserialize(buffer, uid, privilege); ret = serviceImpl.getAppDefinedPrivilegeProvider(uid, privilege, appName, pkgName); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, appName, pkgName); + Serialization::Serialize(buffer, appName, pkgName); } -void Service::processGetAppDefinedPrivilegeLicense(MessageBuffer &buffer, MessageBuffer &send) +void Service::processGetAppDefinedPrivilegeLicense(MessageBuffer &buffer) { int ret; std::string privilege, license; @@ -454,12 +460,13 @@ void Service::processGetAppDefinedPrivilegeLicense(MessageBuffer &buffer, Messag Deserialization::Deserialize(buffer, uid, privilege); ret = serviceImpl.getAppDefinedPrivilegeLicense(uid, privilege, license); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, license); + Serialization::Serialize(buffer, license); } -void Service::processGetClientPrivilegeLicense(MessageBuffer &buffer, MessageBuffer &send) +void Service::processGetClientPrivilegeLicense(MessageBuffer &buffer) { int ret; std::string appName, pkgName, privilege, license; @@ -467,30 +474,33 @@ void Service::processGetClientPrivilegeLicense(MessageBuffer &buffer, MessageBuf Deserialization::Deserialize(buffer, appName, pkgName, uid, privilege); ret = serviceImpl.getClientPrivilegeLicense(appName, pkgName, uid, privilege, license); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, license); + Serialization::Serialize(buffer, license); } -void Service::processAppCleanNamespace(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::processAppCleanNamespace(MessageBuffer &buffer, const Credentials &creds) { std::string appName; uid_t uid; pid_t pid; Deserialization::Deserialize(buffer, appName, uid, pid); int ret = serviceImpl.appCleanNamespace(creds, appName, uid, pid); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); } -void Service::processGetProcessLabel(MessageBuffer &buffer, MessageBuffer &send) +void Service::processGetProcessLabel(MessageBuffer &buffer) { std::string appName; Deserialization::Deserialize(buffer, appName); - Serialization::Serialize(send, SECURITY_MANAGER_SUCCESS); - Serialization::Serialize(send, serviceImpl.getProcessLabel(appName)); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, SECURITY_MANAGER_SUCCESS); + Serialization::Serialize(buffer, serviceImpl.getProcessLabel(appName)); } -void Service::prepareApp(MessageBuffer &buffer, MessageBuffer &send, const Credentials &creds) +void Service::prepareApp(MessageBuffer &buffer, const Credentials &creds) { std::string appName, pkgName, label; PrepareAppFlags prepareAppFlags; @@ -500,9 +510,10 @@ void Service::prepareApp(MessageBuffer &buffer, MessageBuffer &send, const Crede Deserialization::Deserialize(buffer, appName, privPathsVector); int ret = serviceImpl.prepareApp(creds, appName, privPathsVector, label, pkgName, prepareAppFlags, forbiddenGroups, allowedGroups, privPathsStatusVector); - Serialization::Serialize(send, ret); + buffer.ModeStreaming(); + Serialization::Serialize(buffer, ret); if (ret == SECURITY_MANAGER_SUCCESS) - Serialization::Serialize(send, forbiddenGroups, allowedGroups, privPathsStatusVector, label, pkgName, prepareAppFlags); + Serialization::Serialize(buffer, forbiddenGroups, allowedGroups, privPathsStatusVector, label, pkgName, prepareAppFlags); } } // namespace SecurityManager diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 11e79554..5f93fb06 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -94,6 +94,7 @@ SET(SM_TESTS_SOURCES ${SM_TEST_SRC}/test_smack-labels.cpp ${SM_TEST_SRC}/test_smack-rules.cpp ${SM_TEST_SRC}/test_check_proper_drop.cpp + ${SM_TEST_SRC}/test_message_buffer.cpp ${SM_TEST_SRC}/test_misc.cpp ${SM_TEST_SRC}/test_template-manager.cpp ${DPL_PATH}/core/src/assert.cpp diff --git a/test/test_message_buffer.cpp b/test/test_message_buffer.cpp new file mode 100644 index 00000000..c9ff0337 --- /dev/null +++ b/test/test_message_buffer.cpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2022 Samsung Electronics Co., Ltd. All rights reserved. + * + * This file is licensed under the terms of MIT License or the Apache License + * Version 2.0 of your choice. See the LICENSE.MIT file for MIT license details. + * See the LICENSE file or the notice below for Apache License Version 2.0 + * details. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +using namespace SecurityManager; +using InputResult = MessageBuffer::InputResult; + +BOOST_AUTO_TEST_SUITE(MESSAGE_BUFFER_TEST) + +POSITIVE_TEST_CASE(T1200_messageBuffer) +{ + MessageBuffer buffer; + buffer.InitForInput(); + + constexpr char payload[] = {'a', 'b', 'c', 'd', 'e', 'f', '\0'}; + constexpr size_t size = sizeof payload; + + auto ptr = buffer.Ptr(); + BOOST_REQUIRE(ptr); + memcpy(ptr, &size, sizeof size); + for (unsigned i = 0; i < sizeof size; i++) { + BOOST_REQUIRE(buffer.InputSize() >= sizeof size - i); + BOOST_REQUIRE_EQUAL(buffer.Ptr(), &ptr[i]); + BOOST_REQUIRE(buffer.InputDone(1) == InputResult::Pending); + } + ptr = buffer.Ptr(); + BOOST_REQUIRE(ptr); + memcpy(ptr, payload, sizeof payload); + BOOST_REQUIRE(buffer.InputSize() >= sizeof payload); + for (unsigned i = 0; i < sizeof payload - 1; i++) { + BOOST_REQUIRE(buffer.InputDone(1) == InputResult::Pending); + BOOST_REQUIRE(buffer.InputSize() >= sizeof payload - (i+1)); + BOOST_REQUIRE_EQUAL(buffer.Ptr(), &ptr[i+1]); + } + BOOST_REQUIRE(buffer.InputDone(1) == InputResult::Done); + + buffer.ModeStreaming(); + + for (unsigned i = 0; i < sizeof payload; i++) { + char c; + Deserialization::Deserialize(buffer, c); + BOOST_REQUIRE_EQUAL(c, payload[i]); + } + BOOST_REQUIRE(buffer.DeserializationDone()); + + bool thrown = false; + try { + char c; + Deserialization::Deserialize(buffer, c); + } catch (...) { + thrown = true; + } + BOOST_REQUIRE(thrown); + + buffer.ModeStreaming(); + constexpr int End = 65537; + for (int i = 0; i < End; i++) + Serialization::Serialize(buffer, i); + + buffer.ModeOutput(); + BOOST_REQUIRE_EQUAL(buffer.OutputSize(), sizeof(size_t) + End * sizeof(int)); + ptr = buffer.Ptr(); + BOOST_REQUIRE_EQUAL(*reinterpret_cast(buffer.Ptr()), End * sizeof(int)); + BOOST_REQUIRE(!buffer.OutputDone(sizeof(size_t))); + BOOST_REQUIRE_EQUAL(buffer.Ptr(), &ptr[sizeof(size_t)]); + for (int i = 0; i < End - 1; i++) { + BOOST_REQUIRE_EQUAL(*reinterpret_cast(buffer.Ptr()), i); + BOOST_REQUIRE(!buffer.OutputDone(sizeof(int))); + BOOST_REQUIRE_EQUAL(buffer.OutputSize(), (End - (i+1)) * sizeof(int)); + BOOST_REQUIRE_EQUAL(buffer.Ptr(), &ptr[sizeof(size_t) + (i+1) * sizeof(int)]); + } + BOOST_REQUIRE_EQUAL(*reinterpret_cast(buffer.Ptr()), End - 1); + BOOST_REQUIRE(buffer.OutputDone(sizeof(int))); + + buffer.ModeStreaming(); + for (int i = 0; i < End; i++) { + int j; + Deserialization::Deserialize(buffer, j); + BOOST_REQUIRE_EQUAL(i, j); + } + BOOST_REQUIRE(buffer.DeserializationDone()); + + thrown = false; + try { + char c; + Deserialization::Deserialize(buffer, c); + } catch (...) { + thrown = true; + } + BOOST_REQUIRE(thrown); +} + +NEGATIVE_TEST_CASE(T1201_messageBuffer_excessTrailingBytes) +{ + MessageBuffer buffer; + buffer.InitBuffer(); + + // reading of a message with a huge payload + + buffer.ModeInput(); + + constexpr size_t hugePayloadSize = 1000*1000*1000; + size_t size = hugePayloadSize; + + BOOST_REQUIRE(buffer.Ptr()); + BOOST_REQUIRE(buffer.InputSize() >= sizeof size); + memcpy(buffer.Ptr(), &size, sizeof size); + BOOST_REQUIRE(buffer.InputDone(sizeof size) == InputResult::Pending); + + BOOST_REQUIRE(buffer.Ptr()); + BOOST_REQUIRE(buffer.InputSize() >= hugePayloadSize); + BOOST_REQUIRE(buffer.InputDone(hugePayloadSize) == InputResult::Done); + + // protocol broken by an excess trailing byte + + buffer.ModeInput(); + + size--; + + BOOST_REQUIRE(buffer.Ptr()); + BOOST_REQUIRE(buffer.InputSize() >= sizeof size + hugePayloadSize); + memcpy(buffer.Ptr(), &size, sizeof size); + BOOST_REQUIRE(buffer.InputDone(sizeof size + hugePayloadSize) == InputResult::ProtocolBroken); + + // protocol breakage must not break the buffer itself - check one huge input + + buffer.ModeInput(); + + size++; + + BOOST_REQUIRE(buffer.Ptr()); + BOOST_REQUIRE(buffer.InputSize() >= sizeof size + hugePayloadSize); + memcpy(buffer.Ptr(), &size, sizeof size); + BOOST_REQUIRE(buffer.InputDone(sizeof size + hugePayloadSize) == InputResult::Done); +} + +BOOST_AUTO_TEST_SUITE_END()