remove unused files
[platform/upstream/gcc48.git] / libgo / go / database / sql / fakedb_test.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package sql
6
7 import (
8         "database/sql/driver"
9         "errors"
10         "fmt"
11         "io"
12         "log"
13         "strconv"
14         "strings"
15         "sync"
16         "testing"
17         "time"
18 )
19
20 var _ = log.Printf
21
22 // fakeDriver is a fake database that implements Go's driver.Driver
23 // interface, just for testing.
24 //
25 // It speaks a query language that's semantically similar to but
26 // syntantically different and simpler than SQL.  The syntax is as
27 // follows:
28 //
29 //   WIPE
30 //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
31 //     where types are: "string", [u]int{8,16,32,64}, "bool"
32 //   INSERT|<tablename>|col=val,col2=val2,col3=?
33 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
34 //
35 // When opening a fakeDriver's database, it starts empty with no
36 // tables.  All tables and data are stored in memory only.
37 type fakeDriver struct {
38         mu         sync.Mutex // guards 3 following fields
39         openCount  int        // conn opens
40         closeCount int        // conn closes
41         dbs        map[string]*fakeDB
42 }
43
44 type fakeDB struct {
45         name string
46
47         mu      sync.Mutex
48         free    []*fakeConn
49         tables  map[string]*table
50         badConn bool
51 }
52
53 type table struct {
54         mu      sync.Mutex
55         colname []string
56         coltype []string
57         rows    []*row
58 }
59
60 func (t *table) columnIndex(name string) int {
61         for n, nname := range t.colname {
62                 if name == nname {
63                         return n
64                 }
65         }
66         return -1
67 }
68
69 type row struct {
70         cols []interface{} // must be same size as its table colname + coltype
71 }
72
73 func (r *row) clone() *row {
74         nrow := &row{cols: make([]interface{}, len(r.cols))}
75         copy(nrow.cols, r.cols)
76         return nrow
77 }
78
79 type fakeConn struct {
80         db *fakeDB // where to return ourselves to
81
82         currTx *fakeTx
83
84         // Stats for tests:
85         mu          sync.Mutex
86         stmtsMade   int
87         stmtsClosed int
88         numPrepare  int
89         bad         bool
90 }
91
92 func (c *fakeConn) incrStat(v *int) {
93         c.mu.Lock()
94         *v++
95         c.mu.Unlock()
96 }
97
98 type fakeTx struct {
99         c *fakeConn
100 }
101
102 type fakeStmt struct {
103         c *fakeConn
104         q string // just for debugging
105
106         cmd   string
107         table string
108
109         closed bool
110
111         colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
112         colType      []string      // used by CREATE
113         colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
114         placeholders int           // used by INSERT/SELECT: number of ? params
115
116         whereCol []string // used by SELECT (all placeholders)
117
118         placeholderConverter []driver.ValueConverter // used by INSERT
119 }
120
121 var fdriver driver.Driver = &fakeDriver{}
122
123 func init() {
124         Register("test", fdriver)
125 }
126
127 // Supports dsn forms:
128 //    <dbname>
129 //    <dbname>;<opts>  (only currently supported option is `badConn`,
130 //                      which causes driver.ErrBadConn to be returned on
131 //                      every other conn.Begin())
132 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
133         parts := strings.Split(dsn, ";")
134         if len(parts) < 1 {
135                 return nil, errors.New("fakedb: no database name")
136         }
137         name := parts[0]
138
139         db := d.getDB(name)
140
141         d.mu.Lock()
142         d.openCount++
143         d.mu.Unlock()
144         conn := &fakeConn{db: db}
145
146         if len(parts) >= 2 && parts[1] == "badConn" {
147                 conn.bad = true
148         }
149         return conn, nil
150 }
151
152 func (d *fakeDriver) getDB(name string) *fakeDB {
153         d.mu.Lock()
154         defer d.mu.Unlock()
155         if d.dbs == nil {
156                 d.dbs = make(map[string]*fakeDB)
157         }
158         db, ok := d.dbs[name]
159         if !ok {
160                 db = &fakeDB{name: name}
161                 d.dbs[name] = db
162         }
163         return db
164 }
165
166 func (db *fakeDB) wipe() {
167         db.mu.Lock()
168         defer db.mu.Unlock()
169         db.tables = nil
170 }
171
172 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
173         db.mu.Lock()
174         defer db.mu.Unlock()
175         if db.tables == nil {
176                 db.tables = make(map[string]*table)
177         }
178         if _, exist := db.tables[name]; exist {
179                 return fmt.Errorf("table %q already exists", name)
180         }
181         if len(columnNames) != len(columnTypes) {
182                 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
183                         name, len(columnNames), len(columnTypes))
184         }
185         db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
186         return nil
187 }
188
189 // must be called with db.mu lock held
190 func (db *fakeDB) table(table string) (*table, bool) {
191         if db.tables == nil {
192                 return nil, false
193         }
194         t, ok := db.tables[table]
195         return t, ok
196 }
197
198 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
199         db.mu.Lock()
200         defer db.mu.Unlock()
201         t, ok := db.table(table)
202         if !ok {
203                 return
204         }
205         for n, cname := range t.colname {
206                 if cname == column {
207                         return t.coltype[n], true
208                 }
209         }
210         return "", false
211 }
212
213 func (c *fakeConn) isBad() bool {
214         // if not simulating bad conn, do nothing
215         if !c.bad {
216                 return false
217         }
218         // alternate between bad conn and not bad conn
219         c.db.badConn = !c.db.badConn
220         return c.db.badConn
221 }
222
223 func (c *fakeConn) Begin() (driver.Tx, error) {
224         if c.isBad() {
225                 return nil, driver.ErrBadConn
226         }
227         if c.currTx != nil {
228                 return nil, errors.New("already in a transaction")
229         }
230         c.currTx = &fakeTx{c: c}
231         return c.currTx, nil
232 }
233
234 var hookPostCloseConn struct {
235         sync.Mutex
236         fn func(*fakeConn, error)
237 }
238
239 func setHookpostCloseConn(fn func(*fakeConn, error)) {
240         hookPostCloseConn.Lock()
241         defer hookPostCloseConn.Unlock()
242         hookPostCloseConn.fn = fn
243 }
244
245 var testStrictClose *testing.T
246
247 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
248 // fails to close. If nil, the check is disabled.
249 func setStrictFakeConnClose(t *testing.T) {
250         testStrictClose = t
251 }
252
253 func (c *fakeConn) Close() (err error) {
254         drv := fdriver.(*fakeDriver)
255         defer func() {
256                 if err != nil && testStrictClose != nil {
257                         testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
258                 }
259                 hookPostCloseConn.Lock()
260                 fn := hookPostCloseConn.fn
261                 hookPostCloseConn.Unlock()
262                 if fn != nil {
263                         fn(c, err)
264                 }
265                 if err == nil {
266                         drv.mu.Lock()
267                         drv.closeCount++
268                         drv.mu.Unlock()
269                 }
270         }()
271         if c.currTx != nil {
272                 return errors.New("can't close fakeConn; in a Transaction")
273         }
274         if c.db == nil {
275                 return errors.New("can't close fakeConn; already closed")
276         }
277         if c.stmtsMade > c.stmtsClosed {
278                 return errors.New("can't close; dangling statement(s)")
279         }
280         c.db = nil
281         return nil
282 }
283
284 func checkSubsetTypes(args []driver.Value) error {
285         for n, arg := range args {
286                 switch arg.(type) {
287                 case int64, float64, bool, nil, []byte, string, time.Time:
288                 default:
289                         return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
290                 }
291         }
292         return nil
293 }
294
295 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
296         // This is an optional interface, but it's implemented here
297         // just to check that all the args are of the proper types.
298         // ErrSkip is returned so the caller acts as if we didn't
299         // implement this at all.
300         err := checkSubsetTypes(args)
301         if err != nil {
302                 return nil, err
303         }
304         return nil, driver.ErrSkip
305 }
306
307 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
308         // This is an optional interface, but it's implemented here
309         // just to check that all the args are of the proper types.
310         // ErrSkip is returned so the caller acts as if we didn't
311         // implement this at all.
312         err := checkSubsetTypes(args)
313         if err != nil {
314                 return nil, err
315         }
316         return nil, driver.ErrSkip
317 }
318
319 func errf(msg string, args ...interface{}) error {
320         return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
321 }
322
323 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
324 // (note that where columns must always contain ? marks,
325 //  just a limitation for fakedb)
326 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
327         if len(parts) != 3 {
328                 stmt.Close()
329                 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
330         }
331         stmt.table = parts[0]
332         stmt.colName = strings.Split(parts[1], ",")
333         for n, colspec := range strings.Split(parts[2], ",") {
334                 if colspec == "" {
335                         continue
336                 }
337                 nameVal := strings.Split(colspec, "=")
338                 if len(nameVal) != 2 {
339                         stmt.Close()
340                         return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
341                 }
342                 column, value := nameVal[0], nameVal[1]
343                 _, ok := c.db.columnType(stmt.table, column)
344                 if !ok {
345                         stmt.Close()
346                         return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
347                 }
348                 if value != "?" {
349                         stmt.Close()
350                         return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
351                                 stmt.table, column)
352                 }
353                 stmt.whereCol = append(stmt.whereCol, column)
354                 stmt.placeholders++
355         }
356         return stmt, nil
357 }
358
359 // parts are table|col=type,col2=type2
360 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
361         if len(parts) != 2 {
362                 stmt.Close()
363                 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
364         }
365         stmt.table = parts[0]
366         for n, colspec := range strings.Split(parts[1], ",") {
367                 nameType := strings.Split(colspec, "=")
368                 if len(nameType) != 2 {
369                         stmt.Close()
370                         return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
371                 }
372                 stmt.colName = append(stmt.colName, nameType[0])
373                 stmt.colType = append(stmt.colType, nameType[1])
374         }
375         return stmt, nil
376 }
377
378 // parts are table|col=?,col2=val
379 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
380         if len(parts) != 2 {
381                 stmt.Close()
382                 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
383         }
384         stmt.table = parts[0]
385         for n, colspec := range strings.Split(parts[1], ",") {
386                 nameVal := strings.Split(colspec, "=")
387                 if len(nameVal) != 2 {
388                         stmt.Close()
389                         return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
390                 }
391                 column, value := nameVal[0], nameVal[1]
392                 ctype, ok := c.db.columnType(stmt.table, column)
393                 if !ok {
394                         stmt.Close()
395                         return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
396                 }
397                 stmt.colName = append(stmt.colName, column)
398
399                 if value != "?" {
400                         var subsetVal interface{}
401                         // Convert to driver subset type
402                         switch ctype {
403                         case "string":
404                                 subsetVal = []byte(value)
405                         case "blob":
406                                 subsetVal = []byte(value)
407                         case "int32":
408                                 i, err := strconv.Atoi(value)
409                                 if err != nil {
410                                         stmt.Close()
411                                         return nil, errf("invalid conversion to int32 from %q", value)
412                                 }
413                                 subsetVal = int64(i) // int64 is a subset type, but not int32
414                         default:
415                                 stmt.Close()
416                                 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
417                         }
418                         stmt.colValue = append(stmt.colValue, subsetVal)
419                 } else {
420                         stmt.placeholders++
421                         stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
422                         stmt.colValue = append(stmt.colValue, "?")
423                 }
424         }
425         return stmt, nil
426 }
427
428 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
429         c.numPrepare++
430         if c.db == nil {
431                 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
432         }
433         parts := strings.Split(query, "|")
434         if len(parts) < 1 {
435                 return nil, errf("empty query")
436         }
437         cmd := parts[0]
438         parts = parts[1:]
439         stmt := &fakeStmt{q: query, c: c, cmd: cmd}
440         c.incrStat(&c.stmtsMade)
441         switch cmd {
442         case "WIPE":
443                 // Nothing
444         case "SELECT":
445                 return c.prepareSelect(stmt, parts)
446         case "CREATE":
447                 return c.prepareCreate(stmt, parts)
448         case "INSERT":
449                 return c.prepareInsert(stmt, parts)
450         default:
451                 stmt.Close()
452                 return nil, errf("unsupported command type %q", cmd)
453         }
454         return stmt, nil
455 }
456
457 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
458         if len(s.placeholderConverter) == 0 {
459                 return driver.DefaultParameterConverter
460         }
461         return s.placeholderConverter[idx]
462 }
463
464 func (s *fakeStmt) Close() error {
465         if s.c == nil {
466                 panic("nil conn in fakeStmt.Close")
467         }
468         if s.c.db == nil {
469                 panic("in fakeStmt.Close, conn's db is nil (already closed)")
470         }
471         if !s.closed {
472                 s.c.incrStat(&s.c.stmtsClosed)
473                 s.closed = true
474         }
475         return nil
476 }
477
478 var errClosed = errors.New("fakedb: statement has been closed")
479
480 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
481         if s.closed {
482                 return nil, errClosed
483         }
484         err := checkSubsetTypes(args)
485         if err != nil {
486                 return nil, err
487         }
488
489         db := s.c.db
490         switch s.cmd {
491         case "WIPE":
492                 db.wipe()
493                 return driver.ResultNoRows, nil
494         case "CREATE":
495                 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
496                         return nil, err
497                 }
498                 return driver.ResultNoRows, nil
499         case "INSERT":
500                 return s.execInsert(args)
501         }
502         fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
503         return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
504 }
505
506 func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) {
507         db := s.c.db
508         if len(args) != s.placeholders {
509                 panic("error in pkg db; should only get here if size is correct")
510         }
511         db.mu.Lock()
512         t, ok := db.table(s.table)
513         db.mu.Unlock()
514         if !ok {
515                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
516         }
517
518         t.mu.Lock()
519         defer t.mu.Unlock()
520
521         cols := make([]interface{}, len(t.colname))
522         argPos := 0
523         for n, colname := range s.colName {
524                 colidx := t.columnIndex(colname)
525                 if colidx == -1 {
526                         return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
527                 }
528                 var val interface{}
529                 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
530                         val = args[argPos]
531                         argPos++
532                 } else {
533                         val = s.colValue[n]
534                 }
535                 cols[colidx] = val
536         }
537
538         t.rows = append(t.rows, &row{cols: cols})
539         return driver.RowsAffected(1), nil
540 }
541
542 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
543         if s.closed {
544                 return nil, errClosed
545         }
546         err := checkSubsetTypes(args)
547         if err != nil {
548                 return nil, err
549         }
550
551         db := s.c.db
552         if len(args) != s.placeholders {
553                 panic("error in pkg db; should only get here if size is correct")
554         }
555
556         db.mu.Lock()
557         t, ok := db.table(s.table)
558         db.mu.Unlock()
559         if !ok {
560                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
561         }
562
563         if s.table == "magicquery" {
564                 if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
565                         if args[0] == "sleep" {
566                                 time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
567                         }
568                 }
569         }
570
571         t.mu.Lock()
572         defer t.mu.Unlock()
573
574         colIdx := make(map[string]int) // select column name -> column index in table
575         for _, name := range s.colName {
576                 idx := t.columnIndex(name)
577                 if idx == -1 {
578                         return nil, fmt.Errorf("fakedb: unknown column name %q", name)
579                 }
580                 colIdx[name] = idx
581         }
582
583         mrows := []*row{}
584 rows:
585         for _, trow := range t.rows {
586                 // Process the where clause, skipping non-match rows. This is lazy
587                 // and just uses fmt.Sprintf("%v") to test equality.  Good enough
588                 // for test code.
589                 for widx, wcol := range s.whereCol {
590                         idx := t.columnIndex(wcol)
591                         if idx == -1 {
592                                 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
593                         }
594                         tcol := trow.cols[idx]
595                         if bs, ok := tcol.([]byte); ok {
596                                 // lazy hack to avoid sprintf %v on a []byte
597                                 tcol = string(bs)
598                         }
599                         if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
600                                 continue rows
601                         }
602                 }
603                 mrow := &row{cols: make([]interface{}, len(s.colName))}
604                 for seli, name := range s.colName {
605                         mrow.cols[seli] = trow.cols[colIdx[name]]
606                 }
607                 mrows = append(mrows, mrow)
608         }
609
610         cursor := &rowsCursor{
611                 pos:  -1,
612                 rows: mrows,
613                 cols: s.colName,
614         }
615         return cursor, nil
616 }
617
618 func (s *fakeStmt) NumInput() int {
619         return s.placeholders
620 }
621
622 func (tx *fakeTx) Commit() error {
623         tx.c.currTx = nil
624         return nil
625 }
626
627 func (tx *fakeTx) Rollback() error {
628         tx.c.currTx = nil
629         return nil
630 }
631
632 type rowsCursor struct {
633         cols   []string
634         pos    int
635         rows   []*row
636         closed bool
637
638         // a clone of slices to give out to clients, indexed by the
639         // the original slice's first byte address.  we clone them
640         // just so we're able to corrupt them on close.
641         bytesClone map[*byte][]byte
642 }
643
644 func (rc *rowsCursor) Close() error {
645         if !rc.closed {
646                 for _, bs := range rc.bytesClone {
647                         bs[0] = 255 // first byte corrupted
648                 }
649         }
650         rc.closed = true
651         return nil
652 }
653
654 func (rc *rowsCursor) Columns() []string {
655         return rc.cols
656 }
657
658 func (rc *rowsCursor) Next(dest []driver.Value) error {
659         if rc.closed {
660                 return errors.New("fakedb: cursor is closed")
661         }
662         rc.pos++
663         if rc.pos >= len(rc.rows) {
664                 return io.EOF // per interface spec
665         }
666         for i, v := range rc.rows[rc.pos].cols {
667                 // TODO(bradfitz): convert to subset types? naah, I
668                 // think the subset types should only be input to
669                 // driver, but the sql package should be able to handle
670                 // a wider range of types coming out of drivers. all
671                 // for ease of drivers, and to prevent drivers from
672                 // messing up conversions or doing them differently.
673                 dest[i] = v
674
675                 if bs, ok := v.([]byte); ok {
676                         if rc.bytesClone == nil {
677                                 rc.bytesClone = make(map[*byte][]byte)
678                         }
679                         clone, ok := rc.bytesClone[&bs[0]]
680                         if !ok {
681                                 clone = make([]byte, len(bs))
682                                 copy(clone, bs)
683                                 rc.bytesClone[&bs[0]] = clone
684                         }
685                         dest[i] = clone
686                 }
687         }
688         return nil
689 }
690
691 // fakeDriverString is like driver.String, but indirects pointers like
692 // DefaultValueConverter.
693 //
694 // This could be surprising behavior to retroactively apply to
695 // driver.String now that Go1 is out, but this is convenient for
696 // our TestPointerParamsAndScans.
697 //
698 type fakeDriverString struct{}
699
700 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
701         switch c := v.(type) {
702         case string, []byte:
703                 return v, nil
704         case *string:
705                 if c == nil {
706                         return nil, nil
707                 }
708                 return *c, nil
709         }
710         return fmt.Sprintf("%v", v), nil
711 }
712
713 func converterForType(typ string) driver.ValueConverter {
714         switch typ {
715         case "bool":
716                 return driver.Bool
717         case "nullbool":
718                 return driver.Null{Converter: driver.Bool}
719         case "int32":
720                 return driver.Int32
721         case "string":
722                 return driver.NotNull{Converter: fakeDriverString{}}
723         case "nullstring":
724                 return driver.Null{Converter: fakeDriverString{}}
725         case "int64":
726                 // TODO(coopernurse): add type-specific converter
727                 return driver.NotNull{Converter: driver.DefaultParameterConverter}
728         case "nullint64":
729                 // TODO(coopernurse): add type-specific converter
730                 return driver.Null{Converter: driver.DefaultParameterConverter}
731         case "float64":
732                 // TODO(coopernurse): add type-specific converter
733                 return driver.NotNull{Converter: driver.DefaultParameterConverter}
734         case "nullfloat64":
735                 // TODO(coopernurse): add type-specific converter
736                 return driver.Null{Converter: driver.DefaultParameterConverter}
737         case "datetime":
738                 return driver.DefaultParameterConverter
739         }
740         panic("invalid fakedb column type of " + typ)
741 }