progress on v2, graceful restarts not working yet

master
Jaime Pillora 2016-02-08 01:38:24 +11:00
parent 457309e897
commit 9ef850b1bc
18 changed files with 740 additions and 518 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
tmp
myapp*

View File

@ -2,13 +2,13 @@
#NOTE: DONT CTRL+C OR CLEANUP WONT OCCUR
#upgrade server (any http server)
#binary hosting server (any file server)
# go get github.com/jpillora/serve
serve &
#initial build
echo "BUILDING APP 0.3.0"
go build -ldflags "-X main.VERSION 0.3.0" -o myapp
echo "BUILDING APP: A"
go build -ldflags "-X main.FOO=A" -o myapp
#run!
echo "RUNNING APP"
@ -16,17 +16,17 @@ echo "RUNNING APP"
sleep 3
echo "BUILDING APP 0.3.1"
go build -ldflags "-X main.VERSION 0.3.1" -o myapp.0.3.1
echo "BUILDING APP: B"
go build -ldflags "-X main.FOO=B" -o newmyapp
sleep 4
echo "BUILDING APP 0.4.0"
go build -ldflags "-X main.VERSION 0.4.0" -o myapp.0.4.0
echo "BUILDING APP: C"
go build -ldflags "-X main.FOO=C" -o newmyapp
sleep 4
#end demo - cleanup
killall serve
killall myapp
rm myapp* 2> /dev/null
rm myapp* 2> /dev/null

View File

@ -1,33 +1,40 @@
package main
import (
"fmt"
"log"
"os"
"net/http"
"time"
"github.com/jpillora/go-upgrade"
"github.com/jpillora/go-upgrade/fetcher"
)
var VERSION = "0.0.0" //set with ldflags
var FOO = "" //set manually or with with ldflags
//change your 'main' into a 'prog'
func prog() {
log.Printf("Running version %s...", VERSION)
select {}
//convert your 'main()' into a 'prog(state)'
func prog(state upgrade.State) {
log.Printf("app (%s) listening...", state.ID)
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Foo", FOO)
w.Header().Set("Header-Time", time.Now().String())
w.WriteHeader(200)
time.Sleep(30 * time.Second)
fmt.Fprintf(w, "Body-Time: %s (Foo: %s)", time.Now(), FOO)
}))
http.Serve(state.Listener, nil)
}
//then create another 'main' which runs the upgrades
func main() {
upgrade.Run(upgrade.Config{
Program: prog,
Version: VERSION,
Fetcher: upgrade.BasicFetcher(
"http://localhost:3000/myapp_{{.Version}}",
),
FetchInterval: 2 * time.Second,
Signal: os.Interrupt,
//display logs of actions
Logging: true,
Address: "0.0.0.0:3000",
Fetcher: &fetcher.HTTP{
URL: "http://localhost:4000/myapp2",
Interval: 5 * time.Second,
},
Logging: true, //display log of go-upgrade actions
})
}

View File

@ -1 +0,0 @@
lee

View File

@ -1 +0,0 @@
nee

View File

@ -1,5 +0,0 @@
package upgrade
type Fetcher interface {
Fetch(currentVersion string) (binary []byte, err error)
}

24
fetcher/fetcher.go Normal file
View File

@ -0,0 +1,24 @@
package fetcher
import "io"
type Interface interface {
//Fetch should check if there is an updated
//binary to fetch, and then stream it back the
//form of an io.Reader. If io.Reader is nil,
//then it is assumed there are no updates.
Fetch() (io.Reader, error)
}
//Converts a fetch function into interface
func Func(fn func() (io.Reader, error)) Interface {
return &fetcher{fn}
}
type fetcher struct {
fn func() (io.Reader, error)
}
func (f fetcher) Fetch() (io.Reader, error) {
return f.fn()
}

71
fetcher/fetcher_http.go Normal file
View File

