e8ee5901e9be57c356efce86a80889e4d5f03422
[platform/upstream/nss.git] / lib / ssl / ssl3ecc.c
1 /*
2  * SSL3 Protocol
3  *
4  * This Source Code Form is subject to the terms of the Mozilla Public
5  * License, v. 2.0. If a copy of the MPL was not distributed with this
6  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
7
8 /* ECC code moved here from ssl3con.c */
9
10 #include "nss.h"
11 #include "cert.h"
12 #include "ssl.h"
13 #include "cryptohi.h"   /* for DSAU_ stuff */
14 #include "keyhi.h"
15 #include "secder.h"
16 #include "secitem.h"
17
18 #include "sslimpl.h"
19 #include "sslproto.h"
20 #include "sslerr.h"
21 #include "prtime.h"
22 #include "prinrval.h"
23 #include "prerror.h"
24 #include "pratom.h"
25 #include "prthread.h"
26 #include "prinit.h"
27
28 #include "pk11func.h"
29 #include "secmod.h"
30
31 #include <stdio.h>
32
33 #ifndef NSS_DISABLE_ECC
34
35 #ifndef PK11_SETATTRS
36 #define PK11_SETATTRS(x,id,v,l) (x)->type = (id); \
37                 (x)->pValue=(v); (x)->ulValueLen = (l);
38 #endif
39
40 #define SSL_GET_SERVER_PUBLIC_KEY(sock, type) \
41     (ss->serverCerts[type].serverKeyPair ? \
42     ss->serverCerts[type].serverKeyPair->pubKey : NULL)
43
44 #define SSL_IS_CURVE_NEGOTIATED(curvemsk, curveName) \
45     ((curveName > ec_noName) && \
46      (curveName < ec_pastLastName) && \
47      ((1UL << curveName) & curvemsk) != 0)
48
49
50
51 static SECStatus ssl3_CreateECDHEphemeralKeys(sslSocket *ss, ECName ec_curve);
52
53 #define supportedCurve(x) (((x) > ec_noName) && ((x) < ec_pastLastName))
54
55 /* Table containing OID tags for elliptic curves named in the
56  * ECC-TLS IETF draft.
57  */
58 static const SECOidTag ecName2OIDTag[] = {
59         0,
60         SEC_OID_SECG_EC_SECT163K1,  /*  1 */
61         SEC_OID_SECG_EC_SECT163R1,  /*  2 */
62         SEC_OID_SECG_EC_SECT163R2,  /*  3 */
63         SEC_OID_SECG_EC_SECT193R1,  /*  4 */
64         SEC_OID_SECG_EC_SECT193R2,  /*  5 */
65         SEC_OID_SECG_EC_SECT233K1,  /*  6 */
66         SEC_OID_SECG_EC_SECT233R1,  /*  7 */
67         SEC_OID_SECG_EC_SECT239K1,  /*  8 */
68         SEC_OID_SECG_EC_SECT283K1,  /*  9 */
69         SEC_OID_SECG_EC_SECT283R1,  /* 10 */
70         SEC_OID_SECG_EC_SECT409K1,  /* 11 */
71         SEC_OID_SECG_EC_SECT409R1,  /* 12 */
72         SEC_OID_SECG_EC_SECT571K1,  /* 13 */
73         SEC_OID_SECG_EC_SECT571R1,  /* 14 */
74         SEC_OID_SECG_EC_SECP160K1,  /* 15 */
75         SEC_OID_SECG_EC_SECP160R1,  /* 16 */
76         SEC_OID_SECG_EC_SECP160R2,  /* 17 */
77         SEC_OID_SECG_EC_SECP192K1,  /* 18 */
78         SEC_OID_SECG_EC_SECP192R1,  /* 19 */
79         SEC_OID_SECG_EC_SECP224K1,  /* 20 */
80         SEC_OID_SECG_EC_SECP224R1,  /* 21 */
81         SEC_OID_SECG_EC_SECP256K1,  /* 22 */
82         SEC_OID_SECG_EC_SECP256R1,  /* 23 */
83         SEC_OID_SECG_EC_SECP384R1,  /* 24 */
84         SEC_OID_SECG_EC_SECP521R1,  /* 25 */
85 };
86
87 static const PRUint16 curve2bits[] = {
88           0, /*  ec_noName     = 0,   */
89         163, /*  ec_sect163k1  = 1,   */
90         163, /*  ec_sect163r1  = 2,   */
91         163, /*  ec_sect163r2  = 3,   */
92         193, /*  ec_sect193r1  = 4,   */
93         193, /*  ec_sect193r2  = 5,   */
94         233, /*  ec_sect233k1  = 6,   */
95         233, /*  ec_sect233r1  = 7,   */
96         239, /*  ec_sect239k1  = 8,   */
97         283, /*  ec_sect283k1  = 9,   */
98         283, /*  ec_sect283r1  = 10,  */
99         409, /*  ec_sect409k1  = 11,  */
100         409, /*  ec_sect409r1  = 12,  */
101         571, /*  ec_sect571k1  = 13,  */
102         571, /*  ec_sect571r1  = 14,  */
103         160, /*  ec_secp160k1  = 15,  */
104         160, /*  ec_secp160r1  = 16,  */
105         160, /*  ec_secp160r2  = 17,  */
106         192, /*  ec_secp192k1  = 18,  */
107         192, /*  ec_secp192r1  = 19,  */
108         224, /*  ec_secp224k1  = 20,  */
109         224, /*  ec_secp224r1  = 21,  */
110         256, /*  ec_secp256k1  = 22,  */
111         256, /*  ec_secp256r1  = 23,  */
112         384, /*  ec_secp384r1  = 24,  */
113         521, /*  ec_secp521r1  = 25,  */
114       65535  /*  ec_pastLastName      */
115 };
116
117 typedef struct Bits2CurveStr {
118     PRUint16    bits;
119     ECName      curve;
120 } Bits2Curve;
121
122 static const Bits2Curve bits2curve [] = {
123    {    192,     ec_secp192r1    /*  = 19,  fast */  },
124    {    160,     ec_secp160r2    /*  = 17,  fast */  },
125    {    160,     ec_secp160k1    /*  = 15,  */       },
126    {    160,     ec_secp160r1    /*  = 16,  */       },
127    {    163,     ec_sect163k1    /*  = 1,   */       },
128    {    163,     ec_sect163r1    /*  = 2,   */       },
129    {    163,     ec_sect163r2    /*  = 3,   */       },
130    {    192,     ec_secp192k1    /*  = 18,  */       },
131    {    193,     ec_sect193r1    /*  = 4,   */       },
132    {    193,     ec_sect193r2    /*  = 5,   */       },
133    {    224,     ec_secp224r1    /*  = 21,  fast */  },
134    {    224,     ec_secp224k1    /*  = 20,  */       },
135    {    233,     ec_sect233k1    /*  = 6,   */       },
136    {    233,     ec_sect233r1    /*  = 7,   */       },
137    {    239,     ec_sect239k1    /*  = 8,   */       },
138    {    256,     ec_secp256r1    /*  = 23,  fast */  },
139    {    256,     ec_secp256k1    /*  = 22,  */       },
140    {    283,     ec_sect283k1    /*  = 9,   */       },
141    {    283,     ec_sect283r1    /*  = 10,  */       },
142    {    384,     ec_secp384r1    /*  = 24,  fast */  },
143    {    409,     ec_sect409k1    /*  = 11,  */       },
144    {    409,     ec_sect409r1    /*  = 12,  */       },
145    {    521,     ec_secp521r1    /*  = 25,  fast */  },
146    {    571,     ec_sect571k1    /*  = 13,  */       },
147    {    571,     ec_sect571r1    /*  = 14,  */       },
148    {  65535,     ec_noName    }
149 };
150
151 typedef struct ECDHEKeyPairStr {
152     ssl3KeyPair *  pair;
153     int            error;  /* error code of the call-once function */
154     PRCallOnceType once;
155 } ECDHEKeyPair;
156
157 /* arrays of ECDHE KeyPairs */
158 static ECDHEKeyPair gECDHEKeyPairs[ec_pastLastName];
159
160 SECStatus
161 ssl3_ECName2Params(PLArenaPool * arena, ECName curve, SECKEYECParams * params)
162 {
163     SECOidData *oidData = NULL;
164
165     if ((curve <= ec_noName) || (curve >= ec_pastLastName) ||
166         ((oidData = SECOID_FindOIDByTag(ecName2OIDTag[curve])) == NULL)) {
167         PORT_SetError(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
168         return SECFailure;
169     }
170
171     SECITEM_AllocItem(arena, params, (2 + oidData->oid.len));
172     /*
173      * params->data needs to contain the ASN encoding of an object ID (OID)
174      * representing the named curve. The actual OID is in
175      * oidData->oid.data so we simply prepend 0x06 and OID length
176      */
177     params->data[0] = SEC_ASN1_OBJECT_ID;
178     params->data[1] = oidData->oid.len;
179     memcpy(params->data + 2, oidData->oid.data, oidData->oid.len);
180
181     return SECSuccess;
182 }
183
184 static ECName
185 params2ecName(SECKEYECParams * params)
186 {
187     SECItem oid = { siBuffer, NULL, 0};
188     SECOidData *oidData = NULL;
189     ECName i;
190
191     /*
192      * params->data needs to contain the ASN encoding of an object ID (OID)
193      * representing a named curve. Here, we strip away everything
194      * before the actual OID and use the OID to look up a named curve.
195      */
196     if (params->data[0] != SEC_ASN1_OBJECT_ID) return ec_noName;
197     oid.len = params->len - 2;
198     oid.data = params->data + 2;
199     if ((oidData = SECOID_FindOID(&oid)) == NULL) return ec_noName;
200     for (i = ec_noName + 1; i < ec_pastLastName; i++) {
201         if (ecName2OIDTag[i] == oidData->offset)
202             return i;
203     }
204
205     return ec_noName;
206 }
207
208 /* Caller must set hiLevel error code. */
209 static SECStatus
210 ssl3_ComputeECDHKeyHash(SECOidTag hashAlg,
211                         SECItem ec_params, SECItem server_ecpoint,
212                         SSL3Random *client_rand, SSL3Random *server_rand,
213                         SSL3Hashes *hashes, PRBool bypassPKCS11)
214 {
215     PRUint8     * hashBuf;
216     PRUint8     * pBuf;
217     SECStatus     rv            = SECSuccess;
218     unsigned int  bufLen;
219     /*
220      * XXX For now, we only support named curves (the appropriate
221      * checks are made before this method is called) so ec_params
222      * takes up only two bytes. ECPoint needs to fit in 256 bytes
223      * (because the spec says the length must fit in one byte)
224      */
225     PRUint8       buf[2*SSL3_RANDOM_LENGTH + 2 + 1 + 256];
226
227     bufLen = 2*SSL3_RANDOM_LENGTH + ec_params.len + 1 + server_ecpoint.len;
228     if (bufLen <= sizeof buf) {
229         hashBuf = buf;
230     } else {
231         hashBuf = PORT_Alloc(bufLen);
232         if (!hashBuf) {
233             return SECFailure;
234         }
235     }
236
237     memcpy(hashBuf, client_rand, SSL3_RANDOM_LENGTH);
238         pBuf = hashBuf + SSL3_RANDOM_LENGTH;
239     memcpy(pBuf, server_rand, SSL3_RANDOM_LENGTH);
240         pBuf += SSL3_RANDOM_LENGTH;
241     memcpy(pBuf, ec_params.data, ec_params.len);
242         pBuf += ec_params.len;
243     pBuf[0] = (PRUint8)(server_ecpoint.len);
244     pBuf += 1;
245     memcpy(pBuf, server_ecpoint.data, server_ecpoint.len);
246         pBuf += server_ecpoint.len;
247     PORT_Assert((unsigned int)(pBuf - hashBuf) == bufLen);
248
249     rv = ssl3_ComputeCommonKeyHash(hashAlg, hashBuf, bufLen, hashes,
250                                    bypassPKCS11);
251
252     PRINT_BUF(95, (NULL, "ECDHkey hash: ", hashBuf, bufLen));
253     PRINT_BUF(95, (NULL, "ECDHkey hash: MD5 result",
254               hashes->u.s.md5, MD5_LENGTH));
255     PRINT_BUF(95, (NULL, "ECDHkey hash: SHA1 result",
256               hashes->u.s.sha, SHA1_LENGTH));
257
258     if (hashBuf != buf)
259         PORT_Free(hashBuf);
260     return rv;
261 }
262
263
264 /* Called from ssl3_SendClientKeyExchange(). */
265 SECStatus
266 ssl3_SendECDHClientKeyExchange(sslSocket * ss, SECKEYPublicKey * svrPubKey)
267 {
268     PK11SymKey *        pms             = NULL;
269     SECStatus           rv              = SECFailure;
270     PRBool              isTLS, isTLS12;
271     CK_MECHANISM_TYPE   target;
272     SECKEYPublicKey     *pubKey = NULL;         /* Ephemeral ECDH key */
273     SECKEYPrivateKey    *privKey = NULL;        /* Ephemeral ECDH key */
274
275     PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) );
276     PORT_Assert( ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
277
278     isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0);
279     isTLS12 = (PRBool)(ss->ssl3.pwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
280
281     /* Generate ephemeral EC keypair */
282     if (svrPubKey->keyType != ecKey) {
283         PORT_SetError(SEC_ERROR_BAD_KEY);
284         goto loser;
285     }
286     /* XXX SHOULD CALL ssl3_CreateECDHEphemeralKeys here, instead! */
287     privKey = SECKEY_CreateECPrivateKey(&svrPubKey->u.ec.DEREncodedParams,
288                                         &pubKey, ss->pkcs11PinArg);
289     if (!privKey || !pubKey) {
290             ssl_MapLowLevelError(SEC_ERROR_KEYGEN_FAIL);
291             rv = SECFailure;
292             goto loser;
293     }
294     PRINT_BUF(50, (ss, "ECDH public value:",
295                                         pubKey->u.ec.publicValue.data,
296                                         pubKey->u.ec.publicValue.len));
297
298     if (isTLS12) {
299         target = CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256;
300     } else if (isTLS) {
301         target = CKM_TLS_MASTER_KEY_DERIVE_DH;
302     } else {
303         target = CKM_SSL3_MASTER_KEY_DERIVE_DH;
304     }
305
306     /*  Determine the PMS */
307     pms = PK11_PubDeriveWithKDF(privKey, svrPubKey, PR_FALSE, NULL, NULL,
308                             CKM_ECDH1_DERIVE, target, CKA_DERIVE, 0,
309                             CKD_NULL, NULL, NULL);
310
311     if (pms == NULL) {
312         SSL3AlertDescription desc  = illegal_parameter;
313         (void)SSL3_SendAlert(ss, alert_fatal, desc);
314         ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
315         goto loser;
316     }
317
318     SECKEY_DestroyPrivateKey(privKey);
319     privKey = NULL;
320
321     rv = ssl3_InitPendingCipherSpec(ss,  pms);
322     PK11_FreeSymKey(pms); pms = NULL;
323
324     if (rv != SECSuccess) {
325         ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
326         goto loser;
327     }
328
329     rv = ssl3_AppendHandshakeHeader(ss, client_key_exchange,
330                                         pubKey->u.ec.publicValue.len + 1);
331     if (rv != SECSuccess) {
332         goto loser;     /* err set by ssl3_AppendHandshake* */
333     }
334
335     rv = ssl3_AppendHandshakeVariable(ss,
336                                         pubKey->u.ec.publicValue.data,
337                                         pubKey->u.ec.publicValue.len, 1);
338     SECKEY_DestroyPublicKey(pubKey);
339     pubKey = NULL;
340
341     if (rv != SECSuccess) {
342         goto loser;     /* err set by ssl3_AppendHandshake* */
343     }
344
345     rv = SECSuccess;
346
347 loser:
348     if(pms) PK11_FreeSymKey(pms);
349     if(privKey) SECKEY_DestroyPrivateKey(privKey);
350     if(pubKey) SECKEY_DestroyPublicKey(pubKey);
351     return rv;
352 }
353
354
355 /*
356 ** Called from ssl3_HandleClientKeyExchange()
357 */
358 SECStatus
359 ssl3_HandleECDHClientKeyExchange(sslSocket *ss, SSL3Opaque *b,
360                                      PRUint32 length,
361                                      SECKEYPublicKey *srvrPubKey,
362                                      SECKEYPrivateKey *srvrPrivKey)
363 {
364     PK11SymKey *      pms;
365     SECStatus         rv;
366     SECKEYPublicKey   clntPubKey;
367     CK_MECHANISM_TYPE   target;
368     PRBool isTLS, isTLS12;
369
370     PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) );
371     PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) );
372
373     clntPubKey.keyType = ecKey;
374     clntPubKey.u.ec.DEREncodedParams.len =
375         srvrPubKey->u.ec.DEREncodedParams.len;
376     clntPubKey.u.ec.DEREncodedParams.data =
377         srvrPubKey->u.ec.DEREncodedParams.data;
378
379     rv = ssl3_ConsumeHandshakeVariable(ss, &clntPubKey.u.ec.publicValue,
380                                        1, &b, &length);
381     if (rv != SECSuccess) {
382         SEND_ALERT
383         return SECFailure;      /* XXX Who sets the error code?? */
384     }
385
386     isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0);
387     isTLS12 = (PRBool)(ss->ssl3.prSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
388
389     if (isTLS12) {
390         target = CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256;
391     } else if (isTLS) {
392         target = CKM_TLS_MASTER_KEY_DERIVE_DH;
393     } else {
394         target = CKM_SSL3_MASTER_KEY_DERIVE_DH;
395     }
396
397     /*  Determine the PMS */
398     pms = PK11_PubDeriveWithKDF(srvrPrivKey, &clntPubKey, PR_FALSE, NULL, NULL,
399                             CKM_ECDH1_DERIVE, target, CKA_DERIVE, 0,
400                             CKD_NULL, NULL, NULL);
401
402     if (pms == NULL) {
403         /* last gasp.  */
404         ssl_MapLowLevelError(SSL_ERROR_CLIENT_KEY_EXCHANGE_FAILURE);
405         return SECFailure;
406     }
407
408     rv = ssl3_InitPendingCipherSpec(ss,  pms);
409     PK11_FreeSymKey(pms);
410     if (rv != SECSuccess) {
411         SEND_ALERT
412         return SECFailure; /* error code set by ssl3_InitPendingCipherSpec */
413     }
414     return SECSuccess;
415 }
416
417 ECName
418 ssl3_GetCurveWithECKeyStrength(PRUint32 curvemsk, int requiredECCbits)
419 {
420     int    i;
421
422     for ( i = 0; bits2curve[i].curve != ec_noName; i++) {
423         if (bits2curve[i].bits < requiredECCbits)
424             continue;
425         if (SSL_IS_CURVE_NEGOTIATED(curvemsk, bits2curve[i].curve)) {
426             return bits2curve[i].curve;
427         }
428     }
429     PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
430     return ec_noName;
431 }
432
433 /* find the "weakest link".  Get strength of signature key and of sym key.
434  * choose curve for the weakest of those two.
435  */
436 ECName
437 ssl3_GetCurveNameForServerSocket(sslSocket *ss)
438 {
439     SECKEYPublicKey * svrPublicKey = NULL;
440     ECName ec_curve = ec_noName;
441     int    signatureKeyStrength = 521;
442     int    requiredECCbits = ss->sec.secretKeyBits * 2;
443
444     if (ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa) {
445         svrPublicKey = SSL_GET_SERVER_PUBLIC_KEY(ss, kt_ecdh);
446         if (svrPublicKey)
447             ec_curve = params2ecName(&svrPublicKey->u.ec.DEREncodedParams);
448         if (!SSL_IS_CURVE_NEGOTIATED(ss->ssl3.hs.negotiatedECCurves, ec_curve)) {
449             PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
450             return ec_noName;
451         }
452         signatureKeyStrength = curve2bits[ ec_curve ];
453     } else {
454         /* RSA is our signing cert */
455         int serverKeyStrengthInBits;
456
457         svrPublicKey = SSL_GET_SERVER_PUBLIC_KEY(ss, kt_rsa);
458         if (!svrPublicKey) {
459             PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
460             return ec_noName;
461         }
462
463         /* currently strength in bytes */
464         serverKeyStrengthInBits = svrPublicKey->u.rsa.modulus.len;
465         if (svrPublicKey->u.rsa.modulus.data[0] == 0) {
466             serverKeyStrengthInBits--;
467         }
468         /* convert to strength in bits */
469         serverKeyStrengthInBits *= BPB;
470
471         signatureKeyStrength =
472             SSL_RSASTRENGTH_TO_ECSTRENGTH(serverKeyStrengthInBits);
473     }
474     if ( requiredECCbits > signatureKeyStrength )
475          requiredECCbits = signatureKeyStrength;
476
477     return ssl3_GetCurveWithECKeyStrength(ss->ssl3.hs.negotiatedECCurves,
478                                           requiredECCbits);
479 }
480
481 /* function to clear out the lists */
482 static SECStatus
483 ssl3_ShutdownECDHECurves(void *appData, void *nssData)
484 {
485     int i;
486     ECDHEKeyPair *keyPair = &gECDHEKeyPairs[0];
487
488     for (i=0; i < ec_pastLastName; i++, keyPair++) {
489         if (keyPair->pair) {
490             ssl3_FreeKeyPair(keyPair->pair);
491         }
492     }
493     memset(gECDHEKeyPairs, 0, sizeof gECDHEKeyPairs);
494     return SECSuccess;
495 }
496
497 static PRStatus
498 ssl3_ECRegister(void)
499 {
500     SECStatus rv;
501     rv = NSS_RegisterShutdown(ssl3_ShutdownECDHECurves, gECDHEKeyPairs);
502     if (rv != SECSuccess) {
503         gECDHEKeyPairs[ec_noName].error = PORT_GetError();
504     }
505     return (PRStatus)rv;
506 }
507
508 /* CallOnce function, called once for each named curve. */
509 static PRStatus
510 ssl3_CreateECDHEphemeralKeyPair(void * arg)
511 {
512     SECKEYPrivateKey *    privKey  = NULL;
513     SECKEYPublicKey *     pubKey   = NULL;
514     ssl3KeyPair *         keyPair  = NULL;
515     ECName                ec_curve = (ECName)arg;
516     SECKEYECParams        ecParams = { siBuffer, NULL, 0 };
517
518     PORT_Assert(gECDHEKeyPairs[ec_curve].pair == NULL);
519
520     /* ok, no one has generated a global key for this curve yet, do so */
521     if (ssl3_ECName2Params(NULL, ec_curve, &ecParams) != SECSuccess) {
522         gECDHEKeyPairs[ec_curve].error = PORT_GetError();
523         return PR_FAILURE;
524     }
525
526     privKey = SECKEY_CreateECPrivateKey(&ecParams, &pubKey, NULL);
527     SECITEM_FreeItem(&ecParams, PR_FALSE);
528
529     if (!privKey || !pubKey || !(keyPair = ssl3_NewKeyPair(privKey, pubKey))) {
530         if (privKey) {
531             SECKEY_DestroyPrivateKey(privKey);
532         }
533         if (pubKey) {
534             SECKEY_DestroyPublicKey(pubKey);
535         }
536         ssl_MapLowLevelError(SEC_ERROR_KEYGEN_FAIL);
537         gECDHEKeyPairs[ec_curve].error = PORT_GetError();
538         return PR_FAILURE;
539     }
540
541     gECDHEKeyPairs[ec_curve].pair = keyPair;
542     return PR_SUCCESS;
543 }
544
545 /*
546  * Creates the ephemeral public and private ECDH keys used by
547  * server in ECDHE_RSA and ECDHE_ECDSA handshakes.
548  * For now, the elliptic curve is chosen to be the same
549  * strength as the signing certificate (ECC or RSA).
550  * We need an API to specify the curve. This won't be a real
551  * issue until we further develop server-side support for ECC
552  * cipher suites.
553  */
554 static SECStatus
555 ssl3_CreateECDHEphemeralKeys(sslSocket *ss, ECName ec_curve)
556 {
557     ssl3KeyPair *         keyPair        = NULL;
558
559     /* if there's no global key for this curve, make one. */
560     if (gECDHEKeyPairs[ec_curve].pair == NULL) {
561         PRStatus status;
562
563         status = PR_CallOnce(&gECDHEKeyPairs[ec_noName].once, ssl3_ECRegister);
564         if (status != PR_SUCCESS) {
565             PORT_SetError(gECDHEKeyPairs[ec_noName].error);
566             return SECFailure;
567         }
568         status = PR_CallOnceWithArg(&gECDHEKeyPairs[ec_curve].once,
569                                     ssl3_CreateECDHEphemeralKeyPair,
570                                     (void *)ec_curve);
571         if (status != PR_SUCCESS) {
572             PORT_SetError(gECDHEKeyPairs[ec_curve].error);
573             return SECFailure;
574         }
575     }
576
577     keyPair = gECDHEKeyPairs[ec_curve].pair;
578     PORT_Assert(keyPair != NULL);
579     if (!keyPair)
580         return SECFailure;
581     ss->ephemeralECDHKeyPair = ssl3_GetKeyPairRef(keyPair);
582
583     return SECSuccess;
584 }
585
586 SECStatus
587 ssl3_HandleECDHServerKeyExchange(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
588 {
589     PLArenaPool *    arena     = NULL;
590     SECKEYPublicKey *peerKey   = NULL;
591     PRBool           isTLS, isTLS12;
592     SECStatus        rv;
593     int              errCode   = SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH;
594     SSL3AlertDescription desc  = illegal_parameter;
595     SSL3Hashes       hashes;
596     SECItem          signature = {siBuffer, NULL, 0};
597
598     SECItem          ec_params = {siBuffer, NULL, 0};
599     SECItem          ec_point  = {siBuffer, NULL, 0};
600     unsigned char    paramBuf[3]; /* only for curve_type == named_curve */
601     SSL3SignatureAndHashAlgorithm sigAndHash;
602
603     sigAndHash.hashAlg = SEC_OID_UNKNOWN;
604
605     isTLS = (PRBool)(ss->ssl3.prSpec->version > SSL_LIBRARY_VERSION_3_0);
606     isTLS12 = (PRBool)(ss->ssl3.prSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
607
608     /* XXX This works only for named curves, revisit this when
609      * we support generic curves.
610      */
611     ec_params.len  = sizeof paramBuf;
612     ec_params.data = paramBuf;
613     rv = ssl3_ConsumeHandshake(ss, ec_params.data, ec_params.len, &b, &length);
614     if (rv != SECSuccess) {
615         goto loser;             /* malformed. */
616     }
617
618     /* Fail if the curve is not a named curve */
619     if ((ec_params.data[0] != ec_type_named) ||
620         (ec_params.data[1] != 0) ||
621         !supportedCurve(ec_params.data[2])) {
622             errCode = SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE;
623             desc = handshake_failure;
624             goto alert_loser;
625     }
626
627     rv = ssl3_ConsumeHandshakeVariable(ss, &ec_point, 1, &b, &length);
628     if (rv != SECSuccess) {
629         goto loser;             /* malformed. */
630     }
631     /* Fail if the ec point uses compressed representation */
632     if (ec_point.data[0] != EC_POINT_FORM_UNCOMPRESSED) {
633             errCode = SEC_ERROR_UNSUPPORTED_EC_POINT_FORM;
634             desc = handshake_failure;
635             goto alert_loser;
636     }
637
638     if (isTLS12) {
639         rv = ssl3_ConsumeSignatureAndHashAlgorithm(ss, &b, &length,
640                                                    &sigAndHash);
641         if (rv != SECSuccess) {
642             goto loser;         /* malformed or unsupported. */
643         }
644         rv = ssl3_CheckSignatureAndHashAlgorithmConsistency(
645                 &sigAndHash, ss->sec.peerCert);
646         if (rv != SECSuccess) {
647             goto loser;
648         }
649     }
650
651     rv = ssl3_ConsumeHandshakeVariable(ss, &signature, 2, &b, &length);
652     if (rv != SECSuccess) {
653         goto loser;             /* malformed. */
654     }
655
656     if (length != 0) {
657         if (isTLS)
658             desc = decode_error;
659         goto alert_loser;               /* malformed. */
660     }
661
662     PRINT_BUF(60, (NULL, "Server EC params", ec_params.data,
663         ec_params.len));
664     PRINT_BUF(60, (NULL, "Server EC point", ec_point.data, ec_point.len));
665
666     /* failures after this point are not malformed handshakes. */
667     /* TLS: send decrypt_error if signature failed. */
668     desc = isTLS ? decrypt_error : handshake_failure;
669
670     /*
671      *  check to make sure the hash is signed by right guy
672      */
673     rv = ssl3_ComputeECDHKeyHash(sigAndHash.hashAlg, ec_params, ec_point,
674                                  &ss->ssl3.hs.client_random,
675                                  &ss->ssl3.hs.server_random,
676                                  &hashes, ss->opt.bypassPKCS11);
677
678     if (rv != SECSuccess) {
679         errCode =
680             ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
681         goto alert_loser;
682     }
683     rv = ssl3_VerifySignedHashes(&hashes, ss->sec.peerCert, &signature,
684                                 isTLS, ss->pkcs11PinArg);
685     if (rv != SECSuccess)  {
686         errCode =
687             ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
688         goto alert_loser;
689     }
690
691     arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
692     if (arena == NULL) {
693         goto no_memory;
694     }
695
696     ss->sec.peerKey = peerKey = PORT_ArenaZNew(arena, SECKEYPublicKey);
697     if (peerKey == NULL) {
698         goto no_memory;
699     }
700
701     peerKey->arena                 = arena;
702     peerKey->keyType               = ecKey;
703
704     /* set up EC parameters in peerKey */
705     if (ssl3_ECName2Params(arena, ec_params.data[2],
706             &peerKey->u.ec.DEREncodedParams) != SECSuccess) {
707         /* we should never get here since we already
708          * checked that we are dealing with a supported curve
709          */
710         errCode = SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE;
711         goto alert_loser;
712     }
713
714     /* copy publicValue in peerKey */
715     if (SECITEM_CopyItem(arena, &peerKey->u.ec.publicValue,  &ec_point))
716     {
717         PORT_FreeArena(arena, PR_FALSE);
718         goto no_memory;
719     }
720     peerKey->pkcs11Slot         = NULL;
721     peerKey->pkcs11ID           = CK_INVALID_HANDLE;
722
723     ss->sec.peerKey = peerKey;
724     ss->ssl3.hs.ws = wait_cert_request;
725
726     return SECSuccess;
727
728 alert_loser:
729     (void)SSL3_SendAlert(ss, alert_fatal, desc);
730 loser:
731     PORT_SetError( errCode );
732     return SECFailure;
733
734 no_memory:      /* no-memory error has already been set. */
735     ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
736     return SECFailure;
737 }
738
739 SECStatus
740 ssl3_SendECDHServerKeyExchange(
741     sslSocket *ss,
742     const SSL3SignatureAndHashAlgorithm *sigAndHash)
743 {
744     const ssl3KEADef * kea_def     = ss->ssl3.hs.kea_def;
745     SECStatus          rv          = SECFailure;
746     int                length;
747     PRBool             isTLS, isTLS12;
748     SECItem            signed_hash = {siBuffer, NULL, 0};
749     SSL3Hashes         hashes;
750
751     SECKEYPublicKey *  ecdhePub;
752     SECItem            ec_params = {siBuffer, NULL, 0};
753     unsigned char      paramBuf[3];
754     ECName             curve;
755     SSL3KEAType        certIndex;
756
757     /* Generate ephemeral ECDH key pair and send the public key */
758     curve = ssl3_GetCurveNameForServerSocket(ss);
759     if (curve == ec_noName) {
760         goto loser;
761     }
762     rv = ssl3_CreateECDHEphemeralKeys(ss, curve);
763     if (rv != SECSuccess) {
764         goto loser;     /* err set by AppendHandshake. */
765     }
766     ecdhePub = ss->ephemeralECDHKeyPair->pubKey;
767     PORT_Assert(ecdhePub != NULL);
768     if (!ecdhePub) {
769         PORT_SetError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
770         return SECFailure;
771     }
772
773     ec_params.len  = sizeof paramBuf;
774     ec_params.data = paramBuf;
775     curve = params2ecName(&ecdhePub->u.ec.DEREncodedParams);
776     if (curve != ec_noName) {
777         ec_params.data[0] = ec_type_named;
778         ec_params.data[1] = 0x00;
779         ec_params.data[2] = curve;
780     } else {
781         PORT_SetError(SEC_ERROR_UNSUPPORTED_ELLIPTIC_CURVE);
782         goto loser;
783     }
784
785     rv = ssl3_ComputeECDHKeyHash(sigAndHash->hashAlg,
786                                  ec_params,
787                                  ecdhePub->u.ec.publicValue,
788                                  &ss->ssl3.hs.client_random,
789                                  &ss->ssl3.hs.server_random,
790                                  &hashes, ss->opt.bypassPKCS11);
791     if (rv != SECSuccess) {
792         ssl_MapLowLevelError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
793         goto loser;
794     }
795
796     isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0);
797     isTLS12 = (PRBool)(ss->ssl3.pwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_2);
798
799     /* XXX SSLKEAType isn't really a good choice for
800      * indexing certificates but that's all we have
801      * for now.
802      */
803     if (kea_def->kea == kea_ecdhe_rsa)
804         certIndex = kt_rsa;
805     else /* kea_def->kea == kea_ecdhe_ecdsa */
806         certIndex = kt_ecdh;
807
808     rv = ssl3_SignHashes(&hashes, ss->serverCerts[certIndex].SERVERKEY,
809                          &signed_hash, isTLS);
810     if (rv != SECSuccess) {
811         goto loser;             /* ssl3_SignHashes has set err. */
812     }
813     if (signed_hash.data == NULL) {
814         /* how can this happen and rv == SECSuccess ?? */
815         PORT_SetError(SSL_ERROR_SERVER_KEY_EXCHANGE_FAILURE);
816         goto loser;
817     }
818
819     length = ec_params.len +
820              1 + ecdhePub->u.ec.publicValue.len +
821              (isTLS12 ? 2 : 0) + 2 + signed_hash.len;
822
823     rv = ssl3_AppendHandshakeHeader(ss, server_key_exchange, length);
824     if (rv != SECSuccess) {
825         goto loser;     /* err set by AppendHandshake. */
826     }
827
828     rv = ssl3_AppendHandshake(ss, ec_params.data, ec_params.len);
829     if (rv != SECSuccess) {
830         goto loser;     /* err set by AppendHandshake. */
831     }
832
833     rv = ssl3_AppendHandshakeVariable(ss, ecdhePub->u.ec.publicValue.data,
834                                       ecdhePub->u.ec.publicValue.len, 1);
835     if (rv != SECSuccess) {
836         goto loser;     /* err set by AppendHandshake. */
837     }
838
839     if (isTLS12) {
840         rv = ssl3_AppendSignatureAndHashAlgorithm(ss, sigAndHash);
841         if (rv != SECSuccess) {
842             goto loser;         /* err set by AppendHandshake. */
843         }
844     }
845
846     rv = ssl3_AppendHandshakeVariable(ss, signed_hash.data,
847                                       signed_hash.len, 2);
848     if (rv != SECSuccess) {
849         goto loser;     /* err set by AppendHandshake. */
850     }
851
852     PORT_Free(signed_hash.data);
853     return SECSuccess;
854
855 loser:
856     if (signed_hash.data != NULL)
857         PORT_Free(signed_hash.data);
858     return SECFailure;
859 }
860
861 /* Lists of ECC cipher suites for searching and disabling. */
862
863 static const ssl3CipherSuite ecdh_suites[] = {
864     TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
865     TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
866     TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
867     TLS_ECDH_ECDSA_WITH_NULL_SHA,
868     TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
869     TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
870     TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
871     TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
872     TLS_ECDH_RSA_WITH_NULL_SHA,
873     TLS_ECDH_RSA_WITH_RC4_128_SHA,
874     0 /* end of list marker */
875 };
876
877 static const ssl3CipherSuite ecdh_ecdsa_suites[] = {
878     TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
879     TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
880     TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
881     TLS_ECDH_ECDSA_WITH_NULL_SHA,
882     TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
883     0 /* end of list marker */
884 };
885
886 static const ssl3CipherSuite ecdh_rsa_suites[] = {
887     TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
888     TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
889     TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
890     TLS_ECDH_RSA_WITH_NULL_SHA,
891     TLS_ECDH_RSA_WITH_RC4_128_SHA,
892     0 /* end of list marker */
893 };
894
895 static const ssl3CipherSuite ecdhe_ecdsa_suites[] = {
896     TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
897     TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
898     TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
899     TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
900     TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
901     TLS_ECDHE_ECDSA_WITH_NULL_SHA,
902     TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
903     0 /* end of list marker */
904 };
905
906 static const ssl3CipherSuite ecdhe_rsa_suites[] = {
907     TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
908     TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
909     TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
910     TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
911     TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
912     TLS_ECDHE_RSA_WITH_NULL_SHA,
913     TLS_ECDHE_RSA_WITH_RC4_128_SHA,
914     0 /* end of list marker */
915 };
916
917 /* List of all ECC cipher suites */
918 static const ssl3CipherSuite ecSuites[] = {
919     TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
920     TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
921     TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
922     TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
923     TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
924     TLS_ECDHE_ECDSA_WITH_NULL_SHA,
925     TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
926     TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
927     TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
928     TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
929     TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
930     TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
931     TLS_ECDHE_RSA_WITH_NULL_SHA,
932     TLS_ECDHE_RSA_WITH_RC4_128_SHA,
933     TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
934     TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
935     TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
936     TLS_ECDH_ECDSA_WITH_NULL_SHA,
937     TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
938     TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
939     TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
940     TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
941     TLS_ECDH_RSA_WITH_NULL_SHA,
942     TLS_ECDH_RSA_WITH_RC4_128_SHA,
943     0 /* end of list marker */
944 };
945
946 /* On this socket, Disable the ECC cipher suites in the argument's list */
947 SECStatus
948 ssl3_DisableECCSuites(sslSocket * ss, const ssl3CipherSuite * suite)
949 {
950     if (!suite)
951         suite = ecSuites;
952     for (; *suite; ++suite) {
953         SECStatus rv      = ssl3_CipherPrefSet(ss, *suite, PR_FALSE);
954
955         PORT_Assert(rv == SECSuccess); /* else is coding error */
956     }
957     return SECSuccess;
958 }
959
960 /* Look at the server certs configured on this socket, and disable any
961  * ECC cipher suites that are not supported by those certs.
962  */
963 void
964 ssl3_FilterECCipherSuitesByServerCerts(sslSocket * ss)
965 {
966     CERTCertificate * svrCert;
967
968     svrCert = ss->serverCerts[kt_rsa].serverCert;
969     if (!svrCert) {
970         ssl3_DisableECCSuites(ss, ecdhe_rsa_suites);
971     }
972
973     svrCert = ss->serverCerts[kt_ecdh].serverCert;
974     if (!svrCert) {
975         ssl3_DisableECCSuites(ss, ecdh_suites);
976         ssl3_DisableECCSuites(ss, ecdhe_ecdsa_suites);
977     } else {
978         SECOidTag sigTag = SECOID_GetAlgorithmTag(&svrCert->signature);
979
980         switch (sigTag) {
981         case SEC_OID_PKCS1_RSA_ENCRYPTION:
982         case SEC_OID_PKCS1_MD2_WITH_RSA_ENCRYPTION:
983         case SEC_OID_PKCS1_MD4_WITH_RSA_ENCRYPTION:
984         case SEC_OID_PKCS1_MD5_WITH_RSA_ENCRYPTION:
985         case SEC_OID_PKCS1_SHA1_WITH_RSA_ENCRYPTION:
986         case SEC_OID_PKCS1_SHA224_WITH_RSA_ENCRYPTION:
987         case SEC_OID_PKCS1_SHA256_WITH_RSA_ENCRYPTION:
988         case SEC_OID_PKCS1_SHA384_WITH_RSA_ENCRYPTION:
989         case SEC_OID_PKCS1_SHA512_WITH_RSA_ENCRYPTION:
990             ssl3_DisableECCSuites(ss, ecdh_ecdsa_suites);
991             break;
992         case SEC_OID_ANSIX962_ECDSA_SHA1_SIGNATURE:
993         case SEC_OID_ANSIX962_ECDSA_SHA224_SIGNATURE:
994         case SEC_OID_ANSIX962_ECDSA_SHA256_SIGNATURE:
995         case SEC_OID_ANSIX962_ECDSA_SHA384_SIGNATURE:
996         case SEC_OID_ANSIX962_ECDSA_SHA512_SIGNATURE:
997         case SEC_OID_ANSIX962_ECDSA_SIGNATURE_RECOMMENDED_DIGEST:
998         case SEC_OID_ANSIX962_ECDSA_SIGNATURE_SPECIFIED_DIGEST:
999             ssl3_DisableECCSuites(ss, ecdh_rsa_suites);
1000             break;
1001         default:
1002             ssl3_DisableECCSuites(ss, ecdh_suites);
1003             break;
1004         }
1005     }
1006 }
1007
1008 /* Ask: is ANY ECC cipher suite enabled on this socket? */
1009 /* Order(N^2).  Yuk.  Also, this ignores export policy. */
1010 PRBool
1011 ssl3_IsECCEnabled(sslSocket * ss)
1012 {
1013     const ssl3CipherSuite * suite;
1014     PK11SlotInfo *slot;
1015
1016     /* make sure we can do ECC */
1017     slot = PK11_GetBestSlot(CKM_ECDH1_DERIVE,  ss->pkcs11PinArg);
1018     if (!slot) {
1019         return PR_FALSE;
1020     }
1021     PK11_FreeSlot(slot);
1022
1023     /* make sure an ECC cipher is enabled */
1024     for (suite = ecSuites; *suite; ++suite) {
1025         PRBool    enabled = PR_FALSE;
1026         SECStatus rv      = ssl3_CipherPrefGet(ss, *suite, &enabled);
1027
1028         PORT_Assert(rv == SECSuccess); /* else is coding error */
1029         if (rv == SECSuccess && enabled)
1030             return PR_TRUE;
1031     }
1032     return PR_FALSE;
1033 }
1034
1035 #define BE(n) 0, n
1036
1037 /* Prefabricated TLS client hello extension, Elliptic Curves List,
1038  * offers only 3 curves, the Suite B curves, 23-25
1039  */
1040 static const PRUint8 suiteBECList[12] = {
1041     BE(10),         /* Extension type */
1042     BE( 8),         /* octets that follow ( 3 pairs + 1 length pair) */
1043     BE( 6),         /* octets that follow ( 3 pairs) */
1044     BE(23), BE(24), BE(25)
1045 };
1046
1047 /* Prefabricated TLS client hello extension, Elliptic Curves List,
1048  * offers curves 1-25.
1049  */
1050 static const PRUint8 tlsECList[56] = {
1051     BE(10),         /* Extension type */
1052     BE(52),         /* octets that follow (25 pairs + 1 length pair) */
1053     BE(50),         /* octets that follow (25 pairs) */
1054             BE( 1), BE( 2), BE( 3), BE( 4), BE( 5), BE( 6), BE( 7),
1055     BE( 8), BE( 9), BE(10), BE(11), BE(12), BE(13), BE(14), BE(15),
1056     BE(16), BE(17), BE(18), BE(19), BE(20), BE(21), BE(22), BE(23),
1057     BE(24), BE(25)
1058 };
1059
1060 static const PRUint8 ecPtFmt[6] = {
1061     BE(11),         /* Extension type */
1062     BE( 2),         /* octets that follow */
1063              1,     /* octets that follow */
1064                  0  /* uncompressed type only */
1065 };
1066
1067 /* This function already presumes we can do ECC, ssl3_IsECCEnabled must be
1068  * called before this function. It looks to see if we have a token which
1069  * is capable of doing smaller than SuiteB curves. If the token can, we
1070  * presume the token can do the whole SSL suite of curves. If it can't we
1071  * presume the token that allowed ECC to be enabled can only do suite B
1072  * curves. */
1073 static PRBool
1074 ssl3_SuiteBOnly(sslSocket *ss)
1075 {
1076     /* See if we can support small curves (like 163). If not, assume we can
1077      * only support Suite-B curves (P-256, P-384, P-521). */
1078     PK11SlotInfo *slot =
1079         PK11_GetBestSlotWithAttributes(CKM_ECDH1_DERIVE, 0, 163,
1080                                        ss ? ss->pkcs11PinArg : NULL);
1081
1082     if (!slot) {
1083         /* nope, presume we can only do suite B */
1084         return PR_TRUE;
1085     }
1086     /* we can, presume we can do all curves */
1087     PK11_FreeSlot(slot);
1088     return PR_FALSE;
1089 }
1090
1091 /* Send our "canned" (precompiled) Supported Elliptic Curves extension,
1092  * which says that we support all TLS-defined named curves.
1093  */
1094 PRInt32
1095 ssl3_SendSupportedCurvesXtn(
1096                         sslSocket * ss,
1097                         PRBool      append,
1098                         PRUint32    maxBytes)
1099 {
1100     PRInt32 ecListSize = 0;
1101     const PRUint8 *ecList = NULL;
1102
1103     if (!ss || !ssl3_IsECCEnabled(ss))
1104         return 0;
1105
1106     if (ssl3_SuiteBOnly(ss)) {
1107         ecListSize = sizeof suiteBECList;
1108         ecList = suiteBECList;
1109     } else {
1110         ecListSize = sizeof tlsECList;
1111         ecList = tlsECList;
1112     }
1113
1114     if (append && maxBytes >= ecListSize) {
1115         SECStatus rv = ssl3_AppendHandshake(ss, ecList, ecListSize);
1116         if (rv != SECSuccess)
1117             return -1;
1118         if (!ss->sec.isServer) {
1119             TLSExtensionData *xtnData = &ss->xtnData;
1120             xtnData->advertised[xtnData->numAdvertised++] =
1121                 ssl_elliptic_curves_xtn;
1122         }
1123     }
1124     return ecListSize;
1125 }
1126
1127 PRUint32
1128 ssl3_GetSupportedECCurveMask(sslSocket *ss)
1129 {
1130     if (ssl3_SuiteBOnly(ss)) {
1131         return SSL3_SUITE_B_SUPPORTED_CURVES_MASK;
1132     }
1133     return SSL3_ALL_SUPPORTED_CURVES_MASK;
1134 }
1135
1136 /* Send our "canned" (precompiled) Supported Point Formats extension,
1137  * which says that we only support uncompressed points.
1138  */
1139 PRInt32
1140 ssl3_SendSupportedPointFormatsXtn(
1141                         sslSocket * ss,
1142                         PRBool      append,
1143                         PRUint32    maxBytes)
1144 {
1145     if (!ss || !ssl3_IsECCEnabled(ss))
1146         return 0;
1147     if (append && maxBytes >= (sizeof ecPtFmt)) {
1148         SECStatus rv = ssl3_AppendHandshake(ss, ecPtFmt, (sizeof ecPtFmt));
1149         if (rv != SECSuccess)
1150             return -1;
1151         if (!ss->sec.isServer) {
1152             TLSExtensionData *xtnData = &ss->xtnData;
1153             xtnData->advertised[xtnData->numAdvertised++] =
1154                 ssl_ec_point_formats_xtn;
1155         }
1156     }
1157     return (sizeof ecPtFmt);
1158 }
1159
1160 /* Just make sure that the remote client supports uncompressed points,
1161  * Since that is all we support.  Disable ECC cipher suites if it doesn't.
1162  */
1163 SECStatus
1164 ssl3_HandleSupportedPointFormatsXtn(sslSocket *ss, PRUint16 ex_type,
1165                                     SECItem *data)
1166 {
1167     int i;
1168
1169     if (data->len < 2 || data->len > 255 || !data->data ||
1170         data->len != (unsigned int)data->data[0] + 1) {
1171         /* malformed */
1172         goto loser;
1173     }
1174     for (i = data->len; --i > 0; ) {
1175         if (data->data[i] == 0) {
1176             /* indicate that we should send a reply */
1177             SECStatus rv;
1178             rv = ssl3_RegisterServerHelloExtensionSender(ss, ex_type,
1179                               &ssl3_SendSupportedPointFormatsXtn);
1180             return rv;
1181         }
1182     }
1183 loser:
1184     /* evil client doesn't support uncompressed */
1185     ssl3_DisableECCSuites(ss, ecSuites);
1186     return SECFailure;
1187 }
1188
1189
1190 #define SSL3_GET_SERVER_PUBLICKEY(sock, type) \
1191     (ss->serverCerts[type].serverKeyPair ? \
1192     ss->serverCerts[type].serverKeyPair->pubKey : NULL)
1193
1194 /* Extract the TLS curve name for the public key in our EC server cert. */
1195 ECName ssl3_GetSvrCertCurveName(sslSocket *ss)
1196 {
1197     SECKEYPublicKey       *srvPublicKey;
1198     ECName                ec_curve       = ec_noName;
1199
1200     srvPublicKey = SSL3_GET_SERVER_PUBLICKEY(ss, kt_ecdh);
1201     if (srvPublicKey) {
1202         ec_curve = params2ecName(&srvPublicKey->u.ec.DEREncodedParams);
1203     }
1204     return ec_curve;
1205 }
1206
1207 /* Ensure that the curve in our server cert is one of the ones suppored
1208  * by the remote client, and disable all ECC cipher suites if not.
1209  */
1210 SECStatus
1211 ssl3_HandleSupportedCurvesXtn(sslSocket *ss, PRUint16 ex_type, SECItem *data)
1212 {
1213     PRInt32  list_len;
1214     PRUint32 peerCurves   = 0;
1215     PRUint32 mutualCurves = 0;
1216     PRUint16 svrCertCurveName;
1217
1218     if (!data->data || data->len < 4 || data->len > 65535)
1219         goto loser;
1220     /* get the length of elliptic_curve_list */
1221     list_len = ssl3_ConsumeHandshakeNumber(ss, 2, &data->data, &data->len);
1222     if (list_len < 0 || data->len != list_len || (data->len % 2) != 0) {
1223         /* malformed */
1224         goto loser;
1225     }
1226     /* build bit vector of peer's supported curve names */
1227     while (data->len) {
1228         PRInt32  curve_name =
1229                  ssl3_ConsumeHandshakeNumber(ss, 2, &data->data, &data->len);
1230         if (curve_name > ec_noName && curve_name < ec_pastLastName) {
1231             peerCurves |= (1U << curve_name);
1232         }
1233     }
1234     /* What curves do we support in common? */
1235     mutualCurves = ss->ssl3.hs.negotiatedECCurves &= peerCurves;
1236     if (!mutualCurves) { /* no mutually supported EC Curves */
1237         goto loser;
1238     }
1239
1240     /* if our ECC cert doesn't use one of these supported curves,
1241      * disable ECC cipher suites that require an ECC cert.
1242      */
1243     svrCertCurveName = ssl3_GetSvrCertCurveName(ss);
1244     if (svrCertCurveName != ec_noName &&
1245         (mutualCurves & (1U << svrCertCurveName)) != 0) {
1246         return SECSuccess;
1247     }
1248     /* Our EC cert doesn't contain a mutually supported curve.
1249      * Disable all ECC cipher suites that require an EC cert
1250      */
1251     ssl3_DisableECCSuites(ss, ecdh_ecdsa_suites);
1252     ssl3_DisableECCSuites(ss, ecdhe_ecdsa_suites);
1253     return SECFailure;
1254
1255 loser:
1256     /* no common curve supported */
1257     ssl3_DisableECCSuites(ss, ecSuites);
1258     return SECFailure;
1259 }
1260
1261 #endif /* NSS_DISABLE_ECC */