libgo: update to Go1.10rc2
[official-gcc.git] / libgo / go / database / sql / fakedb_test.go
blobabb8d40fc022ed5c492e2709bdd1ec82b8fcccac
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 package sql
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "log"
14 "reflect"
15 "sort"
16 "strconv"
17 "strings"
18 "sync"
19 "testing"
20 "time"
23 var _ = log.Printf
25 // fakeDriver is a fake database that implements Go's driver.Driver
26 // interface, just for testing.
28 // It speaks a query language that's semantically similar to but
29 // syntactically different and simpler than SQL. The syntax is as
30 // follows:
32 // WIPE
33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,...
34 // where types are: "string", [u]int{8,16,32,64}, "bool"
35 // INSERT|<tablename>|col=val,col2=val2,col3=?
36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
39 // Any of these can be preceded by PANIC|<method>|, to cause the
40 // named method on fakeStmt to panic.
42 // Any of these can be proceeded by WAIT|<duration>|, to cause the
43 // named method on fakeStmt to sleep for the specified duration.
45 // Multiple of these can be combined when separated with a semicolon.
47 // When opening a fakeDriver's database, it starts empty with no
48 // tables. All tables and data are stored in memory only.
49 type fakeDriver struct {
50 mu sync.Mutex // guards 3 following fields
51 openCount int // conn opens
52 closeCount int // conn closes
53 waitCh chan struct{}
54 waitingCh chan struct{}
55 dbs map[string]*fakeDB
58 type fakeConnector struct {
59 name string
61 waiter func(context.Context)
64 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
65 conn, err := fdriver.Open(c.name)
66 conn.(*fakeConn).waiter = c.waiter
67 return conn, err
70 func (c *fakeConnector) Driver() driver.Driver {
71 return fdriver
74 type fakeDriverCtx struct {
75 fakeDriver
78 var _ driver.DriverContext = &fakeDriverCtx{}
80 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
81 return &fakeConnector{name: name}, nil
84 type fakeDB struct {
85 name string
87 mu sync.Mutex
88 tables map[string]*table
89 badConn bool
90 allowAny bool
93 type table struct {
94 mu sync.Mutex
95 colname []string
96 coltype []string
97 rows []*row
100 func (t *table) columnIndex(name string) int {
101 for n, nname := range t.colname {
102 if name == nname {
103 return n
106 return -1
109 type row struct {
110 cols []interface{} // must be same size as its table colname + coltype
113 type memToucher interface {
114 // touchMem reads & writes some memory, to help find data races.
115 touchMem()
118 type fakeConn struct {
119 db *fakeDB // where to return ourselves to
121 currTx *fakeTx
123 // Every operation writes to line to enable the race detector
124 // check for data races.
125 line int64
127 // Stats for tests:
128 mu sync.Mutex
129 stmtsMade int
130 stmtsClosed int
131 numPrepare int
133 // bad connection tests; see isBad()
134 bad bool
135 stickyBad bool
137 skipDirtySession bool // tests that use Conn should set this to true.
139 // dirtySession tests ResetSession, true if a query has executed
140 // until ResetSession is called.
141 dirtySession bool
143 // The waiter is called before each query. May be used in place of the "WAIT"
144 // directive.
145 waiter func(context.Context)
148 func (c *fakeConn) touchMem() {
149 c.line++
152 func (c *fakeConn) incrStat(v *int) {
153 c.mu.Lock()
154 *v++
155 c.mu.Unlock()
158 type fakeTx struct {
159 c *fakeConn
162 type boundCol struct {
163 Column string
164 Placeholder string
165 Ordinal int
168 type fakeStmt struct {
169 memToucher
170 c *fakeConn
171 q string // just for debugging
173 cmd string
174 table string
175 panic string
176 wait time.Duration
178 next *fakeStmt // used for returning multiple results.
180 closed bool
182 colName []string // used by CREATE, INSERT, SELECT (selected columns)
183 colType []string // used by CREATE
184 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
185 placeholders int // used by INSERT/SELECT: number of ? params
187 whereCol []boundCol // used by SELECT (all placeholders)
189 placeholderConverter []driver.ValueConverter // used by INSERT
192 var fdriver driver.Driver = &fakeDriver{}
194 func init() {
195 Register("test", fdriver)
198 func contains(list []string, y string) bool {
199 for _, x := range list {
200 if x == y {
201 return true
204 return false
207 type Dummy struct {
208 driver.Driver
211 func TestDrivers(t *testing.T) {
212 unregisterAllDrivers()
213 Register("test", fdriver)
214 Register("invalid", Dummy{})
215 all := Drivers()
216 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
217 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
221 // hook to simulate connection failures
222 var hookOpenErr struct {
223 sync.Mutex
224 fn func() error
227 func setHookOpenErr(fn func() error) {
228 hookOpenErr.Lock()
229 defer hookOpenErr.Unlock()
230 hookOpenErr.fn = fn
233 // Supports dsn forms:
234 // <dbname>
235 // <dbname>;<opts> (only currently supported option is `badConn`,
236 // which causes driver.ErrBadConn to be returned on
237 // every other conn.Begin())
238 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
239 hookOpenErr.Lock()
240 fn := hookOpenErr.fn
241 hookOpenErr.Unlock()
242 if fn != nil {
243 if err := fn(); err != nil {
244 return nil, err
247 parts := strings.Split(dsn, ";")
248 if len(parts) < 1 {
249 return nil, errors.New("fakedb: no database name")
251 name := parts[0]
253 db := d.getDB(name)
255 d.mu.Lock()
256 d.openCount++
257 d.mu.Unlock()
258 conn := &fakeConn{db: db}
260 if len(parts) >= 2 && parts[1] == "badConn" {
261 conn.bad = true
263 if d.waitCh != nil {
264 d.waitingCh <- struct{}{}
265 <-d.waitCh
266 d.waitCh = nil
267 d.waitingCh = nil
269 return conn, nil
272 func (d *fakeDriver) getDB(name string) *fakeDB {
273 d.mu.Lock()
274 defer d.mu.Unlock()
275 if d.dbs == nil {
276 d.dbs = make(map[string]*fakeDB)
278 db, ok := d.dbs[name]
279 if !ok {
280 db = &fakeDB{name: name}
281 d.dbs[name] = db
283 return db
286 func (db *fakeDB) wipe() {
287 db.mu.Lock()
288 defer db.mu.Unlock()
289 db.tables = nil
292 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
293 db.mu.Lock()
294 defer db.mu.Unlock()
295 if db.tables == nil {
296 db.tables = make(map[string]*table)
298 if _, exist := db.tables[name]; exist {
299 return fmt.Errorf("table %q already exists", name)
301 if len(columnNames) != len(columnTypes) {
302 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
303 name, len(columnNames), len(columnTypes))
305 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
306 return nil
309 // must be called with db.mu lock held
310 func (db *fakeDB) table(table string) (*table, bool) {
311 if db.tables == nil {
312 return nil, false
314 t, ok := db.tables[table]
315 return t, ok
318 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
319 db.mu.Lock()
320 defer db.mu.Unlock()
321 t, ok := db.table(table)
322 if !ok {
323 return
325 for n, cname := range t.colname {
326 if cname == column {
327 return t.coltype[n], true
330 return "", false
333 func (c *fakeConn) isBad() bool {
334 if c.stickyBad {
335 return true
336 } else if c.bad {
337 if c.db == nil {
338 return false
340 // alternate between bad conn and not bad conn
341 c.db.badConn = !c.db.badConn
342 return c.db.badConn
343 } else {
344 return false
348 func (c *fakeConn) isDirtyAndMark() bool {
349 if c.skipDirtySession {
350 return false
352 if c.currTx != nil {
353 c.dirtySession = true
354 return false
356 if c.dirtySession {
357 return true
359 c.dirtySession = true
360 return false
363 func (c *fakeConn) Begin() (driver.Tx, error) {
364 if c.isBad() {
365 return nil, driver.ErrBadConn
367 if c.currTx != nil {
368 return nil, errors.New("already in a transaction")
370 c.touchMem()
371 c.currTx = &fakeTx{c: c}
372 return c.currTx, nil
375 var hookPostCloseConn struct {
376 sync.Mutex
377 fn func(*fakeConn, error)
380 func setHookpostCloseConn(fn func(*fakeConn, error)) {
381 hookPostCloseConn.Lock()
382 defer hookPostCloseConn.Unlock()
383 hookPostCloseConn.fn = fn
386 var testStrictClose *testing.T
388 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
389 // fails to close. If nil, the check is disabled.
390 func setStrictFakeConnClose(t *testing.T) {
391 testStrictClose = t
394 func (c *fakeConn) ResetSession(ctx context.Context) error {
395 c.dirtySession = false
396 if c.isBad() {
397 return driver.ErrBadConn
399 return nil
402 func (c *fakeConn) Close() (err error) {
403 drv := fdriver.(*fakeDriver)
404 defer func() {
405 if err != nil && testStrictClose != nil {
406 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
408 hookPostCloseConn.Lock()
409 fn := hookPostCloseConn.fn
410 hookPostCloseConn.Unlock()
411 if fn != nil {
412 fn(c, err)
414 if err == nil {
415 drv.mu.Lock()
416 drv.closeCount++
417 drv.mu.Unlock()
420 c.touchMem()
421 if c.currTx != nil {
422 return errors.New("can't close fakeConn; in a Transaction")
424 if c.db == nil {
425 return errors.New("can't close fakeConn; already closed")
427 if c.stmtsMade > c.stmtsClosed {
428 return errors.New("can't close; dangling statement(s)")
430 c.db = nil
431 return nil
434 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
435 for _, arg := range args {
436 switch arg.Value.(type) {
437 case int64, float64, bool, nil, []byte, string, time.Time:
438 default:
439 if !allowAny {
440 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
444 return nil
447 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
448 // Ensure that ExecContext is called if available.
449 panic("ExecContext was not called.")
452 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
453 // This is an optional interface, but it's implemented here
454 // just to check that all the args are of the proper types.
455 // ErrSkip is returned so the caller acts as if we didn't
456 // implement this at all.
457 err := checkSubsetTypes(c.db.allowAny, args)
458 if err != nil {
459 return nil, err
461 return nil, driver.ErrSkip
464 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
465 // Ensure that ExecContext is called if available.
466 panic("QueryContext was not called.")
469 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
470 // This is an optional interface, but it's implemented here
471 // just to check that all the args are of the proper types.
472 // ErrSkip is returned so the caller acts as if we didn't
473 // implement this at all.
474 err := checkSubsetTypes(c.db.allowAny, args)
475 if err != nil {
476 return nil, err
478 return nil, driver.ErrSkip
481 func errf(msg string, args ...interface{}) error {
482 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
485 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
486 // (note that where columns must always contain ? marks,
487 // just a limitation for fakedb)
488 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
489 if len(parts) != 3 {
490 stmt.Close()
491 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
493 stmt.table = parts[0]
495 stmt.colName = strings.Split(parts[1], ",")
496 for n, colspec := range strings.Split(parts[2], ",") {
497 if colspec == "" {
498 continue
500 nameVal := strings.Split(colspec, "=")
501 if len(nameVal) != 2 {
502 stmt.Close()
503 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
505 column, value := nameVal[0], nameVal[1]
506 _, ok := c.db.columnType(stmt.table, column)
507 if !ok {
508 stmt.Close()
509 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
511 if !strings.HasPrefix(value, "?") {
512 stmt.Close()
513 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
514 stmt.table, column)
516 stmt.placeholders++
517 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
519 return stmt, nil
522 // parts are table|col=type,col2=type2
523 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
524 if len(parts) != 2 {
525 stmt.Close()
526 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
528 stmt.table = parts[0]
529 for n, colspec := range strings.Split(parts[1], ",") {
530 nameType := strings.Split(colspec, "=")
531 if len(nameType) != 2 {
532 stmt.Close()
533 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
535 stmt.colName = append(stmt.colName, nameType[0])
536 stmt.colType = append(stmt.colType, nameType[1])
538 return stmt, nil
541 // parts are table|col=?,col2=val
542 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
543 if len(parts) != 2 {
544 stmt.Close()
545 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
547 stmt.table = parts[0]
548 for n, colspec := range strings.Split(parts[1], ",") {
549 nameVal := strings.Split(colspec, "=")
550 if len(nameVal) != 2 {
551 stmt.Close()
552 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
554 column, value := nameVal[0], nameVal[1]
555 ctype, ok := c.db.columnType(stmt.table, column)
556 if !ok {
557 stmt.Close()
558 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
560 stmt.colName = append(stmt.colName, column)
562 if !strings.HasPrefix(value, "?") {
563 var subsetVal interface{}
564 // Convert to driver subset type
565 switch ctype {
566 case "string":
567 subsetVal = []byte(value)
568 case "blob":
569 subsetVal = []byte(value)
570 case "int32":
571 i, err := strconv.Atoi(value)
572 if err != nil {
573 stmt.Close()
574 return nil, errf("invalid conversion to int32 from %q", value)
576 subsetVal = int64(i) // int64 is a subset type, but not int32
577 default:
578 stmt.Close()
579 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
581 stmt.colValue = append(stmt.colValue, subsetVal)
582 } else {
583 stmt.placeholders++
584 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
585 stmt.colValue = append(stmt.colValue, value)
588 return stmt, nil
591 // hook to simulate broken connections
592 var hookPrepareBadConn func() bool
594 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
595 panic("use PrepareContext")
598 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
599 c.numPrepare++
600 if c.db == nil {
601 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
604 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
605 return nil, driver.ErrBadConn
608 c.touchMem()
609 var firstStmt, prev *fakeStmt
610 for _, query := range strings.Split(query, ";") {
611 parts := strings.Split(query, "|")
612 if len(parts) < 1 {
613 return nil, errf("empty query")
615 stmt := &fakeStmt{q: query, c: c, memToucher: c}
616 if firstStmt == nil {
617 firstStmt = stmt
619 if len(parts) >= 3 {
620 switch parts[0] {
621 case "PANIC":
622 stmt.panic = parts[1]
623 parts = parts[2:]
624 case "WAIT":
625 wait, err := time.ParseDuration(parts[1])
626 if err != nil {
627 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
629 parts = parts[2:]
630 stmt.wait = wait
633 cmd := parts[0]
634 stmt.cmd = cmd
635 parts = parts[1:]
637 if c.waiter != nil {
638 c.waiter(ctx)
641 if stmt.wait > 0 {
642 wait := time.NewTimer(stmt.wait)
643 select {
644 case <-wait.C:
645 case <-ctx.Done():
646 wait.Stop()
647 return nil, ctx.Err()
651 c.incrStat(&c.stmtsMade)
652 var err error
653 switch cmd {
654 case "WIPE":
655 // Nothing
656 case "SELECT":
657 stmt, err = c.prepareSelect(stmt, parts)
658 case "CREATE":
659 stmt, err = c.prepareCreate(stmt, parts)
660 case "INSERT":
661 stmt, err = c.prepareInsert(stmt, parts)
662 case "NOSERT":
663 // Do all the prep-work like for an INSERT but don't actually insert the row.
664 // Used for some of the concurrent tests.
665 stmt, err = c.prepareInsert(stmt, parts)
666 default:
667 stmt.Close()
668 return nil, errf("unsupported command type %q", cmd)
670 if err != nil {
671 return nil, err
673 if prev != nil {
674 prev.next = stmt
676 prev = stmt
678 return firstStmt, nil
681 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
682 if s.panic == "ColumnConverter" {
683 panic(s.panic)
685 if len(s.placeholderConverter) == 0 {
686 return driver.DefaultParameterConverter
688 return s.placeholderConverter[idx]
691 func (s *fakeStmt) Close() error {
692 if s.panic == "Close" {
693 panic(s.panic)
695 if s.c == nil {
696 panic("nil conn in fakeStmt.Close")
698 if s.c.db == nil {
699 panic("in fakeStmt.Close, conn's db is nil (already closed)")
701 s.touchMem()
702 if !s.closed {
703 s.c.incrStat(&s.c.stmtsClosed)
704 s.closed = true
706 if s.next != nil {
707 s.next.Close()
709 return nil
712 var errClosed = errors.New("fakedb: statement has been closed")
714 // hook to simulate broken connections
715 var hookExecBadConn func() bool
717 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
718 panic("Using ExecContext")
720 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
721 if s.panic == "Exec" {
722 panic(s.panic)
724 if s.closed {
725 return nil, errClosed
728 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
729 return nil, driver.ErrBadConn
731 if s.c.isDirtyAndMark() {
732 return nil, errors.New("session is dirty")
735 err := checkSubsetTypes(s.c.db.allowAny, args)
736 if err != nil {
737 return nil, err
739 s.touchMem()
741 if s.wait > 0 {
742 time.Sleep(s.wait)
745 select {
746 default:
747 case <-ctx.Done():
748 return nil, ctx.Err()
751 db := s.c.db
752 switch s.cmd {
753 case "WIPE":
754 db.wipe()
755 return driver.ResultNoRows, nil
756 case "CREATE":
757 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
758 return nil, err
760 return driver.ResultNoRows, nil
761 case "INSERT":
762 return s.execInsert(args, true)
763 case "NOSERT":
764 // Do all the prep-work like for an INSERT but don't actually insert the row.
765 // Used for some of the concurrent tests.
766 return s.execInsert(args, false)
768 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
769 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
772 // When doInsert is true, add the row to the table.
773 // When doInsert is false do prep-work and error checking, but don't
774 // actually add the row to the table.
775 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
776 db := s.c.db
777 if len(args) != s.placeholders {
778 panic("error in pkg db; should only get here if size is correct")
780 db.mu.Lock()
781 t, ok := db.table(s.table)
782 db.mu.Unlock()
783 if !ok {
784 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
787 t.mu.Lock()
788 defer t.mu.Unlock()
790 var cols []interface{}
791 if doInsert {
792 cols = make([]interface{}, len(t.colname))
794 argPos := 0
795 for n, colname := range s.colName {
796 colidx := t.columnIndex(colname)
797 if colidx == -1 {
798 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
800 var val interface{}
801 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
802 if strvalue == "?" {
803 val = args[argPos].Value
804 } else {
805 // Assign value from argument placeholder name.
806 for _, a := range args {
807 if a.Name == strvalue[1:] {
808 val = a.Value
809 break
813 argPos++
814 } else {
815 val = s.colValue[n]
817 if doInsert {
818 cols[colidx] = val
822 if doInsert {
823 t.rows = append(t.rows, &row{cols: cols})
825 return driver.RowsAffected(1), nil
828 // hook to simulate broken connections
829 var hookQueryBadConn func() bool
831 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
832 panic("Use QueryContext")
835 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
836 if s.panic == "Query" {
837 panic(s.panic)
839 if s.closed {
840 return nil, errClosed
843 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
844 return nil, driver.ErrBadConn
846 if s.c.isDirtyAndMark() {
847 return nil, errors.New("session is dirty")
850 err := checkSubsetTypes(s.c.db.allowAny, args)
851 if err != nil {
852 return nil, err
855 s.touchMem()
856 db := s.c.db
857 if len(args) != s.placeholders {
858 panic("error in pkg db; should only get here if size is correct")
861 setMRows := make([][]*row, 0, 1)
862 setColumns := make([][]string, 0, 1)
863 setColType := make([][]string, 0, 1)
865 for {
866 db.mu.Lock()
867 t, ok := db.table(s.table)
868 db.mu.Unlock()
869 if !ok {
870 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
873 if s.table == "magicquery" {
874 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
875 if args[0].Value == "sleep" {
876 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
881 t.mu.Lock()
883 colIdx := make(map[string]int) // select column name -> column index in table
884 for _, name := range s.colName {
885 idx := t.columnIndex(name)
886 if idx == -1 {
887 t.mu.Unlock()
888 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
890 colIdx[name] = idx
893 mrows := []*row{}
894 rows:
895 for _, trow := range t.rows {
896 // Process the where clause, skipping non-match rows. This is lazy
897 // and just uses fmt.Sprintf("%v") to test equality. Good enough
898 // for test code.
899 for _, wcol := range s.whereCol {
900 idx := t.columnIndex(wcol.Column)
901 if idx == -1 {
902 t.mu.Unlock()
903 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
905 tcol := trow.cols[idx]
906 if bs, ok := tcol.([]byte); ok {
907 // lazy hack to avoid sprintf %v on a []byte
908 tcol = string(bs)
910 var argValue interface{}
911 if wcol.Placeholder == "?" {
912 argValue = args[wcol.Ordinal-1].Value
913 } else {
914 // Assign arg value from placeholder name.
915 for _, a := range args {
916 if a.Name == wcol.Placeholder[1:] {
917 argValue = a.Value
918 break
922 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
923 continue rows
926 mrow := &row{cols: make([]interface{}, len(s.colName))}
927 for seli, name := range s.colName {
928 mrow.cols[seli] = trow.cols[colIdx[name]]
930 mrows = append(mrows, mrow)
933 var colType []string
934 for _, column := range s.colName {
935 colType = append(colType, t.coltype[t.columnIndex(column)])
938 t.mu.Unlock()
940 setMRows = append(setMRows, mrows)
941 setColumns = append(setColumns, s.colName)
942 setColType = append(setColType, colType)
944 if s.next == nil {
945 break
947 s = s.next
950 cursor := &rowsCursor{
951 parentMem: s.c,
952 posRow: -1,
953 rows: setMRows,
954 cols: setColumns,
955 colType: setColType,
956 errPos: -1,
958 return cursor, nil
961 func (s *fakeStmt) NumInput() int {
962 if s.panic == "NumInput" {
963 panic(s.panic)
965 return s.placeholders
968 // hook to simulate broken connections
969 var hookCommitBadConn func() bool
971 func (tx *fakeTx) Commit() error {
972 tx.c.currTx = nil
973 if hookCommitBadConn != nil && hookCommitBadConn() {
974 return driver.ErrBadConn
976 tx.c.touchMem()
977 return nil
980 // hook to simulate broken connections
981 var hookRollbackBadConn func() bool
983 func (tx *fakeTx) Rollback() error {
984 tx.c.currTx = nil
985 if hookRollbackBadConn != nil && hookRollbackBadConn() {
986 return driver.ErrBadConn
988 tx.c.touchMem()
989 return nil
992 type rowsCursor struct {
993 parentMem memToucher
994 cols [][]string
995 colType [][]string
996 posSet int
997 posRow int
998 rows [][]*row
999 closed bool
1001 // errPos and err are for making Next return early with error.
1002 errPos int
1003 err error
1005 // a clone of slices to give out to clients, indexed by the
1006 // the original slice's first byte address. we clone them
1007 // just so we're able to corrupt them on close.
1008 bytesClone map[*byte][]byte
1010 // Every operation writes to line to enable the race detector
1011 // check for data races.
1012 // This is separate from the fakeConn.line to allow for drivers that
1013 // can start multiple queries on the same transaction at the same time.
1014 line int64
1017 func (rc *rowsCursor) touchMem() {
1018 rc.parentMem.touchMem()
1019 rc.line++
1022 func (rc *rowsCursor) Close() error {
1023 rc.touchMem()
1024 rc.parentMem.touchMem()
1025 rc.closed = true
1026 return nil
1029 func (rc *rowsCursor) Columns() []string {
1030 return rc.cols[rc.posSet]
1033 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1034 return colTypeToReflectType(rc.colType[rc.posSet][index])
1037 var rowsCursorNextHook func(dest []driver.Value) error
1039 func (rc *rowsCursor) Next(dest []driver.Value) error {
1040 if rowsCursorNextHook != nil {
1041 return rowsCursorNextHook(dest)
1044 if rc.closed {
1045 return errors.New("fakedb: cursor is closed")
1047 rc.touchMem()
1048 rc.posRow++
1049 if rc.posRow == rc.errPos {
1050 return rc.err
1052 if rc.posRow >= len(rc.rows[rc.posSet]) {
1053 return io.EOF // per interface spec
1055 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1056 // TODO(bradfitz): convert to subset types? naah, I
1057 // think the subset types should only be input to
1058 // driver, but the sql package should be able to handle
1059 // a wider range of types coming out of drivers. all
1060 // for ease of drivers, and to prevent drivers from
1061 // messing up conversions or doing them differently.
1062 dest[i] = v
1064 if bs, ok := v.([]byte); ok {
1065 if rc.bytesClone == nil {
1066 rc.bytesClone = make(map[*byte][]byte)
1068 clone, ok := rc.bytesClone[&bs[0]]
1069 if !ok {
1070 clone = make([]byte, len(bs))
1071 copy(clone, bs)
1072 rc.bytesClone[&bs[0]] = clone
1074 dest[i] = clone
1077 return nil
1080 func (rc *rowsCursor) HasNextResultSet() bool {
1081 rc.touchMem()
1082 return rc.posSet < len(rc.rows)-1
1085 func (rc *rowsCursor) NextResultSet() error {
1086 rc.touchMem()
1087 if rc.HasNextResultSet() {
1088 rc.posSet++
1089 rc.posRow = -1
1090 return nil
1092 return io.EOF // Per interface spec.
1095 // fakeDriverString is like driver.String, but indirects pointers like
1096 // DefaultValueConverter.
1098 // This could be surprising behavior to retroactively apply to
1099 // driver.String now that Go1 is out, but this is convenient for
1100 // our TestPointerParamsAndScans.
1102 type fakeDriverString struct{}
1104 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
1105 switch c := v.(type) {
1106 case string, []byte:
1107 return v, nil
1108 case *string:
1109 if c == nil {
1110 return nil, nil
1112 return *c, nil
1114 return fmt.Sprintf("%v", v), nil
1117 type anyTypeConverter struct{}
1119 func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
1120 return v, nil
1123 func converterForType(typ string) driver.ValueConverter {
1124 switch typ {
1125 case "bool":
1126 return driver.Bool
1127 case "nullbool":
1128 return driver.Null{Converter: driver.Bool}
1129 case "int32":
1130 return driver.Int32
1131 case "string":
1132 return driver.NotNull{Converter: fakeDriverString{}}
1133 case "nullstring":
1134 return driver.Null{Converter: fakeDriverString{}}
1135 case "int64":
1136 // TODO(coopernurse): add type-specific converter
1137 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1138 case "nullint64":
1139 // TODO(coopernurse): add type-specific converter
1140 return driver.Null{Converter: driver.DefaultParameterConverter}
1141 case "float64":
1142 // TODO(coopernurse): add type-specific converter
1143 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1144 case "nullfloat64":
1145 // TODO(coopernurse): add type-specific converter
1146 return driver.Null{Converter: driver.DefaultParameterConverter}
1147 case "datetime":
1148 return driver.DefaultParameterConverter
1149 case "any":
1150 return anyTypeConverter{}
1152 panic("invalid fakedb column type of " + typ)
1155 func colTypeToReflectType(typ string) reflect.Type {
1156 switch typ {
1157 case "bool":
1158 return reflect.TypeOf(false)
1159 case "nullbool":
1160 return reflect.TypeOf(NullBool{})
1161 case "int32":
1162 return reflect.TypeOf(int32(0))
1163 case "string":
1164 return reflect.TypeOf("")
1165 case "nullstring":
1166 return reflect.TypeOf(NullString{})
1167 case "int64":
1168 return reflect.TypeOf(int64(0))
1169 case "nullint64":
1170 return reflect.TypeOf(NullInt64{})
1171 case "float64":
1172 return reflect.TypeOf(float64(0))
1173 case "nullfloat64":
1174 return reflect.TypeOf(NullFloat64{})
1175 case "datetime":
1176 return reflect.TypeOf(time.Time{})
1177 case "any":
1178 return reflect.TypeOf(new(interface{})).Elem()
1180 panic("invalid fakedb column type of " + typ)