@ -0,0 +1,71 @@
package fetcher
import (
"fmt"
"io"
"net/http"
"time"
)
//HTTPFetcher uses HEAD requests to poll the status of a given
//file. If it detects this file has been updated, it will fetch
//and stream out to the binary writer.
type HTTP struct {
//URL to poll for new binaries
URL string
Interval time.Duration
//interal state
delay bool
lasts map[string]string
}
//if any of these change, the binary has been updated
var httpHeaders = []string{"ETag", "If-Modified-Since", "Last-Modified", "Content-Length"}
func (h *HTTP) Fetch() (io.Reader, error) {
//apply defaults
if h.Interval == 0 {
h.Interval = 5 * time.Minute
}
if h.lasts == nil {
h.lasts = map[string]string{}
}
//delay fetches after first
if h.delay {
time.Sleep(h.Interval)
}
h.delay = true
//status check using HEAD
resp, err := http.Head(h.URL)
if err != nil {
return nil, fmt.Errorf("HEAD request failed (%s)", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HEAD request failed (status code %d)", resp.StatusCode)
}
//if all headers match, skip update
matches, total := 0, 0
for _, header := range httpHeaders {
if curr := resp.Header.Get(header); curr != "" {
if last, ok := h.lasts[header]; ok && last == curr {
matches++
}
h.lasts[header] = curr
total++
}
}
if matches == total {
return nil, nil //skip, file match
}
//binary fetch using GET
resp, err = http.Get(h.URL)
if err != nil {
return nil, fmt.Errorf("GET request failed (%s)", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET request failed (status code %d)", resp.StatusCode)
}
//success!
return resp.Body, nil
}

64
fetcher/fetcher_util.go Normal file
View File

@ -0,0 +1,64 @@
package fetcher
//Similar to ioutil.ReadAll except it extracts binaries from
//the reader, whether the reader is a .zip .tar .tar.gz .gz or raw bytes
// func GetBinary(path string, r io.Reader) ([]byte, error) {
//
// if strings.HasSuffix(path, ".gz") {
// gr, err := gzip.NewReader(r)
// if err != nil {
// return nil, err
// }
// r = gr
// path = strings.TrimSuffix(path, ".gz")
// }
//
// if strings.HasSuffix(path, ".tar") {
// tr := tar.NewReader(r)
// var fr io.Reader
// for {
// info, err := tr.Next()
// if err != nil {
// return nil, err
// }
// if os.FileMode(info.Mode)&0111 != 0 {
// log.Printf("found exec %s", info.Name)
// fr = tr
// break
// }
// }
// if fr == nil {
// return nil, fmt.Errorf("binary not found in tar archive")
// }
// r = fr
//
// } else if strings.HasSuffix(path, ".zip") {
// bin, err := ioutil.ReadAll(r)
// if err != nil {
// return nil, err
// }
// buff := bytes.NewReader(bin)
// zr, err := zip.NewReader(buff, int64(buff.Len()))
// if err != nil {
// return nil, err
// }
//
// var fr io.Reader
// for _, f := range zr.File {
// info := f.FileInfo()
// if info.Mode()&0111 != 0 {
// log.Printf("found exec %s", info.Name())
// fr, err = f.Open()
// if err != nil {
// return nil, err
// }
// }
// }
// if fr == nil {
// return nil, fmt.Errorf("binary not found in zip archive")
// }
// r = fr
// }
//
// return ioutil.ReadAll(r)
// }

View File

@ -1,126 +0,0 @@
package upgrade
import (
"bytes"
"fmt"
"log"
"net/http"
"regexp"
"runtime"
"strconv"
"strings"
"text/template"
)
type basicFetcher struct {
url string
urlTempl *template.Template
}
func BasicFetcher(url string) Fetcher {
t := template.New("url")
t, err := t.Parse(url)
if err != nil {
log.Fatalf("upgrade.BasicFetcher.url invalid: %s", err)
}
b := &basicFetcher{
url: url,
urlTempl: t,
}
//test template
b.getURL("0.1.0")
return b
}
func (b *basicFetcher) getURL(version string) string {
//run url template with this version
var urlb bytes.Buffer
if err := b.urlTempl.Execute(&urlb, struct {
Version, OS, Arch string
}{
version, runtime.GOOS, runtime.GOARCH,
}); err != nil {
//execute will fail if theres a data error
log.Fatalf("upgrade.BasicFetcher.url invalid: %s", err)
}
return urlb.String()
}
func (b *basicFetcher) Fetch(currentVersion string) ([]byte, error) {
//get version permutations
versions, err := getAllVersionIncrements(currentVersion)
if err != nil {
return nil, err
}
var bin []byte
var errs []string
//try all versions
for _, v := range versions {
url := b.getURL(v)
resp, e := http.Get(url)
if e != nil {
errs = append(errs, "invalid request")
continue
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
errs = append(errs, v)
continue
}
b, err := ReadAll(url, resp.Body)
if err != nil {
errs = append(errs, "download binary failed: "+err.Error())
continue
}
//success!
bin = b
break
}
if bin != nil {
return bin, nil
}
return nil, fmt.Errorf(strings.Join(errs, ", "))
}
func getAllVersionIncrements(version string) ([]string, error) {
var versions []string
curr := 0
for {
re := regexp.MustCompile(`\d+`)
groups := re.FindAllString(version, -1)
numGroups := len(groups)
if numGroups == 0 {
return nil, fmt.Errorf("No digits to increment in version: %s", version)
}
i := 0
//we replace the version string numGroup times, swapping out one group at time
v := re.ReplaceAllStringFunc(version, func(d string) string {
l := len(d)
//going from right to left
if i == numGroups-1-curr {
n, _ := strconv.Atoi(d)
d = strconv.Itoa(n + 1)
} else if i > numGroups-1-curr {
//reset all numbers to the right to 0
d = "0"
}
for len(d) < l {
d = "0" + d
}
i++
return d
})
versions = append(versions, v)
curr++
if curr == numGroups {
break
}
}
return versions, nil
}

View File

@ -1,28 +0,0 @@
package upgrade
import (
"reflect"
"testing"
)
func TestVersionIncs(t *testing.T) {
tests := []struct {
version string
expected []string
}{
{"0.1.0", []string{"0.1.1", "0.2.0", "1.0.0"}},
{"0.3.1", []string{"0.3.2", "0.4.0", "1.0.0"}},
}
for i, test := range tests {
vers, err := getAllVersionIncrements(test.version)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(vers, test.expected) {
t.Fatalf("test %d failed:\nexpecting: %#v\n got: %#v", i, test.expected, vers)
}
}
}

49
graceful.go Normal file
View File

@ -0,0 +1,49 @@
package upgrade
import (
"net"
"os"
"sync"
"time"
)
//gracefully closing net.Listener
type upListener struct {
net.Listener
wg sync.WaitGroup
}
func (l *upListener) Accept() (net.Conn, error) {
conn, err := l.Listener.(*net.TCPListener).AcceptTCP()
if err != nil {
return nil, err
}
conn.SetKeepAlive(true) // see http.tcpKeepAliveListener
conn.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
uconn := upConn{
Conn: conn,
wg: &l.wg,
}
l.wg.Add(1)
return uconn, nil
}
func (l *upListener) File() *os.File {
// returns a dup(2) - FD_CLOEXEC flag *not* set
tl := l.Listener.(*net.TCPListener)
fl, _ := tl.File()
return fl
}
type upConn struct {
net.Conn
wg *sync.WaitGroup
}
func (uconn upConn) Close() error {
err := uconn.Conn.Close()
if err == nil {
uconn.wg.Done()
}
return err
}

9
graceful_test.go Normal file
View File

@ -0,0 +1,9 @@
package upgrade
import "testing"
func TestGraceful(t *testing.T) {
}

310
proc_master.go Normal file
View File

@ -0,0 +1,310 @@
package upgrade
import (
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/hex"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"syscall"
"time"
"github.com/kardianos/osext"
)
var tmpBinPath = filepath.Join(os.TempDir(), "goupgrade")
//a go-upgrade master process
type master struct {
Config
slaveCmd *exec.Cmd
slaveExtraFiles []*os.File
binPath string
binPerms os.FileMode
binHash []byte
restarting bool
restartedAt time.Time
restarted chan bool
descriptorsReleased chan bool
signalledAt time.Time
signals chan os.Signal
}
func (mp *master) run() {
mp.readBinary()
mp.setupSignalling()
mp.retreiveFileDescriptors()
go mp.fetchLoop()
mp.forkLoop()
}
func (mp *master) readBinary() {
//get path to binary and confirm its writable
binPath, err := osext.Executable()
binFound := false
binWritable := false
if err != nil {
mp.logf("failed to find binary path")
} else if f, err := os.OpenFile(binPath, os.O_RDWR, os.ModePerm); err == nil {
if info, err := f.Stat(); err == nil && info.Size() > 0 {
mp.binPath = binPath
binFound = true
//initial hash of file
hash := sha1.New()
io.Copy(hash, f)
mp.binHash = hash.Sum(nil)
//copy permissions
mp.binPerms = info.Mode()
//test write
sample := make([]byte, 1)
if n, err := f.ReadAt(sample, 0); err == nil && n == 1 {
//read 1 byte, now write
if n, err = f.WriteAt(sample, 0); err == nil && n == 1 {
//write success
binWritable = true
}
}
}
f.Close()
}
//is the program? or failed to find the writable binary path?
if !binWritable {
var err error
if !binFound {
err = fmt.Errorf("binary path not found: %s", binPath)
} else if !binWritable {
err = fmt.Errorf("binary path not writable: %s", binPath)
}
if err != nil {
if mp.Config.Optional {
log.Print(err)
} else {
fatalf("%s", err)
}
}
mp.Program(DisabledState)
return
}
}
func (mp *master) setupSignalling() {
//updater-forker comms
mp.restarted = make(chan bool)
mp.descriptorsReleased = make(chan bool)
//read all master process signals
mp.signals = make(chan os.Signal)
signal.Notify(mp.signals)
go func() {
for s := range mp.signals {
if s.String() == "child exited" {
continue
}
//**during a restart** a SIGUSR1 signals
//to the master process that, the file
//descriptors have been released
if mp.restarting && s == syscall.SIGUSR1 {
mp.descriptorsReleased <- true
continue
}
if mp.slaveCmd != nil && mp.slaveCmd.Process != nil {
mp.logf("proxy signal (%s)", s)
mp.slaveCmd.Process.Signal(s)
} else if s == syscall.SIGINT {
mp.logf("interupt with no slave")
os.Exit(1)
} else {
mp.logf("signal discarded (%s), no slave process", s)
}
}
}()
}
func (mp *master) retreiveFileDescriptors() {
mp.slaveExtraFiles = make([]*os.File, len(mp.Config.Addresses))
for i, addr := range mp.Config.Addresses {
a, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
fatalf("invalid address: %s (%s)", addr, err)
}
l, err := net.ListenTCP("tcp", a)
if err != nil {
fatalf(err.Error())
}
f, err := l.File()
if err != nil {
fatalf("failed to retreive fd for: %s (%s)", addr, err)
}
if err := l.Close(); err != nil {
fatalf("failed to close listener for: %s (%s)", addr, err)
}
mp.slaveExtraFiles[i] = f
}
}
func (mp *master) fetchLoop() {
for {
mp.fetch()
time.Sleep(time.Second) //fetches should be throttled by the fetcher!
}
}
func (mp *master) fetch() {
mp.logf("checking for updates...")
reader, err := mp.Fetcher.Fetch()
if err != nil {
mp.logf("failed to get latest version: %s", err)
return
}
if reader == nil {
return //fetcher has explicitly said there are no updates
}
//optional closer
if closer, ok := reader.(io.Closer); ok {
defer closer.Close()
}
tmpBin, err := os.Create(tmpBinPath)
if err != nil {
mp.logf("failed to open temp binary: %s", err)
return
}
defer os.Remove(tmpBinPath)
//tee off to sha1
hash := sha1.New()
reader = io.TeeReader(reader, hash)
//write to temp
_, err = io.Copy(tmpBin, reader)
if err != nil {
mp.logf("failed to write temp binary: %s", err)
return
}
//compare hash
newHash := hash.Sum(nil)
if bytes.Equal(mp.binHash, newHash) {
return
}
//copy permissions
if err := tmpBin.Chmod(mp.binPerms); err != nil {
mp.logf("failed to make binary executable: %s", err)
return
}
tmpBin.Close()
//sanity check
buff := make([]byte, 8)
rand.Read(buff)
tokenIn := hex.EncodeToString(buff)
cmd := exec.Command(tmpBinPath)
cmd.Env = []string{envBinCheck + "=" + tokenIn}
tokenOut, err := cmd.Output()
if err != nil {
mp.logf("failed to run temp binary: %s", err)
return
}
if tokenIn != string(tokenOut) {
mp.logf("sanity check failed")
return
}
//replace!
if err := os.Rename(tmpBinPath, mp.binPath); err != nil {
mp.logf("failed to replace binary: %s", err)
return
}
mp.logf("upgraded binary (%x -> %x)", mp.binHash[:12], newHash[:12])
mp.binHash = newHash
//binary successfully replaced, perform graceful restart
mp.restarting = true
mp.signalledAt = time.Now()
mp.signals <- syscall.SIGTERM //ask nicely to terminate
select {
case <-mp.restarted:
//success
case <-time.After(mp.TerminateTimeout):
//times up process, we did ask nicely!
mp.logf("graceful timeout, forcing exit")
mp.signals <- syscall.SIGKILL
}
//and keep fetching...
return
}
//not a real fork
func (mp *master) forkLoop() {
//loop, restart command
for {
mp.fork()
}
}
func (mp *master) fork() {
mp.logf("starting %s", mp.binPath)
cmd := exec.Command(mp.binPath)
mp.slaveCmd = cmd
e := os.Environ()
e = append(e, envBinID+"="+hex.EncodeToString(mp.binHash))
e = append(e, envIsSlave+"=1")
e = append(e, envNumFDs+"="+strconv.Itoa(len(mp.Config.Addresses)))
cmd.Env = e
cmd.Args = os.Args
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.ExtraFiles = mp.slaveExtraFiles
if err := cmd.Start(); err != nil {
fatalf("failed to fork: %s", err)
os.Exit(1)
}
if mp.restarting {
mp.restartedAt = time.Now()
mp.restarting = false
mp.restarted <- true
}
//convert wait into channel
cmdwait := make(chan error)
go func() {
cmdwait <- cmd.Wait()
}()
//wait....
select {
case err := <-cmdwait:
//program exited before releasing descriptors
if mp.restarting {
//restart requested
return
}
//proxy exit code out to master
code := 0
if err != nil {
code = 1
if exiterr, ok := err.(*exec.ExitError); ok {
if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
code = status.ExitStatus()
}
}
}
mp.logf("prog exited with %d", code)
//proxy exit with same code
os.Exit(code)
case <-mp.descriptorsReleased:
log.Printf("descriptors released")
//if descriptors are released, the program
//is yielding control of the socket and
//should restart
}
}
func (mp *master) logf(f string, args ...interface{}) {
if mp.Logging {
log.Printf("[go-upgrade master] "+f, args...)
}
}

