From b037a616903746de8e647f53503d4edca29192ec Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 21 Aug 2015 17:12:18 -0700 Subject: [PATCH] Add support for interface{} arguments in Go SQLite functions. This enabled support for functions like Foo(a interface{}) and Bar(a ...interface{}). --- callback.go | 24 ++++++++++++++++++++++++ sqlite3.go | 13 ++++++++----- sqlite3_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/callback.go b/callback.go index 1692106..b1704fe 100644 --- a/callback.go +++ b/callback.go @@ -108,8 +108,32 @@ func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { } } +func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_INTEGER: + return callbackArgInt64(v) + case C.SQLITE_FLOAT: + return callbackArgFloat64(v) + case C.SQLITE_TEXT: + return callbackArgString(v) + case C.SQLITE_BLOB: + return callbackArgBytes(v) + case C.SQLITE_NULL: + // Interpret NULL as a nil byte slice. + var ret []byte + return reflect.ValueOf(ret), nil + default: + panic("unreachable") + } +} + func callbackArg(typ reflect.Type) (callbackArgConverter, error) { switch typ.Kind() { + case reflect.Interface: + if typ.NumMethod() != 0 { + return nil, errors.New("the only supported interface type is interface{}") + } + return callbackArgGeneric, nil case reflect.Slice: if typ.Elem().Kind() != reflect.Uint8 { return nil, errors.New("the only supported slice type is []byte") diff --git a/sqlite3.go b/sqlite3.go index 8bb9826..73e67e3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -232,11 +232,14 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function can accept arguments of any real numeric type -// (i.e. not complex), as well as []byte and string. It must return a -// value of one of those types, and optionally an error as a second -// value. Variadic functions are allowed, if the variadic argument is -// one of the allowed types. +// The Go function can have arguments of the following types: any +// numeric type except complex, bool, []byte, string and +// interface{}. interface{} arguments are given the direct translation +// of the SQLite data type: int64 for INTEGER, float64 for FLOAT, +// []byte for BLOB, string for TEXT. +// +// The function can additionally be variadic, as long as the type of +// the variadic argument is one of the above. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive diff --git a/sqlite3_test.go b/sqlite3_test.go index a563c08..62db05b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,20 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + generic := func(a interface{}) int64 { + switch a.(type) { + case int64: + return 1 + case float64: + return 2 + case []byte: + return 3 + case string: + return 4 + default: + panic("unreachable") + } + } variadic := func(a, b int64, c ...int64) int64 { ret := a + b for _, d := range c { @@ -1078,6 +1092,9 @@ func TestFunctionRegistration(t *testing.T) { } return ret } + variadicGeneric := func(a ...interface{}) int64 { + return int64(len(a)) + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1105,9 +1122,15 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("generic", generic, true); err != nil { + return err + } if err := conn.RegisterFunc("variadic", variadic, true); err != nil { return err } + if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { + return err + } return nil }, }) @@ -1131,9 +1154,14 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT generic(1)", int64(1)}, + {"SELECT generic(1.1)", int64(2)}, + {`SELECT generic(NULL)`, int64(3)}, + {`SELECT generic("foo")`, int64(4)}, {"SELECT variadic(1,2)", int64(3)}, {"SELECT variadic(1,2,3,4)", int64(10)}, {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, + {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, } for _, op := range ops {