SUNRPC: Add a TCP-with-TLS RPC transport class
authorChuck Lever <chuck.lever@oracle.com>
Wed, 7 Jun 2023 13:59:15 +0000 (09:59 -0400)
committerTrond Myklebust <trond.myklebust@hammerspace.com>
Mon, 19 Jun 2023 16:28:10 +0000 (12:28 -0400)
Use the new TLS handshake API to enable the SunRPC client code
to request a TLS handshake. This implements support for RFC 9289,
only on TCP sockets.

Upper layers such as NFS use RPC-with-TLS to protect in-transit
traffic.

Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
Signed-off-by: Trond Myklebust <trond.myklebust@hammerspace.com>
include/linux/sunrpc/xprt.h
include/linux/sunrpc/xprtsock.h
include/trace/events/sunrpc.h
net/sunrpc/sysfs.c
net/sunrpc/xprtsock.c

index 9e7f12c..b52411b 100644 (file)
@@ -200,6 +200,7 @@ enum xprt_transports {
        XPRT_TRANSPORT_RDMA     = 256,
        XPRT_TRANSPORT_BC_RDMA  = XPRT_TRANSPORT_RDMA | XPRT_TRANSPORT_BC,
        XPRT_TRANSPORT_LOCAL    = 257,
+       XPRT_TRANSPORT_TCP_TLS  = 258,
 };
 
 struct rpc_sysfs_xprt;