98
proc_slave.go Normal file
View File

@ -0,0 +1,98 @@
package upgrade
import (
"log"
"net"
"os"
"os/signal"
"strconv"
"time"
)
var (
//DisabledState is a placeholder state for when
//go-upgrade is disabled and the program function
//is run manually.
DisabledState = State{Enabled: false}
)
type State struct {
//whether go-upgrade is running enabled. When enabled,
//this program will be running in a child process and
//go-upgrade will perform rolling upgrades.
Enabled bool
//ID is a SHA-1 hash of the current running binary
ID string
//StartedAt records the start time of the program
StartedAt time.Time
//Listener is the first net.Listener in Listeners
Listener net.Listener
//Listeners are the set of acquired sockets by the master
//process. These are all passed into this program in the
//same order they are specified in Config.Addresses.
Listeners []net.Listener
}
//a go-upgrade slave process
type slave struct {
Config
listeners []*upListener
state State
}
func (sp *slave) run() {
sp.state.Enabled = true
sp.state.ID = os.Getenv(envBinID)
sp.state.StartedAt = time.Now()
sp.initFileDescriptors()
//find parent
//run program with state
sp.logf("start program")
sp.Config.Program(sp.state)
}
func (sp *slave) initFileDescriptors() {
//inspect file descriptors
numFDs, err := strconv.Atoi(os.Getenv(envNumFDs))
if err != nil {
fatalf("invalid %s integer", envNumFDs)
}
sp.listeners = make([]*upListener, numFDs)
sp.state.Listeners = make([]net.Listener, numFDs)
for i := 0; i < numFDs; i++ {
f := os.NewFile(uintptr(3+i), "")
l, err := net.FileListener(f)
if err != nil {
fatalf("failed to inherit file descriptor: %d", i)
}
u := &upListener{Listener: l}
sp.listeners[i] = u
sp.state.Listeners[i] = u
}
if len(sp.state.Listeners) > 0 {
sp.state.Listener = sp.state.Listeners[0]
}
}
func (sp *slave) watchSignal() {
signals := make(chan os.Signal)
signal.Notify(signals, sp.Config.Signal)
go func() {
<-signals
//do graceful shutdown
//stop listening
//signal released fds
//listeners should be waiting on connections and close
}()
}
func (sp *slave) logf(f string, args ...interface{}) {
if sp.Logging {
log.Printf("[go-upgrade slave] "+f, args...)
}
}

