socket: add support to control no. of connections from one source (#3607)
authorSusant Sahani <ssahani@users.noreply.github.com>
Tue, 2 Aug 2016 17:48:23 +0000 (23:18 +0530)
committerZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>
Tue, 2 Aug 2016 17:48:23 +0000 (13:48 -0400)
Introduce MaxConnectionsPerSource= that is number of concurrent
connections allowed per IP.

RFE: 1939

man/systemd.socket.xml
src/core/dbus-socket.c
src/core/load-fragment-gperf.gperf.m4
src/core/service.c
src/core/service.h
src/core/socket.c
src/core/socket.h

index 5bf54d8..26e5d3c 100644 (file)
       </varlistentry>
 
       <varlistentry>
+        <term><varname>MaxConnectionsPerSource=</varname></term>
+        <listitem><para>The maximum number of connections for a service per source IP address.
+        This is is very similar to the <varname>MaxConnections=</varname> directive
+        above. Disabled by default.</para>
+        </listitem>
+      </varlistentry>
+
+       <varlistentry>
         <term><varname>KeepAlive=</varname></term>
         <listitem><para>Takes a boolean argument. If true, the TCP/IP
         stack will send a keep alive message after 2h (depending on
index 9613406..9a071a1 100644 (file)
@@ -137,6 +137,7 @@ const sd_bus_vtable bus_socket_vtable[] = {
         SD_BUS_PROPERTY("Symlinks", "as", NULL, offsetof(Socket, symlinks), SD_BUS_VTABLE_PROPERTY_CONST),
         SD_BUS_PROPERTY("Mark", "i", bus_property_get_int, offsetof(Socket, mark), SD_BUS_VTABLE_PROPERTY_CONST),
         SD_BUS_PROPERTY("MaxConnections", "u", bus_property_get_unsigned, offsetof(Socket, max_connections), SD_BUS_VTABLE_PROPERTY_CONST),
+        SD_BUS_PROPERTY("MaxConnectionsPerSource", "u", bus_property_get_unsigned, offsetof(Socket, max_connections_per_source), SD_BUS_VTABLE_PROPERTY_CONST),
         SD_BUS_PROPERTY("MessageQueueMaxMessages", "x", bus_property_get_long, offsetof(Socket, mq_maxmsg), SD_BUS_VTABLE_PROPERTY_CONST),
         SD_BUS_PROPERTY("MessageQueueMessageSize", "x", bus_property_get_long, offsetof(Socket, mq_msgsize), SD_BUS_VTABLE_PROPERTY_CONST),
         SD_BUS_PROPERTY("ReusePort", "b",  bus_property_get_bool, offsetof(Socket, reuse_port), SD_BUS_VTABLE_PROPERTY_CONST),
index c9cdbe8..396f847 100644 (file)
@@ -293,6 +293,7 @@ Socket.DirectoryMode,            config_parse_mode,                  0,
 Socket.Accept,                   config_parse_bool,                  0,                             offsetof(Socket, accept)
 Socket.Writable,                 config_parse_bool,                  0,                             offsetof(Socket, writable)
 Socket.MaxConnections,           config_parse_unsigned,              0,                             offsetof(Socket, max_connections)
+Socket.MaxConnectionsPerSource,  config_parse_unsigned,              0,                             offsetof(Socket, max_connections_per_source)
 Socket.KeepAlive,                config_parse_bool,                  0,                             offsetof(Socket, keep_alive)
 Socket.KeepAliveTimeSec,         config_parse_sec,                   0,                             offsetof(Socket, keep_alive_time)
 Socket.KeepAliveIntervalSec,     config_parse_sec,                   0,                             offsetof(Socket, keep_alive_interval)
index 4d59d78..eb125cb 100644 (file)
@@ -342,6 +342,7 @@ static void service_done(Unit *u) {
         s->bus_name_owner = mfree(s->bus_name_owner);
 
         service_close_socket_fd(s);
+        s->peer = socket_peer_unref(s->peer);
 
         unit_ref_unset(&s->accept_socket);
 
index 8e56e1a..888007c 100644 (file)
@@ -152,6 +152,7 @@ struct Service {
 
         pid_t main_pid, control_pid;
         int socket_fd;
+        SocketPeer *peer;
         bool socket_fd_selinux_context_net;
 
         bool permissions_start_only;
index 1ce41a1..ff55885 100644 (file)
@@ -57,6 +57,7 @@
 #include "unit-printf.h"
 #include "unit.h"
 #include "user-util.h"
+#include "in-addr-util.h"
 
 static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
         [SOCKET_DEAD] = UNIT_INACTIVE,
@@ -77,6 +78,9 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
 static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
 static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
 
+SocketPeer *socket_peer_new(void);
+int socket_find_peer(Socket *s, int fd, SocketPeer **p);
+
 static void socket_init(Unit *u) {
         Socket *s = SOCKET(u);
 
@@ -141,11 +145,17 @@ void socket_free_ports(Socket *s) {
 
 static void socket_done(Unit *u) {
         Socket *s = SOCKET(u);
+        SocketPeer *p;
 
         assert(s);
 
         socket_free_ports(s);
 
+        while ((p = hashmap_steal_first(s->peers_by_address)))
+                p->socket = NULL;
+
+        s->peers_by_address = hashmap_free(s->peers_by_address);
+
         s->exec_runtime = exec_runtime_unref(s->exec_runtime);
         exec_command_free_array(s->exec_command, _SOCKET_EXEC_COMMAND_MAX);
         s->control_command = NULL;
@@ -468,6 +478,40 @@ static int socket_verify(Socket *s) {
         return 0;
 }
 
+static void peer_address_hash_func(const void *p, struct siphash *state) {
+        const SocketPeer *s = p;
+
+        assert(s);
+
+        if (s->peer.sa.sa_family == AF_INET)
+                siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
+        else if (s->peer.sa.sa_family == AF_INET6)
+                siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
+}
+
+static int peer_address_compare_func(const void *a, const void *b) {
+        const SocketPeer *x = a, *y = b;
+
+        if (x->peer.sa.sa_family < y->peer.sa.sa_family)
+                return -1;
+        if (x->peer.sa.sa_family > y->peer.sa.sa_family)
+                return 1;
+
+        switch(x->peer.sa.sa_family) {
+        case AF_INET:
+                return memcmp(&x->peer.in.sin_addr, &y->peer.in.sin_addr, sizeof(x->peer.in.sin_addr));
+        case AF_INET6:
+                return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
+        }
+
+        return -1;
+}
+
+const struct hash_ops peer_address_hash_ops = {
+        .hash = peer_address_hash_func,
+        .compare = peer_address_compare_func
+};
+
 static int socket_load(Unit *u) {
         Socket *s = SOCKET(u);
         int r;
@@ -475,6 +519,10 @@ static int socket_load(Unit *u) {
         assert(u);
         assert(u->load_state == UNIT_STUB);
 
+        r = hashmap_ensure_allocated(&s->peers_by_address, &peer_address_hash_ops);
+        if (r < 0)
+                return r;
+
         r = unit_load_fragment_and_dropin(u);
         if (r < 0)
                 return r;
@@ -2050,6 +2098,7 @@ static void socket_enter_running(Socket *s, int cfd) {
                 socket_set_state(s, SOCKET_RUNNING);
         } else {
                 _cleanup_free_ char *prefix = NULL, *instance = NULL, *name = NULL;
+                _cleanup_(socket_peer_unrefp) SocketPeer *p = NULL;
                 Service *service;
 
                 if (s->n_connections >= s->max_connections) {
@@ -2058,6 +2107,21 @@ static void socket_enter_running(Socket *s, int cfd) {
                         return;
                 }
 
+                if (s->max_connections_per_source > 0) {
+                        r = socket_find_peer(s, cfd, &p);
+                        if (r < 0) {
+                                safe_close(cfd);
+                                return;
+                        }
+
+                        if (p->n_ref > s->max_connections_per_source) {
+                                log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
+                                safe_close(cfd);
+                                p = NULL;
+                                return;
+                        }
+                }
+
                 r = socket_instantiate_service(s);
                 if (r < 0)
                         goto fail;
@@ -2099,6 +2163,11 @@ static void socket_enter_running(Socket *s, int cfd) {
                 cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
                 s->n_connections++;
 
+                if (s->max_connections_per_source > 0) {
+                        service->peer = socket_peer_ref(p);
+                        p = NULL;
+                }
+
                 r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
                 if (r < 0) {
                         /* We failed to activate the new service, but it still exists. Let's make sure the service
@@ -2244,7 +2313,9 @@ static int socket_stop(Unit *u) {
 
 static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
         Socket *s = SOCKET(u);
+        SocketPeer *k;
         SocketPort *p;
+        Iterator i;
         int r;
 
         assert(u);
@@ -2295,6 +2366,16 @@ static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
                 }
         }
 
+        HASHMAP_FOREACH(k, s->peers_by_address, i) {
+                _cleanup_free_ char *t = NULL;
+
+                r = sockaddr_pretty(&k->peer.sa, FAMILY_ADDRESS_SIZE(k->peer.sa.sa_family), true, true, &t);
+                if (r < 0)
+                        return r;
+
+                unit_serialize_item_format(u, f, "peer", "%u %s", k->n_ref, t);
+        }
+
         return 0;
 }
 
@@ -2458,6 +2539,33 @@ static int socket_deserialize_item(Unit *u, const char *key, const char *value,
                         }
                 }
 
+        } else if (streq(key, "peer")) {
+                _cleanup_(socket_peer_unrefp) SocketPeer *p;
+                int n_ref, skip = 0;
+                SocketAddress a;
+                int r;
+
+                if (sscanf(value, "%u %n", &n_ref, &skip) < 1 || n_ref < 1)
+                        log_unit_debug(u, "Failed to parse socket peer value: %s", value);
+                else {
+                        r = socket_address_parse(&a, value+skip);
+                        if (r < 0)
+                                return r;
+
+                        p = socket_peer_new();
+                        if (!p)
+                                return log_oom();
+
+                        p->n_ref = n_ref;
+                        memcpy(&p->peer, &a.sockaddr, sizeof(a.sockaddr));
+                        p->socket = s;
+
+                        r = hashmap_put(s->peers_by_address, p, p);
+                        if (r < 0)
+                                return r;
+
+                        p = NULL;
+                }
         } else
                 log_unit_debug(UNIT(s), "Unknown serialization key: %s", key);
 
@@ -2554,6 +2662,83 @@ _pure_ static bool socket_check_gc(Unit *u) {
         return s->n_connections > 0;
 }
 
+SocketPeer *socket_peer_new(void) {
+        SocketPeer *p;
+
+        p = new0(SocketPeer, 1);
+        if (!p)
+                return NULL;
+
+        p->n_ref = 1;
+
+        return p;
+}
+
+SocketPeer *socket_peer_ref(SocketPeer *p) {
+        if (!p)
+                return NULL;
+
+        assert(p->n_ref > 0);
+        p->n_ref++;
+
+        return p;
+}
+
+SocketPeer *socket_peer_unref(SocketPeer *p) {
+        if (!p)
+                return NULL;
+
+        assert(p->n_ref > 0);
+
+        p->n_ref--;
+
+        if (p->n_ref > 0)
+                return NULL;
+
+        if (p->socket)
+                (void) hashmap_remove(p->socket->peers_by_address, p);
+
+        free(p);
+
+        return NULL;
+}
+
+int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
+        _cleanup_free_ SocketPeer *remote = NULL;
+        SocketPeer sa, *i;
+        socklen_t salen = sizeof(sa.peer);
+        int r;
+
+        assert(fd >= 0);
+        assert(s);
+
+        r = getpeername(fd, &sa.peer.sa, &salen);
+        if (r < 0)
+                return log_error_errno(errno, "getpeername failed: %m");
+
+        i = hashmap_get(s->peers_by_address, &sa);
+        if (i) {
+                *p = i;
+                return 1;
+        }
+
+        remote = socket_peer_new();
+        if (!remote)
+                return log_oom();
+
+        memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
+        remote->socket = s;
+
+        r = hashmap_put(s->peers_by_address, remote, remote);
+        if (r < 0)
+                return r;
+
+        *p = remote;
+        remote = NULL;
+
+        return 0;
+}
+
 static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
         SocketPort *p = userdata;
         int cfd = -1;
index 6c32d67..2fe38ef 100644 (file)
@@ -20,6 +20,7 @@
 ***/
 
 typedef struct Socket Socket;
+typedef struct SocketPeer SocketPeer;
 
 #include "mount.h"
 #include "service.h"
@@ -79,9 +80,12 @@ struct Socket {
 
         LIST_HEAD(SocketPort, ports);
 
+        Hashmap *peers_by_address;
+
         unsigned n_accepted;
         unsigned n_connections;
         unsigned max_connections;
+        unsigned max_connections_per_source;
 
         unsigned backlog;
         unsigned keep_alive_cnt;
@@ -164,6 +168,18 @@ struct Socket {
         RateLimit trigger_limit;
 };
 
+struct SocketPeer {
+        unsigned n_ref;
+
+        Socket *socket;
+        union sockaddr_union peer;
+};
+
+SocketPeer *socket_peer_ref(SocketPeer *p);
+SocketPeer *socket_peer_unref(SocketPeer *p);
+
+DEFINE_TRIVIAL_CLEANUP_FUNC(SocketPeer*, socket_peer_unref);
+
 /* Called from the service code when collecting fds */
 int socket_collect_fds(Socket *s, int **fds);