Imported Upstream version 2.4.2
[scm/test.git] / lfsapi / client_test.go
index 456fd32..f5500f8 100644 (file)
@@ -1,11 +1,13 @@
 package lfsapi
 
 import (
+       "encoding/base64"
        "encoding/json"
        "fmt"
        "net"
        "net/http"
        "net/http/httptest"
+       "strings"
        "sync/atomic"
        "testing"
 
@@ -111,7 +113,7 @@ func TestClientRedirect(t *testing.T) {
        srv3Https = srv3.URL
        srv3Http = fmt.Sprintf("http://%s", srv3InsecureListener.Addr().String())
 
-       c, err := NewClient(nil, UniqTestEnv(map[string]string{
+       c, err := NewClient(NewContext(nil, nil, map[string]string{
                fmt.Sprintf("http.%s.sslverify", srv3Https):  "false",
                fmt.Sprintf("http.%s/.sslverify", srv3Https): "false",
                fmt.Sprintf("http.%s.sslverify", srv3Http):   "false",
@@ -175,8 +177,82 @@ func TestClientRedirect(t *testing.T) {
        assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http")
 }
 
+func TestClientRedirectReauthenticate(t *testing.T) {
+       var srv1, srv2 *httptest.Server
+       var called1, called2 uint32
+       var creds1, creds2 Creds
+
+       srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               atomic.AddUint32(&called1, 1)
+
+               if hdr := r.Header.Get("Authorization"); len(hdr) > 0 {
+                       parts := strings.SplitN(hdr, " ", 2)
+                       typ, b64 := parts[0], parts[1]
+
+                       auth, err := base64.URLEncoding.DecodeString(b64)
+                       assert.Nil(t, err)
+                       assert.Equal(t, "Basic", typ)
+                       assert.Equal(t, "user1:pass1", string(auth))
+
+                       http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently)
+                       return
+               }
+               w.WriteHeader(http.StatusUnauthorized)
+       }))
+
+       srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               atomic.AddUint32(&called2, 1)
+
+               parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
+               typ, b64 := parts[0], parts[1]
+
+               auth, err := base64.URLEncoding.DecodeString(b64)
+               assert.Nil(t, err)
+               assert.Equal(t, "Basic", typ)
+               assert.Equal(t, "user2:pass2", string(auth))
+       }))
+
+       // Change the URL of srv2 to make it appears as if it is a different
+       // host.
+       srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1)
+
+       creds1 = Creds(map[string]string{
+               "protocol": "http",
+               "host":     strings.TrimPrefix(srv1.URL, "http://"),
+
+               "username": "user1",
+               "password": "pass1",
+       })
+       creds2 = Creds(map[string]string{
+               "protocol": "http",
+               "host":     strings.TrimPrefix(srv2.URL, "http://"),
+
+               "username": "user2",
+               "password": "pass2",
+       })
+
+       defer srv1.Close()
+       defer srv2.Close()
+
+       c, err := NewClient(NewContext(nil, nil, nil))
+       creds := newCredentialCacher()
+       creds.Approve(creds1)
+       creds.Approve(creds2)
+       c.Credentials = creds
+
+       req, err := http.NewRequest("GET", srv1.URL, nil)
+       require.Nil(t, err)
+
+       _, err = c.DoWithAuth("", req)
+       assert.Nil(t, err)
+
+       // called1 is 2 since LFS tries an unauthenticated request first
+       assert.EqualValues(t, 2, called1)
+       assert.EqualValues(t, 1, called2)
+}
+
 func TestNewClient(t *testing.T) {
-       c, err := NewClient(UniqTestEnv(map[string]string{}), UniqTestEnv(map[string]string{
+       c, err := NewClient(NewContext(nil, nil, map[string]string{
                "lfs.dialtimeout":         "151",
                "lfs.keepalive":           "152",
                "lfs.tlstimeout":          "153",
@@ -191,12 +267,12 @@ func TestNewClient(t *testing.T) {
 }
 
 func TestNewClientWithGitSSLVerify(t *testing.T) {
-       c, err := NewClient(nil, nil)
+       c, err := NewClient(nil)
        assert.Nil(t, err)
        assert.False(t, c.SkipSSLVerify)
 
        for _, value := range []string{"true", "1", "t"} {
-               c, err = NewClient(UniqTestEnv(map[string]string{}), UniqTestEnv(map[string]string{
+               c, err = NewClient(NewContext(nil, nil, map[string]string{
                        "http.sslverify": value,
                }))
                t.Logf("http.sslverify: %q", value)
@@ -205,7 +281,7 @@ func TestNewClientWithGitSSLVerify(t *testing.T) {
        }
 
        for _, value := range []string{"false", "0", "f"} {
-               c, err = NewClient(UniqTestEnv(map[string]string{}), UniqTestEnv(map[string]string{
+               c, err = NewClient(NewContext(nil, nil, map[string]string{
                        "http.sslverify": value,
                }))
                t.Logf("http.sslverify: %q", value)
@@ -215,23 +291,23 @@ func TestNewClientWithGitSSLVerify(t *testing.T) {
 }
 
 func TestNewClientWithOSSSLVerify(t *testing.T) {
-       c, err := NewClient(nil, nil)
+       c, err := NewClient(nil)
        assert.Nil(t, err)
        assert.False(t, c.SkipSSLVerify)
 
        for _, value := range []string{"false", "0", "f"} {
-               c, err = NewClient(UniqTestEnv(map[string]string{
+               c, err = NewClient(NewContext(nil, map[string]string{
                        "GIT_SSL_NO_VERIFY": value,
-               }), UniqTestEnv(map[string]string{}))
+               }, nil))
                t.Logf("GIT_SSL_NO_VERIFY: %q", value)
                assert.Nil(t, err)
                assert.False(t, c.SkipSSLVerify)
        }
 
        for _, value := range []string{"true", "1", "t"} {
-               c, err = NewClient(UniqTestEnv(map[string]string{
+               c, err = NewClient(NewContext(nil, map[string]string{
                        "GIT_SSL_NO_VERIFY": value,
-               }), UniqTestEnv(map[string]string{}))
+               }, nil))
                t.Logf("GIT_SSL_NO_VERIFY: %q", value)
                assert.Nil(t, err)
                assert.True(t, c.SkipSSLVerify)
@@ -247,7 +323,7 @@ func TestNewRequest(t *testing.T) {
        }
 
        for _, test := range tests {
-               c, err := NewClient(nil, UniqTestEnv(map[string]string{
+               c, err := NewClient(NewContext(nil, nil, map[string]string{
                        "lfs.url": test[0],
                }))
                require.Nil(t, err)
@@ -260,7 +336,7 @@ func TestNewRequest(t *testing.T) {
 }
 
 func TestNewRequestWithBody(t *testing.T) {
-       c, err := NewClient(nil, UniqTestEnv(map[string]string{
+       c, err := NewClient(NewContext(nil, nil, map[string]string{
                "lfs.url": "https://example.com",
        }))
        require.Nil(t, err)