View File

@ -1,78 +0,0 @@
package upgrade
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"strings"
)
//Similar to ioutil.ReadAll except it extracts binaries from
//the reader, whether the reader is a .zip .tar .tar.gz .gz or raw bytes
func ReadAll(path string, r io.Reader) ([]byte, error) {
if strings.HasSuffix(path, ".gz") {
gr, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
r = gr
path = strings.TrimSuffix(path, ".gz")
}
if strings.HasSuffix(path, ".tar") {
tr := tar.NewReader(r)
var fr io.Reader
for {
info, err := tr.Next()
if err != nil {
return nil, err
}
if os.FileMode(info.Mode)&0111 != 0 {
log.Printf("found exec %s", info.Name)
fr = tr
break
}
}
if fr == nil {
return nil, fmt.Errorf("binary not found in tar archive")
}
r = fr
} else if strings.HasSuffix(path, ".zip") {
bin, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
buff := bytes.NewReader(bin)
zr, err := zip.NewReader(buff, int64(buff.Len()))
if err != nil {
return nil, err
}
var fr io.Reader
for _, f := range zr.File {
info := f.FileInfo()
if info.Mode()&0111 != 0 {
log.Printf("found exec %s", info.Name())
fr, err = f.Open()
if err != nil {
return nil, err
}
}
}
if fr == nil {
return nil, fmt.Errorf("binary not found in zip archive")
}
r = fr
}
return ioutil.ReadAll(r)
}

