Imported Upstream version 2.4.2
[scm/test.git] / t / cmd / lfstest-gitserver.go
1 // +build testtools
2
3 package main
4
5 import (
6         "bufio"
7         "bytes"
8         "crypto/rand"
9         "crypto/rsa"
10         "crypto/sha256"
11         "crypto/tls"
12         "crypto/x509"
13         "crypto/x509/pkix"
14         "encoding/base64"
15         "encoding/hex"
16         "encoding/json"
17         "encoding/pem"
18         "errors"
19         "fmt"
20         "io"
21         "io/ioutil"
22         "log"
23         "math"
24         "math/big"
25         "net/http"
26         "net/http/httptest"
27         "net/textproto"
28         "os"
29         "os/exec"
30         "regexp"
31         "sort"
32         "strconv"
33         "strings"
34         "sync"
35         "time"
36
37         "github.com/ThomsonReutersEikon/go-ntlm/ntlm"
38 )
39
40 var (
41         repoDir          string
42         largeObjects     = newLfsStorage()
43         server           *httptest.Server
44         serverTLS        *httptest.Server
45         serverClientCert *httptest.Server
46
47         // maps OIDs to content strings. Both the LFS and Storage test servers below
48         // see OIDs.
49         oidHandlers map[string]string
50
51         // These magic strings tell the test lfs server change their behavior so the
52         // integration tests can check those use cases. Tests will create objects with
53         // the magic strings as the contents.
54         //
55         //   printf "status:lfs:404" > 404.dat
56         //
57         contentHandlers = []string{
58                 "status-batch-403", "status-batch-404", "status-batch-410", "status-batch-422", "status-batch-500",
59                 "status-storage-403", "status-storage-404", "status-storage-410", "status-storage-422", "status-storage-500", "status-storage-503",
60                 "status-batch-resume-206", "batch-resume-fail-fallback", "return-expired-action", "return-expired-action-forever", "return-invalid-size",
61                 "object-authenticated", "storage-download-retry", "storage-upload-retry", "unknown-oid",
62                 "send-verify-action", "send-deprecated-links",
63         }
64 )
65
66 func main() {
67         repoDir = os.Getenv("LFSTEST_DIR")
68
69         mux := http.NewServeMux()
70         server = httptest.NewServer(mux)
71         serverTLS = httptest.NewTLSServer(mux)
72         serverClientCert = httptest.NewUnstartedServer(mux)
73
74         //setup Client Cert server
75         rootKey, rootCert := generateCARootCertificates()
76         _, clientCertPEM, clientKeyPEM := generateClientCertificates(rootCert, rootKey)
77
78         certPool := x509.NewCertPool()
79         certPool.AddCert(rootCert)
80
81         serverClientCert.TLS = &tls.Config{
82                 Certificates: []tls.Certificate{serverTLS.TLS.Certificates[0]},
83                 ClientAuth:   tls.RequireAndVerifyClientCert,
84                 ClientCAs:    certPool,
85         }
86         serverClientCert.StartTLS()
87
88         ntlmSession, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
89         if err != nil {
90                 fmt.Println("Error creating ntlm session:", err)
91                 os.Exit(1)
92         }
93         ntlmSession.SetUserInfo("ntlmuser", "ntlmpass", "NTLMDOMAIN")
94
95         stopch := make(chan bool)
96
97         mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
98                 stopch <- true
99         })
100
101         mux.HandleFunc("/storage/", storageHandler)
102         mux.HandleFunc("/verify", verifyHandler)
103         mux.HandleFunc("/redirect307/", redirect307Handler)
104         mux.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) {
105                 fmt.Fprintf(w, "%s\n", time.Now().String())
106         })
107         mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
108                 id, ok := reqId(w)
109                 if !ok {
110                         return
111                 }
112
113                 if strings.Contains(r.URL.Path, "/info/lfs") {
114                         if !skipIfBadAuth(w, r, id, ntlmSession) {
115                                 lfsHandler(w, r, id)
116                         }
117
118                         return
119                 }
120
121                 debug(id, "git http-backend %s %s", r.Method, r.URL)
122                 gitHandler(w, r)
123         })
124
125         urlname := writeTestStateFile([]byte(server.URL), "LFSTEST_URL", "lfstest-gitserver")
126         defer os.RemoveAll(urlname)
127
128         sslurlname := writeTestStateFile([]byte(serverTLS.URL), "LFSTEST_SSL_URL", "lfstest-gitserver-ssl")
129         defer os.RemoveAll(sslurlname)
130
131         clientCertUrlname := writeTestStateFile([]byte(serverClientCert.URL), "LFSTEST_CLIENT_CERT_URL", "lfstest-gitserver-ssl")
132         defer os.RemoveAll(clientCertUrlname)
133
134         block := &pem.Block{}
135         block.Type = "CERTIFICATE"
136         block.Bytes = serverTLS.TLS.Certificates[0].Certificate[0]
137         pembytes := pem.EncodeToMemory(block)
138
139         certname := writeTestStateFile(pembytes, "LFSTEST_CERT", "lfstest-gitserver-cert")
140         defer os.RemoveAll(certname)
141
142         cccertname := writeTestStateFile(clientCertPEM, "LFSTEST_CLIENT_CERT", "lfstest-gitserver-client-cert")
143         defer os.RemoveAll(cccertname)
144
145         ckcertname := writeTestStateFile(clientKeyPEM, "LFSTEST_CLIENT_KEY", "lfstest-gitserver-client-key")
146         defer os.RemoveAll(ckcertname)
147
148         debug("init", "server url: %s", server.URL)
149         debug("init", "server tls url: %s", serverTLS.URL)
150         debug("init", "server client cert url: %s", serverClientCert.URL)
151
152         <-stopch
153         debug("init", "git server done")
154 }
155
156 // writeTestStateFile writes contents to either the file referenced by the
157 // environment variable envVar, or defaultFilename if that's not set. Returns
158 // the filename that was used
159 func writeTestStateFile(contents []byte, envVar, defaultFilename string) string {
160         f := os.Getenv(envVar)
161         if len(f) == 0 {
162                 f = defaultFilename
163         }
164         file, err := os.Create(f)
165         if err != nil {
166                 log.Fatalln(err)
167         }
168         file.Write(contents)
169         file.Close()
170         return f
171 }
172
173 type lfsObject struct {
174         Oid           string              `json:"oid,omitempty"`
175         Size          int64               `json:"size,omitempty"`
176         Authenticated bool                `json:"authenticated,omitempty"`
177         Actions       map[string]*lfsLink `json:"actions,omitempty"`
178         Links         map[string]*lfsLink `json:"_links,omitempty"`
179         Err           *lfsError           `json:"error,omitempty"`
180 }
181
182 type lfsLink struct {
183         Href      string            `json:"href"`
184         Header    map[string]string `json:"header,omitempty"`
185         ExpiresAt time.Time         `json:"expires_at,omitempty"`
186         ExpiresIn int               `json:"expires_in,omitempty"`
187 }
188
189 type lfsError struct {
190         Code    int    `json:"code,omitempty"`
191         Message string `json:"message"`
192 }
193
194 func writeLFSError(w http.ResponseWriter, code int, msg string) {
195         by, err := json.Marshal(&lfsError{Message: msg})
196         if err != nil {
197                 http.Error(w, "json encoding error: "+err.Error(), 500)
198                 return
199         }
200
201         w.Header().Set("Content-Type", "application/vnd.git-lfs+json")
202         w.WriteHeader(code)
203         w.Write(by)
204 }
205
206 // handles any requests with "{name}.server.git/info/lfs" in the path
207 func lfsHandler(w http.ResponseWriter, r *http.Request, id string) {
208         repo, err := repoFromLfsUrl(r.URL.Path)
209         if err != nil {
210                 w.WriteHeader(500)
211                 w.Write([]byte(err.Error()))
212                 return
213         }
214
215         debug(id, "git lfs %s %s repo: %s", r.Method, r.URL, repo)
216         w.Header().Set("Content-Type", "application/vnd.git-lfs+json")
217         switch r.Method {
218         case "POST":
219                 if strings.HasSuffix(r.URL.String(), "batch") {
220                         lfsBatchHandler(w, r, id, repo)
221                 } else {
222                         locksHandler(w, r, repo)
223                 }
224         case "DELETE":
225                 lfsDeleteHandler(w, r, id, repo)
226         case "GET":
227                 if strings.Contains(r.URL.String(), "/locks") {
228                         locksHandler(w, r, repo)
229                 } else {
230                         w.WriteHeader(404)
231                         w.Write([]byte("lock request"))
232                 }
233         default:
234                 w.WriteHeader(405)
235         }
236 }
237
238 func lfsUrl(repo, oid string) string {
239         return server.URL + "/storage/" + oid + "?r=" + repo
240 }
241
242 var (
243         retries   = make(map[string]uint32)
244         retriesMu sync.Mutex
245 )
246
247 func incrementRetriesFor(api, direction, repo, oid string, check bool) (after uint32, ok bool) {
248         // fmtStr formats a string like "<api>-<direction>-[check]-<retry>",
249         // i.e., "legacy-upload-check-retry", or "storage-download-retry".
250         var fmtStr string
251         if check {
252                 fmtStr = "%s-%s-check-retry"
253         } else {
254                 fmtStr = "%s-%s-retry"
255         }
256
257         if oidHandlers[oid] != fmt.Sprintf(fmtStr, api, direction) {
258                 return 0, false
259         }
260
261         retriesMu.Lock()
262         defer retriesMu.Unlock()
263
264         retryKey := strings.Join([]string{direction, repo, oid}, ":")
265
266         retries[retryKey]++
267         retries := retries[retryKey]
268
269         return retries, true
270 }
271
272 func lfsDeleteHandler(w http.ResponseWriter, r *http.Request, id, repo string) {
273         parts := strings.Split(r.URL.Path, "/")
274         oid := parts[len(parts)-1]
275
276         largeObjects.Delete(repo, oid)
277         debug(id, "DELETE:", oid)
278         w.WriteHeader(200)
279 }
280
281 type batchReq struct {
282         Transfers []string    `json:"transfers"`
283         Operation string      `json:"operation"`
284         Objects   []lfsObject `json:"objects"`
285         Ref       *Ref        `json:"ref,omitempty"`
286 }
287
288 func (r *batchReq) RefName() string {
289         if r.Ref == nil {
290                 return ""
291         }
292         return r.Ref.Name
293 }
294
295 type batchResp struct {
296         Transfer string      `json:"transfer,omitempty"`
297         Objects  []lfsObject `json:"objects"`
298 }
299
300 func lfsBatchHandler(w http.ResponseWriter, r *http.Request, id, repo string) {
301         checkingObject := r.Header.Get("X-Check-Object") == "1"
302         if !checkingObject && repo == "batchunsupported" {
303                 w.WriteHeader(404)
304                 return
305         }
306
307         if !checkingObject && repo == "badbatch" {
308                 w.WriteHeader(203)
309                 return
310         }
311
312         if repo == "netrctest" {
313                 user, pass, err := extractAuth(r.Header.Get("Authorization"))
314                 if err != nil || (user != "netrcuser" || pass != "netrcpass") {
315                         w.WriteHeader(403)
316                         return
317                 }
318         }
319
320         if missingRequiredCreds(w, r, repo) {
321                 return
322         }
323
324         buf := &bytes.Buffer{}
325         tee := io.TeeReader(r.Body, buf)
326         objs := &batchReq{}
327         err := json.NewDecoder(tee).Decode(objs)
328         io.Copy(ioutil.Discard, r.Body)
329         r.Body.Close()
330
331         debug(id, "REQUEST")
332         debug(id, buf.String())
333
334         if err != nil {
335                 log.Fatal(err)
336         }
337
338         if strings.HasSuffix(repo, "branch-required") {
339                 parts := strings.Split(repo, "-")
340                 lenParts := len(parts)
341                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != objs.RefName() {
342                         w.WriteHeader(403)
343                         json.NewEncoder(w).Encode(struct {
344                                 Message string `json:"message"`
345                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], objs.RefName())})
346                         return
347                 }
348         }
349
350         res := []lfsObject{}
351         testingChunked := testingChunkedTransferEncoding(r)
352         testingTus := testingTusUploadInBatchReq(r)
353         testingTusInterrupt := testingTusUploadInterruptedInBatchReq(r)
354         testingCustomTransfer := testingCustomTransfer(r)
355         var transferChoice string
356         var searchForTransfer string
357         if testingTus {
358                 searchForTransfer = "tus"
359         } else if testingCustomTransfer {
360                 searchForTransfer = "testcustom"
361         }
362         if len(searchForTransfer) > 0 {
363                 for _, t := range objs.Transfers {
364                         if t == searchForTransfer {
365                                 transferChoice = searchForTransfer
366                                 break
367                         }
368
369                 }
370         }
371         for _, obj := range objs.Objects {
372                 handler := oidHandlers[obj.Oid]
373                 action := objs.Operation
374
375                 o := lfsObject{
376                         Size:    obj.Size,
377                         Actions: make(map[string]*lfsLink),
378                 }
379
380                 // Clobber the OID if told to do so.
381                 if handler == "unknown-oid" {
382                         o.Oid = "unknown-oid"
383                 } else {
384                         o.Oid = obj.Oid
385                 }
386
387                 exists := largeObjects.Has(repo, obj.Oid)
388                 addAction := true
389                 if action == "download" {
390                         if !exists {
391                                 o.Err = &lfsError{Code: 404, Message: fmt.Sprintf("Object %v does not exist", obj.Oid)}
392                                 addAction = false
393                         }
394                 } else {
395                         if exists {
396                                 // not an error but don't add an action
397                                 addAction = false
398                         }
399                 }
400
401                 if handler == "object-authenticated" {
402                         o.Authenticated = true
403                 }
404
405                 switch handler {
406                 case "status-batch-403":
407                         o.Err = &lfsError{Code: 403, Message: "welp"}
408                 case "status-batch-404":
409                         o.Err = &lfsError{Code: 404, Message: "welp"}
410                 case "status-batch-410":
411                         o.Err = &lfsError{Code: 410, Message: "welp"}
412                 case "status-batch-422":
413                         o.Err = &lfsError{Code: 422, Message: "welp"}
414                 case "status-batch-500":
415                         o.Err = &lfsError{Code: 500, Message: "welp"}
416                 default: // regular 200 response
417                         if handler == "return-invalid-size" {
418                                 o.Size = -1
419                         }
420
421                         if handler == "send-deprecated-links" {
422                                 o.Links = make(map[string]*lfsLink)
423                         }
424
425                         if addAction {
426                                 a := &lfsLink{
427                                         Href:   lfsUrl(repo, obj.Oid),
428                                         Header: map[string]string{},
429                                 }
430                                 a = serveExpired(a, repo, handler)
431
432                                 if handler == "send-deprecated-links" {
433                                         o.Links[action] = a
434                                 } else {
435                                         o.Actions[action] = a
436                                 }
437                         }
438
439                         if handler == "send-verify-action" {
440                                 o.Actions["verify"] = &lfsLink{
441                                         Href: server.URL + "/verify",
442                                         Header: map[string]string{
443                                                 "repo": repo,
444                                         },
445                                 }
446                         }
447                 }
448
449                 if testingChunked && addAction {
450                         if handler == "send-deprecated-links" {
451                                 o.Links[action].Header["Transfer-Encoding"] = "chunked"
452                         } else {
453                                 o.Actions[action].Header["Transfer-Encoding"] = "chunked"
454                         }
455                 }
456                 if testingTusInterrupt && addAction {
457                         if handler == "send-deprecated-links" {
458                                 o.Links[action].Header["Lfs-Tus-Interrupt"] = "true"
459                         } else {
460                                 o.Actions[action].Header["Lfs-Tus-Interrupt"] = "true"
461                         }
462                 }
463
464                 res = append(res, o)
465         }
466
467         ores := batchResp{Transfer: transferChoice, Objects: res}
468
469         by, err := json.Marshal(ores)
470         if err != nil {
471                 log.Fatal(err)
472         }
473
474         debug(id, "RESPONSE: 200")
475         debug(id, string(by))
476
477         w.WriteHeader(200)
478         w.Write(by)
479 }
480
481 // emu guards expiredRepos
482 var emu sync.Mutex
483
484 // expiredRepos is a map keyed by repository name, valuing to whether or not it
485 // has yet served an expired object.
486 var expiredRepos = map[string]bool{}
487
488 // serveExpired marks the given repo as having served an expired object, making
489 // it unable for that same repository to return an expired object in the future,
490 func serveExpired(a *lfsLink, repo, handler string) *lfsLink {
491         var (
492                 dur = -5 * time.Minute
493                 at  = time.Now().Add(dur)
494         )
495
496         if handler == "return-expired-action-forever" ||
497                 (handler == "return-expired-action" && canServeExpired(repo)) {
498
499                 emu.Lock()
500                 expiredRepos[repo] = true
501                 emu.Unlock()
502
503                 a.ExpiresAt = at
504                 return a
505         }
506
507         switch repo {
508         case "expired-absolute":
509                 a.ExpiresAt = at
510         case "expired-relative":
511                 a.ExpiresIn = -5
512         case "expired-both":
513                 a.ExpiresAt = at
514                 a.ExpiresIn = -5
515         }
516
517         return a
518 }
519
520 // canServeExpired returns whether or not a repository is capable of serving an
521 // expired object. In other words, canServeExpired returns whether or not the
522 // given repo has yet served an expired object.
523 func canServeExpired(repo string) bool {
524         emu.Lock()
525         defer emu.Unlock()
526
527         return !expiredRepos[repo]
528 }
529
530 // Persistent state across requests
531 var batchResumeFailFallbackStorageAttempts = 0
532 var tusStorageAttempts = 0
533
534 var (
535         vmu           sync.Mutex
536         verifyCounts  = make(map[string]int)
537         verifyRetryRe = regexp.MustCompile(`verify-fail-(\d+)-times?$`)
538 )
539
540 func verifyHandler(w http.ResponseWriter, r *http.Request) {
541         repo := r.Header.Get("repo")
542         var payload struct {
543                 Oid  string `json:"oid"`
544                 Size int64  `json:"size"`
545         }
546
547         if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
548                 writeLFSError(w, http.StatusUnprocessableEntity, err.Error())
549                 return
550         }
551
552         var max int
553         if matches := verifyRetryRe.FindStringSubmatch(repo); len(matches) < 2 {
554                 return
555         } else {
556                 max, _ = strconv.Atoi(matches[1])
557         }
558
559         key := strings.Join([]string{repo, payload.Oid}, ":")
560
561         vmu.Lock()
562         verifyCounts[key] = verifyCounts[key] + 1
563         count := verifyCounts[key]
564         vmu.Unlock()
565
566         if count < max {
567                 writeLFSError(w, http.StatusServiceUnavailable, fmt.Sprintf(
568                         "intentionally failing verify request %d (out of %d)", count, max,
569                 ))
570                 return
571         }
572 }
573
574 // handles any /storage/{oid} requests
575 func storageHandler(w http.ResponseWriter, r *http.Request) {
576         id, ok := reqId(w)
577         if !ok {
578                 return
579         }
580
581         repo := r.URL.Query().Get("r")
582         parts := strings.Split(r.URL.Path, "/")
583         oid := parts[len(parts)-1]
584         if missingRequiredCreds(w, r, repo) {
585                 return
586         }
587
588         debug(id, "storage %s %s repo: %s", r.Method, oid, repo)
589         switch r.Method {
590         case "PUT":
591                 switch oidHandlers[oid] {
592                 case "status-storage-403":
593                         w.WriteHeader(403)
594                         return
595                 case "status-storage-404":
596                         w.WriteHeader(404)
597                         return
598                 case "status-storage-410":
599                         w.WriteHeader(410)
600                         return
601                 case "status-storage-422":
602                         w.WriteHeader(422)
603                         return
604                 case "status-storage-500":
605                         w.WriteHeader(500)
606                         return
607                 case "status-storage-503":
608                         writeLFSError(w, 503, "LFS is temporarily unavailable")
609                         return
610                 case "object-authenticated":
611                         if len(r.Header.Get("Authorization")) > 0 {
612                                 w.WriteHeader(400)
613                                 w.Write([]byte("Should not send authentication"))
614                         }
615                         return
616                 case "storage-upload-retry":
617                         if retries, ok := incrementRetriesFor("storage", "upload", repo, oid, false); ok && retries < 3 {
618                                 w.WriteHeader(500)
619                                 w.Write([]byte("malformed content"))
620
621                                 return
622                         }
623                 }
624
625                 if testingChunkedTransferEncoding(r) {
626                         valid := false
627                         for _, value := range r.TransferEncoding {
628                                 if value == "chunked" {
629                                         valid = true
630                                         break
631                                 }
632                         }
633                         if !valid {
634                                 debug(id, "Chunked transfer encoding expected")
635                         }
636                 }
637
638                 hash := sha256.New()
639                 buf := &bytes.Buffer{}
640
641                 io.Copy(io.MultiWriter(hash, buf), r.Body)
642                 oid := hex.EncodeToString(hash.Sum(nil))
643                 if !strings.HasSuffix(r.URL.Path, "/"+oid) {
644                         w.WriteHeader(403)
645                         return
646                 }
647
648                 largeObjects.Set(repo, oid, buf.Bytes())
649
650         case "GET":
651                 parts := strings.Split(r.URL.Path, "/")
652                 oid := parts[len(parts)-1]
653                 statusCode := 200
654                 byteLimit := 0
655                 resumeAt := int64(0)
656
657                 if by, ok := largeObjects.Get(repo, oid); ok {
658                         if len(by) == len("storage-download-retry") && string(by) == "storage-download-retry" {
659                                 if retries, ok := incrementRetriesFor("storage", "download", repo, oid, false); ok && retries < 3 {
660                                         statusCode = 500
661                                         by = []byte("malformed content")
662                                 }
663                         } else if len(by) == len("status-batch-resume-206") && string(by) == "status-batch-resume-206" {
664                                 // Resume if header includes range, otherwise deliberately interrupt
665                                 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
666                                         regex := regexp.MustCompile(`bytes=(\d+)\-.*`)
667                                         match := regex.FindStringSubmatch(rangeHdr)
668                                         if match != nil && len(match) > 1 {
669                                                 statusCode = 206
670                                                 resumeAt, _ = strconv.ParseInt(match[1], 10, 32)
671                                                 w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", resumeAt, len(by), resumeAt-int64(len(by))))
672                                         }
673                                 } else {
674                                         byteLimit = 10
675                                 }
676                         } else if len(by) == len("batch-resume-fail-fallback") && string(by) == "batch-resume-fail-fallback" {
677                                 // Fail any Range: request even though we said we supported it
678                                 // To make sure client can fall back
679                                 if rangeHdr := r.Header.Get("Range"); rangeHdr != "" {
680                                         w.WriteHeader(416)
681                                         return
682                                 }
683                                 if batchResumeFailFallbackStorageAttempts == 0 {
684                                         // Truncate output on FIRST attempt to cause resume
685                                         // Second attempt (without range header) is fallback, complete successfully
686                                         byteLimit = 8
687                                         batchResumeFailFallbackStorageAttempts++
688                                 }
689                         }
690                         w.WriteHeader(statusCode)
691                         if byteLimit > 0 {
692                                 w.Write(by[0:byteLimit])
693                         } else if resumeAt > 0 {
694                                 w.Write(by[resumeAt:])
695                         } else {
696                                 w.Write(by)
697                         }
698                         return
699                 }
700
701                 w.WriteHeader(404)
702         case "HEAD":
703                 // tus.io
704                 if !validateTusHeaders(r, id) {
705                         w.WriteHeader(400)
706                         return
707                 }
708                 parts := strings.Split(r.URL.Path, "/")
709                 oid := parts[len(parts)-1]
710                 var offset int64
711                 if by, ok := largeObjects.GetIncomplete(repo, oid); ok {
712                         offset = int64(len(by))
713                 }
714                 w.Header().Set("Upload-Offset", strconv.FormatInt(offset, 10))
715                 w.WriteHeader(200)
716         case "PATCH":
717                 // tus.io
718                 if !validateTusHeaders(r, id) {
719                         w.WriteHeader(400)
720                         return
721                 }
722                 parts := strings.Split(r.URL.Path, "/")
723                 oid := parts[len(parts)-1]
724
725                 offsetHdr := r.Header.Get("Upload-Offset")
726                 offset, err := strconv.ParseInt(offsetHdr, 10, 64)
727                 if err != nil {
728                         log.Fatal("Unable to parse Upload-Offset header in request: ", err)
729                         w.WriteHeader(400)
730                         return
731                 }
732                 hash := sha256.New()
733                 buf := &bytes.Buffer{}
734                 out := io.MultiWriter(hash, buf)
735
736                 if by, ok := largeObjects.GetIncomplete(repo, oid); ok {
737                         if offset != int64(len(by)) {
738                                 log.Fatal(fmt.Sprintf("Incorrect offset in request, got %d expected %d", offset, len(by)))
739                                 w.WriteHeader(400)
740                                 return
741                         }
742                         _, err := out.Write(by)
743                         if err != nil {
744                                 log.Fatal("Error reading incomplete bytes from store: ", err)
745                                 w.WriteHeader(500)
746                                 return
747                         }
748                         largeObjects.DeleteIncomplete(repo, oid)
749                         debug(id, "Resuming upload of %v at byte %d", oid, offset)
750                 }
751
752                 // As a test, we intentionally break the upload from byte 0 by only
753                 // reading some bytes the quitting & erroring, this forces a resume
754                 // any offset > 0 will work ok
755                 var copyErr error
756                 if r.Header.Get("Lfs-Tus-Interrupt") == "true" && offset == 0 {
757                         chdr := r.Header.Get("Content-Length")
758                         contentLen, err := strconv.ParseInt(chdr, 10, 64)
759                         if err != nil {
760                                 log.Fatal(fmt.Sprintf("Invalid Content-Length %q", chdr))
761                                 w.WriteHeader(400)
762                                 return
763                         }
764                         truncated := contentLen / 3
765                         _, _ = io.CopyN(out, r.Body, truncated)
766                         r.Body.Close()
767                         copyErr = fmt.Errorf("Simulated copy error")
768                 } else {
769                         _, copyErr = io.Copy(out, r.Body)
770                 }
771                 if copyErr != nil {
772                         b := buf.Bytes()
773                         if len(b) > 0 {
774                                 debug(id, "Incomplete upload of %v, %d bytes", oid, len(b))
775                                 largeObjects.SetIncomplete(repo, oid, b)
776                         }
777                         w.WriteHeader(500)
778                 } else {
779                         checkoid := hex.EncodeToString(hash.Sum(nil))
780                         if checkoid != oid {
781                                 log.Fatal(fmt.Sprintf("Incorrect oid after calculation, got %q expected %q", checkoid, oid))
782                                 w.WriteHeader(403)
783                                 return
784                         }
785
786                         b := buf.Bytes()
787                         largeObjects.Set(repo, oid, b)
788                         w.Header().Set("Upload-Offset", strconv.FormatInt(int64(len(b)), 10))
789                         w.WriteHeader(204)
790                 }
791
792         default:
793                 w.WriteHeader(405)
794         }
795 }
796
797 func validateTusHeaders(r *http.Request, id string) bool {
798         if len(r.Header.Get("Tus-Resumable")) == 0 {
799                 debug(id, "Missing Tus-Resumable header in request")
800                 return false
801         }
802         return true
803 }
804
805 func gitHandler(w http.ResponseWriter, r *http.Request) {
806         defer func() {
807                 io.Copy(ioutil.Discard, r.Body)
808                 r.Body.Close()
809         }()
810
811         cmd := exec.Command("git", "http-backend")
812         cmd.Env = []string{
813                 fmt.Sprintf("GIT_PROJECT_ROOT=%s", repoDir),
814                 fmt.Sprintf("GIT_HTTP_EXPORT_ALL="),
815                 fmt.Sprintf("PATH_INFO=%s", r.URL.Path),
816                 fmt.Sprintf("QUERY_STRING=%s", r.URL.RawQuery),
817                 fmt.Sprintf("REQUEST_METHOD=%s", r.Method),
818                 fmt.Sprintf("CONTENT_TYPE=%s", r.Header.Get("Content-Type")),
819         }
820
821         buffer := &bytes.Buffer{}
822         cmd.Stdin = r.Body
823         cmd.Stdout = buffer
824         cmd.Stderr = os.Stderr
825
826         if err := cmd.Run(); err != nil {
827                 log.Fatal(err)
828         }
829
830         text := textproto.NewReader(bufio.NewReader(buffer))
831
832         code, _, _ := text.ReadCodeLine(-1)
833
834         if code != 0 {
835                 w.WriteHeader(code)
836         }
837
838         headers, _ := text.ReadMIMEHeader()
839         head := w.Header()
840         for key, values := range headers {
841                 for _, value := range values {
842                         head.Add(key, value)
843                 }
844         }
845
846         io.Copy(w, text.R)
847 }
848
849 func redirect307Handler(w http.ResponseWriter, r *http.Request) {
850         id, ok := reqId(w)
851         if !ok {
852                 return
853         }
854
855         // Send a redirect to info/lfs
856         // Make it either absolute or relative depending on subpath
857         parts := strings.Split(r.URL.Path, "/")
858         // first element is always blank since rooted
859         var redirectTo string
860         if parts[2] == "rel" {
861                 redirectTo = "/" + strings.Join(parts[3:], "/")
862         } else if parts[2] == "abs" {
863                 redirectTo = server.URL + "/" + strings.Join(parts[3:], "/")
864         } else {
865                 debug(id, "Invalid URL for redirect: %v", r.URL)
866                 w.WriteHeader(404)
867                 return
868         }
869         w.Header().Set("Location", redirectTo)
870         w.WriteHeader(307)
871 }
872
873 type User struct {
874         Name string `json:"name"`
875 }
876
877 type Lock struct {
878         Id       string    `json:"id"`
879         Path     string    `json:"path"`
880         Owner    User      `json:"owner"`
881         LockedAt time.Time `json:"locked_at"`
882 }
883
884 type LockRequest struct {
885         Path string `json:"path"`
886         Ref  *Ref   `json:"ref,omitempty"`
887 }
888
889 func (r *LockRequest) RefName() string {
890         if r.Ref == nil {
891                 return ""
892         }
893         return r.Ref.Name
894 }
895
896 type LockResponse struct {
897         Lock    *Lock  `json:"lock"`
898         Message string `json:"message,omitempty"`
899 }
900
901 type UnlockRequest struct {
902         Force bool `json:"force"`
903         Ref   *Ref `json:"ref,omitempty"`
904 }
905
906 func (r *UnlockRequest) RefName() string {
907         if r.Ref == nil {
908                 return ""
909         }
910         return r.Ref.Name
911 }
912
913 type UnlockResponse struct {
914         Lock    *Lock  `json:"lock"`
915         Message string `json:"message,omitempty"`
916 }
917
918 type LockList struct {
919         Locks      []Lock `json:"locks"`
920         NextCursor string `json:"next_cursor,omitempty"`
921         Message    string `json:"message,omitempty"`
922 }
923
924 type Ref struct {
925         Name string `json:"name,omitempty"`
926 }
927
928 type VerifiableLockRequest struct {
929         Ref    *Ref   `json:"ref,omitempty"`
930         Cursor string `json:"cursor,omitempty"`
931         Limit  int    `json:"limit,omitempty"`
932 }
933
934 func (r *VerifiableLockRequest) RefName() string {
935         if r.Ref == nil {
936                 return ""
937         }
938         return r.Ref.Name
939 }
940
941 type VerifiableLockList struct {
942         Ours       []Lock `json:"ours"`
943         Theirs     []Lock `json:"theirs"`
944         NextCursor string `json:"next_cursor,omitempty"`
945         Message    string `json:"message,omitempty"`
946 }
947
948 var (
949         lmu       sync.RWMutex
950         repoLocks = map[string][]Lock{}
951 )
952
953 func addLocks(repo string, l ...Lock) {
954         lmu.Lock()
955         defer lmu.Unlock()
956         repoLocks[repo] = append(repoLocks[repo], l...)
957         sort.Sort(LocksByCreatedAt(repoLocks[repo]))
958 }
959
960 func getLocks(repo string) []Lock {
961         lmu.RLock()
962         defer lmu.RUnlock()
963
964         locks := repoLocks[repo]
965         cp := make([]Lock, len(locks))
966         for i, l := range locks {
967                 cp[i] = l
968         }
969
970         return cp
971 }
972
973 func getFilteredLocks(repo, path, cursor, limit string) ([]Lock, string, error) {
974         locks := getLocks(repo)
975         if cursor != "" {
976                 lastSeen := -1
977                 for i, l := range locks {
978                         if l.Id == cursor {
979                                 lastSeen = i
980                                 break
981                         }
982                 }
983
984                 if lastSeen > -1 {
985                         locks = locks[lastSeen:]
986                 } else {
987                         return nil, "", fmt.Errorf("cursor (%s) not found", cursor)
988                 }
989         }
990
991         if path != "" {
992                 var filtered []Lock
993                 for _, l := range locks {
994                         if l.Path == path {
995                                 filtered = append(filtered, l)
996                         }
997                 }
998
999                 locks = filtered
1000         }
1001
1002         if limit != "" {
1003                 size, err := strconv.Atoi(limit)
1004                 if err != nil {
1005                         return nil, "", errors.New("unable to parse limit amount")
1006                 }
1007
1008                 size = int(math.Min(float64(len(locks)), 3))
1009                 if size < 0 {
1010                         return nil, "", nil
1011                 }
1012
1013                 if size+1 < len(locks) {
1014                         return locks[:size], locks[size+1].Id, nil
1015                 }
1016         }
1017
1018         return locks, "", nil
1019 }
1020
1021 func delLock(repo string, id string) *Lock {
1022         lmu.RLock()
1023         defer lmu.RUnlock()
1024
1025         var deleted *Lock
1026         locks := make([]Lock, 0, len(repoLocks[repo]))
1027         for _, l := range repoLocks[repo] {
1028                 if l.Id == id {
1029                         deleted = &l
1030                         continue
1031                 }
1032                 locks = append(locks, l)
1033         }
1034         repoLocks[repo] = locks
1035         return deleted
1036 }
1037
1038 type LocksByCreatedAt []Lock
1039
1040 func (c LocksByCreatedAt) Len() int           { return len(c) }
1041 func (c LocksByCreatedAt) Less(i, j int) bool { return c[i].LockedAt.Before(c[j].LockedAt) }
1042 func (c LocksByCreatedAt) Swap(i, j int)      { c[i], c[j] = c[j], c[i] }
1043
1044 var (
1045         lockRe   = regexp.MustCompile(`/locks/?$`)
1046         unlockRe = regexp.MustCompile(`locks/([^/]+)/unlock\z`)
1047 )
1048
1049 func locksHandler(w http.ResponseWriter, r *http.Request, repo string) {
1050         dec := json.NewDecoder(r.Body)
1051         enc := json.NewEncoder(w)
1052
1053         switch r.Method {
1054         case "GET":
1055                 if !lockRe.MatchString(r.URL.Path) {
1056                         w.Header().Set("Content-Type", "application/json")
1057                         w.WriteHeader(http.StatusNotFound)
1058                         w.Write([]byte(`{"message":"unknown path: ` + r.URL.Path + `"}`))
1059                         return
1060                 }
1061
1062                 if err := r.ParseForm(); err != nil {
1063                         http.Error(w, "could not parse form values", http.StatusInternalServerError)
1064                         return
1065                 }
1066
1067                 if strings.HasSuffix(repo, "branch-required") {
1068                         parts := strings.Split(repo, "-")
1069                         lenParts := len(parts)
1070                         if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != r.FormValue("refspec") {
1071                                 w.WriteHeader(403)
1072                                 enc.Encode(struct {
1073                                         Message string `json:"message"`
1074                                 }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], r.FormValue("refspec"))})
1075                                 return
1076                         }
1077                 }
1078
1079                 ll := &LockList{}
1080                 w.Header().Set("Content-Type", "application/json")
1081                 locks, nextCursor, err := getFilteredLocks(repo,
1082                         r.FormValue("path"),
1083                         r.FormValue("cursor"),
1084                         r.FormValue("limit"))
1085
1086                 if err != nil {
1087                         ll.Message = err.Error()
1088                 } else {
1089                         ll.Locks = locks
1090                         ll.NextCursor = nextCursor
1091                 }
1092
1093                 enc.Encode(ll)
1094                 return
1095         case "POST":
1096                 w.Header().Set("Content-Type", "application/json")
1097                 if strings.HasSuffix(r.URL.Path, "unlock") {
1098                         var lockId string
1099                         if matches := unlockRe.FindStringSubmatch(r.URL.Path); len(matches) > 1 {
1100                                 lockId = matches[1]
1101                         }
1102
1103                         if len(lockId) == 0 {
1104                                 enc.Encode(&UnlockResponse{Message: "Invalid lock"})
1105                         }
1106
1107                         unlockRequest := &UnlockRequest{}
1108                         if err := dec.Decode(unlockRequest); err != nil {
1109                                 enc.Encode(&UnlockResponse{Message: err.Error()})
1110                                 return
1111                         }
1112
1113                         if strings.HasSuffix(repo, "branch-required") {
1114                                 parts := strings.Split(repo, "-")
1115                                 lenParts := len(parts)
1116                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != unlockRequest.RefName() {
1117                                         w.WriteHeader(403)
1118                                         enc.Encode(struct {
1119                                                 Message string `json:"message"`
1120                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], unlockRequest.RefName())})
1121                                         return
1122                                 }
1123                         }
1124
1125                         if l := delLock(repo, lockId); l != nil {
1126                                 enc.Encode(&UnlockResponse{Lock: l})
1127                         } else {
1128                                 enc.Encode(&UnlockResponse{Message: "unable to find lock"})
1129                         }
1130                         return
1131                 }
1132
1133                 if strings.HasSuffix(r.URL.Path, "/locks/verify") {
1134                         if strings.HasSuffix(repo, "verify-5xx") {
1135                                 w.WriteHeader(500)
1136                                 return
1137                         }
1138                         if strings.HasSuffix(repo, "verify-501") {
1139                                 w.WriteHeader(501)
1140                                 return
1141                         }
1142                         if strings.HasSuffix(repo, "verify-403") {
1143                                 w.WriteHeader(403)
1144                                 return
1145                         }
1146
1147                         switch repo {
1148                         case "pre_push_locks_verify_404":
1149                                 w.WriteHeader(http.StatusNotFound)
1150                                 w.Write([]byte(`{"message":"pre_push_locks_verify_404"}`))
1151                                 return
1152                         case "pre_push_locks_verify_410":
1153                                 w.WriteHeader(http.StatusGone)
1154                                 w.Write([]byte(`{"message":"pre_push_locks_verify_410"}`))
1155                                 return
1156                         }
1157
1158                         reqBody := &VerifiableLockRequest{}
1159                         if err := dec.Decode(reqBody); err != nil {
1160                                 w.WriteHeader(http.StatusBadRequest)
1161                                 enc.Encode(struct {
1162                                         Message string `json:"message"`
1163                                 }{"json decode error: " + err.Error()})
1164                                 return
1165                         }
1166
1167                         if strings.HasSuffix(repo, "branch-required") {
1168                                 parts := strings.Split(repo, "-")
1169                                 lenParts := len(parts)
1170                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != reqBody.RefName() {
1171                                         w.WriteHeader(403)
1172                                         enc.Encode(struct {
1173                                                 Message string `json:"message"`
1174                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], reqBody.RefName())})
1175                                         return
1176                                 }
1177                         }
1178
1179                         ll := &VerifiableLockList{}
1180                         locks, nextCursor, err := getFilteredLocks(repo, "",
1181                                 reqBody.Cursor,
1182                                 strconv.Itoa(reqBody.Limit))
1183                         if err != nil {
1184                                 ll.Message = err.Error()
1185                         } else {
1186                                 ll.NextCursor = nextCursor
1187
1188                                 for _, l := range locks {
1189                                         if strings.Contains(l.Path, "theirs") {
1190                                                 ll.Theirs = append(ll.Theirs, l)
1191                                         } else {
1192                                                 ll.Ours = append(ll.Ours, l)
1193                                         }
1194                                 }
1195                         }
1196
1197                         enc.Encode(ll)
1198                         return
1199                 }
1200
1201                 if strings.HasSuffix(r.URL.Path, "/locks") {
1202                         lockRequest := &LockRequest{}
1203                         if err := dec.Decode(lockRequest); err != nil {
1204                                 enc.Encode(&LockResponse{Message: err.Error()})
1205                         }
1206
1207                         if strings.HasSuffix(repo, "branch-required") {
1208                                 parts := strings.Split(repo, "-")
1209                                 lenParts := len(parts)
1210                                 if lenParts > 3 && "refs/heads/"+parts[lenParts-3] != lockRequest.RefName() {
1211                                         w.WriteHeader(403)
1212                                         enc.Encode(struct {
1213                                                 Message string `json:"message"`
1214                                         }{fmt.Sprintf("Expected ref %q, got %q", "refs/heads/"+parts[lenParts-3], lockRequest.RefName())})
1215                                         return
1216                                 }
1217                         }
1218
1219                         for _, l := range getLocks(repo) {
1220                                 if l.Path == lockRequest.Path {
1221                                         enc.Encode(&LockResponse{Message: "lock already created"})
1222                                         return
1223                                 }
1224                         }
1225
1226                         var id [20]byte
1227                         rand.Read(id[:])
1228
1229                         lock := &Lock{
1230                                 Id:       fmt.Sprintf("%x", id[:]),
1231                                 Path:     lockRequest.Path,
1232                                 Owner:    User{Name: "Git LFS Tests"},
1233                                 LockedAt: time.Now(),
1234                         }
1235
1236                         addLocks(repo, *lock)
1237
1238                         // TODO(taylor): commit_needed case
1239                         // TODO(taylor): err case
1240
1241                         enc.Encode(&LockResponse{
1242                                 Lock: lock,
1243                         })
1244                         return
1245                 }
1246         }
1247
1248         http.NotFound(w, r)
1249 }
1250
1251 func missingRequiredCreds(w http.ResponseWriter, r *http.Request, repo string) bool {
1252         if !strings.HasPrefix(repo, "requirecreds") {
1253                 return false
1254         }
1255
1256         auth := r.Header.Get("Authorization")
1257         user, pass, err := extractAuth(auth)
1258         if err != nil {
1259                 writeLFSError(w, 403, err.Error())
1260                 return true
1261         }
1262
1263         if user != "requirecreds" || pass != "pass" {
1264                 writeLFSError(w, 403, fmt.Sprintf("Got: '%s' => '%s' : '%s'", auth, user, pass))
1265                 return true
1266         }
1267
1268         return false
1269 }
1270
1271 func testingChunkedTransferEncoding(r *http.Request) bool {
1272         return strings.HasPrefix(r.URL.String(), "/test-chunked-transfer-encoding")
1273 }
1274
1275 func testingTusUploadInBatchReq(r *http.Request) bool {
1276         return strings.HasPrefix(r.URL.String(), "/test-tus-upload")
1277 }
1278 func testingTusUploadInterruptedInBatchReq(r *http.Request) bool {
1279         return strings.HasPrefix(r.URL.String(), "/test-tus-upload-interrupt")
1280 }
1281 func testingCustomTransfer(r *http.Request) bool {
1282         return strings.HasPrefix(r.URL.String(), "/test-custom-transfer")
1283 }
1284
1285 var lfsUrlRE = regexp.MustCompile(`\A/?([^/]+)/info/lfs`)
1286
1287 func repoFromLfsUrl(urlpath string) (string, error) {
1288         matches := lfsUrlRE.FindStringSubmatch(urlpath)
1289         if len(matches) != 2 {
1290                 return "", fmt.Errorf("LFS url '%s' does not match %v", urlpath, lfsUrlRE)
1291         }
1292
1293         repo := matches[1]
1294         if strings.HasSuffix(repo, ".git") {
1295                 return repo[0 : len(repo)-4], nil
1296         }
1297         return repo, nil
1298 }
1299
1300 type lfsStorage struct {
1301         objects    map[string]map[string][]byte
1302         incomplete map[string]map[string][]byte
1303         mutex      *sync.Mutex
1304 }
1305
1306 func (s *lfsStorage) Get(repo, oid string) ([]byte, bool) {
1307         s.mutex.Lock()
1308         defer s.mutex.Unlock()
1309         repoObjects, ok := s.objects[repo]
1310         if !ok {
1311                 return nil, ok
1312         }
1313
1314         by, ok := repoObjects[oid]
1315         return by, ok
1316 }
1317
1318 func (s *lfsStorage) Has(repo, oid string) bool {
1319         s.mutex.Lock()
1320         defer s.mutex.Unlock()
1321         repoObjects, ok := s.objects[repo]
1322         if !ok {
1323                 return false
1324         }
1325
1326         _, ok = repoObjects[oid]
1327         return ok
1328 }
1329
1330 func (s *lfsStorage) Set(repo, oid string, by []byte) {
1331         s.mutex.Lock()
1332         defer s.mutex.Unlock()
1333         repoObjects, ok := s.objects[repo]
1334         if !ok {
1335                 repoObjects = make(map[string][]byte)
1336                 s.objects[repo] = repoObjects
1337         }
1338         repoObjects[oid] = by
1339 }
1340
1341 func (s *lfsStorage) Delete(repo, oid string) {
1342         s.mutex.Lock()
1343         defer s.mutex.Unlock()
1344         repoObjects, ok := s.objects[repo]
1345         if ok {
1346                 delete(repoObjects, oid)
1347         }
1348 }
1349
1350 func (s *lfsStorage) GetIncomplete(repo, oid string) ([]byte, bool) {
1351         s.mutex.Lock()
1352         defer s.mutex.Unlock()
1353         repoObjects, ok := s.incomplete[repo]
1354         if !ok {
1355                 return nil, ok
1356         }
1357
1358         by, ok := repoObjects[oid]
1359         return by, ok
1360 }
1361
1362 func (s *lfsStorage) SetIncomplete(repo, oid string, by []byte) {
1363         s.mutex.Lock()
1364         defer s.mutex.Unlock()
1365         repoObjects, ok := s.incomplete[repo]
1366         if !ok {
1367                 repoObjects = make(map[string][]byte)
1368                 s.incomplete[repo] = repoObjects
1369         }
1370         repoObjects[oid] = by
1371 }
1372
1373 func (s *lfsStorage) DeleteIncomplete(repo, oid string) {
1374         s.mutex.Lock()
1375         defer s.mutex.Unlock()
1376         repoObjects, ok := s.incomplete[repo]
1377         if ok {
1378                 delete(repoObjects, oid)
1379         }
1380 }
1381
1382 func newLfsStorage() *lfsStorage {
1383         return &lfsStorage{
1384                 objects:    make(map[string]map[string][]byte),
1385                 incomplete: make(map[string]map[string][]byte),
1386                 mutex:      &sync.Mutex{},
1387         }
1388 }
1389
1390 func extractAuth(auth string) (string, string, error) {
1391         if strings.HasPrefix(auth, "Basic ") {
1392                 decodeBy, err := base64.StdEncoding.DecodeString(auth[6:len(auth)])
1393                 decoded := string(decodeBy)
1394
1395                 if err != nil {
1396                         return "", "", err
1397                 }
1398
1399                 parts := strings.SplitN(decoded, ":", 2)
1400                 if len(parts) == 2 {
1401                         return parts[0], parts[1], nil
1402                 }
1403                 return "", "", nil
1404         }
1405
1406         return "", "", nil
1407 }
1408
1409 func skipIfBadAuth(w http.ResponseWriter, r *http.Request, id string, ntlmSession ntlm.ServerSession) bool {
1410         auth := r.Header.Get("Authorization")
1411         if strings.Contains(r.URL.Path, "ntlm") {
1412                 return false
1413         }
1414
1415         if auth == "" {
1416                 w.WriteHeader(401)
1417                 return true
1418         }
1419
1420         user, pass, err := extractAuth(auth)
1421         if err != nil {
1422                 w.WriteHeader(403)
1423                 debug(id, "Error decoding auth: %s", err)
1424                 return true
1425         }
1426
1427         switch user {
1428         case "user":
1429                 if pass == "pass" {
1430                         return false
1431                 }
1432         case "netrcuser", "requirecreds":
1433                 return false
1434         case "path":
1435                 if strings.HasPrefix(r.URL.Path, "/"+pass) {
1436                         return false
1437                 }
1438                 debug(id, "auth attempt against: %q", r.URL.Path)
1439         }
1440
1441         w.WriteHeader(403)
1442         debug(id, "Bad auth: %q", auth)
1443         return true
1444 }
1445
1446 func handleNTLM(w http.ResponseWriter, r *http.Request, authHeader string, session ntlm.ServerSession) {
1447         if strings.HasPrefix(strings.ToUpper(authHeader), "BASIC ") {
1448                 authHeader = ""
1449         }
1450
1451         switch authHeader {
1452         case "":
1453                 w.Header().Set("Www-Authenticate", "ntlm")
1454                 w.WriteHeader(401)
1455
1456         // ntlmNegotiateMessage from httputil pkg
1457         case "NTLM TlRMTVNTUAABAAAAB7IIogwADAAzAAAACwALACgAAAAKAAAoAAAAD1dJTExISS1NQUlOTk9SVEhBTUVSSUNB":
1458                 ch, err := session.GenerateChallengeMessage()
1459                 if err != nil {
1460                         writeLFSError(w, 500, err.Error())
1461                         return
1462                 }
1463
1464                 chMsg := base64.StdEncoding.EncodeToString(ch.Bytes())
1465                 w.Header().Set("Www-Authenticate", "ntlm "+chMsg)
1466                 w.WriteHeader(401)
1467
1468         default:
1469                 if !strings.HasPrefix(strings.ToUpper(authHeader), "NTLM ") {
1470                         writeLFSError(w, 500, "bad authorization header: "+authHeader)
1471                         return
1472                 }
1473
1474                 auth := authHeader[5:] // strip "ntlm " prefix
1475                 val, err := base64.StdEncoding.DecodeString(auth)
1476                 if err != nil {
1477                         writeLFSError(w, 500, "base64 decode error: "+err.Error())
1478                         return
1479                 }
1480
1481                 _, err = ntlm.ParseAuthenticateMessage(val, 2)
1482                 if err != nil {
1483                         writeLFSError(w, 500, "auth parse error: "+err.Error())
1484                         return
1485                 }
1486         }
1487 }
1488
1489 func init() {
1490         oidHandlers = make(map[string]string)
1491         for _, content := range contentHandlers {
1492                 h := sha256.New()
1493                 h.Write([]byte(content))
1494                 oidHandlers[hex.EncodeToString(h.Sum(nil))] = content
1495         }
1496 }
1497
1498 func debug(reqid, msg string, args ...interface{}) {
1499         fullargs := make([]interface{}, len(args)+1)
1500         fullargs[0] = reqid
1501         for i, a := range args {
1502                 fullargs[i+1] = a
1503         }
1504         log.Printf("[%s] "+msg+"\n", fullargs...)
1505 }
1506
1507 func reqId(w http.ResponseWriter) (string, bool) {
1508         b := make([]byte, 16)
1509         _, err := rand.Read(b)
1510         if err != nil {
1511                 http.Error(w, "error generating id: "+err.Error(), 500)
1512                 return "", false
1513         }
1514         return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]), true
1515 }
1516
1517 // https://ericchiang.github.io/post/go-tls/
1518 func generateCARootCertificates() (rootKey *rsa.PrivateKey, rootCert *x509.Certificate) {
1519
1520         // generate a new key-pair
1521         rootKey, err := rsa.GenerateKey(rand.Reader, 2048)
1522         if err != nil {
1523                 log.Fatalf("generating random key: %v", err)
1524         }
1525
1526         rootCertTmpl, err := CertTemplate()
1527         if err != nil {
1528                 log.Fatalf("creating cert template: %v", err)
1529         }
1530         // describe what the certificate will be used for
1531         rootCertTmpl.IsCA = true
1532         rootCertTmpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature
1533         rootCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}
1534         //      rootCertTmpl.IPAddresses = []net.IP{net.ParseIP("127.0.0.1")}
1535
1536         rootCert, _, err = CreateCert(rootCertTmpl, rootCertTmpl, &rootKey.PublicKey, rootKey)
1537
1538         return
1539 }
1540
1541 func generateClientCertificates(rootCert *x509.Certificate, rootKey interface{}) (clientKey *rsa.PrivateKey, clientCertPEM []byte, clientKeyPEM []byte) {
1542
1543         // create a key-pair for the client
1544         clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
1545         if err != nil {
1546                 log.Fatalf("generating random key: %v", err)
1547         }
1548
1549         // create a template for the client
1550         clientCertTmpl, err1 := CertTemplate()
1551         if err1 != nil {
1552                 log.Fatalf("creating cert template: %v", err1)
1553         }
1554         clientCertTmpl.KeyUsage = x509.KeyUsageDigitalSignature
1555         clientCertTmpl.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
1556
1557         // the root cert signs the cert by again providing its private key
1558         _, clientCertPEM, err2 := CreateCert(clientCertTmpl, rootCert, &clientKey.PublicKey, rootKey)
1559         if err2 != nil {
1560                 log.Fatalf("error creating cert: %v", err2)
1561         }
1562
1563         // encode and load the cert and private key for the client
1564         clientKeyPEM = pem.EncodeToMemory(&pem.Block{
1565                 Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey),
1566         })
1567
1568         return
1569 }
1570
1571 // helper function to create a cert template with a serial number and other required fields
1572 func CertTemplate() (*x509.Certificate, error) {
1573         // generate a random serial number (a real cert authority would have some logic behind this)
1574         serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
1575         serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
1576         if err != nil {
1577                 return nil, errors.New("failed to generate serial number: " + err.Error())
1578         }
1579
1580         tmpl := x509.Certificate{
1581                 SerialNumber:          serialNumber,
1582                 Subject:               pkix.Name{Organization: []string{"Yhat, Inc."}},
1583                 SignatureAlgorithm:    x509.SHA256WithRSA,
1584                 NotBefore:             time.Now(),
1585                 NotAfter:              time.Now().Add(time.Hour), // valid for an hour
1586                 BasicConstraintsValid: true,
1587         }
1588         return &tmpl, nil
1589 }
1590
1591 func CreateCert(template, parent *x509.Certificate, pub interface{}, parentPriv interface{}) (
1592         cert *x509.Certificate, certPEM []byte, err error) {
1593
1594         certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, parentPriv)
1595         if err != nil {
1596                 return
1597         }
1598         // parse the resulting certificate so we can use it again
1599         cert, err = x509.ParseCertificate(certDER)
1600         if err != nil {
1601                 return
1602         }
1603         // PEM encode the certificate (this is a standard TLS encoding)
1604         b := pem.Block{Type: "CERTIFICATE", Bytes: certDER}
1605         certPEM = pem.EncodeToMemory(&b)
1606         return
1607 }