index daef030..700a1e6 100644 (file)
@@ -57,9 +57,11 @@ struct sock_xprt {
        struct work_struct      error_worker;
        struct work_struct      recv_worker;
        struct mutex            recv_mutex;
+       struct completion       handshake_done;
        struct sockaddr_storage srcaddr;
        unsigned short          srcport;
        int                     xprt_err;
+       struct rpc_clnt         *clnt;
 
        /*
         * UDP socket buffer size parameters
index 34784f2..7cd4bbd 100644 (file)
@@ -1525,6 +1525,50 @@ TRACE_EVENT(rpcb_unregister,
        )
 );
 
+/**
+ ** RPC-over-TLS tracepoints
+ **/
+
+DECLARE_EVENT_CLASS(rpc_tls_class,
+       TP_PROTO(
+               const struct rpc_clnt *clnt,
+               const struct rpc_xprt *xprt
+       ),
+
+       TP_ARGS(clnt, xprt),
+
+       TP_STRUCT__entry(
+               __field(unsigned long, requested_policy)
+               __field(u32, version)
+               __string(servername, xprt->servername)
+               __string(progname, clnt->cl_program->name)
+       ),
+
+       TP_fast_assign(
+               __entry->requested_policy = clnt->cl_xprtsec.policy;
+               __entry->version = clnt->cl_vers;
+               __assign_str(servername, xprt->servername);
+               __assign_str(progname, clnt->cl_program->name)
+       ),
+
+       TP_printk("server=%s %sv%u requested_policy=%s",
+               __get_str(servername), __get_str(progname), __entry->version,
+               rpc_show_xprtsec_policy(__entry->requested_policy)
+       )
+);
+
+#define DEFINE_RPC_TLS_EVENT(name) \
+       DEFINE_EVENT(rpc_tls_class, rpc_tls_##name, \
+                       TP_PROTO( \
+                               const struct rpc_clnt *clnt, \
+                               const struct rpc_xprt *xprt \
+                       ), \
+                       TP_ARGS(clnt, xprt))
+
+DEFINE_RPC_TLS_EVENT(unavailable);
+DEFINE_RPC_TLS_EVENT(not_started);
+
+
 /* Record an xdr_buf containing a fully-formed RPC message */
 DECLARE_EVENT_CLASS(svc_xdr_msg_class,
        TP_PROTO(
index 0d0db4e..5c8ecda 100644 (file)
@@ -239,6 +239,7 @@ static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj,
        if (!xprt)
                return 0;
        if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP ||
+             xprt->xprt_class->ident == XPRT_TRANSPORT_TCP_TLS ||
              xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) {
                xprt_put(xprt);
                return -EOPNOTSUPP;
index 7e2f962..9f01036 100644 (file)
@@ -48,6 +48,7 @@
 #include <net/udp.h>
 #include <net/tcp.h>
 #include <net/tls.h>
+#include <net/handshake.h>
 
 #include <linux/bvec.h>
 #include <linux/highmem.h>
@@ -98,6 +99,7 @@ static struct ctl_table_header *sunrpc_table_header;
 static struct xprt_class xs_local_transport;
 static struct xprt_class xs_udp_transport;
 static struct xprt_class xs_tcp_transport;
+static struct xprt_class xs_tcp_tls_transport;
 static struct xprt_class xs_bc_tcp_transport;
 
 /*
@@ -189,6 +191,11 @@ static struct ctl_table xs_tunables_table[] = {
  */
 #define XS_IDLE_DISC_TO                (5U * 60 * HZ)
 
+/*
+ * TLS handshake timeout.
+ */
+#define XS_TLS_HANDSHAKE_TO    (10U * HZ)
+
 #if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
 # undef  RPC_DEBUG_DATA
 # define RPCDBG_FACILITY       RPCDBG_TRANS
@@ -1243,6 +1250,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
        if (atomic_read(&transport->xprt.swapper))
                sk_clear_memalloc(sk);
 
+       tls_handshake_cancel(sk);
+
        kernel_sock_shutdown(sock, SHUT_RDWR);
 
        mutex_lock(&transport->recv_mutex);
@@ -2416,6 +2425,267 @@ out_unlock:
        current_restore_flags(pflags, PF_MEMALLOC);
 }
 
+/*
+ * Transfer the connected socket to @upper_transport, then mark that
+ * xprt CONNECTED.
+ */
+static int xs_tcp_tls_finish_connecting(struct rpc_xprt *lower_xprt,
+                                       struct sock_xprt *upper_transport)
+{
+       struct sock_xprt *lower_transport =
+                       container_of(lower_xprt, struct sock_xprt, xprt);
+       struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+
+       if (!upper_transport->inet) {
+               struct socket *sock = lower_transport->sock;
+               struct sock *sk = sock->sk;
+
+               /* Avoid temporary address, they are bad for long-lived
+                * connections such as NFS mounts.
+                * RFC4941, section 3.6 suggests that:
+                *    Individual applications, which have specific
+                *    knowledge about the normal duration of connections,
+                *    MAY override this as appropriate.
+                */
+               if (xs_addr(upper_xprt)->sa_family == PF_INET6)
+                       ip6_sock_set_addr_preferences(sk, IPV6_PREFER_SRC_PUBLIC);
+
+               xs_tcp_set_socket_timeouts(upper_xprt, sock);
+               tcp_sock_set_nodelay(sk);
+
+               lock_sock(sk);
+
+               /* @sk is already connected, so it now has the RPC callbacks.
+                * Reach into @lower_transport to save the original ones.
+                */
+               upper_transport->old_data_ready = lower_transport->old_data_ready;
+               upper_transport->old_state_change = lower_transport->old_state_change;
+               upper_transport->old_write_space = lower_transport->old_write_space;
+               upper_transport->old_error_report = lower_transport->old_error_report;
+               sk->sk_user_data = upper_xprt;
+
+               /* socket options */
+               sock_reset_flag(sk, SOCK_LINGER);
+
+               xprt_clear_connected(upper_xprt);
+
+               upper_transport->sock = sock;
+               upper_transport->inet = sk;
+               upper_transport->file = lower_transport->file;
+
+               release_sock(sk);
+
+               /* Reset lower_transport before shutting down its clnt */
+               mutex_lock(&lower_transport->recv_mutex);
+               lower_transport->inet = NULL;
+               lower_transport->sock = NULL;
+               lower_transport->file = NULL;
+
+               xprt_clear_connected(lower_xprt);
+               xs_sock_reset_connection_flags(lower_xprt);
+               xs_stream_reset_connect(lower_transport);
+               mutex_unlock(&lower_transport->recv_mutex);
+       }
+
+       if (!xprt_bound(upper_xprt))
+               return -ENOTCONN;
+
+       xs_set_memalloc(upper_xprt);
+
+       if (!xprt_test_and_set_connected(upper_xprt)) {
+               upper_xprt->connect_cookie++;
+               clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+               xprt_clear_connecting(upper_xprt);
+
+               upper_xprt->stat.connect_count++;
+               upper_xprt->stat.connect_time += (long)jiffies -
+                                          upper_xprt->stat.connect_start;
+               xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+       }
+       return 0;
+}
+
+/**
+ * xs_tls_handshake_done - TLS handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote's identity
+ *
+ */
+static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
+{
+       struct rpc_xprt *lower_xprt = data;
+       struct sock_xprt *lower_transport =
+                               container_of(lower_xprt, struct sock_xprt, xprt);
+
+       lower_transport->xprt_err = status ? -EACCES : 0;
+       complete(&lower_transport->handshake_done);
+       xprt_put(lower_xprt);
+}
+
+static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
+{
+       struct sock_xprt *lower_transport =
+                               container_of(lower_xprt, struct sock_xprt, xprt);
+       struct tls_handshake_args args = {
+               .ta_sock        = lower_transport->sock,
+               .ta_done        = xs_tls_handshake_done,
+               .ta_data        = xprt_get(lower_xprt),
+               .ta_peername    = lower_xprt->servername,
+       };
+       struct sock *sk = lower_transport->inet;
+       int rc;
+
+       init_completion(&lower_transport->handshake_done);
+       set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+       lower_transport->xprt_err = -ETIMEDOUT;
+       switch (xprtsec->policy) {
+       case RPC_XPRTSEC_TLS_ANON:
+               rc = tls_client_hello_anon(&args, GFP_KERNEL);
+               if (rc)
+                       goto out_put_xprt;
+               break;
+       case RPC_XPRTSEC_TLS_X509:
+               args.ta_my_cert = xprtsec->cert_serial;
+               args.ta_my_privkey = xprtsec->privkey_serial;
+               rc = tls_client_hello_x509(&args, GFP_KERNEL);
+               if (rc)
+                       goto out_put_xprt;
+               break;
+       default:
+               rc = -EACCES;
+               goto out_put_xprt;
+       }
+
+       rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
+                                                      XS_TLS_HANDSHAKE_TO);
+       if (rc <= 0) {
+               if (!tls_handshake_cancel(sk)) {
+                       if (rc == 0)
+                               rc = -ETIMEDOUT;
+                       goto out_put_xprt;
+               }
+       }
+
+       rc = lower_transport->xprt_err;
+
+out:
+       xs_stream_reset_connect(lower_transport);
+       clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+       return rc;
+
+out_put_xprt:
+       xprt_put(lower_xprt);
+       goto out;
+}
+
+/**
+ * xs_tcp_tls_setup_socket - establish a TLS session on a TCP socket
+ * @work: queued work item
+ *
+ * Invoked by a work queue tasklet.
+ *
+ * For RPC-with-TLS, there is a two-stage connection process.
+ *
+ * The "upper-layer xprt" is visible to the RPC consumer. Once it has
+ * been marked connected, the consumer knows that a TCP connection and
+ * a TLS session have been established.
+ *
+ * A "lower-layer xprt", created in this function, handles the mechanics
+ * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
+ * then driving the TLS handshake. Once all that is complete, the upper
+ * layer xprt is marked connected.
+ */
+static void xs_tcp_tls_setup_socket(struct work_struct *work)
+{
+       struct sock_xprt *upper_transport =
+               container_of(work, struct sock_xprt, connect_worker.work);
+       struct rpc_clnt *upper_clnt = upper_transport->clnt;
+       struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+       struct rpc_create_args args = {
+               .net            = upper_xprt->xprt_net,
+               .protocol       = upper_xprt->prot,
+               .address        = (struct sockaddr *)&upper_xprt->addr,
+               .addrsize       = upper_xprt->addrlen,
+               .timeout        = upper_clnt->cl_timeout,
+               .servername     = upper_xprt->servername,
+               .program        = upper_clnt->cl_program,
+               .prognumber     = upper_clnt->cl_prog,
+               .version        = upper_clnt->cl_vers,
+               .authflavor     = RPC_AUTH_TLS,
+               .cred           = upper_clnt->cl_cred,
+               .xprtsec        = {
+                       .policy         = RPC_XPRTSEC_NONE,
+               },
+       };
+       unsigned int pflags = current->flags;
+       struct rpc_clnt *lower_clnt;
+       struct rpc_xprt *lower_xprt;
+       int status;
+
+       if (atomic_read(&upper_xprt->swapper))
+               current->flags |= PF_MEMALLOC;
+
+       xs_stream_start_connect(upper_transport);
+
+       /* This implicitly sends an RPC_AUTH_TLS probe */
+       lower_clnt = rpc_create(&args);
+       if (IS_ERR(lower_clnt)) {
+               trace_rpc_tls_unavailable(upper_clnt, upper_xprt);
+               clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+               xprt_clear_connecting(upper_xprt);
+               xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
+               xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+               goto out_unlock;
+       }
+
+       /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
+        * the lower xprt.
+        */
+       rcu_read_lock();
+       lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
+       rcu_read_unlock();
+       status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
+       if (status) {
+               trace_rpc_tls_not_started(upper_clnt, upper_xprt);
+               goto out_close;
+       }
+
+       status = xs_tcp_tls_finish_connecting(lower_xprt, upper_transport);
+       if (status)
+               goto out_close;
+
+       trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
+       if (!xprt_test_and_set_connected(upper_xprt)) {
+               upper_xprt->connect_cookie++;
+               clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+               xprt_clear_connecting(upper_xprt);
+
+               upper_xprt->stat.connect_count++;
+               upper_xprt->stat.connect_time += (long)jiffies -
+                                          upper_xprt->stat.connect_start;
+               xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+       }
+       rpc_shutdown_client(lower_clnt);
+
+out_unlock:
+       current_restore_flags(pflags, PF_MEMALLOC);
+       upper_transport->clnt = NULL;
+       xprt_unlock_connect(upper_xprt, upper_transport);
+       return;
+
+out_close:
+       rpc_shutdown_client(lower_clnt);
+
+       /* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
+        * Wake them first here to ensure they get our tk_status code.
+        */
+       xprt_wake_pending_tasks(upper_xprt, status);
+       xs_tcp_force_close(upper_xprt);
+       xprt_clear_connecting(upper_xprt);
+       goto out_unlock;
+}
+
 /**
  * xs_connect - connect a socket to a remote endpoint
  * @xprt: pointer to transport structure
@@ -2447,6 +2717,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
        } else
                dprintk("RPC:       xs_connect scheduled xprt %p\n", xprt);
 
+       transport->clnt = task->tk_client;
        queue_delayed_work(xprtiod_workqueue,
                        &transport->connect_worker,
                        delay);
@@ -3101,6 +3372,94 @@ out_err:
 }
 
 /**
+ * xs_setup_tcp_tls - Set up transport to use a TCP with TLS
+ * @args: rpc transport creation arguments
+ *
+ */
+static struct rpc_xprt *xs_setup_tcp_tls(struct xprt_create *args)
+{
+       struct sockaddr *addr = args->dstaddr;
+       struct rpc_xprt *xprt;
+       struct sock_xprt *transport;
+       struct rpc_xprt *ret;
+       unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries;
+
+       if (args->flags & XPRT_CREATE_INFINITE_SLOTS)
+               max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT;
+
+       xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries,
+                            max_slot_table_size);
+       if (IS_ERR(xprt))
+               return xprt;
+       transport = container_of(xprt, struct sock_xprt, xprt);
+
+       xprt->prot = IPPROTO_TCP;
+       xprt->xprt_class = &xs_tcp_transport;
+       xprt->max_payload = RPC_MAX_FRAGMENT_SIZE;
+
+       xprt->bind_timeout = XS_BIND_TO;
+       xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+       xprt->idle_timeout = XS_IDLE_DISC_TO;
+
+       xprt->ops = &xs_tcp_ops;
+       xprt->timeout = &xs_tcp_default_timeout;
+
+       xprt->max_reconnect_timeout = xprt->timeout->to_maxval;
+       xprt->connect_timeout = xprt->timeout->to_initval *
+               (xprt->timeout->to_retries + 1);
+
+       INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn);
+       INIT_WORK(&transport->error_worker, xs_error_handle);
+
+       switch (args->xprtsec.policy) {
+       case RPC_XPRTSEC_TLS_ANON:
+       case RPC_XPRTSEC_TLS_X509:
+               xprt->xprtsec = args->xprtsec;
+               INIT_DELAYED_WORK(&transport->connect_worker,
+                                 xs_tcp_tls_setup_socket);
+               break;
+       default:
+               ret = ERR_PTR(-EACCES);
+               goto out_err;
+       }
+
+       switch (addr->sa_family) {
+       case AF_INET:
+               if (((struct sockaddr_in *)addr)->sin_port != htons(0))
+                       xprt_set_bound(xprt);
+
+               xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
+               break;
+       case AF_INET6:
+               if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0))
+                       xprt_set_bound(xprt);
+
+               xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
+               break;
+       default:
+               ret = ERR_PTR(-EAFNOSUPPORT);
+               goto out_err;
+       }
+
+       if (xprt_bound(xprt))
+               dprintk("RPC:       set up xprt to %s (port %s) via %s\n",
+                       xprt->address_strings[RPC_DISPLAY_ADDR],
+                       xprt->address_strings[RPC_DISPLAY_PORT],
+                       xprt->address_strings[RPC_DISPLAY_PROTO]);
+       else
+               dprintk("RPC:       set up xprt to %s (autobind) via %s\n",
+                       xprt->address_strings[RPC_DISPLAY_ADDR],
+                       xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+       if (try_module_get(THIS_MODULE))
+               return xprt;
+       ret = ERR_PTR(-EINVAL);
+out_err:
+       xs_xprt_free(xprt);
+       return ret;
+}
+
+/**
  * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket
  * @args: rpc transport creation arguments
  *
@@ -3209,6 +3568,15 @@ static struct xprt_class xs_tcp_transport = {
        .netid          = { "tcp", "tcp6", "" },
 };
 
+static struct xprt_class       xs_tcp_tls_transport = {
+       .list           = LIST_HEAD_INIT(xs_tcp_tls_transport.list),
+       .name           = "tcp-with-tls",
+       .owner          = THIS_MODULE,
+       .ident          = XPRT_TRANSPORT_TCP_TLS,
+       .setup          = xs_setup_tcp_tls,
+       .netid          = { "tcp", "tcp6", "" },
+};
+
 static struct xprt_class       xs_bc_tcp_transport = {
        .list           = LIST_HEAD_INIT(xs_bc_tcp_transport.list),
        .name           = "tcp NFSv4.1 backchannel",
@@ -3230,6 +3598,7 @@ int init_socket_xprt(void)
        xprt_register_transport(&xs_local_transport);
        xprt_register_transport(&xs_udp_transport);
        xprt_register_transport(&xs_tcp_transport);
+       xprt_register_transport(&xs_tcp_tls_transport);
        xprt_register_transport(&xs_bc_tcp_transport);
 
        return 0;
@@ -3249,6 +3618,7 @@ void cleanup_socket_xprt(void)
        xprt_unregister_transport(&xs_local_transport);
        xprt_unregister_transport(&xs_udp_transport);
        xprt_unregister_transport(&xs_tcp_transport);
+       xprt_unregister_transport(&xs_tcp_tls_transport);
        xprt_unregister_transport(&xs_bc_tcp_transport);
 }