e31a7e92fa9504d507162b0486dde127ec253c79
[scm/test.git] / lfsapi / client_test.go
1 package lfsapi
2
3 import (
4         "encoding/json"
5         "fmt"
6         "net"
7         "net/http"
8         "net/http/httptest"
9         "sync/atomic"
10         "testing"
11
12         "github.com/stretchr/testify/assert"
13         "github.com/stretchr/testify/require"
14 )
15
16 type redirectTest struct {
17         Test string
18 }
19
20 func TestClientRedirect(t *testing.T) {
21         var srv3Https, srv3Http string
22
23         var called1 uint32
24         var called2 uint32
25         var called3 uint32
26         srv3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27                 atomic.AddUint32(&called3, 1)
28                 t.Logf("srv3 req %s %s", r.Method, r.URL.Path)
29                 assert.Equal(t, "POST", r.Method)
30
31                 switch r.URL.Path {
32                 case "/upgrade":
33                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
34                         assert.Equal(t, "1", r.Header.Get("A"))
35                         w.Header().Set("Location", srv3Https+"/upgraded")
36                         w.WriteHeader(301)
37                 case "/upgraded":
38                         // Since srv3 listens on both a TLS-enabled socket and a
39                         // TLS-disabled one, they are two different hosts.
40                         // Ensure that, even though this is a "secure" upgrade,
41                         // the authorization header is stripped.
42                         assert.Equal(t, "", r.Header.Get("Authorization"))
43                         assert.Equal(t, "1", r.Header.Get("A"))
44
45                 case "/downgrade":
46                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
47                         assert.Equal(t, "1", r.Header.Get("A"))
48                         w.Header().Set("Location", srv3Http+"/404")
49                         w.WriteHeader(301)
50
51                 default:
52                         w.WriteHeader(404)
53                 }
54         }))
55
56         srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57                 atomic.AddUint32(&called2, 1)
58                 t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
59                 assert.Equal(t, "POST", r.Method)
60
61                 switch r.URL.Path {
62                 case "/ok":
63                         assert.Equal(t, "", r.Header.Get("Authorization"))
64                         assert.Equal(t, "1", r.Header.Get("A"))
65                         body := &redirectTest{}
66                         err := json.NewDecoder(r.Body).Decode(body)
67                         assert.Nil(t, err)
68                         assert.Equal(t, "External", body.Test)
69
70                         w.WriteHeader(200)
71                 default:
72                         w.WriteHeader(404)
73                 }
74         }))
75
76         srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77                 atomic.AddUint32(&called1, 1)
78                 t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
79                 assert.Equal(t, "POST", r.Method)
80
81                 switch r.URL.Path {
82                 case "/local":
83                         w.Header().Set("Location", "/ok")
84                         w.WriteHeader(307)
85                 case "/external":
86                         w.Header().Set("Location", srv2.URL+"/ok")
87                         w.WriteHeader(307)
88                 case "/ok":
89                         assert.Equal(t, "auth", r.Header.Get("Authorization"))
90                         assert.Equal(t, "1", r.Header.Get("A"))
91                         body := &redirectTest{}
92                         err := json.NewDecoder(r.Body).Decode(body)
93                         assert.Nil(t, err)
94                         assert.Equal(t, "Local", body.Test)
95
96                         w.WriteHeader(200)
97                 default:
98                         w.WriteHeader(404)
99                 }
100         }))
101         defer srv1.Close()
102         defer srv2.Close()
103         defer srv3.Close()
104
105         srv3InsecureListener, err := net.Listen("tcp", "127.0.0.1:0")
106         require.Nil(t, err)
107
108         go http.Serve(srv3InsecureListener, srv3.Config.Handler)
109         defer srv3InsecureListener.Close()
110
111         srv3Https = srv3.URL
112         srv3Http = fmt.Sprintf("http://%s", srv3InsecureListener.Addr().String())
113
114         c, err := NewClient(NewContext(nil, nil, map[string]string{
115                 fmt.Sprintf("http.%s.sslverify", srv3Https):  "false",
116                 fmt.Sprintf("http.%s/.sslverify", srv3Https): "false",
117                 fmt.Sprintf("http.%s.sslverify", srv3Http):   "false",
118                 fmt.Sprintf("http.%s/.sslverify", srv3Http):  "false",
119                 fmt.Sprintf("http.sslverify"):                "false",
120         }))
121         require.Nil(t, err)
122
123         // local redirect
124         req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
125         require.Nil(t, err)
126         req.Header.Set("Authorization", "auth")
127         req.Header.Set("A", "1")
128
129         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
130
131         res, err := c.Do(req)
132         require.Nil(t, err)
133         assert.Equal(t, 200, res.StatusCode)
134         assert.EqualValues(t, 2, called1)
135         assert.EqualValues(t, 0, called2)
136
137         // external redirect
138         req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
139         require.Nil(t, err)
140         req.Header.Set("Authorization", "auth")
141         req.Header.Set("A", "1")
142
143         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
144
145         res, err = c.Do(req)
146         require.Nil(t, err)
147         assert.Equal(t, 200, res.StatusCode)
148         assert.EqualValues(t, 3, called1)
149         assert.EqualValues(t, 1, called2)
150
151         // http -> https (secure upgrade)
152
153         req, err = http.NewRequest("POST", srv3Http+"/upgrade", nil)
154         require.Nil(t, err)
155         req.Header.Set("Authorization", "auth")
156         req.Header.Set("A", "1")
157
158         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "http->https"}))
159
160         res, err = c.Do(req)
161         require.Nil(t, err)
162         assert.Equal(t, 200, res.StatusCode)
163         assert.EqualValues(t, 2, atomic.LoadUint32(&called3))
164
165         // https -> http (insecure downgrade)
166
167         req, err = http.NewRequest("POST", srv3Https+"/downgrade", nil)
168         require.Nil(t, err)
169         req.Header.Set("Authorization", "auth")
170         req.Header.Set("A", "1")
171
172         require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "https->http"}))
173
174         _, err = c.Do(req)
175         assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http")
176 }
177
178 func TestNewClient(t *testing.T) {
179         c, err := NewClient(NewContext(nil, nil, map[string]string{
180                 "lfs.dialtimeout":         "151",
181                 "lfs.keepalive":           "152",
182                 "lfs.tlstimeout":          "153",
183                 "lfs.concurrenttransfers": "154",
184         }))
185
186         require.Nil(t, err)
187         assert.Equal(t, 151, c.DialTimeout)
188         assert.Equal(t, 152, c.KeepaliveTimeout)
189         assert.Equal(t, 153, c.TLSTimeout)
190         assert.Equal(t, 154, c.ConcurrentTransfers)
191 }
192
193 func TestNewClientWithGitSSLVerify(t *testing.T) {
194         c, err := NewClient(nil)
195         assert.Nil(t, err)
196         assert.False(t, c.SkipSSLVerify)
197
198         for _, value := range []string{"true", "1", "t"} {
199                 c, err = NewClient(NewContext(nil, nil, map[string]string{
200                         "http.sslverify": value,
201                 }))
202                 t.Logf("http.sslverify: %q", value)
203                 assert.Nil(t, err)
204                 assert.False(t, c.SkipSSLVerify)
205         }
206
207         for _, value := range []string{"false", "0", "f"} {
208                 c, err = NewClient(NewContext(nil, nil, map[string]string{
209                         "http.sslverify": value,
210                 }))
211                 t.Logf("http.sslverify: %q", value)
212                 assert.Nil(t, err)
213                 assert.True(t, c.SkipSSLVerify)
214         }
215 }
216
217 func TestNewClientWithOSSSLVerify(t *testing.T) {
218         c, err := NewClient(nil)
219         assert.Nil(t, err)
220         assert.False(t, c.SkipSSLVerify)
221
222         for _, value := range []string{"false", "0", "f"} {
223                 c, err = NewClient(NewContext(nil, map[string]string{
224                         "GIT_SSL_NO_VERIFY": value,
225                 }, nil))
226                 t.Logf("GIT_SSL_NO_VERIFY: %q", value)
227                 assert.Nil(t, err)
228                 assert.False(t, c.SkipSSLVerify)
229         }
230
231         for _, value := range []string{"true", "1", "t"} {
232                 c, err = NewClient(NewContext(nil, map[string]string{
233                         "GIT_SSL_NO_VERIFY": value,
234                 }, nil))
235                 t.Logf("GIT_SSL_NO_VERIFY: %q", value)
236                 assert.Nil(t, err)
237                 assert.True(t, c.SkipSSLVerify)
238         }
239 }
240
241 func TestNewRequest(t *testing.T) {
242         tests := [][]string{
243                 {"https://example.com", "a", "https://example.com/a"},
244                 {"https://example.com/", "a", "https://example.com/a"},
245                 {"https://example.com/a", "b", "https://example.com/a/b"},
246                 {"https://example.com/a/", "b", "https://example.com/a/b"},
247         }
248
249         for _, test := range tests {
250                 c, err := NewClient(NewContext(nil, nil, map[string]string{
251                         "lfs.url": test[0],
252                 }))
253                 require.Nil(t, err)
254
255                 req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), test[1], nil)
256                 require.Nil(t, err)
257                 assert.Equal(t, "POST", req.Method)
258                 assert.Equal(t, test[2], req.URL.String(), fmt.Sprintf("endpoint: %s, suffix: %s, expected: %s", test[0], test[1], test[2]))
259         }
260 }
261
262 func TestNewRequestWithBody(t *testing.T) {
263         c, err := NewClient(NewContext(nil, nil, map[string]string{
264                 "lfs.url": "https://example.com",
265         }))
266         require.Nil(t, err)
267
268         body := struct {
269                 Test string
270         }{Test: "test"}
271         req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), "body", body)
272         require.Nil(t, err)
273
274         assert.NotNil(t, req.Body)
275         assert.Equal(t, "15", req.Header.Get("Content-Length"))
276         assert.EqualValues(t, 15, req.ContentLength)
277 }
278
279 func TestMarshalToRequest(t *testing.T) {
280         req, err := http.NewRequest("POST", "https://foo/bar", nil)
281         require.Nil(t, err)
282
283         assert.Nil(t, req.Body)
284         assert.Equal(t, "", req.Header.Get("Content-Length"))
285         assert.EqualValues(t, 0, req.ContentLength)
286
287         body := struct {
288                 Test string
289         }{Test: "test"}
290         require.Nil(t, MarshalToRequest(req, body))
291
292         assert.NotNil(t, req.Body)
293         assert.Equal(t, "15", req.Header.Get("Content-Length"))
294         assert.EqualValues(t, 15, req.ContentLength)
295 }