report actual error message if sqlite3_load_extension fails (#800)

* report actual error message if sqlite3_load_extension fails

* more fixes and test cases

Co-authored-by: Jesse Rittner <jrittner@lutron.com>
use-ignore
rittneje 2020-04-16 01:45:59 -04:00 committed by GitHub
parent 58b2310c97
commit 98a44bcf59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 26 deletions

View File

@ -1,22 +1,27 @@
ifeq ($(OS),Windows_NT) ifeq ($(OS),Windows_NT)
EXE=extension.exe EXE=extension.exe
EXT=sqlite3_mod_regexp.dll LIB_EXT=dll
RM=cmd /c del RM=cmd /c del
LDFLAG= LDFLAG=
else else
EXE=extension EXE=extension
EXT=sqlite3_mod_regexp.so ifeq ($(shell uname -s),Darwin)
RM=rm LIB_EXT=dylib
else
LIB_EXT=so
endif
RM=rm -f
LDFLAG=-fPIC LDFLAG=-fPIC
endif endif
LIB=sqlite3_mod_regexp.$(LIB_EXT)
all : $(EXE) $(EXT) all : $(EXE) $(LIB)
$(EXE) : extension.go $(EXE) : extension.go
go build $< go build $<
$(EXT) : sqlite3_mod_regexp.c $(LIB) : sqlite3_mod_regexp.c
gcc $(LDFLAG) -shared -o $@ $< -lsqlite3 -lpcre gcc $(LDFLAG) -shared -o $@ $< -lsqlite3 -lpcre
clean : clean :
@-$(RM) $(EXE) $(EXT) @-$(RM) $(EXE) $(LIB)

View File

@ -1,24 +1,29 @@
ifeq ($(OS),Windows_NT) ifeq ($(OS),Windows_NT)
EXE=extension.exe EXE=extension.exe
EXT=sqlite3_mod_vtable.dll LIB_EXT=dll
RM=cmd /c del RM=cmd /c del
LIBCURL=-lcurldll LIBCURL=-lcurldll
LDFLAG= LDFLAG=
else else
EXE=extension EXE=extension
EXT=sqlite3_mod_vtable.so ifeq ($(shell uname -s),Darwin)
RM=rm LIB_EXT=dylib
else
LIB_EXT=so
endif
RM=rm -f
LDFLAG=-fPIC LDFLAG=-fPIC
LIBCURL=-lcurl LIBCURL=-lcurl
endif endif
LIB=sqlite3_mod_vtable.$(LIB_EXT)
all : $(EXE) $(EXT) all : $(EXE) $(LIB)
$(EXE) : extension.go $(EXE) : extension.go
go build $< go build $<
$(EXT) : sqlite3_mod_vtable.cc $(LIB) : sqlite3_mod_vtable.cc
g++ $(LDFLAG) -shared -o $@ $< -lsqlite3 $(LIBCURL) g++ $(LDFLAG) -shared -o $@ $< -lsqlite3 $(LIBCURL)
clean : clean :
@-$(RM) $(EXE) $(EXT) @-$(RM) $(EXE) $(LIB)

View File

@ -1,6 +1,6 @@
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <sqlite3-binding.h> #include <sqlite3.h>
#include <sqlite3ext.h> #include <sqlite3ext.h>
#include <curl/curl.h> #include <curl/curl.h>
#include "picojson.h" #include "picojson.h"

View File

@ -28,12 +28,9 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
} }
for _, extension := range extensions { for _, extension := range extensions {
cext := C.CString(extension) if err := c.loadExtension(extension, nil); err != nil {
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK {
C.sqlite3_enable_load_extension(c.db, 0) C.sqlite3_enable_load_extension(c.db, 0)
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return err
} }
} }
@ -41,6 +38,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
} }
return nil return nil
} }
@ -51,14 +49,9 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
} }
clib := C.CString(lib) if err := c.loadExtension(lib, &entry); err != nil {
defer C.free(unsafe.Pointer(clib)) C.sqlite3_enable_load_extension(c.db, 0)
centry := C.CString(entry) return err
defer C.free(unsafe.Pointer(centry))
rv = C.sqlite3_load_extension(c.db, clib, centry, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
} }
rv = C.sqlite3_enable_load_extension(c.db, 0) rv = C.sqlite3_enable_load_extension(c.db, 0)
@ -68,3 +61,24 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return nil return nil
} }
func (c *SQLiteConn) loadExtension(lib string, entry *string) error {
clib := C.CString(lib)
defer C.free(unsafe.Pointer(clib))
var centry *C.char
if entry != nil {
centry := C.CString(*entry)
defer C.free(unsafe.Pointer(centry))
}
var errMsg *C.char
defer C.sqlite3_free(unsafe.Pointer(errMsg))
rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(errMsg))
}
return nil
}

View File

@ -0,0 +1,63 @@
// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !sqlite_omit_load_extension
package sqlite3
import (
"database/sql"
"testing"
)
func TestExtensionsError(t *testing.T) {
sql.Register("sqlite3_TestExtensionsError",
&SQLiteDriver{
Extensions: []string{
"foobar",
},
},
)
db, err := sql.Open("sqlite3_TestExtensionsError", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Ping()
if err == nil {
t.Fatal("expected error loading non-existent extension")
}
if err.Error() == "not an error" {
t.Fatal("expected error from sqlite3_enable_load_extension to be returned")
}
}
func TestLoadExtensionError(t *testing.T) {
sql.Register("sqlite3_TestLoadExtensionError",
&SQLiteDriver{
ConnectHook: func(c *SQLiteConn) error {
return c.LoadExtension("foobar", "")
},
},
)
db, err := sql.Open("sqlite3_TestLoadExtensionError", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Ping()
if err == nil {
t.Fatal("expected error loading non-existent extension")
}
if err.Error() == "not an error" {
t.Fatal("expected error from sqlite3_enable_load_extension to be returned")
}
}