Imported Upstream version 1.0.0
[platform/upstream/nghttp2.git] / integration-tests / server_tester.go
1 package nghttp2
2
3 import (
4         "bufio"
5         "bytes"
6         "crypto/tls"
7         "errors"
8         "fmt"
9         "github.com/bradfitz/http2"
10         "github.com/bradfitz/http2/hpack"
11         "github.com/tatsuhiro-t/go-nghttp2"
12         "golang.org/x/net/spdy"
13         "io"
14         "io/ioutil"
15         "net"
16         "net/http"
17         "net/http/httptest"
18         "net/url"
19         "os/exec"
20         "sort"
21         "strconv"
22         "strings"
23         "testing"
24         "time"
25 )
26
27 const (
28         serverBin  = buildDir + "/src/nghttpx"
29         serverPort = 3009
30         testDir    = buildDir + "/integration-tests"
31 )
32
33 func pair(name, value string) hpack.HeaderField {
34         return hpack.HeaderField{
35                 Name:  name,
36                 Value: value,
37         }
38 }
39
40 type serverTester struct {
41         args          []string  // command-line arguments
42         cmd           *exec.Cmd // test frontend server process, which is test subject
43         url           string    // test frontend server URL
44         t             *testing.T
45         ts            *httptest.Server // backend server
46         frontendHost  string           // frontend server host
47         backendHost   string           // backend server host
48         conn          net.Conn         // connection to frontend server
49         h2PrefaceSent bool             // HTTP/2 preface was sent in conn
50         nextStreamID  uint32           // next stream ID
51         fr            *http2.Framer    // HTTP/2 framer
52         spdyFr        *spdy.Framer     // SPDY/3.1 framer
53         headerBlkBuf  bytes.Buffer     // buffer to store encoded header block
54         enc           *hpack.Encoder   // HTTP/2 HPACK encoder
55         header        http.Header      // received header fields
56         dec           *hpack.Decoder   // HTTP/2 HPACK decoder
57         authority     string           // server's host:port
58         frCh          chan http2.Frame // used for incoming HTTP/2 frame
59         spdyFrCh      chan spdy.Frame  // used for incoming SPDY frame
60         errCh         chan error
61 }
62
63 // newServerTester creates test context for plain TCP frontend
64 // connection.
65 func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
66         return newServerTesterInternal(args, t, handler, false, nil)
67 }
68
69 // newServerTester creates test context for TLS frontend connection.
70 func newServerTesterTLS(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
71         return newServerTesterInternal(args, t, handler, true, nil)
72 }
73
74 // newServerTester creates test context for TLS frontend connection
75 // with given clientConfig
76 func newServerTesterTLSConfig(args []string, t *testing.T, handler http.HandlerFunc, clientConfig *tls.Config) *serverTester {
77         return newServerTesterInternal(args, t, handler, true, clientConfig)
78 }
79
80 // newServerTesterInternal creates test context.  If frontendTLS is
81 // true, set up TLS frontend connection.
82 func newServerTesterInternal(args []string, t *testing.T, handler http.HandlerFunc, frontendTLS bool, clientConfig *tls.Config) *serverTester {
83         ts := httptest.NewUnstartedServer(handler)
84
85         backendTLS := false
86         for _, k := range args {
87                 switch k {
88                 case "--http2-bridge":
89                         backendTLS = true
90                 }
91         }
92         if backendTLS {
93                 nghttp2.ConfigureServer(ts.Config, &nghttp2.Server{})
94                 // According to httptest/server.go, we have to set
95                 // NextProtos separately for ts.TLS.  NextProtos set
96                 // in nghttp2.ConfigureServer is effectively ignored.
97                 ts.TLS = new(tls.Config)
98                 ts.TLS.NextProtos = append(ts.TLS.NextProtos, "h2-14")
99                 ts.StartTLS()
100                 args = append(args, "-k")
101         } else {
102                 ts.Start()
103         }
104         scheme := "http"
105         if frontendTLS {
106                 scheme = "https"
107                 args = append(args, testDir+"/server.key", testDir+"/server.crt")
108         } else {
109                 args = append(args, "--frontend-no-tls")
110         }
111
112         backendURL, err := url.Parse(ts.URL)
113         if err != nil {
114                 t.Fatalf("Error parsing URL from httptest.Server: %v", err)
115         }
116
117         // URL.Host looks like "127.0.0.1:8080", but we want
118         // "127.0.0.1,8080"
119         b := "-b" + strings.Replace(backendURL.Host, ":", ",", -1)
120         args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b,
121                 "--errorlog-file="+testDir+"/log.txt", "-LINFO")
122
123         authority := fmt.Sprintf("127.0.0.1:%v", serverPort)
124
125         st := &serverTester{
126                 cmd:          exec.Command(serverBin, args...),
127                 t:            t,
128                 ts:           ts,
129                 url:          fmt.Sprintf("%v://%v", scheme, authority),
130                 frontendHost: fmt.Sprintf("127.0.0.1:%v", serverPort),
131                 backendHost:  backendURL.Host,
132                 nextStreamID: 1,
133                 authority:    authority,
134                 frCh:         make(chan http2.Frame),
135                 spdyFrCh:     make(chan spdy.Frame),
136                 errCh:        make(chan error),
137         }
138
139         if err := st.cmd.Start(); err != nil {
140                 st.t.Fatalf("Error starting %v: %v", serverBin, err)
141         }
142
143         retry := 0
144         for {
145                 var conn net.Conn
146                 var err error
147                 if frontendTLS {
148                         var tlsConfig *tls.Config
149                         if clientConfig == nil {
150                                 tlsConfig = new(tls.Config)
151                         } else {
152                                 tlsConfig = clientConfig
153                         }
154                         tlsConfig.InsecureSkipVerify = true
155                         tlsConfig.NextProtos = []string{"h2-14", "spdy/3.1"}
156                         conn, err = tls.Dial("tcp", authority, tlsConfig)
157                 } else {
158                         conn, err = net.Dial("tcp", authority)
159                 }
160                 if err != nil {
161                         retry += 1
162                         if retry >= 100 {
163                                 st.Close()
164                                 st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
165                         }
166                         time.Sleep(150 * time.Millisecond)
167                         continue
168                 }
169                 if frontendTLS {
170                         tlsConn := conn.(*tls.Conn)
171                         cs := tlsConn.ConnectionState()
172                         if !cs.NegotiatedProtocolIsMutual {
173                                 st.Close()
174                                 st.t.Fatalf("Error negotiated next protocol is not mutual")
175                         }
176                 }
177                 st.conn = conn
178                 break
179         }
180
181         st.fr = http2.NewFramer(st.conn, st.conn)
182         spdyFr, err := spdy.NewFramer(st.conn, st.conn)
183         if err != nil {
184                 st.Close()
185                 st.t.Fatalf("Error spdy.NewFramer: %v", err)
186         }
187         st.spdyFr = spdyFr
188         st.enc = hpack.NewEncoder(&st.headerBlkBuf)
189         st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
190                 st.header.Add(f.Name, f.Value)
191         })
192
193         return st
194 }
195
196 func (st *serverTester) Close() {
197         if st.conn != nil {
198                 st.conn.Close()
199         }
200         if st.cmd != nil {
201                 st.cmd.Process.Kill()
202                 st.cmd.Wait()
203         }
204         if st.ts != nil {
205                 st.ts.Close()
206         }
207 }
208
209 func (st *serverTester) readFrame() (http2.Frame, error) {
210         go func() {
211                 f, err := st.fr.ReadFrame()
212                 if err != nil {
213                         st.errCh <- err
214                         return
215                 }
216                 st.frCh <- f
217         }()
218
219         select {
220         case f := <-st.frCh:
221                 return f, nil
222         case err := <-st.errCh:
223                 return nil, err
224         case <-time.After(5 * time.Second):
225                 return nil, errors.New("timeout waiting for frame")
226         }
227 }
228
229 func (st *serverTester) readSpdyFrame() (spdy.Frame, error) {
230         go func() {
231                 f, err := st.spdyFr.ReadFrame()
232                 if err != nil {
233                         st.errCh <- err
234                         return
235                 }
236                 st.spdyFrCh <- f
237         }()
238
239         select {
240         case f := <-st.spdyFrCh:
241                 return f, nil
242         case err := <-st.errCh:
243                 return nil, err
244         case <-time.After(2 * time.Second):
245                 return nil, errors.New("timeout waiting for frame")
246         }
247 }
248
249 type requestParam struct {
250         name        string              // name for this request to identify the request in log easily
251         streamID    uint32              // stream ID, automatically assigned if 0
252         method      string              // method, defaults to GET
253         scheme      string              // scheme, defaults to http
254         authority   string              // authority, defaults to backend server address
255         path        string              // path, defaults to /
256         header      []hpack.HeaderField // additional request header fields
257         body        []byte              // request body
258         trailer     []hpack.HeaderField // trailer part
259         httpUpgrade bool                // true if upgraded to HTTP/2 through HTTP Upgrade
260 }
261
262 // wrapper for request body to set trailer part
263 type chunkedBodyReader struct {
264         trailer        []hpack.HeaderField
265         trailerWritten bool
266         body           io.Reader
267         req            *http.Request
268 }
269
270 func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
271         // document says that we have to set http.Request.Trailer
272         // after request was sent and before body returns EOF.
273         if !cbr.trailerWritten {
274                 cbr.trailerWritten = true
275                 for _, h := range cbr.trailer {
276                         cbr.req.Trailer.Set(h.Name, h.Value)
277                 }
278         }
279         return cbr.body.Read(p)
280 }
281
282 func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
283         method := "GET"
284         if rp.method != "" {
285                 method = rp.method
286         }
287
288         var body io.Reader
289         var cbr *chunkedBodyReader
290         if rp.body != nil {
291                 body = bytes.NewBuffer(rp.body)
292                 if len(rp.trailer) != 0 {
293                         cbr = &chunkedBodyReader{
294                                 trailer: rp.trailer,
295                                 body:    body,
296                         }
297                         body = cbr
298                 }
299         }
300
301         reqURL := st.url
302
303         if rp.path != "" {
304                 u, err := url.Parse(st.url)
305                 if err != nil {
306                         st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err)
307                 }
308                 u.Path = rp.path
309                 reqURL = u.String()
310         }
311
312         req, err := http.NewRequest(method, reqURL, body)
313         if err != nil {
314                 return nil, err
315         }
316         for _, h := range rp.header {
317                 req.Header.Add(h.Name, h.Value)
318         }
319         req.Header.Add("Test-Case", rp.name)
320         if cbr != nil {
321                 cbr.req = req
322                 // this makes request use chunked encoding
323                 req.ContentLength = -1
324                 req.Trailer = make(http.Header)
325                 for _, h := range cbr.trailer {
326                         req.Trailer.Set(h.Name, "")
327                 }
328         }
329         if err := req.Write(st.conn); err != nil {
330                 return nil, err
331         }
332         resp, err := http.ReadResponse(bufio.NewReader(st.conn), req)
333         if err != nil {
334                 return nil, err
335         }
336         respBody, err := ioutil.ReadAll(resp.Body)
337         if err != nil {
338                 return nil, err
339         }
340         resp.Body.Close()
341
342         res := &serverResponse{
343                 status:    resp.StatusCode,
344                 header:    resp.Header,
345                 body:      respBody,
346                 connClose: resp.Close,
347         }
348
349         return res, nil
350 }
351
352 func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) {
353         res := &serverResponse{}
354
355         var id spdy.StreamId
356         if rp.streamID != 0 {
357                 id = spdy.StreamId(rp.streamID)
358                 if id >= spdy.StreamId(st.nextStreamID) && id%2 == 1 {
359                         st.nextStreamID = uint32(id) + 2
360                 }
361         } else {
362                 id = spdy.StreamId(st.nextStreamID)
363                 st.nextStreamID += 2
364         }
365
366         method := "GET"
367         if rp.method != "" {
368                 method = rp.method
369         }
370
371         scheme := "http"
372         if rp.scheme != "" {
373                 scheme = rp.scheme
374         }
375
376         host := st.authority
377         if rp.authority != "" {
378                 host = rp.authority
379         }
380
381         path := "/"
382         if rp.path != "" {
383                 path = rp.path
384         }
385
386         header := make(http.Header)
387         header.Add(":method", method)
388         header.Add(":scheme", scheme)
389         header.Add(":host", host)
390         header.Add(":path", path)
391         header.Add(":version", "HTTP/1.1")
392         header.Add("test-case", rp.name)
393         for _, h := range rp.header {
394                 header.Add(h.Name, h.Value)
395         }
396
397         var synStreamFlags spdy.ControlFlags
398         if len(rp.body) == 0 {
399                 synStreamFlags = spdy.ControlFlagFin
400         }
401         if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{
402                 CFHeader: spdy.ControlFrameHeader{
403                         Flags: synStreamFlags,
404                 },
405                 StreamId: id,
406                 Headers:  header,
407         }); err != nil {
408                 return nil, err
409         }
410
411         if len(rp.body) != 0 {
412                 if err := st.spdyFr.WriteFrame(&spdy.DataFrame{
413                         StreamId: id,
414                         Flags:    spdy.DataFlagFin,
415                         Data:     rp.body,
416                 }); err != nil {
417                         return nil, err
418                 }
419         }
420
421 loop:
422         for {
423                 fr, err := st.readSpdyFrame()
424                 if err != nil {
425                         return res, err
426                 }
427                 switch f := fr.(type) {
428                 case *spdy.SynReplyFrame:
429                         if f.StreamId != id {
430                                 break
431                         }
432                         res.header = cloneHeader(f.Headers)
433                         if _, err := fmt.Sscan(res.header.Get(":status"), &res.status); err != nil {
434                                 return res, fmt.Errorf("Error parsing status code: %v", err)
435                         }
436                         if f.CFHeader.Flags&spdy.ControlFlagFin != 0 {
437                                 break loop
438                         }
439                 case *spdy.DataFrame:
440                         if f.StreamId != id {
441                                 break
442                         }
443                         res.body = append(res.body, f.Data...)
444                         if f.Flags&spdy.DataFlagFin != 0 {
445                                 break loop
446                         }
447                 case *spdy.RstStreamFrame:
448                         if f.StreamId != id {
449                                 break
450                         }
451                         res.spdyRstErrCode = f.Status
452                         break loop
453                 case *spdy.GoAwayFrame:
454                         if f.Status == spdy.GoAwayOK {
455                                 break
456                         }
457                         res.spdyGoAwayErrCode = f.Status
458                         break loop
459                 }
460         }
461         return res, nil
462 }
463
464 func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
465         st.headerBlkBuf.Reset()
466         st.header = make(http.Header)
467
468         var id uint32
469         if rp.streamID != 0 {
470                 id = rp.streamID
471                 if id >= st.nextStreamID && id%2 == 1 {
472                         st.nextStreamID = id + 2
473                 }
474         } else {
475                 id = st.nextStreamID
476                 st.nextStreamID += 2
477         }
478
479         if !st.h2PrefaceSent {
480                 st.h2PrefaceSent = true
481                 fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
482                 if err := st.fr.WriteSettings(); err != nil {
483                         return nil, err
484                 }
485         }
486
487         res := &serverResponse{
488                 streamID: id,
489         }
490
491         streams := make(map[uint32]*serverResponse)
492         streams[id] = res
493
494         if !rp.httpUpgrade {
495                 method := "GET"
496                 if rp.method != "" {
497                         method = rp.method
498                 }
499                 _ = st.enc.WriteField(pair(":method", method))
500
501                 scheme := "http"
502                 if rp.scheme != "" {
503                         scheme = rp.scheme
504                 }
505                 _ = st.enc.WriteField(pair(":scheme", scheme))
506
507                 authority := st.authority
508                 if rp.authority != "" {
509                         authority = rp.authority
510                 }
511                 _ = st.enc.WriteField(pair(":authority", authority))
512
513                 path := "/"
514                 if rp.path != "" {
515                         path = rp.path
516                 }
517                 _ = st.enc.WriteField(pair(":path", path))
518
519                 _ = st.enc.WriteField(pair("test-case", rp.name))
520
521                 for _, h := range rp.header {
522                         _ = st.enc.WriteField(h)
523                 }
524
525                 err := st.fr.WriteHeaders(http2.HeadersFrameParam{
526                         StreamID:      id,
527                         EndStream:     len(rp.body) == 0 && len(rp.trailer) == 0,
528                         EndHeaders:    true,
529                         BlockFragment: st.headerBlkBuf.Bytes(),
530                 })
531                 if err != nil {
532                         return nil, err
533                 }
534
535                 if len(rp.body) != 0 {
536                         // TODO we assume rp.body fits in 1 frame
537                         if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil {
538                                 return nil, err
539                         }
540                 }
541
542                 if len(rp.trailer) != 0 {
543                         st.headerBlkBuf.Reset()
544                         for _, h := range rp.trailer {
545                                 _ = st.enc.WriteField(h)
546                         }
547                         err := st.fr.WriteHeaders(http2.HeadersFrameParam{
548                                 StreamID:      id,
549                                 EndStream:     true,
550                                 EndHeaders:    true,
551                                 BlockFragment: st.headerBlkBuf.Bytes(),
552                         })
553                         if err != nil {
554                                 return nil, err
555                         }
556                 }
557         }
558 loop:
559         for {
560                 fr, err := st.readFrame()
561                 if err != nil {
562                         return res, err
563                 }
564                 switch f := fr.(type) {
565                 case *http2.HeadersFrame:
566                         _, err := st.dec.Write(f.HeaderBlockFragment())
567                         if err != nil {
568                                 return res, err
569                         }
570                         sr, ok := streams[f.FrameHeader.StreamID]
571                         if !ok {
572                                 st.header = make(http.Header)
573                                 break
574                         }
575                         sr.header = cloneHeader(st.header)
576                         var status int
577                         status, err = strconv.Atoi(sr.header.Get(":status"))
578                         if err != nil {
579                                 return res, fmt.Errorf("Error parsing status code: %v", err)
580                         }
581                         sr.status = status
582                         if f.StreamEnded() {
583                                 if streamEnded(res, streams, sr) {
584                                         break loop
585                                 }
586                         }
587                 case *http2.PushPromiseFrame:
588                         _, err := st.dec.Write(f.HeaderBlockFragment())
589                         if err != nil {
590                                 return res, err
591                         }
592                         sr := &serverResponse{
593                                 streamID:  f.PromiseID,
594                                 reqHeader: cloneHeader(st.header),
595                         }
596                         streams[sr.streamID] = sr
597                 case *http2.DataFrame:
598                         sr, ok := streams[f.FrameHeader.StreamID]
599                         if !ok {
600                                 break
601                         }
602                         sr.body = append(sr.body, f.Data()...)
603                         if f.StreamEnded() {
604                                 if streamEnded(res, streams, sr) {
605                                         break loop
606                                 }
607                         }
608                 case *http2.RSTStreamFrame:
609                         sr, ok := streams[f.FrameHeader.StreamID]
610                         if !ok {
611                                 break
612                         }
613                         sr.errCode = f.ErrCode
614                         if streamEnded(res, streams, sr) {
615                                 break loop
616                         }
617                 case *http2.GoAwayFrame:
618                         if f.ErrCode == http2.ErrCodeNo {
619                                 break
620                         }
621                         res.errCode = f.ErrCode
622                         res.connErr = true
623                         break loop
624                 case *http2.SettingsFrame:
625                         if f.IsAck() {
626                                 break
627                         }
628                         if err := st.fr.WriteSettingsAck(); err != nil {
629                                 return res, err
630                         }
631                 }
632         }
633         sort.Sort(ByStreamID(res.pushResponse))
634         return res, nil
635 }
636
637 func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool {
638         delete(streams, sr.streamID)
639         if mainSr.streamID != sr.streamID {
640                 mainSr.pushResponse = append(mainSr.pushResponse, sr)
641         }
642         return len(streams) == 0
643 }
644
645 type serverResponse struct {
646         status            int                  // HTTP status code
647         header            http.Header          // response header fields
648         body              []byte               // response body
649         streamID          uint32               // stream ID in HTTP/2
650         errCode           http2.ErrCode        // error code received in HTTP/2 RST_STREAM or GOAWAY
651         connErr           bool                 // true if HTTP/2 connection error
652         spdyGoAwayErrCode spdy.GoAwayStatus    // status code received in SPDY RST_STREAM
653         spdyRstErrCode    spdy.RstStreamStatus // status code received in SPDY GOAWAY
654         connClose         bool                 // Conection: close is included in response header in HTTP/1 test
655         reqHeader         http.Header          // http request header, currently only sotres pushed request header
656         pushResponse      []*serverResponse    // pushed response
657 }
658
659 type ByStreamID []*serverResponse
660
661 func (b ByStreamID) Len() int {
662         return len(b)
663 }
664
665 func (b ByStreamID) Swap(i, j int) {
666         b[i], b[j] = b[j], b[i]
667 }
668
669 func (b ByStreamID) Less(i, j int) bool {
670         return b[i].streamID < b[j].streamID
671 }
672
673 func cloneHeader(h http.Header) http.Header {
674         h2 := make(http.Header, len(h))
675         for k, vv := range h {
676                 vv2 := make([]string, len(vv))
677                 copy(vv2, vv)
678                 h2[k] = vv2
679         }
680         return h2
681 }
682
683 func noopHandler(w http.ResponseWriter, r *http.Request) {}