Imported Upstream version 2.4.2
[scm/test.git] / test / cmd / lfstest-gitserver.go
1 // +build testtools
2
3 package main
4
5 import (
6         "bufio"
7         "bytes"
8         "crypto/rand"
9         "crypto/rsa"
10         "crypto/sha256"
11         "crypto/tls"
12         "crypto/x509"
13         "crypto/x509/pkix"
14         "encoding/base64"
15         "encoding/hex"
16         "encoding/json"
17         "encoding/pem"
18         "errors"
19         "fmt"
20         "io"
21         "io/ioutil"
22         "log"
23         "math"
24         "math/big"
25         "net/http"
26         "net/http/httptest"
27         "net/textproto"
28         "os"
29         "os/exec"
30         "regexp"
31         "sort"
32         "strconv"
33         "strings"
34         "sync"
35         "time"
36
37         "github.com/ThomsonReutersEikon/go-ntlm/ntlm"
38 )
39
40 var (
41         repoDir          string
42         largeObjects     = newLfsStorage()
43         server           *httptest.Server
44         serverTLS        *httptest.Server
45         serverClientCert *httptest.Server
46
47         // maps OIDs to content strings. Both the LFS and Storage test servers below
48         // see OIDs.
49         oidHandlers map[string]string
50
51         // These magic strings tell the test lfs server change their behavior so the
52         // integration tests can check those use cases. Tests will create objects with
53         // the magic strings as the contents.
54         //
55         //   printf "status:lfs:404" > 404.dat
56         //
57         contentHandlers = []string{
58                 "status-batch-403", "status-batch-404", "status-batch-410", "status-batch-422", "status-batch-500",
59                 "status-storage-403", "status-storage-404", "status-storage-410", "status-storage-422", "status-storage-500", "status-storage-503",
60                 "status-batch-resume-206", "batch-resume-fail-fallback", "return-expired-action", "return-expired-action-forever", "return-invalid-size",
61                 "object-authenticated", "storage-download-retry", "storage-upload-retry", "unknown-oid",
62                 "send-verify-action", "send-deprecated-links",
63         }
64 )
65
66 func main() {
67         repoDir = os.Getenv("LFSTEST_DIR")
68
69         mux := http.NewServeMux()
70         server = httptest.NewServer(mux)
71         serverTLS = httptest.NewTLSServer(mux)
72         serverClientCert = httptest.NewUnstartedServer(mux)
73
74         //setup Client Cert server
75         rootKey, rootCert := generateCARootCertificates()
76         _, clientCertPEM, clientKeyPEM := generateClientCertificates(rootCert, rootKey)
77
78         certPool := x509.NewCertPool()
79         certPool.AddCert(rootCert)
80
81         serverClientCert.TLS = &tls.Config{
82                 Certificates: []tls.Certificate{serverTLS.TLS.Certificates[0]},
83                 ClientAuth:   tls.RequireAndVerifyClientCert,
84                 ClientCAs:    certPool,
85         }
86         serverClientCert.StartTLS()
87
88         ntlmSession, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
89         if err != nil {
90                 fmt.Println("Error creating ntlm session:", err)
91                 os.Exit(1)
92         }
93         ntlmSession.SetUserInfo("ntlmuser", "ntlmpass", "NTLMDOMAIN")
94
95         stopch := make(chan bool)
96
97         mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
98                 stopch <- true
99         })
100
101         mux.HandleFunc("/storage/", storageHandler)
102         mux.HandleFunc("/verify", verifyHandler)
103         mux.HandleFunc("/redirect307/", redirect307Handler)
104         mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
105                 id, ok := reqId(w)
106                 if !ok {
107                         return
108                 }
109
110                 if strings.Contains(r.URL.Path, "/info/lfs") {
111                         if !skipIfBadAuth(w, r, id, ntlmSession) {
112                                 lfsHandler(w, r, id)
113                         }
114
115                         return
116                 }
117
118                 debug(id, "git http-backend %s %s", r.Method, r.URL)
119                 gitHandler(w, r)
120         })
121
122         urlname := writeTestStateFile([]byte(server.URL), "LFSTEST_URL", "lfstest-gitserver")
123         defer os.RemoveAll(urlname)
124
125         sslurlname := writeTestStateFile([]byte(serverTLS.URL), "LFSTEST_SSL_URL", "lfstest-gitserver-ssl")
126         defer os.RemoveAll(sslurlname)
127
128         clientCertUrlname := writeTestStateFile([]byte(serverClientCert.URL), "LFSTEST_CLIENT_CERT_URL", "lfstest-gitserver-ssl")
129         defer os.RemoveAll(clientCertUrlname)
130
131         block := &pem.Block{}
132         block.Type = "CERTIFICATE"
133         block.Bytes = serverTLS.TLS.Certificates[0].Certificate[0]
134         pembytes := pem.EncodeToMemory(block)
135
136         certname := writeTestStateFile(pembytes, "LFSTEST_CERT", "lfstest-gitserver-cert")
137         defer os.RemoveAll(certname)
138
139         cccertname := writeTestStateFile(clientCertPEM, "LFSTEST_CLIENT_CERT", "lfstest-gitserver-client-cert")
140         defer os.RemoveAll(cccertname)
141
142         ckcertname := writeTestStateFile(clientKeyPEM, "LFSTEST_CLIENT_KEY", "lfstest-gitserver-client-key")
143         defer os.RemoveAll(ckcertname)
144
145         debug("init", "server url: %s", server.URL)
146         debug("init", "server tls url: %s", serverTLS.URL)
147         debug("init", "server client cert url: %s", serverClientCert.URL)
148
149         <-stopch
150         debug("init", "git server done")
151 }
152
153 // writeTestStateFile writes contents to either the file referenced by the
154 // environment variable envVar, or defaultFilename if that's not set. Returns
155 // the filename that was used
156 func writeTestStateFile(contents []byte, envVar, defaultFilename string) string {
157         f := os.Getenv(envVar)
158         if len(f) == 0 {
159                 f = defaultFilename
160         }
161         file, err := os.Create(f)
162         if err != nil {
163                 log.Fatalln(err)
164         }
165         file.Write(contents)
166         file.Close()
167         return f
168 }
169
170 type lfsObject struct {
171         Oid           string              `json:"oid,omitempty"`
172         Size          int64               `json:"size,omitempty"`
173         Authenticated bool                `json:"authenticated,omitempty"`
174         Actions       map[string]*lfsLink `json:"actions,omitempty"`
175         Links         map[string]*lfsLink `json:"_links,omitempty"`
176         Err           *lfsError           `json:"error,omitempty"`
177 }
178
179 type lfsLink struct {
180         Href      string            `json:"href"`
181         Header    map[string]string `json:"header,omitempty"`
182         ExpiresAt time.Time         `json:"expires_at,omitempty"`
183         ExpiresIn int               `json:"expires_in,omitempty"`
184 }
185
186 type lfsError struct {
187         Code    int    `json:"code,omitempty"`
188         Message string `json:"message"`
189 }
190
191 func writeLFSError(w http.ResponseWriter, code int, msg string) {
192         by, err := json.Marshal(&lfsError{Message: msg})
193         if err != nil {
194                 http.Error(w, "json encoding error: "+err.Error(), 500)
195                 return
196         }
197
198         w.Header().Set("Content-Type", "application/vnd.git-lfs+json")
199         w.WriteHeader(code)
200         w.Write(by)
201 }
202
203 // handles any requests with "{name}.server.git/info/lfs" in the path
204 func lfsHandler(w http.ResponseWriter, r *http.Request, id string) {
205         repo, err := repoFromLfsUrl(r.URL.Path)
206         if err != nil {
207                 w.WriteHeader(500)
208                 w.Write([]byte(err.Error()))
209                 return
210         }
211
212         debug(id, "git lfs %s %s repo: %s", r.Method, r.URL, repo)
213         w.Header().Set("Content-Type", "application/vnd.git-lfs+json")
214         switch r.Method {
215         case "POST":
216                 if strings.HasSuffix(r.URL.String(), "batch") {
217                         lfsBatchHandler(w, r, id, repo)
218                 } else {
219                         locksHandler(w, r, repo)
220                 }
221         case "DELETE":
222                 lfsDeleteHandler(w, r, id, repo)
223         case "GET":
224                 if strings.Contains(r.URL.String(), "/locks") {
225                         locksHandler(w, r, repo)
226                 } else {
227                         w.WriteHeader(404)
228                         w.Write([]byte("lock request"))
229                 }
230         default:
231                 w.WriteHeader(405)
232         }
233 }
234
235 func lfsUrl(repo, oid string) string {
236         return server.URL + "/storage/" + oid + "?r=" + repo
237 }
238
239 var (
240         retries   = make(map[string]uint32)
241         retriesMu sync.Mutex
242 )
243
244 func incrementRetriesFor(api, direction, repo, oid string, check bool) (after uint32, ok bool) {
245         // fmtStr formats a string like "<api>-<direction>-[check]-<retry>",
246         // i.e., "legacy-upload-check-retry", or "storage-download-retry".
247         var fmtStr string
248         if check {
249                 fmtStr = "%s-%s-check-retry"
250         } else {
251                 fmtStr = "%s-%s-retry"
252         }
253
254         if oidHandlers[oid] != fmt.Sprintf(fmtStr, api, direction) {
255                 return 0, false
256         }
257
258         retriesMu.Lock()
259         defer retriesMu.Unlock()
260
261         retryKey := strings.Join([]string{direction, repo, oid}, ":")
262
263         retries[retryKey]++
264         retries := retries[retryKey]
265
266         return retries, true
267 }
268
269 func lfsDeleteHandler(w http.ResponseWriter, r *http.Request, id, repo string) {
270         parts := strings.Split(r.URL.Path, "/")
271         oid := parts[len(parts)-1]
272
273         largeObjects.Delete(repo, oid)
274         debug(id, "DELETE:", oid)
275         w.WriteHeader(200)
276 }
277
278 type batchReq struct {
279         Transfers []string    `json:"transfers"`
280         Operation string      `json:"operation"`
281         Objects   []lfsObject `json:"objects"`
282         Ref       *Ref        `json:"ref,omitempty"`
283 }
284
285 func (r *batchReq) RefName() string {
286         if r.Ref == nil {
287                 return ""
288         }
289         return r.Ref.Name
290 }
291
292 type batchResp struct {
293         Transfer string      `json:"transfer,omitempty"`
294         Objects  []lfsObject `json:"objects"`
295 }
296
297 func lfsBatchHandler(w http.ResponseWriter, r *http.Request, id, repo string) {
298         checkingObject := r.Header.Get("X-Check-Object") == "1"
299         if !checkingObject && repo == "batchunsupported" {
300                 w.WriteHeader(404)
301                 return
302         }
303
304         if !checkingObject && repo == "badbatch" {
305                 w.WriteHeader(203)
306                 return
307         }
308
309         if repo == "netrctest" {
310                 user, pass, err := extractAuth(r.Header.Get("Authorization"))
311                 if err != nil || (user != "netrcuser" || pass != "netrcpass") {
312                         w.WriteHeader(403)
313                         return
314                 }
315         }
316
317         if missingRequiredCreds(w, r, repo) {
318                 return
319         }
320
321         buf := &bytes.Buffer{}
322         tee := io.TeeReader(r.Body, buf)
323         objs := &batchReq{}
324         err := json.NewDecoder(tee).Decode(objs)
325         io.Copy(ioutil.Discard, r.Body)
326         r.Body.Close()
327
328         debug(id, "REQUEST")
329         debug(id, buf.String())
330
331         if err != nil {
332                 log.Fatal(err)
333         }
334
335         if strings.HasSuffix(repo, "branch-required") {
336                 parts := strings.Split(repo, "-")
337                 lenParts := len(parts)
338                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != objs.RefName() {
339                         w.WriteHeader(403)
340                         json.NewEncoder(w).Encode(struct {
341                                 Message string `json:"message"`
342                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], objs.RefName())})
343                         return
344                 }
345         }
346
347         res := []lfsObject{}
348         testingChunked := testingChunkedTransferEncoding(r)
349         testingTus := testingTusUploadInBatchReq(r)
350         testingTusInterrupt := testingTusUploadInterruptedInBatchReq(r)
351         testingCustomTransfer := testingCustomTransfer(r)
352         var transferChoice string
353         var searchForTransfer string
354         if testingTus {
355                 searchForTransfer = "tus"
356         } else if testingCustomTransfer {
357                 searchForTransfer = "testcustom"
358         }
359         if len(searchForTransfer) > 0 {
360                 for _, t := range objs.Transfers {
361                         if t == searchForTransfer {
362                                 transferChoice = searchForTransfer
363                                 break
364                         }
365
366                 }
367         }
368         for _, obj := range objs.Objects {
369                 handler := oidHandlers[obj.Oid]
370                 action := objs.Operation
371
372                 o := lfsObject{
373                         Size:    obj.Size,
374                         Actions: make(map[string]*lfsLink),
375                 }
376
377                 // Clobber the OID if told to do so.
378                 if handler == "unknown-oid" {
379                         o.Oid = "unknown-oid"
380                 } else {
381                         o.Oid = obj.Oid
382                 }
383
384                 exists := largeObjects.Has(repo, obj.Oid)
385                 addAction := true
386                 if action == "download" {
387                         if !exists {
388                                 o.Err = &lfsError{Code: 404, Message: fmt.Sprintf("Object %v does not exist", obj.Oid)}
389                                 addAction = false
390                         }
391                 } else {
392                         if exists {
393                                 // not an error but don't add an action
394                                 addAction = false
395                         }
396                 }
397
398                 if handler == "object-authenticated" {
399                         o.Authenticated = true
400                 }
401
402                 switch handler {
403                 case "status-batch-403":
404                         o.Err = &lfsError{Code: 403, Message: "welp"}
405                 case "status-batch-404":
406                         o.Err = &lfsError{Code: 404, Message: "welp"}
407                 case "status-batch-410":
408                         o.Err = &lfsError{Code: 410, Message: "welp"}
409                 case "status-batch-422":
410                         o.Err = &lfsError{Code: 422, Message: "welp"}
411                 case "status-batch-500":
412                         o.Err = &lfsError{Code: 500, Message: "welp"}
413                 default: // regular 200 response
414                         if handler == "return-invalid-size" {
415                                 o.Size = -1
416                         }
417
418                         if handler == "send-deprecated-links" {
419                                 o.Links = make(map[string]*lfsLink)
420                         }
421
422                         if addAction {
423                                 a := &lfsLink{
424                                         Href:   lfsUrl(repo, obj.Oid),
425                                         Header: map[string]string{},
426                                 }
427                                 a = serveExpired(a, repo, handler)
428
429                                 if handler == "send-deprecated-links" {
430                                         o.Links[action] = a
431                                 } else {
432                                         o.Actions[action] = a
433                                 }
434                         }
435
436                         if handler == "send-verify-action" {
437                                 o.Actions["verify"] = &lfsLink{
438                                         Href: server.URL + "/verify",
439                                         Header: map[string]string{
440                                                 "repo": repo,
441                                         },
442                                 }
443                         }
444                 }
445
446                 if testingChunked && addAction {
447                         if handler == "send-deprecated-links" {
448                                 o.Links[action].Header["Transfer-Encoding"] = "chunked"
449                         } else {
450                                 o.Actions[action].Header["Transfer-Encoding"] = "chunked"
451                         }
452                 }
453                 if testingTusInterrupt && addAction {
454                         if handler == "send-deprecated-links" {
455                                 o.Links[action].Header["Lfs-Tus-Interrupt"] = "true"
456                         } else {
457                                 o.Actions[action].Header["Lfs-Tus-Interrupt"] = "true"
458                         }
459                 }
460
461                 res = append(res, o)
462         }
463
464         ores := batchResp{Transfer: transferChoice, Objects: res}
465
466         by, err := json.Marshal(ores)
467         if err != nil {
468                 log.Fatal(err)
469         }
470
471         debug(id, "RESPONSE: 200")
472         debug(id, string(by))
473
474         w.WriteHeader(200)
475         w.Write(by)
476 }
477
478 // emu guards expiredRepos
479 var emu sync.Mutex
480
481 // expiredRepos is a map keyed by repository name, valuing to whether or not it
482 // has yet served an expired object.
483 var expiredRepos = map[string]bool{}
484
485 // serveExpired marks the given repo as having served an expired object, making
486 // it unable for that same repository to return an expired object in the future,
487 func serveExpired(a *lfsLink, repo, handler string) *lfsLink {
488         var (
489                 dur = -5 * time.Minute
490                 at  = time.Now().Add(dur)
491         )
492
493         if handler == "return-expired-action-forever" ||
494                 (handler == "return-expired-action" && canServeExpired(repo)) {
495
496                 emu.Lock()
497                 expiredRepos[repo] = true
498                 emu.Unlock()
499
500                 a.ExpiresAt = at
501                 return a
502         }
503
504         switch repo {
505         case "expired-absolute":
506                 a.ExpiresAt = at
507         case "expired-relative":
508                 a.ExpiresIn = -5
509         case "expired-both":
510                 a.ExpiresAt = at
511                 a.ExpiresIn = -5
512         }
513
514         return a
515 }
516
517 // canServeExpired returns whether or not a repository is capable of serving an
518 // expired object. In other words, canServeExpired returns whether or not the
519 // given repo has yet served an expired object.
520 func canServeExpired(repo string) bool {
521         emu.Lock()
522         defer emu.Unlock()
523
524         return !expiredRepos[repo]
525 }
526
527 // Persistent state across requests
528 var batchResumeFailFallbackStorageAttempts = 0
529 var tusStorageAttempts = 0
530
531 var (
532         vmu           sync.Mutex
533         verifyCounts  = make(map[string]int)
534         verifyRetryRe = regexp.MustCompile(`verify-fail-(\d+)-times?$`)
535 )
536
537 func verifyHandler(w http.ResponseWriter, r *http.Request) {
538         repo := r.Header.Get("repo")
539         var payload struct {
540                 Oid  string `json:"oid"`
541                 Size int64  `json:"size"`
542         }
543
544         if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
545                 writeLFSError(w, http.StatusUnprocessableEntity, err.Error())
546                 return
547         }
548
549         var max int
550         if matches := verifyRetryRe.FindStringSubmatch(repo); len(matches) < 2 {
551                 return
552         } else {
553                 max, _ = strconv.Atoi(matches[1])
554         }
555
556         key := strings.Join([]string{repo, payload.Oid}, ":")
557
558         vmu.Lock()
559         verifyCounts[key] = verifyCounts[key] + 1
560         count := verifyCounts[key]
561         vmu.Unlock()
562
563         if count < max {
564                 writeLFSError(w, http.StatusServiceUnavailable, fmt.Sprintf(
565                         "intentionally failing verify request %d (out of %d)", count, max,
566                 ))
567                 return
568         }
569 }
570
571 // handles any /storage/{oid} requests
572 func storageHandler(w http.ResponseWriter, r *http.Request) {
573         id, ok := reqId(w)
574         if !ok {
575                 return
576         }
577
578         repo := r.URL.Query().Get("r")
579         parts := strings.Split(r.URL.Path, "/")
580         oid := parts[len(parts)-1]
581         if missingRequiredCreds(w, r, repo) {
582                 return
583         }
584
585         debug(id, "storage %s %s repo: %s", r.Method, oid, repo)
586         switch r.Method {
587         case "PUT":
588                 switch oidHandlers[oid] {
589                 case "status-storage-403":
590                         w.WriteHeader(403)
591                         return
592                 case "status-storage-404":
593                         w.WriteHeader(404)
594                         return
595                 case "status-storage-410":
596                         w.WriteHeader(410)
597                         return
598                 case "status-storage-422":
599                         w.WriteHeader(422)
600                         return
601                 case "status-storage-500":
602                         w.WriteHeader(500)
603                         return
604                 case "status-storage-503":
605                         writeLFSError(w, 503, "LFS is temporarily unavailable")
606                         return
607                 case "object-authenticated":
608                         if len(r.Header.Get("Authorization")) > 0 {
609                                 w.WriteHeader(400)
610                                 w.Write([]byte("Should not send authentication"))
611                         }
612                         return
613                 case "storage-upload-retry":
614                         if retries, ok := incrementRetriesFor("storage", "upload", repo, oid, false); ok && retries < 3 {
615                                 w.WriteHeader(500)
616                                 w.Write([]byte("malformed content"))
617
618                                 return
619                         }
620                 }
621
622                 if testingChunkedTransferEncoding(r) {
623                         valid := false
624                         for _, value := range r.TransferEncoding {
625                                 if value == "chunked" {
626                                         valid = true
627                                         break
628                                 }
629                         }
630                         if !valid {
631                                 debug(id, "Chunked transfer encoding expected")
632                         }
633                 }
634
635                 hash := sha256.New()
636                 buf := &bytes.Buffer{}
637
638                 io.Copy(io.MultiWriter(hash, buf), r.Body)
639                 oid := hex.EncodeToString(hash.Sum(nil))
640                 if !strings.HasSuffix(r.URL.Path, "/"+oid) {
641                         w.WriteHeader(403)
642                         return
643                 }
644
645                 largeObjects.Set(repo, oid, buf.Bytes())
646
647         case "GET":
648                 parts := strings.Split(r.URL.Path, "/")
649                 oid := parts[len(parts)-1]
650                 statusCode := 200
651                 byteLimit := 0
652                 resumeAt := int64(0)
653
654                 if by, ok := largeObjects.Get(repo, oid); ok {
655                         if len(by) == len("storage-download-retry") && string(by) == "storage-download-retry" {
656                                 if retries, ok := incrementRetriesFor("storage", "download", repo, oid, false); ok && retries < 3 {
657                                         statusCode = 500
658                                         by = []byte("malformed content")
659                                 }
660                         } else if len(by) == len("status-batch-resume-206") && string(by) == "status-batch-resume-206" {
661                                 // Resume if header includes range, otherwise deliberately interrupt
662                                 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
663                                         regex := regexp.MustCompile(`bytes=(\d+)\-.*`)
664                                         match := regex.FindStringSubmatch(rangeHdr)
665                                         if match != nil && len(match) > 1 {
666                                                 statusCode = 206
667                                                 resumeAt, _ = strconv.ParseInt(match[1], 10, 32)
668                                                 w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", resumeAt, len(by), resumeAt-int64(len(by))))
669                                         }
670                                 } else {
671                                         byteLimit = 10
672                                 }
673                         } else if len(by) == len("batch-resume-fail-fallback") && string(by) == "batch-resume-fail-fallback" {
674                                 // Fail any Range: request even though we said we supported it
675                                 // To make sure client can fall back
676                                 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
677                                         w.WriteHeader(416)
678                                         return
679                                 }
680                                 if batchResumeFailFallbackStorageAttempts == 0 {
681                                         // Truncate output on FIRST attempt to cause resume
682                                         // Second attempt (without range header) is fallback, complete successfully
683                                         byteLimit = 8
684                                         batchResumeFailFallbackStorageAttempts++
685                                 }
686                         }
687                         w.WriteHeader(statusCode)
688                         if byteLimit > 0 {
689                                 w.Write(by[0:byteLimit])
690                         } else if resumeAt > 0 {
691                                 w.Write(by[resumeAt:])
692                         } else {
693                                 w.Write(by)
694                         }
695                         return
696                 }
697
698                 w.WriteHeader(404)
699         case "HEAD":
700                 // tus.io
701                 if !validateTusHeaders(r, id) {
702                         w.WriteHeader(400)
703                         return
704                 }
705                 parts := strings.Split(r.URL.Path, "/")
706                 oid := parts[len(parts)-1]
707                 var offset int64
708                 if by, ok := largeObjects.GetIncomplete(repo, oid); ok {
709                         offset = int64(len(by))
710                 }
711                 w.Header().Set("Upload-Offset", strconv.FormatInt(offset, 10))
712                 w.WriteHeader(200)
713         case "PATCH":
714                 // tus.io
715                 if !validateTusHeaders(r, id) {
716                         w.WriteHeader(400)
717                         return
718                 }
719                 parts := strings.Split(r.URL.Path, "/")
720                 oid := parts[len(parts)-1]
721
722                 offsetHdr := r.Header.Get("Upload-Offset")
723                 offset, err := strconv.ParseInt(offsetHdr, 10, 64)
724                 if err != nil {
725                         log.Fatal("Unable to parse Upload-Offset header in request: ", err)
726                         w.WriteHeader(400)
727                         return
728                 }
729                 hash := sha256.New()
730                 buf := &bytes.Buffer{}
731                 out := io.MultiWriter(hash, buf)
732
733                 if by, ok := largeObjects.GetIncomplete(repo, oid); ok {
734                         if offset != int64(len(by)) {
735                                 log.Fatal(fmt.Sprintf("Incorrect offset in request, got %d expected %d", offset, len(by)))
736                                 w.WriteHeader(400)
737                                 return
738                         }
739                         _, err := out.Write(by)
740                         if err != nil {
741                                 log.Fatal("Error reading incomplete bytes from store: ", err)
742                                 w.WriteHeader(500)
743                                 return
744                         }
745                         largeObjects.DeleteIncomplete(repo, oid)
746                         debug(id, "Resuming upload of %v at byte %d", oid, offset)
747                 }
748
749                 // As a test, we intentionally break the upload from byte 0 by only
750                 // reading some bytes the quitting & erroring, this forces a resume
751                 // any offset > 0 will work ok
752                 var copyErr error
753                 if r.Header.Get("Lfs-Tus-Interrupt") == "true" && offset == 0 {
754                         chdr := r.Header.Get("Content-Length")
755                         contentLen, err := strconv.ParseInt(chdr, 10, 64)
756                         if err != nil {
757                                 log.Fatal(fmt.Sprintf("Invalid Content-Length %q", chdr))
758                                 w.WriteHeader(400)
759                                 return
760                         }
761                         truncated := contentLen / 3
762                         _, _ = io.CopyN(out, r.Body, truncated)
763                         r.Body.Close()
764                         copyErr = fmt.Errorf("Simulated copy error")
765                 } else {
766                         _, copyErr = io.Copy(out, r.Body)
767                 }
768                 if copyErr != nil {
769                         b := buf.Bytes()
770                         if len(b) > 0 {
771                                 debug(id, "Incomplete upload of %v, %d bytes", oid, len(b))
772                                 largeObjects.SetIncomplete(repo, oid, b)
773                         }
774                         w.WriteHeader(500)
775                 } else {
776                         checkoid := hex.EncodeToString(hash.Sum(nil))
777                         if checkoid != oid {
778                                 log.Fatal(fmt.Sprintf("Incorrect oid after calculation, got %q expected %q", checkoid, oid))
779                                 w.WriteHeader(403)
780                                 return
781                         }
782
783                         b := buf.Bytes()
784                         largeObjects.Set(repo, oid, b)
785                         w.Header().Set("Upload-Offset", strconv.FormatInt(int64(len(b)), 10))
786                         w.WriteHeader(204)
787                 }
788
789         default:
790                 w.WriteHeader(405)
791         }
792 }
793
794 func validateTusHeaders(r *http.Request, id string) bool {
795         if len(r.Header.Get("Tus-Resumable")) == 0 {
796                 debug(id, "Missing Tus-Resumable header in request")
797                 return false
798         }
799         return true
800 }
801
802 func gitHandler(w http.ResponseWriter, r *http.Request) {
803         defer func() {
804                 io.Copy(ioutil.Discard, r.Body)
805                 r.Body.Close()
806         }()
807
808         cmd := exec.Command("git", "http-backend")
809         cmd.Env = []string{
810                 fmt.Sprintf("GIT_PROJECT_ROOT=%s", repoDir),
811                 fmt.Sprintf("GIT_HTTP_EXPORT_ALL="),
812                 fmt.Sprintf("PATH_INFO=%s", r.URL.Path),
813                 fmt.Sprintf("QUERY_STRING=%s", r.URL.RawQuery),
814                 fmt.Sprintf("REQUEST_METHOD=%s", r.Method),
815                 fmt.Sprintf("CONTENT_TYPE=%s", r.Header.Get("Content-Type")),
816         }
817
818         buffer := &bytes.Buffer{}
819         cmd.Stdin = r.Body
820         cmd.Stdout = buffer
821         cmd.Stderr = os.Stderr
822
823         if err := cmd.Run(); err != nil {
824                 log.Fatal(err)
825         }
826
827         text := textproto.NewReader(bufio.NewReader(buffer))
828
829         code, _, _ := text.ReadCodeLine(-1)
830
831         if code != 0 {
832                 w.WriteHeader(code)
833         }
834
835         headers, _ := text.ReadMIMEHeader()
836         head := w.Header()
837         for key, values := range headers {
838                 for _, value := range values {
839                         head.Add(key, value)
840                 }
841         }
842
843         io.Copy(w, text.R)
844 }
845
846 func redirect307Handler(w http.ResponseWriter, r *http.Request) {
847         id, ok := reqId(w)
848         if !ok {
849                 return
850         }
851
852         // Send a redirect to info/lfs
853         // Make it either absolute or relative depending on subpath
854         parts := strings.Split(r.URL.Path, "/")
855         // first element is always blank since rooted
856         var redirectTo string
857         if parts[2] == "rel" {
858                 redirectTo = "/" + strings.Join(parts[3:], "/")
859         } else if parts[2] == "abs" {
860                 redirectTo = server.URL + "/" + strings.Join(parts[3:], "/")
861         } else {
862                 debug(id, "Invalid URL for redirect: %v", r.URL)
863                 w.WriteHeader(404)
864                 return
865         }
866         w.Header().Set("Location", redirectTo)
867         w.WriteHeader(307)
868 }
869
870 type User struct {
871         Name string `json:"name"`
872 }
873
874 type Lock struct {
875         Id       string    `json:"id"`
876         Path     string    `json:"path"`
877         Owner    User      `json:"owner"`
878         LockedAt time.Time `json:"locked_at"`
879 }
880
881 type LockRequest struct {
882         Path string `json:"path"`
883         Ref  *Ref   `json:"ref,omitempty"`
884 }
885
886 func (r *LockRequest) RefName() string {
887         if r.Ref == nil {
888                 return ""
889         }
890         return r.Ref.Name
891 }
892
893 type LockResponse struct {
894         Lock    *Lock  `json:"lock"`
895         Message string `json:"message,omitempty"`
896 }
897
898 type UnlockRequest struct {
899         Force bool `json:"force"`
900         Ref   *Ref `json:"ref,omitempty"`
901 }
902
903 func (r *UnlockRequest) RefName() string {
904         if r.Ref == nil {
905                 return ""
906         }
907         return r.Ref.Name
908 }
909
910 type UnlockResponse struct {
911         Lock    *Lock  `json:"lock"`
912         Message string `json:"message,omitempty"`
913 }
914
915 type LockList struct {
916         Locks      []Lock `json:"locks"`
917         NextCursor string `json:"next_cursor,omitempty"`
918         Message    string `json:"message,omitempty"`
919 }
920
921 type Ref struct {
922         Name string `json:"name,omitempty"`
923 }
924
925 type VerifiableLockRequest struct {
926         Ref    *Ref   `json:"ref,omitempty"`
927         Cursor string `json:"cursor,omitempty"`
928         Limit  int    `json:"limit,omitempty"`
929 }
930
931 func (r *VerifiableLockRequest) RefName() string {
932         if r.Ref == nil {
933                 return ""
934         }
935         return r.Ref.Name
936 }
937
938 type VerifiableLockList struct {
939         Ours       []Lock `json:"ours"`
940         Theirs     []Lock `json:"theirs"`
941         NextCursor string `json:"next_cursor,omitempty"`
942         Message    string `json:"message,omitempty"`
943 }
944
945 var (
946         lmu       sync.RWMutex
947         repoLocks = map[string][]Lock{}
948 )
949
950 func addLocks(repo string, l ...Lock) {
951         lmu.Lock()
952         defer lmu.Unlock()
953         repoLocks[repo] = append(repoLocks[repo], l...)
954         sort.Sort(LocksByCreatedAt(repoLocks[repo]))
955 }
956
957 func getLocks(repo string) []Lock {
958         lmu.RLock()
959         defer lmu.RUnlock()
960
961         locks := repoLocks[repo]
962         cp := make([]Lock, len(locks))
963         for i, l := range locks {
964                 cp[i] = l
965         }
966
967         return cp
968 }
969
970 func getFilteredLocks(repo, path, cursor, limit string) ([]Lock, string, error) {
971         locks := getLocks(repo)
972         if cursor != "" {
973                 lastSeen := -1
974                 for i, l := range locks {
975                         if l.Id == cursor {
976                                 lastSeen = i
977                                 break
978                         }
979                 }
980
981                 if lastSeen > -1 {
982                         locks = locks[lastSeen:]
983                 } else {
984                         return nil, "", fmt.Errorf("cursor (%s) not found", cursor)
985                 }
986         }
987
988         if path != "" {
989                 var filtered []Lock
990                 for _, l := range locks {
991                         if l.Path == path {
992                                 filtered = append(filtered, l)
993                         }
994                 }
995
996                 locks = filtered
997         }
998
999         if limit != "" {
1000                 size, err := strconv.Atoi(limit)
1001                 if err != nil {
1002                         return nil, "", errors.New("unable to parse limit amount")
1003                 }
1004
1005                 size = int(math.Min(float64(len(locks)), 3))
1006                 if size < 0 {
1007                         return nil, "", nil
1008                 }
1009
1010                 if size+1 < len(locks) {
1011                         return locks[:size], locks[size+1].Id, nil
1012                 }
1013         }
1014
1015         return locks, "", nil
1016 }
1017
1018 func delLock(repo string, id string) *Lock {
1019         lmu.RLock()
1020         defer lmu.RUnlock()
1021
1022         var deleted *Lock
1023         locks := make([]Lock, 0, len(repoLocks[repo]))
1024         for _, l := range repoLocks[repo] {
1025                 if l.Id == id {
1026                         deleted = &l
1027                         continue
1028                 }
1029                 locks = append(locks, l)
1030         }
1031         repoLocks[repo] = locks
1032         return deleted
1033 }
1034
1035 type LocksByCreatedAt []Lock
1036
1037 func (c LocksByCreatedAt) Len() int           { return len(c) }
1038 func (c LocksByCreatedAt) Less(i, j int) bool { return c[i].LockedAt.Before(c[j].LockedAt) }
1039 func (c LocksByCreatedAt) Swap(i, j int)      { c[i], c[j] = c[j], c[i] }
1040
1041 var (
1042         lockRe   = regexp.MustCompile(`/locks/?$`)
1043         unlockRe = regexp.MustCompile(`locks/([^/]+)/unlock\z`)
1044 )
1045
1046 func locksHandler(w http.ResponseWriter, r *http.Request, repo string) {
1047         dec := json.NewDecoder(r.Body)
1048         enc := json.NewEncoder(w)
1049
1050         switch r.Method {
1051         case "GET":
1052                 if !lockRe.MatchString(r.URL.Path) {
1053                         w.Header().Set("Content-Type", "application/json")
1054                         w.WriteHeader(http.StatusNotFound)
1055                         w.Write([]byte(`{"message":"unknown path: ` + r.URL.Path + `"}`))
1056                         return
1057                 }
1058
1059                 if err := r.ParseForm(); err != nil {
1060                         http.Error(w, "could not parse form values", http.StatusInternalServerError)
1061                         return
1062                 }
1063
1064                 if strings.HasSuffix(repo, "branch-required") {
1065                         parts := strings.Split(repo, "-")
1066                         lenParts := len(parts)
1067                         if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != r.FormValue("refspec") {
1068                                 w.WriteHeader(403)
1069                                 enc.Encode(struct {
1070                                         Message string `json:"message"`
1071                                 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], r.FormValue("refspec"))})
1072                                 return
1073                         }
1074                 }
1075
1076                 ll := &LockList{}
1077                 w.Header().Set("Content-Type", "application/json")
1078                 locks, nextCursor, err := getFilteredLocks(repo,
1079                         r.FormValue("path"),
1080                         r.FormValue("cursor"),
1081                         r.FormValue("limit"))
1082
1083                 if err != nil {
1084                         ll.Message = err.Error()
1085                 } else {
1086                         ll.Locks = locks
1087                         ll.NextCursor = nextCursor
1088                 }
1089
1090                 enc.Encode(ll)
1091                 return
1092         case "POST":
1093                 w.Header().Set("Content-Type", "application/json")
1094                 if strings.HasSuffix(r.URL.Path, "unlock") {
1095                         var lockId string
1096                         if matches := unlockRe.FindStringSubmatch(r.URL.Path); len(matches) > 1 {
1097                                 lockId = matches[1]
1098                         }
1099
1100                         if len(lockId) == 0 {
1101                                 enc.Encode(&UnlockResponse{Message: "Invalid lock"})
1102                         }
1103
1104                         unlockRequest := &UnlockRequest{}
1105                         if err := dec.Decode(unlockRequest); err != nil {
1106                                 enc.Encode(&UnlockResponse{Message: err.Error()})
1107                                 return
1108                         }
1109
1110                         if strings.HasSuffix(repo, "branch-required") {
1111                                 parts := strings.Split(repo, "-")
1112                                 lenParts := len(parts)
1113                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != unlockRequest.RefName() {
1114                                         w.WriteHeader(403)
1115                                         enc.Encode(struct {
1116                                                 Message string `json:"message"`
1117                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], unlockRequest.RefName())})
1118                                         return
1119                                 }
1120                         }
1121
1122                         if l := delLock(repo, lockId); l != nil {
1123                                 enc.Encode(&UnlockResponse{Lock: l})
1124                         } else {
1125                                 enc.Encode(&UnlockResponse{Message: "unable to find lock"})
1126                         }
1127                         return
1128                 }
1129
1130                 if strings.HasSuffix(r.URL.Path, "/locks/verify") {
1131                         if strings.HasSuffix(repo, "verify-5xx") {
1132                                 w.WriteHeader(500)
1133                                 return
1134                         }
1135                         if strings.HasSuffix(repo, "verify-501") {
1136                                 w.WriteHeader(501)
1137                                 return
1138                         }
1139                         if strings.HasSuffix(repo, "verify-403") {
1140                                 w.WriteHeader(403)
1141                                 return
1142                         }
1143
1144                         switch repo {
1145                         case "pre_push_locks_verify_404":
1146                                 w.WriteHeader(http.StatusNotFound)
1147                                 w.Write([]byte(`{"message":"pre_push_locks_verify_404"}`))
1148                                 return
1149                         case "pre_push_locks_verify_410":
1150                                 w.WriteHeader(http.StatusGone)
1151                                 w.Write([]byte(`{"message":"pre_push_locks_verify_410"}`))
1152                                 return
1153                         }
1154
1155                         reqBody := &VerifiableLockRequest{}
1156                         if err := dec.Decode(reqBody); err != nil {
1157                                 w.WriteHeader(http.StatusBadRequest)
1158                                 enc.Encode(struct {
1159                                         Message string `json:"message"`
1160                                 }{"json decode error: " + err.Error()})
1161                                 return
1162                         }
1163
1164                         if strings.HasSuffix(repo, "branch-required") {
1165                                 parts := strings.Split(repo, "-")
1166                                 lenParts := len(parts)
1167                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != reqBody.RefName() {
1168                                         w.WriteHeader(403)
1169                                         enc.Encode(struct {
1170                                                 Message string `json:"message"`
1171                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], reqBody.RefName())})
1172                                         return
1173                                 }
1174                         }
1175
1176                         ll := &VerifiableLockList{}
1177                         locks, nextCursor, err := getFilteredLocks(repo, "",
1178                                 reqBody.Cursor,
1179                                 strconv.Itoa(reqBody.Limit))
1180                         if err != nil {
1181                                 ll.Message = err.Error()
1182                         } else {
1183                                 ll.NextCursor = nextCursor
1184
1185                                 for _, l := range locks {
1186                                         if strings.Contains(l.Path, "theirs") {
1187                                                 ll.Theirs = append(ll.Theirs, l)
1188                                         } else {
1189                                                 ll.Ours = append(ll.Ours, l)
1190                                         }
1191                                 }
1192                         }
1193
1194                         enc.Encode(ll)
1195                         return
1196                 }
1197
1198                 if strings.HasSuffix(r.URL.Path, "/locks") {
1199                         lockRequest := &LockRequest{}
1200                         if err := dec.Decode(lockRequest); err != nil {
1201                                 enc.Encode(&LockResponse{Message: err.Error()})
1202                         }
1203
1204                         if strings.HasSuffix(repo, "branch-required") {
1205                                 parts := strings.Split(repo, "-")
1206                                 lenParts := len(parts)
1207                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != lockRequest.RefName() {
1208                                         w.WriteHeader(403)
1209                                         enc.Encode(struct {
1210                                                 Message string `json:"message"`
1211                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], lockRequest.RefName())})
1212                                         return
1213                                 }
1214                         }
1215
1216                         for _, l := range getLocks(repo) {
1217                                 if l.Path == lockRequest.Path {
1218                                         enc.Encode(&LockResponse{Message: "lock already created"})
1219                                         return
1220                                 }
1221                         }
1222
1223                         var id [20]byte
1224                         rand.Read(id[:])
1225
1226                         lock := &Lock{
1227                                 Id:       fmt.Sprintf("%x", id[:]),
1228                                 Path:     lockRequest.Path,
1229                                 Owner:    User{Name: "Git LFS Tests"},
1230                                 LockedAt: time.Now(),
1231                         }
1232
1233                         addLocks(repo, *lock)
1234
1235                         // TODO(taylor): commit_needed case
1236                         // TODO(taylor): err case
1237
1238                         enc.Encode(&LockResponse{
1239                                 Lock: lock,
1240                         })
1241                         return
1242                 }
1243         }
1244
1245         http.NotFound(w, r)
1246 }
1247
1248 func missingRequiredCreds(w http.ResponseWriter, r *http.Request, repo string) bool {
1249         if !strings.HasPrefix(repo, "requirecreds") {
1250                 return false
1251         }
1252
1253         auth := r.Header.Get("Authorization")
1254         user, pass, err := extractAuth(auth)
1255         if err != nil {
1256                 writeLFSError(w, 403, err.Error())
1257                 return true
1258         }
1259
1260         if user != "requirecreds" || pass != "pass" {
1261                 writeLFSError(w, 403, fmt.Sprintf("Got: '%s' => '%s' : '%s'", auth, user, pass))
1262                 return true
1263         }
1264
1265         return false
1266 }
1267
1268 func testingChunkedTransferEncoding(r *http.Request) bool {
1269         return strings.HasPrefix(r.URL.String(), "/test-chunked-transfer-encoding")
1270 }
1271
1272 func testingTusUploadInBatchReq(r *http.Request) bool {
1273         return strings.HasPrefix(r.URL.String(), "/test-tus-upload")
1274 }
1275 func testingTusUploadInterruptedInBatchReq(r *http.Request) bool {
1276         return strings.HasPrefix(r.URL.String(), "/test-tus-upload-interrupt")
1277 }
1278 func testingCustomTransfer(r *http.Request) bool {
1279         return strings.HasPrefix(r.URL.String(), "/test-custom-transfer")
1280 }
1281
1282 var lfsUrlRE = regexp.MustCompile(`\A/?([^/]+)/info/lfs`)
1283
1284 func repoFromLfsUrl(urlpath string) (string, error) {
1285         matches := lfsUrlRE.FindStringSubmatch(urlpath)
1286         if len(matches) != 2 {
1287                 return "", fmt.Errorf("LFS url '%s' does not match %v", urlpath, lfsUrlRE)
1288         }
1289
1290         repo := matches[1]
1291         if strings.HasSuffix(repo, ".git") {
1292                 return repo[0 : len(repo)-4], nil
1293         }
1294         return repo, nil
1295 }
1296
1297 type lfsStorage struct {
1298         objects    map[string]map[string][]byte
1299         incomplete map[string]map[string][]byte
1300         mutex      *sync.Mutex
1301 }
1302
1303 func (s *lfsStorage) Get(repo, oid string) ([]byte, bool) {
1304         s.mutex.Lock()
1305         defer s.mutex.Unlock()
1306         repoObjects, ok := s.objects[repo]
1307         if !ok {
1308                 return nil, ok
1309         }
1310
1311         by, ok := repoObjects[oid]
1312         return by, ok
1313 }
1314
1315 func (s *lfsStorage) Has(repo, oid string) bool {
1316         s.mutex.Lock()
1317         defer s.mutex.Unlock()
1318         repoObjects, ok := s.objects[repo]
1319         if !ok {
1320                 return false
1321         }
1322
1323         _, ok = repoObjects[oid]
1324         return ok
1325 }
1326
1327 func (s *lfsStorage) Set(repo, oid string, by []byte) {
1328         s.mutex.Lock()
1329         defer s.mutex.Unlock()
1330         repoObjects, ok := s.objects[repo]
1331         if !ok {
1332                 repoObjects = make(map[string][]byte)
1333                 s.objects[repo] = repoObjects
1334         }
1335         repoObjects[oid] = by
1336 }
1337
1338 func (s *lfsStorage) Delete(repo, oid string) {
1339         s.mutex.Lock()
1340         defer s.mutex.Unlock()
1341         repoObjects, ok := s.objects[repo]
1342         if ok {
1343                 delete(repoObjects, oid)
1344         }
1345 }
1346
1347 func (s *lfsStorage) GetIncomplete(repo, oid string) ([]byte, bool) {
1348         s.mutex.Lock()
1349         defer s.mutex.Unlock()
1350         repoObjects, ok := s.incomplete[repo]
1351         if !ok {
1352                 return nil, ok
1353         }
1354
1355         by, ok := repoObjects[oid]
1356         return by, ok
1357 }
1358
1359 func (s *lfsStorage) SetIncomplete(repo, oid string, by []byte) {
1360         s.mutex.Lock()
1361         defer s.mutex.Unlock()
1362         repoObjects, ok := s.incomplete[repo]
1363         if !ok {
1364                 repoObjects = make(map[string][]byte)
1365                 s.incomplete[repo] = repoObjects
1366         }
1367         repoObjects[oid] = by
1368 }
1369
1370 func (s *lfsStorage) DeleteIncomplete(repo, oid string) {
1371         s.mutex.Lock()
1372         defer s.mutex.Unlock()
1373         repoObjects, ok := s.incomplete[repo]
1374         if ok {
1375                 delete(repoObjects, oid)
1376         }
1377 }
1378
1379 func newLfsStorage() *lfsStorage {
1380         return &lfsStorage{
1381                 objects:    make(map[string]map[string][]byte),
1382                 incomplete: make(map[string]map[string][]byte),
1383                 mutex:      &sync.Mutex{},
1384         }
1385 }
1386
1387 func extractAuth(auth string) (string, string, error) {
1388         if strings.HasPrefix(auth, "Basic ") {
1389                 decodeBy, err := base64.StdEncoding.DecodeString(auth[6:len(auth)])
1390                 decoded := string(decodeBy)
1391
1392                 if err != nil {
1393                         return "", "", err
1394                 }
1395
1396                 parts := strings.SplitN(decoded, ":", 2)
1397                 if len(parts) == 2 {
1398                         return parts[0], parts[1], nil
1399                 }
1400                 return "", "", nil
1401         }
1402
1403         return "", "", nil
1404 }
1405
1406 func skipIfBadAuth(w http.ResponseWriter, r *http.Request, id string, ntlmSession ntlm.ServerSession) bool {
1407         auth := r.Header.Get("Authorization")
1408         if strings.Contains(r.URL.Path, "ntlm") {
1409                 return false
1410         }
1411
1412         if auth == "" {
1413                 w.WriteHeader(401)
1414                 return true
1415         }
1416
1417         user, pass, err := extractAuth(auth)
1418         if err != nil {
1419                 w.WriteHeader(403)
1420                 debug(id, "Error decoding auth: %s", err)
1421                 return true
1422         }
1423
1424         switch user {
1425         case "user":
1426                 if pass == "pass" {
1427                         return false
1428                 }
1429         case "netrcuser", "requirecreds":
1430                 return false
1431         case "path":
1432                 if strings.HasPrefix(r.URL.Path, "/"+pass) {
1433                         return false
1434                 }
1435                 debug(id, "auth attempt against: %q", r.URL.Path)
1436         }
1437
1438         w.WriteHeader(403)
1439         debug(id, "Bad auth: %q", auth)
1440         return true
1441 }
1442
1443 func handleNTLM(w http.ResponseWriter, r *http.Request, authHeader string, session ntlm.ServerSession) {
1444         if strings.HasPrefix(strings.ToUpper(authHeader), "BASIC ") {
1445                 authHeader = ""
1446         }
1447
1448         switch authHeader {
1449         case "":
1450                 w.Header().Set("Www-Authenticate", "ntlm")
1451                 w.WriteHeader(401)
1452
1453         // ntlmNegotiateMessage from httputil pkg
1454         case "NTLM TlRMTVNTUAABAAAAB7IIogwADAAzAAAACwALACgAAAAKAAAoAAAAD1dJTExISS1NQUlOTk9SVEhBTUVSSUNB":
1455                 ch, err := session.GenerateChallengeMessage()
1456                 if err != nil {
1457                         writeLFSError(w, 500, err.Error())
1458                         return
1459                 }
1460
1461                 chMsg := base64.StdEncoding.EncodeToString(ch.Bytes())
1462                 w.Header().Set("Www-Authenticate", "ntlm "+chMsg)
1463                 w.WriteHeader(401)
1464
1465         default:
1466                 if !strings.HasPrefix(strings.ToUpper(authHeader), "NTLM ") {
1467                         writeLFSError(w, 500, "bad authorization header: "+authHeader)
1468                         return
1469                 }
1470
1471                 auth := authHeader[5:] // strip "ntlm " prefix
1472                 val, err := base64.StdEncoding.DecodeString(auth)
1473                 if err != nil {
1474                         writeLFSError(w, 500, "base64 decode error: "+err.Error())
1475                         return
1476                 }
1477
1478                 _, err = ntlm.ParseAuthenticateMessage(val, 2)
1479                 if err != nil {
1480                         writeLFSError(w, 500, "auth parse error: "+err.Error())
1481                         return
1482                 }
1483         }
1484 }
1485
1486 func init() {
1487         oidHandlers = make(map[string]string)
1488         for _, content := range contentHandlers {
1489                 h := sha256.New()
1490                 h.Write([]byte(content))
1491                 oidHandlers[hex.EncodeToString(h.Sum(nil))] = content
1492         }
1493 }
1494
1495 func debug(reqid, msg string, args ...interface{}) {
1496         fullargs := make([]interface{}, len(args)+1)
1497         fullargs[0] = reqid
1498         for i, a := range args {
1499                 fullargs[i+1] = a
1500         }
1501         log.Printf("[%s] "+msg+"\n", fullargs...)
1502 }
1503
1504 func reqId(w http.ResponseWriter) (string, bool) {
1505         b := make([]byte, 16)
1506         _, err := rand.Read(b)
1507         if err != nil {
1508                 http.Error(w, "error generating id: "+err.Error(), 500)
1509                 return "", false
1510         }
1511         return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), true
1512 }
1513
1514 // https://ericchiang.github.io/post/go-tls/
1515 func generateCARootCertificates() (rootKey *rsa.PrivateKey, rootCert *x509.Certificate) {
1516
1517         // generate a new key-pair
1518         rootKey, err := rsa.GenerateKey(rand.Reader, 2048)
1519         if err != nil {
1520                 log.Fatalf("generating random key: %v", err)
1521         }
1522
1523         rootCertTmpl, err := CertTemplate()
1524         if err != nil {
1525                 log.Fatalf("creating cert template: %v", err)
1526         }
1527         // describe what the certificate will be used for
1528         rootCertTmpl.IsCA = true
1529         rootCertTmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature
1530         rootCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
1531         //      rootCertTmpl.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")}
1532
1533         rootCert, _, err = CreateCert(rootCertTmpl, rootCertTmpl, &rootKey.PublicKey, rootKey)
1534
1535         return
1536 }
1537
1538 func generateClientCertificates(rootCert *x509.Certificate, rootKey interface{}) (clientKey *rsa.PrivateKey, clientCertPEM []byte, clientKeyPEM []byte) {
1539
1540         // create a key-pair for the client
1541         clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
1542         if err != nil {
1543                 log.Fatalf("generating random key: %v", err)
1544         }
1545
1546         // create a template for the client
1547         clientCertTmpl, err1 := CertTemplate()
1548         if err1 != nil {
1549                 log.Fatalf("creating cert template: %v", err1)
1550         }
1551         clientCertTmpl.KeyUsage = x509.KeyUsageDigitalSignature
1552         clientCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
1553
1554         // the root cert signs the cert by again providing its private key
1555         _, clientCertPEM, err2 := CreateCert(clientCertTmpl, rootCert, &clientKey.PublicKey, rootKey)
1556         if err2 != nil {
1557                 log.Fatalf("error creating cert: %v", err2)
1558         }
1559
1560         // encode and load the cert and private key for the client
1561         clientKeyPEM = pem.EncodeToMemory(&pem.Block{
1562                 Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
1563         })
1564
1565         return
1566 }
1567
1568 // helper function to create a cert template with a serial number and other required fields
1569 func CertTemplate() (*x509.Certificate, error) {
1570         // generate a random serial number (a real cert authority would have some logic behind this)
1571         serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
1572         serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
1573         if err != nil {
1574                 return nil, errors.New("failed to generate serial number: " + err.Error())
1575         }
1576
1577         tmpl := x509.Certificate{
1578                 SerialNumber:          serialNumber,
1579                 Subject:               pkix.Name{Organization: []string{"Yhat, Inc."}},
1580                 SignatureAlgorithm:    x509.SHA256WithRSA,
1581                 NotBefore:             time.Now(),
1582                 NotAfter:              time.Now().Add(time.Hour), // valid for an hour
1583                 BasicConstraintsValid: true,
1584         }
1585         return &tmpl, nil
1586 }
1587
1588 func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv interface{}) (
1589         cert *x509.Certificate, certPEM []byte, err error) {
1590
1591         certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, parentPriv)
1592         if err != nil {
1593                 return
1594         }
1595         // parse the resulting certificate so we can use it again
1596         cert, err = x509.ParseCertificate(certDER)
1597         if err != nil {
1598                 return
1599         }
1600         // PEM encode the certificate (this is a standard TLS encoding)
1601         b := pem.Block{Type: "CERTIFICATE", Bytes: certDER}
1602         certPEM = pem.EncodeToMemory(&b)
1603         return
1604 }