diff --git a/util/resolver/authorizer.go b/util/resolver/authorizer.go index 80e2a351..e58038c4 100644 --- a/util/resolver/authorizer.go +++ b/util/resolver/authorizer.go @@ -220,15 +220,13 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R // authResult is used to control limit rate. type authResult struct { - sync.WaitGroup token string - err error expires time.Time } // authHandler is used to handle auth request per registry server. type authHandler struct { - sync.Mutex + g flightcontrol.Group client *http.Client @@ -240,7 +238,8 @@ type authHandler struct { // scopedTokens caches token indexed by scopes, which used in // bearer auth case - scopedTokens map[string]*authResult + scopedTokens map[string]*authResult + scopedTokensMu sync.Mutex lastUsed time.Time @@ -292,46 +291,44 @@ func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g // Docs: https://docs.docker.com/registry/spec/auth/scope scoped := strings.Join(to.Scopes, " ") - ah.Lock() - for { + res, err := ah.g.Do(ctx, scoped, func(ctx context.Context) (interface{}, error) { + ah.scopedTokensMu.Lock() r, exist := ah.scopedTokens[scoped] - if !exist { - // no entry cached - break - } - ah.Unlock() - r.Wait() - if r.err != nil { - select { - case <-ctx.Done(): - return "", r.err - default: + ah.scopedTokensMu.Unlock() + if exist { + if r.expires.IsZero() || r.expires.After(time.Now()) { + return r, nil } } - if !errors.Is(r.err, context.Canceled) && - (r.expires.IsZero() || r.expires.After(time.Now())) { - return r.token, r.err - } - // r.err is canceled or token expired. Get rid of it and try again - ah.Lock() - r2, exist := ah.scopedTokens[scoped] - if exist && r == r2 { - delete(ah.scopedTokens, scoped) + r, err := ah.fetchToken(ctx, sm, g, to) + if err != nil { + return nil, err } + ah.scopedTokensMu.Lock() + ah.scopedTokens[scoped] = r + ah.scopedTokensMu.Unlock() + return r, nil + }) + + if err != nil || res == nil { + return "", err } + r := res.(*authResult) + if r == nil { + return "", nil + } + return r.token, nil +} - // only one fetch token job - r := new(authResult) - r.Add(1) - ah.scopedTokens[scoped] = r - ah.Unlock() - +func (ah *authHandler) fetchToken(ctx context.Context, sm *session.Manager, g session.Group, to auth.TokenOptions) (r *authResult, err error) { var issuedAt time.Time var expires int + var token string defer func() { token = fmt.Sprintf("Bearer %s", token) - r.token, r.err = token, err + if err == nil { + r = &authResult{token: token} if issuedAt.IsZero() { issuedAt = time.Now() } @@ -339,7 +336,6 @@ func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g r.expires = exp } } - r.Done() }() if ah.authority != nil { @@ -351,10 +347,11 @@ func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g Scopes: to.Scopes, }, sm, g) if err != nil { - return "", err + return nil, err } issuedAt, expires = time.Unix(resp.IssuedAt, 0), int(resp.ExpiresIn) - return resp.Token, nil + token = resp.Token + return nil, nil } // fetch token for the resource scope @@ -374,29 +371,32 @@ func (ah *authHandler) doBearerAuth(ctx context.Context, sm *session.Manager, g if (errStatus.StatusCode == 405 && to.Username != "") || errStatus.StatusCode == 404 || errStatus.StatusCode == 401 { resp, err := auth.FetchTokenWithOAuth(ctx, ah.client, nil, "buildkit-client", to) if err != nil { - return "", err + return nil, err } issuedAt, expires = resp.IssuedAt, resp.ExpiresIn - return resp.AccessToken, nil + token = resp.AccessToken + return nil, nil } log.G(ctx).WithFields(logrus.Fields{ "status": errStatus.Status, "body": string(errStatus.Body), }).Debugf("token request failed") } - return "", err + return nil, err } issuedAt, expires = resp.IssuedAt, resp.ExpiresIn - return resp.Token, nil + token = resp.Token + return nil, nil } // do request anonymously resp, err := auth.FetchToken(ctx, ah.client, nil, to) if err != nil { - return "", errors.Wrap(err, "failed to fetch anonymous token") + return nil, errors.Wrap(err, "failed to fetch anonymous token") } issuedAt, expires = resp.IssuedAt, resp.ExpiresIn - return resp.Token, nil + token = resp.Token + return nil, nil } func invalidAuthorization(c auth.Challenge, responses []*http.Response) error {