Rebase.
[official-gcc.git] / libgo / go / database / sql / fakedb_test.go
blobc7db0dd77b370a1efa5f27b31707117a3168fd78
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 package sql
7 import (
8 "database/sql/driver"
9 "errors"
10 "fmt"
11 "io"
12 "log"
13 "strconv"
14 "strings"
15 "sync"
16 "testing"
17 "time"
20 var _ = log.Printf
22 // fakeDriver is a fake database that implements Go's driver.Driver
23 // interface, just for testing.
25 // It speaks a query language that's semantically similar to but
26 // syntactically different and simpler than SQL. The syntax is as
27 // follows:
29 // WIPE
30 // CREATE|<tablename>|<col>=<type>,<col>=<type>,...
31 // where types are: "string", [u]int{8,16,32,64}, "bool"
32 // INSERT|<tablename>|col=val,col2=val2,col3=?
33 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
35 // When opening a fakeDriver's database, it starts empty with no
36 // tables. All tables and data are stored in memory only.
37 type fakeDriver struct {
38 mu sync.Mutex // guards 3 following fields
39 openCount int // conn opens
40 closeCount int // conn closes
41 waitCh chan struct{}
42 waitingCh chan struct{}
43 dbs map[string]*fakeDB
46 type fakeDB struct {
47 name string
49 mu sync.Mutex
50 free []*fakeConn
51 tables map[string]*table
52 badConn bool
55 type table struct {
56 mu sync.Mutex
57 colname []string
58 coltype []string
59 rows []*row
62 func (t *table) columnIndex(name string) int {
63 for n, nname := range t.colname {
64 if name == nname {
65 return n
68 return -1
71 type row struct {
72 cols []interface{} // must be same size as its table colname + coltype
75 func (r *row) clone() *row {
76 nrow := &row{cols: make([]interface{}, len(r.cols))}
77 copy(nrow.cols, r.cols)
78 return nrow
81 type fakeConn struct {
82 db *fakeDB // where to return ourselves to
84 currTx *fakeTx
86 // Stats for tests:
87 mu sync.Mutex
88 stmtsMade int
89 stmtsClosed int
90 numPrepare int
91 bad bool
94 func (c *fakeConn) incrStat(v *int) {
95 c.mu.Lock()
96 *v++
97 c.mu.Unlock()
100 type fakeTx struct {
101 c *fakeConn
104 type fakeStmt struct {
105 c *fakeConn
106 q string // just for debugging
108 cmd string
109 table string
111 closed bool
113 colName []string // used by CREATE, INSERT, SELECT (selected columns)
114 colType []string // used by CREATE
115 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
116 placeholders int // used by INSERT/SELECT: number of ? params
118 whereCol []string // used by SELECT (all placeholders)
120 placeholderConverter []driver.ValueConverter // used by INSERT
123 var fdriver driver.Driver = &fakeDriver{}
125 func init() {
126 Register("test", fdriver)
129 // Supports dsn forms:
130 // <dbname>
131 // <dbname>;<opts> (only currently supported option is `badConn`,
132 // which causes driver.ErrBadConn to be returned on
133 // every other conn.Begin())
134 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
135 parts := strings.Split(dsn, ";")
136 if len(parts) < 1 {
137 return nil, errors.New("fakedb: no database name")
139 name := parts[0]
141 db := d.getDB(name)
143 d.mu.Lock()
144 d.openCount++
145 d.mu.Unlock()
146 conn := &fakeConn{db: db}
148 if len(parts) >= 2 && parts[1] == "badConn" {
149 conn.bad = true
151 if d.waitCh != nil {
152 d.waitingCh <- struct{}{}
153 <-d.waitCh
154 d.waitCh = nil
155 d.waitingCh = nil
157 return conn, nil
160 func (d *fakeDriver) getDB(name string) *fakeDB {
161 d.mu.Lock()
162 defer d.mu.Unlock()
163 if d.dbs == nil {
164 d.dbs = make(map[string]*fakeDB)
166 db, ok := d.dbs[name]
167 if !ok {
168 db = &fakeDB{name: name}
169 d.dbs[name] = db
171 return db
174 func (db *fakeDB) wipe() {
175 db.mu.Lock()
176 defer db.mu.Unlock()
177 db.tables = nil
180 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
181 db.mu.Lock()
182 defer db.mu.Unlock()
183 if db.tables == nil {
184 db.tables = make(map[string]*table)
186 if _, exist := db.tables[name]; exist {
187 return fmt.Errorf("table %q already exists", name)
189 if len(columnNames) != len(columnTypes) {
190 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
191 name, len(columnNames), len(columnTypes))
193 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
194 return nil
197 // must be called with db.mu lock held
198 func (db *fakeDB) table(table string) (*table, bool) {
199 if db.tables == nil {
200 return nil, false
202 t, ok := db.tables[table]
203 return t, ok
206 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
207 db.mu.Lock()
208 defer db.mu.Unlock()
209 t, ok := db.table(table)
210 if !ok {
211 return
213 for n, cname := range t.colname {
214 if cname == column {
215 return t.coltype[n], true
218 return "", false
221 func (c *fakeConn) isBad() bool {
222 // if not simulating bad conn, do nothing
223 if !c.bad {
224 return false
226 // alternate between bad conn and not bad conn
227 c.db.badConn = !c.db.badConn
228 return c.db.badConn
231 func (c *fakeConn) Begin() (driver.Tx, error) {
232 if c.isBad() {
233 return nil, driver.ErrBadConn
235 if c.currTx != nil {
236 return nil, errors.New("already in a transaction")
238 c.currTx = &fakeTx{c: c}
239 return c.currTx, nil
242 var hookPostCloseConn struct {
243 sync.Mutex
244 fn func(*fakeConn, error)
247 func setHookpostCloseConn(fn func(*fakeConn, error)) {
248 hookPostCloseConn.Lock()
249 defer hookPostCloseConn.Unlock()
250 hookPostCloseConn.fn = fn
253 var testStrictClose *testing.T
255 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
256 // fails to close. If nil, the check is disabled.
257 func setStrictFakeConnClose(t *testing.T) {
258 testStrictClose = t
261 func (c *fakeConn) Close() (err error) {
262 drv := fdriver.(*fakeDriver)
263 defer func() {
264 if err != nil && testStrictClose != nil {
265 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
267 hookPostCloseConn.Lock()
268 fn := hookPostCloseConn.fn
269 hookPostCloseConn.Unlock()
270 if fn != nil {
271 fn(c, err)
273 if err == nil {
274 drv.mu.Lock()
275 drv.closeCount++
276 drv.mu.Unlock()
279 if c.currTx != nil {
280 return errors.New("can't close fakeConn; in a Transaction")
282 if c.db == nil {
283 return errors.New("can't close fakeConn; already closed")
285 if c.stmtsMade > c.stmtsClosed {
286 return errors.New("can't close; dangling statement(s)")
288 c.db = nil
289 return nil
292 func checkSubsetTypes(args []driver.Value) error {
293 for n, arg := range args {
294 switch arg.(type) {
295 case int64, float64, bool, nil, []byte, string, time.Time:
296 default:
297 return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
300 return nil
303 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
304 // This is an optional interface, but it's implemented here
305 // just to check that all the args are of the proper types.
306 // ErrSkip is returned so the caller acts as if we didn't
307 // implement this at all.
308 err := checkSubsetTypes(args)
309 if err != nil {
310 return nil, err
312 return nil, driver.ErrSkip
315 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
316 // This is an optional interface, but it's implemented here
317 // just to check that all the args are of the proper types.
318 // ErrSkip is returned so the caller acts as if we didn't
319 // implement this at all.
320 err := checkSubsetTypes(args)
321 if err != nil {
322 return nil, err
324 return nil, driver.ErrSkip
327 func errf(msg string, args ...interface{}) error {
328 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
331 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
332 // (note that where columns must always contain ? marks,
333 // just a limitation for fakedb)
334 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
335 if len(parts) != 3 {
336 stmt.Close()
337 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
339 stmt.table = parts[0]
340 stmt.colName = strings.Split(parts[1], ",")
341 for n, colspec := range strings.Split(parts[2], ",") {
342 if colspec == "" {
343 continue
345 nameVal := strings.Split(colspec, "=")
346 if len(nameVal) != 2 {
347 stmt.Close()
348 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
350 column, value := nameVal[0], nameVal[1]
351 _, ok := c.db.columnType(stmt.table, column)
352 if !ok {
353 stmt.Close()
354 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
356 if value != "?" {
357 stmt.Close()
358 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
359 stmt.table, column)
361 stmt.whereCol = append(stmt.whereCol, column)
362 stmt.placeholders++
364 return stmt, nil
367 // parts are table|col=type,col2=type2
368 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
369 if len(parts) != 2 {
370 stmt.Close()
371 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
373 stmt.table = parts[0]
374 for n, colspec := range strings.Split(parts[1], ",") {
375 nameType := strings.Split(colspec, "=")
376 if len(nameType) != 2 {
377 stmt.Close()
378 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
380 stmt.colName = append(stmt.colName, nameType[0])
381 stmt.colType = append(stmt.colType, nameType[1])
383 return stmt, nil
386 // parts are table|col=?,col2=val
387 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
388 if len(parts) != 2 {
389 stmt.Close()
390 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
392 stmt.table = parts[0]
393 for n, colspec := range strings.Split(parts[1], ",") {
394 nameVal := strings.Split(colspec, "=")
395 if len(nameVal) != 2 {
396 stmt.Close()
397 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
399 column, value := nameVal[0], nameVal[1]
400 ctype, ok := c.db.columnType(stmt.table, column)
401 if !ok {
402 stmt.Close()
403 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
405 stmt.colName = append(stmt.colName, column)
407 if value != "?" {
408 var subsetVal interface{}
409 // Convert to driver subset type
410 switch ctype {
411 case "string":
412 subsetVal = []byte(value)
413 case "blob":
414 subsetVal = []byte(value)
415 case "int32":
416 i, err := strconv.Atoi(value)
417 if err != nil {
418 stmt.Close()
419 return nil, errf("invalid conversion to int32 from %q", value)
421 subsetVal = int64(i) // int64 is a subset type, but not int32
422 default:
423 stmt.Close()
424 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
426 stmt.colValue = append(stmt.colValue, subsetVal)
427 } else {
428 stmt.placeholders++
429 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
430 stmt.colValue = append(stmt.colValue, "?")
433 return stmt, nil
436 // hook to simulate broken connections
437 var hookPrepareBadConn func() bool
439 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
440 c.numPrepare++
441 if c.db == nil {
442 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
445 if hookPrepareBadConn != nil && hookPrepareBadConn() {
446 return nil, driver.ErrBadConn
449 parts := strings.Split(query, "|")
450 if len(parts) < 1 {
451 return nil, errf("empty query")
453 cmd := parts[0]
454 parts = parts[1:]
455 stmt := &fakeStmt{q: query, c: c, cmd: cmd}
456 c.incrStat(&c.stmtsMade)
457 switch cmd {
458 case "WIPE":
459 // Nothing
460 case "SELECT":
461 return c.prepareSelect(stmt, parts)
462 case "CREATE":
463 return c.prepareCreate(stmt, parts)
464 case "INSERT":
465 return c.prepareInsert(stmt, parts)
466 case "NOSERT":
467 // Do all the prep-work like for an INSERT but don't actually insert the row.
468 // Used for some of the concurrent tests.
469 return c.prepareInsert(stmt, parts)
470 default:
471 stmt.Close()
472 return nil, errf("unsupported command type %q", cmd)
474 return stmt, nil
477 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
478 if len(s.placeholderConverter) == 0 {
479 return driver.DefaultParameterConverter
481 return s.placeholderConverter[idx]
484 func (s *fakeStmt) Close() error {
485 if s.c == nil {
486 panic("nil conn in fakeStmt.Close")
488 if s.c.db == nil {
489 panic("in fakeStmt.Close, conn's db is nil (already closed)")
491 if !s.closed {
492 s.c.incrStat(&s.c.stmtsClosed)
493 s.closed = true
495 return nil
498 var errClosed = errors.New("fakedb: statement has been closed")
500 // hook to simulate broken connections
501 var hookExecBadConn func() bool
503 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
504 if s.closed {
505 return nil, errClosed
508 if hookExecBadConn != nil && hookExecBadConn() {
509 return nil, driver.ErrBadConn
512 err := checkSubsetTypes(args)
513 if err != nil {
514 return nil, err
517 db := s.c.db
518 switch s.cmd {
519 case "WIPE":
520 db.wipe()
521 return driver.ResultNoRows, nil
522 case "CREATE":
523 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
524 return nil, err
526 return driver.ResultNoRows, nil
527 case "INSERT":
528 return s.execInsert(args, true)
529 case "NOSERT":
530 // Do all the prep-work like for an INSERT but don't actually insert the row.
531 // Used for some of the concurrent tests.
532 return s.execInsert(args, false)
534 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
535 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
538 // When doInsert is true, add the row to the table.
539 // When doInsert is false do prep-work and error checking, but don't
540 // actually add the row to the table.
541 func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) {
542 db := s.c.db
543 if len(args) != s.placeholders {
544 panic("error in pkg db; should only get here if size is correct")
546 db.mu.Lock()
547 t, ok := db.table(s.table)
548 db.mu.Unlock()
549 if !ok {
550 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
553 t.mu.Lock()
554 defer t.mu.Unlock()
556 var cols []interface{}
557 if doInsert {
558 cols = make([]interface{}, len(t.colname))
560 argPos := 0
561 for n, colname := range s.colName {
562 colidx := t.columnIndex(colname)
563 if colidx == -1 {
564 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
566 var val interface{}
567 if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
568 val = args[argPos]
569 argPos++
570 } else {
571 val = s.colValue[n]
573 if doInsert {
574 cols[colidx] = val
578 if doInsert {
579 t.rows = append(t.rows, &row{cols: cols})
581 return driver.RowsAffected(1), nil
584 // hook to simulate broken connections
585 var hookQueryBadConn func() bool
587 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
588 if s.closed {
589 return nil, errClosed
592 if hookQueryBadConn != nil && hookQueryBadConn() {
593 return nil, driver.ErrBadConn
596 err := checkSubsetTypes(args)
597 if err != nil {
598 return nil, err
601 db := s.c.db
602 if len(args) != s.placeholders {
603 panic("error in pkg db; should only get here if size is correct")
606 db.mu.Lock()
607 t, ok := db.table(s.table)
608 db.mu.Unlock()
609 if !ok {
610 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
613 if s.table == "magicquery" {
614 if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
615 if args[0] == "sleep" {
616 time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
621 t.mu.Lock()
622 defer t.mu.Unlock()
624 colIdx := make(map[string]int) // select column name -> column index in table
625 for _, name := range s.colName {
626 idx := t.columnIndex(name)
627 if idx == -1 {
628 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
630 colIdx[name] = idx
633 mrows := []*row{}
634 rows:
635 for _, trow := range t.rows {
636 // Process the where clause, skipping non-match rows. This is lazy
637 // and just uses fmt.Sprintf("%v") to test equality. Good enough
638 // for test code.
639 for widx, wcol := range s.whereCol {
640 idx := t.columnIndex(wcol)
641 if idx == -1 {
642 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
644 tcol := trow.cols[idx]
645 if bs, ok := tcol.([]byte); ok {
646 // lazy hack to avoid sprintf %v on a []byte
647 tcol = string(bs)
649 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
650 continue rows
653 mrow := &row{cols: make([]interface{}, len(s.colName))}
654 for seli, name := range s.colName {
655 mrow.cols[seli] = trow.cols[colIdx[name]]
657 mrows = append(mrows, mrow)
660 cursor := &rowsCursor{
661 pos: -1,
662 rows: mrows,
663 cols: s.colName,
664 errPos: -1,
666 return cursor, nil
669 func (s *fakeStmt) NumInput() int {
670 return s.placeholders
673 func (tx *fakeTx) Commit() error {
674 tx.c.currTx = nil
675 return nil
678 func (tx *fakeTx) Rollback() error {
679 tx.c.currTx = nil
680 return nil
683 type rowsCursor struct {
684 cols []string
685 pos int
686 rows []*row
687 closed bool
689 // errPos and err are for making Next return early with error.
690 errPos int
691 err error
693 // a clone of slices to give out to clients, indexed by the
694 // the original slice's first byte address. we clone them
695 // just so we're able to corrupt them on close.
696 bytesClone map[*byte][]byte
699 func (rc *rowsCursor) Close() error {
700 if !rc.closed {
701 for _, bs := range rc.bytesClone {
702 bs[0] = 255 // first byte corrupted
705 rc.closed = true
706 return nil
709 func (rc *rowsCursor) Columns() []string {
710 return rc.cols
713 var rowsCursorNextHook func(dest []driver.Value) error
715 func (rc *rowsCursor) Next(dest []driver.Value) error {
716 if rowsCursorNextHook != nil {
717 return rowsCursorNextHook(dest)
720 if rc.closed {
721 return errors.New("fakedb: cursor is closed")
723 rc.pos++
724 if rc.pos == rc.errPos {
725 return rc.err
727 if rc.pos >= len(rc.rows) {
728 return io.EOF // per interface spec
730 for i, v := range rc.rows[rc.pos].cols {
731 // TODO(bradfitz): convert to subset types? naah, I
732 // think the subset types should only be input to
733 // driver, but the sql package should be able to handle
734 // a wider range of types coming out of drivers. all
735 // for ease of drivers, and to prevent drivers from
736 // messing up conversions or doing them differently.
737 dest[i] = v
739 if bs, ok := v.([]byte); ok {
740 if rc.bytesClone == nil {
741 rc.bytesClone = make(map[*byte][]byte)
743 clone, ok := rc.bytesClone[&bs[0]]
744 if !ok {
745 clone = make([]byte, len(bs))
746 copy(clone, bs)
747 rc.bytesClone[&bs[0]] = clone
749 dest[i] = clone
752 return nil
755 // fakeDriverString is like driver.String, but indirects pointers like
756 // DefaultValueConverter.
758 // This could be surprising behavior to retroactively apply to
759 // driver.String now that Go1 is out, but this is convenient for
760 // our TestPointerParamsAndScans.
762 type fakeDriverString struct{}
764 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
765 switch c := v.(type) {
766 case string, []byte:
767 return v, nil
768 case *string:
769 if c == nil {
770 return nil, nil
772 return *c, nil
774 return fmt.Sprintf("%v", v), nil
777 func converterForType(typ string) driver.ValueConverter {
778 switch typ {
779 case "bool":
780 return driver.Bool
781 case "nullbool":
782 return driver.Null{Converter: driver.Bool}
783 case "int32":
784 return driver.Int32
785 case "string":
786 return driver.NotNull{Converter: fakeDriverString{}}
787 case "nullstring":
788 return driver.Null{Converter: fakeDriverString{}}
789 case "int64":
790 // TODO(coopernurse): add type-specific converter
791 return driver.NotNull{Converter: driver.DefaultParameterConverter}
792 case "nullint64":
793 // TODO(coopernurse): add type-specific converter
794 return driver.Null{Converter: driver.DefaultParameterConverter}
795 case "float64":
796 // TODO(coopernurse): add type-specific converter
797 return driver.NotNull{Converter: driver.DefaultParameterConverter}
798 case "nullfloat64":
799 // TODO(coopernurse): add type-specific converter
800 return driver.Null{Converter: driver.DefaultParameterConverter}
801 case "datetime":
802 return driver.DefaultParameterConverter
804 panic("invalid fakedb column type of " + typ)