184e7756c51e4228e89e420393519bff05361a81
[platform/upstream/gcc48.git] / libgo / go / database / sql / fakedb_test.go
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package sql
6
7 import (
8         "database/sql/driver"
9         "errors"
10         "fmt"
11         "io"
12         "log"
13         "strconv"
14         "strings"
15         "sync"
16         "time"
17 )
18
19 var _ = log.Printf
20
21 // fakeDriver is a fake database that implements Go's driver.Driver
22 // interface, just for testing.
23 //
24 // It speaks a query language that's semantically similar to but
25 // syntantically different and simpler than SQL.  The syntax is as
26 // follows:
27 //
28 //   WIPE
29 //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
30 //     where types are: "string", [u]int{8,16,32,64}, "bool"
31 //   INSERT|<tablename>|col=val,col2=val2,col3=?
32 //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
33 //
34 // When opening a a fakeDriver's database, it starts empty with no
35 // tables.  All tables and data are stored in memory only.
36 type fakeDriver struct {
37         mu        sync.Mutex
38         openCount int
39         dbs       map[string]*fakeDB
40 }
41
42 type fakeDB struct {
43         name string
44
45         mu     sync.Mutex
46         free   []*fakeConn
47         tables map[string]*table
48 }
49
50 type table struct {
51         mu      sync.Mutex
52         colname []string
53         coltype []string
54         rows    []*row
55 }
56
57 func (t *table) columnIndex(name string) int {
58         for n, nname := range t.colname {
59                 if name == nname {
60                         return n
61                 }
62         }
63         return -1
64 }
65
66 type row struct {
67         cols []interface{} // must be same size as its table colname + coltype
68 }
69
70 func (r *row) clone() *row {
71         nrow := &row{cols: make([]interface{}, len(r.cols))}
72         copy(nrow.cols, r.cols)
73         return nrow
74 }
75
76 type fakeConn struct {
77         db *fakeDB // where to return ourselves to
78
79         currTx *fakeTx
80
81         // Stats for tests:
82         mu          sync.Mutex
83         stmtsMade   int
84         stmtsClosed int
85         numPrepare  int
86 }
87
88 func (c *fakeConn) incrStat(v *int) {
89         c.mu.Lock()
90         *v++
91         c.mu.Unlock()
92 }
93
94 type fakeTx struct {
95         c *fakeConn
96 }
97
98 type fakeStmt struct {
99         c *fakeConn
100         q string // just for debugging
101
102         cmd   string
103         table string
104
105         closed bool
106
107         colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
108         colType      []string      // used by CREATE
109         colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
110         placeholders int           // used by INSERT/SELECT: number of ? params
111
112         whereCol []string // used by SELECT (all placeholders)
113
114         placeholderConverter []driver.ValueConverter // used by INSERT
115 }
116
117 var fdriver driver.Driver = &fakeDriver{}
118
119 func init() {
120         Register("test", fdriver)
121 }
122
123 // Supports dsn forms:
124 //    <dbname>
125 //    <dbname>;<opts>  (no currently supported options)
126 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
127         parts := strings.Split(dsn, ";")
128         if len(parts) < 1 {
129                 return nil, errors.New("fakedb: no database name")
130         }
131         name := parts[0]
132
133         db := d.getDB(name)
134
135         d.mu.Lock()
136         d.openCount++
137         d.mu.Unlock()
138         return &fakeConn{db: db}, nil
139 }
140
141 func (d *fakeDriver) getDB(name string) *fakeDB {
142         d.mu.Lock()
143         defer d.mu.Unlock()
144         if d.dbs == nil {
145                 d.dbs = make(map[string]*fakeDB)
146         }
147         db, ok := d.dbs[name]
148         if !ok {
149                 db = &fakeDB{name: name}
150                 d.dbs[name] = db
151         }
152         return db
153 }
154
155 func (db *fakeDB) wipe() {
156         db.mu.Lock()
157         defer db.mu.Unlock()
158         db.tables = nil
159 }
160
161 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
162         db.mu.Lock()
163         defer db.mu.Unlock()
164         if db.tables == nil {
165                 db.tables = make(map[string]*table)
166         }
167         if _, exist := db.tables[name]; exist {
168                 return fmt.Errorf("table %q already exists", name)
169         }
170         if len(columnNames) != len(columnTypes) {
171                 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
172                         name, len(columnNames), len(columnTypes))
173         }
174         db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
175         return nil
176 }
177
178 // must be called with db.mu lock held
179 func (db *fakeDB) table(table string) (*table, bool) {
180         if db.tables == nil {
181                 return nil, false
182         }
183         t, ok := db.tables[table]
184         return t, ok
185 }
186
187 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
188         db.mu.Lock()
189         defer db.mu.Unlock()
190         t, ok := db.table(table)
191         if !ok {
192                 return
193         }
194         for n, cname := range t.colname {
195                 if cname == column {
196                         return t.coltype[n], true
197                 }
198         }
199         return "", false
200 }
201
202 func (c *fakeConn) Begin() (driver.Tx, error) {
203         if c.currTx != nil {
204                 return nil, errors.New("already in a transaction")
205         }
206         c.currTx = &fakeTx{c: c}
207         return c.currTx, nil
208 }
209
210 func (c *fakeConn) Close() error {
211         if c.currTx != nil {
212                 return errors.New("can't close fakeConn; in a Transaction")
213         }
214         if c.db == nil {
215                 return errors.New("can't close fakeConn; already closed")
216         }
217         if c.stmtsMade > c.stmtsClosed {
218                 return errors.New("can't close; dangling statement(s)")
219         }
220         c.db = nil
221         return nil
222 }
223
224 func checkSubsetTypes(args []driver.Value) error {
225         for n, arg := range args {
226                 switch arg.(type) {
227                 case int64, float64, bool, nil, []byte, string, time.Time:
228                 default:
229                         return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
230                 }
231         }
232         return nil
233 }
234
235 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
236         // This is an optional interface, but it's implemented here
237         // just to check that all the args of of the proper types.
238         // ErrSkip is returned so the caller acts as if we didn't
239         // implement this at all.
240         err := checkSubsetTypes(args)
241         if err != nil {
242                 return nil, err
243         }
244         return nil, driver.ErrSkip
245 }
246
247 func errf(msg string, args ...interface{}) error {
248         return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
249 }
250
251 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
252 // (note that where where columns must always contain ? marks,
253 //  just a limitation for fakedb)
254 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
255         if len(parts) != 3 {
256                 stmt.Close()
257                 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
258         }
259         stmt.table = parts[0]
260         stmt.colName = strings.Split(parts[1], ",")
261         for n, colspec := range strings.Split(parts[2], ",") {
262                 if colspec == "" {
263                         continue
264                 }
265                 nameVal := strings.Split(colspec, "=")
266                 if len(nameVal) != 2 {
267                         stmt.Close()
268                         return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
269                 }
270                 column, value := nameVal[0], nameVal[1]
271                 _, ok := c.db.columnType(stmt.table, column)
272                 if !ok {
273                         stmt.Close()
274                         return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
275                 }
276                 if value != "?" {
277                         stmt.Close()
278                         return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
279                                 stmt.table, column)
280                 }
281                 stmt.whereCol = append(stmt.whereCol, column)
282                 stmt.placeholders++
283         }
284         return stmt, nil
285 }
286
287 // parts are table|col=type,col2=type2
288 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
289         if len(parts) != 2 {
290                 stmt.Close()
291                 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
292         }
293         stmt.table = parts[0]
294         for n, colspec := range strings.Split(parts[1], ",") {
295                 nameType := strings.Split(colspec, "=")
296                 if len(nameType) != 2 {
297                         stmt.Close()
298                         return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
299                 }
300                 stmt.colName = append(stmt.colName, nameType[0])
301                 stmt.colType = append(stmt.colType, nameType[1])
302         }
303         return stmt, nil
304 }
305
306 // parts are table|col=?,col2=val
307 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
308         if len(parts) != 2 {
309                 stmt.Close()
310                 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
311         }
312         stmt.table = parts[0]
313         for n, colspec := range strings.Split(parts[1], ",") {
314                 nameVal := strings.Split(colspec, "=")
315                 if len(nameVal) != 2 {
316                         stmt.Close()
317                         return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
318                 }
319                 column, value := nameVal[0], nameVal[1]
320                 ctype, ok := c.db.columnType(stmt.table, column)
321                 if !ok {
322                         stmt.Close()
323                         return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
324                 }
325                 stmt.colName = append(stmt.colName, column)
326
327                 if value != "?" {
328                         var subsetVal interface{}
329                         // Convert to driver subset type
330                         switch ctype {
331                         case "string":
332                                 subsetVal = []byte(value)
333                         case "blob":
334                                 subsetVal = []byte(value)
335                         case "int32":
336                                 i, err := strconv.Atoi(value)
337                                 if err != nil {
338                                         stmt.Close()
339                                         return nil, errf("invalid conversion to int32 from %q", value)
340                                 }
341                                 subsetVal = int64(i) // int64 is a subset type, but not int32
342                         default:
343                                 stmt.Close()
344                                 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
345                         }
346                         stmt.colValue = append(stmt.colValue, subsetVal)
347                 } else {
348                         stmt.placeholders++
349                         stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
350                         stmt.colValue = append(stmt.colValue, "?")
351                 }
352         }
353         return stmt, nil
354 }
355
356 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
357         c.numPrepare++
358         if c.db == nil {
359                 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
360         }
361         parts := strings.Split(query, "|")
362         if len(parts) < 1 {
363                 return nil, errf("empty query")
364         }
365         cmd := parts[0]
366         parts = parts[1:]
367         stmt := &fakeStmt{q: query, c: c, cmd: cmd}
368         c.incrStat(&c.stmtsMade)
369         switch cmd {
370         case "WIPE":
371                 // Nothing
372         case "SELECT":
373                 return c.prepareSelect(stmt, parts)
374         case "CREATE":
375                 return c.prepareCreate(stmt, parts)
376         case "INSERT":
377                 return c.prepareInsert(stmt, parts)
378         default:
379                 stmt.Close()
380                 return nil, errf("unsupported command type %q", cmd)
381         }
382         return stmt, nil
383 }
384
385 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
386         return s.placeholderConverter[idx]
387 }
388
389 func (s *fakeStmt) Close() error {
390         if !s.closed {
391                 s.c.incrStat(&s.c.stmtsClosed)
392                 s.closed = true
393         }
394         return nil
395 }
396
397 var errClosed = errors.New("fakedb: statement has been closed")
398
399 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
400         if s.closed {
401                 return nil, errClosed
402         }
403         err := checkSubsetTypes(args)
404         if err != nil {
405                 return nil, err
406         }
407
408         db := s.c.db
409         switch s.cmd {
410         case "WIPE":
411                 db.wipe()
412                 return driver.ResultNoRows, nil
413         case "CREATE":
414                 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
415                         return nil, err
416                 }
417                 return driver.ResultNoRows, nil
418         case "INSERT":
419                 return s.execInsert(args)
420         }
421         fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
422         return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
423 }
424
425 func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) {
426         db := s.c.db
427         if len(args) != s.placeholders {
428                 panic("error in pkg db; should only get here if size is correct")
429         }
430         db.mu.Lock()
431         t, ok := db.table(s.table)
432         db.mu.Unlock()
433         if !ok {
434                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
435         }
436
437         t.mu.Lock()
438         defer t.mu.Unlock()
439
440         cols := make([]interface{}, len(t.colname))
441         argPos := 0
442         for n, colname := range s.colName {
443                 colidx := t.columnIndex(colname)
444                 if colidx == -1 {
445                         return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
446                 }
447                 var val interface{}
448                 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
449                         val = args[argPos]
450                         argPos++
451                 } else {
452                         val = s.colValue[n]
453                 }
454                 cols[colidx] = val
455         }
456
457         t.rows = append(t.rows, &row{cols: cols})
458         return driver.RowsAffected(1), nil
459 }
460
461 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
462         if s.closed {
463                 return nil, errClosed
464         }
465         err := checkSubsetTypes(args)
466         if err != nil {
467                 return nil, err
468         }
469
470         db := s.c.db
471         if len(args) != s.placeholders {
472                 panic("error in pkg db; should only get here if size is correct")
473         }
474
475         db.mu.Lock()
476         t, ok := db.table(s.table)
477         db.mu.Unlock()
478         if !ok {
479                 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
480         }
481         t.mu.Lock()
482         defer t.mu.Unlock()
483
484         colIdx := make(map[string]int) // select column name -> column index in table
485         for _, name := range s.colName {
486                 idx := t.columnIndex(name)
487                 if idx == -1 {
488                         return nil, fmt.Errorf("fakedb: unknown column name %q", name)
489                 }
490                 colIdx[name] = idx
491         }
492
493         mrows := []*row{}
494 rows:
495         for _, trow := range t.rows {
496                 // Process the where clause, skipping non-match rows. This is lazy
497                 // and just uses fmt.Sprintf("%v") to test equality.  Good enough
498                 // for test code.
499                 for widx, wcol := range s.whereCol {
500                         idx := t.columnIndex(wcol)
501                         if idx == -1 {
502                                 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
503                         }
504                         tcol := trow.cols[idx]
505                         if bs, ok := tcol.([]byte); ok {
506                                 // lazy hack to avoid sprintf %v on a []byte
507                                 tcol = string(bs)
508                         }
509                         if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
510                                 continue rows
511                         }
512                 }
513                 mrow := &row{cols: make([]interface{}, len(s.colName))}
514                 for seli, name := range s.colName {
515                         mrow.cols[seli] = trow.cols[colIdx[name]]
516                 }
517                 mrows = append(mrows, mrow)
518         }
519
520         cursor := &rowsCursor{
521                 pos:  -1,
522                 rows: mrows,
523                 cols: s.colName,
524         }
525         return cursor, nil
526 }
527
528 func (s *fakeStmt) NumInput() int {
529         return s.placeholders
530 }
531
532 func (tx *fakeTx) Commit() error {
533         tx.c.currTx = nil
534         return nil
535 }
536
537 func (tx *fakeTx) Rollback() error {
538         tx.c.currTx = nil
539         return nil
540 }
541
542 type rowsCursor struct {
543         cols   []string
544         pos    int
545         rows   []*row
546         closed bool
547
548         // a clone of slices to give out to clients, indexed by the
549         // the original slice's first byte address.  we clone them
550         // just so we're able to corrupt them on close.
551         bytesClone map[*byte][]byte
552 }
553
554 func (rc *rowsCursor) Close() error {
555         if !rc.closed {
556                 for _, bs := range rc.bytesClone {
557                         bs[0] = 255 // first byte corrupted
558                 }
559         }
560         rc.closed = true
561         return nil
562 }
563
564 func (rc *rowsCursor) Columns() []string {
565         return rc.cols
566 }
567
568 func (rc *rowsCursor) Next(dest []driver.Value) error {
569         if rc.closed {
570                 return errors.New("fakedb: cursor is closed")
571         }
572         rc.pos++
573         if rc.pos >= len(rc.rows) {
574                 return io.EOF // per interface spec
575         }
576         for i, v := range rc.rows[rc.pos].cols {
577                 // TODO(bradfitz): convert to subset types? naah, I
578                 // think the subset types should only be input to
579                 // driver, but the sql package should be able to handle
580                 // a wider range of types coming out of drivers. all
581                 // for ease of drivers, and to prevent drivers from
582                 // messing up conversions or doing them differently.
583                 dest[i] = v
584
585                 if bs, ok := v.([]byte); ok {
586                         if rc.bytesClone == nil {
587                                 rc.bytesClone = make(map[*byte][]byte)
588                         }
589                         clone, ok := rc.bytesClone[&bs[0]]
590                         if !ok {
591                                 clone = make([]byte, len(bs))
592                                 copy(clone, bs)
593                                 rc.bytesClone[&bs[0]] = clone
594                         }
595                         dest[i] = clone
596                 }
597         }
598         return nil
599 }
600
601 func converterForType(typ string) driver.ValueConverter {
602         switch typ {
603         case "bool":
604                 return driver.Bool
605         case "nullbool":
606                 return driver.Null{Converter: driver.Bool}
607         case "int32":
608                 return driver.Int32
609         case "string":
610                 return driver.NotNull{Converter: driver.String}
611         case "nullstring":
612                 return driver.Null{Converter: driver.String}
613         case "int64":
614                 // TODO(coopernurse): add type-specific converter
615                 return driver.NotNull{Converter: driver.DefaultParameterConverter}
616         case "nullint64":
617                 // TODO(coopernurse): add type-specific converter
618                 return driver.Null{Converter: driver.DefaultParameterConverter}
619         case "float64":
620                 // TODO(coopernurse): add type-specific converter
621                 return driver.NotNull{Converter: driver.DefaultParameterConverter}
622         case "nullfloat64":
623                 // TODO(coopernurse): add type-specific converter
624                 return driver.Null{Converter: driver.DefaultParameterConverter}
625         case "datetime":
626                 return driver.DefaultParameterConverter
627         }
628         panic("invalid fakedb column type of " + typ)
629 }