84
run.go Normal file
View File

@ -0,0 +1,84 @@
package upgrade
import (
"fmt"
"log"
"os"
"syscall"
"time"
"github.com/jpillora/go-upgrade/fetcher"
)
const (
envIsSlave = "GO_UPGRADE_IS_SLAVE"
envNumFDs = "GO_UPGRADE_NUM_FDS"
envBinID = "GO_UPGRADE_BIN_ID"
envBinCheck = "GO_UPGRADE_BIN_CHECK"
)
type Config struct {
//Optional allows go-upgrade to fallback to running
//running the program with
Optional bool
//Program's main function
Program func(state State)
//Program's zero-downtime socket listening address (set this or Addresses)
Address string
//Program's zero-downtime socket listening addresses (set this or Address)
Addresses []string
//Signal program will accept to initiate graceful
//application termination. Defaults to SIGTERM.
Signal os.Signal
//TerminateTimeout controls how long go-upgrade should
//wait for the program to terminate itself. After this
//timeout, go-upgrade will issue a SIGKILL.
TerminateTimeout time.Duration
//Restarts will be throttled by this duration.
ThrottleRestarts time.Duration
//Logging enables [go-upgrade] logs to be sent to stdout.
Logging bool
//Fetcher will be used to fetch binaries.
Fetcher fetcher.Interface
}
func fatalf(f string, args ...interface{}) {
log.Fatalf("[go-upgrade] "+f, args...)
}
func Run(c Config) {
//sanity check
if token := os.Getenv(envBinCheck); token != "" {
fmt.Fprint(os.Stdout, token)
os.Exit(0)
}
//validate
if c.Program == nil {
fatalf("upgrade.Config.Program required")
}
if c.Address != "" {
if len(c.Addresses) > 0 {
fatalf("upgrade.Config.Address and Addresses cant both be set")
}
c.Addresses = []string{c.Address}
}
if c.Signal == nil {
c.Signal = syscall.SIGTERM
}
if c.TerminateTimeout == 0 {
c.TerminateTimeout = 30 * time.Second
}
if c.Fetcher == nil {
fatalf("upgrade.Config.Fetcher required")
}
//run either in master or slave mode
if os.Getenv(envIsSlave) == "1" {
sp := slave{Config: c}
sp.logf("run")
sp.run()
} else {
mp := master{Config: c}
mp.logf("run")
mp.run()
}
}

