Add strict mode.

master
Ben Johnson 2014-05-14 18:08:55 -06:00
parent a6d6d964b6
commit 1f5fb0208b
3 changed files with 134 additions and 63 deletions

80
db.go
View File

@ -6,6 +6,7 @@ import (
"hash/fnv"
"io"
"os"
"strings"
"sync"
"syscall"
"unsafe"
@ -47,6 +48,12 @@ var (
// All data access is performed through transactions which can be obtained through the DB.
// All the functions on DB will return a ErrDatabaseNotOpen if accessed before Open() is called.
type DB struct {
// When enabled, the database will perform a Check() after every commit.
// A panic is issued if the database is in an inconsistent state. This
// flag has a large performance impact so it should only be used for
// debugging purposes.
StrictMode bool
path string
file *os.File
data []byte
@ -533,69 +540,7 @@ func (db *DB) Stats() Stats {
// An error is returned if any inconsistency is found.
func (db *DB) Check() error {
return db.Update(func(tx *Tx) error {
var errors ErrorList
// Track every reachable page.
reachable := make(map[pgid]*page)
reachable[0] = db.page(0) // meta0
reachable[1] = db.page(1) // meta1
for i := uint32(0); i <= db.page(tx.meta.freelist).overflow; i++ {
reachable[tx.meta.freelist+pgid(i)] = db.page(tx.meta.freelist)
}
// Recursively check buckets.
db.checkBucket(&tx.root, reachable, &errors)
// Ensure all pages below high water mark are either reachable or freed.
for i := pgid(0); i < tx.meta.pgid; i++ {
_, isReachable := reachable[i]
if !isReachable && !db.freelist.isFree(i) {
errors = append(errors, fmt.Errorf("page %d: unreachable unfreed", int(i)))
}
}
if len(errors) > 0 {
return errors
}
return nil
})
}
func (db *DB) checkBucket(b *Bucket, reachable map[pgid]*page, errors *ErrorList) {
// Ignore inline buckets.
if b.root == 0 {
return
}
// Check every page used by this bucket.
b.tx.forEachPage(b.root, 0, func(p *page, _ int) {
// Ensure each page is only referenced once.
for i := pgid(0); i <= pgid(p.overflow); i++ {
var id = p.id + i
if _, ok := reachable[id]; ok {
*errors = append(*errors, fmt.Errorf("page %d: multiple references", int(id)))
}
reachable[id] = p
}
// Retrieve page info.
info, err := b.tx.Page(int(p.id))
if err != nil {
*errors = append(*errors, err)
} else if info == nil {
*errors = append(*errors, fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid)))
} else if info.Type != "branch" && info.Type != "leaf" {
*errors = append(*errors, fmt.Errorf("page %d: invalid type: %s", int(p.id), info.Type))
}
})
// Check each bucket within this bucket.
_ = b.ForEach(func(k, v []byte) error {
if child := b.Bucket(k); child != nil {
db.checkBucket(child, reachable, errors)
}
return nil
return tx.Check()
})
}
@ -734,6 +679,15 @@ func (l ErrorList) Error() string {
return fmt.Sprintf("%d errors occurred", len(l))
}
// join returns a error messages joined by a string.
func (l ErrorList) join(sep string) string {
var a []string
for _, e := range l {
a = append(a, e.Error())
}
return strings.Join(a, sep)
}
// _assert will panic with a given formatted message if the given condition is false.
func _assert(condition bool, msg string, v ...interface{}) {
if !condition {

92
tx.go
View File

@ -2,6 +2,7 @@ package bolt
import (
"errors"
"fmt"
"sort"
"time"
"unsafe"
@ -175,6 +176,14 @@ func (tx *Tx) Commit() error {
return err
}
// If strict mode is enabled then perform a consistency check.
if tx.db.StrictMode {
if err := tx.Check(); err != nil {
err := err.(ErrorList)
panic("check fail: " + err.Error() + ": " + err.join("; "))
}
}
// Write meta to disk.
if err := tx.writeMeta(); err != nil {
tx.close()
@ -218,6 +227,89 @@ func (tx *Tx) close() {
tx.db = nil
}
// Check performs several consistency checks on the database for this transaction.
// An error is returned if any inconsistency is found or if executed on a read-only transaction.
func (tx *Tx) Check() error {
if !tx.writable {
return ErrTxNotWritable
}
var errors ErrorList
// Check if any pages are double freed.
freed := make(map[pgid]bool)
for _, id := range tx.db.freelist.all() {
if freed[id] {
errors = append(errors, fmt.Errorf("page %d: already freed", id))
}
freed[id] = true
}
// Track every reachable page.
reachable := make(map[pgid]*page)
reachable[0] = tx.page(0) // meta0
reachable[1] = tx.page(1) // meta1
for i := uint32(0); i <= tx.page(tx.meta.freelist).overflow; i++ {
reachable[tx.meta.freelist+pgid(i)] = tx.page(tx.meta.freelist)
}
// Recursively check buckets.
tx.checkBucket(&tx.root, reachable, &errors)
// Ensure all pages below high water mark are either reachable or freed.
for i := pgid(0); i < tx.meta.pgid; i++ {
_, isReachable := reachable[i]
if !isReachable && !freed[i] {
errors = append(errors, fmt.Errorf("page %d: unreachable unfreed", int(i)))
} else if isReachable && freed[i] {
errors = append(errors, fmt.Errorf("page %d: reachable freed", int(i)))
}
}
if len(errors) > 0 {
return errors
}
return nil
}
func (tx *Tx) checkBucket(b *Bucket, reachable map[pgid]*page, errors *ErrorList) {
// Ignore inline buckets.
if b.root == 0 {
return
}
// Check every page used by this bucket.
b.tx.forEachPage(b.root, 0, func(p *page, _ int) {
// Ensure each page is only referenced once.
for i := pgid(0); i <= pgid(p.overflow); i++ {
var id = p.id + i
if _, ok := reachable[id]; ok {
*errors = append(*errors, fmt.Errorf("page %d: multiple references", int(id)))
}
reachable[id] = p
}
// Retrieve page info.
info, err := b.tx.Page(int(p.id))
if err != nil {
*errors = append(*errors, err)
} else if info == nil {
*errors = append(*errors, fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid)))
} else if info.Type != "branch" && info.Type != "leaf" {
*errors = append(*errors, fmt.Errorf("page %d: invalid type: %s", int(p.id), info.Type))
}
})
// Check each bucket within this bucket.
_ = b.ForEach(func(k, v []byte) error {
if child := b.Bucket(k); child != nil {
tx.checkBucket(child, reachable, errors)
}
return nil
})
}
// allocate returns a contiguous block of memory starting at a given page.
func (tx *Tx) allocate(count int) (*page, error) {
p, err := tx.db.allocate(count)

View File

@ -288,6 +288,31 @@ func TestTx_OnCommit_Rollback(t *testing.T) {
assert.Equal(t, 0, x)
}
// Ensure that a Tx in strict mode will fail when corrupted.
func TestTx_Check_Corrupt(t *testing.T) {
var msg string
func() {
defer func() {
msg = fmt.Sprintf("%s", recover())
}()
withOpenDB(func(db *DB, path string) {
db.StrictMode = true
db.Update(func(tx *Tx) error {
tx.CreateBucket([]byte("foo"))
// Corrupt the DB by adding a page to the freelist.
warn("---")
db.freelist.free(0, tx.page(3))
return nil
})
})
}()
assert.Equal(t, "check fail: 1 errors occurred: page 3: already freed", msg)
}
func ExampleTx_Rollback() {
// Open the database.
db, _ := Open(tempfile(), 0666)