Imported Upstream version 2.4.2
[scm/test.git] / lfsapi / client_test.go
1 package lfsapi
2
3 import (
4         "encoding/base64"
5         "encoding/json"
6         "fmt"
7         "net"
8         "net/http"
9         "net/http/httptest"
10         "strings"
11         "sync/atomic"
12         "testing"
13
14         "github.com/stretchr/testify/assert"
15         "github.com/stretchr/testify/require"
16 )
17
18 type redirectTest struct {
19         Test string
20 }
21
22 func TestClientRedirect(t *testing.T) {
23         var srv3Https, srv3Http string
24
25         var called1 uint32
26         var called2 uint32
27         var called3 uint32
28         srv3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29                 atomic.AddUint32(&called3, 1)
30                 t.Logf("srv3 req %s %s", r.Method, r.URL.Path)
31                 assert.Equal(t, "POST", r.Method)
32
33                 switch r.URL.Path {
34                 case "/upgrade":
35                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
36                         assert.Equal(t, "1", r.Header.Get("A"))
37                         w.Header().Set("Location", srv3Https+"/upgraded")
38                         w.WriteHeader(301)
39                 case "/upgraded":
40                         // Since srv3 listens on both a TLS-enabled socket and a
41                         // TLS-disabled one, they are two different hosts.
42                         // Ensure that, even though this is a "secure" upgrade,
43                         // the authorization header is stripped.
44                         assert.Equal(t, "", r.Header.Get("Authorization"))
45                         assert.Equal(t, "1", r.Header.Get("A"))
46
47                 case "/downgrade":
48                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
49                         assert.Equal(t, "1", r.Header.Get("A"))
50                         w.Header().Set("Location", srv3Http+"/404")
51                         w.WriteHeader(301)
52
53                 default:
54                         w.WriteHeader(404)
55                 }
56         }))
57
58         srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59                 atomic.AddUint32(&called2, 1)
60                 t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
61                 assert.Equal(t, "POST", r.Method)
62
63                 switch r.URL.Path {
64                 case "/ok":
65                         assert.Equal(t, "", r.Header.Get("Authorization"))
66                         assert.Equal(t, "1", r.Header.Get("A"))
67                         body := &redirectTest{}
68                         err := json.NewDecoder(r.Body).Decode(body)
69                         assert.Nil(t, err)
70                         assert.Equal(t, "External", body.Test)
71
72                         w.WriteHeader(200)
73                 default:
74                         w.WriteHeader(404)
75                 }
76         }))
77
78         srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
79                 atomic.AddUint32(&called1, 1)
80                 t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
81                 assert.Equal(t, "POST", r.Method)
82
83                 switch r.URL.Path {
84                 case "/local":
85                         w.Header().Set("Location", "/ok")
86                         w.WriteHeader(307)
87                 case "/external":
88                         w.Header().Set("Location", srv2.URL+"/ok")
89                         w.WriteHeader(307)
90                 case "/ok":
91                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
92                         assert.Equal(t, "1", r.Header.Get("A"))
93                         body := &redirectTest{}
94                         err := json.NewDecoder(r.Body).Decode(body)
95                         assert.Nil(t, err)
96                         assert.Equal(t, "Local", body.Test)
97
98                         w.WriteHeader(200)
99                 default:
100                         w.WriteHeader(404)
101                 }
102         }))
103         defer srv1.Close()
104         defer srv2.Close()
105         defer srv3.Close()
106
107         srv3InsecureListener, err := net.Listen("tcp", "127.0.0.1:0")
108         require.Nil(t, err)
109
110         go http.Serve(srv3InsecureListener, srv3.Config.Handler)
111         defer srv3InsecureListener.Close()
112
113         srv3Https = srv3.URL
114         srv3Http = fmt.Sprintf("http://%s", srv3InsecureListener.Addr().String())
115
116         c, err := NewClient(NewContext(nil, nil, map[string]string{
117                 fmt.Sprintf("http.%s.sslverify", srv3Https):  "false",
118                 fmt.Sprintf("http.%s/.sslverify", srv3Https): "false",
119                 fmt.Sprintf("http.%s.sslverify", srv3Http):   "false",
120                 fmt.Sprintf("http.%s/.sslverify", srv3Http):  "false",
121                 fmt.Sprintf("http.sslverify"):                "false",
122         }))
123         require.Nil(t, err)
124
125         // local redirect
126         req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
127         require.Nil(t, err)
128         req.Header.Set("Authorization", "auth")
129         req.Header.Set("A", "1")
130
131         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
132
133         res, err := c.Do(req)
134         require.Nil(t, err)
135         assert.Equal(t, 200, res.StatusCode)
136         assert.EqualValues(t, 2, called1)
137         assert.EqualValues(t, 0, called2)
138
139         // external redirect
140         req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
141         require.Nil(t, err)
142         req.Header.Set("Authorization", "auth")
143         req.Header.Set("A", "1")
144
145         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
146
147         res, err = c.Do(req)
148         require.Nil(t, err)
149         assert.Equal(t, 200, res.StatusCode)
150         assert.EqualValues(t, 3, called1)
151         assert.EqualValues(t, 1, called2)
152
153         // http -> https (secure upgrade)
154
155         req, err = http.NewRequest("POST", srv3Http+"/upgrade", nil)
156         require.Nil(t, err)
157         req.Header.Set("Authorization", "auth")
158         req.Header.Set("A", "1")
159
160         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "http->https"}))
161
162         res, err = c.Do(req)
163         require.Nil(t, err)
164         assert.Equal(t, 200, res.StatusCode)
165         assert.EqualValues(t, 2, atomic.LoadUint32(&called3))
166
167         // https -> http (insecure downgrade)
168
169         req, err = http.NewRequest("POST", srv3Https+"/downgrade", nil)
170         require.Nil(t, err)
171         req.Header.Set("Authorization", "auth")
172         req.Header.Set("A", "1")
173
174         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "https->http"}))
175
176         _, err = c.Do(req)
177         assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http")
178 }
179
180 func TestClientRedirectReauthenticate(t *testing.T) {
181         var srv1, srv2 *httptest.Server
182         var called1, called2 uint32
183         var creds1, creds2 Creds
184
185         srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
186                 atomic.AddUint32(&called1, 1)
187
188                 if hdr := r.Header.Get("Authorization"); len(hdr) > 0 {
189                         parts := strings.SplitN(hdr, " ", 2)
190                         typ, b64 := parts[0], parts[1]
191
192                         auth, err := base64.URLEncoding.DecodeString(b64)
193                         assert.Nil(t, err)
194                         assert.Equal(t, "Basic", typ)
195                         assert.Equal(t, "user1:pass1", string(auth))
196
197                         http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently)
198                         return
199                 }
200                 w.WriteHeader(http.StatusUnauthorized)
201         }))
202
203         srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
204                 atomic.AddUint32(&called2, 1)
205
206                 parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
207                 typ, b64 := parts[0], parts[1]
208
209                 auth, err := base64.URLEncoding.DecodeString(b64)
210                 assert.Nil(t, err)
211                 assert.Equal(t, "Basic", typ)
212                 assert.Equal(t, "user2:pass2", string(auth))
213         }))
214
215         // Change the URL of srv2 to make it appears as if it is a different
216         // host.
217         srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1)
218
219         creds1 = Creds(map[string]string{
220                 "protocol": "http",
221                 "host":     strings.TrimPrefix(srv1.URL, "http://"),
222
223                 "username": "user1",
224                 "password": "pass1",
225         })
226         creds2 = Creds(map[string]string{
227                 "protocol": "http",
228                 "host":     strings.TrimPrefix(srv2.URL, "http://"),
229
230                 "username": "user2",
231                 "password": "pass2",
232         })
233
234         defer srv1.Close()
235         defer srv2.Close()
236
237         c, err := NewClient(NewContext(nil, nil, nil))
238         creds := newCredentialCacher()
239         creds.Approve(creds1)
240         creds.Approve(creds2)
241         c.Credentials = creds
242
243         req, err := http.NewRequest("GET", srv1.URL, nil)
244         require.Nil(t, err)
245
246         _, err = c.DoWithAuth("", req)
247         assert.Nil(t, err)
248
249         // called1 is 2 since LFS tries an unauthenticated request first
250         assert.EqualValues(t, 2, called1)
251         assert.EqualValues(t, 1, called2)
252 }
253
254 func TestNewClient(t *testing.T) {
255         c, err := NewClient(NewContext(nil, nil, map[string]string{
256                 "lfs.dialtimeout":         "151",
257                 "lfs.keepalive":           "152",
258                 "lfs.tlstimeout":          "153",
259                 "lfs.concurrenttransfers": "154",
260         }))
261
262         require.Nil(t, err)
263         assert.Equal(t, 151, c.DialTimeout)
264         assert.Equal(t, 152, c.KeepaliveTimeout)
265         assert.Equal(t, 153, c.TLSTimeout)
266         assert.Equal(t, 154, c.ConcurrentTransfers)
267 }
268
269 func TestNewClientWithGitSSLVerify(t *testing.T) {
270         c, err := NewClient(nil)
271         assert.Nil(t, err)
272         assert.False(t, c.SkipSSLVerify)
273
274         for _, value := range []string{"true", "1", "t"} {
275                 c, err = NewClient(NewContext(nil, nil, map[string]string{
276                         "http.sslverify": value,
277                 }))
278                 t.Logf("http.sslverify: %q", value)
279                 assert.Nil(t, err)
280                 assert.False(t, c.SkipSSLVerify)
281         }
282
283         for _, value := range []string{"false", "0", "f"} {
284                 c, err = NewClient(NewContext(nil, nil, map[string]string{
285                         "http.sslverify": value,
286                 }))
287                 t.Logf("http.sslverify: %q", value)
288                 assert.Nil(t, err)
289                 assert.True(t, c.SkipSSLVerify)
290         }
291 }
292
293 func TestNewClientWithOSSSLVerify(t *testing.T) {
294         c, err := NewClient(nil)
295         assert.Nil(t, err)
296         assert.False(t, c.SkipSSLVerify)
297
298         for _, value := range []string{"false", "0", "f"} {
299                 c, err = NewClient(NewContext(nil, map[string]string{
300                         "GIT_SSL_NO_VERIFY": value,
301                 }, nil))
302                 t.Logf("GIT_SSL_NO_VERIFY: %q", value)
303                 assert.Nil(t, err)
304                 assert.False(t, c.SkipSSLVerify)
305         }
306
307         for _, value := range []string{"true", "1", "t"} {
308                 c, err = NewClient(NewContext(nil, map[string]string{
309                         "GIT_SSL_NO_VERIFY": value,
310                 }, nil))
311                 t.Logf("GIT_SSL_NO_VERIFY: %q", value)
312                 assert.Nil(t, err)
313                 assert.True(t, c.SkipSSLVerify)
314         }
315 }
316
317 func TestNewRequest(t *testing.T) {
318         tests := [][]string{
319                 {"https://example.com", "a", "https://example.com/a"},
320                 {"https://example.com/", "a", "https://example.com/a"},
321                 {"https://example.com/a", "b", "https://example.com/a/b"},
322                 {"https://example.com/a/", "b", "https://example.com/a/b"},
323         }
324
325         for _, test := range tests {
326                 c, err := NewClient(NewContext(nil, nil, map[string]string{
327                         "lfs.url": test[0],
328                 }))
329                 require.Nil(t, err)
330
331                 req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), test[1], nil)
332                 require.Nil(t, err)
333                 assert.Equal(t, "POST", req.Method)
334                 assert.Equal(t, test[2], req.URL.String(), fmt.Sprintf("endpoint: %s, suffix: %s, expected: %s", test[0], test[1], test[2]))
335         }
336 }
337
338 func TestNewRequestWithBody(t *testing.T) {
339         c, err := NewClient(NewContext(nil, nil, map[string]string{
340                 "lfs.url": "https://example.com",
341         }))
342         require.Nil(t, err)
343
344         body := struct {
345                 Test string
346         }{Test: "test"}
347         req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), "body", body)
348         require.Nil(t, err)
349
350         assert.NotNil(t, req.Body)
351         assert.Equal(t, "15", req.Header.Get("Content-Length"))
352         assert.EqualValues(t, 15, req.ContentLength)
353 }
354
355 func TestMarshalToRequest(t *testing.T) {
356         req, err := http.NewRequest("POST", "https://foo/bar", nil)
357         require.Nil(t, err)
358
359         assert.Nil(t, req.Body)
360         assert.Equal(t, "", req.Header.Get("Content-Length"))
361         assert.EqualValues(t, 0, req.ContentLength)
362
363         body := struct {
364                 Test string
365         }{Test: "test"}
366         require.Nil(t, MarshalToRequest(req, body))
367
368         assert.NotNil(t, req.Body)
369         assert.Equal(t, "15", req.Header.Get("Content-Length"))
370         assert.EqualValues(t, 15, req.ContentLength)
371 }