diff --git a/sqlite3.go b/sqlite3.go index 7569b73..bcf3099 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1123,18 +1123,20 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, done: make(chan struct{}), } - go func(db *C.sqlite3) { - select { - case <-ctx.Done(): + if ctxdone := ctx.Done(); ctxdone != nil { + go func(db *C.sqlite3) { select { + case <-ctxdone: + select { + case <-rows.done: + default: + C.sqlite3_interrupt(db) + rows.Close() + } case <-rows.done: - default: - C.sqlite3_interrupt(db) - rows.Close() } - case <-rows.done: - } - }(s.c.db) + }(s.c.db) + } return rows, nil } @@ -1168,19 +1170,21 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result return nil, err } - done := make(chan struct{}) - defer close(done) - go func(db *C.sqlite3) { - select { - case <-done: - case <-ctx.Done(): + if ctxdone := ctx.Done(); ctxdone != nil { + done := make(chan struct{}) + defer close(done) + go func(db *C.sqlite3) { select { case <-done: - default: - C.sqlite3_interrupt(db) + case <-ctxdone: + select { + case <-done: + default: + C.sqlite3_interrupt(db) + } } - } - }(s.c.db) + }(s.c.db) + } var rowid, changes C.longlong rv := C._sqlite3_step(s.s, &rowid, &changes)