mirror of https://github.com/hak5/overseer.git
426 lines
10 KiB
Go
426 lines
10 KiB
Go
package overseer
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/kardianos/osext"
|
|
)
|
|
|
|
var tmpBinPath = filepath.Join(os.TempDir(), "overseer-"+token())
|
|
|
|
//a overseer master process
|
|
type master struct {
|
|
*Config
|
|
slaveID int
|
|
slaveCmd *exec.Cmd
|
|
slaveExtraFiles []*os.File
|
|
binPath, tmpBinPath string
|
|
binPerms os.FileMode
|
|
binHash []byte
|
|
restartMux sync.Mutex
|
|
restarting bool
|
|
restartedAt time.Time
|
|
restarted chan bool
|
|
awaitingUSR1 bool
|
|
descriptorsReleased chan bool
|
|
signalledAt time.Time
|
|
printCheckUpdate bool
|
|
}
|
|
|
|
func (mp *master) run() error {
|
|
mp.debugf("run")
|
|
if err := mp.checkBinary(); err != nil {
|
|
return err
|
|
}
|
|
if mp.Config.Fetcher != nil {
|
|
if err := mp.Config.Fetcher.Init(); err != nil {
|
|
mp.warnf("fetcher init failed (%s). fetcher disabled.", err)
|
|
mp.Config.Fetcher = nil
|
|
}
|
|
}
|
|
mp.setupSignalling()
|
|
if err := mp.retreiveFileDescriptors(); err != nil {
|
|
return err
|
|
}
|
|
if mp.Config.Fetcher != nil {
|
|
mp.printCheckUpdate = true
|
|
mp.fetch()
|
|
go mp.fetchLoop()
|
|
}
|
|
return mp.forkLoop()
|
|
}
|
|
|
|
func (mp *master) checkBinary() error {
|
|
//get path to binary and confirm its writable
|
|
binPath, err := osext.Executable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to find binary path (%s)", err)
|
|
}
|
|
mp.binPath = binPath
|
|
if info, err := os.Stat(binPath); err != nil {
|
|
return fmt.Errorf("failed to stat binary (%s)", err)
|
|
} else if info.Size() == 0 {
|
|
return fmt.Errorf("binary file is empty")
|
|
} else {
|
|
//copy permissions
|
|
mp.binPerms = info.Mode()
|
|
}
|
|
f, err := os.Open(binPath)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot read binary (%s)", err)
|
|
}
|
|
//initial hash of file
|
|
hash := sha1.New()
|
|
io.Copy(hash, f)
|
|
mp.binHash = hash.Sum(nil)
|
|
f.Close()
|
|
//test bin<->tmpbin moves
|
|
if err := move(tmpBinPath, mp.binPath); err != nil {
|
|
return fmt.Errorf("cannot move binary (%s)", err)
|
|
}
|
|
if err := move(mp.binPath, tmpBinPath); err != nil {
|
|
return fmt.Errorf("cannot move binary back (%s)", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mp *master) setupSignalling() {
|
|
//updater-forker comms
|
|
mp.restarted = make(chan bool)
|
|
mp.descriptorsReleased = make(chan bool)
|
|
//read all master process signals
|
|
signals := make(chan os.Signal)
|
|
signal.Notify(signals)
|
|
go func() {
|
|
for s := range signals {
|
|
mp.handleSignal(s)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (mp *master) handleSignal(s os.Signal) {
|
|
if s == mp.RestartSignal {
|
|
//user initiated manual restart
|
|
mp.triggerRestart()
|
|
} else if s.String() == "child exited" {
|
|
// will occur on every restart, ignore it
|
|
} else
|
|
//**during a restart** a SIGUSR1 signals
|
|
//to the master process that, the file
|
|
//descriptors have been released
|
|
if mp.awaitingUSR1 && s == SIGUSR1 {
|
|
mp.debugf("signaled, sockets ready")
|
|
mp.awaitingUSR1 = false
|
|
mp.descriptorsReleased <- true
|
|
} else
|
|
//while the slave process is running, proxy
|
|
//all signals through
|
|
if mp.slaveCmd != nil && mp.slaveCmd.Process != nil {
|
|
mp.debugf("proxy signal (%s)", s)
|
|
mp.sendSignal(s)
|
|
} else
|
|
//otherwise if not running, kill on CTRL+c
|
|
if s == os.Interrupt {
|
|
mp.debugf("interupt with no slave")
|
|
os.Exit(1)
|
|
} else {
|
|
mp.debugf("signal discarded (%s), no slave process", s)
|
|
}
|
|
}
|
|
|
|
func (mp *master) sendSignal(s os.Signal) {
|
|
if err := mp.slaveCmd.Process.Signal(s); err != nil {
|
|
mp.debugf("signal failed (%s), assuming slave process died unexpectedly", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func (mp *master) retreiveFileDescriptors() error {
|
|
mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses))
|
|
for i, addr := range mp.Config.Addresses {
|
|
a, err := net.ResolveTCPAddr("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("Invalid address %s (%s)", addr, err)
|
|
}
|
|
l, err := net.ListenTCP("tcp", a)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
f, err := l.File()
|
|
if err != nil {
|
|
return fmt.Errorf("Failed to retreive fd for: %s (%s)", addr, err)
|
|
}
|
|
if err := l.Close(); err != nil {
|
|
return fmt.Errorf("Failed to close listener for: %s (%s)", addr, err)
|
|
}
|
|
mp.slaveExtraFiles[i] = f
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//fetchLoop is run in a goroutine
|
|
func (mp *master) fetchLoop() {
|
|
min := mp.Config.MinFetchInterval
|
|
time.Sleep(min)
|
|
for {
|
|
t0 := time.Now()
|
|
mp.fetch()
|
|
//duration fetch of fetch
|
|
diff := time.Now().Sub(t0)
|
|
if diff < min {
|
|
delay := min - diff
|
|
//ensures at least MinFetchInterval delay.
|
|
//should be throttled by the fetcher!
|
|
time.Sleep(delay)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *master) fetch() {
|
|
if mp.restarting {
|
|
return //skip if restarting
|
|
}
|
|
if mp.printCheckUpdate {
|
|
mp.debugf("checking for updates...")
|
|
}
|
|
reader, err := mp.Fetcher.Fetch()
|
|
if err != nil {
|
|
mp.debugf("failed to get latest version: %s", err)
|
|
return
|
|
}
|
|
if reader == nil {
|
|
if mp.printCheckUpdate {
|
|
mp.debugf("no updates")
|
|
}
|
|
mp.printCheckUpdate = false
|
|
return //fetcher has explicitly said there are no updates
|
|
}
|
|
mp.printCheckUpdate = true
|
|
mp.debugf("streaming update...")
|
|
//optional closer
|
|
if closer, ok := reader.(io.Closer); ok {
|
|
defer closer.Close()
|
|
}
|
|
tmpBin, err := os.OpenFile(tmpBinPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
|
|
if err != nil {
|
|
mp.warnf("failed to open temp binary: %s", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
tmpBin.Close()
|
|
os.Remove(tmpBinPath)
|
|
}()
|
|
//tee off to sha1
|
|
hash := sha1.New()
|
|
reader = io.TeeReader(reader, hash)
|
|
//write to a temp file
|
|
_, err = io.Copy(tmpBin, reader)
|
|
if err != nil {
|
|
mp.warnf("failed to write temp binary: %s", err)
|
|
return
|
|
}
|
|
//compare hash
|
|
newHash := hash.Sum(nil)
|
|
if bytes.Equal(mp.binHash, newHash) {
|
|
mp.debugf("hash match - skip")
|
|
return
|
|
}
|
|
//copy permissions
|
|
if err := chmod(tmpBin, mp.binPerms); err != nil {
|
|
mp.warnf("failed to make temp binary executable: %s", err)
|
|
return
|
|
}
|
|
if err := chown(tmpBin, uid, gid); err != nil {
|
|
mp.warnf("failed to change owner of binary: %s", err)
|
|
return
|
|
}
|
|
if _, err := tmpBin.Stat(); err != nil {
|
|
mp.warnf("failed to stat temp binary: %s", err)
|
|
return
|
|
}
|
|
tmpBin.Close()
|
|
if _, err := os.Stat(tmpBinPath); err != nil {
|
|
mp.warnf("failed to stat temp binary by path: %s", err)
|
|
return
|
|
}
|
|
if mp.Config.PreUpgrade != nil {
|
|
if err := mp.Config.PreUpgrade(tmpBinPath); err != nil {
|
|
mp.warnf("user cancelled upgrade: %s", err)
|
|
return
|
|
}
|
|
}
|
|
//overseer sanity check, dont replace our good binary with a non-executable file
|
|
tokenIn := token()
|
|
cmd := exec.Command(tmpBinPath)
|
|
cmd.Env = []string{envBinCheck + "=" + tokenIn}
|
|
returned := false
|
|
go func() {
|
|
time.Sleep(5 * time.Second)
|
|
if !returned {
|
|
mp.warnf("sanity check against fetched executable timed-out, check overseer is running")
|
|
if cmd.Process != nil {
|
|
cmd.Process.Kill()
|
|
}
|
|
}
|
|
}()
|
|
tokenOut, err := cmd.Output()
|
|
returned = true
|
|
if err != nil {
|
|
mp.warnf("failed to run temp binary: %s", err)
|
|
return
|
|
}
|
|
if tokenIn != string(tokenOut) {
|
|
mp.warnf("sanity check failed")
|
|
return
|
|
}
|
|
//overwrite!
|
|
if err := move(mp.binPath, tmpBinPath); err != nil {
|
|
mp.warnf("failed to overwrite binary: %s", err)
|
|
return
|
|
}
|
|
mp.debugf("upgraded binary (%x -> %x)", mp.binHash[:12], newHash[:12])
|
|
mp.binHash = newHash
|
|
//binary successfully replaced
|
|
if !mp.Config.NoRestartAfterFetch {
|
|
mp.triggerRestart()
|
|
}
|
|
//and keep fetching...
|
|
return
|
|
}
|
|
|
|
func (mp *master) triggerRestart() {
|
|
if mp.restarting {
|
|
mp.debugf("already graceful restarting")
|
|
return //skip
|
|
} else if mp.slaveCmd == nil || mp.restarting {
|
|
mp.debugf("no slave process")
|
|
return //skip
|
|
}
|
|
mp.debugf("graceful restart triggered")
|
|
mp.restarting = true
|
|
mp.awaitingUSR1 = true
|
|
mp.signalledAt = time.Now()
|
|
mp.sendSignal(mp.Config.RestartSignal) //ask nicely to terminate
|
|
select {
|
|
case <-mp.restarted:
|
|
//success
|
|
mp.debugf("restart success")
|
|
case <-time.After(mp.TerminateTimeout):
|
|
//times up mr. process, we did ask nicely!
|
|
mp.debugf("graceful timeout, forcing exit")
|
|
mp.sendSignal(os.Kill)
|
|
}
|
|
}
|
|
|
|
//not a real fork
|
|
func (mp *master) forkLoop() error {
|
|
//loop, restart command
|
|
for {
|
|
if err := mp.fork(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mp *master) fork() error {
|
|
mp.debugf("starting %s", mp.binPath)
|
|
cmd := exec.Command(mp.binPath)
|
|
//mark this new process as the "active" slave process.
|
|
//this process is assumed to be holding the socket files.
|
|
mp.slaveCmd = cmd
|
|
mp.slaveID++
|
|
//provide the slave process with some state
|
|
e := os.Environ()
|
|
e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash))
|
|
e = append(e, envBinPath+"="+mp.binPath)
|
|
e = append(e, envSlaveID+"="+strconv.Itoa(mp.slaveID))
|
|
e = append(e, envIsSlave+"=1")
|
|
e = append(e, envNumFDs+"="+strconv.Itoa(len(mp.slaveExtraFiles)))
|
|
cmd.Env = e
|
|
//inherit master args/stdfiles
|
|
cmd.Args = os.Args
|
|
cmd.Stdin = os.Stdin
|
|
cmd.Stdout = os.Stdout
|
|
cmd.Stderr = os.Stderr
|
|
//include socket files
|
|
cmd.ExtraFiles = mp.slaveExtraFiles
|
|
if err := cmd.Start(); err != nil {
|
|
return fmt.Errorf("Failed to start slave process: %s", err)
|
|
}
|
|
//was scheduled to restart, notify success
|
|
if mp.restarting {
|
|
mp.restartedAt = time.Now()
|
|
mp.restarting = false
|
|
mp.restarted <- true
|
|
}
|
|
//convert wait into channel
|
|
cmdwait := make(chan error)
|
|
go func() {
|
|
cmdwait <- cmd.Wait()
|
|
}()
|
|
//wait....
|
|
select {
|
|
case err := <-cmdwait:
|
|
//program exited before releasing descriptors
|
|
//proxy exit code out to master
|
|
code := 0
|
|
if err != nil {
|
|
code = 1
|
|
if exiterr, ok := err.(*exec.ExitError); ok {
|
|
if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
|
|
code = status.ExitStatus()
|
|
}
|
|
}
|
|
}
|
|
mp.debugf("prog exited with %d", code)
|
|
//if a restarts are disabled or if it was an
|
|
//unexpected creash, proxy this exit straight
|
|
//through to the main process
|
|
if mp.NoRestart || !mp.restarting {
|
|
os.Exit(code)
|
|
}
|
|
case <-mp.descriptorsReleased:
|
|
//if descriptors are released, the program
|
|
//has yielded control of its sockets and
|
|
//a parallel instance of the program can be
|
|
//started safely. it should serve state.Listeners
|
|
//to ensure downtime is kept at <1sec. The previous
|
|
//cmd.Wait() will still be consumed though the
|
|
//result will be discarded.
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mp *master) debugf(f string, args ...interface{}) {
|
|
if mp.Config.Debug {
|
|
log.Printf("[overseer master] "+f, args...)
|
|
}
|
|
}
|
|
|
|
func (mp *master) warnf(f string, args ...interface{}) {
|
|
if mp.Config.Debug || !mp.Config.NoWarn {
|
|
log.Printf("[overseer master] "+f, args...)
|
|
}
|
|
}
|
|
|
|
func token() string {
|
|
buff := make([]byte, 8)
|
|
rand.Read(buff)
|
|
return hex.EncodeToString(buff)
|
|
}
|