diff options
Diffstat (limited to 'libgo/go/database/sql/sql_test.go')
-rw-r--r-- | libgo/go/database/sql/sql_test.go | 555 |
1 files changed, 533 insertions, 22 deletions
diff --git a/libgo/go/database/sql/sql_test.go b/libgo/go/database/sql/sql_test.go index 08df0c7666a..63e1292cb1f 100644 --- a/libgo/go/database/sql/sql_test.go +++ b/libgo/go/database/sql/sql_test.go @@ -5,6 +5,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -23,6 +24,17 @@ func init() { c *driverConn } freedFrom := make(map[dbConn]string) + var mu sync.Mutex + getFreedFrom := func(c dbConn) string { + mu.Lock() + defer mu.Unlock() + return freedFrom[c] + } + setFreedFrom := func(c dbConn, s string) { + mu.Lock() + defer mu.Unlock() + freedFrom[c] = s + } putConnHook = func(db *DB, c *driverConn) { idx := -1 for i, v := range db.freeConn { @@ -35,10 +47,10 @@ func init() { // print before panic, as panic may get lost due to conflicting panic // (all goroutines asleep) elsewhere, since we might not unlock // the mutex in freeConn here. - println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack()) + println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack()) panic("double free of conn.") } - freedFrom[dbConn{db, c}] = stack() + setFreedFrom(dbConn{db, c}, stack()) } } @@ -140,10 +152,7 @@ func closeDB(t testing.TB, db *DB) { if err != nil { t.Fatalf("error closing DB: %v", err) } - db.mu.Lock() - count := db.numOpen - db.mu.Unlock() - if count != 0 { + if count := db.numOpenConns(); count != 0 { t.Fatalf("%d connections still open after closing DB", count) } } @@ -182,6 +191,12 @@ func (db *DB) numFreeConns() int { return len(db.freeConn) } +func (db *DB) numOpenConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return db.numOpen +} + // clearAllConns closes all connections in db. func (db *DB) clearAllConns(t *testing.T) { db.SetMaxIdleConns(0) @@ -260,6 +275,257 @@ func TestQuery(t *testing.T) { } } +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + time.Sleep(10 * time.Millisecond) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err == nil { + t.Fatal("expected an error on last scan") + } + got = append(got, r) + index++ + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} + +func TestQueryContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15) + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err := db.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + // Verify closed rows connection after error condition. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestTxContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + var numFree int + if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + numFree = db.numFreeConns() + return numFree == 0 + }) { + t.Fatalf("free conns after hitting EOF = %d; want 0", numFree) + } + + // Ensure the dropped connection allows more connections to be made. + // Checked on DB Close. + waitCondition(5*time.Second, 5*time.Millisecond, func() bool { + return db.numOpenConns() == 0 + }) +} + +func TestMultiResultSetQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row1 struct { + age int + name string + } + type row2 struct { + name string + } + got1 := []row1{} + for rows.Next() { + var r row1 + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got1 = append(got1, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want1 := []row1{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got1, want1) { + t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) + } + + if !rows.NextResultSet() { + t.Errorf("expected another result set") + } + + got2 := []row2{} + for rows.Next() { + var r row2 + err = rows.Scan(&r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got2 = append(got2, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want2 := []row2{ + {name: "Alice"}, + {name: "Bob"}, + {name: "Chris"}, + } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) + } + if rows.NextResultSet() { + t.Errorf("expected no more result sets") + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestQueryNamedArg(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query( + // Ensure the name and age parameters only match on placeholder name, not position. + "SELECT|people|age,name|name=?name,age=?age", + Named("age", 2), + Named("name", "Bob"), + ) + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -317,6 +583,56 @@ func TestRowsColumns(t *testing.T) { } } +func TestRowsColumnTypes(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + types[i] = st + } + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + ct := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + ct++ + if ct == 0 { + if values[0].(string) != "Bob" { + t.Errorf("Expected Bob, got %v", values[0]) + } + if values[1].(int) != 2 { + t.Errorf("Expected 2, got %v", values[1]) + } + } + } + if ct != 3 { + t.Errorf("expected 3 rows, got %d", ct) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + func TestQueryRow(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -367,6 +683,37 @@ func TestQueryRow(t *testing.T) { } } +func TestTxRollbackCommitErr(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Rollback() + if err != nil { + t.Errorf("expected nil error from Rollback; got %v", err) + } + err = tx.Commit() + if err != ErrTxDone { + t.Errorf("expected %q from Commit; got %q", ErrTxDone, err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Errorf("expected nil error from Commit; got %v", err) + } + err = tx.Rollback() + if err != ErrTxDone { + t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err) + } +} + func TestStatementErrorAfterClose(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db) @@ -439,7 +786,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want { @@ -513,8 +860,8 @@ func TestExec(t *testing.T) { {[]interface{}{7, 9}, ""}, // Invalid conversions: - {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument #1's type: sql/driver: value 4294967295 overflows int32"}, - {[]interface{}{"Brad", "strconv fail"}, "sql: converting argument #1's type: sql/driver: value \"strconv fail\" can't be converted to int32"}, + {[]interface{}{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"}, + {[]interface{}{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`}, // Wrong number of args: {[]interface{}{}, "sql: expected 2 arguments, got 0"}, @@ -1159,17 +1506,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) { db.SetMaxOpenConns(3) - conn0, err := db.conn(cachedOrNewConn) + ctx := context.Background() + + conn0, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn1, err := db.conn(cachedOrNewConn) + conn1, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn2, err := db.conn(cachedOrNewConn) + conn2, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } @@ -1203,7 +1552,11 @@ func TestPendingConnsAfterErr(t *testing.T) { tryOpen = maxOpen*2 + 2 ) - db := newTestDB(t, "people") + // No queries will be run. + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } defer closeDB(t, db) defer func() { for k, v := range db.lastPut { @@ -1215,29 +1568,29 @@ func TestPendingConnsAfterErr(t *testing.T) { db.SetMaxIdleConns(0) errOffline := errors.New("db offline") + defer func() { setHookOpenErr(nil) }() errs := make(chan error, tryOpen) - unblock := make(chan struct{}) + var opening sync.WaitGroup + opening.Add(tryOpen) + setHookOpenErr(func() error { - <-unblock // block until all connections are in flight + // Wait for all connections to enqueue. + opening.Wait() return errOffline }) - var opening sync.WaitGroup - opening.Add(tryOpen) for i := 0; i < tryOpen; i++ { go func() { opening.Done() // signal one connection is in flight - _, err := db.Exec("INSERT|people|name=Julia,age=19") + _, err := db.Exec("will never run") errs <- err }() } - opening.Wait() // wait for all workers to begin running - time.Sleep(10 * time.Millisecond) // make extra sure all workers are blocked - close(unblock) // let all workers proceed + opening.Wait() // wait for all workers to begin running const timeout = 5 * time.Second to := time.NewTimer(timeout) @@ -1254,6 +1607,24 @@ func TestPendingConnsAfterErr(t *testing.T) { t.Fatalf("orphaned connection request(s), still waiting after %v", timeout) } } + + // Wait a reasonable time for the database to close all connections. + tick := time.NewTicker(3 * time.Millisecond) + defer tick.Stop() + for { + select { + case <-tick.C: + db.mu.Lock() + if db.numOpen == 0 { + db.mu.Unlock() + return + } + db.mu.Unlock() + case <-to.C: + // Closing the database will check for numOpen and fail the test. + return + } + } } func TestSingleOpenConn(t *testing.T) { @@ -2236,6 +2607,54 @@ func TestIssue6081(t *testing.T) { } } +// TestIssue18429 attempts to stress rolling back the transaction from a +// context cancel while simultaneously calling Tx.Rollback. Rolling back from a +// context happens concurrently so tx.rollback and tx.Commit must guard against +// double entry. +// +// In the test, a context is canceled while the query is in process so +// the internal rollback will run concurrently with the explicitly called +// Tx.Rollback. +func TestIssue18429(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String() + + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return + } + rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") + if rows != nil { + rows.Close() + } + // This call will race with the context cancel rollback to complete + // if the rollback itself isn't guarded. + tx.Rollback() + }() + } + wg.Wait() + time.Sleep(milliWait * 3 * time.Millisecond) +} + func TestConcurrency(t *testing.T) { doConcurrentTest(t, new(concurrentDBQueryTest)) doConcurrentTest(t, new(concurrentDBExecTest)) @@ -2279,7 +2698,8 @@ func TestConnectionLeak(t *testing.T) { go func() { r, err := db.Query("SELECT|people|name|") if err != nil { - t.Fatal(err) + t.Error(err) + return } r.Close() wg.Done() @@ -2299,6 +2719,97 @@ func TestConnectionLeak(t *testing.T) { wg.Wait() } +// badConn implements a bad driver.Conn, for TestBadDriver. +// The Exec method panics. +type badConn struct{} + +func (bc badConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("badConn Prepare") +} + +func (bc badConn) Close() error { + return nil +} + +func (bc badConn) Begin() (driver.Tx, error) { + return nil, errors.New("badConn Begin") +} + +func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) { + panic("badConn.Exec") +} + +// badDriver is a driver.Driver that uses badConn. +type badDriver struct{} + +func (bd badDriver) Open(name string) (driver.Conn, error) { + return badConn{}, nil +} + +// Issue 15901. +func TestBadDriver(t *testing.T) { + Register("bad", badDriver{}) + db, err := Open("bad", "ignored") + if err != nil { + t.Fatal(err) + } + defer func() { + if r := recover(); r == nil { + t.Error("expected panic") + } else { + if want := "badConn.Exec"; r.(string) != want { + t.Errorf("panic was %v, expected %v", r, want) + } + } + }() + defer db.Close() + db.Exec("ignored") +} + +type pingDriver struct { + fails bool +} + +type pingConn struct { + badConn + driver *pingDriver +} + +var pingError = errors.New("Ping failed") + +func (pc pingConn) Ping(ctx context.Context) error { + if pc.driver.fails { + return pingError + } + return nil +} + +var _ driver.Pinger = pingConn{} + +func (pd *pingDriver) Open(name string) (driver.Conn, error) { + return pingConn{driver: pd}, nil +} + +func TestPing(t *testing.T) { + driver := &pingDriver{} + Register("ping", driver) + + db, err := Open("ping", "ignored") + if err != nil { + t.Fatal(err) + } + + if err := db.Ping(); err != nil { + t.Errorf("err was %#v, expected nil", err) + return + } + + driver.fails = true + if err := db.Ping(); err != pingError { + t.Errorf("err was %#v, expected pingError", err) + } +} + func BenchmarkConcurrentDBExec(b *testing.B) { b.ReportAllocs() ct := new(concurrentDBExecTest) |