1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 // This file implements the Socialist Millionaires Protocol as described in
6 // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
7 // specification is required in order to understand this code and, where
8 // possible, the variable names in the code match up with the spec.
20 type smpFailure string
22 func (s smpFailure) Error() string {
26 var smpFailureError = smpFailure("otr: SMP protocol failed")
27 var smpSecretMissingError = smpFailure("otr: mutual secret needed")
38 type smpState struct {
40 a2, a3, b2, b3, pb, qb *big.Int
43 g3b, papb, qaqb, ra *big.Int
49 func (c *Conversation) startSMP(question string) (tlvs []tlv) {
50 if c.smp.state != smpState1 {
51 tlvs = append(tlvs, c.generateSMPAbort())
53 tlvs = append(tlvs, c.generateSMP1(question))
55 c.smp.state = smpState2
59 func (c *Conversation) resetSMP() {
60 c.smp.state = smpState1
65 func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
70 if c.smp.state != smpState1 {
75 case tlvTypeSMP1WithQuestion:
76 // We preprocess this into a SMP1 message.
77 nulPos := bytes.IndexByte(data, 0)
79 err = errors.New("otr: SMP message with question didn't contain a NUL byte")
82 c.smp.question = string(data[:nulPos])
83 data = data[nulPos+1:]
86 numMPIs, data, ok := getU32(data)
87 if !ok || numMPIs > 20 {
88 err = errors.New("otr: corrupt SMP message")
92 mpis := make([]*big.Int, numMPIs)
95 mpis[i], data, ok = getMPI(data)
97 err = errors.New("otr: corrupt SMP message")
103 case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
104 if c.smp.state != smpState1 {
106 out = c.generateSMPAbort()
109 if c.smp.secret == nil {
110 err = smpSecretMissingError
113 if err = c.processSMP1(mpis); err != nil {
116 c.smp.state = smpState3
117 out = c.generateSMP2()
119 if c.smp.state != smpState2 {
121 out = c.generateSMPAbort()
124 if out, err = c.processSMP2(mpis); err != nil {
125 out = c.generateSMPAbort()
128 c.smp.state = smpState4
130 if c.smp.state != smpState3 {
132 out = c.generateSMPAbort()
135 if out, err = c.processSMP3(mpis); err != nil {
138 c.smp.state = smpState1
142 if c.smp.state != smpState4 {
144 out = c.generateSMPAbort()
147 if err = c.processSMP4(mpis); err != nil {
148 out = c.generateSMPAbort()
151 c.smp.state = smpState1
155 panic("unknown SMP message")
161 func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
163 h.Write([]byte{smpVersion})
165 h.Write(c.PrivateKey.PublicKey.Fingerprint())
166 h.Write(c.TheirPublicKey.Fingerprint())
168 h.Write(c.TheirPublicKey.Fingerprint())
169 h.Write(c.PrivateKey.PublicKey.Fingerprint())
172 h.Write(mutualSecret)
173 c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
176 func (c *Conversation) generateSMP1(question string) tlv {
178 c.smp.a2 = c.randMPI(randBuf[:])
179 c.smp.a3 = c.randMPI(randBuf[:])
180 g2a := new(big.Int).Exp(g, c.smp.a2, p)
181 g3a := new(big.Int).Exp(g, c.smp.a3, p)
184 r2 := c.randMPI(randBuf[:])
185 r := new(big.Int).Exp(g, r2, p)
186 c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
187 d2 := new(big.Int).Mul(c.smp.a2, c2)
194 r3 := c.randMPI(randBuf[:])
196 c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
197 d3 := new(big.Int).Mul(c.smp.a3, c3)
205 if len(question) > 0 {
206 ret.typ = tlvTypeSMP1WithQuestion
207 ret.data = append(ret.data, question...)
208 ret.data = append(ret.data, 0)
210 ret.typ = tlvTypeSMP1
212 ret.data = appendU32(ret.data, 6)
213 ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
217 func (c *Conversation) processSMP1(mpis []*big.Int) error {
219 return errors.New("otr: incorrect number of arguments in SMP1 message")
229 r := new(big.Int).Exp(g, d2, p)
230 s := new(big.Int).Exp(g2a, c2, p)
233 t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
235 return errors.New("otr: ZKP c2 incorrect in SMP1 message")
241 t.SetBytes(hashMPIs(h, 2, r))
243 return errors.New("otr: ZKP c3 incorrect in SMP1 message")
251 func (c *Conversation) generateSMP2() tlv {
253 b2 := c.randMPI(randBuf[:])
254 c.smp.b3 = c.randMPI(randBuf[:])
255 r2 := c.randMPI(randBuf[:])
256 r3 := c.randMPI(randBuf[:])
257 r4 := c.randMPI(randBuf[:])
258 r5 := c.randMPI(randBuf[:])
259 r6 := c.randMPI(randBuf[:])
261 g2b := new(big.Int).Exp(g, b2, p)
262 g3b := new(big.Int).Exp(g, c.smp.b3, p)
264 r := new(big.Int).Exp(g, r2, p)
266 c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
267 d2 := new(big.Int).Mul(b2, c2)
275 c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
276 d3 := new(big.Int).Mul(c.smp.b3, c3)
283 c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
284 c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
285 c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
286 c.smp.qb = new(big.Int).Exp(g, r4, p)
287 r.Exp(c.smp.g2, c.smp.secret, p)
288 c.smp.qb.Mul(c.smp.qb, r)
289 c.smp.qb.Mod(c.smp.qb, p)
292 s.Exp(c.smp.g2, r6, p)
296 r.Exp(c.smp.g3, r5, p)
297 cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
299 // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
303 d5 := new(big.Int).Mod(r, q)
308 s.Mul(c.smp.secret, cp)
310 d6 := new(big.Int).Mod(r, q)
316 ret.typ = tlvTypeSMP2
317 ret.data = appendU32(ret.data, 11)
318 ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
322 func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
324 err = errors.New("otr: incorrect number of arguments in SMP2 message")
340 r := new(big.Int).Exp(g, d2, p)
341 s := new(big.Int).Exp(g2b, c2, p)
344 s.SetBytes(hashMPIs(h, 3, r))
346 err = errors.New("otr: ZKP c2 failed in SMP2 message")
354 s.SetBytes(hashMPIs(h, 4, r))
356 err = errors.New("otr: ZKP c3 failed in SMP2 message")
360 c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
361 c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
364 s.Exp(c.smp.g2, d6, p)
370 s.Exp(c.smp.g3, d5, p)
371 t := new(big.Int).Exp(pb, cp, p)
374 t.SetBytes(hashMPIs(h, 5, s, r))
376 err = errors.New("otr: ZKP cP failed in SMP2 message")
381 r4 := c.randMPI(randBuf[:])
382 r5 := c.randMPI(randBuf[:])
383 r6 := c.randMPI(randBuf[:])
384 r7 := c.randMPI(randBuf[:])
386 pa := new(big.Int).Exp(c.smp.g3, r4, p)
387 r.Exp(c.smp.g2, c.smp.secret, p)
388 qa := new(big.Int).Exp(g, r4, p)
393 s.Exp(c.smp.g2, r6, p)
397 s.Exp(c.smp.g3, r5, p)
398 cp.SetBytes(hashMPIs(h, 6, s, r))
401 d5 = new(big.Int).Sub(r5, r)
407 r.Mul(c.smp.secret, cp)
408 d6 = new(big.Int).Sub(r6, r)
415 qaqb := new(big.Int).Mul(qa, r)
418 ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
421 cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
424 d7 := new(big.Int).Sub(r7, r)
434 c.smp.papb = new(big.Int).Mul(pa, r)
435 c.smp.papb.Mod(c.smp.papb, p)
438 out.typ = tlvTypeSMP3
439 out.data = appendU32(out.data, 8)
440 out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
444 func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
446 err = errors.New("otr: incorrect number of arguments in SMP3 message")
459 r := new(big.Int).Exp(g, d5, p)
460 s := new(big.Int).Exp(c.smp.g2, d6, p)
466 s.Exp(c.smp.g3, d5, p)
467 t := new(big.Int).Exp(pa, cp, p)
470 t.SetBytes(hashMPIs(h, 6, s, r))
472 err = errors.New("otr: ZKP cP failed in SMP3 message")
476 r.ModInverse(c.smp.qb, p)
477 qaqb := new(big.Int).Mul(qa, r)
486 t.Exp(c.smp.g3a, cr, p)
489 t.SetBytes(hashMPIs(h, 7, s, r))
491 err = errors.New("otr: ZKP cR failed in SMP3 message")
496 r7 := c.randMPI(randBuf[:])
497 rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
501 cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
504 d7 = new(big.Int).Sub(r7, r)
510 out.typ = tlvTypeSMP4
511 out.data = appendU32(out.data, 3)
512 out.data = appendMPIs(out.data, rb, cr, d7)
514 r.ModInverse(c.smp.pb, p)
517 s.Exp(ra, c.smp.b3, p)
519 err = smpFailureError
525 func (c *Conversation) processSMP4(mpis []*big.Int) error {
527 return errors.New("otr: incorrect number of arguments in SMP4 message")
534 r := new(big.Int).Exp(c.smp.qaqb, d7, p)
535 s := new(big.Int).Exp(rb, cr, p)
540 t := new(big.Int).Exp(c.smp.g3b, cr, p)
543 t.SetBytes(hashMPIs(h, 8, s, r))
545 return errors.New("otr: ZKP cR failed in SMP4 message")
548 r.Exp(rb, c.smp.a3, p)
549 if r.Cmp(c.smp.papb) != 0 {
550 return smpFailureError
556 func (c *Conversation) generateSMPAbort() tlv {
557 return tlv{typ: tlvTypeSMPAbort}
560 func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
567 h.Write([]byte{magic})
568 for _, mpi := range mpis {
569 h.Write(appendMPI(nil, mpi))