diff --git a/fetcher/fetcher_github.go b/fetcher/fetcher_github.go new file mode 100644 index 0000000..a161e30 --- /dev/null +++ b/fetcher/fetcher_github.go @@ -0,0 +1,142 @@ +package fetcher + +import ( + "compress/gzip" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "runtime" + "strings" + "time" +) + +//Github uses the Github V3 API to retrieve the latest release +//of a given repository and enumerate its assets. If a release +//contains a matching asset, it will fetch +//and return its io.Reader stream. +type Github struct { + //Github username and repository name + User, Repo string + //Interval between fetches + Interval time.Duration + //Asset is used to find matching release asset. + //By default a file will match if it contains + //both GOOS and GOARCH. + Asset func(filename string) bool + //interal state + releaseURL string + delay bool + lastETag string + latestRelease struct { + TagName string `json:"tag_name"` + Assets []struct { + Name string `json:"name"` + URL string `json:"browser_download_url"` + } `json:"assets"` + } +} + +func (h *Github) defaultAsset(filename string) bool { + return strings.Contains(filename, runtime.GOOS) && strings.Contains(filename, runtime.GOARCH) +} + +func (h *Github) Init() error { + //apply defaults + if h.User == "" { + return fmt.Errorf("User required") + } + if h.Repo == "" { + return fmt.Errorf("Repo required") + } + if h.Asset == nil { + h.Asset = h.defaultAsset + } + h.releaseURL = "https://api.github.com/repos/" + h.User + "/" + h.Repo + "/releases/latest" + if h.Interval == 0 { + h.Interval = 5 * time.Minute + } else if h.Interval < 1*time.Minute { + log.Printf("[overseer.github] warning: intervals less than 1 minute will surpass the public rate limit") + } + return nil +} + +func (h *Github) Fetch() (io.Reader, error) { + //delay fetches after first + if h.delay { + time.Sleep(h.Interval) + } + h.delay = true + //check release status + resp, err := http.Get(h.releaseURL) + if err != nil { + return nil, fmt.Errorf("release info request failed (%s)", err) + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("release info request failed (status code %d)", resp.StatusCode) + } + //clear assets + h.latestRelease.Assets = nil + if err := json.NewDecoder(resp.Body).Decode(&h.latestRelease); err != nil { + return nil, fmt.Errorf("invalid request info (%s)", err) + } + resp.Body.Close() + //find appropriate asset + assetURL := "" + for _, a := range h.latestRelease.Assets { + if h.Asset(a.Name) { + assetURL = a.URL + break + } + } + if assetURL == "" { + return nil, fmt.Errorf("no matching assets in this release (%s)", h.latestRelease.TagName) + } + //fetch location + req, _ := http.NewRequest("HEAD", assetURL, nil) + resp, err = http.DefaultTransport.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("release location request failed (%s)", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusFound { + return nil, fmt.Errorf("release location request failed (status code %d)", resp.StatusCode) + } + s3URL := resp.Header.Get("Location") + //psuedo-HEAD request + req, err = http.NewRequest("GET", s3URL, nil) + if err != nil { + return nil, fmt.Errorf("release location url error (%s)", err) + } + req.Header.Set("Range", "bytes=0-0") // HEAD not allowed so we request for 1 byte + resp, err = http.DefaultTransport.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("release location request failed (%s)", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusPartialContent { + return nil, fmt.Errorf("release location request failed (status code %d)", resp.StatusCode) + } + etag := resp.Header.Get("ETag") + if etag != "" && h.lastETag == etag { + return nil, nil //skip, hash match + } + //get binary request + resp, err = http.Get(s3URL) + if err != nil { + return nil, fmt.Errorf("release binary request failed (%s)", err) + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("release binary request failed (status code %d)", resp.StatusCode) + } + h.lastETag = etag + //success! + //extract gz files + if strings.HasSuffix(assetURL, ".gz") && resp.Header.Get("Content-Encoding") != "gzip" { + return gzip.NewReader(resp.Body) + } + return resp.Body, nil +} diff --git a/fetcher/fetcher_http.go b/fetcher/fetcher_http.go index 671bdca..64ff400 100644 --- a/fetcher/fetcher_http.go +++ b/fetcher/fetcher_http.go @@ -1,9 +1,11 @@ package fetcher import ( + "compress/gzip" "fmt" "io" "net/http" + "strings" "time" ) @@ -75,6 +77,10 @@ func (h *HTTP) Fetch() (io.Reader, error) { if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("GET request failed (status code %d)", resp.StatusCode) } + //extract gz files + if strings.HasSuffix(h.URL, ".gz") && resp.Header.Get("Content-Encoding") != "gzip" { + return gzip.NewReader(resp.Body) + } //success! return resp.Body, nil } diff --git a/fetcher/fetcher_s3.go b/fetcher/fetcher_s3.go index 8ecb573..3bf8f8a 100644 --- a/fetcher/fetcher_s3.go +++ b/fetcher/fetcher_s3.go @@ -1,10 +1,12 @@ package fetcher import ( + "compress/gzip" "errors" "fmt" "io" "os" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -49,11 +51,6 @@ func (s *S3) Init() error { Region: &s.Region, } s.client = s3.New(session.New(config)) - - //TODO include this? maybe given access to bucket after init - // resp, err := s.client.HeadBucketRequest(&s3.HeadBucketInput{Bucket: &s.Bucket}) - // if err != nil {} - //apply defaults if s.Interval == 0 { s.Interval = 5 * time.Minute @@ -81,6 +78,10 @@ func (s *S3) Fetch() (io.Reader, error) { if err != nil { return nil, fmt.Errorf("GET request failed (%s)", err) } + //extract gz files + if strings.HasSuffix(s.Key, ".gz") && *get.ContentEncoding != "gzip" { + return gzip.NewReader(get.Body) + } //success! return get.Body, nil } diff --git a/fetcher/fetcher_util.go b/fetcher/fetcher_util.go deleted file mode 100644 index d174299..0000000 --- a/fetcher/fetcher_util.go +++ /dev/null @@ -1,64 +0,0 @@ -package fetcher - -//Similar to ioutil.ReadAll except it extracts binaries from -//the reader, whether the reader is a .zip .tar .tar.gz .gz or raw bytes -// func GetBinary(path string, r io.Reader) ([]byte, error) { -// -// if strings.HasSuffix(path, ".gz") { -// gr, err := gzip.NewReader(r) -// if err != nil { -// return nil, err -// } -// r = gr -// path = strings.TrimSuffix(path, ".gz") -// } -// -// if strings.HasSuffix(path, ".tar") { -// tr := tar.NewReader(r) -// var fr io.Reader -// for { -// info, err := tr.Next() -// if err != nil { -// return nil, err -// } -// if os.FileMode(info.Mode)&0111 != 0 { -// log.Printf("found exec %s", info.Name) -// fr = tr -// break -// } -// } -// if fr == nil { -// return nil, fmt.Errorf("binary not found in tar archive") -// } -// r = fr -// -// } else if strings.HasSuffix(path, ".zip") { -// bin, err := ioutil.ReadAll(r) -// if err != nil { -// return nil, err -// } -// buff := bytes.NewReader(bin) -// zr, err := zip.NewReader(buff, int64(buff.Len())) -// if err != nil { -// return nil, err -// } -// -// var fr io.Reader -// for _, f := range zr.File { -// info := f.FileInfo() -// if info.Mode()&0111 != 0 { -// log.Printf("found exec %s", info.Name()) -// fr, err = f.Open() -// if err != nil { -// return nil, err -// } -// } -// } -// if fr == nil { -// return nil, fmt.Errorf("binary not found in zip archive") -// } -// r = fr -// } -// -// return ioutil.ReadAll(r) -// } diff --git a/graceful.go b/graceful.go index adc659c..dfd65b0 100644 --- a/graceful.go +++ b/graceful.go @@ -1,5 +1,9 @@ package overseer +//overseer listeners and connections allow graceful +//restarts by tracking when all connections from a listener +//have been closed + import ( "net" "os" @@ -7,29 +11,29 @@ import ( "time" ) -func newUpListener(l net.Listener) *upListener { - return &upListener{ +func newOverseerListener(l net.Listener) *overseerListener { + return &overseerListener{ Listener: l, closeByForce: make(chan bool), } } //gracefully closing net.Listener -type upListener struct { +type overseerListener struct { net.Listener closeError error closeByForce chan bool wg sync.WaitGroup } -func (l *upListener) Accept() (net.Conn, error) { +func (l *overseerListener) Accept() (net.Conn, error) { conn, err := l.Listener.(*net.TCPListener).AcceptTCP() if err != nil { return nil, err } conn.SetKeepAlive(true) // see http.tcpKeepAliveListener conn.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener - uconn := upConn{ + uconn := overseerConn{ Conn: conn, wg: &l.wg, closed: make(chan bool), @@ -48,7 +52,7 @@ func (l *upListener) Accept() (net.Conn, error) { } //non-blocking trigger close -func (l *upListener) release(timeout time.Duration) { +func (l *overseerListener) release(timeout time.Duration) { //stop accepting connections - release fd l.closeError = l.Listener.Close() //start timer, close by force if deadline not met @@ -68,12 +72,12 @@ func (l *upListener) release(timeout time.Duration) { } //blocking wait for close -func (l *upListener) Close() error { +func (l *overseerListener) Close() error { l.wg.Wait() return l.closeError } -func (l *upListener) File() *os.File { +func (l *overseerListener) File() *os.File { // returns a dup(2) - FD_CLOEXEC flag *not* set tl := l.Listener.(*net.TCPListener) fl, _ := tl.File() @@ -81,17 +85,17 @@ func (l *upListener) File() *os.File { } //notifying on close net.Conn -type upConn struct { +type overseerConn struct { net.Conn wg *sync.WaitGroup closed chan bool } -func (uconn upConn) Close() error { - err := uconn.Conn.Close() +func (o overseerConn) Close() error { + err := o.Conn.Close() if err == nil { - uconn.wg.Done() - uconn.closed <- true + o.wg.Done() + o.closed <- true } return err } diff --git a/proc_master.go b/proc_master.go index ca0061b..2e6e697 100644 --- a/proc_master.go +++ b/proc_master.go @@ -268,7 +268,18 @@ func (mp *master) fetch() { tokenIn := token() cmd := exec.Command(tmpBinPath) cmd.Env = []string{envBinCheck + "=" + tokenIn} + returned := false + go func() { + time.Sleep(5 * time.Second) + if !returned { + mp.warnf("sanity check against fetched executable timed-out, check overseer is running") + if cmd.Process != nil { + cmd.Process.Kill() + } + } + }() tokenOut, err := cmd.Output() + returned = true if err != nil { mp.warnf("failed to run temp binary: %s", err) return diff --git a/proc_slave.go b/proc_slave.go index d63b9ed..05f1c4b 100644 --- a/proc_slave.go +++ b/proc_slave.go @@ -49,7 +49,7 @@ type State struct { type slave struct { *Config id string - listeners []*upListener + listeners []*overseerListener masterPid int masterProc *os.Process state State @@ -104,7 +104,7 @@ func (sp *slave) initFileDescriptors() error { if err != nil { return fmt.Errorf("invalid %s integer", envNumFDs) } - sp.listeners = make([]*upListener, numFDs) + sp.listeners = make([]*overseerListener, numFDs) sp.state.Listeners = make([]net.Listener, numFDs) for i := 0; i < numFDs; i++ { f := os.NewFile(uintptr(3+i), "") @@ -112,7 +112,7 @@ func (sp *slave) initFileDescriptors() error { if err != nil { return fmt.Errorf("failed to inherit file descriptor: %d", i) } - u := newUpListener(l) + u := newOverseerListener(l) sp.listeners[i] = u sp.state.Listeners[i] = u }