From db4c9426f851c76327db11e1b5301e520e0e048b Mon Sep 17 00:00:00 2001 From: gber Date: Thu, 14 May 2020 14:28:04 +0000 Subject: [PATCH] Enable all prefixes for named parameters and allow for unused named parameters (#811) * Allow unused named parameters Try to bind all named parameters and ignore those not used. * Allow "@" and "$" for named parameters * Add tests for named parameters Co-authored-by: Guido Berhoerster --- sqlite3.go | 132 ++++++++++++++++++++++++++++-------------------- sqlite3_test.go | 39 ++++++++++++++ 2 files changed, 117 insertions(+), 54 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 59e3670..86c0f64 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -802,20 +802,29 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) } var res driver.Result if s.(*SQLiteStmt).s != nil { + stmtArgs := make([]namedValue, 0, len(args)) na := s.NumInput() - if len(args) < na { + if len(args) - start < na { s.Close() return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) } - for i := 0; i < na; i++ { - args[i].Ordinal -= start + // consume the number of arguments used in the current + // statement and append all named arguments not + // contained therein + stmtArgs = append(stmtArgs, args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) + } } - res, err = s.(*SQLiteStmt).exec(ctx, args[:na]) + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { s.Close() return nil, err } - args = args[na:] start += na } tail := s.(*SQLiteStmt).t @@ -848,24 +857,33 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) { start := 0 for { + stmtArgs := make([]namedValue, 0, len(args)) s, err := c.prepare(ctx, query) if err != nil { return nil, err } s.(*SQLiteStmt).cls = true na := s.NumInput() - if len(args) < na { - return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) + if len(args) - start < na { + return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args) - start) } - for i := 0; i < na; i++ { - args[i].Ordinal -= start + // consume the number of arguments used in the current + // statement and append all named arguments not contained + // therein + stmtArgs = append(stmtArgs, args[start:start+na]...) + for i := range args { + if (i < start || i >= na) && args[i].Name != "" { + stmtArgs = append(stmtArgs, args[i]) + } } - rows, err := s.(*SQLiteStmt).query(ctx, args[:na]) + for i := range stmtArgs { + stmtArgs[i].Ordinal = i + 1 + } + rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) if err != nil && err != driver.ErrSkip { s.Close() return rows, err } - args = args[na:] start += na tail := s.(*SQLiteStmt).t if tail == "" { @@ -1778,11 +1796,6 @@ func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) } -type bindArg struct { - n int - v driver.Value -} - var placeHolder = []byte{0} func (s *SQLiteStmt) bind(args []namedValue) error { @@ -1791,52 +1804,63 @@ func (s *SQLiteStmt) bind(args []namedValue) error { return s.c.lastError() } + bindIndices := make([][3]int, len(args)) + prefixes := []string{":", "@", "$"} for i, v := range args { + bindIndices[i][0] = args[i].Ordinal if v.Name != "" { - cname := C.CString(":" + v.Name) - args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) + for j := range prefixes { + cname := C.CString(prefixes[j] + v.Name) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) + C.free(unsafe.Pointer(cname)) + } + args[i].Ordinal = bindIndices[i][0] } } - for _, arg := range args { - n := C.int(arg.Ordinal) - switch v := arg.Value.(type) { - case nil: - rv = C.sqlite3_bind_null(s.s, n) - case string: - if len(v) == 0 { - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) - } else { - b := []byte(v) + for i, arg := range args { + for j := range bindIndices[i] { + if bindIndices[i][j] == 0 { + continue + } + n := C.int(bindIndices[i][j]) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + if len(v) == 0 { + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) + } else { + b := []byte(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + if v { + rv = C.sqlite3_bind_int(s.s, n, 1) + } else { + rv = C.sqlite3_bind_int(s.s, n, 0) + } + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) + } + case time.Time: + b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } - case int64: - rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) - case bool: - if v { - rv = C.sqlite3_bind_int(s.s, n, 1) - } else { - rv = C.sqlite3_bind_int(s.s, n, 0) + if rv != C.SQLITE_OK { + return s.c.lastError() } - case float64: - rv = C.sqlite3_bind_double(s.s, n, C.double(v)) - case []byte: - if v == nil { - rv = C.sqlite3_bind_null(s.s, n) - } else { - ln := len(v) - if ln == 0 { - v = placeHolder - } - rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) - } - case time.Time: - b := []byte(v.Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } - if rv != C.SQLITE_OK { - return s.c.lastError() } } return nil diff --git a/sqlite3_test.go b/sqlite3_test.go index 4b8fe01..d5b0cea 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1778,6 +1778,45 @@ func TestInsertNilByteSlice(t *testing.T) { } } +func TestNamedParam(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer, name text, amount integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)", + sql.Named("bar", 42), sql.Named("baz", "quux"), + sql.Named("amount", 123), sql.Named("corge", "waldo"), + sql.Named("id", 2), sql.Named("name", "grault")) + if err != nil { + t.Fatal("Failed to insert record with named parameters:", err) + } + + rows, err := db.Query("select id, name, amount from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var id, amount int + var name string + rows.Scan(&id, &name, &amount) + if id != 2 || name != "grault" || amount != 123 { + t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount) + } +} + var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) {