diff --git a/db.go b/db.go index 0b74621..758c911 100644 --- a/db.go +++ b/db.go @@ -6,6 +6,7 @@ import ( "hash/fnv" "io" "os" + "strings" "sync" "syscall" "unsafe" @@ -47,6 +48,12 @@ var ( // All data access is performed through transactions which can be obtained through the DB. // All the functions on DB will return a ErrDatabaseNotOpen if accessed before Open() is called. type DB struct { + // When enabled, the database will perform a Check() after every commit. + // A panic is issued if the database is in an inconsistent state. This + // flag has a large performance impact so it should only be used for + // debugging purposes. + StrictMode bool + path string file *os.File data []byte @@ -533,69 +540,7 @@ func (db *DB) Stats() Stats { // An error is returned if any inconsistency is found. func (db *DB) Check() error { return db.Update(func(tx *Tx) error { - var errors ErrorList - - // Track every reachable page. - reachable := make(map[pgid]*page) - reachable[0] = db.page(0) // meta0 - reachable[1] = db.page(1) // meta1 - for i := uint32(0); i <= db.page(tx.meta.freelist).overflow; i++ { - reachable[tx.meta.freelist+pgid(i)] = db.page(tx.meta.freelist) - } - - // Recursively check buckets. - db.checkBucket(&tx.root, reachable, &errors) - - // Ensure all pages below high water mark are either reachable or freed. - for i := pgid(0); i < tx.meta.pgid; i++ { - _, isReachable := reachable[i] - if !isReachable && !db.freelist.isFree(i) { - errors = append(errors, fmt.Errorf("page %d: unreachable unfreed", int(i))) - } - } - - if len(errors) > 0 { - return errors - } - - return nil - }) -} - -func (db *DB) checkBucket(b *Bucket, reachable map[pgid]*page, errors *ErrorList) { - // Ignore inline buckets. - if b.root == 0 { - return - } - - // Check every page used by this bucket. - b.tx.forEachPage(b.root, 0, func(p *page, _ int) { - // Ensure each page is only referenced once. - for i := pgid(0); i <= pgid(p.overflow); i++ { - var id = p.id + i - if _, ok := reachable[id]; ok { - *errors = append(*errors, fmt.Errorf("page %d: multiple references", int(id))) - } - reachable[id] = p - } - - // Retrieve page info. - info, err := b.tx.Page(int(p.id)) - if err != nil { - *errors = append(*errors, err) - } else if info == nil { - *errors = append(*errors, fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid))) - } else if info.Type != "branch" && info.Type != "leaf" { - *errors = append(*errors, fmt.Errorf("page %d: invalid type: %s", int(p.id), info.Type)) - } - }) - - // Check each bucket within this bucket. - _ = b.ForEach(func(k, v []byte) error { - if child := b.Bucket(k); child != nil { - db.checkBucket(child, reachable, errors) - } - return nil + return tx.Check() }) } @@ -734,6 +679,15 @@ func (l ErrorList) Error() string { return fmt.Sprintf("%d errors occurred", len(l)) } +// join returns a error messages joined by a string. +func (l ErrorList) join(sep string) string { + var a []string + for _, e := range l { + a = append(a, e.Error()) + } + return strings.Join(a, sep) +} + // _assert will panic with a given formatted message if the given condition is false. func _assert(condition bool, msg string, v ...interface{}) { if !condition { diff --git a/tx.go b/tx.go index 7cdadae..fd456eb 100644 --- a/tx.go +++ b/tx.go @@ -2,6 +2,7 @@ package bolt import ( "errors" + "fmt" "sort" "time" "unsafe" @@ -175,6 +176,14 @@ func (tx *Tx) Commit() error { return err } + // If strict mode is enabled then perform a consistency check. + if tx.db.StrictMode { + if err := tx.Check(); err != nil { + err := err.(ErrorList) + panic("check fail: " + err.Error() + ": " + err.join("; ")) + } + } + // Write meta to disk. if err := tx.writeMeta(); err != nil { tx.close() @@ -218,6 +227,89 @@ func (tx *Tx) close() { tx.db = nil } +// Check performs several consistency checks on the database for this transaction. +// An error is returned if any inconsistency is found or if executed on a read-only transaction. +func (tx *Tx) Check() error { + if !tx.writable { + return ErrTxNotWritable + } + + var errors ErrorList + + // Check if any pages are double freed. + freed := make(map[pgid]bool) + for _, id := range tx.db.freelist.all() { + if freed[id] { + errors = append(errors, fmt.Errorf("page %d: already freed", id)) + } + freed[id] = true + } + + // Track every reachable page. + reachable := make(map[pgid]*page) + reachable[0] = tx.page(0) // meta0 + reachable[1] = tx.page(1) // meta1 + for i := uint32(0); i <= tx.page(tx.meta.freelist).overflow; i++ { + reachable[tx.meta.freelist+pgid(i)] = tx.page(tx.meta.freelist) + } + + // Recursively check buckets. + tx.checkBucket(&tx.root, reachable, &errors) + + // Ensure all pages below high water mark are either reachable or freed. + for i := pgid(0); i < tx.meta.pgid; i++ { + _, isReachable := reachable[i] + if !isReachable && !freed[i] { + errors = append(errors, fmt.Errorf("page %d: unreachable unfreed", int(i))) + } else if isReachable && freed[i] { + errors = append(errors, fmt.Errorf("page %d: reachable freed", int(i))) + } + } + + if len(errors) > 0 { + return errors + } + + return nil +} + +func (tx *Tx) checkBucket(b *Bucket, reachable map[pgid]*page, errors *ErrorList) { + // Ignore inline buckets. + if b.root == 0 { + return + } + + // Check every page used by this bucket. + b.tx.forEachPage(b.root, 0, func(p *page, _ int) { + // Ensure each page is only referenced once. + for i := pgid(0); i <= pgid(p.overflow); i++ { + var id = p.id + i + if _, ok := reachable[id]; ok { + *errors = append(*errors, fmt.Errorf("page %d: multiple references", int(id))) + } + reachable[id] = p + } + + // Retrieve page info. + info, err := b.tx.Page(int(p.id)) + if err != nil { + *errors = append(*errors, err) + } else if info == nil { + *errors = append(*errors, fmt.Errorf("page %d: out of bounds: %d", int(p.id), int(b.tx.meta.pgid))) + } else if info.Type != "branch" && info.Type != "leaf" { + *errors = append(*errors, fmt.Errorf("page %d: invalid type: %s", int(p.id), info.Type)) + } + }) + + // Check each bucket within this bucket. + _ = b.ForEach(func(k, v []byte) error { + if child := b.Bucket(k); child != nil { + tx.checkBucket(child, reachable, errors) + } + return nil + }) +} + // allocate returns a contiguous block of memory starting at a given page. func (tx *Tx) allocate(count int) (*page, error) { p, err := tx.db.allocate(count) diff --git a/tx_test.go b/tx_test.go index e466765..7bf369b 100644 --- a/tx_test.go +++ b/tx_test.go @@ -288,6 +288,31 @@ func TestTx_OnCommit_Rollback(t *testing.T) { assert.Equal(t, 0, x) } +// Ensure that a Tx in strict mode will fail when corrupted. +func TestTx_Check_Corrupt(t *testing.T) { + var msg string + func() { + defer func() { + msg = fmt.Sprintf("%s", recover()) + }() + + withOpenDB(func(db *DB, path string) { + db.StrictMode = true + db.Update(func(tx *Tx) error { + tx.CreateBucket([]byte("foo")) + + // Corrupt the DB by adding a page to the freelist. + warn("---") + db.freelist.free(0, tx.page(3)) + + return nil + }) + }) + }() + + assert.Equal(t, "check fail: 1 errors occurred: page 3: already freed", msg) +} + func ExampleTx_Rollback() { // Open the database. db, _ := Open(tempfile(), 0666)