Imported Upstream version 2.3.4
[scm/test.git] / lfs / pointer_smudge.go
1 package lfs
2
3 import (
4         "fmt"
5         "io"
6         "os"
7         "path/filepath"
8
9         "github.com/git-lfs/git-lfs/tools"
10         "github.com/git-lfs/git-lfs/tools/humanize"
11         "github.com/git-lfs/git-lfs/tq"
12
13         "github.com/git-lfs/git-lfs/config"
14         "github.com/git-lfs/git-lfs/errors"
15         "github.com/git-lfs/git-lfs/progress"
16         "github.com/rubyist/tracerx"
17 )
18
19 func PointerSmudgeToFile(filename string, ptr *Pointer, download bool, manifest *tq.Manifest, cb progress.CopyCallback) error {
20         os.MkdirAll(filepath.Dir(filename), 0755)
21         file, err := os.Create(filename)
22         if err != nil {
23                 return fmt.Errorf("Could not create working directory file: %v", err)
24         }
25         defer file.Close()
26         if _, err := PointerSmudge(file, ptr, filename, download, manifest, cb); err != nil {
27                 if errors.IsDownloadDeclinedError(err) {
28                         // write placeholder data instead
29                         file.Seek(0, os.SEEK_SET)
30                         ptr.Encode(file)
31                         return err
32                 } else {
33                         return fmt.Errorf("Could not write working directory file: %v", err)
34                 }
35         }
36         return nil
37 }
38
39 func PointerSmudge(writer io.Writer, ptr *Pointer, workingfile string, download bool, manifest *tq.Manifest, cb progress.CopyCallback) (int64, error) {
40         mediafile, err := LocalMediaPath(ptr.Oid)
41         if err != nil {
42                 return 0, err
43         }
44
45         LinkOrCopyFromReference(ptr.Oid, ptr.Size)
46
47         stat, statErr := os.Stat(mediafile)
48         if statErr == nil && stat != nil {
49                 fileSize := stat.Size()
50                 if fileSize == 0 || fileSize != ptr.Size {
51                         tracerx.Printf("Removing %s, size %d is invalid", mediafile, fileSize)
52                         os.RemoveAll(mediafile)
53                         stat = nil
54                 }
55         }
56
57         var n int64
58
59         if statErr != nil || stat == nil {
60                 if download {
61                         n, err = downloadFile(writer, ptr, workingfile, mediafile, manifest, cb)
62                 } else {
63                         return 0, errors.NewDownloadDeclinedError(statErr, "smudge")
64                 }
65         } else {
66                 n, err = readLocalFile(writer, ptr, mediafile, workingfile, cb)
67         }
68
69         if err != nil {
70                 return 0, errors.NewSmudgeError(err, ptr.Oid, mediafile)
71         }
72
73         return n, nil
74 }
75
76 func downloadFile(writer io.Writer, ptr *Pointer, workingfile, mediafile string, manifest *tq.Manifest, cb progress.CopyCallback) (int64, error) {
77         fmt.Fprintf(os.Stderr, "Downloading %s (%s)\n", workingfile, humanize.FormatBytes(uint64(ptr.Size)))
78
79         // NOTE: if given, "cb" is a progress.CopyCallback which writes updates
80         // to the logpath specified by GIT_LFS_PROGRESS.
81         //
82         // Either way, forward it into the *tq.TransferQueue so that updates are
83         // sent over correctly.
84         q := tq.NewTransferQueue(tq.Download, manifest, "", tq.WithProgressCallback(cb))
85         q.Add(filepath.Base(workingfile), mediafile, ptr.Oid, ptr.Size)
86         q.Wait()
87
88         if errs := q.Errors(); len(errs) > 0 {
89                 var multiErr error
90                 for _, e := range errs {
91                         if multiErr != nil {
92                                 multiErr = fmt.Errorf("%v\n%v", multiErr, e)
93                         } else {
94                                 multiErr = e
95                         }
96                         return 0, errors.Wrapf(multiErr, "Error downloading %s (%s)", workingfile, ptr.Oid)
97                 }
98         }
99
100         return readLocalFile(writer, ptr, mediafile, workingfile, nil)
101 }
102
103 func readLocalFile(writer io.Writer, ptr *Pointer, mediafile string, workingfile string, cb progress.CopyCallback) (int64, error) {
104         reader, err := os.Open(mediafile)
105         if err != nil {
106                 return 0, errors.Wrapf(err, "Error opening media file.")
107         }
108         defer reader.Close()
109
110         if ptr.Size == 0 {
111                 if stat, _ := os.Stat(mediafile); stat != nil {
112                         ptr.Size = stat.Size()
113                 }
114         }
115
116         if len(ptr.Extensions) > 0 {
117                 registeredExts := config.Config.Extensions()
118                 extensions := make(map[string]config.Extension)
119                 for _, ptrExt := range ptr.Extensions {
120                         ext, ok := registeredExts[ptrExt.Name]
121                         if !ok {
122                                 err := fmt.Errorf("Extension '%s' is not configured.", ptrExt.Name)
123                                 return 0, errors.Wrap(err, "smudge")
124                         }
125                         ext.Priority = ptrExt.Priority
126                         extensions[ext.Name] = ext
127                 }
128                 exts, err := config.SortExtensions(extensions)
129                 if err != nil {
130                         return 0, errors.Wrap(err, "smudge")
131                 }
132
133                 // pipe extensions in reverse order
134                 var extsR []config.Extension
135                 for i := range exts {
136                         ext := exts[len(exts)-1-i]
137                         extsR = append(extsR, ext)
138                 }
139
140                 request := &pipeRequest{"smudge", reader, workingfile, extsR}
141
142                 response, err := pipeExtensions(request)
143                 if err != nil {
144                         return 0, errors.Wrap(err, "smudge")
145                 }
146
147                 actualExts := make(map[string]*pipeExtResult)
148                 for _, result := range response.results {
149                         actualExts[result.name] = result
150                 }
151
152                 // verify name, order, and oids
153                 oid := response.results[0].oidIn
154                 if ptr.Oid != oid {
155                         err = fmt.Errorf("Actual oid %s during smudge does not match expected %s", oid, ptr.Oid)
156                         return 0, errors.Wrap(err, "smudge")
157                 }
158
159                 for _, expected := range ptr.Extensions {
160                         actual := actualExts[expected.Name]
161                         if actual.name != expected.Name {
162                                 err = fmt.Errorf("Actual extension name '%s' does not match expected '%s'", actual.name, expected.Name)
163                                 return 0, errors.Wrap(err, "smudge")
164                         }
165                         if actual.oidOut != expected.Oid {
166                                 err = fmt.Errorf("Actual oid %s for extension '%s' does not match expected %s", actual.oidOut, expected.Name, expected.Oid)
167                                 return 0, errors.Wrap(err, "smudge")
168                         }
169                 }
170
171                 // setup reader
172                 reader, err = os.Open(response.file.Name())
173                 if err != nil {
174                         return 0, errors.Wrapf(err, "Error opening smudged file: %s", err)
175                 }
176                 defer reader.Close()
177         }
178
179         n, err := tools.CopyWithCallback(writer, reader, ptr.Size, cb)
180         if err != nil {
181                 return n, errors.Wrapf(err, "Error reading from media file: %s", err)
182         }
183
184         return n, nil
185 }