Imported Upstream version 2.4.2
[scm/test.git] / lfsapi / ntlm.go
1 package lfsapi
2
3 import (
4         "encoding/base64"
5         "fmt"
6         "io"
7         "io/ioutil"
8         "net/http"
9         "net/url"
10         "strings"
11
12         "github.com/git-lfs/git-lfs/errors"
13 )
14
15 type ntmlCredentials struct {
16         domain   string
17         username string
18         password string
19 }
20
21 func (c *Client) doWithNTLM(req *http.Request, credHelper CredentialHelper, creds Creds, credsURL *url.URL) (*http.Response, error) {
22         res, err := c.do(req, "", nil)
23         if err != nil && !errors.IsAuthError(err) {
24                 return res, err
25         }
26
27         if res.StatusCode != 401 {
28                 return res, nil
29         }
30
31         return c.ntlmReAuth(req, credHelper, creds, true)
32 }
33
34 // If the status is 401 then we need to re-authenticate
35 func (c *Client) ntlmReAuth(req *http.Request, credHelper CredentialHelper, creds Creds, retry bool) (*http.Response, error) {
36         ntmlCreds, err := ntlmGetCredentials(creds)
37         if err != nil {
38                 return nil, err
39         }
40
41         res, err := c.ntlmAuthenticateRequest(req, ntmlCreds)
42         if err != nil && !errors.IsAuthError(err) {
43                 return res, err
44         }
45
46         switch res.StatusCode {
47         case 401:
48                 credHelper.Reject(creds)
49                 if retry {
50                         return c.ntlmReAuth(req, credHelper, creds, false)
51                 }
52         case 403:
53                 credHelper.Reject(creds)
54         default:
55                 if res.StatusCode < 300 && res.StatusCode > 199 {
56                         credHelper.Approve(creds)
57                 }
58         }
59
60         return res, nil
61 }
62
63 func (c *Client) ntlmSendType1Message(req *http.Request, message []byte) (*http.Response, []byte, error) {
64         res, err := c.ntlmSendMessage(req, message)
65         if err != nil && !errors.IsAuthError(err) {
66                 return res, nil, err
67         }
68
69         io.Copy(ioutil.Discard, res.Body)
70         res.Body.Close()
71
72         by, err := parseChallengeResponse(res)
73         return res, by, err
74 }
75
76 func (c *Client) ntlmSendType3Message(req *http.Request, authenticate []byte) (*http.Response, error) {
77         return c.ntlmSendMessage(req, authenticate)
78 }
79
80 func (c *Client) ntlmSendMessage(req *http.Request, message []byte) (*http.Response, error) {
81         body, err := rewoundRequestBody(req)
82         if err != nil {
83                 return nil, err
84         }
85         req.Body = body
86
87         msg := base64.StdEncoding.EncodeToString(message)
88         req.Header.Set("Authorization", "NTLM "+msg)
89         return c.do(req, "", nil)
90 }
91
92 func parseChallengeResponse(res *http.Response) ([]byte, error) {
93         header := res.Header.Get("Www-Authenticate")
94         if len(header) < 6 {
95                 return nil, fmt.Errorf("Invalid NTLM challenge response: %q", header)
96         }
97
98         //parse out the "NTLM " at the beginning of the response
99         challenge := header[5:]
100         val, err := base64.StdEncoding.DecodeString(challenge)
101
102         if err != nil {
103                 return nil, err
104         }
105         return []byte(val), nil
106 }
107
108 func rewoundRequestBody(req *http.Request) (io.ReadCloser, error) {
109         if req.Body == nil {
110                 return nil, nil
111         }
112
113         body, ok := req.Body.(ReadSeekCloser)
114         if !ok {
115                 return nil, fmt.Errorf("Request body must implement io.ReadCloser and io.Seeker. Got: %T", body)
116         }
117
118         _, err := body.Seek(0, io.SeekStart)
119         return body, err
120 }
121
122 func ntlmGetCredentials(creds Creds) (*ntmlCredentials, error) {
123         username := creds["username"]
124         password := creds["password"]
125
126         if username == "" && password == "" {
127                 return nil, nil
128         }
129
130         splits := strings.Split(username, "\\")
131         if len(splits) != 2 {
132                 return nil, fmt.Errorf("Your user name must be of the form DOMAIN\\user. It is currently '%s'", username)
133         }
134
135         domain := strings.ToUpper(splits[0])
136         username = splits[1]
137
138         return &ntmlCredentials{domain: domain, username: username, password: password}, nil
139 }