View File

@ -1,257 +0,0 @@
package upgrade
import (
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/kardianos/osext"
)
const (
isProgVar = "GO_UPGRADE_IS_PROG"
getVersionVar = "GO_UPGRADE_GET_VERSION"
)
type Config struct {
Program func() //Programs main function
Version string //Current version of the program
Fetcher Fetcher //Used to fetch binaries
URL string //Template to create upgrade URLs
Signal os.Signal //Signal to send to the program on upgrade
FetchInterval time.Duration //Check for upgrades at this interval
RestartTimeout time.Duration //Restarts will only occur within this timeout
Logging bool //Enable logging
}
type upsig struct {
upgrade bool
sig os.Signal
}
type upgrader struct {
Config
signals chan *upsig
binPath string
binPerms os.FileMode
binUpgrade []byte
upgradedAt time.Time
upgraded bool
}
func Run(c Config) {
//validate
if c.Program == nil {
log.Fatalf("upgrade.Config.Program required")
}
if c.Fetcher == nil {
log.Fatalf("upgrade.Config.Fetcher required")
}
//prepare
u := upgrader{}
u.signals = make(chan *upsig)
//apply defaults
if c.FetchInterval == 0 {
c.FetchInterval = 30 * time.Second
}
if c.RestartTimeout == 0 {
c.RestartTimeout = 10 * time.Second
}
u.Config = c
u.run()
}
func (u *upgrader) run() {
//get path to binary and confirm its writable
binPath, err := osext.Executable()
binModifiable := false
if err != nil {
u.printf("failed to find binary path")
} else if f, err := os.OpenFile(binPath, os.O_RDWR, os.ModePerm); err == nil {
if info, err := f.Stat(); err == nil && info.Size() > 0 {
u.binPerms = info.Mode()
sample := make([]byte, 1)
if n, err := f.Read(sample); err == nil && n == 1 {
//read 1 byte, now write
if n, err = f.WriteAt(sample, 0); err == nil && n == 1 {
//write success
u.binPath = binPath
binModifiable = true
}
}
}
f.Close()
}
//is the program? or failed to find bin?
if os.Getenv(isProgVar) == "1" || !binModifiable {
if !binModifiable {
u.printf("binary is not writable")
}
u.Program()
return
}
//version request
if os.Getenv(getVersionVar) == "1" {
fmt.Print(u.Version)
os.Exit(0)
return
}
//check loop
go u.check()
//fork loop
u.fork()
}
func (u *upgrader) check() {
first := true
for {
//wait till next update
if first {
first = false
} else {
time.Sleep(u.FetchInterval)
}
u.printf("checking for updates...")
bin, err := u.Fetcher.Fetch(u.Version)
if err != nil {
u.printf("failed to get latest version: %s", err)
continue
}
if len(bin) == 0 {
continue
}
tmpBinPath := filepath.Join(os.TempDir(), "goupgrade")
if err := ioutil.WriteFile(tmpBinPath, bin, 0700); err != nil {
u.printf("failed to write temp binary: %s", err)
continue
}
cmd := exec.Command(tmpBinPath)
cmd.Env = []string{getVersionVar + "=1"}
cmdVer, err := cmd.Output()
if err != nil {
err = fmt.Errorf("failed to run temp binary: %s", err)
}
ver := string(cmdVer)
if ver == u.Version {
err = fmt.Errorf("version check failed, upgrade contained same version")
}
//best-effort remove tmp file
os.Remove(tmpBinPath)
if err != nil {
u.printf("%s", err)
continue
}
//version confirmed, replace!
if err := ioutil.WriteFile(u.binPath, bin, u.binPerms); err != nil {
u.printf("failed to replace binary: %s", err)
continue
}
//note new version
u.Version = ver
u.printf("upgraded prog to: %s", ver)
u.upgraded = true
u.upgradedAt = time.Now()
//send the chosen signal to prog
if u.Signal != nil {
u.printf("sending program signal: %s", u.Signal)
u.signals <- &upsig{upgrade: true, sig: u.Signal}
}
}
}
func (u *upgrader) fork() {
var cmd *exec.Cmd = nil
//proxy native signals through to the child proc
nativesigs := make(chan os.Signal)
signal.Notify(nativesigs)
go func() {
for sig := range nativesigs {
u.signals <- &upsig{upgrade: false, sig: sig}
}
}()
//recieve all native and upgrade signals
go func() {
for s := range u.signals {
if cmd == nil || cmd.Process == nil {
continue
}
//child exited was meant for go-upgrade
if !s.upgrade && s.sig.String() == "child exited" {
continue
}
if err := cmd.Process.Signal(s.sig); err != nil {
u.printf("failed to signal: %s (%s)", s.sig, err)
} else {
u.printf("signaled: %s", s.sig)
}
}
}()
for {
u.printf("starting %s", u.binPath)
cmd = exec.Command(u.binPath)
e := os.Environ()
e = append(e, isProgVar+"=1")
cmd.Env = e
cmd.Args = os.Args
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
log.Fatal("failed to start fork: %s", err)
os.Exit(1)
}
err := cmd.Wait()
cmd = nil
code := 0
if err != nil {
code = 1
if exiterr, ok := err.(*exec.ExitError); ok {
if status, ok := exiterr.Sys().(syscall.WaitStatus); ok {
code = status.ExitStatus()
}
}
}
u.printf("prog exited with %d", code)
//if go-upgrade recently sent a signal, then allow it restart
if u.upgraded && time.Now().Sub(u.upgradedAt) < u.RestartTimeout {
u.upgraded = false
continue
}
os.Exit(code)
}
}
func (u *upgrader) printf(f string, args ...interface{}) {
if u.Logging {
log.Printf("[go-upgrade] "+f, args...)
}
}