libgo: update to final Go 1.8 release
[official-gcc.git] / libgo / go / database / sql / fakedb_test.go
blob4b15f5bec7bfb273aeb8b40bcc08b10387107934
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 package sql
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "log"
14 "reflect"
15 "sort"
16 "strconv"
17 "strings"
18 "sync"
19 "testing"
20 "time"
23 var _ = log.Printf
25 // fakeDriver is a fake database that implements Go's driver.Driver
26 // interface, just for testing.
28 // It speaks a query language that's semantically similar to but
29 // syntactically different and simpler than SQL. The syntax is as
30 // follows:
32 // WIPE
33 // CREATE|<tablename>|<col>=<type>,<col>=<type>,...
34 // where types are: "string", [u]int{8,16,32,64}, "bool"
35 // INSERT|<tablename>|col=val,col2=val2,col3=?
36 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
37 // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
39 // Any of these can be preceded by PANIC|<method>|, to cause the
40 // named method on fakeStmt to panic.
42 // Any of these can be proceeded by WAIT|<duration>|, to cause the
43 // named method on fakeStmt to sleep for the specified duration.
45 // Multiple of these can be combined when separated with a semicolon.
47 // When opening a fakeDriver's database, it starts empty with no
48 // tables. All tables and data are stored in memory only.
49 type fakeDriver struct {
50 mu sync.Mutex // guards 3 following fields
51 openCount int // conn opens
52 closeCount int // conn closes
53 waitCh chan struct{}
54 waitingCh chan struct{}
55 dbs map[string]*fakeDB
58 type fakeDB struct {
59 name string
61 mu sync.Mutex
62 tables map[string]*table
63 badConn bool
66 type table struct {
67 mu sync.Mutex
68 colname []string
69 coltype []string
70 rows []*row
73 func (t *table) columnIndex(name string) int {
74 for n, nname := range t.colname {
75 if name == nname {
76 return n
79 return -1
82 type row struct {
83 cols []interface{} // must be same size as its table colname + coltype
86 type fakeConn struct {
87 db *fakeDB // where to return ourselves to
89 currTx *fakeTx
91 // Stats for tests:
92 mu sync.Mutex
93 stmtsMade int
94 stmtsClosed int
95 numPrepare int
97 // bad connection tests; see isBad()
98 bad bool
99 stickyBad bool
102 func (c *fakeConn) incrStat(v *int) {
103 c.mu.Lock()
104 *v++
105 c.mu.Unlock()
108 type fakeTx struct {
109 c *fakeConn
112 type boundCol struct {
113 Column string
114 Placeholder string
115 Ordinal int
118 type fakeStmt struct {
119 c *fakeConn
120 q string // just for debugging
122 cmd string
123 table string
124 panic string
125 wait time.Duration
127 next *fakeStmt // used for returning multiple results.
129 closed bool
131 colName []string // used by CREATE, INSERT, SELECT (selected columns)
132 colType []string // used by CREATE
133 colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
134 placeholders int // used by INSERT/SELECT: number of ? params
136 whereCol []boundCol // used by SELECT (all placeholders)
138 placeholderConverter []driver.ValueConverter // used by INSERT
141 var fdriver driver.Driver = &fakeDriver{}
143 func init() {
144 Register("test", fdriver)
147 func contains(list []string, y string) bool {
148 for _, x := range list {
149 if x == y {
150 return true
153 return false
156 type Dummy struct {
157 driver.Driver
160 func TestDrivers(t *testing.T) {
161 unregisterAllDrivers()
162 Register("test", fdriver)
163 Register("invalid", Dummy{})
164 all := Drivers()
165 if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
166 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
170 // hook to simulate connection failures
171 var hookOpenErr struct {
172 sync.Mutex
173 fn func() error
176 func setHookOpenErr(fn func() error) {
177 hookOpenErr.Lock()
178 defer hookOpenErr.Unlock()
179 hookOpenErr.fn = fn
182 // Supports dsn forms:
183 // <dbname>
184 // <dbname>;<opts> (only currently supported option is `badConn`,
185 // which causes driver.ErrBadConn to be returned on
186 // every other conn.Begin())
187 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
188 hookOpenErr.Lock()
189 fn := hookOpenErr.fn
190 hookOpenErr.Unlock()
191 if fn != nil {
192 if err := fn(); err != nil {
193 return nil, err
196 parts := strings.Split(dsn, ";")
197 if len(parts) < 1 {
198 return nil, errors.New("fakedb: no database name")
200 name := parts[0]
202 db := d.getDB(name)
204 d.mu.Lock()
205 d.openCount++
206 d.mu.Unlock()
207 conn := &fakeConn{db: db}
209 if len(parts) >= 2 && parts[1] == "badConn" {
210 conn.bad = true
212 if d.waitCh != nil {
213 d.waitingCh <- struct{}{}
214 <-d.waitCh
215 d.waitCh = nil
216 d.waitingCh = nil
218 return conn, nil
221 func (d *fakeDriver) getDB(name string) *fakeDB {
222 d.mu.Lock()
223 defer d.mu.Unlock()
224 if d.dbs == nil {
225 d.dbs = make(map[string]*fakeDB)
227 db, ok := d.dbs[name]
228 if !ok {
229 db = &fakeDB{name: name}
230 d.dbs[name] = db
232 return db
235 func (db *fakeDB) wipe() {
236 db.mu.Lock()
237 defer db.mu.Unlock()
238 db.tables = nil
241 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
242 db.mu.Lock()
243 defer db.mu.Unlock()
244 if db.tables == nil {
245 db.tables = make(map[string]*table)
247 if _, exist := db.tables[name]; exist {
248 return fmt.Errorf("table %q already exists", name)
250 if len(columnNames) != len(columnTypes) {
251 return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
252 name, len(columnNames), len(columnTypes))
254 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
255 return nil
258 // must be called with db.mu lock held
259 func (db *fakeDB) table(table string) (*table, bool) {
260 if db.tables == nil {
261 return nil, false
263 t, ok := db.tables[table]
264 return t, ok
267 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
268 db.mu.Lock()
269 defer db.mu.Unlock()
270 t, ok := db.table(table)
271 if !ok {
272 return
274 for n, cname := range t.colname {
275 if cname == column {
276 return t.coltype[n], true
279 return "", false
282 func (c *fakeConn) isBad() bool {
283 if c.stickyBad {
284 return true
285 } else if c.bad {
286 // alternate between bad conn and not bad conn
287 c.db.badConn = !c.db.badConn
288 return c.db.badConn
289 } else {
290 return false
294 func (c *fakeConn) Begin() (driver.Tx, error) {
295 if c.isBad() {
296 return nil, driver.ErrBadConn
298 if c.currTx != nil {
299 return nil, errors.New("already in a transaction")
301 c.currTx = &fakeTx{c: c}
302 return c.currTx, nil
305 var hookPostCloseConn struct {
306 sync.Mutex
307 fn func(*fakeConn, error)
310 func setHookpostCloseConn(fn func(*fakeConn, error)) {
311 hookPostCloseConn.Lock()
312 defer hookPostCloseConn.Unlock()
313 hookPostCloseConn.fn = fn
316 var testStrictClose *testing.T
318 // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
319 // fails to close. If nil, the check is disabled.
320 func setStrictFakeConnClose(t *testing.T) {
321 testStrictClose = t
324 func (c *fakeConn) Close() (err error) {
325 drv := fdriver.(*fakeDriver)
326 defer func() {
327 if err != nil && testStrictClose != nil {
328 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
330 hookPostCloseConn.Lock()
331 fn := hookPostCloseConn.fn
332 hookPostCloseConn.Unlock()
333 if fn != nil {
334 fn(c, err)
336 if err == nil {
337 drv.mu.Lock()
338 drv.closeCount++
339 drv.mu.Unlock()
342 if c.currTx != nil {
343 return errors.New("can't close fakeConn; in a Transaction")
345 if c.db == nil {
346 return errors.New("can't close fakeConn; already closed")
348 if c.stmtsMade > c.stmtsClosed {
349 return errors.New("can't close; dangling statement(s)")
351 c.db = nil
352 return nil
355 func checkSubsetTypes(args []driver.NamedValue) error {
356 for _, arg := range args {
357 switch arg.Value.(type) {
358 case int64, float64, bool, nil, []byte, string, time.Time:
359 default:
360 return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
363 return nil
366 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
367 // Ensure that ExecContext is called if available.
368 panic("ExecContext was not called.")
371 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
372 // This is an optional interface, but it's implemented here
373 // just to check that all the args are of the proper types.
374 // ErrSkip is returned so the caller acts as if we didn't
375 // implement this at all.
376 err := checkSubsetTypes(args)
377 if err != nil {
378 return nil, err
380 return nil, driver.ErrSkip
383 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
384 // Ensure that ExecContext is called if available.
385 panic("QueryContext was not called.")
388 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
389 // This is an optional interface, but it's implemented here
390 // just to check that all the args are of the proper types.
391 // ErrSkip is returned so the caller acts as if we didn't
392 // implement this at all.
393 err := checkSubsetTypes(args)
394 if err != nil {
395 return nil, err
397 return nil, driver.ErrSkip
400 func errf(msg string, args ...interface{}) error {
401 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
404 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
405 // (note that where columns must always contain ? marks,
406 // just a limitation for fakedb)
407 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
408 if len(parts) != 3 {
409 stmt.Close()
410 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
412 stmt.table = parts[0]
414 stmt.colName = strings.Split(parts[1], ",")
415 for n, colspec := range strings.Split(parts[2], ",") {
416 if colspec == "" {
417 continue
419 nameVal := strings.Split(colspec, "=")
420 if len(nameVal) != 2 {
421 stmt.Close()
422 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
424 column, value := nameVal[0], nameVal[1]
425 _, ok := c.db.columnType(stmt.table, column)
426 if !ok {
427 stmt.Close()
428 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
430 if !strings.HasPrefix(value, "?") {
431 stmt.Close()
432 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
433 stmt.table, column)
435 stmt.placeholders++
436 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
438 return stmt, nil
441 // parts are table|col=type,col2=type2
442 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
443 if len(parts) != 2 {
444 stmt.Close()
445 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
447 stmt.table = parts[0]
448 for n, colspec := range strings.Split(parts[1], ",") {
449 nameType := strings.Split(colspec, "=")
450 if len(nameType) != 2 {
451 stmt.Close()
452 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
454 stmt.colName = append(stmt.colName, nameType[0])
455 stmt.colType = append(stmt.colType, nameType[1])
457 return stmt, nil
460 // parts are table|col=?,col2=val
461 func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
462 if len(parts) != 2 {
463 stmt.Close()
464 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
466 stmt.table = parts[0]
467 for n, colspec := range strings.Split(parts[1], ",") {
468 nameVal := strings.Split(colspec, "=")
469 if len(nameVal) != 2 {
470 stmt.Close()
471 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
473 column, value := nameVal[0], nameVal[1]
474 ctype, ok := c.db.columnType(stmt.table, column)
475 if !ok {
476 stmt.Close()
477 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
479 stmt.colName = append(stmt.colName, column)
481 if !strings.HasPrefix(value, "?") {
482 var subsetVal interface{}
483 // Convert to driver subset type
484 switch ctype {
485 case "string":
486 subsetVal = []byte(value)
487 case "blob":
488 subsetVal = []byte(value)
489 case "int32":
490 i, err := strconv.Atoi(value)
491 if err != nil {
492 stmt.Close()
493 return nil, errf("invalid conversion to int32 from %q", value)
495 subsetVal = int64(i) // int64 is a subset type, but not int32
496 default:
497 stmt.Close()
498 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
500 stmt.colValue = append(stmt.colValue, subsetVal)
501 } else {
502 stmt.placeholders++
503 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
504 stmt.colValue = append(stmt.colValue, value)
507 return stmt, nil
510 // hook to simulate broken connections
511 var hookPrepareBadConn func() bool
513 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
514 panic("use PrepareContext")
517 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
518 c.numPrepare++
519 if c.db == nil {
520 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
523 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
524 return nil, driver.ErrBadConn
527 var firstStmt, prev *fakeStmt
528 for _, query := range strings.Split(query, ";") {
529 parts := strings.Split(query, "|")
530 if len(parts) < 1 {
531 return nil, errf("empty query")
533 stmt := &fakeStmt{q: query, c: c}
534 if firstStmt == nil {
535 firstStmt = stmt
537 if len(parts) >= 3 {
538 switch parts[0] {
539 case "PANIC":
540 stmt.panic = parts[1]
541 parts = parts[2:]
542 case "WAIT":
543 wait, err := time.ParseDuration(parts[1])
544 if err != nil {
545 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
547 parts = parts[2:]
548 stmt.wait = wait
551 cmd := parts[0]
552 stmt.cmd = cmd
553 parts = parts[1:]
555 if stmt.wait > 0 {
556 wait := time.NewTimer(stmt.wait)
557 select {
558 case <-wait.C:
559 case <-ctx.Done():
560 wait.Stop()
561 return nil, ctx.Err()
565 c.incrStat(&c.stmtsMade)
566 var err error
567 switch cmd {
568 case "WIPE":
569 // Nothing
570 case "SELECT":
571 stmt, err = c.prepareSelect(stmt, parts)
572 case "CREATE":
573 stmt, err = c.prepareCreate(stmt, parts)
574 case "INSERT":
575 stmt, err = c.prepareInsert(stmt, parts)
576 case "NOSERT":
577 // Do all the prep-work like for an INSERT but don't actually insert the row.
578 // Used for some of the concurrent tests.
579 stmt, err = c.prepareInsert(stmt, parts)
580 default:
581 stmt.Close()
582 return nil, errf("unsupported command type %q", cmd)
584 if err != nil {
585 return nil, err
587 if prev != nil {
588 prev.next = stmt
590 prev = stmt
592 return firstStmt, nil
595 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
596 if s.panic == "ColumnConverter" {
597 panic(s.panic)
599 if len(s.placeholderConverter) == 0 {
600 return driver.DefaultParameterConverter
602 return s.placeholderConverter[idx]
605 func (s *fakeStmt) Close() error {
606 if s.panic == "Close" {
607 panic(s.panic)
609 if s.c == nil {
610 panic("nil conn in fakeStmt.Close")
612 if s.c.db == nil {
613 panic("in fakeStmt.Close, conn's db is nil (already closed)")
615 if !s.closed {
616 s.c.incrStat(&s.c.stmtsClosed)
617 s.closed = true
619 if s.next != nil {
620 s.next.Close()
622 return nil
625 var errClosed = errors.New("fakedb: statement has been closed")
627 // hook to simulate broken connections
628 var hookExecBadConn func() bool
630 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
631 panic("Using ExecContext")
633 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
634 if s.panic == "Exec" {
635 panic(s.panic)
637 if s.closed {
638 return nil, errClosed
641 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
642 return nil, driver.ErrBadConn
645 err := checkSubsetTypes(args)
646 if err != nil {
647 return nil, err
650 if s.wait > 0 {
651 time.Sleep(s.wait)
654 select {
655 default:
656 case <-ctx.Done():
657 return nil, ctx.Err()
660 db := s.c.db
661 switch s.cmd {
662 case "WIPE":
663 db.wipe()
664 return driver.ResultNoRows, nil
665 case "CREATE":
666 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
667 return nil, err
669 return driver.ResultNoRows, nil
670 case "INSERT":
671 return s.execInsert(args, true)
672 case "NOSERT":
673 // Do all the prep-work like for an INSERT but don't actually insert the row.
674 // Used for some of the concurrent tests.
675 return s.execInsert(args, false)
677 fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
678 return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
681 // When doInsert is true, add the row to the table.
682 // When doInsert is false do prep-work and error checking, but don't
683 // actually add the row to the table.
684 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
685 db := s.c.db
686 if len(args) != s.placeholders {
687 panic("error in pkg db; should only get here if size is correct")
689 db.mu.Lock()
690 t, ok := db.table(s.table)
691 db.mu.Unlock()
692 if !ok {
693 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
696 t.mu.Lock()
697 defer t.mu.Unlock()
699 var cols []interface{}
700 if doInsert {
701 cols = make([]interface{}, len(t.colname))
703 argPos := 0
704 for n, colname := range s.colName {
705 colidx := t.columnIndex(colname)
706 if colidx == -1 {
707 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
709 var val interface{}
710 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
711 if strvalue == "?" {
712 val = args[argPos].Value
713 } else {
714 // Assign value from argument placeholder name.
715 for _, a := range args {
716 if a.Name == strvalue[1:] {
717 val = a.Value
718 break
722 argPos++
723 } else {
724 val = s.colValue[n]
726 if doInsert {
727 cols[colidx] = val
731 if doInsert {
732 t.rows = append(t.rows, &row{cols: cols})
734 return driver.RowsAffected(1), nil
737 // hook to simulate broken connections
738 var hookQueryBadConn func() bool
740 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
741 panic("Use QueryContext")
744 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
745 if s.panic == "Query" {
746 panic(s.panic)
748 if s.closed {
749 return nil, errClosed
752 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
753 return nil, driver.ErrBadConn
756 err := checkSubsetTypes(args)
757 if err != nil {
758 return nil, err
761 db := s.c.db
762 if len(args) != s.placeholders {
763 panic("error in pkg db; should only get here if size is correct")
766 setMRows := make([][]*row, 0, 1)
767 setColumns := make([][]string, 0, 1)
768 setColType := make([][]string, 0, 1)
770 for {
771 db.mu.Lock()
772 t, ok := db.table(s.table)
773 db.mu.Unlock()
774 if !ok {
775 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
778 if s.table == "magicquery" {
779 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
780 if args[0].Value == "sleep" {
781 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
786 t.mu.Lock()
788 colIdx := make(map[string]int) // select column name -> column index in table
789 for _, name := range s.colName {
790 idx := t.columnIndex(name)
791 if idx == -1 {
792 t.mu.Unlock()
793 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
795 colIdx[name] = idx
798 mrows := []*row{}
799 rows:
800 for _, trow := range t.rows {
801 // Process the where clause, skipping non-match rows. This is lazy
802 // and just uses fmt.Sprintf("%v") to test equality. Good enough
803 // for test code.
804 for _, wcol := range s.whereCol {
805 idx := t.columnIndex(wcol.Column)
806 if idx == -1 {
807 t.mu.Unlock()
808 return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
810 tcol := trow.cols[idx]
811 if bs, ok := tcol.([]byte); ok {
812 // lazy hack to avoid sprintf %v on a []byte
813 tcol = string(bs)
815 var argValue interface{}
816 if wcol.Placeholder == "?" {
817 argValue = args[wcol.Ordinal-1].Value
818 } else {
819 // Assign arg value from placeholder name.
820 for _, a := range args {
821 if a.Name == wcol.Placeholder[1:] {
822 argValue = a.Value
823 break
827 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
828 continue rows
831 mrow := &row{cols: make([]interface{}, len(s.colName))}
832 for seli, name := range s.colName {
833 mrow.cols[seli] = trow.cols[colIdx[name]]
835 mrows = append(mrows, mrow)
838 var colType []string
839 for _, column := range s.colName {
840 colType = append(colType, t.coltype[t.columnIndex(column)])
843 t.mu.Unlock()
845 setMRows = append(setMRows, mrows)
846 setColumns = append(setColumns, s.colName)
847 setColType = append(setColType, colType)
849 if s.next == nil {
850 break
852 s = s.next
855 cursor := &rowsCursor{
856 posRow: -1,
857 rows: setMRows,
858 cols: setColumns,
859 colType: setColType,
860 errPos: -1,
862 return cursor, nil
865 func (s *fakeStmt) NumInput() int {
866 if s.panic == "NumInput" {
867 panic(s.panic)
869 return s.placeholders
872 // hook to simulate broken connections
873 var hookCommitBadConn func() bool
875 func (tx *fakeTx) Commit() error {
876 tx.c.currTx = nil
877 if hookCommitBadConn != nil && hookCommitBadConn() {
878 return driver.ErrBadConn
880 return nil
883 // hook to simulate broken connections
884 var hookRollbackBadConn func() bool
886 func (tx *fakeTx) Rollback() error {
887 tx.c.currTx = nil
888 if hookRollbackBadConn != nil && hookRollbackBadConn() {
889 return driver.ErrBadConn
891 return nil
894 type rowsCursor struct {
895 cols [][]string
896 colType [][]string
897 posSet int
898 posRow int
899 rows [][]*row
900 closed bool
902 // errPos and err are for making Next return early with error.
903 errPos int
904 err error
906 // a clone of slices to give out to clients, indexed by the
907 // the original slice's first byte address. we clone them
908 // just so we're able to corrupt them on close.
909 bytesClone map[*byte][]byte
912 func (rc *rowsCursor) Close() error {
913 if !rc.closed {
914 for _, bs := range rc.bytesClone {
915 bs[0] = 255 // first byte corrupted
918 rc.closed = true
919 return nil
922 func (rc *rowsCursor) Columns() []string {
923 return rc.cols[rc.posSet]
926 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
927 return colTypeToReflectType(rc.colType[rc.posSet][index])
930 var rowsCursorNextHook func(dest []driver.Value) error
932 func (rc *rowsCursor) Next(dest []driver.Value) error {
933 if rowsCursorNextHook != nil {
934 return rowsCursorNextHook(dest)
937 if rc.closed {
938 return errors.New("fakedb: cursor is closed")
940 rc.posRow++
941 if rc.posRow == rc.errPos {
942 return rc.err
944 if rc.posRow >= len(rc.rows[rc.posSet]) {
945 return io.EOF // per interface spec
947 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
948 // TODO(bradfitz): convert to subset types? naah, I
949 // think the subset types should only be input to
950 // driver, but the sql package should be able to handle
951 // a wider range of types coming out of drivers. all
952 // for ease of drivers, and to prevent drivers from
953 // messing up conversions or doing them differently.
954 dest[i] = v
956 if bs, ok := v.([]byte); ok {
957 if rc.bytesClone == nil {
958 rc.bytesClone = make(map[*byte][]byte)
960 clone, ok := rc.bytesClone[&bs[0]]
961 if !ok {
962 clone = make([]byte, len(bs))
963 copy(clone, bs)
964 rc.bytesClone[&bs[0]] = clone
966 dest[i] = clone
969 return nil
972 func (rc *rowsCursor) HasNextResultSet() bool {
973 return rc.posSet < len(rc.rows)-1
976 func (rc *rowsCursor) NextResultSet() error {
977 if rc.HasNextResultSet() {
978 rc.posSet++
979 rc.posRow = -1
980 return nil
982 return io.EOF // Per interface spec.
985 // fakeDriverString is like driver.String, but indirects pointers like
986 // DefaultValueConverter.
988 // This could be surprising behavior to retroactively apply to
989 // driver.String now that Go1 is out, but this is convenient for
990 // our TestPointerParamsAndScans.
992 type fakeDriverString struct{}
994 func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
995 switch c := v.(type) {
996 case string, []byte:
997 return v, nil
998 case *string:
999 if c == nil {
1000 return nil, nil
1002 return *c, nil
1004 return fmt.Sprintf("%v", v), nil
1007 func converterForType(typ string) driver.ValueConverter {
1008 switch typ {
1009 case "bool":
1010 return driver.Bool
1011 case "nullbool":
1012 return driver.Null{Converter: driver.Bool}
1013 case "int32":
1014 return driver.Int32
1015 case "string":
1016 return driver.NotNull{Converter: fakeDriverString{}}
1017 case "nullstring":
1018 return driver.Null{Converter: fakeDriverString{}}
1019 case "int64":
1020 // TODO(coopernurse): add type-specific converter
1021 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1022 case "nullint64":
1023 // TODO(coopernurse): add type-specific converter
1024 return driver.Null{Converter: driver.DefaultParameterConverter}
1025 case "float64":
1026 // TODO(coopernurse): add type-specific converter
1027 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1028 case "nullfloat64":
1029 // TODO(coopernurse): add type-specific converter
1030 return driver.Null{Converter: driver.DefaultParameterConverter}
1031 case "datetime":
1032 return driver.DefaultParameterConverter
1034 panic("invalid fakedb column type of " + typ)
1037 func colTypeToReflectType(typ string) reflect.Type {
1038 switch typ {
1039 case "bool":
1040 return reflect.TypeOf(false)
1041 case "nullbool":
1042 return reflect.TypeOf(NullBool{})
1043 case "int32":
1044 return reflect.TypeOf(int32(0))
1045 case "string":
1046 return reflect.TypeOf("")
1047 case "nullstring":
1048 return reflect.TypeOf(NullString{})
1049 case "int64":
1050 return reflect.TypeOf(int64(0))
1051 case "nullint64":
1052 return reflect.TypeOf(NullInt64{})
1053 case "float64":
1054 return reflect.TypeOf(float64(0))
1055 case "nullfloat64":
1056 return reflect.TypeOf(NullFloat64{})
1057 case "datetime":
1058 return reflect.TypeOf(time.Time{})
1060 panic("invalid fakedb column type of " + typ)