tizen 2.4 release
[external/nghttp2.git] / integration-tests / server_tester.go
1 package nghttp2
2
3 import (
4         "bufio"
5         "bytes"
6         "crypto/tls"
7         "errors"
8         "fmt"
9         "github.com/bradfitz/http2"
10         "github.com/bradfitz/http2/hpack"
11         "github.com/tatsuhiro-t/go-nghttp2"
12         "golang.org/x/net/spdy"
13         "io"
14         "io/ioutil"
15         "net"
16         "net/http"
17         "net/http/httptest"
18         "net/url"
19         "os/exec"
20         "sort"
21         "strconv"
22         "strings"
23         "testing"
24         "time"
25 )
26
27 const (
28         serverBin  = buildDir + "/src/nghttpx"
29         serverPort = 3009
30         testDir    = buildDir + "/integration-tests"
31 )
32
33 func pair(name, value string) hpack.HeaderField {
34         return hpack.HeaderField{
35                 Name:  name,
36                 Value: value,
37         }
38 }
39
40 type serverTester struct {
41         args          []string  // command-line arguments
42         cmd           *exec.Cmd // test frontend server process, which is test subject
43         url           string    // test frontend server URL
44         t             *testing.T
45         ts            *httptest.Server // backend server
46         frontendHost  string           // frontend server host
47         backendHost   string           // backend server host
48         conn          net.Conn         // connection to frontend server
49         h2PrefaceSent bool             // HTTP/2 preface was sent in conn
50         nextStreamID  uint32           // next stream ID
51         fr            *http2.Framer    // HTTP/2 framer
52         spdyFr        *spdy.Framer     // SPDY/3.1 framer
53         headerBlkBuf  bytes.Buffer     // buffer to store encoded header block
54         enc           *hpack.Encoder   // HTTP/2 HPACK encoder
55         header        http.Header      // received header fields
56         dec           *hpack.Decoder   // HTTP/2 HPACK decoder
57         authority     string           // server's host:port
58         frCh          chan http2.Frame // used for incoming HTTP/2 frame
59         spdyFrCh      chan spdy.Frame  // used for incoming SPDY frame
60         errCh         chan error
61 }
62
63 // newServerTester creates test context for plain TCP frontend
64 // connection.
65 func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
66         return newServerTesterInternal(args, t, handler, false, nil)
67 }
68
69 // newServerTester creates test context for TLS frontend connection.
70 func newServerTesterTLS(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
71         return newServerTesterInternal(args, t, handler, true, nil)
72 }
73
74 // newServerTester creates test context for TLS frontend connection
75 // with given clientConfig
76 func newServerTesterTLSConfig(args []string, t *testing.T, handler http.HandlerFunc, clientConfig *tls.Config) *serverTester {
77         return newServerTesterInternal(args, t, handler, true, clientConfig)
78 }
79
80 // newServerTesterInternal creates test context.  If frontendTLS is
81 // true, set up TLS frontend connection.
82 func newServerTesterInternal(args []string, t *testing.T, handler http.HandlerFunc, frontendTLS bool, clientConfig *tls.Config) *serverTester {
83         ts := httptest.NewUnstartedServer(handler)
84
85         backendTLS := false
86         for _, k := range args {
87                 switch k {
88                 case "--http2-bridge":
89                         backendTLS = true
90                 }
91         }
92         if backendTLS {
93                 nghttp2.ConfigureServer(ts.Config, &nghttp2.Server{})
94                 // According to httptest/server.go, we have to set
95                 // NextProtos separately for ts.TLS.  NextProtos set
96                 // in nghttp2.ConfigureServer is effectively ignored.
97                 ts.TLS = new(tls.Config)
98                 ts.TLS.NextProtos = append(ts.TLS.NextProtos, "h2-14")
99                 ts.StartTLS()
100                 args = append(args, "-k")
101         } else {
102                 ts.Start()
103         }
104         scheme := "http"
105         if frontendTLS {
106                 scheme = "https"
107                 args = append(args, testDir+"/server.key", testDir+"/server.crt")
108         } else {
109                 args = append(args, "--frontend-no-tls")
110         }
111
112         backendURL, err := url.Parse(ts.URL)
113         if err != nil {
114                 t.Fatalf("Error parsing URL from httptest.Server: %v", err)
115         }
116
117         // URL.Host looks like "127.0.0.1:8080", but we want
118         // "127.0.0.1,8080"
119         b := "-b" + strings.Replace(backendURL.Host, ":", ",", -1)
120         args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b,
121                 "--errorlog-file="+testDir+"/log.txt", "-LINFO")
122
123         authority := fmt.Sprintf("127.0.0.1:%v", serverPort)
124
125         st := &serverTester{
126                 cmd:          exec.Command(serverBin, args...),
127                 t:            t,
128                 ts:           ts,
129                 url:          fmt.Sprintf("%v://%v", scheme, authority),
130                 frontendHost: fmt.Sprintf("127.0.0.1:%v", serverPort),
131                 backendHost:  backendURL.Host,
132                 nextStreamID: 1,
133                 authority:    authority,
134                 frCh:         make(chan http2.Frame),
135                 spdyFrCh:     make(chan spdy.Frame),
136                 errCh:        make(chan error),
137         }
138
139         if err := st.cmd.Start(); err != nil {
140                 st.t.Fatalf("Error starting %v: %v", serverBin, err)
141         }
142
143         retry := 0
144         for {
145                 var conn net.Conn
146                 var err error
147                 if frontendTLS {
148                         var tlsConfig *tls.Config
149                         if clientConfig == nil {
150                                 tlsConfig = new(tls.Config)
151                         } else {
152                                 tlsConfig = clientConfig
153                         }
154                         tlsConfig.InsecureSkipVerify = true
155                         tlsConfig.NextProtos = []string{"h2-14", "spdy/3.1"}
156                         conn, err = tls.Dial("tcp", authority, tlsConfig)
157                 } else {
158                         conn, err = net.Dial("tcp", authority)
159                 }
160                 if err != nil {
161                         retry += 1
162                         if retry >= 100 {
163                                 st.Close()
164                                 st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
165                         }
166                         time.Sleep(150 * time.Millisecond)
167                         continue
168                 }
169                 if frontendTLS {
170                         tlsConn := conn.(*tls.Conn)
171                         cs := tlsConn.ConnectionState()
172                         if !cs.NegotiatedProtocolIsMutual {
173                                 st.Close()
174                                 st.t.Fatalf("Error negotiated next protocol is not mutual")
175                         }
176                 }
177                 st.conn = conn
178                 break
179         }
180
181         st.fr = http2.NewFramer(st.conn, st.conn)
182         spdyFr, err := spdy.NewFramer(st.conn, st.conn)
183         if err != nil {
184                 st.Close()
185                 st.t.Fatalf("Error spdy.NewFramer: %v", err)
186         }
187         st.spdyFr = spdyFr
188         st.enc = hpack.NewEncoder(&st.headerBlkBuf)
189         st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
190                 st.header.Add(f.Name, f.Value)
191         })
192
193         return st
194 }
195
196 func (st *serverTester) Close() {
197         if st.conn != nil {
198                 st.conn.Close()
199         }
200         if st.cmd != nil {
201                 st.cmd.Process.Kill()
202                 st.cmd.Wait()
203         }
204         if st.ts != nil {
205                 st.ts.Close()
206         }
207 }
208
209 func (st *serverTester) readFrame() (http2.Frame, error) {
210         go func() {
211                 f, err := st.fr.ReadFrame()
212                 if err != nil {
213                         st.errCh <- err
214                         return
215                 }
216                 st.frCh <- f
217         }()
218
219         select {
220         case f := <-st.frCh:
221                 return f, nil
222         case err := <-st.errCh:
223                 return nil, err
224         case <-time.After(5 * time.Second):
225                 return nil, errors.New("timeout waiting for frame")
226         }
227 }
228
229 func (st *serverTester) readSpdyFrame() (spdy.Frame, error) {
230         go func() {
231                 f, err := st.spdyFr.ReadFrame()
232                 if err != nil {
233                         st.errCh <- err
234                         return
235                 }
236                 st.spdyFrCh <- f
237         }()
238
239         select {
240         case f := <-st.spdyFrCh:
241                 return f, nil
242         case err := <-st.errCh:
243                 return nil, err
244         case <-time.After(2 * time.Second):
245                 return nil, errors.New("timeout waiting for frame")
246         }
247 }
248
249 type requestParam struct {
250         name      string              // name for this request to identify the request in log easily
251         streamID  uint32              // stream ID, automatically assigned if 0
252         method    string              // method, defaults to GET
253         scheme    string              // scheme, defaults to http
254         authority string              // authority, defaults to backend server address
255         path      string              // path, defaults to /
256         header    []hpack.HeaderField // additional request header fields
257         body      []byte              // request body
258 }
259
260 func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
261         method := "GET"
262         if rp.method != "" {
263                 method = rp.method
264         }
265
266         var body io.Reader
267         if rp.body != nil {
268                 body = bytes.NewBuffer(rp.body)
269         }
270         req, err := http.NewRequest(method, st.url, body)
271         if err != nil {
272                 return nil, err
273         }
274         for _, h := range rp.header {
275                 req.Header.Add(h.Name, h.Value)
276         }
277         req.Header.Add("Test-Case", rp.name)
278
279         if err := req.Write(st.conn); err != nil {
280                 return nil, err
281         }
282         resp, err := http.ReadResponse(bufio.NewReader(st.conn), req)
283         if err != nil {
284                 return nil, err
285         }
286         respBody, err := ioutil.ReadAll(resp.Body)
287         if err != nil {
288                 return nil, err
289         }
290         resp.Body.Close()
291
292         res := &serverResponse{
293                 status:    resp.StatusCode,
294                 header:    resp.Header,
295                 body:      respBody,
296                 connClose: resp.Close,
297         }
298
299         return res, nil
300 }
301
302 func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) {
303         res := &serverResponse{}
304
305         var id spdy.StreamId
306         if rp.streamID != 0 {
307                 id = spdy.StreamId(rp.streamID)
308                 if id >= spdy.StreamId(st.nextStreamID) && id%2 == 1 {
309                         st.nextStreamID = uint32(id) + 2
310                 }
311         } else {
312                 id = spdy.StreamId(st.nextStreamID)
313                 st.nextStreamID += 2
314         }
315
316         method := "GET"
317         if rp.method != "" {
318                 method = rp.method
319         }
320
321         scheme := "http"
322         if rp.scheme != "" {
323                 scheme = rp.scheme
324         }
325
326         host := st.authority
327         if rp.authority != "" {
328                 host = rp.authority
329         }
330
331         path := "/"
332         if rp.path != "" {
333                 path = rp.path
334         }
335
336         header := make(http.Header)
337         header.Add(":method", method)
338         header.Add(":scheme", scheme)
339         header.Add(":host", host)
340         header.Add(":path", path)
341         header.Add(":version", "HTTP/1.1")
342         header.Add("test-case", rp.name)
343         for _, h := range rp.header {
344                 header.Add(h.Name, h.Value)
345         }
346
347         var synStreamFlags spdy.ControlFlags
348         if len(rp.body) == 0 {
349                 synStreamFlags = spdy.ControlFlagFin
350         }
351         if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{
352                 CFHeader: spdy.ControlFrameHeader{
353                         Flags: synStreamFlags,
354                 },
355                 StreamId: id,
356                 Headers:  header,
357         }); err != nil {
358                 return nil, err
359         }
360
361         if len(rp.body) != 0 {
362                 if err := st.spdyFr.WriteFrame(&spdy.DataFrame{
363                         StreamId: id,
364                         Flags:    spdy.DataFlagFin,
365                         Data:     rp.body,
366                 }); err != nil {
367                         return nil, err
368                 }
369         }
370
371 loop:
372         for {
373                 fr, err := st.readSpdyFrame()
374                 if err != nil {
375                         return res, err
376                 }
377                 switch f := fr.(type) {
378                 case *spdy.SynReplyFrame:
379                         if f.StreamId != id {
380                                 break
381                         }
382                         res.header = cloneHeader(f.Headers)
383                         if _, err := fmt.Sscan(res.header.Get(":status"), &res.status); err != nil {
384                                 return res, fmt.Errorf("Error parsing status code: %v", err)
385                         }
386                         if f.CFHeader.Flags&spdy.ControlFlagFin != 0 {
387                                 break loop
388                         }
389                 case *spdy.DataFrame:
390                         if f.StreamId != id {
391                                 break
392                         }
393                         res.body = append(res.body, f.Data...)
394                         if f.Flags&spdy.DataFlagFin != 0 {
395                                 break loop
396                         }
397                 case *spdy.RstStreamFrame:
398                         if f.StreamId != id {
399                                 break
400                         }
401                         res.spdyRstErrCode = f.Status
402                         break loop
403                 case *spdy.GoAwayFrame:
404                         if f.Status == spdy.GoAwayOK {
405                                 break
406                         }
407                         res.spdyGoAwayErrCode = f.Status
408                         break loop
409                 }
410         }
411         return res, nil
412 }
413
414 func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
415         st.headerBlkBuf.Reset()
416         st.header = make(http.Header)
417
418         var id uint32
419         if rp.streamID != 0 {
420                 id = rp.streamID
421                 if id >= st.nextStreamID && id%2 == 1 {
422                         st.nextStreamID = id + 2
423                 }
424         } else {
425                 id = st.nextStreamID
426                 st.nextStreamID += 2
427         }
428
429         if !st.h2PrefaceSent {
430                 st.h2PrefaceSent = true
431                 fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
432                 if err := st.fr.WriteSettings(); err != nil {
433                         return nil, err
434                 }
435         }
436
437         res := &serverResponse{
438                 streamID: id,
439         }
440
441         streams := make(map[uint32]*serverResponse)
442         streams[id] = res
443
444         method := "GET"
445         if rp.method != "" {
446                 method = rp.method
447         }
448         _ = st.enc.WriteField(pair(":method", method))
449
450         scheme := "http"
451         if rp.scheme != "" {
452                 scheme = rp.scheme
453         }
454         _ = st.enc.WriteField(pair(":scheme", scheme))
455
456         authority := st.authority
457         if rp.authority != "" {
458                 authority = rp.authority
459         }
460         _ = st.enc.WriteField(pair(":authority", authority))
461
462         path := "/"
463         if rp.path != "" {
464                 path = rp.path
465         }
466         _ = st.enc.WriteField(pair(":path", path))
467
468         _ = st.enc.WriteField(pair("test-case", rp.name))
469
470         for _, h := range rp.header {
471                 _ = st.enc.WriteField(h)
472         }
473
474         err := st.fr.WriteHeaders(http2.HeadersFrameParam{
475                 StreamID:      id,
476                 EndStream:     len(rp.body) == 0,
477                 EndHeaders:    true,
478                 BlockFragment: st.headerBlkBuf.Bytes(),
479         })
480         if err != nil {
481                 return nil, err
482         }
483
484         if len(rp.body) != 0 {
485                 // TODO we assume rp.body fits in 1 frame
486                 if err := st.fr.WriteData(id, true, rp.body); err != nil {
487                         return nil, err
488                 }
489         }
490
491 loop:
492         for {
493                 fr, err := st.readFrame()
494                 if err != nil {
495                         return res, err
496                 }
497                 switch f := fr.(type) {
498                 case *http2.HeadersFrame:
499                         _, err := st.dec.Write(f.HeaderBlockFragment())
500                         if err != nil {
501                                 return res, err
502                         }
503                         sr, ok := streams[f.FrameHeader.StreamID]
504                         if !ok {
505                                 st.header = make(http.Header)
506                                 break
507                         }
508                         sr.header = cloneHeader(st.header)
509                         var status int
510                         status, err = strconv.Atoi(sr.header.Get(":status"))
511                         if err != nil {
512                                 return res, fmt.Errorf("Error parsing status code: %v", err)
513                         }
514                         sr.status = status
515                         if f.StreamEnded() {
516                                 if streamEnded(res, streams, sr) {
517                                         break loop
518                                 }
519                         }
520                 case *http2.PushPromiseFrame:
521                         _, err := st.dec.Write(f.HeaderBlockFragment())
522                         if err != nil {
523                                 return res, err
524                         }
525                         sr := &serverResponse{
526                                 streamID:  f.PromiseID,
527                                 reqHeader: cloneHeader(st.header),
528                         }
529                         streams[sr.streamID] = sr
530                 case *http2.DataFrame:
531                         sr, ok := streams[f.FrameHeader.StreamID]
532                         if !ok {
533                                 break
534                         }
535                         sr.body = append(sr.body, f.Data()...)
536                         if f.StreamEnded() {
537                                 if streamEnded(res, streams, sr) {
538                                         break loop
539                                 }
540                         }
541                 case *http2.RSTStreamFrame:
542                         sr, ok := streams[f.FrameHeader.StreamID]
543                         if !ok {
544                                 break
545                         }
546                         sr.errCode = f.ErrCode
547                         if streamEnded(res, streams, sr) {
548                                 break loop
549                         }
550                 case *http2.GoAwayFrame:
551                         if f.ErrCode == http2.ErrCodeNo {
552                                 break
553                         }
554                         res.errCode = f.ErrCode
555                         res.connErr = true
556                         break loop
557                 case *http2.SettingsFrame:
558                         if f.IsAck() {
559                                 break
560                         }
561                         if err := st.fr.WriteSettingsAck(); err != nil {
562                                 return res, err
563                         }
564                 }
565         }
566         sort.Sort(ByStreamID(res.pushResponse))
567         return res, nil
568 }
569
570 func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool {
571         delete(streams, sr.streamID)
572         if mainSr.streamID != sr.streamID {
573                 mainSr.pushResponse = append(mainSr.pushResponse, sr)
574         }
575         return len(streams) == 0
576 }
577
578 type serverResponse struct {
579         status            int                  // HTTP status code
580         header            http.Header          // response header fields
581         body              []byte               // response body
582         streamID          uint32               // stream ID in HTTP/2
583         errCode           http2.ErrCode        // error code received in HTTP/2 RST_STREAM or GOAWAY
584         connErr           bool                 // true if HTTP/2 connection error
585         spdyGoAwayErrCode spdy.GoAwayStatus    // status code received in SPDY RST_STREAM
586         spdyRstErrCode    spdy.RstStreamStatus // status code received in SPDY GOAWAY
587         connClose         bool                 // Conection: close is included in response header in HTTP/1 test
588         reqHeader         http.Header          // http request header, currently only sotres pushed request header
589         pushResponse      []*serverResponse    // pushed response
590 }
591
592 type ByStreamID []*serverResponse
593
594 func (b ByStreamID) Len() int {
595         return len(b)
596 }
597
598 func (b ByStreamID) Swap(i, j int) {
599         b[i], b[j] = b[j], b[i]
600 }
601
602 func (b ByStreamID) Less(i, j int) bool {
603         return b[i].streamID < b[j].streamID
604 }
605
606 func cloneHeader(h http.Header) http.Header {
607         h2 := make(http.Header, len(h))
608         for k, vv := range h {
609                 vv2 := make([]string, len(vv))
610                 copy(vv2, vv)
611                 h2[k] = vv2
612         }
613         return h2
614 }
615
616 func noopHandler(w http.ResponseWriter, r *http.Request) {}