mirror of https://github.com/hak5/overseer.git
progress on v2, graceful restarts not working yet
parent
457309e897
commit
9ef850b1bc
|
@ -0,0 +1,2 @@
|
|||
tmp
|
||||
myapp*
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
#NOTE: DONT CTRL+C OR CLEANUP WONT OCCUR
|
||||
|
||||
#upgrade server (any http server)
|
||||
#binary hosting server (any file server)
|
||||
# go get github.com/jpillora/serve
|
||||
serve &
|
||||
|
||||
#initial build
|
||||
echo "BUILDING APP 0.3.0"
|
||||
go build -ldflags "-X main.VERSION 0.3.0" -o myapp
|
||||
echo "BUILDING APP: A"
|
||||
go build -ldflags "-X main.FOO=A" -o myapp
|
||||
|
||||
#run!
|
||||
echo "RUNNING APP"
|
||||
|
@ -16,17 +16,17 @@ echo "RUNNING APP"
|
|||
|
||||
sleep 3
|
||||
|
||||
echo "BUILDING APP 0.3.1"
|
||||
go build -ldflags "-X main.VERSION 0.3.1" -o myapp.0.3.1
|
||||
echo "BUILDING APP: B"
|
||||
go build -ldflags "-X main.FOO=B" -o newmyapp
|
||||
|
||||
sleep 4
|
||||
|
||||
echo "BUILDING APP 0.4.0"
|
||||
go build -ldflags "-X main.VERSION 0.4.0" -o myapp.0.4.0
|
||||
echo "BUILDING APP: C"
|
||||
go build -ldflags "-X main.FOO=C" -o newmyapp
|
||||
|
||||
sleep 4
|
||||
|
||||
#end demo - cleanup
|
||||
killall serve
|
||||
killall myapp
|
||||
rm myapp* 2> /dev/null
|
||||
rm myapp* 2> /dev/null
|
||||
|
|
|
@ -1,33 +1,40 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/jpillora/go-upgrade"
|
||||
"github.com/jpillora/go-upgrade/fetcher"
|
||||
)
|
||||
|
||||
var VERSION = "0.0.0" //set with ldflags
|
||||
var FOO = "" //set manually or with with ldflags
|
||||
|
||||
//change your 'main' into a 'prog'
|
||||
func prog() {
|
||||
log.Printf("Running version %s...", VERSION)
|
||||
select {}
|
||||
//convert your 'main()' into a 'prog(state)'
|
||||
func prog(state upgrade.State) {
|
||||
log.Printf("app (%s) listening...", state.ID)
|
||||
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Foo", FOO)
|
||||
w.Header().Set("Header-Time", time.Now().String())
|
||||
w.WriteHeader(200)
|
||||
time.Sleep(30 * time.Second)
|
||||
fmt.Fprintf(w, "Body-Time: %s (Foo: %s)", time.Now(), FOO)
|
||||
}))
|
||||
http.Serve(state.Listener, nil)
|
||||
}
|
||||
|
||||
//then create another 'main' which runs the upgrades
|
||||
func main() {
|
||||
upgrade.Run(upgrade.Config{
|
||||
Program: prog,
|
||||
Version: VERSION,
|
||||
Fetcher: upgrade.BasicFetcher(
|
||||
"http://localhost:3000/myapp_{{.Version}}",
|
||||
),
|
||||
FetchInterval: 2 * time.Second,
|
||||
Signal: os.Interrupt,
|
||||
//display logs of actions
|
||||
Logging: true,
|
||||
Address: "0.0.0.0:3000",
|
||||
Fetcher: &fetcher.HTTP{
|
||||
URL: "http://localhost:4000/myapp2",
|
||||
Interval: 5 * time.Second,
|
||||
},
|
||||
Logging: true, //display log of go-upgrade actions
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
lee
|
|
@ -1 +0,0 @@
|
|||
nee
|
|
@ -1,5 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
type Fetcher interface {
|
||||
Fetch(currentVersion string) (binary []byte, err error)
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package fetcher
|
||||
|
||||
import "io"
|
||||
|
||||
type Interface interface {
|
||||
//Fetch should check if there is an updated
|
||||
//binary to fetch, and then stream it back the
|
||||
//form of an io.Reader. If io.Reader is nil,
|
||||
//then it is assumed there are no updates.
|
||||
Fetch() (io.Reader, error)
|
||||
}
|
||||
|
||||
//Converts a fetch function into interface
|
||||
func Func(fn func() (io.Reader, error)) Interface {
|
||||
return &fetcher{fn}
|
||||
}
|
||||
|
||||
type fetcher struct {
|
||||
fn func() (io.Reader, error)
|
||||
}
|
||||
|
||||
func (f fetcher) Fetch() (io.Reader, error) {
|
||||
return f.fn()
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
package fetcher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
//HTTPFetcher uses HEAD requests to poll the status of a given
|
||||
//file. If it detects this file has been updated, it will fetch
|
||||
//and stream out to the binary writer.
|
||||
type HTTP struct {
|
||||
//URL to poll for new binaries
|
||||
URL string
|
||||
Interval time.Duration
|
||||
//interal state
|
||||
delay bool
|
||||
lasts map[string]string
|
||||
}
|
||||
|
||||
//if any of these change, the binary has been updated
|
||||
var httpHeaders = []string{"ETag", "If-Modified-Since", "Last-Modified", "Content-Length"}
|
||||
|
||||
func (h *HTTP) Fetch() (io.Reader, error) {
|
||||
//apply defaults
|
||||
if h.Interval == 0 {
|
||||
h.Interval = 5 * time.Minute
|
||||
}
|
||||
if h.lasts == nil {
|
||||
h.lasts = map[string]string{}
|
||||
}
|
||||
//delay fetches after first
|
||||
if h.delay {
|
||||
time.Sleep(h.Interval)
|
||||
}
|
||||
h.delay = true
|
||||
//status check using HEAD
|
||||
resp, err := http.Head(h.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HEAD request failed (%s)", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HEAD request failed (status code %d)", resp.StatusCode)
|
||||
}
|
||||
//if all headers match, skip update
|
||||
matches, total := 0, 0
|
||||
for _, header := range httpHeaders {
|
||||
if curr := resp.Header.Get(header); curr != "" {
|
||||
if last, ok := h.lasts[header]; ok && last == curr {
|
||||
matches++
|
||||
}
|
||||
h.lasts[header] = curr
|
||||
total++
|
||||
}
|
||||
}
|
||||
if matches == total {
|
||||
return nil, nil //skip, file match
|
||||
}
|
||||
//binary fetch using GET
|
||||
resp, err = http.Get(h.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GET request failed (%s)", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GET request failed (status code %d)", resp.StatusCode)
|
||||
}
|
||||
//success!
|
||||
return resp.Body, nil
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package fetcher
|
||||
|
||||
//Similar to ioutil.ReadAll except it extracts binaries from
|
||||
//the reader, whether the reader is a .zip .tar .tar.gz .gz or raw bytes
|
||||
// func GetBinary(path string, r io.Reader) ([]byte, error) {
|
||||
//
|
||||
// if strings.HasSuffix(path, ".gz") {
|
||||
// gr, err := gzip.NewReader(r)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// r = gr
|
||||
// path = strings.TrimSuffix(path, ".gz")
|
||||
// }
|
||||
//
|
||||
// if strings.HasSuffix(path, ".tar") {
|
||||
// tr := tar.NewReader(r)
|
||||
// var fr io.Reader
|
||||
// for {
|
||||
// info, err := tr.Next()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if os.FileMode(info.Mode)&0111 != 0 {
|
||||
// log.Printf("found exec %s", info.Name)
|
||||
// fr = tr
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// if fr == nil {
|
||||
// return nil, fmt.Errorf("binary not found in tar archive")
|
||||
// }
|
||||
// r = fr
|
||||
//
|
||||
// } else if strings.HasSuffix(path, ".zip") {
|
||||
// bin, err := ioutil.ReadAll(r)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// buff := bytes.NewReader(bin)
|
||||
// zr, err := zip.NewReader(buff, int64(buff.Len()))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// var fr io.Reader
|
||||
// for _, f := range zr.File {
|
||||
// info := f.FileInfo()
|
||||
// if info.Mode()&0111 != 0 {
|
||||
// log.Printf("found exec %s", info.Name())
|
||||
// fr, err = f.Open()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if fr == nil {
|
||||
// return nil, fmt.Errorf("binary not found in zip archive")
|
||||
// }
|
||||
// r = fr
|
||||
// }
|
||||
//
|
||||
// return ioutil.ReadAll(r)
|
||||
// }
|
126
fetcher_basic.go
126
fetcher_basic.go
|
@ -1,126 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
type basicFetcher struct {
|
||||
url string
|
||||
urlTempl *template.Template
|
||||
}
|
||||
|
||||
func BasicFetcher(url string) Fetcher {
|
||||
t := template.New("url")
|
||||
t, err := t.Parse(url)
|
||||
if err != nil {
|
||||
log.Fatalf("upgrade.BasicFetcher.url invalid: %s", err)
|
||||
}
|
||||
b := &basicFetcher{
|
||||
url: url,
|
||||
urlTempl: t,
|
||||
}
|
||||
//test template
|
||||
b.getURL("0.1.0")
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *basicFetcher) getURL(version string) string {
|
||||
//run url template with this version
|
||||
var urlb bytes.Buffer
|
||||
if err := b.urlTempl.Execute(&urlb, struct {
|
||||
Version, OS, Arch string
|
||||
}{
|
||||
version, runtime.GOOS, runtime.GOARCH,
|
||||
}); err != nil {
|
||||
//execute will fail if theres a data error
|
||||
log.Fatalf("upgrade.BasicFetcher.url invalid: %s", err)
|
||||
}
|
||||
return urlb.String()
|
||||
}
|
||||
|
||||
func (b *basicFetcher) Fetch(currentVersion string) ([]byte, error) {
|
||||
|
||||
//get version permutations
|
||||
versions, err := getAllVersionIncrements(currentVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bin []byte
|
||||
var errs []string
|
||||
//try all versions
|
||||
for _, v := range versions {
|
||||
url := b.getURL(v)
|
||||
resp, e := http.Get(url)
|
||||
if e != nil {
|
||||
errs = append(errs, "invalid request")
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
errs = append(errs, v)
|
||||
continue
|
||||
}
|
||||
|
||||
b, err := ReadAll(url, resp.Body)
|
||||
if err != nil {
|
||||
errs = append(errs, "download binary failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
//success!
|
||||
bin = b
|
||||
break
|
||||
}
|
||||
|
||||
if bin != nil {
|
||||
return bin, nil
|
||||
}
|
||||
return nil, fmt.Errorf(strings.Join(errs, ", "))
|
||||
}
|
||||
|
||||
func getAllVersionIncrements(version string) ([]string, error) {
|
||||
var versions []string
|
||||
curr := 0
|
||||
for {
|
||||
re := regexp.MustCompile(`\d+`)
|
||||
groups := re.FindAllString(version, -1)
|
||||
numGroups := len(groups)
|
||||
if numGroups == 0 {
|
||||
return nil, fmt.Errorf("No digits to increment in version: %s", version)
|
||||
}
|
||||
i := 0
|
||||
//we replace the version string numGroup times, swapping out one group at time
|
||||
v := re.ReplaceAllStringFunc(version, func(d string) string {
|
||||
l := len(d)
|
||||
//going from right to left
|
||||
if i == numGroups-1-curr {
|
||||
n, _ := strconv.Atoi(d)
|
||||
d = strconv.Itoa(n + 1)
|
||||
} else if i > numGroups-1-curr {
|
||||
//reset all numbers to the right to 0
|
||||
d = "0"
|
||||
}
|
||||
for len(d) < l {
|
||||
d = "0" + d
|
||||
}
|
||||
i++
|
||||
return d
|
||||
})
|
||||
versions = append(versions, v)
|
||||
curr++
|
||||
if curr == numGroups {
|
||||
break
|
||||
}
|
||||
}
|
||||
return versions, nil
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVersionIncs(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
version string
|
||||
expected []string
|
||||
}{
|
||||
{"0.1.0", []string{"0.1.1", "0.2.0", "1.0.0"}},
|
||||
{"0.3.1", []string{"0.3.2", "0.4.0", "1.0.0"}},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
||||
vers, err := getAllVersionIncrements(test.version)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(vers, test.expected) {
|
||||
t.Fatalf("test %d failed:\nexpecting: %#v\n got: %#v", i, test.expected, vers)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//gracefully closing net.Listener
|
||||
type upListener struct {
|
||||
net.Listener
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (l *upListener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.(*net.TCPListener).AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.SetKeepAlive(true) // see http.tcpKeepAliveListener
|
||||
conn.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
|
||||
uconn := upConn{
|
||||
Conn: conn,
|
||||
wg: &l.wg,
|
||||
}
|
||||
l.wg.Add(1)
|
||||
return uconn, nil
|
||||
}
|
||||
|
||||
func (l *upListener) File() *os.File {
|
||||
// returns a dup(2) - FD_CLOEXEC flag *not* set
|
||||
tl := l.Listener.(*net.TCPListener)
|
||||
fl, _ := tl.File()
|
||||
return fl
|
||||
}
|
||||
|
||||
type upConn struct {
|
||||
net.Conn
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (uconn upConn) Close() error {
|
||||
err := uconn.Conn.Close()
|
||||
if err == nil {
|
||||
uconn.wg.Done()
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package upgrade
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGraceful(t *testing.T) {
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,310 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/osext"
|
||||
)
|
||||
|
||||
var tmpBinPath = filepath.Join(os.TempDir(), "goupgrade")
|
||||
|
||||
//a go-upgrade master process
|
||||
type master struct {
|
||||
Config
|
||||
slaveCmd *exec.Cmd
|
||||
slaveExtraFiles []*os.File
|
||||
binPath string
|
||||
binPerms os.FileMode
|
||||
binHash []byte
|
||||
restarting bool
|
||||
restartedAt time.Time
|
||||
restarted chan bool
|
||||
descriptorsReleased chan bool
|
||||
signalledAt time.Time
|
||||
signals chan os.Signal
|
||||
}
|
||||
|
||||
func (mp *master) run() {
|
||||
mp.readBinary()
|
||||
mp.setupSignalling()
|
||||
mp.retreiveFileDescriptors()
|
||||
go mp.fetchLoop()
|
||||
mp.forkLoop()
|
||||
}
|
||||
|
||||
func (mp *master) readBinary() {
|
||||
//get path to binary and confirm its writable
|
||||
binPath, err := osext.Executable()
|
||||
binFound := false
|
||||
binWritable := false
|
||||
if err != nil {
|
||||
mp.logf("failed to find binary path")
|
||||
} else if f, err := os.OpenFile(binPath, os.O_RDWR, os.ModePerm); err == nil {
|
||||
if info, err := f.Stat(); err == nil && info.Size() > 0 {
|
||||
mp.binPath = binPath
|
||||
binFound = true
|
||||
//initial hash of file
|
||||
hash := sha1.New()
|
||||
io.Copy(hash, f)
|
||||
mp.binHash = hash.Sum(nil)
|
||||
//copy permissions
|
||||
mp.binPerms = info.Mode()
|
||||
//test write
|
||||
sample := make([]byte, 1)
|
||||
if n, err := f.ReadAt(sample, 0); err == nil && n == 1 {
|
||||
//read 1 byte, now write
|
||||
if n, err = f.WriteAt(sample, 0); err == nil && n == 1 {
|
||||
//write success
|
||||
binWritable = true
|
||||
}
|
||||
}
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
//is the program? or failed to find the writable binary path?
|
||||
if !binWritable {
|
||||
var err error
|
||||
if !binFound {
|
||||
err = fmt.Errorf("binary path not found: %s", binPath)
|
||||
} else if !binWritable {
|
||||
err = fmt.Errorf("binary path not writable: %s", binPath)
|
||||
}
|
||||
if err != nil {
|
||||
if mp.Config.Optional {
|
||||
log.Print(err)
|
||||
} else {
|
||||
fatalf("%s", err)
|
||||
}
|
||||
}
|
||||
mp.Program(DisabledState)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *master) setupSignalling() {
|
||||
//updater-forker comms
|
||||
mp.restarted = make(chan bool)
|
||||
mp.descriptorsReleased = make(chan bool)
|
||||
//read all master process signals
|
||||
mp.signals = make(chan os.Signal)
|
||||
signal.Notify(mp.signals)
|
||||
go func() {
|
||||
for s := range mp.signals {
|
||||
|
||||
if s.String() == "child exited" {
|
||||
continue
|
||||
}
|
||||
|
||||
//**during a restart** a SIGUSR1 signals
|
||||
//to the master process that, the file
|
||||
//descriptors have been released
|
||||
if mp.restarting && s == syscall.SIGUSR1 {
|
||||
mp.descriptorsReleased <- true
|
||||
continue
|
||||
}
|
||||
|
||||
if mp.slaveCmd != nil && mp.slaveCmd.Process != nil {
|
||||
mp.logf("proxy signal (%s)", s)
|
||||
mp.slaveCmd.Process.Signal(s)
|
||||
} else if s == syscall.SIGINT {
|
||||
mp.logf("interupt with no slave")
|
||||
os.Exit(1)
|
||||
} else {
|
||||
mp.logf("signal discarded (%s), no slave process", s)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (mp *master) retreiveFileDescriptors() {
|
||||
mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses))
|
||||
for i, addr := range mp.Config.Addresses {
|
||||
a, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
fatalf("invalid address: %s (%s)", addr, err)
|
||||
}
|
||||
l, err := net.ListenTCP("tcp", a)
|
||||
if err != nil {
|
||||
fatalf(err.Error())
|
||||
}
|
||||
f, err := l.File()
|
||||
if err != nil {
|
||||
fatalf("failed to retreive fd for: %s (%s)", addr, err)
|
||||
}
|
||||
if err := l.Close(); err != nil {
|
||||
fatalf("failed to close listener for: %s (%s)", addr, err)
|
||||
}
|
||||
mp.slaveExtraFiles[i] = f
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *master) fetchLoop() {
|
||||
for {
|
||||
mp.fetch()
|
||||
time.Sleep(time.Second) //fetches should be throttled by the fetcher!
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *master) fetch() {
|
||||
mp.logf("checking for updates...")
|
||||
reader, err := mp.Fetcher.Fetch()
|
||||
if err != nil {
|
||||
mp.logf("failed to get latest version: %s", err)
|
||||
return
|
||||
}
|
||||
if reader == nil {
|
||||
return //fetcher has explicitly said there are no updates
|
||||
}
|
||||
//optional closer
|
||||
if closer, ok := reader.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
tmpBin, err := os.Create(tmpBinPath)
|
||||
if err != nil {
|
||||
mp.logf("failed to open temp binary: %s", err)
|
||||
return
|
||||
}
|
||||
defer os.Remove(tmpBinPath)
|
||||
//tee off to sha1
|
||||
hash := sha1.New()
|
||||
reader = io.TeeReader(reader, hash)
|
||||
//write to temp
|
||||
_, err = io.Copy(tmpBin, reader)
|
||||
if err != nil {
|
||||
mp.logf("failed to write temp binary: %s", err)
|
||||
return
|
||||
}
|
||||
//compare hash
|
||||
newHash := hash.Sum(nil)
|
||||
if bytes.Equal(mp.binHash, newHash) {
|
||||
return
|
||||
}
|
||||
//copy permissions
|
||||
if err := tmpBin.Chmod(mp.binPerms); err != nil {
|
||||
mp.logf("failed to make binary executable: %s", err)
|
||||
return
|
||||
}
|
||||
tmpBin.Close()
|
||||
//sanity check
|
||||
buff := make([]byte, 8)
|
||||
rand.Read(buff)
|
||||
tokenIn := hex.EncodeToString(buff)
|
||||
cmd := exec.Command(tmpBinPath)
|
||||
cmd.Env = []string{envBinCheck + "=" + tokenIn}
|
||||
tokenOut, err := cmd.Output()
|
||||
if err != nil {
|
||||
mp.logf("failed to run temp binary: %s", err)
|
||||
return
|
||||
}
|
||||
if tokenIn != string(tokenOut) {
|
||||
mp.logf("sanity check failed")
|
||||
return
|
||||
}
|
||||
//replace!
|
||||
if err := os.Rename(tmpBinPath, mp.binPath); err != nil {
|
||||
mp.logf("failed to replace binary: %s", err)
|
||||
return
|
||||
}
|
||||
mp.logf("upgraded binary (%x -> %x)", mp.binHash[:12], newHash[:12])
|
||||
mp.binHash = newHash
|
||||
//binary successfully replaced, perform graceful restart
|
||||
mp.restarting = true
|
||||
mp.signalledAt = time.Now()
|
||||
mp.signals <- syscall.SIGTERM //ask nicely to terminate
|
||||
select {
|
||||
case <-mp.restarted:
|
||||
//success
|
||||
case <-time.After(mp.TerminateTimeout):
|
||||
//times up process, we did ask nicely!
|
||||
mp.logf("graceful timeout, forcing exit")
|
||||
mp.signals <- syscall.SIGKILL
|
||||
}
|
||||
//and keep fetching...
|
||||
return
|
||||
}
|
||||
|
||||
//not a real fork
|
||||
func (mp *master) forkLoop() {
|
||||
//loop, restart command
|
||||
for {
|
||||
mp.fork()
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *master) fork() {
|
||||
mp.logf("starting %s", mp.binPath)
|
||||
cmd := exec.Command(mp.binPath)
|
||||
mp.slaveCmd = cmd
|
||||
|
||||
e := os.Environ()
|
||||
e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash))
|
||||
e = append(e, envIsSlave+"=1")
|
||||
e = append(e, envNumFDs+"="+strconv.Itoa(len(mp.Config.Addresses)))
|
||||
cmd.Env = e
|
||||
cmd.Args = os.Args
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.ExtraFiles = mp.slaveExtraFiles
|
||||
if err := cmd.Start(); err != nil {
|
||||
fatalf("failed to fork: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if mp.restarting {
|
||||
mp.restartedAt = time.Now()
|
||||
mp.restarting = false
|
||||
mp.restarted <- true
|
||||
}
|
||||
//convert wait into channel
|
||||
cmdwait := make(chan error)
|
||||
go func() {
|
||||
cmdwait <- cmd.Wait()
|
||||
}()
|
||||
//wait....
|
||||
select {
|
||||
case err := <-cmdwait:
|
||||
//program exited before releasing descriptors
|
||||
if mp.restarting {
|
||||
//restart requested
|
||||
return
|
||||
}
|
||||
//proxy exit code out to master
|
||||
code := 0
|
||||
if err != nil {
|
||||
code = 1
|
||||
if exiterr, ok := err.(*exec.ExitError); ok {
|
||||
if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
|
||||
code = status.ExitStatus()
|
||||
}
|
||||
}
|
||||
}
|
||||
mp.logf("prog exited with %d", code)
|
||||
//proxy exit with same code
|
||||
os.Exit(code)
|
||||
case <-mp.descriptorsReleased:
|
||||
log.Printf("descriptors released")
|
||||
//if descriptors are released, the program
|
||||
//is yielding control of the socket and
|
||||
//should restart
|
||||
}
|
||||
}
|
||||
|
||||
func (mp *master) logf(f string, args ...interface{}) {
|
||||
if mp.Logging {
|
||||
log.Printf("[go-upgrade master] "+f, args...)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
//DisabledState is a placeholder state for when
|
||||
//go-upgrade is disabled and the program function
|
||||
//is run manually.
|
||||
DisabledState = State{Enabled: false}
|
||||
)
|
||||
|
||||
type State struct {
|
||||
//whether go-upgrade is running enabled. When enabled,
|
||||
//this program will be running in a child process and
|
||||
//go-upgrade will perform rolling upgrades.
|
||||
Enabled bool
|
||||
//ID is a SHA-1 hash of the current running binary
|
||||
ID string
|
||||
//StartedAt records the start time of the program
|
||||
StartedAt time.Time
|
||||
//Listener is the first net.Listener in Listeners
|
||||
Listener net.Listener
|
||||
//Listeners are the set of acquired sockets by the master
|
||||
//process. These are all passed into this program in the
|
||||
//same order they are specified in Config.Addresses.
|
||||
Listeners []net.Listener
|
||||
}
|
||||
|
||||
//a go-upgrade slave process
|
||||
|
||||
type slave struct {
|
||||
Config
|
||||
listeners []*upListener
|
||||
state State
|
||||
}
|
||||
|
||||
func (sp *slave) run() {
|
||||
sp.state.Enabled = true
|
||||
sp.state.ID = os.Getenv(envBinID)
|
||||
sp.state.StartedAt = time.Now()
|
||||
sp.initFileDescriptors()
|
||||
//find parent
|
||||
|
||||
//run program with state
|
||||
sp.logf("start program")
|
||||
sp.Config.Program(sp.state)
|
||||
}
|
||||
|
||||
func (sp *slave) initFileDescriptors() {
|
||||
//inspect file descriptors
|
||||
numFDs, err := strconv.Atoi(os.Getenv(envNumFDs))
|
||||
if err != nil {
|
||||
fatalf("invalid %s integer", envNumFDs)
|
||||
}
|
||||
sp.listeners = make([]*upListener, numFDs)
|
||||
sp.state.Listeners = make([]net.Listener, numFDs)
|
||||
for i := 0; i < numFDs; i++ {
|
||||
f := os.NewFile(uintptr(3+i), "")
|
||||
l, err := net.FileListener(f)
|
||||
if err != nil {
|
||||
fatalf("failed to inherit file descriptor: %d", i)
|
||||
}
|
||||
u := &upListener{Listener: l}
|
||||
sp.listeners[i] = u
|
||||
sp.state.Listeners[i] = u
|
||||
}
|
||||
if len(sp.state.Listeners) > 0 {
|
||||
sp.state.Listener = sp.state.Listeners[0]
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *slave) watchSignal() {
|
||||
signals := make(chan os.Signal)
|
||||
signal.Notify(signals, sp.Config.Signal)
|
||||
go func() {
|
||||
<-signals
|
||||
//do graceful shutdown
|
||||
|
||||
//stop listening
|
||||
|
||||
//signal released fds
|
||||
|
||||
//listeners should be waiting on connections and close
|
||||
}()
|
||||
}
|
||||
|
||||
func (sp *slave) logf(f string, args ...interface{}) {
|
||||
if sp.Logging {
|
||||
log.Printf("[go-upgrade slave] "+f, args...)
|
||||
}
|
||||
}
|
78
read_all.go
78
read_all.go
|
@ -1,78 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
//Similar to ioutil.ReadAll except it extracts binaries from
|
||||
//the reader, whether the reader is a .zip .tar .tar.gz .gz or raw bytes
|
||||
func ReadAll(path string, r io.Reader) ([]byte, error) {
|
||||
|
||||
if strings.HasSuffix(path, ".gz") {
|
||||
gr, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r = gr
|
||||
path = strings.TrimSuffix(path, ".gz")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(path, ".tar") {
|
||||
tr := tar.NewReader(r)
|
||||
var fr io.Reader
|
||||
for {
|
||||
info, err := tr.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if os.FileMode(info.Mode)&0111 != 0 {
|
||||
log.Printf("found exec %s", info.Name)
|
||||
fr = tr
|
||||
break
|
||||
}
|
||||
}
|
||||
if fr == nil {
|
||||
return nil, fmt.Errorf("binary not found in tar archive")
|
||||
}
|
||||
r = fr
|
||||
|
||||
} else if strings.HasSuffix(path, ".zip") {
|
||||
bin, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buff := bytes.NewReader(bin)
|
||||
zr, err := zip.NewReader(buff, int64(buff.Len()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fr io.Reader
|
||||
for _, f := range zr.File {
|
||||
info := f.FileInfo()
|
||||
if info.Mode()&0111 != 0 {
|
||||
log.Printf("found exec %s", info.Name())
|
||||
fr, err = f.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if fr == nil {
|
||||
return nil, fmt.Errorf("binary not found in zip archive")
|
||||
}
|
||||
r = fr
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(r)
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/jpillora/go-upgrade/fetcher"
|
||||
)
|
||||
|
||||
const (
|
||||
envIsSlave = "GO_UPGRADE_IS_SLAVE"
|
||||
envNumFDs = "GO_UPGRADE_NUM_FDS"
|
||||
envBinID = "GO_UPGRADE_BIN_ID"
|
||||
envBinCheck = "GO_UPGRADE_BIN_CHECK"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
//Optional allows go-upgrade to fallback to running
|
||||
//running the program with
|
||||
Optional bool
|
||||
//Program's main function
|
||||
Program func(state State)
|
||||
//Program's zero-downtime socket listening address (set this or Addresses)
|
||||
Address string
|
||||
//Program's zero-downtime socket listening addresses (set this or Address)
|
||||
Addresses []string
|
||||
//Signal program will accept to initiate graceful
|
||||
//application termination. Defaults to SIGTERM.
|
||||
Signal os.Signal
|
||||
//TerminateTimeout controls how long go-upgrade should
|
||||
//wait for the program to terminate itself. After this
|
||||
//timeout, go-upgrade will issue a SIGKILL.
|
||||
TerminateTimeout time.Duration
|
||||
//Restarts will be throttled by this duration.
|
||||
ThrottleRestarts time.Duration
|
||||
//Logging enables [go-upgrade] logs to be sent to stdout.
|
||||
Logging bool
|
||||
//Fetcher will be used to fetch binaries.
|
||||
Fetcher fetcher.Interface
|
||||
}
|
||||
|
||||
func fatalf(f string, args ...interface{}) {
|
||||
log.Fatalf("[go-upgrade] "+f, args...)
|
||||
}
|
||||
|
||||
func Run(c Config) {
|
||||
//sanity check
|
||||
if token := os.Getenv(envBinCheck); token != "" {
|
||||
fmt.Fprint(os.Stdout, token)
|
||||
os.Exit(0)
|
||||
}
|
||||
//validate
|
||||
if c.Program == nil {
|
||||
fatalf("upgrade.Config.Program required")
|
||||
}
|
||||
if c.Address != "" {
|
||||
if len(c.Addresses) > 0 {
|
||||
fatalf("upgrade.Config.Address and Addresses cant both be set")
|
||||
}
|
||||
c.Addresses = []string{c.Address}
|
||||
}
|
||||
if c.Signal == nil {
|
||||
c.Signal = syscall.SIGTERM
|
||||
}
|
||||
if c.TerminateTimeout == 0 {
|
||||
c.TerminateTimeout = 30 * time.Second
|
||||
}
|
||||
if c.Fetcher == nil {
|
||||
fatalf("upgrade.Config.Fetcher required")
|
||||
}
|
||||
//run either in master or slave mode
|
||||
if os.Getenv(envIsSlave) == "1" {
|
||||
sp := slave{Config: c}
|
||||
sp.logf("run")
|
||||
sp.run()
|
||||
} else {
|
||||
mp := master{Config: c}
|
||||
mp.logf("run")
|
||||
mp.run()
|
||||
}
|
||||
}
|
257
upgrade.go
257
upgrade.go
|
@ -1,257 +0,0 @@
|
|||
package upgrade
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/osext"
|
||||
)
|
||||
|
||||
const (
|
||||
isProgVar = "GO_UPGRADE_IS_PROG"
|
||||
getVersionVar = "GO_UPGRADE_GET_VERSION"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Program func() //Programs main function
|
||||
Version string //Current version of the program
|
||||
Fetcher Fetcher //Used to fetch binaries
|
||||
URL string //Template to create upgrade URLs
|
||||
Signal os.Signal //Signal to send to the program on upgrade
|
||||
FetchInterval time.Duration //Check for upgrades at this interval
|
||||
RestartTimeout time.Duration //Restarts will only occur within this timeout
|
||||
Logging bool //Enable logging
|
||||
}
|
||||
|
||||
type upsig struct {
|
||||
upgrade bool
|
||||
sig os.Signal
|
||||
}
|
||||
|
||||
type upgrader struct {
|
||||
Config
|
||||
signals chan *upsig
|
||||
binPath string
|
||||
binPerms os.FileMode
|
||||
binUpgrade []byte
|
||||
upgradedAt time.Time
|
||||
upgraded bool
|
||||
}
|
||||
|
||||
func Run(c Config) {
|
||||
//validate
|
||||
if c.Program == nil {
|
||||
log.Fatalf("upgrade.Config.Program required")
|
||||
}
|
||||
if c.Fetcher == nil {
|
||||
log.Fatalf("upgrade.Config.Fetcher required")
|
||||
}
|
||||
|
||||
//prepare
|
||||
u := upgrader{}
|
||||
u.signals = make(chan *upsig)
|
||||
//apply defaults
|
||||
if c.FetchInterval == 0 {
|
||||
c.FetchInterval = 30 * time.Second
|
||||
}
|
||||
if c.RestartTimeout == 0 {
|
||||
c.RestartTimeout = 10 * time.Second
|
||||
}
|
||||
u.Config = c
|
||||
u.run()
|
||||
}
|
||||
|
||||
func (u *upgrader) run() {
|
||||
|
||||
//get path to binary and confirm its writable
|
||||
binPath, err := osext.Executable()
|
||||
binModifiable := false
|
||||
if err != nil {
|
||||
u.printf("failed to find binary path")
|
||||
} else if f, err := os.OpenFile(binPath, os.O_RDWR, os.ModePerm); err == nil {
|
||||
if info, err := f.Stat(); err == nil && info.Size() > 0 {
|
||||
u.binPerms = info.Mode()
|
||||
sample := make([]byte, 1)
|
||||
if n, err := f.Read(sample); err == nil && n == 1 {
|
||||
//read 1 byte, now write
|
||||
if n, err = f.WriteAt(sample, 0); err == nil && n == 1 {
|
||||
//write success
|
||||
u.binPath = binPath
|
||||
binModifiable = true
|
||||
}
|
||||
}
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
//is the program? or failed to find bin?
|
||||
if os.Getenv(isProgVar) == "1" || !binModifiable {
|
||||
if !binModifiable {
|
||||
u.printf("binary is not writable")
|
||||
}
|
||||
u.Program()
|
||||
return
|
||||
}
|
||||
|
||||
//version request
|
||||
if os.Getenv(getVersionVar) == "1" {
|
||||
fmt.Print(u.Version)
|
||||
os.Exit(0)
|
||||
return
|
||||
}
|
||||
|
||||
//check loop
|
||||
go u.check()
|
||||
//fork loop
|
||||
u.fork()
|
||||
}
|
||||
|
||||
func (u *upgrader) check() {
|
||||
|
||||
first := true
|
||||
for {
|
||||
//wait till next update
|
||||
if first {
|
||||
first = false
|
||||
} else {
|
||||
time.Sleep(u.FetchInterval)
|
||||
}
|
||||
|
||||
u.printf("checking for updates...")
|
||||
|
||||
bin, err := u.Fetcher.Fetch(u.Version)
|
||||
if err != nil {
|
||||
u.printf("failed to get latest version: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(bin) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
tmpBinPath := filepath.Join(os.TempDir(), "goupgrade")
|
||||
if err := ioutil.WriteFile(tmpBinPath, bin, 0700); err != nil {
|
||||
u.printf("failed to write temp binary: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmd := exec.Command(tmpBinPath)
|
||||
cmd.Env = []string{getVersionVar + "=1"}
|
||||
cmdVer, err := cmd.Output()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to run temp binary: %s", err)
|
||||
}
|
||||
ver := string(cmdVer)
|
||||
if ver == u.Version {
|
||||
err = fmt.Errorf("version check failed, upgrade contained same version")
|
||||
}
|
||||
|
||||
//best-effort remove tmp file
|
||||
os.Remove(tmpBinPath)
|
||||
|
||||
if err != nil {
|
||||
u.printf("%s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
//version confirmed, replace!
|
||||
if err := ioutil.WriteFile(u.binPath, bin, u.binPerms); err != nil {
|
||||
u.printf("failed to replace binary: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
//note new version
|
||||
u.Version = ver
|
||||
u.printf("upgraded prog to: %s", ver)
|
||||
u.upgraded = true
|
||||
u.upgradedAt = time.Now()
|
||||
|
||||
//send the chosen signal to prog
|
||||
if u.Signal != nil {
|
||||
u.printf("sending program signal: %s", u.Signal)
|
||||
u.signals <- &upsig{upgrade: true, sig: u.Signal}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upgrader) fork() {
|
||||
|
||||
var cmd *exec.Cmd = nil
|
||||
//proxy native signals through to the child proc
|
||||
nativesigs := make(chan os.Signal)
|
||||
signal.Notify(nativesigs)
|
||||
go func() {
|
||||
for sig := range nativesigs {
|
||||
u.signals <- &upsig{upgrade: false, sig: sig}
|
||||
}
|
||||
}()
|
||||
|
||||
//recieve all native and upgrade signals
|
||||
go func() {
|
||||
for s := range u.signals {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
continue
|
||||
}
|
||||
//child exited was meant for go-upgrade
|
||||
if !s.upgrade && s.sig.String() == "child exited" {
|
||||
continue
|
||||
}
|
||||
if err := cmd.Process.Signal(s.sig); err != nil {
|
||||
u.printf("failed to signal: %s (%s)", s.sig, err)
|
||||
} else {
|
||||
u.printf("signaled: %s", s.sig)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
u.printf("starting %s", u.binPath)
|
||||
cmd = exec.Command(u.binPath)
|
||||
e := os.Environ()
|
||||
e = append(e, isProgVar+"=1")
|
||||
cmd.Env = e
|
||||
cmd.Args = os.Args
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Fatal("failed to start fork: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err := cmd.Wait()
|
||||
cmd = nil
|
||||
code := 0
|
||||
if err != nil {
|
||||
code = 1
|
||||
if exiterr, ok := err.(*exec.ExitError); ok {
|
||||
if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
|
||||
code = status.ExitStatus()
|
||||
}
|
||||
}
|
||||
}
|
||||
u.printf("prog exited with %d", code)
|
||||
|
||||
//if go-upgrade recently sent a signal, then allow it restart
|
||||
if u.upgraded && time.Now().Sub(u.upgradedAt) < u.RestartTimeout {
|
||||
u.upgraded = false
|
||||
continue
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upgrader) printf(f string, args ...interface{}) {
|
||||
if u.Logging {
|
||||
log.Printf("[go-upgrade] "+f, args...)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue