1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
21 // fakeDriver is a fake database that implements Go's driver.Driver
22 // interface, just for testing.
24 // It speaks a query language that's semantically similar to but
25 // syntantically different and simpler than SQL. The syntax is as
29 // CREATE|<tablename>|<col>=<type>,<col>=<type>,...
30 // where types are: "string", [u]int{8,16,32,64}, "bool"
31 // INSERT|<tablename>|col=val,col2=val2,col3=?
32 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
34 // When opening a a fakeDriver's database, it starts empty with no
35 // tables. All tables and data are stored in memory only.
36 type fakeDriver struct {
39 dbs map[string]*fakeDB
47 tables map[string]*table
57 func (t *table) columnIndex(name string) int {
58 for n, nname := range t.colname {
67 cols []interface{} // must be same size as its table colname + coltype
70 func (r *row) clone() *row {
71 nrow := &row{cols: make([]interface{}, len(r.cols))}
72 copy(nrow.cols, r.cols)
76 type fakeConn struct {
77 db *fakeDB // where to return ourselves to
88 func (c *fakeConn) incrStat(v *int) {
98 type fakeStmt struct {
100 q string // just for debugging
107 colName []string // used by CREATE, INSERT, SELECT (selected columns)
108 colType []string // used by CREATE
109 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
110 placeholders int // used by INSERT/SELECT: number of ? params
112 whereCol []string // used by SELECT (all placeholders)
114 placeholderConverter []driver.ValueConverter // used by INSERT
117 var fdriver driver.Driver = &fakeDriver{}
120 Register("test", fdriver)
123 // Supports dsn forms:
125 // <dbname>;<opts> (no currently supported options)
126 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
127 parts := strings.Split(dsn, ";")
129 return nil, errors.New("fakedb: no database name")
138 return &fakeConn{db: db}, nil
141 func (d *fakeDriver) getDB(name string) *fakeDB {
145 d.dbs = make(map[string]*fakeDB)
147 db, ok := d.dbs[name]
149 db = &fakeDB{name: name}
155 func (db *fakeDB) wipe() {
161 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
164 if db.tables == nil {
165 db.tables = make(map[string]*table)
167 if _, exist := db.tables[name]; exist {
168 return fmt.Errorf("table %q already exists", name)
170 if len(columnNames) != len(columnTypes) {
171 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
172 name, len(columnNames), len(columnTypes))
174 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
178 // must be called with db.mu lock held
179 func (db *fakeDB) table(table string) (*table, bool) {
180 if db.tables == nil {
183 t, ok := db.tables[table]
187 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
190 t, ok := db.table(table)
194 for n, cname := range t.colname {
196 return t.coltype[n], true
202 func (c *fakeConn) Begin() (driver.Tx, error) {
204 return nil, errors.New("already in a transaction")
206 c.currTx = &fakeTx{c: c}
210 func (c *fakeConn) Close() error {
212 return errors.New("can't close fakeConn; in a Transaction")
215 return errors.New("can't close fakeConn; already closed")
217 if c.stmtsMade > c.stmtsClosed {
218 return errors.New("can't close; dangling statement(s)")
224 func checkSubsetTypes(args []driver.Value) error {
225 for n, arg := range args {
227 case int64, float64, bool, nil, []byte, string, time.Time:
229 return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
235 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
236 // This is an optional interface, but it's implemented here
237 // just to check that all the args of of the proper types.
238 // ErrSkip is returned so the caller acts as if we didn't
239 // implement this at all.
240 err := checkSubsetTypes(args)
244 return nil, driver.ErrSkip
247 func errf(msg string, args ...interface{}) error {
248 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
251 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
252 // (note that where where columns must always contain ? marks,
253 // just a limitation for fakedb)
254 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
257 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
259 stmt.table = parts[0]
260 stmt.colName = strings.Split(parts[1], ",")
261 for n, colspec := range strings.Split(parts[2], ",") {
265 nameVal := strings.Split(colspec, "=")
266 if len(nameVal) != 2 {
268 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
270 column, value := nameVal[0], nameVal[1]
271 _, ok := c.db.columnType(stmt.table, column)
274 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
278 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
281 stmt.whereCol = append(stmt.whereCol, column)
287 // parts are table|col=type,col2=type2
288 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
291 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
293 stmt.table = parts[0]
294 for n, colspec := range strings.Split(parts[1], ",") {
295 nameType := strings.Split(colspec, "=")
296 if len(nameType) != 2 {
298 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
300 stmt.colName = append(stmt.colName, nameType[0])
301 stmt.colType = append(stmt.colType, nameType[1])
306 // parts are table|col=?,col2=val
307 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
310 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
312 stmt.table = parts[0]
313 for n, colspec := range strings.Split(parts[1], ",") {
314 nameVal := strings.Split(colspec, "=")
315 if len(nameVal) != 2 {
317 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
319 column, value := nameVal[0], nameVal[1]
320 ctype, ok := c.db.columnType(stmt.table, column)
323 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
325 stmt.colName = append(stmt.colName, column)
328 var subsetVal interface{}
329 // Convert to driver subset type
332 subsetVal = []byte(value)
334 subsetVal = []byte(value)
336 i, err := strconv.Atoi(value)
339 return nil, errf("invalid conversion to int32 from %q", value)
341 subsetVal = int64(i) // int64 is a subset type, but not int32
344 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
346 stmt.colValue = append(stmt.colValue, subsetVal)
349 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
350 stmt.colValue = append(stmt.colValue, "?")
356 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
359 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
361 parts := strings.Split(query, "|")
363 return nil, errf("empty query")
367 stmt := &fakeStmt{q: query, c: c, cmd: cmd}
368 c.incrStat(&c.stmtsMade)
373 return c.prepareSelect(stmt, parts)
375 return c.prepareCreate(stmt, parts)
377 return c.prepareInsert(stmt, parts)
380 return nil, errf("unsupported command type %q", cmd)
385 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
386 return s.placeholderConverter[idx]
389 func (s *fakeStmt) Close() error {
391 s.c.incrStat(&s.c.stmtsClosed)
397 var errClosed = errors.New("fakedb: statement has been closed")
399 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
401 return nil, errClosed
403 err := checkSubsetTypes(args)
412 return driver.ResultNoRows, nil
414 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
417 return driver.ResultNoRows, nil
419 return s.execInsert(args)
421 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
422 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
425 func (s *fakeStmt) execInsert(args []driver.Value) (driver.Result, error) {
427 if len(args) != s.placeholders {
428 panic("error in pkg db; should only get here if size is correct")
431 t, ok := db.table(s.table)
434 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
440 cols := make([]interface{}, len(t.colname))
442 for n, colname := range s.colName {
443 colidx := t.columnIndex(colname)
445 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
448 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
457 t.rows = append(t.rows, &row{cols: cols})
458 return driver.RowsAffected(1), nil
461 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
463 return nil, errClosed
465 err := checkSubsetTypes(args)
471 if len(args) != s.placeholders {
472 panic("error in pkg db; should only get here if size is correct")
476 t, ok := db.table(s.table)
479 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
484 colIdx := make(map[string]int) // select column name -> column index in table
485 for _, name := range s.colName {
486 idx := t.columnIndex(name)
488 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
495 for _, trow := range t.rows {
496 // Process the where clause, skipping non-match rows. This is lazy
497 // and just uses fmt.Sprintf("%v") to test equality. Good enough
499 for widx, wcol := range s.whereCol {
500 idx := t.columnIndex(wcol)
502 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
504 tcol := trow.cols[idx]
505 if bs, ok := tcol.([]byte); ok {
506 // lazy hack to avoid sprintf %v on a []byte
509 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
513 mrow := &row{cols: make([]interface{}, len(s.colName))}
514 for seli, name := range s.colName {
515 mrow.cols[seli] = trow.cols[colIdx[name]]
517 mrows = append(mrows, mrow)
520 cursor := &rowsCursor{
528 func (s *fakeStmt) NumInput() int {
529 return s.placeholders
532 func (tx *fakeTx) Commit() error {
537 func (tx *fakeTx) Rollback() error {
542 type rowsCursor struct {
548 // a clone of slices to give out to clients, indexed by the
549 // the original slice's first byte address. we clone them
550 // just so we're able to corrupt them on close.
551 bytesClone map[*byte][]byte
554 func (rc *rowsCursor) Close() error {
556 for _, bs := range rc.bytesClone {
557 bs[0] = 255 // first byte corrupted
564 func (rc *rowsCursor) Columns() []string {
568 func (rc *rowsCursor) Next(dest []driver.Value) error {
570 return errors.New("fakedb: cursor is closed")
573 if rc.pos >= len(rc.rows) {
574 return io.EOF // per interface spec
576 for i, v := range rc.rows[rc.pos].cols {
577 // TODO(bradfitz): convert to subset types? naah, I
578 // think the subset types should only be input to
579 // driver, but the sql package should be able to handle
580 // a wider range of types coming out of drivers. all
581 // for ease of drivers, and to prevent drivers from
582 // messing up conversions or doing them differently.
585 if bs, ok := v.([]byte); ok {
586 if rc.bytesClone == nil {
587 rc.bytesClone = make(map[*byte][]byte)
589 clone, ok := rc.bytesClone[&bs[0]]
591 clone = make([]byte, len(bs))
593 rc.bytesClone[&bs[0]] = clone
601 func converterForType(typ string) driver.ValueConverter {
606 return driver.Null{Converter: driver.Bool}
610 return driver.NotNull{Converter: driver.String}
612 return driver.Null{Converter: driver.String}
614 // TODO(coopernurse): add type-specific converter
615 return driver.NotNull{Converter: driver.DefaultParameterConverter}
617 // TODO(coopernurse): add type-specific converter
618 return driver.Null{Converter: driver.DefaultParameterConverter}
620 // TODO(coopernurse): add type-specific converter
621 return driver.NotNull{Converter: driver.DefaultParameterConverter}
623 // TODO(coopernurse): add type-specific converter
624 return driver.Null{Converter: driver.DefaultParameterConverter}
626 return driver.DefaultParameterConverter
628 panic("invalid fakedb column type of " + typ)