Merge tag 'pmdomain-v6.6-rc1-2' of git://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-starfive.git] / net / sunrpc / auth_tls.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2021, 2022 Oracle.  All rights reserved.
4  *
5  * The AUTH_TLS credential is used only to probe a remote peer
6  * for RPC-over-TLS support.
7  */
8
9 #include <linux/types.h>
10 #include <linux/module.h>
11 #include <linux/sunrpc/clnt.h>
12
13 static const char *starttls_token = "STARTTLS";
14 static const size_t starttls_len = 8;
15
16 static struct rpc_auth tls_auth;
17 static struct rpc_cred tls_cred;
18
19 static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
20                              const void *obj)
21 {
22 }
23
24 static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
25                             void *obj)
26 {
27         return 0;
28 }
29
30 static const struct rpc_procinfo rpcproc_tls_probe = {
31         .p_encode       = tls_encode_probe,
32         .p_decode       = tls_decode_probe,
33 };
34
35 static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
36 {
37         task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
38         rpc_call_start(task);
39 }
40
41 static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
42 {
43 }
44
45 static const struct rpc_call_ops rpc_tls_probe_ops = {
46         .rpc_call_prepare       = rpc_tls_probe_call_prepare,
47         .rpc_call_done          = rpc_tls_probe_call_done,
48 };
49
50 static int tls_probe(struct rpc_clnt *clnt)
51 {
52         struct rpc_message msg = {
53                 .rpc_proc       = &rpcproc_tls_probe,
54         };
55         struct rpc_task_setup task_setup_data = {
56                 .rpc_client     = clnt,
57                 .rpc_message    = &msg,
58                 .rpc_op_cred    = &tls_cred,
59                 .callback_ops   = &rpc_tls_probe_ops,
60                 .flags          = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
61         };
62         struct rpc_task *task;
63         int status;
64
65         task = rpc_run_task(&task_setup_data);
66         if (IS_ERR(task))
67                 return PTR_ERR(task);
68         status = task->tk_status;
69         rpc_put_task(task);
70         return status;
71 }
72
73 static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
74                                    struct rpc_clnt *clnt)
75 {
76         refcount_inc(&tls_auth.au_count);
77         return &tls_auth;
78 }
79
80 static void tls_destroy(struct rpc_auth *auth)
81 {
82 }
83
84 static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
85                                         struct auth_cred *acred, int flags)
86 {
87         return get_rpccred(&tls_cred);
88 }
89
90 static void tls_destroy_cred(struct rpc_cred *cred)
91 {
92 }
93
94 static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
95 {
96         return 1;
97 }
98
99 static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
100 {
101         __be32 *p;
102
103         p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
104         if (!p)
105                 return -EMSGSIZE;
106         /* Credential */
107         *p++ = rpc_auth_tls;
108         *p++ = xdr_zero;
109         /* Verifier */
110         *p++ = rpc_auth_null;
111         *p   = xdr_zero;
112         return 0;
113 }
114
115 static int tls_refresh(struct rpc_task *task)
116 {
117         set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
118         return 0;
119 }
120
121 static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
122 {
123         __be32 *p;
124         void *str;
125
126         p = xdr_inline_decode(xdr, XDR_UNIT);
127         if (!p)
128                 return -EIO;
129         if (*p != rpc_auth_null)
130                 return -EIO;
131         if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
132                 return -EPROTONOSUPPORT;
133         if (memcmp(str, starttls_token, starttls_len))
134                 return -EPROTONOSUPPORT;
135         return 0;
136 }
137
138 const struct rpc_authops authtls_ops = {
139         .owner          = THIS_MODULE,
140         .au_flavor      = RPC_AUTH_TLS,
141         .au_name        = "NULL",
142         .create         = tls_create,
143         .destroy        = tls_destroy,
144         .lookup_cred    = tls_lookup_cred,
145         .ping           = tls_probe,
146 };
147
148 static struct rpc_auth tls_auth = {
149         .au_cslack      = NUL_CALLSLACK,
150         .au_rslack      = NUL_REPLYSLACK,
151         .au_verfsize    = NUL_REPLYSLACK,
152         .au_ralign      = NUL_REPLYSLACK,
153         .au_ops         = &authtls_ops,
154         .au_flavor      = RPC_AUTH_TLS,
155         .au_count       = REFCOUNT_INIT(1),
156 };
157
158 static const struct rpc_credops tls_credops = {
159         .cr_name        = "AUTH_TLS",
160         .crdestroy      = tls_destroy_cred,
161         .crmatch        = tls_match,
162         .crmarshal      = tls_marshal,
163         .crwrap_req     = rpcauth_wrap_req_encode,
164         .crrefresh      = tls_refresh,
165         .crvalidate     = tls_validate,
166         .crunwrap_resp  = rpcauth_unwrap_resp_decode,
167 };
168
169 static struct rpc_cred tls_cred = {
170         .cr_lru         = LIST_HEAD_INIT(tls_cred.cr_lru),
171         .cr_auth        = &tls_auth,
172         .cr_ops         = &tls_credops,
173         .cr_count       = REFCOUNT_INIT(2),
174         .cr_flags       = 1UL << RPCAUTH_CRED_UPTODATE,
175 };