diff --git a/sqlite3_test.go b/sqlite3_test.go index 55536e8..bcbba44 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "./sqltest" "crypto/rand" "database/sql" "encoding/hex" @@ -626,6 +627,16 @@ func TestWAL(t *testing.T) { } } +func TestSuite(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + sqltest.RunTests(t, db, sqltest.SQLITE) +} + // TODO: Execer & Queryer currently disabled // https://github.com/mattn/go-sqlite3/issues/82 //func TestExecer(t *testing.T) { diff --git a/sqltest/sqltest.go b/sqltest/sqltest.go new file mode 100644 index 0000000..c101c96 --- /dev/null +++ b/sqltest/sqltest.go @@ -0,0 +1,412 @@ +package sqltest + +import ( + "database/sql" + "fmt" + "math/rand" + "regexp" + "strconv" + "sync" + "testing" + "time" +) + +type Dialect int + +const ( + SQLITE Dialect = iota + POSTGRESQL + MYSQL +) + +type DB struct { + *testing.T + *sql.DB + dialect Dialect + once sync.Once +} + +var db *DB + +// the following tables will be created and dropped during the test +var testTables = []string{"foo", "bar", "t", "bench"} + +var tests = []testing.InternalTest{ + {"TestBlobs", TestBlobs}, + {"TestManyQueryRow", TestManyQueryRow}, + {"TestTxQuery", TestTxQuery}, + {"TestPreparedStmt", TestPreparedStmt}, +} + +var benchmarks = []testing.InternalBenchmark{ + {"BenchmarkExec", BenchmarkExec}, + {"BenchmarkQuery", BenchmarkQuery}, + {"BenchmarkParams", BenchmarkParams}, + {"BenchmarkStmt", BenchmarkStmt}, + {"BenchmarkRows", BenchmarkRows}, + {"BenchmarkStmtRows", BenchmarkStmtRows}, +} + +// RunTests runs the SQL test suite +func RunTests(t *testing.T, d *sql.DB, dialect Dialect) { + db = &DB{t, d, dialect, sync.Once{}} + testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) + + if !testing.Short() { + for _, b := range benchmarks { + fmt.Printf("%-20s", b.Name) + r := testing.Benchmark(b.F) + fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) + } + } + db.tearDown() +} + +func (db *DB) mustExec(sql string, args ...interface{}) sql.Result { + res, err := db.Exec(sql, args...) + if err != nil { + db.Fatalf("Error running %q: %v", sql, err) + } + return res +} + +func (db *DB) tearDown() { + for _, tbl := range testTables { + switch db.dialect { + case SQLITE: + db.mustExec("drop table if exists " + tbl) + case MYSQL, POSTGRESQL: + db.mustExec("drop table if exists " + tbl) + default: + db.Fatal("unkown dialect") + } + } +} + +// q replaces ? parameters if needed +func (db *DB) q(sql string) string { + switch db.dialect { + case POSTGRESQL: // repace with $1, $2, .. + qrx := regexp.MustCompile(`\?`) + n := 0 + return qrx.ReplaceAllStringFunc(sql, func(string) string { + n++ + return "$" + strconv.Itoa(n) + }) + } + return sql +} + +func (db *DB) blobType(size int) string { + switch db.dialect { + case SQLITE: + return fmt.Sprintf("blob[%d]", size) + case POSTGRESQL: + return "bytea" + case MYSQL: + return fmt.Sprintf("VARBINARY(%d)", size) + } + panic("unkown dialect") +} + +func (db *DB) serialPK() string { + switch db.dialect { + case SQLITE: + return "integer primary key autoincrement" + case POSTGRESQL: + return "serial primary key" + case MYSQL: + return "integer primary key auto_increment" + } + panic("unkown dialect") +} + +func (db *DB) now() string { + switch db.dialect { + case SQLITE: + return "datetime('now')" + case POSTGRESQL: + return "now()" + case MYSQL: + return "now()" + } + panic("unkown dialect") +} + +func makeBench() { + if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { + panic(err) + } + st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") + if err != nil { + panic(err) + } + defer st.Close() + for i := 0; i < 100; i++ { + if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { + panic(err) + } + } +} + +func TestResult(t *testing.T) { + db.tearDown() + db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") + + for i := 1; i < 3; i++ { + r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) + n, err := r.RowsAffected() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("got %v, want %v", n, 1) + } + n, err = r.LastInsertId() + if err != nil { + t.Fatal(err) + } + if n != int64(i) { + t.Errorf("got %v, want %v", n, i) + } + } + if _, err := db.Exec("error!"); err == nil { + t.Fatalf("expected error") + } +} + +func TestBlobs(t *testing.T) { + db.tearDown() + var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) + + want := fmt.Sprintf("%x", blob) + + b := make([]byte, 16) + err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) + got := fmt.Sprintf("%x", b) + if err != nil { + t.Errorf("[]byte scan: %v", err) + } else if got != want { + t.Errorf("for []byte, got %q; want %q", got, want) + } + + err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) + want = string(blob) + if err != nil { + t.Errorf("string scan: %v", err) + } else if got != want { + t.Errorf("for string, got %q; want %q", got, want) + } +} + +func TestManyQueryRow(t *testing.T) { + if testing.Short() { + t.Log("skipping in short mode") + return + } + db.tearDown() + db.mustExec("create table foo (id integer primary key, name varchar(50))") + db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + var name string + for i := 0; i < 10000; i++ { + err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) + if err != nil || name != "bob" { + t.Fatalf("on query %d: err=%v, name=%q", i, err, name) + } + } +} + +func TestTxQuery(t *testing.T) { + db.tearDown() + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") + if err != nil { + t.Fatal(err) + } + + _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + if err != nil { + t.Fatal(err) + } + + r, err := tx.Query(db.q("select name from foo where id = ?"), 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } +} + +func TestPreparedStmt(t *testing.T) { + db.tearDown() + db.mustExec("CREATE TABLE t (count INT)") + sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") + if err != nil { + t.Fatalf("prepare 1: %v", err) + } + ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) + if err != nil { + t.Fatalf("prepare 2: %v", err) + } + + for n := 1; n <= 3; n++ { + if _, err := ins.Exec(n); err != nil { + t.Fatalf("insert(%d) = %v", n, err) + } + } + + const nRuns = 10 + ch := make(chan bool) + for i := 0; i < nRuns; i++ { + go func() { + defer func() { + ch <- true + }() + for j := 0; j < 10; j++ { + count := 0 + if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { + t.Errorf("Query: %v", err) + return + } + if _, err := ins.Exec(rand.Intn(100)); err != nil { + t.Errorf("Insert: %v", err) + return + } + } + }() + } + for i := 0; i < nRuns; i++ { + <-ch + } +} + +// Benchmarks need to use panic() since b.Error errors are lost when +// running via testing.Benchmark() I would like to run these via go +// test -bench but calling Benchmark() from a benchmark test +// currently hangs go. + +func BenchmarkExec(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1"); err != nil { + panic(err) + } + } +} + +func BenchmarkQuery(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string +// var t time.Time + if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +func BenchmarkParams(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string +// var t time.Time + if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +func BenchmarkStmt(b *testing.B) { + st, err := db.Prepare("select ?, ?, ?, ?") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string +// var t time.Time + if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +func BenchmarkRows(b *testing.B) { + db.once.Do(makeBench) + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := db.Query("select * from bench") + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +} + +func BenchmarkStmtRows(b *testing.B) { + db.once.Do(makeBench) + + st, err := db.Prepare("select * from bench") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := st.Query() + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +}