report actual error message if sqlite3_load_extension fails (#800)

* report actual error message if sqlite3_load_extension fails

* more fixes and test cases

Co-authored-by: Jesse Rittner <jrittner@lutron.com>
use-ignore
rittneje 2020-04-16 01:45:59 -04:00 committed by GitHub
parent 58b2310c97
commit 98a44bcf59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 26 deletions

View File

@ -1,22 +1,27 @@
ifeq ($(OS),Windows_NT)
EXE=extension.exe
EXT=sqlite3_mod_regexp.dll
LIB_EXT=dll
RM=cmd /c del
LDFLAG=
else
EXE=extension
EXT=sqlite3_mod_regexp.so
RM=rm
ifeq ($(shell uname -s),Darwin)
LIB_EXT=dylib
else
LIB_EXT=so
endif
RM=rm -f
LDFLAG=-fPIC
endif
LIB=sqlite3_mod_regexp.$(LIB_EXT)
all : $(EXE) $(EXT)
all : $(EXE) $(LIB)
$(EXE) : extension.go
go build $<
$(EXT) : sqlite3_mod_regexp.c
$(LIB) : sqlite3_mod_regexp.c
gcc $(LDFLAG) -shared -o $@ $< -lsqlite3 -lpcre
clean :
@-$(RM) $(EXE) $(EXT)
@-$(RM) $(EXE) $(LIB)

View File

@ -1,24 +1,29 @@
ifeq ($(OS),Windows_NT)
EXE=extension.exe
EXT=sqlite3_mod_vtable.dll
LIB_EXT=dll
RM=cmd /c del
LIBCURL=-lcurldll
LDFLAG=
else
EXE=extension
EXT=sqlite3_mod_vtable.so
RM=rm
ifeq ($(shell uname -s),Darwin)
LIB_EXT=dylib
else
LIB_EXT=so
endif
RM=rm -f
LDFLAG=-fPIC
LIBCURL=-lcurl
endif
LIB=sqlite3_mod_vtable.$(LIB_EXT)
all : $(EXE) $(EXT)
all : $(EXE) $(LIB)
$(EXE) : extension.go
go build $<
$(EXT) : sqlite3_mod_vtable.cc
$(LIB) : sqlite3_mod_vtable.cc
g++ $(LDFLAG) -shared -o $@ $< -lsqlite3 $(LIBCURL)
clean :
@-$(RM) $(EXE) $(EXT)
@-$(RM) $(EXE) $(LIB)

View File

@ -1,6 +1,6 @@
#include <string>
#include <sstream>
#include <sqlite3-binding.h>
#include <sqlite3.h>
#include <sqlite3ext.h>
#include <curl/curl.h>
#include "picojson.h"

View File

@ -28,12 +28,9 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
}
for _, extension := range extensions {
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK {
if err := c.loadExtension(extension, nil); err != nil {
C.sqlite3_enable_load_extension(c.db, 0)
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
return err
}
}
@ -41,6 +38,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}
@ -51,14 +49,9 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
clib := C.CString(lib)
defer C.free(unsafe.Pointer(clib))
centry := C.CString(entry)
defer C.free(unsafe.Pointer(centry))
rv = C.sqlite3_load_extension(c.db, clib, centry, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
if err := c.loadExtension(lib, &entry); err != nil {
C.sqlite3_enable_load_extension(c.db, 0)
return err
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
@ -68,3 +61,24 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return nil
}
func (c *SQLiteConn) loadExtension(lib string, entry *string) error {
clib := C.CString(lib)
defer C.free(unsafe.Pointer(clib))
var centry *C.char
if entry != nil {
centry := C.CString(*entry)
defer C.free(unsafe.Pointer(centry))
}
var errMsg *C.char
defer C.sqlite3_free(unsafe.Pointer(errMsg))
rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(errMsg))
}
return nil
}

View File

@ -0,0 +1,63 @@
// Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !sqlite_omit_load_extension
package sqlite3
import (
"database/sql"
"testing"
)
func TestExtensionsError(t *testing.T) {
sql.Register("sqlite3_TestExtensionsError",
&SQLiteDriver{
Extensions: []string{
"foobar",
},
},
)
db, err := sql.Open("sqlite3_TestExtensionsError", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Ping()
if err == nil {
t.Fatal("expected error loading non-existent extension")
}
if err.Error() == "not an error" {
t.Fatal("expected error from sqlite3_enable_load_extension to be returned")
}
}
func TestLoadExtensionError(t *testing.T) {
sql.Register("sqlite3_TestLoadExtensionError",
&SQLiteDriver{
ConnectHook: func(c *SQLiteConn) error {
return c.LoadExtension("foobar", "")
},
},
)
db, err := sql.Open("sqlite3_TestLoadExtensionError", ":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Ping()
if err == nil {
t.Fatal("expected error loading non-existent extension")
}
if err.Error() == "not an error" {
t.Fatal("expected error from sqlite3_enable_load_extension to be returned")
}
}