Protect deserialization against malicious clients 35/182235/4
authorTomasz Swierczek <t.swierczek@samsung.com>
Thu, 21 Jun 2018 09:35:43 +0000 (11:35 +0200)
committerTomasz Swierczek <t.swierczek@samsung.com>
Tue, 26 Jun 2018 04:35:54 +0000 (06:35 +0200)
Added protection against memory leaks when deserializing data
of bad size & detection of invalid STL sizes.

Change-Id: Ia2781b352585ce32e401ca3830b8304e43233e5c

src/dpl/core/include/dpl/serialization.h
tests/CMakeLists.txt

index d078c8d1f5424b0e98e0ddc8d1098d493d632ed0..f18b53f90df115b9677483816e79b8a241b35144 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2011 Samsung Electronics Co., Ltd All Rights Reserved
+ * Copyright (c) 2011 - 2018 Samsung Electronics Co., Ltd All Rights Reserved
  *
  *    Licensed under the Apache License, Version 2.0 (the "License");
  *    you may not use this file except in compliance with the License.
 #include <map>
 #include <memory>
 
+#include <dpl/exception.h>
+
 namespace AuthPasswd {
+
+class SerializationException {
+public:
+       DECLARE_EXCEPTION_TYPE(AuthPasswd::Exception, Base)
+       DECLARE_EXCEPTION_TYPE(Base, InvalidStreamData)
+};
+
 // Abstract data stream buffer
 class IStream {
 public:
@@ -223,7 +232,9 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, char *&value) {
                value = new char;
+               std::unique_ptr<char> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // unsigned char
@@ -232,7 +243,9 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, unsigned char *&value) {
                value = new unsigned char;
+               std::unique_ptr<unsigned char> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // unsigned int
@@ -241,7 +254,9 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, unsigned *&value) {
                value = new unsigned;
+               std::unique_ptr<unsigned> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // int
@@ -250,7 +265,9 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, int *&value) {
                value = new int;
+               std::unique_ptr<int> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // bool
@@ -259,7 +276,9 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, bool *&value) {
                value = new bool;
+               std::unique_ptr<bool> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // time_t
@@ -268,27 +287,33 @@ struct Deserialization {
        }
        static void Deserialize(IStream &stream, time_t *&value) {
                value = new time_t;
+               std::unique_ptr<time_t> ptr(value);
                stream.Read(sizeof(*value), value);
+               ptr.release();
        }
 
        // std::string
        static void Deserialize(IStream &stream, std::string &str) {
                int length;
                stream.Read(sizeof(length), &length);
-               char *buf = new char[length + 1];
+               if (length < 0)
+                       ThrowMsg(SerializationException::InvalidStreamData, "Invalid length of std::string (less than 0)");
+               char * buf = new char[length + 1];
+               std::unique_ptr<char[]> ptr(buf);
                stream.Read(length, buf);
                buf[length] = 0;
                str = std::string(buf);
-               delete[] buf;
        }
        static void Deserialize(IStream &stream, std::string *&str) {
                int length;
                stream.Read(sizeof(length), &length);
-               char *buf = new char[length + 1];
+               if (length < 0)
+                       ThrowMsg(SerializationException::InvalidStreamData, "Invalid length of std::string (less than 0)");
+               char * buf = new char[length + 1];
+               std::unique_ptr<char[]> ptr(buf);
                stream.Read(length, buf);
                buf[length] = 0;
                str = new std::string(buf);
-               delete[] buf;
        }
 
        // STL templates
@@ -298,7 +323,8 @@ struct Deserialization {
        static void Deserialize(IStream &stream, std::list<T> &list) {
                int length;
                stream.Read(sizeof(length), &length);
-
+               if (length < 0)
+                       ThrowMsg(SerializationException::InvalidStreamData, "Invalid length of std::list (less than 0)");
                for (int i = 0; i < length; ++i) {
                        T obj;
                        Deserialize(stream, obj);
@@ -308,7 +334,9 @@ struct Deserialization {
        template <typename T>
        static void Deserialize(IStream &stream, std::list<T> *&list) {
                list = new std::list<T>;
+               std::unique_ptr<std::list<T>> ptr(list);
                Deserialize(stream, *list);
+               ptr.release();
        }
 
        // std::vector
@@ -316,7 +344,8 @@ struct Deserialization {
        static void Deserialize(IStream &stream, std::vector<T> &vec) {
                int length;
                stream.Read(sizeof(length), &length);
-
+               if (length < 0)
+                       ThrowMsg(SerializationException::InvalidStreamData, "Invalid length of std::vector (less than 0)");
                for (int i = 0; i < length; ++i) {
                        T obj;
                        Deserialize(stream, obj);
@@ -326,7 +355,9 @@ struct Deserialization {
        template <typename T>
        static void Deserialize(IStream &stream, std::vector<T> *&vec) {
                vec = new std::vector<T>;
+               std::unique_ptr<std::vector<T>> ptr(vec);
                Deserialize(stream, *vec);
+               ptr.release();
        }
 
        // std::set
@@ -334,7 +365,6 @@ struct Deserialization {
        static void Deserialize(IStream &stream, std::set<T> &set) {
                size_t length;
                stream.Read(sizeof(length), &length);
-
                for (size_t i = 0; i < length; ++i) {
                        T obj;
                        Deserialize(stream, obj);
@@ -351,7 +381,30 @@ struct Deserialization {
        template <typename A, typename B>
        static void Deserialize(IStream &stream, std::pair<A, B> *&p) {
                p = new std::pair<A, B>;
+               std::unique_ptr<std::pair<A, B>> ptr(p);
                Deserialize(stream, *p);
+               ptr.release();
+       }
+
+       // std::tuple
+       template <std::size_t I = 0, typename... Tp>
+       static inline typename std::enable_if<I == sizeof...(Tp), void>::type
+       Deserialize(IStream&, std::tuple<Tp...>&)
+       {}
+
+       template <std::size_t I = 0, typename... Tp>
+       static inline typename std::enable_if<I < sizeof...(Tp), void>::type
+       Deserialize(IStream& stream, std::tuple<Tp...>& t) {
+               Deserialize(stream, std::get<I>(t));
+               Deserialize<I+1>(stream, t);
+       }
+
+       template <typename... Tp>
+       static void Deserialize(IStream &stream, std::tuple<Tp...> *&t) {
+               t = new std::tuple<Tp...>;
+               std::unique_ptr<std::tuple<Tp...>> ptr(t);
+               Deserialize(stream, *t);
+               ptr.release();
        }
 
        // std::map
@@ -359,7 +412,8 @@ struct Deserialization {
        static void Deserialize(IStream &stream, std::map<K, T> &map) {
                int length;
                stream.Read(sizeof(length), &length);
-
+               if (length < 0)
+                       ThrowMsg(SerializationException::InvalidStreamData, "Invalid size of std::map (less than 0)");
                for (int i = 0; i < length; ++i) {
                        K key;
                        T obj;
@@ -371,7 +425,15 @@ struct Deserialization {
        template <typename K, typename T>
        static void Deserialize(IStream &stream, std::map<K, T> *&map) {
                map = new std::map<K, T>;
+               std::unique_ptr<std::map<K, T>> ptr(map);
                Deserialize(stream, *map);
+               ptr.release();
+       }
+
+       template<typename T1, typename T2, typename... Tail>
+       static void Deserialize(IStream &stream, T1 &first, T2 &second, Tail&... tail) {
+               Deserialization::Deserialize(stream, first);
+               Deserialization::Deserialize(stream, second, tail...);
        }
 }; // struct Deserialization
 } // namespace AuthPasswd
index 64746c6cd3fea589ee9d499d66dafe9956ff7dae..9e401e5b57150ffd341fb2fc83fb0e93dfca95e2 100644 (file)
@@ -32,6 +32,7 @@ INCLUDE_DIRECTORIES(SYSTEM ${${TARGET_TEST}_DEP_INCLUDE_DIRS}
                                                   ${SERVER_PATH}/service/include
                                                   ${DPL_PATH}/core/include
                                                   ${DPL_PATH}/log/include
+                                                  ${COMMON_PATH}/include
                                                   ${PLUGIN_PATH})
 
 ADD_EXECUTABLE(${TARGET_TEST} ${TEST_SRCS})