Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d13a639172 | |||
| c4e474edb5 | |||
| 3301166e4b | |||
| cbd307d5c5 | |||
| a3def79430 | |||
| 1920308d15 | |||
| 3c3d0f32eb | |||
| 557354a718 | |||
| e303047b52 | |||
| 9062ac191b | |||
| 3050d45eb1 | |||
| 4aa1aac8b8 | |||
| 5f3dd28fe1 | |||
| 3740a132a2 | |||
| af14ec844f | |||
| d5f6756cdf |
@@ -2909,6 +2909,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
|
||||
provider.MCPToolDenyRegex = v.Value
|
||||
case "PKCE_METHODS":
|
||||
provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ")
|
||||
case "API_BASE_URL":
|
||||
provider.APIBaseURL = v.Value
|
||||
}
|
||||
providers[providerNum] = provider
|
||||
}
|
||||
|
||||
@@ -108,6 +108,29 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE=github",
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
|
||||
"CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL)
|
||||
}
|
||||
|
||||
func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
|
||||
"CODER_EXTERNAL_AUTH_0_TYPE=github",
|
||||
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, providers, 1)
|
||||
assert.Equal(t, "", providers[0].APIBaseURL)
|
||||
}
|
||||
|
||||
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
|
||||
// environment variables are still supported.
|
||||
func TestReadGitAuthProvidersFromEnv(t *testing.T) {
|
||||
|
||||
Generated
+4
@@ -15448,6 +15448,10 @@ const docTemplate = `{
|
||||
"codersdk.ExternalAuthConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_base_url": {
|
||||
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
|
||||
"type": "string"
|
||||
},
|
||||
"app_install_url": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
Generated
+4
@@ -13957,6 +13957,10 @@
|
||||
"codersdk.ExternalAuthConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"api_base_url": {
|
||||
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
|
||||
"type": "string"
|
||||
},
|
||||
"app_install_url": {
|
||||
"type": "string"
|
||||
},
|
||||
|
||||
+177
-668
@@ -13,7 +13,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -32,6 +31,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpapi/httperror"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
@@ -39,16 +40,15 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/wsjson"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
chatDiffStatusTTL = 120 * time.Second
|
||||
chatDiffBackgroundRefreshTimeout = 20 * time.Second
|
||||
githubAPIBaseURL = "https://api.github.com"
|
||||
chatStreamBatchSize = 256
|
||||
chatDiffStatusTTL = gitsync.DiffStatusTTL
|
||||
chatStreamBatchSize = 256
|
||||
|
||||
chatContextLimitModelConfigKey = "context_limit"
|
||||
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
|
||||
@@ -57,19 +57,6 @@ const (
|
||||
maxChatContextCompressionThreshold = int32(100)
|
||||
)
|
||||
|
||||
// chatDiffRefreshBackoffSchedule defines the delays between successive
|
||||
// background diff refresh attempts. The trigger fires when the agent
|
||||
// obtains a GitHub token, which is typically right before a git push
|
||||
// or PR creation. The backoff gives progressively more time for the
|
||||
// push and any PR workflow to complete before querying the GitHub API.
|
||||
var chatDiffRefreshBackoffSchedule = []time.Duration{
|
||||
1 * time.Second,
|
||||
3 * time.Second,
|
||||
5 * time.Second,
|
||||
10 * time.Second,
|
||||
20 * time.Second,
|
||||
}
|
||||
|
||||
// chatGitRef holds the branch and remote origin reported by the
|
||||
// workspace agent during a git operation.
|
||||
type chatGitRef struct {
|
||||
@@ -77,32 +64,6 @@ type chatGitRef struct {
|
||||
RemoteOrigin string
|
||||
}
|
||||
|
||||
var (
|
||||
githubPullRequestPathPattern = regexp.MustCompile(
|
||||
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
|
||||
)
|
||||
githubRepositoryHTTPSPattern = regexp.MustCompile(
|
||||
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
||||
)
|
||||
githubRepositorySSHPathPattern = regexp.MustCompile(
|
||||
`^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
||||
)
|
||||
)
|
||||
|
||||
type githubPullRequestRef struct {
|
||||
Owner string
|
||||
Repo string
|
||||
Number int
|
||||
}
|
||||
|
||||
type githubPullRequestStatus struct {
|
||||
PullRequestState string
|
||||
ChangesRequested bool
|
||||
Additions int32
|
||||
Deletions int32
|
||||
ChangedFiles int32
|
||||
}
|
||||
|
||||
type chatRepositoryRef struct {
|
||||
Provider string
|
||||
RemoteOrigin string
|
||||
@@ -1256,193 +1217,6 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time,
|
||||
return chatDiffStatusIsStale(status, now)
|
||||
}
|
||||
|
||||
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
ctx := api.ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
//nolint:gocritic // Background goroutine for diff status refresh has no user context.
|
||||
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||
|
||||
// Always store the git ref so the data is persisted even
|
||||
// before a PR exists. The frontend can show branch info
|
||||
// and the refresh loop can resolve a PR later.
|
||||
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
|
||||
|
||||
for _, delay := range chatDiffRefreshBackoffSchedule {
|
||||
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
}
|
||||
|
||||
// Refresh and publish status on every iteration.
|
||||
// Stop the loop once a PR is discovered — there's
|
||||
// nothing more to wait for after that.
|
||||
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
|
||||
}
|
||||
|
||||
// storeChatGitRef persists the git branch and remote origin reported
|
||||
// by the workspace agent on the chat that initiated the git operation.
|
||||
// When chatID is set, only that specific chat is updated; otherwise all
|
||||
// chats associated with the workspace are updated (legacy fallback).
|
||||
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
var chatsToUpdate []database.Chat
|
||||
|
||||
if chatID.Valid {
|
||||
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
|
||||
slog.F("chat_id", chatID.UUID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chatsToUpdate = []database.Chat{chat}
|
||||
} else {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: workspaceOwnerID,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
|
||||
}
|
||||
|
||||
for _, chat := range chatsToUpdate {
|
||||
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
GitBranch: gitRef.Branch,
|
||||
GitRemoteOrigin: gitRef.RemoteOrigin,
|
||||
StaleAt: time.Now().UTC().Add(-time.Second),
|
||||
Url: sql.NullString{},
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to store git ref on chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
|
||||
// associated with the given workspace. When chatID is set, only that
|
||||
// specific chat is refreshed; otherwise all chats for the workspace
|
||||
// are refreshed (legacy fallback). It returns true when every
|
||||
// refreshed chat has a PR URL resolved, signaling that the caller
|
||||
// can stop polling.
|
||||
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
|
||||
var filtered []database.Chat
|
||||
|
||||
if chatID.Valid {
|
||||
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
|
||||
slog.F("chat_id", chatID.UUID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
filtered = []database.Chat{chat}
|
||||
} else {
|
||||
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: workspaceOwnerID,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.F("workspace_owner_id", workspaceOwnerID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return false
|
||||
}
|
||||
filtered = filterChatsByWorkspaceID(chats, workspaceID)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
allHavePR := true
|
||||
for _, chat := range filtered {
|
||||
refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout)
|
||||
status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true)
|
||||
cancel()
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
allHavePR = false
|
||||
} else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" {
|
||||
allHavePR = false
|
||||
}
|
||||
|
||||
api.publishChatStatusEvent(ctx, chat.ID)
|
||||
api.publishChatDiffStatusEvent(ctx, chat.ID)
|
||||
}
|
||||
|
||||
return allHavePR
|
||||
}
|
||||
|
||||
func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat {
|
||||
filteredChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
|
||||
continue
|
||||
}
|
||||
filteredChats = append(filteredChats, chat)
|
||||
}
|
||||
return filteredChats
|
||||
}
|
||||
|
||||
func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
||||
if api.chatDaemon == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to refresh published chat status",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) {
|
||||
if api.chatDaemon == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to publish chat diff status change",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) resolveChatDiffContents(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
@@ -1490,22 +1264,36 @@ func (api *API) resolveChatDiffContents(
|
||||
if reference.RepositoryRef == nil {
|
||||
return result, nil
|
||||
}
|
||||
if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
||||
|
||||
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
|
||||
if gp == nil {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
|
||||
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("resolve git access token: %w", err)
|
||||
} else if token == nil {
|
||||
return result, xerrors.New("nil git access token")
|
||||
}
|
||||
|
||||
if reference.PullRequestURL != "" {
|
||||
diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token)
|
||||
ref, ok := gp.ParsePullRequestURL(reference.PullRequestURL)
|
||||
if !ok {
|
||||
return result, xerrors.Errorf("invalid pull request URL %q", reference.PullRequestURL)
|
||||
}
|
||||
diff, err := gp.FetchPullRequestDiff(ctx, *token, ref)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result.Diff = diff
|
||||
return result, nil
|
||||
}
|
||||
|
||||
diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token)
|
||||
diff, err := gp.FetchBranchDiff(ctx, *token, gitprovider.BranchRef{
|
||||
Owner: reference.RepositoryRef.Owner,
|
||||
Repo: reference.RepositoryRef.Repo,
|
||||
Branch: reference.RepositoryRef.Branch,
|
||||
})
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
@@ -1539,34 +1327,53 @@ func (api *API) resolveChatDiffReference(
|
||||
// If we have a repo ref with a branch, try to resolve the
|
||||
// current open PR. This picks up new PRs after the previous
|
||||
// one was closed.
|
||||
if reference.RepositoryRef != nil &&
|
||||
strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
|
||||
pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef)
|
||||
if lookupErr != nil {
|
||||
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", reference.RepositoryRef.Provider),
|
||||
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
|
||||
slog.F("branch", reference.RepositoryRef.Branch),
|
||||
slog.Error(lookupErr),
|
||||
)
|
||||
} else if pullRequestURL != "" {
|
||||
reference.PullRequestURL = pullRequestURL
|
||||
if reference.RepositoryRef != nil && reference.RepositoryRef.Owner != "" {
|
||||
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
|
||||
if gp != nil {
|
||||
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
|
||||
if token == nil || errors.Is(err, gitsync.ErrNoTokenAvailable) {
|
||||
// No token available yet.
|
||||
return reference, nil
|
||||
} else if err != nil {
|
||||
return chatDiffReference{}, xerrors.Errorf("resolve git access token: %w", err)
|
||||
}
|
||||
prRef, lookupErr := gp.ResolveBranchPullRequest(ctx, *token, gitprovider.BranchRef{
|
||||
Owner: reference.RepositoryRef.Owner,
|
||||
Repo: reference.RepositoryRef.Repo,
|
||||
Branch: reference.RepositoryRef.Branch,
|
||||
})
|
||||
if lookupErr != nil {
|
||||
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", reference.RepositoryRef.Provider),
|
||||
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
|
||||
slog.F("branch", reference.RepositoryRef.Branch),
|
||||
slog.Error(lookupErr),
|
||||
)
|
||||
} else if prRef != nil {
|
||||
reference.PullRequestURL = gp.BuildPullRequestURL(*prRef)
|
||||
}
|
||||
reference.PullRequestURL = gp.NormalizePullRequestURL(reference.PullRequestURL)
|
||||
}
|
||||
}
|
||||
|
||||
reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL)
|
||||
|
||||
// If we have a PR URL but no repo ref (e.g. the agent hasn't
|
||||
// reported branch/origin yet), derive a partial ref from the
|
||||
// PR URL so the caller can still show provider/owner/repo.
|
||||
if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
|
||||
if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok {
|
||||
reference.RepositoryRef = &chatRepositoryRef{
|
||||
Provider: string(codersdk.EnhancedExternalAuthProviderGitHub),
|
||||
RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo),
|
||||
Owner: parsed.Owner,
|
||||
Repo: parsed.Repo,
|
||||
for _, extAuth := range api.ExternalAuthConfigs {
|
||||
gp := extAuth.Git(api.HTTPClient)
|
||||
if gp == nil {
|
||||
continue
|
||||
}
|
||||
if parsed, ok := gp.ParsePullRequestURL(reference.PullRequestURL); ok {
|
||||
reference.RepositoryRef = &chatRepositoryRef{
|
||||
Provider: strings.ToLower(extAuth.Type),
|
||||
Owner: parsed.Owner,
|
||||
Repo: parsed.Repo,
|
||||
RemoteOrigin: gp.BuildRepositoryURL(parsed.Owner, parsed.Repo),
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1584,19 +1391,18 @@ func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
providerType, gp := api.resolveExternalAuth(origin)
|
||||
repoRef := &chatRepositoryRef{
|
||||
Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)),
|
||||
Provider: providerType,
|
||||
RemoteOrigin: origin,
|
||||
Branch: branch,
|
||||
}
|
||||
|
||||
if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok {
|
||||
if repoRef.Provider == "" {
|
||||
repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub)
|
||||
if gp != nil {
|
||||
if owner, repo, normalizedOrigin, ok := gp.ParseRepositoryOrigin(repoRef.RemoteOrigin); ok {
|
||||
repoRef.RemoteOrigin = normalizedOrigin
|
||||
repoRef.Owner = owner
|
||||
repoRef.Repo = repo
|
||||
}
|
||||
repoRef.RemoteOrigin = normalizedOrigin
|
||||
repoRef.Owner = owner
|
||||
repoRef.Repo = repo
|
||||
}
|
||||
|
||||
if repoRef.Provider == "" {
|
||||
@@ -1650,60 +1456,31 @@ func (api *API) getCachedChatDiffStatus(
|
||||
)
|
||||
}
|
||||
|
||||
func (api *API) resolveExternalAuthProviderType(match string) string {
|
||||
match = strings.TrimSpace(match)
|
||||
if match == "" {
|
||||
return ""
|
||||
// resolveExternalAuth finds the external auth config matching the
|
||||
// given remote origin URL and returns both the provider type string
|
||||
// (e.g. "github") and the gitprovider.Provider. Returns ("", nil)
|
||||
// if no matching config is found.
|
||||
func (api *API) resolveExternalAuth(origin string) (providerType string, gp gitprovider.Provider) {
|
||||
origin = strings.TrimSpace(origin)
|
||||
if origin == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
for _, extAuth := range api.ExternalAuthConfigs {
|
||||
if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) {
|
||||
if extAuth.Regex == nil || !extAuth.Regex.MatchString(origin) {
|
||||
continue
|
||||
}
|
||||
return strings.ToLower(strings.TrimSpace(extAuth.Type))
|
||||
return strings.ToLower(strings.TrimSpace(extAuth.Type)),
|
||||
extAuth.Git(api.HTTPClient)
|
||||
}
|
||||
|
||||
return ""
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw)
|
||||
if len(matches) != 3 {
|
||||
matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw)
|
||||
}
|
||||
if len(matches) != 3 {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
owner = strings.TrimSpace(matches[1])
|
||||
repo = strings.TrimSpace(matches[2])
|
||||
repo = strings.TrimSuffix(repo, ".git")
|
||||
if owner == "" || repo == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true
|
||||
}
|
||||
|
||||
func buildGitHubBranchURL(owner string, repo string, branch string) string {
|
||||
owner = strings.TrimSpace(owner)
|
||||
repo = strings.TrimSpace(repo)
|
||||
branch = strings.TrimSpace(branch)
|
||||
if owner == "" || repo == "" || branch == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"https://github.com/%s/%s/tree/%s",
|
||||
owner,
|
||||
repo,
|
||||
url.PathEscape(branch),
|
||||
)
|
||||
// resolveGitProvider finds the external auth config matching the
|
||||
// given remote origin URL and returns its git provider. Returns
|
||||
// nil if no matching git provider is configured.
|
||||
func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
|
||||
_, gp := api.resolveExternalAuth(origin)
|
||||
return gp
|
||||
}
|
||||
|
||||
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
|
||||
@@ -1719,11 +1496,32 @@ func (api *API) refreshChatDiffStatus(
|
||||
chatID uuid.UUID,
|
||||
pullRequestURL string,
|
||||
) (database.ChatDiffStatus, error) {
|
||||
status, err := api.fetchGitHubPullRequestStatus(
|
||||
ctx,
|
||||
pullRequestURL,
|
||||
api.resolveChatGitHubAccessToken(ctx, chatOwnerID),
|
||||
)
|
||||
// Find a provider that can handle this PR URL.
|
||||
var gp gitprovider.Provider
|
||||
var ref gitprovider.PRRef
|
||||
for _, extAuth := range api.ExternalAuthConfigs {
|
||||
p := extAuth.Git(api.HTTPClient)
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
if parsed, ok := p.ParsePullRequestURL(pullRequestURL); ok {
|
||||
gp = p
|
||||
ref = parsed
|
||||
break
|
||||
}
|
||||
}
|
||||
if gp == nil {
|
||||
return database.ChatDiffStatus{}, xerrors.Errorf("no git provider found for PR URL %q", pullRequestURL)
|
||||
}
|
||||
|
||||
origin := gp.BuildRepositoryURL(ref.Owner, ref.Repo)
|
||||
token, err := api.resolveChatGitAccessToken(ctx, chatOwnerID, origin)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, xerrors.Errorf("resolve git access token: %w", err)
|
||||
} else if token == nil {
|
||||
return database.ChatDiffStatus{}, xerrors.New("nil git access token")
|
||||
}
|
||||
status, err := gp.FetchPullRequestStatus(ctx, *token, ref)
|
||||
if err != nil {
|
||||
return database.ChatDiffStatus{}, err
|
||||
}
|
||||
@@ -1735,13 +1533,13 @@ func (api *API) refreshChatDiffStatus(
|
||||
ChatID: chatID,
|
||||
Url: sql.NullString{String: pullRequestURL, Valid: true},
|
||||
PullRequestState: sql.NullString{
|
||||
String: status.PullRequestState,
|
||||
Valid: status.PullRequestState != "",
|
||||
String: string(status.State),
|
||||
Valid: status.State != "",
|
||||
},
|
||||
ChangesRequested: status.ChangesRequested,
|
||||
Additions: status.Additions,
|
||||
Deletions: status.Deletions,
|
||||
ChangedFiles: status.ChangedFiles,
|
||||
Additions: status.DiffStats.Additions,
|
||||
Deletions: status.DiffStats.Deletions,
|
||||
ChangedFiles: status.DiffStats.ChangedFiles,
|
||||
RefreshedAt: refreshedAt,
|
||||
StaleAt: refreshedAt.Add(chatDiffStatusTTL),
|
||||
},
|
||||
@@ -1752,23 +1550,49 @@ func (api *API) refreshChatDiffStatus(
|
||||
return refreshedStatus, nil
|
||||
}
|
||||
|
||||
func (api *API) resolveChatGitHubAccessToken(
|
||||
func (api *API) resolveChatGitAccessToken(
|
||||
ctx context.Context,
|
||||
userID uuid.UUID,
|
||||
) string {
|
||||
// Build a map of provider ID -> config so we can refresh tokens
|
||||
// using the same code path as provisionerdserver.
|
||||
ghConfigs := make(map[string]*externalauth.Config)
|
||||
providerIDs := []string{"github"}
|
||||
for _, config := range api.ExternalAuthConfigs {
|
||||
if !strings.EqualFold(
|
||||
config.Type,
|
||||
string(codersdk.EnhancedExternalAuthProviderGitHub),
|
||||
) {
|
||||
continue
|
||||
origin string,
|
||||
) (*string, error) {
|
||||
origin = strings.TrimSpace(origin)
|
||||
|
||||
// If we have an origin, find the specific matching config first.
|
||||
// This ensures multi-provider setups (github.com + GHE) get the
|
||||
// correct token.
|
||||
if origin != "" {
|
||||
for _, config := range api.ExternalAuthConfigs {
|
||||
if config.Regex == nil || !config.Regex.MatchString(origin) {
|
||||
continue
|
||||
}
|
||||
link, err := api.Database.GetExternalAuthLink(ctx,
|
||||
database.GetExternalAuthLinkParams{
|
||||
ProviderID: config.ID,
|
||||
UserID: userID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
refreshed, refreshErr := config.RefreshToken(ctx, api.Database, link)
|
||||
if refreshErr == nil {
|
||||
link = refreshed
|
||||
}
|
||||
token := strings.TrimSpace(link.OAuthAccessToken)
|
||||
if token != "" {
|
||||
return ptr.Ref(token), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: iterate all external auth configs.
|
||||
// Used when origin is empty (inline refresh from HTTP handler)
|
||||
// or when the origin-specific lookup above failed.
|
||||
configs := make(map[string]*externalauth.Config)
|
||||
providerIDs := []string{}
|
||||
for _, config := range api.ExternalAuthConfigs {
|
||||
providerIDs = append(providerIDs, config.ID)
|
||||
ghConfigs[config.ID] = config
|
||||
configs[config.ID] = config
|
||||
}
|
||||
|
||||
seen := map[string]struct{}{}
|
||||
@@ -1792,7 +1616,7 @@ func (api *API) resolveChatGitHubAccessToken(
|
||||
// Refresh the token if there is a matching config, mirroring
|
||||
// the same code path used by provisionerdserver when handing
|
||||
// tokens to provisioners.
|
||||
if cfg, ok := ghConfigs[providerID]; ok {
|
||||
if cfg, ok := configs[providerID]; ok {
|
||||
refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link)
|
||||
if refreshErr != nil {
|
||||
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
|
||||
@@ -1809,336 +1633,11 @@ func (api *API) resolveChatGitHubAccessToken(
|
||||
|
||||
token := strings.TrimSpace(link.OAuthAccessToken)
|
||||
if token != "" {
|
||||
return token
|
||||
return ptr.Ref(token), nil
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (api *API) resolveGitHubPullRequestURLFromRepositoryRef(
|
||||
ctx context.Context,
|
||||
userID uuid.UUID,
|
||||
repositoryRef chatRepositoryRef,
|
||||
) (string, error) {
|
||||
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("state", "open")
|
||||
query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch))
|
||||
query.Set("sort", "updated")
|
||||
query.Set("direction", "desc")
|
||||
query.Set("per_page", "1")
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls?%s",
|
||||
githubAPIBaseURL,
|
||||
repositoryRef.Owner,
|
||||
repositoryRef.Repo,
|
||||
query.Encode(),
|
||||
)
|
||||
|
||||
var pulls []struct {
|
||||
HTMLURL string `json:"html_url"`
|
||||
}
|
||||
|
||||
token := api.resolveChatGitHubAccessToken(ctx, userID)
|
||||
if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(pulls) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil
|
||||
}
|
||||
|
||||
func (api *API) fetchGitHubPullRequestDiff(
|
||||
ctx context.Context,
|
||||
pullRequestURL string,
|
||||
token string,
|
||||
) (string, error) {
|
||||
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL)
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls/%d",
|
||||
githubAPIBaseURL,
|
||||
ref.Owner,
|
||||
ref.Repo,
|
||||
ref.Number,
|
||||
)
|
||||
|
||||
return api.fetchGitHubDiff(ctx, requestURL, token)
|
||||
}
|
||||
|
||||
func (api *API) fetchGitHubCompareDiff(
|
||||
ctx context.Context,
|
||||
repositoryRef chatRepositoryRef,
|
||||
token string,
|
||||
) (string, error) {
|
||||
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var repository struct {
|
||||
DefaultBranch string `json:"default_branch"`
|
||||
}
|
||||
|
||||
repositoryURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s",
|
||||
githubAPIBaseURL,
|
||||
repositoryRef.Owner,
|
||||
repositoryRef.Repo,
|
||||
)
|
||||
if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
|
||||
if defaultBranch == "" {
|
||||
return "", xerrors.New("github repository default branch is empty")
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/compare/%s...%s",
|
||||
githubAPIBaseURL,
|
||||
repositoryRef.Owner,
|
||||
repositoryRef.Repo,
|
||||
url.PathEscape(defaultBranch),
|
||||
url.PathEscape(repositoryRef.Branch),
|
||||
)
|
||||
|
||||
return api.fetchGitHubDiff(ctx, requestURL, token)
|
||||
}
|
||||
|
||||
func (api *API) fetchGitHubDiff(
|
||||
ctx context.Context,
|
||||
requestURL string,
|
||||
token string,
|
||||
) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("create github diff request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.diff")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
req.Header.Set("User-Agent", "coder-chat-diff")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
httpClient := api.HTTPClient
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("execute github diff request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
return "", xerrors.Errorf(
|
||||
"github diff request failed with status %d: %s",
|
||||
resp.StatusCode,
|
||||
strings.TrimSpace(string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("read github diff response: %w", err)
|
||||
}
|
||||
return string(diff), nil
|
||||
}
|
||||
|
||||
func (api *API) fetchGitHubPullRequestStatus(
|
||||
ctx context.Context,
|
||||
pullRequestURL string,
|
||||
token string,
|
||||
) (githubPullRequestStatus, error) {
|
||||
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
|
||||
if !ok {
|
||||
return githubPullRequestStatus{}, xerrors.Errorf(
|
||||
"invalid GitHub pull request URL %q",
|
||||
pullRequestURL,
|
||||
)
|
||||
}
|
||||
|
||||
pullEndpoint := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls/%d",
|
||||
githubAPIBaseURL,
|
||||
ref.Owner,
|
||||
ref.Repo,
|
||||
ref.Number,
|
||||
)
|
||||
|
||||
var pull struct {
|
||||
State string `json:"state"`
|
||||
Additions int32 `json:"additions"`
|
||||
Deletions int32 `json:"deletions"`
|
||||
ChangedFiles int32 `json:"changed_files"`
|
||||
}
|
||||
if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil {
|
||||
return githubPullRequestStatus{}, err
|
||||
}
|
||||
|
||||
var reviews []struct {
|
||||
ID int64 `json:"id"`
|
||||
State string `json:"state"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
}
|
||||
if err := api.decodeGitHubJSON(
|
||||
ctx,
|
||||
pullEndpoint+"/reviews?per_page=100",
|
||||
token,
|
||||
&reviews,
|
||||
); err != nil {
|
||||
return githubPullRequestStatus{}, err
|
||||
}
|
||||
|
||||
return githubPullRequestStatus{
|
||||
PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)),
|
||||
ChangesRequested: hasOutstandingGitHubChangesRequested(reviews),
|
||||
Additions: pull.Additions,
|
||||
Deletions: pull.Deletions,
|
||||
ChangedFiles: pull.ChangedFiles,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api *API) decodeGitHubJSON(
|
||||
ctx context.Context,
|
||||
requestURL string,
|
||||
token string,
|
||||
dest any,
|
||||
) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create github request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
req.Header.Set("User-Agent", "coder-chat-diff-status")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
httpClient := api.HTTPClient
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute github request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return xerrors.Errorf(
|
||||
"github request failed with status %d",
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
return xerrors.Errorf(
|
||||
"github request failed with status %d: %s",
|
||||
resp.StatusCode,
|
||||
strings.TrimSpace(string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
|
||||
return xerrors.Errorf("decode github response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasOutstandingGitHubChangesRequested(
|
||||
reviews []struct {
|
||||
ID int64 `json:"id"`
|
||||
State string `json:"state"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
},
|
||||
) bool {
|
||||
type reviewerState struct {
|
||||
reviewID int64
|
||||
state string
|
||||
}
|
||||
|
||||
statesByReviewer := make(map[string]reviewerState)
|
||||
for _, review := range reviews {
|
||||
login := strings.ToLower(strings.TrimSpace(review.User.Login))
|
||||
if login == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
state := strings.ToUpper(strings.TrimSpace(review.State))
|
||||
switch state {
|
||||
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
current, exists := statesByReviewer[login]
|
||||
if exists && current.reviewID > review.ID {
|
||||
continue
|
||||
}
|
||||
statesByReviewer[login] = reviewerState{
|
||||
reviewID: review.ID,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
for _, state := range statesByReviewer {
|
||||
if state.state == "CHANGES_REQUESTED" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeGitHubPullRequestURL(raw string) string {
|
||||
ref, ok := parseGitHubPullRequestURL(strings.TrimRight(
|
||||
strings.TrimSpace(raw),
|
||||
"),.;",
|
||||
))
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
|
||||
}
|
||||
|
||||
func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) {
|
||||
matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
|
||||
if len(matches) != 4 {
|
||||
return githubPullRequestRef{}, false
|
||||
}
|
||||
|
||||
number, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return githubPullRequestRef{}, false
|
||||
}
|
||||
|
||||
return githubPullRequestRef{
|
||||
Owner: matches[1],
|
||||
Repo: matches[2],
|
||||
Number: number,
|
||||
}, true
|
||||
return nil, gitsync.ErrNoTokenAvailable
|
||||
}
|
||||
|
||||
type createChatWorkspaceSelection struct {
|
||||
@@ -2696,11 +2195,21 @@ func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) co
|
||||
}
|
||||
}
|
||||
if result.URL == nil {
|
||||
owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin)
|
||||
if ok {
|
||||
branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch)
|
||||
if branchURL != "" {
|
||||
result.URL = &branchURL
|
||||
// Try to build a branch URL from the stored origin.
|
||||
// Since convertChatDiffStatus does not have access to
|
||||
// the API instance, we construct a GitHub provider
|
||||
// directly as a best-effort fallback.
|
||||
// TODO: This uses the default github.com API base URL,
|
||||
// so branch URLs for GitHub Enterprise instances will
|
||||
// be incorrect. To fix this, convertChatDiffStatus
|
||||
// would need access to the external auth configs.
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
if gp != nil {
|
||||
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
|
||||
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
|
||||
if branchURL != "" {
|
||||
result.URL = &branchURL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2521,7 +2521,7 @@ func TestGetChatDiffStatus(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID)
|
||||
require.NotNil(t, cachedStatus.URL)
|
||||
require.Equal(t, "https://github.com/coder/coder/tree/feature%2Fdiff-status", *cachedStatus.URL)
|
||||
require.Equal(t, "https://github.com/coder/coder/tree/feature/diff-status", *cachedStatus.URL)
|
||||
require.NotNil(t, cachedStatus.PullRequestState)
|
||||
require.Equal(t, "open", *cachedStatus.PullRequestState)
|
||||
require.True(t, cachedStatus.ChangesRequested)
|
||||
|
||||
@@ -61,6 +61,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
"github.com/coder/coder/v2/coderd/files"
|
||||
"github.com/coder/coder/v2/coderd/gitsshkey"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck"
|
||||
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
@@ -772,6 +773,20 @@ func New(options *Options) *API {
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
})
|
||||
refresher := gitsync.NewRefresher(
|
||||
api.resolveGitProvider,
|
||||
api.resolveChatGitAccessToken,
|
||||
options.Logger.Named("gitsync").Named("refresher"),
|
||||
quartz.NewReal(),
|
||||
)
|
||||
api.gitSyncWorker = gitsync.NewWorker(options.Database,
|
||||
refresher,
|
||||
api.chatDaemon,
|
||||
quartz.NewReal(),
|
||||
options.Logger.Named("gitsync"),
|
||||
)
|
||||
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
|
||||
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(stn)
|
||||
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
|
||||
@@ -1989,6 +2004,9 @@ type API struct {
|
||||
dbRolluper *dbrollup.Rolluper
|
||||
// chatDaemon handles background processing of pending chats.
|
||||
chatDaemon *chatd.Server
|
||||
// gitSyncWorker refreshes stale chat diff statuses in the
|
||||
// background.
|
||||
gitSyncWorker *gitsync.Worker
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -2018,6 +2036,13 @@ func (api *API) Close() error {
|
||||
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
|
||||
}
|
||||
api.dbRolluper.Close()
|
||||
// chatDiffWorker is unconditionally initialized in New().
|
||||
select {
|
||||
case <-api.gitSyncWorker.Done():
|
||||
case <-time.After(10 * time.Second):
|
||||
api.Logger.Warn(context.Background(),
|
||||
"chat diff refresh worker did not exit in time")
|
||||
}
|
||||
if err := api.chatDaemon.Close(); err != nil {
|
||||
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
|
||||
}
|
||||
|
||||
@@ -1539,6 +1539,17 @@ func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.Acquir
|
||||
return q.db.AcquireProvisionerJob(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
// This is a system-level batch operation used by the gitsync
|
||||
// background worker. Per-object authorization is impractical
|
||||
// for a SKIP LOCKED acquisition query; callers must use
|
||||
// AsChatd or AsSystemRestricted context.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.AcquireStaleChatDiffStatuses(ctx, limitVal)
|
||||
}
|
||||
|
||||
func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
|
||||
fetch := func(ctx context.Context, arg database.ActivityBumpWorkspaceParams) (database.Workspace, error) {
|
||||
return q.db.GetWorkspaceByID(ctx, arg.WorkspaceID)
|
||||
@@ -1577,6 +1588,16 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
|
||||
return q.db.ArchiveUnusedTemplateVersions(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
|
||||
// This is a system-level operation used by the gitsync
|
||||
// background worker to reschedule failed refreshes. Same
|
||||
// authorization pattern as AcquireStaleChatDiffStatuses.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.BackoffChatDiffStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
// Could be any workspace agent and checking auth to each workspace agent is overkill for
|
||||
// the purpose of this function.
|
||||
|
||||
@@ -758,6 +758,18 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
|
||||
}))
|
||||
s.Run("AcquireStaleChatDiffStatuses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), int32(10)).Return([]database.AcquireStaleChatDiffStatusesRow{}, nil).AnyTimes()
|
||||
check.Args(int32(10)).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.AcquireStaleChatDiffStatusesRow{})
|
||||
}))
|
||||
s.Run("BackoffChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.BackoffChatDiffStatusParams{
|
||||
ChatID: uuid.New(),
|
||||
StaleAt: dbtime.Now(),
|
||||
}
|
||||
dbm.EXPECT().BackoffChatDiffStatus(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *MethodTestSuite) TestFile() {
|
||||
|
||||
@@ -136,6 +136,14 @@ func (m queryMetricsStore) AcquireProvisionerJob(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.AcquireStaleChatDiffStatuses(ctx, limitVal)
|
||||
m.queryLatencies.WithLabelValues("AcquireStaleChatDiffStatuses").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireStaleChatDiffStatuses").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ActivityBumpWorkspace(ctx, arg)
|
||||
@@ -168,6 +176,14 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BackoffChatDiffStatus(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("BackoffChatDiffStatus").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BackoffChatDiffStatus").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
|
||||
|
||||
@@ -103,6 +103,21 @@ func (mr *MockStoreMockRecorder) AcquireProvisionerJob(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireProvisionerJob", reflect.TypeOf((*MockStore)(nil).AcquireProvisionerJob), ctx, arg)
|
||||
}
|
||||
|
||||
// AcquireStaleChatDiffStatuses mocks base method.
|
||||
func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal)
|
||||
ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses.
|
||||
func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal)
|
||||
}
|
||||
|
||||
// ActivityBumpWorkspace mocks base method.
|
||||
func (m *MockStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -161,6 +176,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg)
|
||||
}
|
||||
|
||||
// BackoffChatDiffStatus mocks base method.
|
||||
func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus.
|
||||
func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// BatchUpdateWorkspaceAgentMetadata mocks base method.
|
||||
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -39,6 +39,7 @@ type sqlcQuerier interface {
|
||||
// multiple provisioners from acquiring the same jobs. See:
|
||||
// https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
|
||||
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
|
||||
AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error)
|
||||
// Bumps the workspace deadline by the template's configured "activity_bump"
|
||||
// duration (default 1h). If the workspace bump will cross an autostart
|
||||
// threshold, then the bump is autostart + TTL. This is the deadline behavior if
|
||||
@@ -60,6 +61,7 @@ type sqlcQuerier interface {
|
||||
// Only unused template versions will be archived, which are any versions not
|
||||
// referenced by the latest build of a workspace.
|
||||
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
|
||||
BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error
|
||||
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
|
||||
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
|
||||
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
|
||||
|
||||
@@ -3026,6 +3026,102 @@ func (q *sqlQuerier) AcquireChat(ctx context.Context, arg AcquireChatParams) (Ch
|
||||
return i, err
|
||||
}
|
||||
|
||||
const acquireStaleChatDiffStatuses = `-- name: AcquireStaleChatDiffStatuses :many
|
||||
WITH acquired AS (
|
||||
UPDATE
|
||||
chat_diff_statuses
|
||||
SET
|
||||
-- Claim for 5 minutes. The worker sets the real stale_at
|
||||
-- after refresh. If the worker crashes, rows become eligible
|
||||
-- again after this interval.
|
||||
stale_at = NOW() + INTERVAL '5 minutes',
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
chat_id IN (
|
||||
SELECT
|
||||
cds.chat_id
|
||||
FROM
|
||||
chat_diff_statuses cds
|
||||
INNER JOIN
|
||||
chats c ON c.id = cds.chat_id
|
||||
WHERE
|
||||
cds.stale_at <= NOW()
|
||||
AND cds.git_remote_origin != ''
|
||||
AND cds.git_branch != ''
|
||||
AND c.archived = FALSE
|
||||
ORDER BY
|
||||
cds.stale_at ASC
|
||||
FOR UPDATE OF cds
|
||||
SKIP LOCKED
|
||||
LIMIT
|
||||
$1::int
|
||||
)
|
||||
RETURNING chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin
|
||||
)
|
||||
SELECT
|
||||
acquired.chat_id, acquired.url, acquired.pull_request_state, acquired.changes_requested, acquired.additions, acquired.deletions, acquired.changed_files, acquired.refreshed_at, acquired.stale_at, acquired.created_at, acquired.updated_at, acquired.git_branch, acquired.git_remote_origin,
|
||||
c.owner_id
|
||||
FROM
|
||||
acquired
|
||||
INNER JOIN
|
||||
chats c ON c.id = acquired.chat_id
|
||||
`
|
||||
|
||||
type AcquireStaleChatDiffStatusesRow struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Url sql.NullString `db:"url" json:"url"`
|
||||
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
|
||||
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
|
||||
Additions int32 `db:"additions" json:"additions"`
|
||||
Deletions int32 `db:"deletions" json:"deletions"`
|
||||
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
|
||||
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
|
||||
StaleAt time.Time `db:"stale_at" json:"stale_at"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
GitBranch string `db:"git_branch" json:"git_branch"`
|
||||
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, acquireStaleChatDiffStatuses, limitVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []AcquireStaleChatDiffStatusesRow
|
||||
for rows.Next() {
|
||||
var i AcquireStaleChatDiffStatusesRow
|
||||
if err := rows.Scan(
|
||||
&i.ChatID,
|
||||
&i.Url,
|
||||
&i.PullRequestState,
|
||||
&i.ChangesRequested,
|
||||
&i.Additions,
|
||||
&i.Deletions,
|
||||
&i.ChangedFiles,
|
||||
&i.RefreshedAt,
|
||||
&i.StaleAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.GitBranch,
|
||||
&i.GitRemoteOrigin,
|
||||
&i.OwnerID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const archiveChatByID = `-- name: ArchiveChatByID :exec
|
||||
UPDATE chats SET archived = true, updated_at = NOW()
|
||||
WHERE id = $1 OR root_chat_id = $1
|
||||
@@ -3036,6 +3132,26 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
|
||||
return err
|
||||
}
|
||||
|
||||
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
|
||||
UPDATE
|
||||
chat_diff_statuses
|
||||
SET
|
||||
stale_at = $1::timestamptz,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
chat_id = $2::uuid
|
||||
`
|
||||
|
||||
type BackoffChatDiffStatusParams struct {
|
||||
StaleAt time.Time `db:"stale_at" json:"stale_at"`
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error {
|
||||
_, err := q.db.ExecContext(ctx, backoffChatDiffStatus, arg.StaleAt, arg.ChatID)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteAllChatQueuedMessages = `-- name: DeleteAllChatQueuedMessages :exec
|
||||
DELETE FROM chat_queued_messages WHERE chat_id = $1
|
||||
`
|
||||
|
||||
@@ -410,3 +410,52 @@ RETURNING *;
|
||||
|
||||
-- name: GetChatByIDForUpdate :one
|
||||
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
|
||||
|
||||
-- name: AcquireStaleChatDiffStatuses :many
|
||||
WITH acquired AS (
|
||||
UPDATE
|
||||
chat_diff_statuses
|
||||
SET
|
||||
-- Claim for 5 minutes. The worker sets the real stale_at
|
||||
-- after refresh. If the worker crashes, rows become eligible
|
||||
-- again after this interval.
|
||||
stale_at = NOW() + INTERVAL '5 minutes',
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
chat_id IN (
|
||||
SELECT
|
||||
cds.chat_id
|
||||
FROM
|
||||
chat_diff_statuses cds
|
||||
INNER JOIN
|
||||
chats c ON c.id = cds.chat_id
|
||||
WHERE
|
||||
cds.stale_at <= NOW()
|
||||
AND cds.git_remote_origin != ''
|
||||
AND cds.git_branch != ''
|
||||
AND c.archived = FALSE
|
||||
ORDER BY
|
||||
cds.stale_at ASC
|
||||
FOR UPDATE OF cds
|
||||
SKIP LOCKED
|
||||
LIMIT
|
||||
@limit_val::int
|
||||
)
|
||||
RETURNING *
|
||||
)
|
||||
SELECT
|
||||
acquired.*,
|
||||
c.owner_id
|
||||
FROM
|
||||
acquired
|
||||
INNER JOIN
|
||||
chats c ON c.id = acquired.chat_id;
|
||||
|
||||
-- name: BackoffChatDiffStatus :exec
|
||||
UPDATE
|
||||
chat_diff_statuses
|
||||
SET
|
||||
stale_at = @stale_at::timestamptz,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/promoauth"
|
||||
"github.com/coder/coder/v2/coderd/util/slice"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
@@ -82,6 +83,10 @@ type Config struct {
|
||||
// a Git clone. e.g. "Username for 'https://github.com':"
|
||||
// The regex would be `github\.com`..
|
||||
Regex *regexp.Regexp
|
||||
// APIBaseURL is the base URL for provider REST API calls
|
||||
// (e.g., "https://api.github.com" for GitHub). Derived from
|
||||
// defaults when not explicitly configured.
|
||||
APIBaseURL string
|
||||
// AppInstallURL is for GitHub App's (and hopefully others eventually)
|
||||
// to provide a link to install the app. There's installation
|
||||
// of the application, and user authentication. It's possible
|
||||
@@ -106,12 +111,23 @@ type Config struct {
|
||||
CodeChallengeMethodsSupported []promoauth.Oauth2PKCEChallengeMethod
|
||||
}
|
||||
|
||||
// Git returns a Provider for this config if the provider type
|
||||
// is a supported git hosting provider. Returns nil for non-git
|
||||
// providers (e.g. Slack, JFrog).
|
||||
func (c *Config) Git(client *http.Client) gitprovider.Provider {
|
||||
norm := strings.ToLower(c.Type)
|
||||
if !codersdk.EnhancedExternalAuthProvider(norm).Git() {
|
||||
return nil
|
||||
}
|
||||
return gitprovider.New(norm, c.APIBaseURL, client)
|
||||
}
|
||||
|
||||
// GenerateTokenExtra generates the extra token data to store in the database.
|
||||
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
|
||||
if len(c.ExtraTokenKeys) == 0 {
|
||||
return pqtype.NullRawMessage{}, nil
|
||||
}
|
||||
extraMap := map[string]interface{}{}
|
||||
extraMap := map[string]any{}
|
||||
for _, key := range c.ExtraTokenKeys {
|
||||
extraMap[key] = token.Extra(key)
|
||||
}
|
||||
@@ -729,6 +745,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
|
||||
ClientID: entry.ClientID,
|
||||
ClientSecret: entry.ClientSecret,
|
||||
Regex: regex,
|
||||
APIBaseURL: entry.APIBaseURL,
|
||||
Type: entry.Type,
|
||||
NoRefresh: entry.NoRefresh,
|
||||
ValidateURL: entry.ValidateURL,
|
||||
@@ -765,7 +782,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
|
||||
|
||||
// applyDefaultsToConfig applies defaults to the config entry.
|
||||
func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
|
||||
configType := codersdk.EnhancedExternalAuthProvider(config.Type)
|
||||
configType := codersdk.EnhancedExternalAuthProvider(strings.ToLower(config.Type))
|
||||
if configType == "bitbucket" {
|
||||
// For backwards compatibility, we need to support the "bitbucket" string.
|
||||
configType = codersdk.EnhancedExternalAuthProviderBitBucketCloud
|
||||
@@ -782,7 +799,7 @@ func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
|
||||
}
|
||||
|
||||
// Dynamic defaults
|
||||
switch codersdk.EnhancedExternalAuthProvider(config.Type) {
|
||||
switch configType {
|
||||
case codersdk.EnhancedExternalAuthProviderGitHub:
|
||||
copyDefaultSettings(config, gitHubDefaults(config))
|
||||
return
|
||||
@@ -863,6 +880,18 @@ func copyDefaultSettings(config *codersdk.ExternalAuthConfig, defaults codersdk.
|
||||
if config.CodeChallengeMethodsSupported == nil {
|
||||
config.CodeChallengeMethodsSupported = []string{string(promoauth.PKCEChallengeMethodSha256)}
|
||||
}
|
||||
|
||||
// Set default API base URL for providers that need one.
|
||||
if config.APIBaseURL == "" {
|
||||
switch codersdk.EnhancedExternalAuthProvider(config.Type) {
|
||||
case codersdk.EnhancedExternalAuthProviderGitHub:
|
||||
config.APIBaseURL = "https://api.github.com"
|
||||
case codersdk.EnhancedExternalAuthProviderGitLab:
|
||||
config.APIBaseURL = "https://gitlab.com/api/v4"
|
||||
case codersdk.EnhancedExternalAuthProviderGitea:
|
||||
config.APIBaseURL = "https://gitea.com/api/v1"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// gitHubDefaults returns default config values for GitHub.
|
||||
|
||||
@@ -25,6 +25,7 @@ func TestGitlabDefaults(t *testing.T) {
|
||||
DisplayName: "GitLab",
|
||||
DisplayIcon: "/icon/gitlab.svg",
|
||||
Regex: `^(https?://)?gitlab\.com(/.*)?$`,
|
||||
APIBaseURL: "https://gitlab.com/api/v4",
|
||||
Scopes: []string{"write_repository"},
|
||||
CodeChallengeMethodsSupported: []string{string(promoauth.PKCEChallengeMethodSha256)},
|
||||
}
|
||||
|
||||
@@ -844,6 +844,40 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
|
||||
return fake, config, link
|
||||
}
|
||||
|
||||
func TestApplyDefaultsToConfig_CaseInsensitive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
instrument := promoauth.NewFactory(prometheus.NewRegistry())
|
||||
accessURL, err := url.Parse("https://coder.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range []struct {
|
||||
Name string
|
||||
Type string
|
||||
}{
|
||||
{Name: "GitHub", Type: "GitHub"},
|
||||
{Name: "GITLAB", Type: "GITLAB"},
|
||||
{Name: "Gitea", Type: "Gitea"},
|
||||
} {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
configs, err := externalauth.ConvertConfig(
|
||||
instrument,
|
||||
[]codersdk.ExternalAuthConfig{{
|
||||
Type: tc.Type,
|
||||
ClientID: "test-id",
|
||||
ClientSecret: "test-secret",
|
||||
}},
|
||||
accessURL,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, configs, 1)
|
||||
// Defaults should have been applied despite mixed-case Type.
|
||||
assert.NotEmpty(t, configs[0].AuthCodeURL("state"), "auth URL should be populated from defaults")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripper func(req *http.Request) (*http.Response, error)
|
||||
|
||||
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package gitprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const defaultGitHubAPIBaseURL = "https://api.github.com"
|
||||
|
||||
type githubProvider struct {
|
||||
apiBaseURL string
|
||||
webBaseURL string
|
||||
httpClient *http.Client
|
||||
clock quartz.Clock
|
||||
|
||||
// Compiled per-instance to support GitHub Enterprise hosts.
|
||||
pullRequestPathPattern *regexp.Regexp
|
||||
repositoryHTTPSPattern *regexp.Regexp
|
||||
repositorySSHPathPattern *regexp.Regexp
|
||||
}
|
||||
|
||||
func newGitHub(apiBaseURL string, httpClient *http.Client, clock quartz.Clock) *githubProvider {
|
||||
if apiBaseURL == "" {
|
||||
apiBaseURL = defaultGitHubAPIBaseURL
|
||||
}
|
||||
apiBaseURL = strings.TrimRight(apiBaseURL, "/")
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
// Derive the web base URL from the API base URL.
|
||||
// github.com: api.github.com → github.com
|
||||
// GHE: ghes.corp.com/api/v3 → ghes.corp.com
|
||||
webBaseURL := deriveWebBaseURL(apiBaseURL)
|
||||
|
||||
// Parse the host for regex construction.
|
||||
host := extractHost(webBaseURL)
|
||||
|
||||
// Escape the host for use in regex patterns.
|
||||
escapedHost := regexp.QuoteMeta(host)
|
||||
|
||||
return &githubProvider{
|
||||
apiBaseURL: apiBaseURL,
|
||||
webBaseURL: webBaseURL,
|
||||
httpClient: httpClient,
|
||||
clock: clock,
|
||||
pullRequestPathPattern: regexp.MustCompile(
|
||||
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
|
||||
),
|
||||
repositoryHTTPSPattern: regexp.MustCompile(
|
||||
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
||||
),
|
||||
repositorySSHPathPattern: regexp.MustCompile(
|
||||
`^(?:ssh://)?git@` + escapedHost + `[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// deriveWebBaseURL converts a GitHub API base URL to the
|
||||
// corresponding web base URL.
|
||||
//
|
||||
// github.com: https://api.github.com → https://github.com
|
||||
// GHE: https://ghes.corp.com/api/v3 → https://ghes.corp.com
|
||||
func deriveWebBaseURL(apiBaseURL string) string {
|
||||
u, err := url.Parse(apiBaseURL)
|
||||
if err != nil {
|
||||
return "https://github.com"
|
||||
}
|
||||
|
||||
// Standard github.com: API host is api.github.com.
|
||||
if strings.EqualFold(u.Host, "api.github.com") {
|
||||
return "https://github.com"
|
||||
}
|
||||
|
||||
// GHE: strip /api/v3 path suffix.
|
||||
u.Path = strings.TrimSuffix(u.Path, "/api/v3")
|
||||
u.Path = strings.TrimSuffix(u.Path, "/")
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// extractHost returns the host portion of a URL.
|
||||
func extractHost(rawURL string) string {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "github.com"
|
||||
}
|
||||
return u.Host
|
||||
}
|
||||
|
||||
func (g *githubProvider) ParseRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
matches := g.repositoryHTTPSPattern.FindStringSubmatch(raw)
|
||||
if len(matches) != 3 {
|
||||
matches = g.repositorySSHPathPattern.FindStringSubmatch(raw)
|
||||
}
|
||||
if len(matches) != 3 {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
owner = strings.TrimSpace(matches[1])
|
||||
repo = strings.TrimSpace(matches[2])
|
||||
repo = strings.TrimSuffix(repo, ".git")
|
||||
if owner == "" || repo == "" {
|
||||
return "", "", "", false
|
||||
}
|
||||
|
||||
return owner, repo, fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo)), true
|
||||
}
|
||||
|
||||
func (g *githubProvider) ParsePullRequestURL(raw string) (PRRef, bool) {
|
||||
matches := g.pullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
|
||||
if len(matches) != 4 {
|
||||
return PRRef{}, false
|
||||
}
|
||||
|
||||
number, err := strconv.Atoi(matches[3])
|
||||
if err != nil {
|
||||
return PRRef{}, false
|
||||
}
|
||||
|
||||
return PRRef{
|
||||
Owner: matches[1],
|
||||
Repo: matches[2],
|
||||
Number: number,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (g *githubProvider) NormalizePullRequestURL(raw string) string {
|
||||
ref, ok := g.ParsePullRequestURL(strings.TrimRight(
|
||||
strings.TrimSpace(raw),
|
||||
"),.;",
|
||||
))
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
|
||||
}
|
||||
|
||||
// escapePathPreserveSlashes escapes each segment of a path
|
||||
// individually, preserving `/` separators. This is needed for
|
||||
// web URLs where GitHub expects literal slashes (e.g.
|
||||
// /tree/feat/new-thing).
|
||||
func escapePathPreserveSlashes(s string) string {
|
||||
segments := strings.Split(s, "/")
|
||||
for i, seg := range segments {
|
||||
segments[i] = url.PathEscape(seg)
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
func (g *githubProvider) BuildBranchURL(owner string, repo string, branch string) string {
|
||||
owner = strings.TrimSpace(owner)
|
||||
repo = strings.TrimSpace(repo)
|
||||
branch = strings.TrimSpace(branch)
|
||||
if owner == "" || repo == "" || branch == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s/%s/%s/tree/%s",
|
||||
g.webBaseURL,
|
||||
url.PathEscape(owner),
|
||||
url.PathEscape(repo),
|
||||
escapePathPreserveSlashes(branch),
|
||||
)
|
||||
}
|
||||
|
||||
func (g *githubProvider) BuildRepositoryURL(owner string, repo string) string {
|
||||
owner = strings.TrimSpace(owner)
|
||||
repo = strings.TrimSpace(repo)
|
||||
if owner == "" || repo == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo))
|
||||
}
|
||||
|
||||
func (g *githubProvider) BuildPullRequestURL(ref PRRef) string {
|
||||
if ref.Owner == "" || ref.Repo == "" || ref.Number <= 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
|
||||
}
|
||||
|
||||
func (g *githubProvider) ResolveBranchPullRequest(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
ref BranchRef,
|
||||
) (*PRRef, error) {
|
||||
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
query.Set("state", "open")
|
||||
query.Set("head", fmt.Sprintf("%s:%s", ref.Owner, ref.Branch))
|
||||
query.Set("sort", "updated")
|
||||
query.Set("direction", "desc")
|
||||
query.Set("per_page", "1")
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls?%s",
|
||||
g.apiBaseURL,
|
||||
url.PathEscape(ref.Owner),
|
||||
url.PathEscape(ref.Repo),
|
||||
query.Encode(),
|
||||
)
|
||||
|
||||
var pulls []struct {
|
||||
HTMLURL string `json:"html_url"`
|
||||
Number int `json:"number"`
|
||||
}
|
||||
|
||||
if err := g.decodeJSON(ctx, requestURL, token, &pulls); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(pulls) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
prRef, ok := g.ParsePullRequestURL(pulls[0].HTMLURL)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return &prRef, nil
|
||||
}
|
||||
|
||||
func (g *githubProvider) FetchPullRequestStatus(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
ref PRRef,
|
||||
) (*PRStatus, error) {
|
||||
pullEndpoint := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls/%d",
|
||||
g.apiBaseURL,
|
||||
url.PathEscape(ref.Owner),
|
||||
url.PathEscape(ref.Repo),
|
||||
ref.Number,
|
||||
)
|
||||
|
||||
var pull struct {
|
||||
State string `json:"state"`
|
||||
Merged bool `json:"merged"`
|
||||
Draft bool `json:"draft"`
|
||||
Additions int32 `json:"additions"`
|
||||
Deletions int32 `json:"deletions"`
|
||||
ChangedFiles int32 `json:"changed_files"`
|
||||
Head struct {
|
||||
SHA string `json:"sha"`
|
||||
} `json:"head"`
|
||||
}
|
||||
if err := g.decodeJSON(ctx, pullEndpoint, token, &pull); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reviews []struct {
|
||||
ID int64 `json:"id"`
|
||||
State string `json:"state"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
}
|
||||
// GitHub returns at most 100 reviews per page. We do not
|
||||
// paginate because PRs with >100 reviews are extremely rare,
|
||||
// and the cost of multiple API calls per refresh is not
|
||||
// justified. If needed, pagination can be added later.
|
||||
if err := g.decodeJSON(
|
||||
ctx,
|
||||
pullEndpoint+"/reviews?per_page=100",
|
||||
token,
|
||||
&reviews,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state := PRState(strings.ToLower(strings.TrimSpace(pull.State)))
|
||||
if pull.Merged {
|
||||
state = PRStateMerged
|
||||
}
|
||||
|
||||
return &PRStatus{
|
||||
State: state,
|
||||
Draft: pull.Draft,
|
||||
HeadSHA: pull.Head.SHA,
|
||||
DiffStats: DiffStats{
|
||||
Additions: pull.Additions,
|
||||
Deletions: pull.Deletions,
|
||||
ChangedFiles: pull.ChangedFiles,
|
||||
},
|
||||
ChangesRequested: hasOutstandingChangesRequested(reviews),
|
||||
FetchedAt: g.clock.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *githubProvider) FetchPullRequestDiff(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
ref PRRef,
|
||||
) (string, error) {
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/pulls/%d",
|
||||
g.apiBaseURL,
|
||||
url.PathEscape(ref.Owner),
|
||||
url.PathEscape(ref.Repo),
|
||||
ref.Number,
|
||||
)
|
||||
return g.fetchDiff(ctx, requestURL, token)
|
||||
}
|
||||
|
||||
func (g *githubProvider) FetchBranchDiff(
|
||||
ctx context.Context,
|
||||
token string,
|
||||
ref BranchRef,
|
||||
) (string, error) {
|
||||
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var repository struct {
|
||||
DefaultBranch string `json:"default_branch"`
|
||||
}
|
||||
|
||||
repositoryURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s",
|
||||
g.apiBaseURL,
|
||||
url.PathEscape(ref.Owner),
|
||||
url.PathEscape(ref.Repo),
|
||||
)
|
||||
if err := g.decodeJSON(ctx, repositoryURL, token, &repository); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
|
||||
if defaultBranch == "" {
|
||||
return "", xerrors.New("github repository default branch is empty")
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf(
|
||||
"%s/repos/%s/%s/compare/%s...%s",
|
||||
g.apiBaseURL,
|
||||
url.PathEscape(ref.Owner),
|
||||
url.PathEscape(ref.Repo),
|
||||
url.PathEscape(defaultBranch),
|
||||
url.PathEscape(ref.Branch),
|
||||
)
|
||||
|
||||
return g.fetchDiff(ctx, requestURL, token)
|
||||
}
|
||||
|
||||
func (g *githubProvider) decodeJSON(
|
||||
ctx context.Context,
|
||||
requestURL string,
|
||||
token string,
|
||||
dest any,
|
||||
) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create github request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github+json")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
req.Header.Set("User-Agent", "coder-chat-diff-status")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("execute github request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
|
||||
retryAfter := ParseRetryAfter(resp.Header, g.clock)
|
||||
if retryAfter > 0 {
|
||||
return &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter)}
|
||||
}
|
||||
// No rate-limit headers — fall through to generic error.
|
||||
}
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return xerrors.Errorf(
|
||||
"github request failed with status %d",
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
return xerrors.Errorf(
|
||||
"github request failed with status %d: %s",
|
||||
resp.StatusCode,
|
||||
strings.TrimSpace(string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
|
||||
return xerrors.Errorf("decode github response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *githubProvider) fetchDiff(
|
||||
ctx context.Context,
|
||||
requestURL string,
|
||||
token string,
|
||||
) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("create github diff request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.diff")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
req.Header.Set("User-Agent", "coder-chat-diff")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
resp, err := g.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("execute github diff request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
|
||||
retryAfter := ParseRetryAfter(resp.Header, g.clock)
|
||||
if retryAfter > 0 {
|
||||
return "", &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter)}
|
||||
}
|
||||
}
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
|
||||
if readErr != nil {
|
||||
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
|
||||
}
|
||||
return "", xerrors.Errorf(
|
||||
"github diff request failed with status %d: %s",
|
||||
resp.StatusCode,
|
||||
strings.TrimSpace(string(body)),
|
||||
)
|
||||
}
|
||||
|
||||
// Read one extra byte beyond MaxDiffSize so we can detect
|
||||
// whether the diff exceeds the limit. LimitReader stops us
|
||||
// allocating an arbitrarily large buffer by accident.
|
||||
buf, err := io.ReadAll(io.LimitReader(resp.Body, MaxDiffSize+1))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("read github diff response: %w", err)
|
||||
}
|
||||
if len(buf) > MaxDiffSize {
|
||||
return "", ErrDiffTooLarge
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
// ParseRetryAfter extracts a retry-after time from GitHub
|
||||
// rate-limit headers. Returns zero value if no recognizable header is
|
||||
// present.
|
||||
func ParseRetryAfter(h http.Header, clk quartz.Clock) time.Duration {
|
||||
if clk == nil {
|
||||
clk = quartz.NewReal()
|
||||
}
|
||||
// Retry-After header: seconds until retry.
|
||||
if ra := h.Get("Retry-After"); ra != "" {
|
||||
if secs, err := strconv.Atoi(ra); err == nil {
|
||||
return time.Duration(secs) * time.Second
|
||||
}
|
||||
}
|
||||
// X-Ratelimit-Reset header: unix timestamp. We compute the
|
||||
// duration from now according to the caller's clock.
|
||||
if reset := h.Get("X-Ratelimit-Reset"); reset != "" {
|
||||
if ts, err := strconv.ParseInt(reset, 10, 64); err == nil {
|
||||
d := time.Unix(ts, 0).Sub(clk.Now())
|
||||
return d
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func hasOutstandingChangesRequested(
|
||||
reviews []struct {
|
||||
ID int64 `json:"id"`
|
||||
State string `json:"state"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
},
|
||||
) bool {
|
||||
type reviewerState struct {
|
||||
reviewID int64
|
||||
state string
|
||||
}
|
||||
|
||||
statesByReviewer := make(map[string]reviewerState)
|
||||
for _, review := range reviews {
|
||||
login := strings.ToLower(strings.TrimSpace(review.User.Login))
|
||||
if login == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
state := strings.ToUpper(strings.TrimSpace(review.State))
|
||||
switch state {
|
||||
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
current, exists := statesByReviewer[login]
|
||||
if exists && current.reviewID > review.ID {
|
||||
continue
|
||||
}
|
||||
statesByReviewer[login] = reviewerState{
|
||||
reviewID: review.ID,
|
||||
state: state,
|
||||
}
|
||||
}
|
||||
|
||||
for _, state := range statesByReviewer {
|
||||
if state.state == "CHANGES_REQUESTED" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,994 @@
|
||||
package gitprovider_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestGitHubParseRepositoryOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expectOK bool
|
||||
expectOwner string
|
||||
expectRepo string
|
||||
expectNormalized string
|
||||
}{
|
||||
{
|
||||
name: "HTTPS URL",
|
||||
raw: "https://github.com/coder/coder",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL with .git",
|
||||
raw: "https://github.com/coder/coder.git",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL with trailing slash",
|
||||
raw: "https://github.com/coder/coder/",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "SSH URL",
|
||||
raw: "git@github.com:coder/coder.git",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "SSH URL without .git",
|
||||
raw: "git@github.com:coder/coder",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "SSH URL with ssh:// prefix",
|
||||
raw: "ssh://git@github.com/coder/coder.git",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNormalized: "https://github.com/coder/coder",
|
||||
},
|
||||
{
|
||||
name: "GitLab URL does not match",
|
||||
raw: "https://gitlab.com/coder/coder",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
raw: "",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "Not a URL",
|
||||
raw: "not-a-url",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "Hyphenated owner and repo",
|
||||
raw: "https://github.com/my-org/my-repo.git",
|
||||
expectOK: true,
|
||||
expectOwner: "my-org",
|
||||
expectRepo: "my-repo",
|
||||
expectNormalized: "https://github.com/my-org/my-repo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
owner, repo, normalized, ok := gp.ParseRepositoryOrigin(tt.raw)
|
||||
assert.Equal(t, tt.expectOK, ok)
|
||||
if tt.expectOK {
|
||||
assert.Equal(t, tt.expectOwner, owner)
|
||||
assert.Equal(t, tt.expectRepo, repo)
|
||||
assert.Equal(t, tt.expectNormalized, normalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubParsePullRequestURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expectOK bool
|
||||
expectOwner string
|
||||
expectRepo string
|
||||
expectNumber int
|
||||
}{
|
||||
{
|
||||
name: "Standard PR URL",
|
||||
raw: "https://github.com/coder/coder/pull/123",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNumber: 123,
|
||||
},
|
||||
{
|
||||
name: "PR URL with query string",
|
||||
raw: "https://github.com/coder/coder/pull/456?diff=split",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNumber: 456,
|
||||
},
|
||||
{
|
||||
name: "PR URL with fragment",
|
||||
raw: "https://github.com/coder/coder/pull/789#discussion",
|
||||
expectOK: true,
|
||||
expectOwner: "coder",
|
||||
expectRepo: "coder",
|
||||
expectNumber: 789,
|
||||
},
|
||||
{
|
||||
name: "Not a PR URL",
|
||||
raw: "https://github.com/coder/coder",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "Issue URL (not PR)",
|
||||
raw: "https://github.com/coder/coder/issues/123",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "GitLab MR URL",
|
||||
raw: "https://gitlab.com/coder/coder/-/merge_requests/123",
|
||||
expectOK: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
raw: "",
|
||||
expectOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ref, ok := gp.ParsePullRequestURL(tt.raw)
|
||||
assert.Equal(t, tt.expectOK, ok)
|
||||
if tt.expectOK {
|
||||
assert.Equal(t, tt.expectOwner, ref.Owner)
|
||||
assert.Equal(t, tt.expectRepo, ref.Repo)
|
||||
assert.Equal(t, tt.expectNumber, ref.Number)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubNormalizePullRequestURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Already normalized",
|
||||
raw: "https://github.com/coder/coder/pull/123",
|
||||
expected: "https://github.com/coder/coder/pull/123",
|
||||
},
|
||||
{
|
||||
name: "With trailing punctuation",
|
||||
raw: "https://github.com/coder/coder/pull/123).",
|
||||
expected: "https://github.com/coder/coder/pull/123",
|
||||
},
|
||||
{
|
||||
name: "With query string",
|
||||
raw: "https://github.com/coder/coder/pull/123?diff=split",
|
||||
expected: "https://github.com/coder/coder/pull/123",
|
||||
},
|
||||
{
|
||||
name: "With whitespace",
|
||||
raw: " https://github.com/coder/coder/pull/123 ",
|
||||
expected: "https://github.com/coder/coder/pull/123",
|
||||
},
|
||||
{
|
||||
name: "Not a PR URL",
|
||||
raw: "https://example.com",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
raw: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.NormalizePullRequestURL(tt.raw)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubBuildBranchURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
owner string
|
||||
repo string
|
||||
branch string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple branch",
|
||||
owner: "coder",
|
||||
repo: "coder",
|
||||
branch: "main",
|
||||
expected: "https://github.com/coder/coder/tree/main",
|
||||
},
|
||||
{
|
||||
name: "Branch with slash",
|
||||
owner: "coder",
|
||||
repo: "coder",
|
||||
branch: "feat/new-thing",
|
||||
expected: "https://github.com/coder/coder/tree/feat/new-thing",
|
||||
},
|
||||
{
|
||||
name: "Empty owner",
|
||||
owner: "",
|
||||
repo: "coder",
|
||||
branch: "main",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty repo",
|
||||
owner: "coder",
|
||||
repo: "",
|
||||
branch: "main",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty branch",
|
||||
owner: "coder",
|
||||
repo: "coder",
|
||||
branch: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Branch with slashes",
|
||||
owner: "my-org",
|
||||
repo: "my-repo",
|
||||
branch: "feat/new-thing",
|
||||
expected: "https://github.com/my-org/my-repo/tree/feat/new-thing",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.BuildBranchURL(tt.owner, tt.repo, tt.branch)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubBuildPullRequestURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ref gitprovider.PRRef
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Valid PR ref",
|
||||
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 123},
|
||||
expected: "https://github.com/coder/coder/pull/123",
|
||||
},
|
||||
{
|
||||
name: "Empty owner",
|
||||
ref: gitprovider.PRRef{Owner: "", Repo: "coder", Number: 123},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty repo",
|
||||
ref: gitprovider.PRRef{Owner: "coder", Repo: "", Number: 123},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Zero number",
|
||||
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 0},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Negative number",
|
||||
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: -1},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.BuildPullRequestURL(tt.ref)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubEnterpriseURLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("github", "https://ghes.corp.com/api/v3", nil)
|
||||
require.NotNil(t, gp)
|
||||
|
||||
t.Run("ParseRepositoryOrigin HTTPS", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("https://ghes.corp.com/org/repo.git")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "org", owner)
|
||||
assert.Equal(t, "repo", repo)
|
||||
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
|
||||
})
|
||||
|
||||
t.Run("ParseRepositoryOrigin SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("git@ghes.corp.com:org/repo.git")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "org", owner)
|
||||
assert.Equal(t, "repo", repo)
|
||||
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
|
||||
})
|
||||
|
||||
t.Run("ParsePullRequestURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ref, ok := gp.ParsePullRequestURL("https://ghes.corp.com/org/repo/pull/42")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "org", ref.Owner)
|
||||
assert.Equal(t, "repo", ref.Repo)
|
||||
assert.Equal(t, 42, ref.Number)
|
||||
})
|
||||
|
||||
t.Run("NormalizePullRequestURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.NormalizePullRequestURL("https://ghes.corp.com/org/repo/pull/42?x=y")
|
||||
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
|
||||
})
|
||||
|
||||
t.Run("BuildBranchURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.BuildBranchURL("org", "repo", "main")
|
||||
assert.Equal(t, "https://ghes.corp.com/org/repo/tree/main", result)
|
||||
})
|
||||
|
||||
t.Run("BuildPullRequestURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := gp.BuildPullRequestURL(gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42})
|
||||
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
|
||||
})
|
||||
|
||||
t.Run("github.com URLs do not match GHE instance", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, _, _, ok := gp.ParseRepositoryOrigin("https://github.com/coder/coder")
|
||||
assert.False(t, ok, "github.com HTTPS URL should not match GHE instance")
|
||||
|
||||
_, _, _, ok = gp.ParseRepositoryOrigin("git@github.com:coder/coder.git")
|
||||
assert.False(t, ok, "github.com SSH URL should not match GHE instance")
|
||||
|
||||
_, ok = gp.ParsePullRequestURL("https://github.com/coder/coder/pull/123")
|
||||
assert.False(t, ok, "github.com PR URL should not match GHE instance")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewUnsupportedProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
gp := gitprovider.New("unsupported", "", nil)
|
||||
assert.Nil(t, gp, "unsupported provider type should return nil")
|
||||
}
|
||||
|
||||
func TestGitHubRatelimit_403WithResetHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resetTime := time.Now().Add(60 * time.Second)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("X-Ratelimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"message": "API rate limit exceeded"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchPullRequestStatus(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.Error(t, err)
|
||||
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
|
||||
assert.WithinDuration(t, resetTime, rlErr.RetryAfter, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestGitHubRatelimit_429WithRetryAfter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Retry-After", "120")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"message": "secondary rate limit"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchPullRequestStatus(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.Error(t, err)
|
||||
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
|
||||
|
||||
// Retry-After: 120 means ~120s from now.
|
||||
expected := time.Now().Add(120 * time.Second)
|
||||
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestGitHubRatelimit_403NormalError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte(`{"message": "Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchPullRequestStatus(
|
||||
context.Background(),
|
||||
"bad-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.Error(t, err)
|
||||
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
assert.False(t, errors.As(err, &rlErr), "error should NOT be *RateLimitError")
|
||||
assert.Contains(t, err.Error(), "403")
|
||||
}
|
||||
|
||||
func TestGitHubFetchPullRequestDiff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(smallDiff))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
diff, err := gp.FetchPullRequestDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, smallDiff, diff)
|
||||
})
|
||||
|
||||
t.Run("ExactlyMaxSize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exactDiff := string(make([]byte, gitprovider.MaxDiffSize))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(exactDiff))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
diff, err := gp.FetchPullRequestDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, diff, gitprovider.MaxDiffSize)
|
||||
})
|
||||
|
||||
t.Run("TooLarge", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(oversizeDiff))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchPullRequestDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFetchPullRequestDiff_Ratelimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchPullRequestDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.Error(t, err)
|
||||
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
|
||||
expected := time.Now().Add(60 * time.Second)
|
||||
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestFetchBranchDiff_Ratelimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/compare/") {
|
||||
// Second request: compare endpoint returns 429.
|
||||
w.Header().Set("Retry-After", "60")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
|
||||
return
|
||||
}
|
||||
// First request: repo metadata.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchBranchDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.Error(t, err)
|
||||
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
|
||||
expected := time.Now().Add(60 * time.Second)
|
||||
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
|
||||
}
|
||||
|
||||
func TestFetchPullRequestStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type review struct {
|
||||
ID int64 `json:"id"`
|
||||
State string `json:"state"`
|
||||
User struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"user"`
|
||||
}
|
||||
|
||||
makeReview := func(id int64, state, login string) review {
|
||||
r := review{ID: id, State: state}
|
||||
r.User.Login = login
|
||||
return r
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pullJSON string
|
||||
reviews []review
|
||||
expectedState gitprovider.PRState
|
||||
expectedDraft bool
|
||||
changesRequested bool
|
||||
}{
|
||||
{
|
||||
name: "OpenPR/NoReviews",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
expectedDraft: false,
|
||||
changesRequested: false,
|
||||
},
|
||||
{
|
||||
name: "OpenPR/SingleChangesRequested",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{makeReview(1, "CHANGES_REQUESTED", "alice")},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
changesRequested: true,
|
||||
},
|
||||
{
|
||||
name: "OpenPR/ChangesRequestedThenApproved",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{
|
||||
makeReview(1, "CHANGES_REQUESTED", "alice"),
|
||||
makeReview(2, "APPROVED", "alice"),
|
||||
},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
changesRequested: false,
|
||||
},
|
||||
{
|
||||
name: "OpenPR/ChangesRequestedThenDismissed",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{
|
||||
makeReview(1, "CHANGES_REQUESTED", "alice"),
|
||||
makeReview(2, "DISMISSED", "alice"),
|
||||
},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
changesRequested: false,
|
||||
},
|
||||
{
|
||||
name: "OpenPR/MultipleReviewersMixed",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{
|
||||
makeReview(1, "APPROVED", "alice"),
|
||||
makeReview(2, "CHANGES_REQUESTED", "bob"),
|
||||
},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
changesRequested: true,
|
||||
},
|
||||
{
|
||||
name: "OpenPR/CommentedDoesNotAffect",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{
|
||||
makeReview(1, "COMMENTED", "alice"),
|
||||
},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
changesRequested: false,
|
||||
},
|
||||
{
|
||||
name: "MergedPR",
|
||||
pullJSON: `{"state":"closed","merged":true,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{},
|
||||
expectedState: gitprovider.PRStateMerged,
|
||||
changesRequested: false,
|
||||
},
|
||||
{
|
||||
name: "DraftPR",
|
||||
pullJSON: `{"state":"open","merged":false,"draft":true,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
|
||||
reviews: []review{},
|
||||
expectedState: gitprovider.PRStateOpen,
|
||||
expectedDraft: true,
|
||||
changesRequested: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reviewsJSON, err := json.Marshal(tc.reviews)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1/reviews", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(reviewsJSON)
|
||||
})
|
||||
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(tc.pullJSON))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
before := time.Now().UTC()
|
||||
status, err := gp.FetchPullRequestStatus(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedState, status.State)
|
||||
assert.Equal(t, tc.expectedDraft, status.Draft)
|
||||
assert.Equal(t, tc.changesRequested, status.ChangesRequested)
|
||||
assert.Equal(t, "abc123", status.HeadSHA)
|
||||
assert.Equal(t, int32(10), status.DiffStats.Additions)
|
||||
assert.Equal(t, int32(5), status.DiffStats.Deletions)
|
||||
assert.Equal(t, int32(3), status.DiffStats.ChangedFiles)
|
||||
assert.False(t, status.FetchedAt.IsZero())
|
||||
assert.True(t, !status.FetchedAt.Before(before), "FetchedAt should be >= test start time")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBranchPullRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var srvURL string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify query parameters.
|
||||
assert.Equal(t, "open", r.URL.Query().Get("state"))
|
||||
assert.Equal(t, "owner:feat", r.URL.Query().Get("head"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Use the test server's URL so ParsePullRequestURL
|
||||
// matches the provider's derived web host.
|
||||
htmlURL := fmt.Sprintf("https://%s/owner/repo/pull/42",
|
||||
strings.TrimPrefix(strings.TrimPrefix(srvURL, "http://"), "https://"))
|
||||
_, _ = w.Write([]byte(fmt.Sprintf(`[{"html_url":%q,"number":42}]`, htmlURL)))
|
||||
}))
|
||||
defer srv.Close()
|
||||
srvURL = srv.URL
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
prRef, err := gp.ResolveBranchPullRequest(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, prRef)
|
||||
assert.Equal(t, "owner", prRef.Owner)
|
||||
assert.Equal(t, "repo", prRef.Repo)
|
||||
assert.Equal(t, 42, prRef.Number)
|
||||
})
|
||||
|
||||
t.Run("NoneOpen", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
prRef, err := gp.ResolveBranchPullRequest(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, prRef)
|
||||
})
|
||||
|
||||
t.Run("InvalidHTMLURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// If html_url can't be parsed as a PR URL, ResolveBranchPullRequest
|
||||
// returns nil, nil.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"html_url":"not-a-valid-url","number":42}]`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
prRef, err := gp.ResolveBranchPullRequest(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, prRef)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFetchBranchDiff(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/compare/") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(smallDiff))
|
||||
return
|
||||
}
|
||||
// Repo metadata.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
diff, err := gp.FetchBranchDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, smallDiff, diff)
|
||||
})
|
||||
|
||||
t.Run("EmptyDefaultBranch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"default_branch":""}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchBranchDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "default branch is empty")
|
||||
})
|
||||
|
||||
t.Run("DiffTooLarge", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/compare/") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = w.Write([]byte(oversizeDiff))
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
|
||||
require.NotNil(t, gp)
|
||||
|
||||
_, err := gp.FetchBranchDiff(
|
||||
context.Background(),
|
||||
"test-token",
|
||||
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
|
||||
)
|
||||
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEscapePathPreserveSlashes(t *testing.T) {
|
||||
t.Parallel()
|
||||
// The function is unexported, so test it indirectly via BuildBranchURL.
|
||||
// A branch with a space in a segment should be escaped, but slashes preserved.
|
||||
gp := gitprovider.New("github", "", nil)
|
||||
require.NotNil(t, gp)
|
||||
got := gp.BuildBranchURL("owner", "repo", "feat/my thing")
|
||||
assert.Equal(t, "https://github.com/owner/repo/tree/feat/my%20thing", got)
|
||||
}
|
||||
|
||||
func TestParseRetryAfter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(time.Now())
|
||||
|
||||
t.Run("RetryAfterSeconds", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := http.Header{}
|
||||
h.Set("Retry-After", "120")
|
||||
d := gitprovider.ParseRetryAfter(h, clk)
|
||||
assert.Equal(t, 120*time.Second, d)
|
||||
})
|
||||
|
||||
t.Run("XRatelimitReset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
future := clk.Now().Add(90 * time.Second)
|
||||
t.Logf("now: %d future: %d", clk.Now().Unix(), future.Unix())
|
||||
h := http.Header{}
|
||||
h.Set("X-Ratelimit-Reset", strconv.FormatInt(future.Unix(), 10))
|
||||
d := gitprovider.ParseRetryAfter(h, clk)
|
||||
assert.WithinDuration(t, future, clk.Now().Add(d), time.Second)
|
||||
})
|
||||
|
||||
t.Run("NoHeaders", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := http.Header{}
|
||||
d := gitprovider.ParseRetryAfter(h, clk)
|
||||
assert.Equal(t, time.Duration(0), d)
|
||||
})
|
||||
|
||||
t.Run("InvalidValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := http.Header{}
|
||||
h.Set("Retry-After", "not-a-number")
|
||||
d := gitprovider.ParseRetryAfter(h, clk)
|
||||
assert.Equal(t, time.Duration(0), d)
|
||||
})
|
||||
|
||||
t.Run("RetryAfterTakesPrecedence", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := http.Header{}
|
||||
h.Set("Retry-After", "60")
|
||||
h.Set("X-Ratelimit-Reset", strconv.FormatInt(
|
||||
clk.Now().Unix()+120, 10,
|
||||
))
|
||||
d := gitprovider.ParseRetryAfter(h, clk)
|
||||
assert.Equal(t, 60*time.Second, d)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package gitprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// providerOptions holds optional configuration for provider
|
||||
// construction.
|
||||
type providerOptions struct {
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// Option configures optional behavior for a Provider.
|
||||
type Option func(*providerOptions)
|
||||
|
||||
// WithClock sets the clock used by the provider. Defaults to
|
||||
// quartz.NewReal() if not provided.
|
||||
func WithClock(c quartz.Clock) Option {
|
||||
return func(o *providerOptions) {
|
||||
o.clock = c
|
||||
}
|
||||
}
|
||||
|
||||
// PRState is the normalized state of a pull/merge request across
|
||||
// all providers.
|
||||
type PRState string
|
||||
|
||||
const (
|
||||
PRStateOpen PRState = "open"
|
||||
PRStateClosed PRState = "closed"
|
||||
PRStateMerged PRState = "merged"
|
||||
)
|
||||
|
||||
// PRRef identifies a pull request on any provider.
|
||||
type PRRef struct {
|
||||
// Owner is the repository owner / project / workspace.
|
||||
Owner string
|
||||
// Repo is the repository name or slug.
|
||||
Repo string
|
||||
// Number is the PR number / IID / index.
|
||||
Number int
|
||||
}
|
||||
|
||||
// BranchRef identifies a branch in a repository, used for
|
||||
// branch-to-PR resolution.
|
||||
type BranchRef struct {
|
||||
Owner string
|
||||
Repo string
|
||||
Branch string
|
||||
}
|
||||
|
||||
// DiffStats summarizes the size of a PR's changes.
|
||||
type DiffStats struct {
|
||||
Additions int32
|
||||
Deletions int32
|
||||
ChangedFiles int32
|
||||
}
|
||||
|
||||
// PRStatus is the complete status of a pull/merge request.
|
||||
// This is the universal return type that all providers populate.
|
||||
type PRStatus struct {
|
||||
// State is the PR's lifecycle state.
|
||||
State PRState
|
||||
// Draft indicates the PR is marked as draft/WIP.
|
||||
Draft bool
|
||||
// HeadSHA is the SHA of the head commit.
|
||||
HeadSHA string
|
||||
// DiffStats summarizes additions/deletions/files changed.
|
||||
DiffStats DiffStats
|
||||
// ChangesRequested is a convenience boolean: true if any
|
||||
// reviewer's current state is "changes_requested".
|
||||
ChangesRequested bool
|
||||
// FetchedAt is when this status was fetched.
|
||||
FetchedAt time.Time
|
||||
}
|
||||
|
||||
// MaxDiffSize is the maximum number of bytes read from a diff
|
||||
// response. Diffs exceeding this limit are rejected with
|
||||
// ErrDiffTooLarge.
|
||||
const MaxDiffSize = 4 << 20 // 4 MiB
|
||||
|
||||
// ErrDiffTooLarge is returned when a diff exceeds MaxDiffSize.
|
||||
var ErrDiffTooLarge = xerrors.Errorf("diff exceeds maximum size of %d bytes", MaxDiffSize)
|
||||
|
||||
// Provider defines the interface that all Git hosting providers
|
||||
// implement. Each method is designed to minimize API round-trips
|
||||
// for the specific provider.
|
||||
type Provider interface {
|
||||
// FetchPullRequestStatus retrieves the complete status of a
|
||||
// pull request in the minimum number of API calls for this
|
||||
// provider.
|
||||
FetchPullRequestStatus(ctx context.Context, token string, ref PRRef) (*PRStatus, error)
|
||||
|
||||
// ResolveBranchPullRequest finds the open PR (if any) for
|
||||
// the given branch. Returns nil, nil if no open PR exists.
|
||||
ResolveBranchPullRequest(ctx context.Context, token string, ref BranchRef) (*PRRef, error)
|
||||
|
||||
// FetchPullRequestDiff returns the raw unified diff for a
|
||||
// pull request. This uses the PR's actual base branch (which
|
||||
// may differ from the repo default branch, e.g. a PR
|
||||
// targeting "staging" instead of "main"), so it matches what
|
||||
// the provider shows on the PR's "Files changed" tab.
|
||||
// Returns ErrDiffTooLarge if the diff exceeds MaxDiffSize.
|
||||
FetchPullRequestDiff(ctx context.Context, token string, ref PRRef) (string, error)
|
||||
|
||||
// FetchBranchDiff returns the diff of a branch compared
|
||||
// against the repository's default branch. This is the
|
||||
// fallback when no pull request exists yet (e.g. the agent
|
||||
// pushed a branch but hasn't opened a PR). Returns
|
||||
// ErrDiffTooLarge if the diff exceeds MaxDiffSize.
|
||||
FetchBranchDiff(ctx context.Context, token string, ref BranchRef) (string, error)
|
||||
|
||||
// ParseRepositoryOrigin parses a remote origin URL (HTTPS
|
||||
// or SSH) into owner and repo components, returning the
|
||||
// normalized HTTPS URL. Returns false if the URL does not
|
||||
// match this provider.
|
||||
ParseRepositoryOrigin(raw string) (owner, repo, normalizedOrigin string, ok bool)
|
||||
|
||||
// ParsePullRequestURL parses a pull request URL into a
|
||||
// PRRef. Returns false if the URL does not match this
|
||||
// provider.
|
||||
ParsePullRequestURL(raw string) (PRRef, bool)
|
||||
|
||||
// NormalizePullRequestURL normalizes a pull request URL,
|
||||
// stripping trailing punctuation, query strings, and
|
||||
// fragments. Returns empty string if the URL does not
|
||||
// match this provider.
|
||||
NormalizePullRequestURL(raw string) string
|
||||
|
||||
// BuildBranchURL constructs a URL to view a branch on
|
||||
// the provider's web UI.
|
||||
BuildBranchURL(owner, repo, branch string) string
|
||||
|
||||
// BuildRepositoryURL constructs a URL to view a repository
|
||||
// on the provider's web UI.
|
||||
BuildRepositoryURL(owner, repo string) string
|
||||
|
||||
// BuildPullRequestURL constructs a URL to view a pull
|
||||
// request on the provider's web UI.
|
||||
BuildPullRequestURL(ref PRRef) string
|
||||
}
|
||||
|
||||
// New creates a Provider for the given provider type and API base
|
||||
// URL. Returns nil if the provider type is not a supported git
|
||||
// provider.
|
||||
func New(providerType string, apiBaseURL string, httpClient *http.Client, opts ...Option) Provider {
|
||||
o := providerOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
if o.clock == nil {
|
||||
o.clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
switch providerType {
|
||||
case "github":
|
||||
return newGitHub(apiBaseURL, httpClient, o.clock)
|
||||
default:
|
||||
// Other providers (gitlab, bitbucket-cloud, etc.) will be
|
||||
// added here as they are implemented.
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitError indicates the git provider's API rate limit was hit.
|
||||
type RateLimitError struct {
|
||||
RetryAfter time.Time
|
||||
}
|
||||
|
||||
func (e *RateLimitError) Error() string {
|
||||
return fmt.Sprintf("rate limited until %s", e.RetryAfter.Format(time.RFC3339))
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package gitsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// DiffStatusTTL is how long a successfully refreshed
|
||||
// diff status remains fresh before becoming stale again.
|
||||
DiffStatusTTL = 120 * time.Second
|
||||
)
|
||||
|
||||
// ProviderResolver maps a git remote origin to the gitprovider
|
||||
// that handles it. Returns nil if no provider matches.
|
||||
type ProviderResolver func(origin string) gitprovider.Provider
|
||||
|
||||
var ErrNoTokenAvailable error = errors.New("no token available")
|
||||
|
||||
// TokenResolver obtains the user's git access token for a given
|
||||
// remote origin. Should return nil if no token is available, in
|
||||
// which case ErrNoTokenAvailable will be returned.
|
||||
type TokenResolver func(
|
||||
ctx context.Context,
|
||||
userID uuid.UUID,
|
||||
origin string,
|
||||
) (*string, error)
|
||||
|
||||
// Refresher contains the stateless business logic for fetching
|
||||
// fresh PR data from a git provider given a stale
|
||||
// database.ChatDiffStatus row.
|
||||
type Refresher struct {
|
||||
providers ProviderResolver
|
||||
tokens TokenResolver
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
}
|
||||
|
||||
// NewRefresher creates a Refresher with the given dependency
|
||||
// functions.
|
||||
func NewRefresher(
|
||||
providers ProviderResolver,
|
||||
tokens TokenResolver,
|
||||
logger slog.Logger,
|
||||
clock quartz.Clock,
|
||||
) *Refresher {
|
||||
return &Refresher{
|
||||
providers: providers,
|
||||
tokens: tokens,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshRequest pairs a stale row with the chat owner who
|
||||
// holds the git token needed for API calls.
|
||||
type RefreshRequest struct {
|
||||
Row database.ChatDiffStatus
|
||||
OwnerID uuid.UUID
|
||||
}
|
||||
|
||||
// RefreshResult is the outcome for a single row.
|
||||
// - Params != nil, Error == nil → success, caller should upsert.
|
||||
// - Params == nil, Error == nil → no PR yet, caller should skip.
|
||||
// - Params == nil, Error != nil → row-level failure.
|
||||
type RefreshResult struct {
|
||||
Request RefreshRequest
|
||||
Params *database.UpsertChatDiffStatusParams
|
||||
Error error
|
||||
}
|
||||
|
||||
// groupKey identifies a unique (owner, origin) pair so that
|
||||
// provider and token resolution happen once per group.
|
||||
type groupKey struct {
|
||||
ownerID uuid.UUID
|
||||
origin string
|
||||
}
|
||||
|
||||
// Refresh fetches fresh PR data for a batch of stale rows.
|
||||
// Rows are grouped internally by (ownerID, origin) so that
|
||||
// provider and token resolution happen once per group. A
|
||||
// top-level error is returned only when the entire batch
|
||||
// fails catastrophically. Per-row outcomes are in the
|
||||
// returned RefreshResult slice (one per input request, same
|
||||
// order).
|
||||
func (r *Refresher) Refresh(
|
||||
ctx context.Context,
|
||||
requests []RefreshRequest,
|
||||
) ([]RefreshResult, error) {
|
||||
results := make([]RefreshResult, len(requests))
|
||||
for i, req := range requests {
|
||||
results[i].Request = req
|
||||
}
|
||||
|
||||
// Group request indices by (ownerID, origin).
|
||||
groups := make(map[groupKey][]int)
|
||||
for i, req := range requests {
|
||||
key := groupKey{
|
||||
ownerID: req.OwnerID,
|
||||
origin: req.Row.GitRemoteOrigin,
|
||||
}
|
||||
groups[key] = append(groups[key], i)
|
||||
}
|
||||
|
||||
for key, indices := range groups {
|
||||
provider := r.providers(key.origin)
|
||||
if provider == nil {
|
||||
err := xerrors.Errorf("no provider for origin %q", key.origin)
|
||||
for _, i := range indices {
|
||||
results[i].Error = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := r.tokens(ctx, key.ownerID, key.origin)
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("resolve token: %w", err)
|
||||
} else if token == nil || len(*token) == 0 {
|
||||
err = ErrNoTokenAvailable
|
||||
}
|
||||
if err != nil {
|
||||
for _, i := range indices {
|
||||
results[i].Error = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
// This is technically unnecessary but kept here as a future molly-guard.
|
||||
if token == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, idx := range indices {
|
||||
req := requests[idx]
|
||||
params, err := r.refreshOne(ctx, provider, *token, req.Row)
|
||||
results[idx] = RefreshResult{Request: req, Params: params, Error: err}
|
||||
|
||||
// If rate-limited, skip remaining rows in this group.
|
||||
var rlErr *gitprovider.RateLimitError
|
||||
if errors.As(err, &rlErr) {
|
||||
for _, remaining := range indices[i+1:] {
|
||||
results[remaining] = RefreshResult{
|
||||
Request: requests[remaining],
|
||||
Error: fmt.Errorf("skipped: %w", rlErr),
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// refreshOne processes a single row using an already-resolved
|
||||
// provider and token. This is the old Refresh logic, unchanged.
|
||||
func (r *Refresher) refreshOne(
|
||||
ctx context.Context,
|
||||
provider gitprovider.Provider,
|
||||
token string,
|
||||
row database.ChatDiffStatus,
|
||||
) (*database.UpsertChatDiffStatusParams, error) {
|
||||
var ref gitprovider.PRRef
|
||||
var prURL string
|
||||
|
||||
if row.Url.Valid && row.Url.String != "" {
|
||||
// Row already has a PR URL — parse it directly.
|
||||
parsed, ok := provider.ParsePullRequestURL(row.Url.String)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("parse pull request URL %q", row.Url.String)
|
||||
}
|
||||
ref = parsed
|
||||
prURL = row.Url.String
|
||||
} else {
|
||||
// No PR URL — resolve owner/repo from the remote origin,
|
||||
// then look up the open PR for this branch.
|
||||
owner, repo, _, ok := provider.ParseRepositoryOrigin(row.GitRemoteOrigin)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("parse repository origin %q", row.GitRemoteOrigin)
|
||||
}
|
||||
|
||||
resolved, err := provider.ResolveBranchPullRequest(ctx, token, gitprovider.BranchRef{
|
||||
Owner: owner,
|
||||
Repo: repo,
|
||||
Branch: row.GitBranch,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("resolve branch pull request: %w", err)
|
||||
}
|
||||
if resolved == nil {
|
||||
// No PR exists yet for this branch.
|
||||
return nil, nil
|
||||
}
|
||||
ref = *resolved
|
||||
prURL = provider.BuildPullRequestURL(ref)
|
||||
}
|
||||
|
||||
status, err := provider.FetchPullRequestStatus(ctx, token, ref)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("fetch pull request status: %w", err)
|
||||
}
|
||||
|
||||
now := r.clock.Now().UTC()
|
||||
params := &database.UpsertChatDiffStatusParams{
|
||||
ChatID: row.ChatID,
|
||||
Url: sql.NullString{String: prURL, Valid: prURL != ""},
|
||||
PullRequestState: sql.NullString{
|
||||
String: string(status.State),
|
||||
Valid: status.State != "",
|
||||
},
|
||||
ChangesRequested: status.ChangesRequested,
|
||||
Additions: status.DiffStats.Additions,
|
||||
Deletions: status.DiffStats.Deletions,
|
||||
ChangedFiles: status.DiffStats.ChangedFiles,
|
||||
RefreshedAt: now,
|
||||
StaleAt: now.Add(DiffStatusTTL),
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
@@ -0,0 +1,775 @@
|
||||
package gitsync_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// mockProvider implements gitprovider.Provider with function fields
|
||||
// so each test can wire only the methods it needs. Any method left
|
||||
// nil panics with "unexpected call".
|
||||
type mockProvider struct {
|
||||
fetchPullRequestStatus func(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error)
|
||||
resolveBranchPR func(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error)
|
||||
fetchPullRequestDiff func(ctx context.Context, token string, ref gitprovider.PRRef) (string, error)
|
||||
fetchBranchDiff func(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error)
|
||||
parseRepositoryOrigin func(raw string) (string, string, string, bool)
|
||||
parsePullRequestURL func(raw string) (gitprovider.PRRef, bool)
|
||||
normalizePullRequestURL func(raw string) string
|
||||
buildBranchURL func(owner, repo, branch string) string
|
||||
buildRepositoryURL func(owner, repo string) string
|
||||
buildPullRequestURL func(ref gitprovider.PRRef) string
|
||||
}
|
||||
|
||||
func (m *mockProvider) FetchPullRequestStatus(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
if m.fetchPullRequestStatus == nil {
|
||||
panic("unexpected call to FetchPullRequestStatus")
|
||||
}
|
||||
return m.fetchPullRequestStatus(ctx, token, ref)
|
||||
}
|
||||
|
||||
func (m *mockProvider) ResolveBranchPullRequest(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
if m.resolveBranchPR == nil {
|
||||
panic("unexpected call to ResolveBranchPullRequest")
|
||||
}
|
||||
return m.resolveBranchPR(ctx, token, ref)
|
||||
}
|
||||
|
||||
func (m *mockProvider) FetchPullRequestDiff(ctx context.Context, token string, ref gitprovider.PRRef) (string, error) {
|
||||
if m.fetchPullRequestDiff == nil {
|
||||
panic("unexpected call to FetchPullRequestDiff")
|
||||
}
|
||||
return m.fetchPullRequestDiff(ctx, token, ref)
|
||||
}
|
||||
|
||||
func (m *mockProvider) FetchBranchDiff(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error) {
|
||||
if m.fetchBranchDiff == nil {
|
||||
panic("unexpected call to FetchBranchDiff")
|
||||
}
|
||||
return m.fetchBranchDiff(ctx, token, ref)
|
||||
}
|
||||
|
||||
func (m *mockProvider) ParseRepositoryOrigin(raw string) (string, string, string, bool) {
|
||||
if m.parseRepositoryOrigin == nil {
|
||||
panic("unexpected call to ParseRepositoryOrigin")
|
||||
}
|
||||
return m.parseRepositoryOrigin(raw)
|
||||
}
|
||||
|
||||
func (m *mockProvider) ParsePullRequestURL(raw string) (gitprovider.PRRef, bool) {
|
||||
if m.parsePullRequestURL == nil {
|
||||
panic("unexpected call to ParsePullRequestURL")
|
||||
}
|
||||
return m.parsePullRequestURL(raw)
|
||||
}
|
||||
|
||||
func (m *mockProvider) NormalizePullRequestURL(raw string) string {
|
||||
if m.normalizePullRequestURL == nil {
|
||||
panic("unexpected call to NormalizePullRequestURL")
|
||||
}
|
||||
return m.normalizePullRequestURL(raw)
|
||||
}
|
||||
|
||||
func (m *mockProvider) BuildBranchURL(owner, repo, branch string) string {
|
||||
if m.buildBranchURL == nil {
|
||||
panic("unexpected call to BuildBranchURL")
|
||||
}
|
||||
return m.buildBranchURL(owner, repo, branch)
|
||||
}
|
||||
|
||||
func (m *mockProvider) BuildRepositoryURL(owner, repo string) string {
|
||||
if m.buildRepositoryURL == nil {
|
||||
panic("unexpected call to BuildRepositoryURL")
|
||||
}
|
||||
return m.buildRepositoryURL(owner, repo)
|
||||
}
|
||||
|
||||
func (m *mockProvider) BuildPullRequestURL(ref gitprovider.PRRef) string {
|
||||
if m.buildPullRequestURL == nil {
|
||||
panic("unexpected call to BuildPullRequestURL")
|
||||
}
|
||||
return m.buildPullRequestURL(ref)
|
||||
}
|
||||
|
||||
func TestRefresher_WithPRURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return &gitprovider.PRStatus{
|
||||
State: gitprovider.PRStateOpen,
|
||||
DiffStats: gitprovider.DiffStats{
|
||||
Additions: 10,
|
||||
Deletions: 5,
|
||||
ChangedFiles: 3,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
chatID := uuid.New()
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: chatID,
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
require.NoError(t, res.Error)
|
||||
require.NotNil(t, res.Params)
|
||||
|
||||
assert.Equal(t, chatID, res.Params.ChatID)
|
||||
assert.Equal(t, "open", res.Params.PullRequestState.String)
|
||||
assert.True(t, res.Params.PullRequestState.Valid)
|
||||
assert.Equal(t, int32(10), res.Params.Additions)
|
||||
assert.Equal(t, int32(5), res.Params.Deletions)
|
||||
assert.Equal(t, int32(3), res.Params.ChangedFiles)
|
||||
|
||||
// StaleAt should be ~120s after RefreshedAt.
|
||||
diff := res.Params.StaleAt.Sub(res.Params.RefreshedAt)
|
||||
assert.InDelta(t, 120, diff.Seconds(), 5)
|
||||
}
|
||||
|
||||
func TestRefresher_BranchResolvesToPR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
|
||||
return "org", "repo", "https://github.com/org/repo", true
|
||||
},
|
||||
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
return &gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 7}, nil
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
|
||||
},
|
||||
buildPullRequestURL: func(_ gitprovider.PRRef) string {
|
||||
return "https://github.com/org/repo/pull/7"
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
require.NoError(t, res.Error)
|
||||
require.NotNil(t, res.Params)
|
||||
|
||||
assert.Contains(t, res.Params.Url.String, "pull/7")
|
||||
assert.True(t, res.Params.Url.Valid)
|
||||
assert.Equal(t, "open", res.Params.PullRequestState.String)
|
||||
}
|
||||
|
||||
func TestRefresher_BranchNoPRYet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
|
||||
return "org", "repo", "https://github.com/org/repo", true
|
||||
},
|
||||
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.NoError(t, res.Error)
|
||||
assert.Nil(t, res.Params)
|
||||
}
|
||||
|
||||
func TestRefresher_NoProviderForOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return nil }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://example.com/pr/1", Valid: true},
|
||||
GitRemoteOrigin: "https://example.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.Nil(t, res.Params)
|
||||
require.Error(t, res.Error)
|
||||
assert.Contains(t, res.Error.Error(), "no provider")
|
||||
}
|
||||
|
||||
func TestRefresher_TokenResolutionFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var fetchCalled atomic.Bool
|
||||
mp := &mockProvider{
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
fetchCalled.Store(true)
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return nil, errors.New("token lookup failed")
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.Nil(t, res.Params)
|
||||
require.Error(t, res.Error)
|
||||
assert.False(t, fetchCalled.Load(), "FetchPullRequestStatus should not be called when token resolution fails")
|
||||
}
|
||||
|
||||
func TestRefresher_EmptyToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref(""), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.Nil(t, res.Params)
|
||||
require.ErrorIs(t, res.Error, gitsync.ErrNoTokenAvailable)
|
||||
}
|
||||
|
||||
func TestRefresher_ProviderFetchFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return nil, errors.New("api error")
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.Nil(t, res.Params)
|
||||
require.Error(t, res.Error)
|
||||
assert.Contains(t, res.Error.Error(), "api error")
|
||||
}
|
||||
|
||||
func TestRefresher_PRURLParseFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{}, false
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/not-a-pr", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
assert.Nil(t, res.Params)
|
||||
require.Error(t, res.Error)
|
||||
}
|
||||
|
||||
func TestRefresher_BatchGroupsByOwnerAndOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
|
||||
var tokenCalls atomic.Int32
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
tokenCalls.Add(1)
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
ownerID := uuid.New()
|
||||
originA := "https://github.com/org/repo"
|
||||
originB := "https://gitlab.com/org/repo"
|
||||
|
||||
requests := []gitsync.RefreshRequest{
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: originA,
|
||||
GitBranch: "feature-1",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: originA,
|
||||
GitBranch: "feature-2",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://gitlab.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: originB,
|
||||
GitBranch: "feature-3",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
}
|
||||
|
||||
results, err := r.Refresh(context.Background(), requests)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
for i, res := range results {
|
||||
require.NoError(t, res.Error, "result[%d] should not have an error", i)
|
||||
require.NotNil(t, res.Params, "result[%d] should have params", i)
|
||||
}
|
||||
|
||||
// Two distinct (ownerID, origin) groups → exactly 2 token
|
||||
// resolution calls.
|
||||
assert.Equal(t, int32(2), tokenCalls.Load(),
|
||||
"TokenResolver should be called once per (owner, origin) group")
|
||||
}
|
||||
|
||||
func TestRefresher_UsesInjectedClock(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
fixedTime := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
mClock.Set(fixedTime)
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return &gitprovider.PRStatus{
|
||||
State: gitprovider.PRStateOpen,
|
||||
DiffStats: gitprovider.DiffStats{
|
||||
Additions: 10,
|
||||
Deletions: 5,
|
||||
ChangedFiles: 3,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), mClock)
|
||||
|
||||
chatID := uuid.New()
|
||||
row := database.ChatDiffStatus{
|
||||
ChatID: chatID,
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature",
|
||||
}
|
||||
|
||||
ownerID := uuid.New()
|
||||
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
|
||||
{Row: row, OwnerID: ownerID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 1)
|
||||
res := results[0]
|
||||
|
||||
require.NoError(t, res.Error)
|
||||
require.NotNil(t, res.Params)
|
||||
|
||||
// The mock clock is deterministic, so times must be exact.
|
||||
assert.Equal(t, fixedTime, res.Params.RefreshedAt)
|
||||
assert.Equal(t, fixedTime.Add(gitsync.DiffStatusTTL), res.Params.StaleAt)
|
||||
}
|
||||
|
||||
func TestRefresher_RateLimitSkipsRemainingInGroup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
||||
var num int
|
||||
switch {
|
||||
case strings.HasSuffix(raw, "/pull/1"):
|
||||
num = 1
|
||||
case strings.HasSuffix(raw, "/pull/2"):
|
||||
num = 2
|
||||
case strings.HasSuffix(raw, "/pull/3"):
|
||||
num = 3
|
||||
default:
|
||||
return gitprovider.PRRef{}, false
|
||||
}
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, _ string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
call := callCount.Add(1)
|
||||
switch call {
|
||||
case 1:
|
||||
// First call succeeds.
|
||||
return &gitprovider.PRStatus{
|
||||
State: gitprovider.PRStateOpen,
|
||||
DiffStats: gitprovider.DiffStats{
|
||||
Additions: 5,
|
||||
Deletions: 2,
|
||||
ChangedFiles: 1,
|
||||
},
|
||||
}, nil
|
||||
case 2:
|
||||
// Second call hits rate limit.
|
||||
return nil, &gitprovider.RateLimitError{
|
||||
RetryAfter: time.Now().Add(60 * time.Second),
|
||||
}
|
||||
default:
|
||||
// Third call should never happen.
|
||||
t.Fatal("FetchPullRequestStatus called more than 2 times")
|
||||
return nil, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
|
||||
return ptr.Ref("test-token"), nil
|
||||
}
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
ownerID := uuid.New()
|
||||
origin := "https://github.com/org/repo"
|
||||
|
||||
requests := []gitsync.RefreshRequest{
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: origin,
|
||||
GitBranch: "feat-1",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
|
||||
GitRemoteOrigin: origin,
|
||||
GitBranch: "feat-2",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/3", Valid: true},
|
||||
GitRemoteOrigin: origin,
|
||||
GitBranch: "feat-3",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
}
|
||||
|
||||
results, err := r.Refresh(context.Background(), requests)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
// Row 0: success.
|
||||
assert.NoError(t, results[0].Error)
|
||||
assert.NotNil(t, results[0].Params)
|
||||
|
||||
// Row 1: rate-limited.
|
||||
require.Error(t, results[1].Error)
|
||||
var rlErr1 *gitprovider.RateLimitError
|
||||
assert.True(t, errors.As(results[1].Error, &rlErr1),
|
||||
"result[1] error should be *RateLimitError")
|
||||
|
||||
// Row 2: skipped due to rate limit.
|
||||
require.Error(t, results[2].Error)
|
||||
var rlErr2 *gitprovider.RateLimitError
|
||||
assert.True(t, errors.As(results[2].Error, &rlErr2),
|
||||
"result[2] error should wrap *RateLimitError")
|
||||
assert.Contains(t, results[2].Error.Error(), "skipped")
|
||||
|
||||
// Provider should have been called exactly twice.
|
||||
assert.Equal(t, int32(2), callCount.Load(),
|
||||
"FetchPullRequestStatus should be called exactly 2 times")
|
||||
}
|
||||
|
||||
func TestRefresher_CorrectTokenPerOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var tokenCalls atomic.Int32
|
||||
tokens := func(_ context.Context, _ uuid.UUID, origin string) (*string, error) {
|
||||
tokenCalls.Add(1)
|
||||
switch {
|
||||
case strings.Contains(origin, "github.com"):
|
||||
return ptr.Ref("gh-public-token"), nil
|
||||
case strings.Contains(origin, "ghes.corp.com"):
|
||||
return ptr.Ref("ghe-private-token"), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected origin: %s", origin)
|
||||
}
|
||||
}
|
||||
|
||||
// Track which token each FetchPullRequestStatus call received,
|
||||
// keyed by chat ID. We pass the chat ID through the PRRef.Number
|
||||
// field (unique per request) so FetchPullRequestStatus can
|
||||
// identify which row it's processing.
|
||||
var mu sync.Mutex
|
||||
tokensByPR := make(map[int]string)
|
||||
|
||||
mp := &mockProvider{
|
||||
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
||||
// Extract a unique PR number from the URL to identify
|
||||
// each row inside FetchPullRequestStatus.
|
||||
var num int
|
||||
switch {
|
||||
case strings.HasSuffix(raw, "/pull/1"):
|
||||
num = 1
|
||||
case strings.HasSuffix(raw, "/pull/2"):
|
||||
num = 2
|
||||
case strings.HasSuffix(raw, "/pull/10"):
|
||||
num = 10
|
||||
default:
|
||||
return gitprovider.PRRef{}, false
|
||||
}
|
||||
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
|
||||
},
|
||||
fetchPullRequestStatus: func(_ context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
mu.Lock()
|
||||
tokensByPR[ref.Number] = token
|
||||
mu.Unlock()
|
||||
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(_ string) gitprovider.Provider { return mp }
|
||||
|
||||
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
|
||||
|
||||
ownerID := uuid.New()
|
||||
|
||||
requests := []gitsync.RefreshRequest{
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature-1",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
|
||||
GitRemoteOrigin: "https://github.com/org/repo",
|
||||
GitBranch: "feature-2",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
{
|
||||
Row: database.ChatDiffStatus{
|
||||
ChatID: uuid.New(),
|
||||
Url: sql.NullString{String: "https://ghes.corp.com/org/repo/pull/10", Valid: true},
|
||||
GitRemoteOrigin: "https://ghes.corp.com/org/repo",
|
||||
GitBranch: "feature-3",
|
||||
},
|
||||
OwnerID: ownerID,
|
||||
},
|
||||
}
|
||||
|
||||
results, err := r.Refresh(context.Background(), requests)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
|
||||
for i, res := range results {
|
||||
require.NoError(t, res.Error, "result[%d] should not have an error", i)
|
||||
require.NotNil(t, res.Params, "result[%d] should have params", i)
|
||||
}
|
||||
|
||||
// github.com rows (PR #1 and #2) should use the public token.
|
||||
assert.Equal(t, "gh-public-token", tokensByPR[1],
|
||||
"github.com PR #1 should use gh-public-token")
|
||||
assert.Equal(t, "gh-public-token", tokensByPR[2],
|
||||
"github.com PR #2 should use gh-public-token")
|
||||
|
||||
// ghes.corp.com row (PR #10) should use the GHE token.
|
||||
assert.Equal(t, "ghe-private-token", tokensByPR[10],
|
||||
"ghes.corp.com PR #10 should use ghe-private-token")
|
||||
|
||||
// Token resolution should be called exactly twice — once per
|
||||
// (owner, origin) group.
|
||||
assert.Equal(t, int32(2), tokenCalls.Load(),
|
||||
"TokenResolver should be called once per (owner, origin) group")
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
// Package gitsyncmock contains generated mocks for the gitsync package.
|
||||
package gitsyncmock
|
||||
|
||||
//go:generate mockgen -destination ./store.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync Store
|
||||
//go:generate mockgen -destination ./publisher.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync EventPublisher
|
||||
@@ -0,0 +1,56 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/coderd/gitsync (interfaces: EventPublisher)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./publisher.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync EventPublisher
|
||||
//
|
||||
|
||||
// Package gitsyncmock is a generated GoMock package.
|
||||
package gitsyncmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
uuid "github.com/google/uuid"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockEventPublisher is a mock of EventPublisher interface.
|
||||
type MockEventPublisher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEventPublisherMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockEventPublisherMockRecorder is the mock recorder for MockEventPublisher.
|
||||
type MockEventPublisherMockRecorder struct {
|
||||
mock *MockEventPublisher
|
||||
}
|
||||
|
||||
// NewMockEventPublisher creates a new mock instance.
|
||||
func NewMockEventPublisher(ctrl *gomock.Controller) *MockEventPublisher {
|
||||
mock := &MockEventPublisher{ctrl: ctrl}
|
||||
mock.recorder = &MockEventPublisherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEventPublisher) EXPECT() *MockEventPublisherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// PublishDiffStatusChange mocks base method.
|
||||
func (m *MockEventPublisher) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PublishDiffStatusChange", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PublishDiffStatusChange indicates an expected call of PublishDiffStatusChange.
|
||||
func (mr *MockEventPublisherMockRecorder) PublishDiffStatusChange(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDiffStatusChange", reflect.TypeOf((*MockEventPublisher)(nil).PublishDiffStatusChange), ctx, chatID)
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/coderd/gitsync (interfaces: Store)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination ./store.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync Store
|
||||
//
|
||||
|
||||
// Package gitsyncmock is a generated GoMock package.
|
||||
package gitsyncmock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
database "github.com/coder/coder/v2/coderd/database"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStore is a mock of Store interface.
|
||||
type MockStore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStoreMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockStoreMockRecorder is the mock recorder for MockStore.
|
||||
type MockStoreMockRecorder struct {
|
||||
mock *MockStore
|
||||
}
|
||||
|
||||
// NewMockStore creates a new mock instance.
|
||||
func NewMockStore(ctrl *gomock.Controller) *MockStore {
|
||||
mock := &MockStore{ctrl: ctrl}
|
||||
mock.recorder = &MockStoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcquireStaleChatDiffStatuses mocks base method.
|
||||
func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal)
|
||||
ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses.
|
||||
func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal)
|
||||
}
|
||||
|
||||
// BackoffChatDiffStatus mocks base method.
|
||||
func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus.
|
||||
func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID mocks base method.
|
||||
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
|
||||
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatus", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatus indicates an expected call of UpsertChatDiffStatus.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatus(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatus", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference mocks base method.
|
||||
func (m *MockStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatDiffStatusReference", ctx, arg)
|
||||
ret0, _ := ret[0].(database.ChatDiffStatus)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpsertChatDiffStatusReference indicates an expected call of UpsertChatDiffStatusReference.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package gitsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultBatchSize is the maximum number of stale rows fetched
|
||||
// per tick.
|
||||
defaultBatchSize int32 = 50
|
||||
|
||||
// defaultInterval is the polling interval between ticks.
|
||||
defaultInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
// Store is the narrow DB interface the Worker needs.
|
||||
type Store interface {
|
||||
AcquireStaleChatDiffStatuses(
|
||||
ctx context.Context, limitVal int32,
|
||||
) ([]database.AcquireStaleChatDiffStatusesRow, error)
|
||||
BackoffChatDiffStatus(
|
||||
ctx context.Context, arg database.BackoffChatDiffStatusParams,
|
||||
) error
|
||||
UpsertChatDiffStatus(
|
||||
ctx context.Context, arg database.UpsertChatDiffStatusParams,
|
||||
) (database.ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(
|
||||
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
|
||||
) (database.ChatDiffStatus, error)
|
||||
GetChatsByOwnerID(
|
||||
ctx context.Context, arg database.GetChatsByOwnerIDParams,
|
||||
) ([]database.Chat, error)
|
||||
}
|
||||
|
||||
// EventPublisher notifies the frontend of diff status changes.
|
||||
type EventPublisher interface {
|
||||
PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error
|
||||
}
|
||||
|
||||
// Worker is a background loop that periodically refreshes stale
|
||||
// chat diff statuses by delegating to a Refresher.
|
||||
type Worker struct {
|
||||
store Store
|
||||
refresher *Refresher
|
||||
publisher EventPublisher
|
||||
clock quartz.Clock
|
||||
logger slog.Logger
|
||||
batchSize int32
|
||||
interval time.Duration
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewWorker creates a Worker with default batch size and interval.
|
||||
func NewWorker(
|
||||
store Store,
|
||||
refresher *Refresher,
|
||||
publisher EventPublisher,
|
||||
clock quartz.Clock,
|
||||
logger slog.Logger,
|
||||
) *Worker {
|
||||
return &Worker{
|
||||
store: store,
|
||||
refresher: refresher,
|
||||
publisher: publisher,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
batchSize: defaultBatchSize,
|
||||
interval: defaultInterval,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start launches the background loop. It blocks until ctx is
|
||||
// cancelled, then closes w.done.
|
||||
func (w *Worker) Start(ctx context.Context) {
|
||||
defer close(w.done)
|
||||
|
||||
ticker := w.clock.NewTicker(w.interval, "gitsync", "worker")
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.tick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the worker exits.
|
||||
func (w *Worker) Done() <-chan struct{} {
|
||||
return w.done
|
||||
}
|
||||
|
||||
func chatDiffStatusFromRow(row database.AcquireStaleChatDiffStatusesRow) database.ChatDiffStatus {
|
||||
return database.ChatDiffStatus{
|
||||
ChatID: row.ChatID,
|
||||
Url: row.Url,
|
||||
PullRequestState: row.PullRequestState,
|
||||
ChangesRequested: row.ChangesRequested,
|
||||
Additions: row.Additions,
|
||||
Deletions: row.Deletions,
|
||||
ChangedFiles: row.ChangedFiles,
|
||||
RefreshedAt: row.RefreshedAt,
|
||||
StaleAt: row.StaleAt,
|
||||
CreatedAt: row.CreatedAt,
|
||||
UpdatedAt: row.UpdatedAt,
|
||||
GitBranch: row.GitBranch,
|
||||
GitRemoteOrigin: row.GitRemoteOrigin,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) tick(ctx context.Context) {
|
||||
acquiredRows, err := w.store.AcquireStaleChatDiffStatuses(ctx, w.batchSize)
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "acquire stale chat diff statuses",
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if len(acquiredRows) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Build refresh requests directly from acquired rows.
|
||||
requests := make([]RefreshRequest, 0, len(acquiredRows))
|
||||
for _, row := range acquiredRows {
|
||||
requests = append(requests, RefreshRequest{
|
||||
Row: chatDiffStatusFromRow(row),
|
||||
OwnerID: row.OwnerID,
|
||||
})
|
||||
}
|
||||
|
||||
results, err := w.refresher.Refresh(ctx, requests)
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "batch refresh chat diff statuses",
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
for _, res := range results {
|
||||
if res.Error != nil {
|
||||
w.logger.Debug(ctx, "refresh chat diff status",
|
||||
slog.F("chat_id", res.Request.Row.ChatID),
|
||||
slog.Error(res.Error))
|
||||
// Back off so the row isn't retried immediately.
|
||||
if err := w.store.BackoffChatDiffStatus(ctx,
|
||||
database.BackoffChatDiffStatusParams{
|
||||
ChatID: res.Request.Row.ChatID,
|
||||
StaleAt: w.clock.Now().UTC().Add(DiffStatusTTL),
|
||||
},
|
||||
); err != nil {
|
||||
w.logger.Warn(ctx, "backoff failed chat diff status",
|
||||
slog.F("chat_id", res.Request.Row.ChatID),
|
||||
slog.Error(err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if res.Params == nil {
|
||||
// No PR yet — skip.
|
||||
continue
|
||||
}
|
||||
if _, err := w.store.UpsertChatDiffStatus(ctx, *res.Params); err != nil {
|
||||
w.logger.Warn(ctx, "upsert refreshed chat diff status",
|
||||
slog.F("chat_id", res.Request.Row.ChatID),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if err := w.publisher.PublishDiffStatusChange(ctx, res.Request.Row.ChatID); err != nil {
|
||||
w.logger.Debug(ctx, "publish diff status change",
|
||||
slog.F("chat_id", res.Request.Row.ChatID),
|
||||
slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MarkStale persists the git ref on all chats for a workspace,
|
||||
// setting stale_at to the past so the next tick picks them up.
|
||||
// Publishes a diff status event for each affected chat.
|
||||
// Called from workspaceagents handlers. No goroutines spawned.
|
||||
func (w *Worker) MarkStale(
|
||||
ctx context.Context,
|
||||
workspaceID, ownerID uuid.UUID,
|
||||
branch, origin string,
|
||||
) {
|
||||
if branch == "" || origin == "" {
|
||||
return
|
||||
}
|
||||
|
||||
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
|
||||
OwnerID: ownerID,
|
||||
})
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "list chats for git ref storage",
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
|
||||
_, err := w.store.UpsertChatDiffStatusReference(ctx,
|
||||
database.UpsertChatDiffStatusReferenceParams{
|
||||
ChatID: chat.ID,
|
||||
GitBranch: branch,
|
||||
GitRemoteOrigin: origin,
|
||||
StaleAt: w.clock.Now().Add(-time.Second),
|
||||
Url: sql.NullString{},
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
w.logger.Warn(ctx, "store git ref on chat diff status",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("workspace_id", workspaceID),
|
||||
slog.Error(err))
|
||||
continue
|
||||
}
|
||||
// Notify the frontend immediately so the UI shows the
|
||||
// branch info even before the worker refreshes PR data.
|
||||
if pubErr := w.publisher.PublishDiffStatusChange(ctx, chat.ID); pubErr != nil {
|
||||
w.logger.Debug(ctx, "publish diff status after mark stale",
|
||||
slog.F("chat_id", chat.ID), slog.Error(pubErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterChatsByWorkspaceID returns only chats associated with
|
||||
// the given workspace.
|
||||
func filterChatsByWorkspaceID(
|
||||
chats []database.Chat,
|
||||
workspaceID uuid.UUID,
|
||||
) []database.Chat {
|
||||
filtered := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, chat)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
@@ -0,0 +1,757 @@
|
||||
package gitsync_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
|
||||
"github.com/coder/coder/v2/coderd/gitsync"
|
||||
"github.com/coder/coder/v2/coderd/gitsync/gitsyncmock"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
|
||||
|
||||
// testRefresherCfg configures newTestRefresher.
|
||||
type testRefresherCfg struct {
|
||||
resolveBranchPR func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)
|
||||
fetchPRStatus func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error)
|
||||
}
|
||||
|
||||
type testRefresherOpt func(*testRefresherCfg)
|
||||
|
||||
func withResolveBranchPR(f func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)) testRefresherOpt {
|
||||
return func(c *testRefresherCfg) { c.resolveBranchPR = f }
|
||||
}
|
||||
|
||||
// newTestRefresher creates a Refresher backed by mock
|
||||
// provider/token resolvers. The provider recognises any origin,
|
||||
// resolves branches to a canned PR, and returns a canned PRStatus.
|
||||
func newTestRefresher(t *testing.T, clk quartz.Clock, opts ...testRefresherOpt) *gitsync.Refresher {
|
||||
t.Helper()
|
||||
|
||||
cfg := testRefresherCfg{
|
||||
resolveBranchPR: func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
|
||||
},
|
||||
fetchPRStatus: func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) {
|
||||
return &gitprovider.PRStatus{
|
||||
State: gitprovider.PRStateOpen,
|
||||
DiffStats: gitprovider.DiffStats{
|
||||
Additions: 10,
|
||||
Deletions: 3,
|
||||
ChangedFiles: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
prov := &mockProvider{
|
||||
parseRepositoryOrigin: func(string) (string, string, string, bool) {
|
||||
return "owner", "repo", "https://github.com/owner/repo", true
|
||||
},
|
||||
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
|
||||
return gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, raw != ""
|
||||
},
|
||||
resolveBranchPR: cfg.resolveBranchPR,
|
||||
fetchPullRequestStatus: cfg.fetchPRStatus,
|
||||
buildPullRequestURL: func(ref gitprovider.PRRef) string {
|
||||
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
|
||||
},
|
||||
}
|
||||
|
||||
providers := func(string) gitprovider.Provider { return prov }
|
||||
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
|
||||
return ptr.Ref("tok"), nil
|
||||
}
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
return gitsync.NewRefresher(providers, tokens, logger, clk)
|
||||
}
|
||||
|
||||
// makeAcquiredRow returns an AcquireStaleChatDiffStatusesRow with
|
||||
// a non-empty branch/origin so the Refresher goes through the
|
||||
// branch-resolution path.
|
||||
func makeAcquiredRow(chatID, ownerID uuid.UUID) database.AcquireStaleChatDiffStatusesRow {
|
||||
return database.AcquireStaleChatDiffStatusesRow{
|
||||
ChatID: chatID,
|
||||
GitBranch: "feature",
|
||||
GitRemoteOrigin: "https://github.com/owner/repo",
|
||||
StaleAt: time.Now().Add(-time.Minute),
|
||||
OwnerID: ownerID,
|
||||
}
|
||||
}
|
||||
|
||||
// tickOnce traps the worker's NewTicker call, starts the worker,
|
||||
// fires one tick, waits for it to finish by observing the given
|
||||
// tickDone channel, then shuts the worker down. The tickDone
|
||||
// channel must be closed when the last expected operation in the
|
||||
// tick completes. For tests where the tick does nothing (e.g. 0
|
||||
// stale rows or store error), tickDone should be closed inside
|
||||
// acquireStaleChatDiffStatuses.
|
||||
func tickOnce(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
mClock *quartz.Mock,
|
||||
worker *gitsync.Worker,
|
||||
tickDone <-chan struct{},
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
trap := mClock.Trap().NewTicker("gitsync", "worker")
|
||||
defer trap.Close()
|
||||
|
||||
workerCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go worker.Start(workerCtx)
|
||||
|
||||
// Wait for the worker to create its ticker.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
// Fire one tick. The waiter resolves when the channel receive
|
||||
// completes, not when w.tick() returns, so we use tickDone to
|
||||
// know when to proceed.
|
||||
_, w := mClock.AdvanceNext()
|
||||
w.MustWait(ctx)
|
||||
|
||||
// Wait for the tick's business logic to finish.
|
||||
select {
|
||||
case <-tickDone:
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for tick to complete")
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-worker.Done()
|
||||
}
|
||||
|
||||
func TestWorker_PicksUpStaleRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
var upsertCount atomic.Int32
|
||||
var publishCount atomic.Int32
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return([]database.AcquireStaleChatDiffStatusesRow{
|
||||
makeAcquiredRow(chat1, ownerID),
|
||||
makeAcquiredRow(chat2, ownerID),
|
||||
}, nil)
|
||||
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
upsertCount.Add(1)
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
|
||||
if publishCount.Add(1) == 2 {
|
||||
close(tickDone)
|
||||
}
|
||||
return nil
|
||||
}).Times(2)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
|
||||
assert.Equal(t, int32(2), upsertCount.Load())
|
||||
assert.Equal(t, int32(2), publishCount.Load())
|
||||
}
|
||||
|
||||
func TestWorker_SkipsFreshRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
// No stale rows — tick returns immediately.
|
||||
close(tickDone)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
}
|
||||
|
||||
func TestWorker_LimitsToNRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
var capturedLimit atomic.Int32
|
||||
var upsertCount atomic.Int32
|
||||
ownerID := uuid.New()
|
||||
const numRows = 5
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
|
||||
for i := range rows {
|
||||
rows[i] = makeAcquiredRow(uuid.New(), ownerID)
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
capturedLimit.Store(limitVal)
|
||||
return rows, nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
upsertCount.Add(1)
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(numRows)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, uuid.UUID) error {
|
||||
if upsertCount.Load() == numRows {
|
||||
close(tickDone)
|
||||
}
|
||||
return nil
|
||||
}).Times(numRows)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
|
||||
// The default batch size is 50.
|
||||
assert.Equal(t, int32(50), capturedLimit.Load())
|
||||
assert.Equal(t, int32(numRows), upsertCount.Load())
|
||||
}
|
||||
|
||||
func TestWorker_RefresherReturnsNilNil_SkipsUpsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
chatID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
// When the Refresher returns (nil, nil) the worker skips the
|
||||
// upsert and publish. We signal tickDone from the refresher
|
||||
// mock since that is the last operation before the tick
|
||||
// returns.
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return([]database.AcquireStaleChatDiffStatusesRow{makeAcquiredRow(chatID, ownerID)}, nil)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
// ResolveBranchPullRequest returns nil → Refresher returns
|
||||
// (nil, nil).
|
||||
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
|
||||
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
close(tickDone)
|
||||
return nil, nil
|
||||
},
|
||||
))
|
||||
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
}
|
||||
|
||||
func TestWorker_RefresherError_BacksOffRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
var upsertCount atomic.Int32
|
||||
var publishCount atomic.Int32
|
||||
var backoffCount atomic.Int32
|
||||
var mu sync.Mutex
|
||||
var backoffArgs []database.BackoffChatDiffStatusParams
|
||||
tickDone := make(chan struct{})
|
||||
var closeOnce sync.Once
|
||||
|
||||
// Two rows processed: one fails (backoff), one succeeds
|
||||
// (upsert+publish). Both must finish before we close tickDone.
|
||||
var terminalOps atomic.Int32
|
||||
signalIfDone := func() {
|
||||
if terminalOps.Add(1) == 2 {
|
||||
closeOnce.Do(func() { close(tickDone) })
|
||||
}
|
||||
}
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return([]database.AcquireStaleChatDiffStatusesRow{
|
||||
makeAcquiredRow(chat1, ownerID),
|
||||
makeAcquiredRow(chat2, ownerID),
|
||||
}, nil)
|
||||
store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error {
|
||||
backoffCount.Add(1)
|
||||
mu.Lock()
|
||||
backoffArgs = append(backoffArgs, arg)
|
||||
mu.Unlock()
|
||||
signalIfDone()
|
||||
return nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
upsertCount.Add(1)
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
})
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, uuid.UUID) error {
|
||||
// Only the successful row publishes.
|
||||
publishCount.Add(1)
|
||||
signalIfDone()
|
||||
return nil
|
||||
})
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
// Fail ResolveBranchPullRequest for the first call, succeed
|
||||
// for the second.
|
||||
var callCount atomic.Int32
|
||||
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
|
||||
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return nil, fmt.Errorf("simulated provider error")
|
||||
}
|
||||
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
|
||||
},
|
||||
))
|
||||
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
|
||||
// BackoffChatDiffStatus was called for the failed row.
|
||||
assert.Equal(t, int32(1), backoffCount.Load())
|
||||
mu.Lock()
|
||||
require.Len(t, backoffArgs, 1)
|
||||
assert.Equal(t, chat1, backoffArgs[0].ChatID)
|
||||
// stale_at should be approximately clock.Now() + DiffStatusTTL (120s).
|
||||
expectedStaleAt := mClock.Now().UTC().Add(gitsync.DiffStatusTTL)
|
||||
assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second)
|
||||
mu.Unlock()
|
||||
|
||||
// UpsertChatDiffStatus was called for the successful row.
|
||||
assert.Equal(t, int32(1), upsertCount.Load())
|
||||
// PublishDiffStatusChange was called only for the successful row.
|
||||
assert.Equal(t, int32(1), publishCount.Load())
|
||||
}
|
||||
|
||||
func TestWorker_UpsertError_ContinuesNextRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
var publishCount atomic.Int32
|
||||
tickDone := make(chan struct{})
|
||||
var closeOnce sync.Once
|
||||
var mu sync.Mutex
|
||||
upsertedChatIDs := make(map[uuid.UUID]struct{})
|
||||
|
||||
// We have 2 rows. The upsert for chat1 fails; the upsert
|
||||
// for chat2 succeeds and publishes. Because goroutines run
|
||||
// concurrently we don't know which finishes last, so we
|
||||
// track the total number of "terminal" events (upsert error
|
||||
// + publish success) and close tickDone when both have
|
||||
// occurred.
|
||||
var terminalOps atomic.Int32
|
||||
signalIfDone := func() {
|
||||
if terminalOps.Add(1) == 2 {
|
||||
closeOnce.Do(func() { close(tickDone) })
|
||||
}
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return([]database.AcquireStaleChatDiffStatusesRow{
|
||||
makeAcquiredRow(chat1, ownerID),
|
||||
makeAcquiredRow(chat2, ownerID),
|
||||
}, nil)
|
||||
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
if arg.ChatID == chat1 {
|
||||
// Terminal event for the failing row.
|
||||
signalIfDone()
|
||||
return database.ChatDiffStatus{}, fmt.Errorf("db write error")
|
||||
}
|
||||
mu.Lock()
|
||||
upsertedChatIDs[arg.ChatID] = struct{}{}
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, uuid.UUID) error {
|
||||
publishCount.Add(1)
|
||||
// Terminal event for the successful row.
|
||||
signalIfDone()
|
||||
return nil
|
||||
})
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
|
||||
mu.Lock()
|
||||
_, gotChat2 := upsertedChatIDs[chat2]
|
||||
mu.Unlock()
|
||||
assert.True(t, gotChat2, "chat2 should have been upserted")
|
||||
assert.Equal(t, int32(1), publishCount.Load())
|
||||
}
|
||||
|
||||
func TestWorker_RespectsShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return(nil, nil).AnyTimes()
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
trap := mClock.Trap().NewTicker("gitsync", "worker")
|
||||
defer trap.Close()
|
||||
|
||||
workerCtx, cancel := context.WithCancel(ctx)
|
||||
go worker.Start(workerCtx)
|
||||
|
||||
// Wait for ticker creation so the worker is running.
|
||||
trap.MustWait(ctx).MustRelease(ctx)
|
||||
|
||||
// Cancel immediately.
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-worker.Done():
|
||||
// Success — worker shut down.
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out waiting for worker to shut down")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorker_BatchRefreshAllRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
ownerID := uuid.New()
|
||||
const numRows = 5
|
||||
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
|
||||
for i := range rows {
|
||||
rows[i] = makeAcquiredRow(uuid.New(), ownerID)
|
||||
}
|
||||
|
||||
var upsertCount atomic.Int32
|
||||
var publishCount atomic.Int32
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
Return(rows, nil)
|
||||
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
|
||||
upsertCount.Add(1)
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(numRows)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, uuid.UUID) error {
|
||||
if publishCount.Add(1) == numRows {
|
||||
close(tickDone)
|
||||
}
|
||||
return nil
|
||||
}).Times(numRows)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
|
||||
assert.Equal(t, int32(numRows), upsertCount.Load())
|
||||
assert.Equal(t, int32(numRows), publishCount.Load())
|
||||
}
|
||||
|
||||
// --- MarkStale tests ---
|
||||
|
||||
func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
chatOther := uuid.New()
|
||||
|
||||
var mu sync.Mutex
|
||||
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
|
||||
var publishedIDs []uuid.UUID
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
|
||||
require.Equal(t, ownerID, arg.OwnerID)
|
||||
return []database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}, nil
|
||||
})
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
mu.Lock()
|
||||
upsertRefCalls = append(upsertRefCalls, arg)
|
||||
mu.Unlock()
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, chatID uuid.UUID) error {
|
||||
mu.Lock()
|
||||
publishedIDs = append(publishedIDs, chatID)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
}).Times(2)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
now := mClock.Now()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
require.Len(t, upsertRefCalls, 2)
|
||||
for _, call := range upsertRefCalls {
|
||||
assert.Equal(t, "feature", call.GitBranch)
|
||||
assert.Equal(t, "https://github.com/owner/repo", call.GitRemoteOrigin)
|
||||
assert.True(t, call.StaleAt.Before(now),
|
||||
"stale_at should be in the past, got %v vs now %v", call.StaleAt, now)
|
||||
assert.Equal(t, sql.NullString{}, call.Url)
|
||||
}
|
||||
|
||||
require.Len(t, publishedIDs, 2)
|
||||
assert.ElementsMatch(t, []uuid.UUID{chat1, chat2}, publishedIDs)
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
|
||||
}, nil)
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
workspaceID := uuid.New()
|
||||
ownerID := uuid.New()
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
|
||||
var publishCount atomic.Int32
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return([]database.Chat{
|
||||
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
|
||||
}, nil)
|
||||
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
|
||||
if arg.ChatID == chat1 {
|
||||
return database.ChatDiffStatus{}, fmt.Errorf("upsert ref error")
|
||||
}
|
||||
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
|
||||
}).Times(2)
|
||||
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, uuid.UUID) error {
|
||||
publishCount.Add(1)
|
||||
return nil
|
||||
})
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
|
||||
|
||||
assert.Equal(t, int32(1), publishCount.Load())
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
|
||||
Return(nil, fmt.Errorf("db error"))
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
|
||||
}
|
||||
|
||||
func TestWorker_TickStoreError(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
tickDone := make(chan struct{})
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
|
||||
close(tickDone)
|
||||
return nil, fmt.Errorf("database unavailable")
|
||||
})
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
tickOnce(ctx, t, mClock, worker, tickDone)
|
||||
}
|
||||
|
||||
func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
branch string
|
||||
origin string
|
||||
}{
|
||||
{"both empty", "", ""},
|
||||
{"branch empty", "", "https://github.com/x/y"},
|
||||
{"origin empty", "main", ""},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
store := gitsyncmock.NewMockStore(ctrl)
|
||||
pub := gitsyncmock.NewMockEventPublisher(ctrl)
|
||||
|
||||
|
||||
mClock := quartz.NewMock(t)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
refresher := newTestRefresher(t, mClock)
|
||||
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
|
||||
|
||||
worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1835,18 +1835,6 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
Branch: strings.TrimSpace(query.Get("git_branch")),
|
||||
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
|
||||
}
|
||||
var chatID uuid.NullUUID
|
||||
if rawChatID := query.Get("chat_id"); rawChatID != "" {
|
||||
parsed, err := uuid.Parse(rawChatID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid chat_id.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
chatID = uuid.NullUUID{UUID: parsed, Valid: true}
|
||||
}
|
||||
// Either match or configID must be provided!
|
||||
match := query.Get("match")
|
||||
if match == "" {
|
||||
@@ -1942,9 +1930,9 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
|
||||
// Persist git refs as soon as the agent requests external auth so branch
|
||||
// context is retained even if the flow requires an out-of-band login.
|
||||
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // System context required to persist chat git refs.
|
||||
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef)
|
||||
if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
}
|
||||
|
||||
var previousToken *database.ExternalAuthLink
|
||||
@@ -1960,7 +1948,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
return
|
||||
}
|
||||
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef)
|
||||
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
|
||||
}
|
||||
|
||||
// This is the URL that will redirect the user with a state token.
|
||||
@@ -2018,11 +2006,10 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
|
||||
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
|
||||
// Since we're ticking frequently and this sign-in operation is rare,
|
||||
// we are OK with polling to avoid the complexity of pubsub.
|
||||
ticker, done := api.NewTicker(time.Second)
|
||||
@@ -2092,7 +2079,8 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
|
||||
})
|
||||
return
|
||||
}
|
||||
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
|
||||
//nolint:gocritic // Chat processor context required for cross-user chat lookup
|
||||
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -980,6 +980,10 @@ type ExternalAuthConfig struct {
|
||||
// 'Username for "https://github.com":'
|
||||
// And sending it to the Coder server to match against the Regex.
|
||||
Regex string `json:"regex" yaml:"regex"`
|
||||
// APIBaseURL is the base URL for provider REST API calls
|
||||
// (e.g., "https://api.github.com" for GitHub). Derived from
|
||||
// defaults when not explicitly configured.
|
||||
APIBaseURL string `json:"api_base_url" yaml:"api_base_url"`
|
||||
// DisplayName is shown in the UI to identify the auth config.
|
||||
DisplayName string `json:"display_name" yaml:"display_name"`
|
||||
// DisplayIcon is a URL to an icon to display in the UI.
|
||||
|
||||
Vendored
+1
@@ -22,6 +22,7 @@ externalAuthProviders:
|
||||
mcp_tool_allow_regex: .*
|
||||
mcp_tool_deny_regex: create_gist
|
||||
regex: ^https://example.com/.*$
|
||||
api_base_url: ""
|
||||
display_name: GitHub
|
||||
display_icon: /static/icons/github.svg
|
||||
code_challenge_methods_supported:
|
||||
|
||||
Generated
+1
@@ -279,6 +279,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
|
||||
"external_auth": {
|
||||
"value": [
|
||||
{
|
||||
"api_base_url": "string",
|
||||
"app_install_url": "string",
|
||||
"app_installations_url": "string",
|
||||
"auth_url": "string",
|
||||
|
||||
Generated
+21
-16
@@ -2786,6 +2786,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"external_auth": {
|
||||
"value": [
|
||||
{
|
||||
"api_base_url": "string",
|
||||
"app_install_url": "string",
|
||||
"app_installations_url": "string",
|
||||
"auth_url": "string",
|
||||
@@ -3357,6 +3358,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
"external_auth": {
|
||||
"value": [
|
||||
{
|
||||
"api_base_url": "string",
|
||||
"app_install_url": "string",
|
||||
"app_installations_url": "string",
|
||||
"auth_url": "string",
|
||||
@@ -4104,6 +4106,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
|
||||
```json
|
||||
{
|
||||
"api_base_url": "string",
|
||||
"app_install_url": "string",
|
||||
"app_installations_url": "string",
|
||||
"auth_url": "string",
|
||||
@@ -4133,22 +4136,23 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------|
|
||||
| `app_install_url` | string | false | | |
|
||||
| `app_installations_url` | string | false | | |
|
||||
| `auth_url` | string | false | | |
|
||||
| `client_id` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". |
|
||||
| `device_code_url` | string | false | | |
|
||||
| `device_flow` | boolean | false | | |
|
||||
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
|
||||
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
|
||||
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
|
||||
| `mcp_tool_allow_regex` | string | false | | |
|
||||
| `mcp_tool_deny_regex` | string | false | | |
|
||||
| `mcp_url` | string | false | | |
|
||||
| `no_refresh` | boolean | false | | |
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `api_base_url` | string | false | | Api base URL is the base URL for provider REST API calls (e.g., "https://api.github.com" for GitHub). Derived from defaults when not explicitly configured. |
|
||||
| `app_install_url` | string | false | | |
|
||||
| `app_installations_url` | string | false | | |
|
||||
| `auth_url` | string | false | | |
|
||||
| `client_id` | string | false | | |
|
||||
| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". |
|
||||
| `device_code_url` | string | false | | |
|
||||
| `device_flow` | boolean | false | | |
|
||||
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
|
||||
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
|
||||
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
|
||||
| `mcp_tool_allow_regex` | string | false | | |
|
||||
| `mcp_tool_deny_regex` | string | false | | |
|
||||
| `mcp_url` | string | false | | |
|
||||
| `no_refresh` | boolean | false | | |
|
||||
|`regex`|string|false||Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type.
|
||||
Git clone makes use of this by parsing the URL from: 'Username for "https://github.com":' And sending it to the Coder server to match against the Regex.|
|
||||
|`revoke_url`|string|false|||
|
||||
@@ -14196,6 +14200,7 @@ None
|
||||
{
|
||||
"value": [
|
||||
{
|
||||
"api_base_url": "string",
|
||||
"app_install_url": "string",
|
||||
"app_installations_url": "string",
|
||||
"auth_url": "string",
|
||||
|
||||
Generated
+6
@@ -2682,6 +2682,12 @@ export interface ExternalAuthConfig {
|
||||
* And sending it to the Coder server to match against the Regex.
|
||||
*/
|
||||
readonly regex: string;
|
||||
/**
|
||||
* APIBaseURL is the base URL for provider REST API calls
|
||||
* (e.g., "https://api.github.com" for GitHub). Derived from
|
||||
* defaults when not explicitly configured.
|
||||
*/
|
||||
readonly api_base_url: string;
|
||||
/**
|
||||
* DisplayName is shown in the UI to identify the auth config.
|
||||
*/
|
||||
|
||||
+1
@@ -12,6 +12,7 @@ const meta: Meta<typeof ExternalAuthSettingsPageView> = {
|
||||
type: "GitHub",
|
||||
client_id: "client_id",
|
||||
regex: "regex",
|
||||
api_base_url: "",
|
||||
auth_url: "",
|
||||
token_url: "",
|
||||
validate_url: "",
|
||||
|
||||
Reference in New Issue
Block a user