Compare commits

...

16 Commits

Author SHA1 Message Date
Cian Johnston d13a639172 fix: address review comments for externalauth, coderd shutdown, and gitsync test
- Remove pointer mutation of config.Type in applyDefaultsToConfig; use
  local configType variable with lowercasing instead
- Fix dynamic defaults switch to use configType instead of re-reading
  config.Type
- Revert quartz Clock usage in Close() shutdown timers back to standard
  time.NewTimer and time.After
- Remove unnecessary comment in worker_test.go
2026-03-09 22:00:26 +00:00
Cian Johnston c4e474edb5 refactor(gitsync): replace hand-rolled mocks with gomock 2026-03-09 21:50:18 +00:00
Cian Johnston 3301166e4b fix(coderd): replace bare time.After/NewTimer with quartz Clock in Close() 2026-03-09 21:38:18 +00:00
Cian Johnston cbd307d5c5 fix(externalauth): normalize config.Type to lowercase in applyDefaultsToConfig
Previously, if a user specified Type: "GitHub" (mixed case), the
defaults would not apply because EnhancedExternalAuthProvider constants
are all lowercase. This adds strings.ToLower normalization at the top
of applyDefaultsToConfig, matching what Config.Git() already does.

Adds TestApplyDefaultsToConfig_CaseInsensitive to verify defaults are
applied for mixed-case type values.
2026-03-09 21:38:08 +00:00
Cian Johnston a3def79430 fix casing 2026-03-09 16:29:38 +00:00
Cian Johnston 1920308d15 address copilot nits 2026-03-09 16:22:47 +00:00
Cian Johnston 3c3d0f32eb fixup! feat(coderd): wire gitsync worker and refactor chats.go 2026-03-09 16:03:58 +00:00
Cian Johnston 557354a718 fixup! feat(coderd): wire gitsync worker and refactor chats.go 2026-03-09 15:48:40 +00:00
Cian Johnston e303047b52 fixup! feat(gitsync): background worker for chat diff status refresh 2026-03-09 15:48:21 +00:00
Cian Johnston 9062ac191b fixup! feat(coderd): wire gitsync worker and refactor chats.go 2026-03-09 14:57:19 +00:00
Cian Johnston 3050d45eb1 fixup! feat(coderd): wire gitsync worker and refactor chats.go 2026-03-09 14:35:03 +00:00
Cian Johnston 4aa1aac8b8 fixup! feat(coderd): wire gitsync worker and refactor chats.go 2026-03-09 14:29:06 +00:00
Cian Johnston 5f3dd28fe1 feat(coderd): wire gitsync worker and refactor chats.go
Wire the gitsync Refresher and Worker into coderd.go, started with
AsSystemRestricted context. Call sites in workspaceagents.go invoke
MarkStale on external auth events.

Refactor chats.go: remove inlined goroutine-based refresh logic,
delegate to the gitsync worker for background polling. Remove unused
publishChatStatusEvent and publishChatDiffStatusEvent functions.
resolveChatGitAccessToken now filters ExternalAuthConfigs by origin
regex before falling back to iterating all GitHub configs.
2026-03-09 13:44:34 +00:00
Cian Johnston 3740a132a2 feat(gitsync): background worker for chat diff status refresh
New coderd/gitsync package with two components:

Refresher: batches refresh requests by (ownerID, origin), resolves
the git provider and access token once per group, and calls
FetchPullRequestStatus for each row. Rate-limit errors short-circuit
remaining rows in the group. Uses injected quartz.Clock for
deterministic test timing.

Worker: polls for stale chat_diff_statuses using AcquireStaleChatDiffStatuses
(SELECT ... FOR UPDATE SKIP LOCKED) to prevent duplicate work across
replicas. Failed refreshes are backed off via BackoffChatDiffStatus
which only updates stale_at without clobbering PR data. MarkStale
triggers immediate staleness with pubsub notification.

Database changes (in existing 000422_chats.up.sql migration):
- AcquireStaleChatDiffStatuses: CTE with FOR UPDATE SKIP LOCKED,
  joins chats for owner_id, filters archived, orders by stale_at ASC
- BackoffChatDiffStatus: updates only stale_at and updated_at
- Partial index on stale_at WHERE origin and branch are non-empty
- dbauthz wrappers with ResourceChat type-level authorization
2026-03-09 13:44:34 +00:00
Cian Johnston af14ec844f feat(gitprovider): GitHub PR status, diff fetching, and URL abstraction
New coderd/externalauth/gitprovider package providing a Provider
interface for git hosting platforms. The GitHub implementation
supports:

- Fetching PR status (mergeable, CI checks, reviews, labels)
- Resolving branch-to-PR mapping
- Fetching PR and branch diffs
- Building repository, branch, and PR URLs
- Parsing repository URLs (github.com and GHE)

Rate-limit detection returns a typed RateLimitError with RetryAfter,
parsed from Retry-After and X-RateLimit-Reset headers. Only 403/429
responses with rate-limit headers are treated as rate limits; other
403s (e.g. bad credentials) remain generic errors.
2026-03-09 13:44:33 +00:00
Cian Johnston d5f6756cdf feat(codersdk): add APIBaseURL field to ExternalAuthConfig
Add API_BASE_URL to the external auth provider environment variable
parser in cli/server.go, allowing GitHub Enterprise deployments to
specify a custom API endpoint. The field is included in the SDK type,
generated docs, and Storybook fixtures.
2026-03-09 13:44:33 +00:00
34 changed files with 4485 additions and 707 deletions
+2
View File
@@ -2909,6 +2909,8 @@ func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]coder
provider.MCPToolDenyRegex = v.Value
case "PKCE_METHODS":
provider.CodeChallengeMethodsSupported = strings.Split(v.Value, " ")
case "API_BASE_URL":
provider.APIBaseURL = v.Value
}
providers[providerNum] = provider
}
+23
View File
@@ -108,6 +108,29 @@ func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
})
}
func TestReadExternalAuthProvidersFromEnv_APIBaseURL(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
"CODER_EXTERNAL_AUTH_0_API_BASE_URL=https://ghes.corp.com/api/v3",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "https://ghes.corp.com/api/v3", providers[0].APIBaseURL)
}
func TestReadExternalAuthProvidersFromEnv_APIBaseURLDefault(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_TYPE=github",
"CODER_EXTERNAL_AUTH_0_CLIENT_ID=xxx",
})
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "", providers[0].APIBaseURL)
}
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
// environment variables are still supported.
func TestReadGitAuthProvidersFromEnv(t *testing.T) {
+4
View File
@@ -15448,6 +15448,10 @@ const docTemplate = `{
"codersdk.ExternalAuthConfig": {
"type": "object",
"properties": {
"api_base_url": {
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
"type": "string"
},
"app_install_url": {
"type": "string"
},
+4
View File
@@ -13957,6 +13957,10 @@
"codersdk.ExternalAuthConfig": {
"type": "object",
"properties": {
"api_base_url": {
"description": "APIBaseURL is the base URL for provider REST API calls\n(e.g., \"https://api.github.com\" for GitHub). Derived from\ndefaults when not explicitly configured.",
"type": "string"
},
"app_install_url": {
"type": "string"
},
+177 -668
View File
@@ -13,7 +13,6 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
@@ -32,6 +31,8 @@ import (
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/httpmw"
@@ -39,16 +40,15 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/websocket"
)
const (
chatDiffStatusTTL = 120 * time.Second
chatDiffBackgroundRefreshTimeout = 20 * time.Second
githubAPIBaseURL = "https://api.github.com"
chatStreamBatchSize = 256
chatDiffStatusTTL = gitsync.DiffStatusTTL
chatStreamBatchSize = 256
chatContextLimitModelConfigKey = "context_limit"
chatContextCompressionThresholdModelConfigKey = "context_compression_threshold"
@@ -57,19 +57,6 @@ const (
maxChatContextCompressionThreshold = int32(100)
)
// chatDiffRefreshBackoffSchedule defines the delays between successive
// background diff refresh attempts. The trigger fires when the agent
// obtains a GitHub token, which is typically right before a git push
// or PR creation. The backoff gives progressively more time for the
// push and any PR workflow to complete before querying the GitHub API.
var chatDiffRefreshBackoffSchedule = []time.Duration{
1 * time.Second,
3 * time.Second,
5 * time.Second,
10 * time.Second,
20 * time.Second,
}
// chatGitRef holds the branch and remote origin reported by the
// workspace agent during a git operation.
type chatGitRef struct {
@@ -77,32 +64,6 @@ type chatGitRef struct {
RemoteOrigin string
}
var (
githubPullRequestPathPattern = regexp.MustCompile(
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
)
githubRepositoryHTTPSPattern = regexp.MustCompile(
`^https://github\.com/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
)
githubRepositorySSHPathPattern = regexp.MustCompile(
`^(?:ssh://)?git@github\.com[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
)
)
type githubPullRequestRef struct {
Owner string
Repo string
Number int
}
type githubPullRequestStatus struct {
PullRequestState string
ChangesRequested bool
Additions int32
Deletions int32
ChangedFiles int32
}
type chatRepositoryRef struct {
Provider string
RemoteOrigin string
@@ -1256,193 +1217,6 @@ func shouldRefreshChatDiffStatus(status database.ChatDiffStatus, now time.Time,
return chatDiffStatusIsStale(status, now)
}
func (api *API) triggerWorkspaceChatDiffStatusRefresh(workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
if workspace.ID == uuid.Nil || workspace.OwnerID == uuid.Nil {
return
}
go func(workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
ctx := api.ctx
if ctx == nil {
ctx = context.Background()
}
//nolint:gocritic // Background goroutine for diff status refresh has no user context.
ctx = dbauthz.AsSystemRestricted(ctx)
// Always store the git ref so the data is persisted even
// before a PR exists. The frontend can show branch info
// and the refresh loop can resolve a PR later.
api.storeChatGitRef(ctx, workspaceID, workspaceOwnerID, chatID, gitRef)
for _, delay := range chatDiffRefreshBackoffSchedule {
t := api.Clock.NewTimer(delay, "chat_diff_refresh")
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
}
// Refresh and publish status on every iteration.
// Stop the loop once a PR is discovered — there's
// nothing more to wait for after that.
if api.refreshWorkspaceChatDiffStatuses(ctx, workspaceID, workspaceOwnerID, chatID) {
return
}
}
}(workspace.ID, workspace.OwnerID, chatID, gitRef)
}
// storeChatGitRef persists the git branch and remote origin reported
// by the workspace agent on the chat that initiated the git operation.
// When chatID is set, only that specific chat is updated; otherwise all
// chats associated with the workspace are updated (legacy fallback).
func (api *API) storeChatGitRef(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID, gitRef chatGitRef) {
var chatsToUpdate []database.Chat
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for git ref storage",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: workspaceOwnerID,
})
if err != nil {
api.Logger.Warn(ctx, "failed to list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return
}
chatsToUpdate = filterChatsByWorkspaceID(chats, workspaceID)
}
for _, chat := range chatsToUpdate {
_, err := api.Database.UpsertChatDiffStatusReference(ctx, database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: gitRef.Branch,
GitRemoteOrigin: gitRef.RemoteOrigin,
StaleAt: time.Now().UTC().Add(-time.Second),
Url: sql.NullString{},
})
if err != nil {
api.Logger.Warn(ctx, "failed to store git ref on chat diff status",
slog.F("chat_id", chat.ID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
continue
}
api.publishChatDiffStatusEvent(ctx, chat.ID)
}
}
// refreshWorkspaceChatDiffStatuses refreshes the diff status for chats
// associated with the given workspace. When chatID is set, only that
// specific chat is refreshed; otherwise all chats for the workspace
// are refreshed (legacy fallback). It returns true when every
// refreshed chat has a PR URL resolved, signaling that the caller
// can stop polling.
func (api *API) refreshWorkspaceChatDiffStatuses(ctx context.Context, workspaceID, workspaceOwnerID uuid.UUID, chatID uuid.NullUUID) bool {
var filtered []database.Chat
if chatID.Valid {
chat, err := api.Database.GetChatByID(ctx, chatID.UUID)
if err != nil {
api.Logger.Warn(ctx, "failed to get chat for diff refresh",
slog.F("chat_id", chatID.UUID),
slog.F("workspace_id", workspaceID),
slog.Error(err),
)
return false
}
filtered = []database.Chat{chat}
} else {
chats, err := api.Database.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: workspaceOwnerID,
})
if err != nil {
api.Logger.Warn(ctx, "failed to list workspace owner chats for diff refresh",
slog.F("workspace_id", workspaceID),
slog.F("workspace_owner_id", workspaceOwnerID),
slog.Error(err),
)
return false
}
filtered = filterChatsByWorkspaceID(chats, workspaceID)
}
if len(filtered) == 0 {
return false
}
allHavePR := true
for _, chat := range filtered {
refreshCtx, cancel := context.WithTimeout(ctx, chatDiffBackgroundRefreshTimeout)
status, err := api.resolveChatDiffStatusWithOptions(refreshCtx, chat, true)
cancel()
if err != nil {
api.Logger.Warn(ctx, "failed to refresh chat diff status after workspace external auth",
slog.F("workspace_id", workspaceID),
slog.F("chat_id", chat.ID),
slog.Error(err),
)
allHavePR = false
} else if status == nil || !status.Url.Valid || strings.TrimSpace(status.Url.String) == "" {
allHavePR = false
}
api.publishChatStatusEvent(ctx, chat.ID)
api.publishChatDiffStatusEvent(ctx, chat.ID)
}
return allHavePR
}
func filterChatsByWorkspaceID(chats []database.Chat, workspaceID uuid.UUID) []database.Chat {
filteredChats := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
continue
}
filteredChats = append(filteredChats, chat)
}
return filteredChats
}
func (api *API) publishChatStatusEvent(ctx context.Context, chatID uuid.UUID) {
if api.chatDaemon == nil {
return
}
if err := api.chatDaemon.RefreshStatus(ctx, chatID); err != nil {
api.Logger.Debug(ctx, "failed to refresh published chat status",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
func (api *API) publishChatDiffStatusEvent(ctx context.Context, chatID uuid.UUID) {
if api.chatDaemon == nil {
return
}
if err := api.chatDaemon.PublishDiffStatusChange(ctx, chatID); err != nil {
api.Logger.Debug(ctx, "failed to publish chat diff status change",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
func (api *API) resolveChatDiffContents(
ctx context.Context,
chat database.Chat,
@@ -1490,22 +1264,36 @@ func (api *API) resolveChatDiffContents(
if reference.RepositoryRef == nil {
return result, nil
}
if !strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
if gp == nil {
return result, nil
}
token := api.resolveChatGitHubAccessToken(ctx, chat.OwnerID)
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
if err != nil {
return result, xerrors.Errorf("resolve git access token: %w", err)
} else if token == nil {
return result, xerrors.New("nil git access token")
}
if reference.PullRequestURL != "" {
diff, err := api.fetchGitHubPullRequestDiff(ctx, reference.PullRequestURL, token)
ref, ok := gp.ParsePullRequestURL(reference.PullRequestURL)
if !ok {
return result, xerrors.Errorf("invalid pull request URL %q", reference.PullRequestURL)
}
diff, err := gp.FetchPullRequestDiff(ctx, *token, ref)
if err != nil {
return result, err
}
result.Diff = diff
return result, nil
}
diff, err := api.fetchGitHubCompareDiff(ctx, *reference.RepositoryRef, token)
diff, err := gp.FetchBranchDiff(ctx, *token, gitprovider.BranchRef{
Owner: reference.RepositoryRef.Owner,
Repo: reference.RepositoryRef.Repo,
Branch: reference.RepositoryRef.Branch,
})
if err != nil {
return result, err
}
@@ -1539,34 +1327,53 @@ func (api *API) resolveChatDiffReference(
// If we have a repo ref with a branch, try to resolve the
// current open PR. This picks up new PRs after the previous
// one was closed.
if reference.RepositoryRef != nil &&
strings.EqualFold(reference.RepositoryRef.Provider, string(codersdk.EnhancedExternalAuthProviderGitHub)) {
pullRequestURL, lookupErr := api.resolveGitHubPullRequestURLFromRepositoryRef(ctx, chat.OwnerID, *reference.RepositoryRef)
if lookupErr != nil {
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
slog.F("chat_id", chat.ID),
slog.F("provider", reference.RepositoryRef.Provider),
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
slog.F("branch", reference.RepositoryRef.Branch),
slog.Error(lookupErr),
)
} else if pullRequestURL != "" {
reference.PullRequestURL = pullRequestURL
if reference.RepositoryRef != nil && reference.RepositoryRef.Owner != "" {
gp := api.resolveGitProvider(reference.RepositoryRef.RemoteOrigin)
if gp != nil {
token, err := api.resolveChatGitAccessToken(ctx, chat.OwnerID, reference.RepositoryRef.RemoteOrigin)
if token == nil || errors.Is(err, gitsync.ErrNoTokenAvailable) {
// No token available yet.
return reference, nil
} else if err != nil {
return chatDiffReference{}, xerrors.Errorf("resolve git access token: %w", err)
}
prRef, lookupErr := gp.ResolveBranchPullRequest(ctx, *token, gitprovider.BranchRef{
Owner: reference.RepositoryRef.Owner,
Repo: reference.RepositoryRef.Repo,
Branch: reference.RepositoryRef.Branch,
})
if lookupErr != nil {
api.Logger.Debug(ctx, "failed to resolve pull request from repository reference",
slog.F("chat_id", chat.ID),
slog.F("provider", reference.RepositoryRef.Provider),
slog.F("remote_origin", reference.RepositoryRef.RemoteOrigin),
slog.F("branch", reference.RepositoryRef.Branch),
slog.Error(lookupErr),
)
} else if prRef != nil {
reference.PullRequestURL = gp.BuildPullRequestURL(*prRef)
}
reference.PullRequestURL = gp.NormalizePullRequestURL(reference.PullRequestURL)
}
}
reference.PullRequestURL = normalizeGitHubPullRequestURL(reference.PullRequestURL)
// If we have a PR URL but no repo ref (e.g. the agent hasn't
// reported branch/origin yet), derive a partial ref from the
// PR URL so the caller can still show provider/owner/repo.
if reference.RepositoryRef == nil && reference.PullRequestURL != "" {
if parsed, ok := parseGitHubPullRequestURL(reference.PullRequestURL); ok {
reference.RepositoryRef = &chatRepositoryRef{
Provider: string(codersdk.EnhancedExternalAuthProviderGitHub),
RemoteOrigin: fmt.Sprintf("https://github.com/%s/%s", parsed.Owner, parsed.Repo),
Owner: parsed.Owner,
Repo: parsed.Repo,
for _, extAuth := range api.ExternalAuthConfigs {
gp := extAuth.Git(api.HTTPClient)
if gp == nil {
continue
}
if parsed, ok := gp.ParsePullRequestURL(reference.PullRequestURL); ok {
reference.RepositoryRef = &chatRepositoryRef{
Provider: strings.ToLower(extAuth.Type),
Owner: parsed.Owner,
Repo: parsed.Repo,
RemoteOrigin: gp.BuildRepositoryURL(parsed.Owner, parsed.Repo),
}
break
}
}
}
@@ -1584,19 +1391,18 @@ func (api *API) buildChatRepositoryRefFromStatus(status database.ChatDiffStatus)
return nil
}
providerType, gp := api.resolveExternalAuth(origin)
repoRef := &chatRepositoryRef{
Provider: strings.TrimSpace(api.resolveExternalAuthProviderType(origin)),
Provider: providerType,
RemoteOrigin: origin,
Branch: branch,
}
if owner, repo, normalizedOrigin, ok := parseGitHubRepositoryOrigin(repoRef.RemoteOrigin); ok {
if repoRef.Provider == "" {
repoRef.Provider = string(codersdk.EnhancedExternalAuthProviderGitHub)
if gp != nil {
if owner, repo, normalizedOrigin, ok := gp.ParseRepositoryOrigin(repoRef.RemoteOrigin); ok {
repoRef.RemoteOrigin = normalizedOrigin
repoRef.Owner = owner
repoRef.Repo = repo
}
repoRef.RemoteOrigin = normalizedOrigin
repoRef.Owner = owner
repoRef.Repo = repo
}
if repoRef.Provider == "" {
@@ -1650,60 +1456,31 @@ func (api *API) getCachedChatDiffStatus(
)
}
func (api *API) resolveExternalAuthProviderType(match string) string {
match = strings.TrimSpace(match)
if match == "" {
return ""
// resolveExternalAuth finds the external auth config matching the
// given remote origin URL and returns both the provider type string
// (e.g. "github") and the gitprovider.Provider. Returns ("", nil)
// if no matching config is found.
func (api *API) resolveExternalAuth(origin string) (providerType string, gp gitprovider.Provider) {
origin = strings.TrimSpace(origin)
if origin == "" {
return "", nil
}
for _, extAuth := range api.ExternalAuthConfigs {
if extAuth.Regex == nil || !extAuth.Regex.MatchString(match) {
if extAuth.Regex == nil || !extAuth.Regex.MatchString(origin) {
continue
}
return strings.ToLower(strings.TrimSpace(extAuth.Type))
return strings.ToLower(strings.TrimSpace(extAuth.Type)),
extAuth.Git(api.HTTPClient)
}
return ""
return "", nil
}
func parseGitHubRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", "", "", false
}
matches := githubRepositoryHTTPSPattern.FindStringSubmatch(raw)
if len(matches) != 3 {
matches = githubRepositorySSHPathPattern.FindStringSubmatch(raw)
}
if len(matches) != 3 {
return "", "", "", false
}
owner = strings.TrimSpace(matches[1])
repo = strings.TrimSpace(matches[2])
repo = strings.TrimSuffix(repo, ".git")
if owner == "" || repo == "" {
return "", "", "", false
}
return owner, repo, fmt.Sprintf("https://github.com/%s/%s", owner, repo), true
}
func buildGitHubBranchURL(owner string, repo string, branch string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
branch = strings.TrimSpace(branch)
if owner == "" || repo == "" || branch == "" {
return ""
}
return fmt.Sprintf(
"https://github.com/%s/%s/tree/%s",
owner,
repo,
url.PathEscape(branch),
)
// resolveGitProvider finds the external auth config matching the
// given remote origin URL and returns its git provider. Returns
// nil if no matching git provider is configured.
func (api *API) resolveGitProvider(origin string) gitprovider.Provider {
_, gp := api.resolveExternalAuth(origin)
return gp
}
func chatDiffStatusIsStale(status database.ChatDiffStatus, now time.Time) bool {
@@ -1719,11 +1496,32 @@ func (api *API) refreshChatDiffStatus(
chatID uuid.UUID,
pullRequestURL string,
) (database.ChatDiffStatus, error) {
status, err := api.fetchGitHubPullRequestStatus(
ctx,
pullRequestURL,
api.resolveChatGitHubAccessToken(ctx, chatOwnerID),
)
// Find a provider that can handle this PR URL.
var gp gitprovider.Provider
var ref gitprovider.PRRef
for _, extAuth := range api.ExternalAuthConfigs {
p := extAuth.Git(api.HTTPClient)
if p == nil {
continue
}
if parsed, ok := p.ParsePullRequestURL(pullRequestURL); ok {
gp = p
ref = parsed
break
}
}
if gp == nil {
return database.ChatDiffStatus{}, xerrors.Errorf("no git provider found for PR URL %q", pullRequestURL)
}
origin := gp.BuildRepositoryURL(ref.Owner, ref.Repo)
token, err := api.resolveChatGitAccessToken(ctx, chatOwnerID, origin)
if err != nil {
return database.ChatDiffStatus{}, xerrors.Errorf("resolve git access token: %w", err)
} else if token == nil {
return database.ChatDiffStatus{}, xerrors.New("nil git access token")
}
status, err := gp.FetchPullRequestStatus(ctx, *token, ref)
if err != nil {
return database.ChatDiffStatus{}, err
}
@@ -1735,13 +1533,13 @@ func (api *API) refreshChatDiffStatus(
ChatID: chatID,
Url: sql.NullString{String: pullRequestURL, Valid: true},
PullRequestState: sql.NullString{
String: status.PullRequestState,
Valid: status.PullRequestState != "",
String: string(status.State),
Valid: status.State != "",
},
ChangesRequested: status.ChangesRequested,
Additions: status.Additions,
Deletions: status.Deletions,
ChangedFiles: status.ChangedFiles,
Additions: status.DiffStats.Additions,
Deletions: status.DiffStats.Deletions,
ChangedFiles: status.DiffStats.ChangedFiles,
RefreshedAt: refreshedAt,
StaleAt: refreshedAt.Add(chatDiffStatusTTL),
},
@@ -1752,23 +1550,49 @@ func (api *API) refreshChatDiffStatus(
return refreshedStatus, nil
}
func (api *API) resolveChatGitHubAccessToken(
func (api *API) resolveChatGitAccessToken(
ctx context.Context,
userID uuid.UUID,
) string {
// Build a map of provider ID -> config so we can refresh tokens
// using the same code path as provisionerdserver.
ghConfigs := make(map[string]*externalauth.Config)
providerIDs := []string{"github"}
for _, config := range api.ExternalAuthConfigs {
if !strings.EqualFold(
config.Type,
string(codersdk.EnhancedExternalAuthProviderGitHub),
) {
continue
origin string,
) (*string, error) {
origin = strings.TrimSpace(origin)
// If we have an origin, find the specific matching config first.
// This ensures multi-provider setups (github.com + GHE) get the
// correct token.
if origin != "" {
for _, config := range api.ExternalAuthConfigs {
if config.Regex == nil || !config.Regex.MatchString(origin) {
continue
}
link, err := api.Database.GetExternalAuthLink(ctx,
database.GetExternalAuthLinkParams{
ProviderID: config.ID,
UserID: userID,
},
)
if err != nil {
continue
}
refreshed, refreshErr := config.RefreshToken(ctx, api.Database, link)
if refreshErr == nil {
link = refreshed
}
token := strings.TrimSpace(link.OAuthAccessToken)
if token != "" {
return ptr.Ref(token), nil
}
}
}
// Fallback: iterate all external auth configs.
// Used when origin is empty (inline refresh from HTTP handler)
// or when the origin-specific lookup above failed.
configs := make(map[string]*externalauth.Config)
providerIDs := []string{}
for _, config := range api.ExternalAuthConfigs {
providerIDs = append(providerIDs, config.ID)
ghConfigs[config.ID] = config
configs[config.ID] = config
}
seen := map[string]struct{}{}
@@ -1792,7 +1616,7 @@ func (api *API) resolveChatGitHubAccessToken(
// Refresh the token if there is a matching config, mirroring
// the same code path used by provisionerdserver when handing
// tokens to provisioners.
if cfg, ok := ghConfigs[providerID]; ok {
if cfg, ok := configs[providerID]; ok {
refreshed, refreshErr := cfg.RefreshToken(ctx, api.Database, link)
if refreshErr != nil {
api.Logger.Debug(ctx, "failed to refresh external auth token for chat diff",
@@ -1809,336 +1633,11 @@ func (api *API) resolveChatGitHubAccessToken(
token := strings.TrimSpace(link.OAuthAccessToken)
if token != "" {
return token
return ptr.Ref(token), nil
}
}
return ""
}
func (api *API) resolveGitHubPullRequestURLFromRepositoryRef(
ctx context.Context,
userID uuid.UUID,
repositoryRef chatRepositoryRef,
) (string, error) {
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
return "", nil
}
query := url.Values{}
query.Set("state", "open")
query.Set("head", fmt.Sprintf("%s:%s", repositoryRef.Owner, repositoryRef.Branch))
query.Set("sort", "updated")
query.Set("direction", "desc")
query.Set("per_page", "1")
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls?%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
query.Encode(),
)
var pulls []struct {
HTMLURL string `json:"html_url"`
}
token := api.resolveChatGitHubAccessToken(ctx, userID)
if err := api.decodeGitHubJSON(ctx, requestURL, token, &pulls); err != nil {
return "", err
}
if len(pulls) == 0 {
return "", nil
}
return normalizeGitHubPullRequestURL(pulls[0].HTMLURL), nil
}
func (api *API) fetchGitHubPullRequestDiff(
ctx context.Context,
pullRequestURL string,
token string,
) (string, error) {
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
if !ok {
return "", xerrors.Errorf("invalid GitHub pull request URL %q", pullRequestURL)
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
githubAPIBaseURL,
ref.Owner,
ref.Repo,
ref.Number,
)
return api.fetchGitHubDiff(ctx, requestURL, token)
}
func (api *API) fetchGitHubCompareDiff(
ctx context.Context,
repositoryRef chatRepositoryRef,
token string,
) (string, error) {
if repositoryRef.Owner == "" || repositoryRef.Repo == "" || repositoryRef.Branch == "" {
return "", nil
}
var repository struct {
DefaultBranch string `json:"default_branch"`
}
repositoryURL := fmt.Sprintf(
"%s/repos/%s/%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
)
if err := api.decodeGitHubJSON(ctx, repositoryURL, token, &repository); err != nil {
return "", err
}
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
if defaultBranch == "" {
return "", xerrors.New("github repository default branch is empty")
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/compare/%s...%s",
githubAPIBaseURL,
repositoryRef.Owner,
repositoryRef.Repo,
url.PathEscape(defaultBranch),
url.PathEscape(repositoryRef.Branch),
)
return api.fetchGitHubDiff(ctx, requestURL, token)
}
func (api *API) fetchGitHubDiff(
ctx context.Context,
requestURL string,
token string,
) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return "", xerrors.Errorf("create github diff request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.diff")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
httpClient := api.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return "", xerrors.Errorf("execute github diff request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
}
return "", xerrors.Errorf(
"github diff request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
diff, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20))
if err != nil {
return "", xerrors.Errorf("read github diff response: %w", err)
}
return string(diff), nil
}
func (api *API) fetchGitHubPullRequestStatus(
ctx context.Context,
pullRequestURL string,
token string,
) (githubPullRequestStatus, error) {
ref, ok := parseGitHubPullRequestURL(pullRequestURL)
if !ok {
return githubPullRequestStatus{}, xerrors.Errorf(
"invalid GitHub pull request URL %q",
pullRequestURL,
)
}
pullEndpoint := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
githubAPIBaseURL,
ref.Owner,
ref.Repo,
ref.Number,
)
var pull struct {
State string `json:"state"`
Additions int32 `json:"additions"`
Deletions int32 `json:"deletions"`
ChangedFiles int32 `json:"changed_files"`
}
if err := api.decodeGitHubJSON(ctx, pullEndpoint, token, &pull); err != nil {
return githubPullRequestStatus{}, err
}
var reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
if err := api.decodeGitHubJSON(
ctx,
pullEndpoint+"/reviews?per_page=100",
token,
&reviews,
); err != nil {
return githubPullRequestStatus{}, err
}
return githubPullRequestStatus{
PullRequestState: strings.ToLower(strings.TrimSpace(pull.State)),
ChangesRequested: hasOutstandingGitHubChangesRequested(reviews),
Additions: pull.Additions,
Deletions: pull.Deletions,
ChangedFiles: pull.ChangedFiles,
}, nil
}
func (api *API) decodeGitHubJSON(
ctx context.Context,
requestURL string,
token string,
dest any,
) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return xerrors.Errorf("create github request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff-status")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
httpClient := api.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return xerrors.Errorf("execute github request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return xerrors.Errorf(
"github request failed with status %d",
resp.StatusCode,
)
}
return xerrors.Errorf(
"github request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
return xerrors.Errorf("decode github response: %w", err)
}
return nil
}
func hasOutstandingGitHubChangesRequested(
reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
},
) bool {
type reviewerState struct {
reviewID int64
state string
}
statesByReviewer := make(map[string]reviewerState)
for _, review := range reviews {
login := strings.ToLower(strings.TrimSpace(review.User.Login))
if login == "" {
continue
}
state := strings.ToUpper(strings.TrimSpace(review.State))
switch state {
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
default:
continue
}
current, exists := statesByReviewer[login]
if exists && current.reviewID > review.ID {
continue
}
statesByReviewer[login] = reviewerState{
reviewID: review.ID,
state: state,
}
}
for _, state := range statesByReviewer {
if state.state == "CHANGES_REQUESTED" {
return true
}
}
return false
}
func normalizeGitHubPullRequestURL(raw string) string {
ref, ok := parseGitHubPullRequestURL(strings.TrimRight(
strings.TrimSpace(raw),
"),.;",
))
if !ok {
return ""
}
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
}
func parseGitHubPullRequestURL(raw string) (githubPullRequestRef, bool) {
matches := githubPullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
if len(matches) != 4 {
return githubPullRequestRef{}, false
}
number, err := strconv.Atoi(matches[3])
if err != nil {
return githubPullRequestRef{}, false
}
return githubPullRequestRef{
Owner: matches[1],
Repo: matches[2],
Number: number,
}, true
return nil, gitsync.ErrNoTokenAvailable
}
type createChatWorkspaceSelection struct {
@@ -2696,11 +2195,21 @@ func convertChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) co
}
}
if result.URL == nil {
owner, repo, _, ok := parseGitHubRepositoryOrigin(status.GitRemoteOrigin)
if ok {
branchURL := buildGitHubBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
// Try to build a branch URL from the stored origin.
// Since convertChatDiffStatus does not have access to
// the API instance, we construct a GitHub provider
// directly as a best-effort fallback.
// TODO: This uses the default github.com API base URL,
// so branch URLs for GitHub Enterprise instances will
// be incorrect. To fix this, convertChatDiffStatus
// would need access to the external auth configs.
gp := gitprovider.New("github", "", nil)
if gp != nil {
if owner, repo, _, ok := gp.ParseRepositoryOrigin(status.GitRemoteOrigin); ok {
branchURL := gp.BuildBranchURL(owner, repo, status.GitBranch)
if branchURL != "" {
result.URL = &branchURL
}
}
}
}
+1 -1
View File
@@ -2521,7 +2521,7 @@ func TestGetChatDiffStatus(t *testing.T) {
require.NoError(t, err)
require.Equal(t, cachedStatusChat.ID, cachedStatus.ChatID)
require.NotNil(t, cachedStatus.URL)
require.Equal(t, "https://github.com/coder/coder/tree/feature%2Fdiff-status", *cachedStatus.URL)
require.Equal(t, "https://github.com/coder/coder/tree/feature/diff-status", *cachedStatus.URL)
require.NotNil(t, cachedStatus.PullRequestState)
require.Equal(t, "open", *cachedStatus.PullRequestState)
require.True(t, cachedStatus.ChangesRequested)
+25
View File
@@ -61,6 +61,7 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/gitsshkey"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/healthcheck"
"github.com/coder/coder/v2/coderd/healthcheck/derphealth"
"github.com/coder/coder/v2/coderd/httpapi"
@@ -772,6 +773,20 @@ func New(options *Options) *API {
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
})
refresher := gitsync.NewRefresher(
api.resolveGitProvider,
api.resolveChatGitAccessToken,
options.Logger.Named("gitsync").Named("refresher"),
quartz.NewReal(),
)
api.gitSyncWorker = gitsync.NewWorker(options.Database,
refresher,
api.chatDaemon,
quartz.NewReal(),
options.Logger.Named("gitsync"),
)
// nolint:gocritic // chat diff worker needs to be able to CRUD chats.
go api.gitSyncWorker.Start(dbauthz.AsChatd(api.ctx))
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(stn)
api.lifecycleMetrics = agentapi.NewLifecycleMetrics(options.PrometheusRegistry)
@@ -1989,6 +2004,9 @@ type API struct {
dbRolluper *dbrollup.Rolluper
// chatDaemon handles background processing of pending chats.
chatDaemon *chatd.Server
// gitSyncWorker refreshes stale chat diff statuses in the
// background.
gitSyncWorker *gitsync.Worker
}
// Close waits for all WebSocket connections to drain before returning.
@@ -2018,6 +2036,13 @@ func (api *API) Close() error {
api.Logger.Warn(api.ctx, "websocket shutdown timed out after 10 seconds")
}
api.dbRolluper.Close()
// chatDiffWorker is unconditionally initialized in New().
select {
case <-api.gitSyncWorker.Done():
case <-time.After(10 * time.Second):
api.Logger.Warn(context.Background(),
"chat diff refresh worker did not exit in time")
}
if err := api.chatDaemon.Close(); err != nil {
api.Logger.Warn(api.ctx, "close chat processor", slog.Error(err))
}
+21
View File
@@ -1539,6 +1539,17 @@ func (q *querier) AcquireProvisionerJob(ctx context.Context, arg database.Acquir
return q.db.AcquireProvisionerJob(ctx, arg)
}
func (q *querier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
// This is a system-level batch operation used by the gitsync
// background worker. Per-object authorization is impractical
// for a SKIP LOCKED acquisition query; callers must use
// AsChatd or AsSystemRestricted context.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.AcquireStaleChatDiffStatuses(ctx, limitVal)
}
func (q *querier) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
fetch := func(ctx context.Context, arg database.ActivityBumpWorkspaceParams) (database.Workspace, error) {
return q.db.GetWorkspaceByID(ctx, arg.WorkspaceID)
@@ -1577,6 +1588,16 @@ func (q *querier) ArchiveUnusedTemplateVersions(ctx context.Context, arg databas
return q.db.ArchiveUnusedTemplateVersions(ctx, arg)
}
func (q *querier) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
// This is a system-level operation used by the gitsync
// background worker to reschedule failed refreshes. Same
// authorization pattern as AcquireStaleChatDiffStatuses.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return err
}
return q.db.BackoffChatDiffStatus(ctx, arg)
}
func (q *querier) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
// Could be any workspace agent and checking auth to each workspace agent is overkill for
// the purpose of this function.
+12
View File
@@ -758,6 +758,18 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), arg).Return(diffStatus, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(diffStatus)
}))
s.Run("AcquireStaleChatDiffStatuses", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), int32(10)).Return([]database.AcquireStaleChatDiffStatusesRow{}, nil).AnyTimes()
check.Args(int32(10)).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]database.AcquireStaleChatDiffStatusesRow{})
}))
s.Run("BackoffChatDiffStatus", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.BackoffChatDiffStatusParams{
ChatID: uuid.New(),
StaleAt: dbtime.Now(),
}
dbm.EXPECT().BackoffChatDiffStatus(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns()
}))
}
func (s *MethodTestSuite) TestFile() {
+16
View File
@@ -136,6 +136,14 @@ func (m queryMetricsStore) AcquireProvisionerJob(ctx context.Context, arg databa
return r0, r1
}
func (m queryMetricsStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
start := time.Now()
r0, r1 := m.s.AcquireStaleChatDiffStatuses(ctx, limitVal)
m.queryLatencies.WithLabelValues("AcquireStaleChatDiffStatuses").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "AcquireStaleChatDiffStatuses").Inc()
return r0, r1
}
func (m queryMetricsStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
start := time.Now()
r0 := m.s.ActivityBumpWorkspace(ctx, arg)
@@ -168,6 +176,14 @@ func (m queryMetricsStore) ArchiveUnusedTemplateVersions(ctx context.Context, ar
return r0, r1
}
func (m queryMetricsStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
start := time.Now()
r0 := m.s.BackoffChatDiffStatus(ctx, arg)
m.queryLatencies.WithLabelValues("BackoffChatDiffStatus").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "BackoffChatDiffStatus").Inc()
return r0
}
func (m queryMetricsStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
start := time.Now()
r0 := m.s.BatchUpdateWorkspaceAgentMetadata(ctx, arg)
+29
View File
@@ -103,6 +103,21 @@ func (mr *MockStoreMockRecorder) AcquireProvisionerJob(ctx, arg any) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireProvisionerJob", reflect.TypeOf((*MockStore)(nil).AcquireProvisionerJob), ctx, arg)
}
// AcquireStaleChatDiffStatuses mocks base method.
func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal)
ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses.
func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal)
}
// ActivityBumpWorkspace mocks base method.
func (m *MockStore) ActivityBumpWorkspace(ctx context.Context, arg database.ActivityBumpWorkspaceParams) error {
m.ctrl.T.Helper()
@@ -161,6 +176,20 @@ func (mr *MockStoreMockRecorder) ArchiveUnusedTemplateVersions(ctx, arg any) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ArchiveUnusedTemplateVersions", reflect.TypeOf((*MockStore)(nil).ArchiveUnusedTemplateVersions), ctx, arg)
}
// BackoffChatDiffStatus mocks base method.
func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus.
func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
}
// BatchUpdateWorkspaceAgentMetadata mocks base method.
func (m *MockStore) BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg database.BatchUpdateWorkspaceAgentMetadataParams) error {
m.ctrl.T.Helper()
+2
View File
@@ -39,6 +39,7 @@ type sqlcQuerier interface {
// multiple provisioners from acquiring the same jobs. See:
// https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error)
AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error)
// Bumps the workspace deadline by the template's configured "activity_bump"
// duration (default 1h). If the workspace bump will cross an autostart
// threshold, then the bump is autostart + TTL. This is the deadline behavior if
@@ -60,6 +61,7 @@ type sqlcQuerier interface {
// Only unused template versions will be archived, which are any versions not
// referenced by the latest build of a workspace.
ArchiveUnusedTemplateVersions(ctx context.Context, arg ArchiveUnusedTemplateVersionsParams) ([]uuid.UUID, error)
BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error
BatchUpdateWorkspaceAgentMetadata(ctx context.Context, arg BatchUpdateWorkspaceAgentMetadataParams) error
BatchUpdateWorkspaceLastUsedAt(ctx context.Context, arg BatchUpdateWorkspaceLastUsedAtParams) error
BatchUpdateWorkspaceNextStartAt(ctx context.Context, arg BatchUpdateWorkspaceNextStartAtParams) error
+116
View File
@@ -3026,6 +3026,102 @@ func (q *sqlQuerier) AcquireChat(ctx context.Context, arg AcquireChatParams) (Ch
return i, err
}
const acquireStaleChatDiffStatuses = `-- name: AcquireStaleChatDiffStatuses :many
WITH acquired AS (
UPDATE
chat_diff_statuses
SET
-- Claim for 5 minutes. The worker sets the real stale_at
-- after refresh. If the worker crashes, rows become eligible
-- again after this interval.
stale_at = NOW() + INTERVAL '5 minutes',
updated_at = NOW()
WHERE
chat_id IN (
SELECT
cds.chat_id
FROM
chat_diff_statuses cds
INNER JOIN
chats c ON c.id = cds.chat_id
WHERE
cds.stale_at <= NOW()
AND cds.git_remote_origin != ''
AND cds.git_branch != ''
AND c.archived = FALSE
ORDER BY
cds.stale_at ASC
FOR UPDATE OF cds
SKIP LOCKED
LIMIT
$1::int
)
RETURNING chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin
)
SELECT
acquired.chat_id, acquired.url, acquired.pull_request_state, acquired.changes_requested, acquired.additions, acquired.deletions, acquired.changed_files, acquired.refreshed_at, acquired.stale_at, acquired.created_at, acquired.updated_at, acquired.git_branch, acquired.git_remote_origin,
c.owner_id
FROM
acquired
INNER JOIN
chats c ON c.id = acquired.chat_id
`
type AcquireStaleChatDiffStatusesRow struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Url sql.NullString `db:"url" json:"url"`
PullRequestState sql.NullString `db:"pull_request_state" json:"pull_request_state"`
ChangesRequested bool `db:"changes_requested" json:"changes_requested"`
Additions int32 `db:"additions" json:"additions"`
Deletions int32 `db:"deletions" json:"deletions"`
ChangedFiles int32 `db:"changed_files" json:"changed_files"`
RefreshedAt sql.NullTime `db:"refreshed_at" json:"refreshed_at"`
StaleAt time.Time `db:"stale_at" json:"stale_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
GitBranch string `db:"git_branch" json:"git_branch"`
GitRemoteOrigin string `db:"git_remote_origin" json:"git_remote_origin"`
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
}
func (q *sqlQuerier) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]AcquireStaleChatDiffStatusesRow, error) {
rows, err := q.db.QueryContext(ctx, acquireStaleChatDiffStatuses, limitVal)
if err != nil {
return nil, err
}
defer rows.Close()
var items []AcquireStaleChatDiffStatusesRow
for rows.Next() {
var i AcquireStaleChatDiffStatusesRow
if err := rows.Scan(
&i.ChatID,
&i.Url,
&i.PullRequestState,
&i.ChangesRequested,
&i.Additions,
&i.Deletions,
&i.ChangedFiles,
&i.RefreshedAt,
&i.StaleAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.GitBranch,
&i.GitRemoteOrigin,
&i.OwnerID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const archiveChatByID = `-- name: ArchiveChatByID :exec
UPDATE chats SET archived = true, updated_at = NOW()
WHERE id = $1 OR root_chat_id = $1
@@ -3036,6 +3132,26 @@ func (q *sqlQuerier) ArchiveChatByID(ctx context.Context, id uuid.UUID) error {
return err
}
const backoffChatDiffStatus = `-- name: BackoffChatDiffStatus :exec
UPDATE
chat_diff_statuses
SET
stale_at = $1::timestamptz,
updated_at = NOW()
WHERE
chat_id = $2::uuid
`
type BackoffChatDiffStatusParams struct {
StaleAt time.Time `db:"stale_at" json:"stale_at"`
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
}
func (q *sqlQuerier) BackoffChatDiffStatus(ctx context.Context, arg BackoffChatDiffStatusParams) error {
_, err := q.db.ExecContext(ctx, backoffChatDiffStatus, arg.StaleAt, arg.ChatID)
return err
}
const deleteAllChatQueuedMessages = `-- name: DeleteAllChatQueuedMessages :exec
DELETE FROM chat_queued_messages WHERE chat_id = $1
`
+49
View File
@@ -410,3 +410,52 @@ RETURNING *;
-- name: GetChatByIDForUpdate :one
SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;
-- name: AcquireStaleChatDiffStatuses :many
WITH acquired AS (
UPDATE
chat_diff_statuses
SET
-- Claim for 5 minutes. The worker sets the real stale_at
-- after refresh. If the worker crashes, rows become eligible
-- again after this interval.
stale_at = NOW() + INTERVAL '5 minutes',
updated_at = NOW()
WHERE
chat_id IN (
SELECT
cds.chat_id
FROM
chat_diff_statuses cds
INNER JOIN
chats c ON c.id = cds.chat_id
WHERE
cds.stale_at <= NOW()
AND cds.git_remote_origin != ''
AND cds.git_branch != ''
AND c.archived = FALSE
ORDER BY
cds.stale_at ASC
FOR UPDATE OF cds
SKIP LOCKED
LIMIT
@limit_val::int
)
RETURNING *
)
SELECT
acquired.*,
c.owner_id
FROM
acquired
INNER JOIN
chats c ON c.id = acquired.chat_id;
-- name: BackoffChatDiffStatus :exec
UPDATE
chat_diff_statuses
SET
stale_at = @stale_at::timestamptz,
updated_at = NOW()
WHERE
chat_id = @chat_id::uuid;
+32 -3
View File
@@ -23,6 +23,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/promoauth"
"github.com/coder/coder/v2/coderd/util/slice"
"github.com/coder/coder/v2/codersdk"
@@ -82,6 +83,10 @@ type Config struct {
// a Git clone. e.g. "Username for 'https://github.com':"
// The regex would be `github\.com`..
Regex *regexp.Regexp
// APIBaseURL is the base URL for provider REST API calls
// (e.g., "https://api.github.com" for GitHub). Derived from
// defaults when not explicitly configured.
APIBaseURL string
// AppInstallURL is for GitHub App's (and hopefully others eventually)
// to provide a link to install the app. There's installation
// of the application, and user authentication. It's possible
@@ -106,12 +111,23 @@ type Config struct {
CodeChallengeMethodsSupported []promoauth.Oauth2PKCEChallengeMethod
}
// Git returns a Provider for this config if the provider type
// is a supported git hosting provider. Returns nil for non-git
// providers (e.g. Slack, JFrog).
func (c *Config) Git(client *http.Client) gitprovider.Provider {
norm := strings.ToLower(c.Type)
if !codersdk.EnhancedExternalAuthProvider(norm).Git() {
return nil
}
return gitprovider.New(norm, c.APIBaseURL, client)
}
// GenerateTokenExtra generates the extra token data to store in the database.
func (c *Config) GenerateTokenExtra(token *oauth2.Token) (pqtype.NullRawMessage, error) {
if len(c.ExtraTokenKeys) == 0 {
return pqtype.NullRawMessage{}, nil
}
extraMap := map[string]interface{}{}
extraMap := map[string]any{}
for _, key := range c.ExtraTokenKeys {
extraMap[key] = token.Extra(key)
}
@@ -729,6 +745,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
ClientID: entry.ClientID,
ClientSecret: entry.ClientSecret,
Regex: regex,
APIBaseURL: entry.APIBaseURL,
Type: entry.Type,
NoRefresh: entry.NoRefresh,
ValidateURL: entry.ValidateURL,
@@ -765,7 +782,7 @@ func ConvertConfig(instrument *promoauth.Factory, entries []codersdk.ExternalAut
// applyDefaultsToConfig applies defaults to the config entry.
func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
configType := codersdk.EnhancedExternalAuthProvider(config.Type)
configType := codersdk.EnhancedExternalAuthProvider(strings.ToLower(config.Type))
if configType == "bitbucket" {
// For backwards compatibility, we need to support the "bitbucket" string.
configType = codersdk.EnhancedExternalAuthProviderBitBucketCloud
@@ -782,7 +799,7 @@ func applyDefaultsToConfig(config *codersdk.ExternalAuthConfig) {
}
// Dynamic defaults
switch codersdk.EnhancedExternalAuthProvider(config.Type) {
switch configType {
case codersdk.EnhancedExternalAuthProviderGitHub:
copyDefaultSettings(config, gitHubDefaults(config))
return
@@ -863,6 +880,18 @@ func copyDefaultSettings(config *codersdk.ExternalAuthConfig, defaults codersdk.
if config.CodeChallengeMethodsSupported == nil {
config.CodeChallengeMethodsSupported = []string{string(promoauth.PKCEChallengeMethodSha256)}
}
// Set default API base URL for providers that need one.
if config.APIBaseURL == "" {
switch codersdk.EnhancedExternalAuthProvider(config.Type) {
case codersdk.EnhancedExternalAuthProviderGitHub:
config.APIBaseURL = "https://api.github.com"
case codersdk.EnhancedExternalAuthProviderGitLab:
config.APIBaseURL = "https://gitlab.com/api/v4"
case codersdk.EnhancedExternalAuthProviderGitea:
config.APIBaseURL = "https://gitea.com/api/v1"
}
}
}
// gitHubDefaults returns default config values for GitHub.
@@ -25,6 +25,7 @@ func TestGitlabDefaults(t *testing.T) {
DisplayName: "GitLab",
DisplayIcon: "/icon/gitlab.svg",
Regex: `^(https?://)?gitlab\.com(/.*)?$`,
APIBaseURL: "https://gitlab.com/api/v4",
Scopes: []string{"write_repository"},
CodeChallengeMethodsSupported: []string{string(promoauth.PKCEChallengeMethodSha256)},
}
+34
View File
@@ -844,6 +844,40 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
return fake, config, link
}
func TestApplyDefaultsToConfig_CaseInsensitive(t *testing.T) {
t.Parallel()
instrument := promoauth.NewFactory(prometheus.NewRegistry())
accessURL, err := url.Parse("https://coder.example.com")
require.NoError(t, err)
for _, tc := range []struct {
Name string
Type string
}{
{Name: "GitHub", Type: "GitHub"},
{Name: "GITLAB", Type: "GITLAB"},
{Name: "Gitea", Type: "Gitea"},
} {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
configs, err := externalauth.ConvertConfig(
instrument,
[]codersdk.ExternalAuthConfig{{
Type: tc.Type,
ClientID: "test-id",
ClientSecret: "test-secret",
}},
accessURL,
)
require.NoError(t, err)
require.Len(t, configs, 1)
// Defaults should have been applied despite mixed-case Type.
assert.NotEmpty(t, configs[0].AuthCodeURL("state"), "auth URL should be populated from defaults")
})
}
}
type roundTripper func(req *http.Request) (*http.Response, error)
func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+536
View File
@@ -0,0 +1,536 @@
package gitprovider
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"golang.org/x/xerrors"
"github.com/coder/quartz"
)
const defaultGitHubAPIBaseURL = "https://api.github.com"
type githubProvider struct {
apiBaseURL string
webBaseURL string
httpClient *http.Client
clock quartz.Clock
// Compiled per-instance to support GitHub Enterprise hosts.
pullRequestPathPattern *regexp.Regexp
repositoryHTTPSPattern *regexp.Regexp
repositorySSHPathPattern *regexp.Regexp
}
func newGitHub(apiBaseURL string, httpClient *http.Client, clock quartz.Clock) *githubProvider {
if apiBaseURL == "" {
apiBaseURL = defaultGitHubAPIBaseURL
}
apiBaseURL = strings.TrimRight(apiBaseURL, "/")
if httpClient == nil {
httpClient = http.DefaultClient
}
// Derive the web base URL from the API base URL.
// github.com: api.github.com → github.com
// GHE: ghes.corp.com/api/v3 → ghes.corp.com
webBaseURL := deriveWebBaseURL(apiBaseURL)
// Parse the host for regex construction.
host := extractHost(webBaseURL)
// Escape the host for use in regex patterns.
escapedHost := regexp.QuoteMeta(host)
return &githubProvider{
apiBaseURL: apiBaseURL,
webBaseURL: webBaseURL,
httpClient: httpClient,
clock: clock,
pullRequestPathPattern: regexp.MustCompile(
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+)/pull/([0-9]+)(?:[/?#].*)?$`,
),
repositoryHTTPSPattern: regexp.MustCompile(
`^https://` + escapedHost + `/([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
),
repositorySSHPathPattern: regexp.MustCompile(
`^(?:ssh://)?git@` + escapedHost + `[:/]([A-Za-z0-9_.-]+)/([A-Za-z0-9_.-]+?)(?:\.git)?/?$`,
),
}
}
// deriveWebBaseURL converts a GitHub API base URL to the
// corresponding web base URL.
//
// github.com: https://api.github.com → https://github.com
// GHE: https://ghes.corp.com/api/v3 → https://ghes.corp.com
func deriveWebBaseURL(apiBaseURL string) string {
u, err := url.Parse(apiBaseURL)
if err != nil {
return "https://github.com"
}
// Standard github.com: API host is api.github.com.
if strings.EqualFold(u.Host, "api.github.com") {
return "https://github.com"
}
// GHE: strip /api/v3 path suffix.
u.Path = strings.TrimSuffix(u.Path, "/api/v3")
u.Path = strings.TrimSuffix(u.Path, "/")
return u.String()
}
// extractHost returns the host portion of a URL.
func extractHost(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return "github.com"
}
return u.Host
}
func (g *githubProvider) ParseRepositoryOrigin(raw string) (owner string, repo string, normalizedOrigin string, ok bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", "", "", false
}
matches := g.repositoryHTTPSPattern.FindStringSubmatch(raw)
if len(matches) != 3 {
matches = g.repositorySSHPathPattern.FindStringSubmatch(raw)
}
if len(matches) != 3 {
return "", "", "", false
}
owner = strings.TrimSpace(matches[1])
repo = strings.TrimSpace(matches[2])
repo = strings.TrimSuffix(repo, ".git")
if owner == "" || repo == "" {
return "", "", "", false
}
return owner, repo, fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo)), true
}
func (g *githubProvider) ParsePullRequestURL(raw string) (PRRef, bool) {
matches := g.pullRequestPathPattern.FindStringSubmatch(strings.TrimSpace(raw))
if len(matches) != 4 {
return PRRef{}, false
}
number, err := strconv.Atoi(matches[3])
if err != nil {
return PRRef{}, false
}
return PRRef{
Owner: matches[1],
Repo: matches[2],
Number: number,
}, true
}
func (g *githubProvider) NormalizePullRequestURL(raw string) string {
ref, ok := g.ParsePullRequestURL(strings.TrimRight(
strings.TrimSpace(raw),
"),.;",
))
if !ok {
return ""
}
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
}
// escapePathPreserveSlashes escapes each segment of a path
// individually, preserving `/` separators. This is needed for
// web URLs where GitHub expects literal slashes (e.g.
// /tree/feat/new-thing).
func escapePathPreserveSlashes(s string) string {
segments := strings.Split(s, "/")
for i, seg := range segments {
segments[i] = url.PathEscape(seg)
}
return strings.Join(segments, "/")
}
func (g *githubProvider) BuildBranchURL(owner string, repo string, branch string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
branch = strings.TrimSpace(branch)
if owner == "" || repo == "" || branch == "" {
return ""
}
return fmt.Sprintf(
"%s/%s/%s/tree/%s",
g.webBaseURL,
url.PathEscape(owner),
url.PathEscape(repo),
escapePathPreserveSlashes(branch),
)
}
func (g *githubProvider) BuildRepositoryURL(owner string, repo string) string {
owner = strings.TrimSpace(owner)
repo = strings.TrimSpace(repo)
if owner == "" || repo == "" {
return ""
}
return fmt.Sprintf("%s/%s/%s", g.webBaseURL, url.PathEscape(owner), url.PathEscape(repo))
}
func (g *githubProvider) BuildPullRequestURL(ref PRRef) string {
if ref.Owner == "" || ref.Repo == "" || ref.Number <= 0 {
return ""
}
return fmt.Sprintf("%s/%s/%s/pull/%d", g.webBaseURL, url.PathEscape(ref.Owner), url.PathEscape(ref.Repo), ref.Number)
}
func (g *githubProvider) ResolveBranchPullRequest(
ctx context.Context,
token string,
ref BranchRef,
) (*PRRef, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return nil, nil
}
query := url.Values{}
query.Set("state", "open")
query.Set("head", fmt.Sprintf("%s:%s", ref.Owner, ref.Branch))
query.Set("sort", "updated")
query.Set("direction", "desc")
query.Set("per_page", "1")
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls?%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
query.Encode(),
)
var pulls []struct {
HTMLURL string `json:"html_url"`
Number int `json:"number"`
}
if err := g.decodeJSON(ctx, requestURL, token, &pulls); err != nil {
return nil, err
}
if len(pulls) == 0 {
return nil, nil
}
prRef, ok := g.ParsePullRequestURL(pulls[0].HTMLURL)
if !ok {
return nil, nil
}
return &prRef, nil
}
func (g *githubProvider) FetchPullRequestStatus(
ctx context.Context,
token string,
ref PRRef,
) (*PRStatus, error) {
pullEndpoint := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
ref.Number,
)
var pull struct {
State string `json:"state"`
Merged bool `json:"merged"`
Draft bool `json:"draft"`
Additions int32 `json:"additions"`
Deletions int32 `json:"deletions"`
ChangedFiles int32 `json:"changed_files"`
Head struct {
SHA string `json:"sha"`
} `json:"head"`
}
if err := g.decodeJSON(ctx, pullEndpoint, token, &pull); err != nil {
return nil, err
}
var reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
// GitHub returns at most 100 reviews per page. We do not
// paginate because PRs with >100 reviews are extremely rare,
// and the cost of multiple API calls per refresh is not
// justified. If needed, pagination can be added later.
if err := g.decodeJSON(
ctx,
pullEndpoint+"/reviews?per_page=100",
token,
&reviews,
); err != nil {
return nil, err
}
state := PRState(strings.ToLower(strings.TrimSpace(pull.State)))
if pull.Merged {
state = PRStateMerged
}
return &PRStatus{
State: state,
Draft: pull.Draft,
HeadSHA: pull.Head.SHA,
DiffStats: DiffStats{
Additions: pull.Additions,
Deletions: pull.Deletions,
ChangedFiles: pull.ChangedFiles,
},
ChangesRequested: hasOutstandingChangesRequested(reviews),
FetchedAt: g.clock.Now().UTC(),
}, nil
}
func (g *githubProvider) FetchPullRequestDiff(
ctx context.Context,
token string,
ref PRRef,
) (string, error) {
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/pulls/%d",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
ref.Number,
)
return g.fetchDiff(ctx, requestURL, token)
}
func (g *githubProvider) FetchBranchDiff(
ctx context.Context,
token string,
ref BranchRef,
) (string, error) {
if ref.Owner == "" || ref.Repo == "" || ref.Branch == "" {
return "", nil
}
var repository struct {
DefaultBranch string `json:"default_branch"`
}
repositoryURL := fmt.Sprintf(
"%s/repos/%s/%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
)
if err := g.decodeJSON(ctx, repositoryURL, token, &repository); err != nil {
return "", err
}
defaultBranch := strings.TrimSpace(repository.DefaultBranch)
if defaultBranch == "" {
return "", xerrors.New("github repository default branch is empty")
}
requestURL := fmt.Sprintf(
"%s/repos/%s/%s/compare/%s...%s",
g.apiBaseURL,
url.PathEscape(ref.Owner),
url.PathEscape(ref.Repo),
url.PathEscape(defaultBranch),
url.PathEscape(ref.Branch),
)
return g.fetchDiff(ctx, requestURL, token)
}
func (g *githubProvider) decodeJSON(
ctx context.Context,
requestURL string,
token string,
dest any,
) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return xerrors.Errorf("create github request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff-status")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.httpClient.Do(req)
if err != nil {
return xerrors.Errorf("execute github request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
retryAfter := ParseRetryAfter(resp.Header, g.clock)
if retryAfter > 0 {
return &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter)}
}
// No rate-limit headers — fall through to generic error.
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return xerrors.Errorf(
"github request failed with status %d",
resp.StatusCode,
)
}
return xerrors.Errorf(
"github request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
return xerrors.Errorf("decode github response: %w", err)
}
return nil
}
func (g *githubProvider) fetchDiff(
ctx context.Context,
requestURL string,
token string,
) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return "", xerrors.Errorf("create github diff request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.diff")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
req.Header.Set("User-Agent", "coder-chat-diff")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := g.httpClient.Do(req)
if err != nil {
return "", xerrors.Errorf("execute github diff request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusTooManyRequests {
retryAfter := ParseRetryAfter(resp.Header, g.clock)
if retryAfter > 0 {
return "", &RateLimitError{RetryAfter: g.clock.Now().Add(retryAfter)}
}
}
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 8192))
if readErr != nil {
return "", xerrors.Errorf("github diff request failed with status %d", resp.StatusCode)
}
return "", xerrors.Errorf(
"github diff request failed with status %d: %s",
resp.StatusCode,
strings.TrimSpace(string(body)),
)
}
// Read one extra byte beyond MaxDiffSize so we can detect
// whether the diff exceeds the limit. LimitReader stops us
// allocating an arbitrarily large buffer by accident.
buf, err := io.ReadAll(io.LimitReader(resp.Body, MaxDiffSize+1))
if err != nil {
return "", xerrors.Errorf("read github diff response: %w", err)
}
if len(buf) > MaxDiffSize {
return "", ErrDiffTooLarge
}
return string(buf), nil
}
// ParseRetryAfter extracts a retry-after time from GitHub
// rate-limit headers. Returns zero value if no recognizable header is
// present.
func ParseRetryAfter(h http.Header, clk quartz.Clock) time.Duration {
if clk == nil {
clk = quartz.NewReal()
}
// Retry-After header: seconds until retry.
if ra := h.Get("Retry-After"); ra != "" {
if secs, err := strconv.Atoi(ra); err == nil {
return time.Duration(secs) * time.Second
}
}
// X-Ratelimit-Reset header: unix timestamp. We compute the
// duration from now according to the caller's clock.
if reset := h.Get("X-Ratelimit-Reset"); reset != "" {
if ts, err := strconv.ParseInt(reset, 10, 64); err == nil {
d := time.Unix(ts, 0).Sub(clk.Now())
return d
}
}
return 0
}
func hasOutstandingChangesRequested(
reviews []struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
},
) bool {
type reviewerState struct {
reviewID int64
state string
}
statesByReviewer := make(map[string]reviewerState)
for _, review := range reviews {
login := strings.ToLower(strings.TrimSpace(review.User.Login))
if login == "" {
continue
}
state := strings.ToUpper(strings.TrimSpace(review.State))
switch state {
case "CHANGES_REQUESTED", "APPROVED", "DISMISSED":
default:
continue
}
current, exists := statesByReviewer[login]
if exists && current.reviewID > review.ID {
continue
}
statesByReviewer[login] = reviewerState{
reviewID: review.ID,
state: state,
}
}
for _, state := range statesByReviewer {
if state.state == "CHANGES_REQUESTED" {
return true
}
}
return false
}
@@ -0,0 +1,994 @@
package gitprovider_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/quartz"
)
func TestGitHubParseRepositoryOrigin(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expectOK bool
expectOwner string
expectRepo string
expectNormalized string
}{
{
name: "HTTPS URL",
raw: "https://github.com/coder/coder",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "HTTPS URL with .git",
raw: "https://github.com/coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "HTTPS URL with trailing slash",
raw: "https://github.com/coder/coder/",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL",
raw: "git@github.com:coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL without .git",
raw: "git@github.com:coder/coder",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "SSH URL with ssh:// prefix",
raw: "ssh://git@github.com/coder/coder.git",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNormalized: "https://github.com/coder/coder",
},
{
name: "GitLab URL does not match",
raw: "https://gitlab.com/coder/coder",
expectOK: false,
},
{
name: "Empty string",
raw: "",
expectOK: false,
},
{
name: "Not a URL",
raw: "not-a-url",
expectOK: false,
},
{
name: "Hyphenated owner and repo",
raw: "https://github.com/my-org/my-repo.git",
expectOK: true,
expectOwner: "my-org",
expectRepo: "my-repo",
expectNormalized: "https://github.com/my-org/my-repo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin(tt.raw)
assert.Equal(t, tt.expectOK, ok)
if tt.expectOK {
assert.Equal(t, tt.expectOwner, owner)
assert.Equal(t, tt.expectRepo, repo)
assert.Equal(t, tt.expectNormalized, normalized)
}
})
}
}
func TestGitHubParsePullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expectOK bool
expectOwner string
expectRepo string
expectNumber int
}{
{
name: "Standard PR URL",
raw: "https://github.com/coder/coder/pull/123",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 123,
},
{
name: "PR URL with query string",
raw: "https://github.com/coder/coder/pull/456?diff=split",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 456,
},
{
name: "PR URL with fragment",
raw: "https://github.com/coder/coder/pull/789#discussion",
expectOK: true,
expectOwner: "coder",
expectRepo: "coder",
expectNumber: 789,
},
{
name: "Not a PR URL",
raw: "https://github.com/coder/coder",
expectOK: false,
},
{
name: "Issue URL (not PR)",
raw: "https://github.com/coder/coder/issues/123",
expectOK: false,
},
{
name: "GitLab MR URL",
raw: "https://gitlab.com/coder/coder/-/merge_requests/123",
expectOK: false,
},
{
name: "Empty string",
raw: "",
expectOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ref, ok := gp.ParsePullRequestURL(tt.raw)
assert.Equal(t, tt.expectOK, ok)
if tt.expectOK {
assert.Equal(t, tt.expectOwner, ref.Owner)
assert.Equal(t, tt.expectRepo, ref.Repo)
assert.Equal(t, tt.expectNumber, ref.Number)
}
})
}
}
func TestGitHubNormalizePullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
raw string
expected string
}{
{
name: "Already normalized",
raw: "https://github.com/coder/coder/pull/123",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With trailing punctuation",
raw: "https://github.com/coder/coder/pull/123).",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With query string",
raw: "https://github.com/coder/coder/pull/123?diff=split",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "With whitespace",
raw: " https://github.com/coder/coder/pull/123 ",
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "Not a PR URL",
raw: "https://example.com",
expected: "",
},
{
name: "Empty string",
raw: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.NormalizePullRequestURL(tt.raw)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubBuildBranchURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
owner string
repo string
branch string
expected string
}{
{
name: "Simple branch",
owner: "coder",
repo: "coder",
branch: "main",
expected: "https://github.com/coder/coder/tree/main",
},
{
name: "Branch with slash",
owner: "coder",
repo: "coder",
branch: "feat/new-thing",
expected: "https://github.com/coder/coder/tree/feat/new-thing",
},
{
name: "Empty owner",
owner: "",
repo: "coder",
branch: "main",
expected: "",
},
{
name: "Empty repo",
owner: "coder",
repo: "",
branch: "main",
expected: "",
},
{
name: "Empty branch",
owner: "coder",
repo: "coder",
branch: "",
expected: "",
},
{
name: "Branch with slashes",
owner: "my-org",
repo: "my-repo",
branch: "feat/new-thing",
expected: "https://github.com/my-org/my-repo/tree/feat/new-thing",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.BuildBranchURL(tt.owner, tt.repo, tt.branch)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubBuildPullRequestURL(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
tests := []struct {
name string
ref gitprovider.PRRef
expected string
}{
{
name: "Valid PR ref",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 123},
expected: "https://github.com/coder/coder/pull/123",
},
{
name: "Empty owner",
ref: gitprovider.PRRef{Owner: "", Repo: "coder", Number: 123},
expected: "",
},
{
name: "Empty repo",
ref: gitprovider.PRRef{Owner: "coder", Repo: "", Number: 123},
expected: "",
},
{
name: "Zero number",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: 0},
expected: "",
},
{
name: "Negative number",
ref: gitprovider.PRRef{Owner: "coder", Repo: "coder", Number: -1},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := gp.BuildPullRequestURL(tt.ref)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGitHubEnterpriseURLs(t *testing.T) {
t.Parallel()
gp := gitprovider.New("github", "https://ghes.corp.com/api/v3", nil)
require.NotNil(t, gp)
t.Run("ParseRepositoryOrigin HTTPS", func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("https://ghes.corp.com/org/repo.git")
assert.True(t, ok)
assert.Equal(t, "org", owner)
assert.Equal(t, "repo", repo)
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
})
t.Run("ParseRepositoryOrigin SSH", func(t *testing.T) {
t.Parallel()
owner, repo, normalized, ok := gp.ParseRepositoryOrigin("git@ghes.corp.com:org/repo.git")
assert.True(t, ok)
assert.Equal(t, "org", owner)
assert.Equal(t, "repo", repo)
assert.Equal(t, "https://ghes.corp.com/org/repo", normalized)
})
t.Run("ParsePullRequestURL", func(t *testing.T) {
t.Parallel()
ref, ok := gp.ParsePullRequestURL("https://ghes.corp.com/org/repo/pull/42")
assert.True(t, ok)
assert.Equal(t, "org", ref.Owner)
assert.Equal(t, "repo", ref.Repo)
assert.Equal(t, 42, ref.Number)
})
t.Run("NormalizePullRequestURL", func(t *testing.T) {
t.Parallel()
result := gp.NormalizePullRequestURL("https://ghes.corp.com/org/repo/pull/42?x=y")
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
})
t.Run("BuildBranchURL", func(t *testing.T) {
t.Parallel()
result := gp.BuildBranchURL("org", "repo", "main")
assert.Equal(t, "https://ghes.corp.com/org/repo/tree/main", result)
})
t.Run("BuildPullRequestURL", func(t *testing.T) {
t.Parallel()
result := gp.BuildPullRequestURL(gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42})
assert.Equal(t, "https://ghes.corp.com/org/repo/pull/42", result)
})
t.Run("github.com URLs do not match GHE instance", func(t *testing.T) {
t.Parallel()
_, _, _, ok := gp.ParseRepositoryOrigin("https://github.com/coder/coder")
assert.False(t, ok, "github.com HTTPS URL should not match GHE instance")
_, _, _, ok = gp.ParseRepositoryOrigin("git@github.com:coder/coder.git")
assert.False(t, ok, "github.com SSH URL should not match GHE instance")
_, ok = gp.ParsePullRequestURL("https://github.com/coder/coder/pull/123")
assert.False(t, ok, "github.com PR URL should not match GHE instance")
})
}
func TestNewUnsupportedProvider(t *testing.T) {
t.Parallel()
gp := gitprovider.New("unsupported", "", nil)
assert.Nil(t, gp, "unsupported provider type should return nil")
}
func TestGitHubRatelimit_403WithResetHeader(t *testing.T) {
t.Parallel()
resetTime := time.Now().Add(60 * time.Second)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("X-Ratelimit-Reset", fmt.Sprintf("%d", resetTime.Unix()))
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"message": "API rate limit exceeded"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
assert.WithinDuration(t, resetTime, rlErr.RetryAfter, 2*time.Second)
}
func TestGitHubRatelimit_429WithRetryAfter(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Retry-After", "120")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "secondary rate limit"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
// Retry-After: 120 means ~120s from now.
expected := time.Now().Add(120 * time.Second)
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
}
func TestGitHubRatelimit_403NormalError(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"message": "Bad credentials"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestStatus(
context.Background(),
"bad-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
assert.False(t, errors.As(err, &rlErr), "error should NOT be *RateLimitError")
assert.Contains(t, err.Error(), "403")
}
func TestGitHubFetchPullRequestDiff(t *testing.T) {
t.Parallel()
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
t.Run("OK", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(smallDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Equal(t, smallDiff, diff)
})
t.Run("ExactlyMaxSize", func(t *testing.T) {
t.Parallel()
exactDiff := string(make([]byte, gitprovider.MaxDiffSize))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(exactDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Len(t, diff, gitprovider.MaxDiffSize)
})
t.Run("TooLarge", func(t *testing.T) {
t.Parallel()
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(oversizeDiff))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
})
}
func TestFetchPullRequestDiff_Ratelimit(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchPullRequestDiff(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
expected := time.Now().Add(60 * time.Second)
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
}
func TestFetchBranchDiff_Ratelimit(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
// Second request: compare endpoint returns 429.
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"message": "rate limit"}`))
return
}
// First request: repo metadata.
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.Error(t, err)
var rlErr *gitprovider.RateLimitError
require.True(t, errors.As(err, &rlErr), "error should be *RateLimitError, got: %T", err)
expected := time.Now().Add(60 * time.Second)
assert.WithinDuration(t, expected, rlErr.RetryAfter, 5*time.Second)
}
func TestFetchPullRequestStatus(t *testing.T) {
t.Parallel()
type review struct {
ID int64 `json:"id"`
State string `json:"state"`
User struct {
Login string `json:"login"`
} `json:"user"`
}
makeReview := func(id int64, state, login string) review {
r := review{ID: id, State: state}
r.User.Login = login
return r
}
tests := []struct {
name string
pullJSON string
reviews []review
expectedState gitprovider.PRState
expectedDraft bool
changesRequested bool
}{
{
name: "OpenPR/NoReviews",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateOpen,
expectedDraft: false,
changesRequested: false,
},
{
name: "OpenPR/SingleChangesRequested",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{makeReview(1, "CHANGES_REQUESTED", "alice")},
expectedState: gitprovider.PRStateOpen,
changesRequested: true,
},
{
name: "OpenPR/ChangesRequestedThenApproved",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "CHANGES_REQUESTED", "alice"),
makeReview(2, "APPROVED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "OpenPR/ChangesRequestedThenDismissed",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "CHANGES_REQUESTED", "alice"),
makeReview(2, "DISMISSED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "OpenPR/MultipleReviewersMixed",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "APPROVED", "alice"),
makeReview(2, "CHANGES_REQUESTED", "bob"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: true,
},
{
name: "OpenPR/CommentedDoesNotAffect",
pullJSON: `{"state":"open","merged":false,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{
makeReview(1, "COMMENTED", "alice"),
},
expectedState: gitprovider.PRStateOpen,
changesRequested: false,
},
{
name: "MergedPR",
pullJSON: `{"state":"closed","merged":true,"draft":false,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateMerged,
changesRequested: false,
},
{
name: "DraftPR",
pullJSON: `{"state":"open","merged":false,"draft":true,"additions":10,"deletions":5,"changed_files":3,"head":{"sha":"abc123"}}`,
reviews: []review{},
expectedState: gitprovider.PRStateOpen,
expectedDraft: true,
changesRequested: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
reviewsJSON, err := json.Marshal(tc.reviews)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1/reviews", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(reviewsJSON)
})
mux.HandleFunc("/api/v3/repos/owner/repo/pulls/1", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tc.pullJSON))
})
srv := httptest.NewServer(mux)
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
before := time.Now().UTC()
status, err := gp.FetchPullRequestStatus(
context.Background(),
"test-token",
gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1},
)
require.NoError(t, err)
assert.Equal(t, tc.expectedState, status.State)
assert.Equal(t, tc.expectedDraft, status.Draft)
assert.Equal(t, tc.changesRequested, status.ChangesRequested)
assert.Equal(t, "abc123", status.HeadSHA)
assert.Equal(t, int32(10), status.DiffStats.Additions)
assert.Equal(t, int32(5), status.DiffStats.Deletions)
assert.Equal(t, int32(3), status.DiffStats.ChangedFiles)
assert.False(t, status.FetchedAt.IsZero())
assert.True(t, !status.FetchedAt.Before(before), "FetchedAt should be >= test start time")
})
}
}
func TestResolveBranchPullRequest(t *testing.T) {
t.Parallel()
t.Run("Found", func(t *testing.T) {
t.Parallel()
var srvURL string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters.
assert.Equal(t, "open", r.URL.Query().Get("state"))
assert.Equal(t, "owner:feat", r.URL.Query().Get("head"))
w.Header().Set("Content-Type", "application/json")
// Use the test server's URL so ParsePullRequestURL
// matches the provider's derived web host.
htmlURL := fmt.Sprintf("https://%s/owner/repo/pull/42",
strings.TrimPrefix(strings.TrimPrefix(srvURL, "http://"), "https://"))
_, _ = w.Write([]byte(fmt.Sprintf(`[{"html_url":%q,"number":42}]`, htmlURL)))
}))
defer srv.Close()
srvURL = srv.URL
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
require.NotNil(t, prRef)
assert.Equal(t, "owner", prRef.Owner)
assert.Equal(t, "repo", prRef.Repo)
assert.Equal(t, 42, prRef.Number)
})
t.Run("NoneOpen", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[]`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Nil(t, prRef)
})
t.Run("InvalidHTMLURL", func(t *testing.T) {
t.Parallel()
// If html_url can't be parsed as a PR URL, ResolveBranchPullRequest
// returns nil, nil.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"html_url":"not-a-valid-url","number":42}]`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
prRef, err := gp.ResolveBranchPullRequest(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "owner", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Nil(t, prRef)
})
}
func TestFetchBranchDiff(t *testing.T) {
t.Parallel()
const smallDiff = "diff --git a/file.go b/file.go\n--- a/file.go\n+++ b/file.go\n@@ -1 +1 @@\n-old\n+new\n"
t.Run("OK", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(smallDiff))
return
}
// Repo metadata.
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
diff, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.NoError(t, err)
assert.Equal(t, smallDiff, diff)
})
t.Run("EmptyDefaultBranch", func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":""}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
require.Error(t, err)
assert.Contains(t, err.Error(), "default branch is empty")
})
t.Run("DiffTooLarge", func(t *testing.T) {
t.Parallel()
oversizeDiff := string(make([]byte, gitprovider.MaxDiffSize+1024))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/compare/") {
w.Header().Set("Content-Type", "text/plain")
_, _ = w.Write([]byte(oversizeDiff))
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"default_branch":"main"}`))
}))
defer srv.Close()
gp := gitprovider.New("github", srv.URL+"/api/v3", srv.Client())
require.NotNil(t, gp)
_, err := gp.FetchBranchDiff(
context.Background(),
"test-token",
gitprovider.BranchRef{Owner: "org", Repo: "repo", Branch: "feat"},
)
assert.ErrorIs(t, err, gitprovider.ErrDiffTooLarge)
})
}
func TestEscapePathPreserveSlashes(t *testing.T) {
t.Parallel()
// The function is unexported, so test it indirectly via BuildBranchURL.
// A branch with a space in a segment should be escaped, but slashes preserved.
gp := gitprovider.New("github", "", nil)
require.NotNil(t, gp)
got := gp.BuildBranchURL("owner", "repo", "feat/my thing")
assert.Equal(t, "https://github.com/owner/repo/tree/feat/my%20thing", got)
}
func TestParseRetryAfter(t *testing.T) {
t.Parallel()
clk := quartz.NewMock(t)
clk.Set(time.Now())
t.Run("RetryAfterSeconds", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "120")
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, 120*time.Second, d)
})
t.Run("XRatelimitReset", func(t *testing.T) {
t.Parallel()
future := clk.Now().Add(90 * time.Second)
t.Logf("now: %d future: %d", clk.Now().Unix(), future.Unix())
h := http.Header{}
h.Set("X-Ratelimit-Reset", strconv.FormatInt(future.Unix(), 10))
d := gitprovider.ParseRetryAfter(h, clk)
assert.WithinDuration(t, future, clk.Now().Add(d), time.Second)
})
t.Run("NoHeaders", func(t *testing.T) {
t.Parallel()
h := http.Header{}
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, time.Duration(0), d)
})
t.Run("InvalidValue", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "not-a-number")
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, time.Duration(0), d)
})
t.Run("RetryAfterTakesPrecedence", func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Retry-After", "60")
h.Set("X-Ratelimit-Reset", strconv.FormatInt(
clk.Now().Unix()+120, 10,
))
d := gitprovider.ParseRetryAfter(h, clk)
assert.Equal(t, 60*time.Second, d)
})
}
@@ -0,0 +1,179 @@
package gitprovider
import (
"context"
"fmt"
"net/http"
"time"
"golang.org/x/xerrors"
"github.com/coder/quartz"
)
// providerOptions holds optional configuration for provider
// construction.
type providerOptions struct {
clock quartz.Clock
}
// Option configures optional behavior for a Provider.
type Option func(*providerOptions)
// WithClock sets the clock used by the provider. Defaults to
// quartz.NewReal() if not provided.
func WithClock(c quartz.Clock) Option {
return func(o *providerOptions) {
o.clock = c
}
}
// PRState is the normalized state of a pull/merge request across
// all providers.
type PRState string
const (
PRStateOpen PRState = "open"
PRStateClosed PRState = "closed"
PRStateMerged PRState = "merged"
)
// PRRef identifies a pull request on any provider.
type PRRef struct {
// Owner is the repository owner / project / workspace.
Owner string
// Repo is the repository name or slug.
Repo string
// Number is the PR number / IID / index.
Number int
}
// BranchRef identifies a branch in a repository, used for
// branch-to-PR resolution.
type BranchRef struct {
Owner string
Repo string
Branch string
}
// DiffStats summarizes the size of a PR's changes.
type DiffStats struct {
Additions int32
Deletions int32
ChangedFiles int32
}
// PRStatus is the complete status of a pull/merge request.
// This is the universal return type that all providers populate.
type PRStatus struct {
// State is the PR's lifecycle state.
State PRState
// Draft indicates the PR is marked as draft/WIP.
Draft bool
// HeadSHA is the SHA of the head commit.
HeadSHA string
// DiffStats summarizes additions/deletions/files changed.
DiffStats DiffStats
// ChangesRequested is a convenience boolean: true if any
// reviewer's current state is "changes_requested".
ChangesRequested bool
// FetchedAt is when this status was fetched.
FetchedAt time.Time
}
// MaxDiffSize is the maximum number of bytes read from a diff
// response. Diffs exceeding this limit are rejected with
// ErrDiffTooLarge.
const MaxDiffSize = 4 << 20 // 4 MiB
// ErrDiffTooLarge is returned when a diff exceeds MaxDiffSize.
var ErrDiffTooLarge = xerrors.Errorf("diff exceeds maximum size of %d bytes", MaxDiffSize)
// Provider defines the interface that all Git hosting providers
// implement. Each method is designed to minimize API round-trips
// for the specific provider.
type Provider interface {
// FetchPullRequestStatus retrieves the complete status of a
// pull request in the minimum number of API calls for this
// provider.
FetchPullRequestStatus(ctx context.Context, token string, ref PRRef) (*PRStatus, error)
// ResolveBranchPullRequest finds the open PR (if any) for
// the given branch. Returns nil, nil if no open PR exists.
ResolveBranchPullRequest(ctx context.Context, token string, ref BranchRef) (*PRRef, error)
// FetchPullRequestDiff returns the raw unified diff for a
// pull request. This uses the PR's actual base branch (which
// may differ from the repo default branch, e.g. a PR
// targeting "staging" instead of "main"), so it matches what
// the provider shows on the PR's "Files changed" tab.
// Returns ErrDiffTooLarge if the diff exceeds MaxDiffSize.
FetchPullRequestDiff(ctx context.Context, token string, ref PRRef) (string, error)
// FetchBranchDiff returns the diff of a branch compared
// against the repository's default branch. This is the
// fallback when no pull request exists yet (e.g. the agent
// pushed a branch but hasn't opened a PR). Returns
// ErrDiffTooLarge if the diff exceeds MaxDiffSize.
FetchBranchDiff(ctx context.Context, token string, ref BranchRef) (string, error)
// ParseRepositoryOrigin parses a remote origin URL (HTTPS
// or SSH) into owner and repo components, returning the
// normalized HTTPS URL. Returns false if the URL does not
// match this provider.
ParseRepositoryOrigin(raw string) (owner, repo, normalizedOrigin string, ok bool)
// ParsePullRequestURL parses a pull request URL into a
// PRRef. Returns false if the URL does not match this
// provider.
ParsePullRequestURL(raw string) (PRRef, bool)
// NormalizePullRequestURL normalizes a pull request URL,
// stripping trailing punctuation, query strings, and
// fragments. Returns empty string if the URL does not
// match this provider.
NormalizePullRequestURL(raw string) string
// BuildBranchURL constructs a URL to view a branch on
// the provider's web UI.
BuildBranchURL(owner, repo, branch string) string
// BuildRepositoryURL constructs a URL to view a repository
// on the provider's web UI.
BuildRepositoryURL(owner, repo string) string
// BuildPullRequestURL constructs a URL to view a pull
// request on the provider's web UI.
BuildPullRequestURL(ref PRRef) string
}
// New creates a Provider for the given provider type and API base
// URL. Returns nil if the provider type is not a supported git
// provider.
func New(providerType string, apiBaseURL string, httpClient *http.Client, opts ...Option) Provider {
o := providerOptions{}
for _, opt := range opts {
opt(&o)
}
if o.clock == nil {
o.clock = quartz.NewReal()
}
switch providerType {
case "github":
return newGitHub(apiBaseURL, httpClient, o.clock)
default:
// Other providers (gitlab, bitbucket-cloud, etc.) will be
// added here as they are implemented.
return nil
}
}
// RateLimitError indicates the git provider's API rate limit was hit.
type RateLimitError struct {
RetryAfter time.Time
}
func (e *RateLimitError) Error() string {
return fmt.Sprintf("rate limited until %s", e.RetryAfter.Format(time.RFC3339))
}
+230
View File
@@ -0,0 +1,230 @@
package gitsync
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/quartz"
)
const (
// DiffStatusTTL is how long a successfully refreshed
// diff status remains fresh before becoming stale again.
DiffStatusTTL = 120 * time.Second
)
// ProviderResolver maps a git remote origin to the gitprovider
// that handles it. Returns nil if no provider matches.
type ProviderResolver func(origin string) gitprovider.Provider
var ErrNoTokenAvailable error = errors.New("no token available")
// TokenResolver obtains the user's git access token for a given
// remote origin. Should return nil if no token is available, in
// which case ErrNoTokenAvailable will be returned.
type TokenResolver func(
ctx context.Context,
userID uuid.UUID,
origin string,
) (*string, error)
// Refresher contains the stateless business logic for fetching
// fresh PR data from a git provider given a stale
// database.ChatDiffStatus row.
type Refresher struct {
providers ProviderResolver
tokens TokenResolver
logger slog.Logger
clock quartz.Clock
}
// NewRefresher creates a Refresher with the given dependency
// functions.
func NewRefresher(
providers ProviderResolver,
tokens TokenResolver,
logger slog.Logger,
clock quartz.Clock,
) *Refresher {
return &Refresher{
providers: providers,
tokens: tokens,
logger: logger,
clock: clock,
}
}
// RefreshRequest pairs a stale row with the chat owner who
// holds the git token needed for API calls.
type RefreshRequest struct {
Row database.ChatDiffStatus
OwnerID uuid.UUID
}
// RefreshResult is the outcome for a single row.
// - Params != nil, Error == nil → success, caller should upsert.
// - Params == nil, Error == nil → no PR yet, caller should skip.
// - Params == nil, Error != nil → row-level failure.
type RefreshResult struct {
Request RefreshRequest
Params *database.UpsertChatDiffStatusParams
Error error
}
// groupKey identifies a unique (owner, origin) pair so that
// provider and token resolution happen once per group.
type groupKey struct {
ownerID uuid.UUID
origin string
}
// Refresh fetches fresh PR data for a batch of stale rows.
// Rows are grouped internally by (ownerID, origin) so that
// provider and token resolution happen once per group. A
// top-level error is returned only when the entire batch
// fails catastrophically. Per-row outcomes are in the
// returned RefreshResult slice (one per input request, same
// order).
func (r *Refresher) Refresh(
ctx context.Context,
requests []RefreshRequest,
) ([]RefreshResult, error) {
results := make([]RefreshResult, len(requests))
for i, req := range requests {
results[i].Request = req
}
// Group request indices by (ownerID, origin).
groups := make(map[groupKey][]int)
for i, req := range requests {
key := groupKey{
ownerID: req.OwnerID,
origin: req.Row.GitRemoteOrigin,
}
groups[key] = append(groups[key], i)
}
for key, indices := range groups {
provider := r.providers(key.origin)
if provider == nil {
err := xerrors.Errorf("no provider for origin %q", key.origin)
for _, i := range indices {
results[i].Error = err
}
continue
}
token, err := r.tokens(ctx, key.ownerID, key.origin)
if err != nil {
err = xerrors.Errorf("resolve token: %w", err)
} else if token == nil || len(*token) == 0 {
err = ErrNoTokenAvailable
}
if err != nil {
for _, i := range indices {
results[i].Error = err
}
continue
}
// This is technically unnecessary but kept here as a future molly-guard.
if token == nil {
continue
}
for i, idx := range indices {
req := requests[idx]
params, err := r.refreshOne(ctx, provider, *token, req.Row)
results[idx] = RefreshResult{Request: req, Params: params, Error: err}
// If rate-limited, skip remaining rows in this group.
var rlErr *gitprovider.RateLimitError
if errors.As(err, &rlErr) {
for _, remaining := range indices[i+1:] {
results[remaining] = RefreshResult{
Request: requests[remaining],
Error: fmt.Errorf("skipped: %w", rlErr),
}
}
break
}
}
}
return results, nil
}
// refreshOne processes a single row using an already-resolved
// provider and token. This is the old Refresh logic, unchanged.
func (r *Refresher) refreshOne(
ctx context.Context,
provider gitprovider.Provider,
token string,
row database.ChatDiffStatus,
) (*database.UpsertChatDiffStatusParams, error) {
var ref gitprovider.PRRef
var prURL string
if row.Url.Valid && row.Url.String != "" {
// Row already has a PR URL — parse it directly.
parsed, ok := provider.ParsePullRequestURL(row.Url.String)
if !ok {
return nil, xerrors.Errorf("parse pull request URL %q", row.Url.String)
}
ref = parsed
prURL = row.Url.String
} else {
// No PR URL — resolve owner/repo from the remote origin,
// then look up the open PR for this branch.
owner, repo, _, ok := provider.ParseRepositoryOrigin(row.GitRemoteOrigin)
if !ok {
return nil, xerrors.Errorf("parse repository origin %q", row.GitRemoteOrigin)
}
resolved, err := provider.ResolveBranchPullRequest(ctx, token, gitprovider.BranchRef{
Owner: owner,
Repo: repo,
Branch: row.GitBranch,
})
if err != nil {
return nil, xerrors.Errorf("resolve branch pull request: %w", err)
}
if resolved == nil {
// No PR exists yet for this branch.
return nil, nil
}
ref = *resolved
prURL = provider.BuildPullRequestURL(ref)
}
status, err := provider.FetchPullRequestStatus(ctx, token, ref)
if err != nil {
return nil, xerrors.Errorf("fetch pull request status: %w", err)
}
now := r.clock.Now().UTC()
params := &database.UpsertChatDiffStatusParams{
ChatID: row.ChatID,
Url: sql.NullString{String: prURL, Valid: prURL != ""},
PullRequestState: sql.NullString{
String: string(status.State),
Valid: status.State != "",
},
ChangesRequested: status.ChangesRequested,
Additions: status.DiffStats.Additions,
Deletions: status.DiffStats.Deletions,
ChangedFiles: status.DiffStats.ChangedFiles,
RefreshedAt: now,
StaleAt: now.Add(DiffStatusTTL),
}
return params, nil
}
+775
View File
@@ -0,0 +1,775 @@
package gitsync_test
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/quartz"
)
// mockProvider implements gitprovider.Provider with function fields
// so each test can wire only the methods it needs. Any method left
// nil panics with "unexpected call".
type mockProvider struct {
fetchPullRequestStatus func(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error)
resolveBranchPR func(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error)
fetchPullRequestDiff func(ctx context.Context, token string, ref gitprovider.PRRef) (string, error)
fetchBranchDiff func(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error)
parseRepositoryOrigin func(raw string) (string, string, string, bool)
parsePullRequestURL func(raw string) (gitprovider.PRRef, bool)
normalizePullRequestURL func(raw string) string
buildBranchURL func(owner, repo, branch string) string
buildRepositoryURL func(owner, repo string) string
buildPullRequestURL func(ref gitprovider.PRRef) string
}
func (m *mockProvider) FetchPullRequestStatus(ctx context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
if m.fetchPullRequestStatus == nil {
panic("unexpected call to FetchPullRequestStatus")
}
return m.fetchPullRequestStatus(ctx, token, ref)
}
func (m *mockProvider) ResolveBranchPullRequest(ctx context.Context, token string, ref gitprovider.BranchRef) (*gitprovider.PRRef, error) {
if m.resolveBranchPR == nil {
panic("unexpected call to ResolveBranchPullRequest")
}
return m.resolveBranchPR(ctx, token, ref)
}
func (m *mockProvider) FetchPullRequestDiff(ctx context.Context, token string, ref gitprovider.PRRef) (string, error) {
if m.fetchPullRequestDiff == nil {
panic("unexpected call to FetchPullRequestDiff")
}
return m.fetchPullRequestDiff(ctx, token, ref)
}
func (m *mockProvider) FetchBranchDiff(ctx context.Context, token string, ref gitprovider.BranchRef) (string, error) {
if m.fetchBranchDiff == nil {
panic("unexpected call to FetchBranchDiff")
}
return m.fetchBranchDiff(ctx, token, ref)
}
func (m *mockProvider) ParseRepositoryOrigin(raw string) (string, string, string, bool) {
if m.parseRepositoryOrigin == nil {
panic("unexpected call to ParseRepositoryOrigin")
}
return m.parseRepositoryOrigin(raw)
}
func (m *mockProvider) ParsePullRequestURL(raw string) (gitprovider.PRRef, bool) {
if m.parsePullRequestURL == nil {
panic("unexpected call to ParsePullRequestURL")
}
return m.parsePullRequestURL(raw)
}
func (m *mockProvider) NormalizePullRequestURL(raw string) string {
if m.normalizePullRequestURL == nil {
panic("unexpected call to NormalizePullRequestURL")
}
return m.normalizePullRequestURL(raw)
}
func (m *mockProvider) BuildBranchURL(owner, repo, branch string) string {
if m.buildBranchURL == nil {
panic("unexpected call to BuildBranchURL")
}
return m.buildBranchURL(owner, repo, branch)
}
func (m *mockProvider) BuildRepositoryURL(owner, repo string) string {
if m.buildRepositoryURL == nil {
panic("unexpected call to BuildRepositoryURL")
}
return m.buildRepositoryURL(owner, repo)
}
func (m *mockProvider) BuildPullRequestURL(ref gitprovider.PRRef) string {
if m.buildPullRequestURL == nil {
panic("unexpected call to BuildPullRequestURL")
}
return m.buildPullRequestURL(ref)
}
func TestRefresher_WithPRURL(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 5,
ChangedFiles: 3,
},
}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
chatID := uuid.New()
row := database.ChatDiffStatus{
ChatID: chatID,
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
assert.Equal(t, chatID, res.Params.ChatID)
assert.Equal(t, "open", res.Params.PullRequestState.String)
assert.True(t, res.Params.PullRequestState.Valid)
assert.Equal(t, int32(10), res.Params.Additions)
assert.Equal(t, int32(5), res.Params.Deletions)
assert.Equal(t, int32(3), res.Params.ChangedFiles)
// StaleAt should be ~120s after RefreshedAt.
diff := res.Params.StaleAt.Sub(res.Params.RefreshedAt)
assert.InDelta(t, 120, diff.Seconds(), 5)
}
func TestRefresher_BranchResolvesToPR(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
return "org", "repo", "https://github.com/org/repo", true
},
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return &gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 7}, nil
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
buildPullRequestURL: func(_ gitprovider.PRRef) string {
return "https://github.com/org/repo/pull/7"
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
assert.Contains(t, res.Params.Url.String, "pull/7")
assert.True(t, res.Params.Url.Valid)
assert.Equal(t, "open", res.Params.PullRequestState.String)
}
func TestRefresher_BranchNoPRYet(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parseRepositoryOrigin: func(_ string) (string, string, string, bool) {
return "org", "repo", "https://github.com/org/repo", true
},
resolveBranchPR: func(_ context.Context, _ string, _ gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return nil, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.NoError(t, res.Error)
assert.Nil(t, res.Params)
}
func TestRefresher_NoProviderForOrigin(t *testing.T) {
t.Parallel()
providers := func(_ string) gitprovider.Provider { return nil }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://example.com/pr/1", Valid: true},
GitRemoteOrigin: "https://example.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.Contains(t, res.Error.Error(), "no provider")
}
func TestRefresher_TokenResolutionFails(t *testing.T) {
t.Parallel()
var fetchCalled atomic.Bool
mp := &mockProvider{
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
fetchCalled.Store(true)
return nil, errors.New("should not be called")
},
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return nil, errors.New("token lookup failed")
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.False(t, fetchCalled.Load(), "FetchPullRequestStatus should not be called when token resolution fails")
}
func TestRefresher_EmptyToken(t *testing.T) {
t.Parallel()
mp := &mockProvider{}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref(""), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.ErrorIs(t, res.Error, gitsync.ErrNoTokenAvailable)
}
func TestRefresher_ProviderFetchFails(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return nil, errors.New("api error")
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
assert.Contains(t, res.Error.Error(), "api error")
}
func TestRefresher_PRURLParseFailure(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{}, false
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
row := database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/not-a-pr", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
assert.Nil(t, res.Params)
require.Error(t, res.Error)
}
func TestRefresher_BatchGroupsByOwnerAndOrigin(t *testing.T) {
t.Parallel()
mp := &mockProvider{
parsePullRequestURL: func(_ string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 1}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
var tokenCalls atomic.Int32
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
tokenCalls.Add(1)
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
originA := "https://github.com/org/repo"
originB := "https://gitlab.com/org/repo"
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originA,
GitBranch: "feature-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originA,
GitBranch: "feature-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://gitlab.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: originB,
GitBranch: "feature-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
for i, res := range results {
require.NoError(t, res.Error, "result[%d] should not have an error", i)
require.NotNil(t, res.Params, "result[%d] should have params", i)
}
// Two distinct (ownerID, origin) groups → exactly 2 token
// resolution calls.
assert.Equal(t, int32(2), tokenCalls.Load(),
"TokenResolver should be called once per (owner, origin) group")
}
func TestRefresher_UsesInjectedClock(t *testing.T) {
t.Parallel()
mClock := quartz.NewMock(t)
fixedTime := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
mClock.Set(fixedTime)
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: 42}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, _ gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 5,
ChangedFiles: 3,
},
}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), mClock)
chatID := uuid.New()
row := database.ChatDiffStatus{
ChatID: chatID,
Url: sql.NullString{String: "https://github.com/org/repo/pull/42", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature",
}
ownerID := uuid.New()
results, err := r.Refresh(context.Background(), []gitsync.RefreshRequest{
{Row: row, OwnerID: ownerID},
})
require.NoError(t, err)
require.Len(t, results, 1)
res := results[0]
require.NoError(t, res.Error)
require.NotNil(t, res.Params)
// The mock clock is deterministic, so times must be exact.
assert.Equal(t, fixedTime, res.Params.RefreshedAt)
assert.Equal(t, fixedTime.Add(gitsync.DiffStatusTTL), res.Params.StaleAt)
}
func TestRefresher_RateLimitSkipsRemainingInGroup(t *testing.T) {
t.Parallel()
var callCount atomic.Int32
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
var num int
switch {
case strings.HasSuffix(raw, "/pull/1"):
num = 1
case strings.HasSuffix(raw, "/pull/2"):
num = 2
case strings.HasSuffix(raw, "/pull/3"):
num = 3
default:
return gitprovider.PRRef{}, false
}
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
},
fetchPullRequestStatus: func(_ context.Context, _ string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
call := callCount.Add(1)
switch call {
case 1:
// First call succeeds.
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 5,
Deletions: 2,
ChangedFiles: 1,
},
}, nil
case 2:
// Second call hits rate limit.
return nil, &gitprovider.RateLimitError{
RetryAfter: time.Now().Add(60 * time.Second),
}
default:
// Third call should never happen.
t.Fatal("FetchPullRequestStatus called more than 2 times")
return nil, nil
}
},
}
providers := func(_ string) gitprovider.Provider { return mp }
tokens := func(_ context.Context, _ uuid.UUID, _ string) (*string, error) {
return ptr.Ref("test-token"), nil
}
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
origin := "https://github.com/org/repo"
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/3", Valid: true},
GitRemoteOrigin: origin,
GitBranch: "feat-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
// Row 0: success.
assert.NoError(t, results[0].Error)
assert.NotNil(t, results[0].Params)
// Row 1: rate-limited.
require.Error(t, results[1].Error)
var rlErr1 *gitprovider.RateLimitError
assert.True(t, errors.As(results[1].Error, &rlErr1),
"result[1] error should be *RateLimitError")
// Row 2: skipped due to rate limit.
require.Error(t, results[2].Error)
var rlErr2 *gitprovider.RateLimitError
assert.True(t, errors.As(results[2].Error, &rlErr2),
"result[2] error should wrap *RateLimitError")
assert.Contains(t, results[2].Error.Error(), "skipped")
// Provider should have been called exactly twice.
assert.Equal(t, int32(2), callCount.Load(),
"FetchPullRequestStatus should be called exactly 2 times")
}
func TestRefresher_CorrectTokenPerOrigin(t *testing.T) {
t.Parallel()
var tokenCalls atomic.Int32
tokens := func(_ context.Context, _ uuid.UUID, origin string) (*string, error) {
tokenCalls.Add(1)
switch {
case strings.Contains(origin, "github.com"):
return ptr.Ref("gh-public-token"), nil
case strings.Contains(origin, "ghes.corp.com"):
return ptr.Ref("ghe-private-token"), nil
default:
return nil, fmt.Errorf("unexpected origin: %s", origin)
}
}
// Track which token each FetchPullRequestStatus call received,
// keyed by chat ID. We pass the chat ID through the PRRef.Number
// field (unique per request) so FetchPullRequestStatus can
// identify which row it's processing.
var mu sync.Mutex
tokensByPR := make(map[int]string)
mp := &mockProvider{
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
// Extract a unique PR number from the URL to identify
// each row inside FetchPullRequestStatus.
var num int
switch {
case strings.HasSuffix(raw, "/pull/1"):
num = 1
case strings.HasSuffix(raw, "/pull/2"):
num = 2
case strings.HasSuffix(raw, "/pull/10"):
num = 10
default:
return gitprovider.PRRef{}, false
}
return gitprovider.PRRef{Owner: "org", Repo: "repo", Number: num}, true
},
fetchPullRequestStatus: func(_ context.Context, token string, ref gitprovider.PRRef) (*gitprovider.PRStatus, error) {
mu.Lock()
tokensByPR[ref.Number] = token
mu.Unlock()
return &gitprovider.PRStatus{State: gitprovider.PRStateOpen}, nil
},
}
providers := func(_ string) gitprovider.Provider { return mp }
r := gitsync.NewRefresher(providers, tokens, slogtest.Make(t, nil), quartz.NewReal())
ownerID := uuid.New()
requests := []gitsync.RefreshRequest{
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/1", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature-1",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://github.com/org/repo/pull/2", Valid: true},
GitRemoteOrigin: "https://github.com/org/repo",
GitBranch: "feature-2",
},
OwnerID: ownerID,
},
{
Row: database.ChatDiffStatus{
ChatID: uuid.New(),
Url: sql.NullString{String: "https://ghes.corp.com/org/repo/pull/10", Valid: true},
GitRemoteOrigin: "https://ghes.corp.com/org/repo",
GitBranch: "feature-3",
},
OwnerID: ownerID,
},
}
results, err := r.Refresh(context.Background(), requests)
require.NoError(t, err)
require.Len(t, results, 3)
for i, res := range results {
require.NoError(t, res.Error, "result[%d] should not have an error", i)
require.NotNil(t, res.Params, "result[%d] should have params", i)
}
// github.com rows (PR #1 and #2) should use the public token.
assert.Equal(t, "gh-public-token", tokensByPR[1],
"github.com PR #1 should use gh-public-token")
assert.Equal(t, "gh-public-token", tokensByPR[2],
"github.com PR #2 should use gh-public-token")
// ghes.corp.com row (PR #10) should use the GHE token.
assert.Equal(t, "ghe-private-token", tokensByPR[10],
"ghes.corp.com PR #10 should use ghe-private-token")
// Token resolution should be called exactly twice — once per
// (owner, origin) group.
assert.Equal(t, int32(2), tokenCalls.Load(),
"TokenResolver should be called once per (owner, origin) group")
}
+5
View File
@@ -0,0 +1,5 @@
// Package gitsyncmock contains generated mocks for the gitsync package.
package gitsyncmock
//go:generate mockgen -destination ./store.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync Store
//go:generate mockgen -destination ./publisher.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync EventPublisher
+56
View File
@@ -0,0 +1,56 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/coderd/gitsync (interfaces: EventPublisher)
//
// Generated by this command:
//
// mockgen -destination ./publisher.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync EventPublisher
//
// Package gitsyncmock is a generated GoMock package.
package gitsyncmock
import (
context "context"
reflect "reflect"
uuid "github.com/google/uuid"
gomock "go.uber.org/mock/gomock"
)
// MockEventPublisher is a mock of EventPublisher interface.
type MockEventPublisher struct {
ctrl *gomock.Controller
recorder *MockEventPublisherMockRecorder
isgomock struct{}
}
// MockEventPublisherMockRecorder is the mock recorder for MockEventPublisher.
type MockEventPublisherMockRecorder struct {
mock *MockEventPublisher
}
// NewMockEventPublisher creates a new mock instance.
func NewMockEventPublisher(ctrl *gomock.Controller) *MockEventPublisher {
mock := &MockEventPublisher{ctrl: ctrl}
mock.recorder = &MockEventPublisherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEventPublisher) EXPECT() *MockEventPublisherMockRecorder {
return m.recorder
}
// PublishDiffStatusChange mocks base method.
func (m *MockEventPublisher) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PublishDiffStatusChange", ctx, chatID)
ret0, _ := ret[0].(error)
return ret0
}
// PublishDiffStatusChange indicates an expected call of PublishDiffStatusChange.
func (mr *MockEventPublisherMockRecorder) PublishDiffStatusChange(ctx, chatID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PublishDiffStatusChange", reflect.TypeOf((*MockEventPublisher)(nil).PublishDiffStatusChange), ctx, chatID)
}
+116
View File
@@ -0,0 +1,116 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/coderd/gitsync (interfaces: Store)
//
// Generated by this command:
//
// mockgen -destination ./store.go -package gitsyncmock github.com/coder/coder/v2/coderd/gitsync Store
//
// Package gitsyncmock is a generated GoMock package.
package gitsyncmock
import (
context "context"
reflect "reflect"
database "github.com/coder/coder/v2/coderd/database"
gomock "go.uber.org/mock/gomock"
)
// MockStore is a mock of Store interface.
type MockStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
isgomock struct{}
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// AcquireStaleChatDiffStatuses mocks base method.
func (m *MockStore) AcquireStaleChatDiffStatuses(ctx context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireStaleChatDiffStatuses", ctx, limitVal)
ret0, _ := ret[0].([]database.AcquireStaleChatDiffStatusesRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireStaleChatDiffStatuses indicates an expected call of AcquireStaleChatDiffStatuses.
func (mr *MockStoreMockRecorder) AcquireStaleChatDiffStatuses(ctx, limitVal any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireStaleChatDiffStatuses", reflect.TypeOf((*MockStore)(nil).AcquireStaleChatDiffStatuses), ctx, limitVal)
}
// BackoffChatDiffStatus mocks base method.
func (m *MockStore) BackoffChatDiffStatus(ctx context.Context, arg database.BackoffChatDiffStatusParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BackoffChatDiffStatus", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// BackoffChatDiffStatus indicates an expected call of BackoffChatDiffStatus.
func (mr *MockStoreMockRecorder) BackoffChatDiffStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackoffChatDiffStatus", reflect.TypeOf((*MockStore)(nil).BackoffChatDiffStatus), ctx, arg)
}
// GetChatsByOwnerID mocks base method.
func (m *MockStore) GetChatsByOwnerID(ctx context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatsByOwnerID", ctx, arg)
ret0, _ := ret[0].([]database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatsByOwnerID indicates an expected call of GetChatsByOwnerID.
func (mr *MockStoreMockRecorder) GetChatsByOwnerID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByOwnerID", reflect.TypeOf((*MockStore)(nil).GetChatsByOwnerID), ctx, arg)
}
// UpsertChatDiffStatus mocks base method.
func (m *MockStore) UpsertChatDiffStatus(ctx context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDiffStatus", ctx, arg)
ret0, _ := ret[0].(database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatDiffStatus indicates an expected call of UpsertChatDiffStatus.
func (mr *MockStoreMockRecorder) UpsertChatDiffStatus(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatus", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatus), ctx, arg)
}
// UpsertChatDiffStatusReference mocks base method.
func (m *MockStore) UpsertChatDiffStatusReference(ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatDiffStatusReference", ctx, arg)
ret0, _ := ret[0].(database.ChatDiffStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertChatDiffStatusReference indicates an expected call of UpsertChatDiffStatusReference.
func (mr *MockStoreMockRecorder) UpsertChatDiffStatusReference(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatDiffStatusReference", reflect.TypeOf((*MockStore)(nil).UpsertChatDiffStatusReference), ctx, arg)
}
+248
View File
@@ -0,0 +1,248 @@
package gitsync
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/quartz"
)
const (
// defaultBatchSize is the maximum number of stale rows fetched
// per tick.
defaultBatchSize int32 = 50
// defaultInterval is the polling interval between ticks.
defaultInterval = 10 * time.Second
)
// Store is the narrow DB interface the Worker needs.
type Store interface {
AcquireStaleChatDiffStatuses(
ctx context.Context, limitVal int32,
) ([]database.AcquireStaleChatDiffStatusesRow, error)
BackoffChatDiffStatus(
ctx context.Context, arg database.BackoffChatDiffStatusParams,
) error
UpsertChatDiffStatus(
ctx context.Context, arg database.UpsertChatDiffStatusParams,
) (database.ChatDiffStatus, error)
UpsertChatDiffStatusReference(
ctx context.Context, arg database.UpsertChatDiffStatusReferenceParams,
) (database.ChatDiffStatus, error)
GetChatsByOwnerID(
ctx context.Context, arg database.GetChatsByOwnerIDParams,
) ([]database.Chat, error)
}
// EventPublisher notifies the frontend of diff status changes.
type EventPublisher interface {
PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error
}
// Worker is a background loop that periodically refreshes stale
// chat diff statuses by delegating to a Refresher.
type Worker struct {
store Store
refresher *Refresher
publisher EventPublisher
clock quartz.Clock
logger slog.Logger
batchSize int32
interval time.Duration
done chan struct{}
}
// NewWorker creates a Worker with default batch size and interval.
func NewWorker(
store Store,
refresher *Refresher,
publisher EventPublisher,
clock quartz.Clock,
logger slog.Logger,
) *Worker {
return &Worker{
store: store,
refresher: refresher,
publisher: publisher,
clock: clock,
logger: logger,
batchSize: defaultBatchSize,
interval: defaultInterval,
done: make(chan struct{}),
}
}
// Start launches the background loop. It blocks until ctx is
// cancelled, then closes w.done.
func (w *Worker) Start(ctx context.Context) {
defer close(w.done)
ticker := w.clock.NewTicker(w.interval, "gitsync", "worker")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.tick(ctx)
}
}
}
// Done returns a channel that is closed when the worker exits.
func (w *Worker) Done() <-chan struct{} {
return w.done
}
func chatDiffStatusFromRow(row database.AcquireStaleChatDiffStatusesRow) database.ChatDiffStatus {
return database.ChatDiffStatus{
ChatID: row.ChatID,
Url: row.Url,
PullRequestState: row.PullRequestState,
ChangesRequested: row.ChangesRequested,
Additions: row.Additions,
Deletions: row.Deletions,
ChangedFiles: row.ChangedFiles,
RefreshedAt: row.RefreshedAt,
StaleAt: row.StaleAt,
CreatedAt: row.CreatedAt,
UpdatedAt: row.UpdatedAt,
GitBranch: row.GitBranch,
GitRemoteOrigin: row.GitRemoteOrigin,
}
}
func (w *Worker) tick(ctx context.Context) {
acquiredRows, err := w.store.AcquireStaleChatDiffStatuses(ctx, w.batchSize)
if err != nil {
w.logger.Warn(ctx, "acquire stale chat diff statuses",
slog.Error(err))
return
}
if len(acquiredRows) == 0 {
return
}
// Build refresh requests directly from acquired rows.
requests := make([]RefreshRequest, 0, len(acquiredRows))
for _, row := range acquiredRows {
requests = append(requests, RefreshRequest{
Row: chatDiffStatusFromRow(row),
OwnerID: row.OwnerID,
})
}
results, err := w.refresher.Refresh(ctx, requests)
if err != nil {
w.logger.Warn(ctx, "batch refresh chat diff statuses",
slog.Error(err))
return
}
for _, res := range results {
if res.Error != nil {
w.logger.Debug(ctx, "refresh chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(res.Error))
// Back off so the row isn't retried immediately.
if err := w.store.BackoffChatDiffStatus(ctx,
database.BackoffChatDiffStatusParams{
ChatID: res.Request.Row.ChatID,
StaleAt: w.clock.Now().UTC().Add(DiffStatusTTL),
},
); err != nil {
w.logger.Warn(ctx, "backoff failed chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
}
continue
}
if res.Params == nil {
// No PR yet — skip.
continue
}
if _, err := w.store.UpsertChatDiffStatus(ctx, *res.Params); err != nil {
w.logger.Warn(ctx, "upsert refreshed chat diff status",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
continue
}
if err := w.publisher.PublishDiffStatusChange(ctx, res.Request.Row.ChatID); err != nil {
w.logger.Debug(ctx, "publish diff status change",
slog.F("chat_id", res.Request.Row.ChatID),
slog.Error(err))
}
}
}
// MarkStale persists the git ref on all chats for a workspace,
// setting stale_at to the past so the next tick picks them up.
// Publishes a diff status event for each affected chat.
// Called from workspaceagents handlers. No goroutines spawned.
func (w *Worker) MarkStale(
ctx context.Context,
workspaceID, ownerID uuid.UUID,
branch, origin string,
) {
if branch == "" || origin == "" {
return
}
chats, err := w.store.GetChatsByOwnerID(ctx, database.GetChatsByOwnerIDParams{
OwnerID: ownerID,
})
if err != nil {
w.logger.Warn(ctx, "list chats for git ref storage",
slog.F("workspace_id", workspaceID),
slog.Error(err))
return
}
for _, chat := range filterChatsByWorkspaceID(chats, workspaceID) {
_, err := w.store.UpsertChatDiffStatusReference(ctx,
database.UpsertChatDiffStatusReferenceParams{
ChatID: chat.ID,
GitBranch: branch,
GitRemoteOrigin: origin,
StaleAt: w.clock.Now().Add(-time.Second),
Url: sql.NullString{},
},
)
if err != nil {
w.logger.Warn(ctx, "store git ref on chat diff status",
slog.F("chat_id", chat.ID),
slog.F("workspace_id", workspaceID),
slog.Error(err))
continue
}
// Notify the frontend immediately so the UI shows the
// branch info even before the worker refreshes PR data.
if pubErr := w.publisher.PublishDiffStatusChange(ctx, chat.ID); pubErr != nil {
w.logger.Debug(ctx, "publish diff status after mark stale",
slog.F("chat_id", chat.ID), slog.Error(pubErr))
}
}
}
// filterChatsByWorkspaceID returns only chats associated with
// the given workspace.
func filterChatsByWorkspaceID(
chats []database.Chat,
workspaceID uuid.UUID,
) []database.Chat {
filtered := make([]database.Chat, 0, len(chats))
for _, chat := range chats {
if !chat.WorkspaceID.Valid || chat.WorkspaceID.UUID != workspaceID {
continue
}
filtered = append(filtered, chat)
}
return filtered
}
+757
View File
@@ -0,0 +1,757 @@
package gitsync_test
import (
"context"
"database/sql"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/externalauth/gitprovider"
"github.com/coder/coder/v2/coderd/gitsync"
"github.com/coder/coder/v2/coderd/gitsync/gitsyncmock"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
// testRefresherCfg configures newTestRefresher.
type testRefresherCfg struct {
resolveBranchPR func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)
fetchPRStatus func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error)
}
type testRefresherOpt func(*testRefresherCfg)
func withResolveBranchPR(f func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error)) testRefresherOpt {
return func(c *testRefresherCfg) { c.resolveBranchPR = f }
}
// newTestRefresher creates a Refresher backed by mock
// provider/token resolvers. The provider recognises any origin,
// resolves branches to a canned PR, and returns a canned PRStatus.
func newTestRefresher(t *testing.T, clk quartz.Clock, opts ...testRefresherOpt) *gitsync.Refresher {
t.Helper()
cfg := testRefresherCfg{
resolveBranchPR: func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
},
fetchPRStatus: func(context.Context, string, gitprovider.PRRef) (*gitprovider.PRStatus, error) {
return &gitprovider.PRStatus{
State: gitprovider.PRStateOpen,
DiffStats: gitprovider.DiffStats{
Additions: 10,
Deletions: 3,
ChangedFiles: 2,
},
}, nil
},
}
for _, o := range opts {
o(&cfg)
}
prov := &mockProvider{
parseRepositoryOrigin: func(string) (string, string, string, bool) {
return "owner", "repo", "https://github.com/owner/repo", true
},
parsePullRequestURL: func(raw string) (gitprovider.PRRef, bool) {
return gitprovider.PRRef{Owner: "owner", Repo: "repo", Number: 1}, raw != ""
},
resolveBranchPR: cfg.resolveBranchPR,
fetchPullRequestStatus: cfg.fetchPRStatus,
buildPullRequestURL: func(ref gitprovider.PRRef) string {
return fmt.Sprintf("https://github.com/%s/%s/pull/%d", ref.Owner, ref.Repo, ref.Number)
},
}
providers := func(string) gitprovider.Provider { return prov }
tokens := func(context.Context, uuid.UUID, string) (*string, error) {
return ptr.Ref("tok"), nil
}
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
return gitsync.NewRefresher(providers, tokens, logger, clk)
}
// makeAcquiredRow returns an AcquireStaleChatDiffStatusesRow with
// a non-empty branch/origin so the Refresher goes through the
// branch-resolution path.
func makeAcquiredRow(chatID, ownerID uuid.UUID) database.AcquireStaleChatDiffStatusesRow {
return database.AcquireStaleChatDiffStatusesRow{
ChatID: chatID,
GitBranch: "feature",
GitRemoteOrigin: "https://github.com/owner/repo",
StaleAt: time.Now().Add(-time.Minute),
OwnerID: ownerID,
}
}
// tickOnce traps the worker's NewTicker call, starts the worker,
// fires one tick, waits for it to finish by observing the given
// tickDone channel, then shuts the worker down. The tickDone
// channel must be closed when the last expected operation in the
// tick completes. For tests where the tick does nothing (e.g. 0
// stale rows or store error), tickDone should be closed inside
// acquireStaleChatDiffStatuses.
func tickOnce(
ctx context.Context,
t *testing.T,
mClock *quartz.Mock,
worker *gitsync.Worker,
tickDone <-chan struct{},
) {
t.Helper()
trap := mClock.Trap().NewTicker("gitsync", "worker")
defer trap.Close()
workerCtx, cancel := context.WithCancel(ctx)
defer cancel()
go worker.Start(workerCtx)
// Wait for the worker to create its ticker.
trap.MustWait(ctx).MustRelease(ctx)
// Fire one tick. The waiter resolves when the channel receive
// completes, not when w.tick() returns, so we use tickDone to
// know when to proceed.
_, w := mClock.AdvanceNext()
w.MustWait(ctx)
// Wait for the tick's business logic to finish.
select {
case <-tickDone:
case <-ctx.Done():
t.Fatal("timed out waiting for tick to complete")
}
cancel()
<-worker.Done()
}
func TestWorker_PicksUpStaleRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chat1 := uuid.New()
chat2 := uuid.New()
ownerID := uuid.New()
var upsertCount atomic.Int32
var publishCount atomic.Int32
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{
makeAcquiredRow(chat1, ownerID),
makeAcquiredRow(chat2, ownerID),
}, nil)
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
if publishCount.Add(1) == 2 {
close(tickDone)
}
return nil
}).Times(2)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
assert.Equal(t, int32(2), upsertCount.Load())
assert.Equal(t, int32(2), publishCount.Load())
}
func TestWorker_SkipsFreshRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
// No stale rows — tick returns immediately.
close(tickDone)
return nil, nil
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_LimitsToNRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
var capturedLimit atomic.Int32
var upsertCount atomic.Int32
ownerID := uuid.New()
const numRows = 5
tickDone := make(chan struct{})
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
for i := range rows {
rows[i] = makeAcquiredRow(uuid.New(), ownerID)
}
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, limitVal int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
capturedLimit.Store(limitVal)
return rows, nil
})
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(numRows)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, uuid.UUID) error {
if upsertCount.Load() == numRows {
close(tickDone)
}
return nil
}).Times(numRows)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
// The default batch size is 50.
assert.Equal(t, int32(50), capturedLimit.Load())
assert.Equal(t, int32(numRows), upsertCount.Load())
}
func TestWorker_RefresherReturnsNilNil_SkipsUpsert(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chatID := uuid.New()
ownerID := uuid.New()
// When the Refresher returns (nil, nil) the worker skips the
// upsert and publish. We signal tickDone from the refresher
// mock since that is the last operation before the tick
// returns.
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{makeAcquiredRow(chatID, ownerID)}, nil)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// ResolveBranchPullRequest returns nil → Refresher returns
// (nil, nil).
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
close(tickDone)
return nil, nil
},
))
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_RefresherError_BacksOffRow(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chat1 := uuid.New()
chat2 := uuid.New()
ownerID := uuid.New()
var upsertCount atomic.Int32
var publishCount atomic.Int32
var backoffCount atomic.Int32
var mu sync.Mutex
var backoffArgs []database.BackoffChatDiffStatusParams
tickDone := make(chan struct{})
var closeOnce sync.Once
// Two rows processed: one fails (backoff), one succeeds
// (upsert+publish). Both must finish before we close tickDone.
var terminalOps atomic.Int32
signalIfDone := func() {
if terminalOps.Add(1) == 2 {
closeOnce.Do(func() { close(tickDone) })
}
}
mClock := quartz.NewMock(t)
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{
makeAcquiredRow(chat1, ownerID),
makeAcquiredRow(chat2, ownerID),
}, nil)
store.EXPECT().BackoffChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.BackoffChatDiffStatusParams) error {
backoffCount.Add(1)
mu.Lock()
backoffArgs = append(backoffArgs, arg)
mu.Unlock()
signalIfDone()
return nil
})
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
})
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, uuid.UUID) error {
// Only the successful row publishes.
publishCount.Add(1)
signalIfDone()
return nil
})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Fail ResolveBranchPullRequest for the first call, succeed
// for the second.
var callCount atomic.Int32
refresher := newTestRefresher(t, mClock, withResolveBranchPR(
func(context.Context, string, gitprovider.BranchRef) (*gitprovider.PRRef, error) {
n := callCount.Add(1)
if n == 1 {
return nil, fmt.Errorf("simulated provider error")
}
return &gitprovider.PRRef{Owner: "o", Repo: "r", Number: 1}, nil
},
))
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
// BackoffChatDiffStatus was called for the failed row.
assert.Equal(t, int32(1), backoffCount.Load())
mu.Lock()
require.Len(t, backoffArgs, 1)
assert.Equal(t, chat1, backoffArgs[0].ChatID)
// stale_at should be approximately clock.Now() + DiffStatusTTL (120s).
expectedStaleAt := mClock.Now().UTC().Add(gitsync.DiffStatusTTL)
assert.WithinDuration(t, expectedStaleAt, backoffArgs[0].StaleAt, time.Second)
mu.Unlock()
// UpsertChatDiffStatus was called for the successful row.
assert.Equal(t, int32(1), upsertCount.Load())
// PublishDiffStatusChange was called only for the successful row.
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_UpsertError_ContinuesNextRow(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
chat1 := uuid.New()
chat2 := uuid.New()
ownerID := uuid.New()
var publishCount atomic.Int32
tickDone := make(chan struct{})
var closeOnce sync.Once
var mu sync.Mutex
upsertedChatIDs := make(map[uuid.UUID]struct{})
// We have 2 rows. The upsert for chat1 fails; the upsert
// for chat2 succeeds and publishes. Because goroutines run
// concurrently we don't know which finishes last, so we
// track the total number of "terminal" events (upsert error
// + publish success) and close tickDone when both have
// occurred.
var terminalOps atomic.Int32
signalIfDone := func() {
if terminalOps.Add(1) == 2 {
closeOnce.Do(func() { close(tickDone) })
}
}
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return([]database.AcquireStaleChatDiffStatusesRow{
makeAcquiredRow(chat1, ownerID),
makeAcquiredRow(chat2, ownerID),
}, nil)
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
if arg.ChatID == chat1 {
// Terminal event for the failing row.
signalIfDone()
return database.ChatDiffStatus{}, fmt.Errorf("db write error")
}
mu.Lock()
upsertedChatIDs[arg.ChatID] = struct{}{}
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, uuid.UUID) error {
publishCount.Add(1)
// Terminal event for the successful row.
signalIfDone()
return nil
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
mu.Lock()
_, gotChat2 := upsertedChatIDs[chat2]
mu.Unlock()
assert.True(t, gotChat2, "chat2 should have been upserted")
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_RespectsShutdown(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return(nil, nil).AnyTimes()
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
trap := mClock.Trap().NewTicker("gitsync", "worker")
defer trap.Close()
workerCtx, cancel := context.WithCancel(ctx)
go worker.Start(workerCtx)
// Wait for ticker creation so the worker is running.
trap.MustWait(ctx).MustRelease(ctx)
// Cancel immediately.
cancel()
select {
case <-worker.Done():
// Success — worker shut down.
case <-ctx.Done():
t.Fatal("timed out waiting for worker to shut down")
}
}
func TestWorker_BatchRefreshAllRows(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ownerID := uuid.New()
const numRows = 5
rows := make([]database.AcquireStaleChatDiffStatusesRow, numRows)
for i := range rows {
rows[i] = makeAcquiredRow(uuid.New(), ownerID)
}
var upsertCount atomic.Int32
var publishCount atomic.Int32
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
Return(rows, nil)
store.EXPECT().UpsertChatDiffStatus(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusParams) (database.ChatDiffStatus, error) {
upsertCount.Add(1)
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(numRows)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, uuid.UUID) error {
if publishCount.Add(1) == numRows {
close(tickDone)
}
return nil
}).Times(numRows)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
assert.Equal(t, int32(numRows), upsertCount.Load())
assert.Equal(t, int32(numRows), publishCount.Load())
}
// --- MarkStale tests ---
func TestWorker_MarkStale_UpsertAndPublish(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
chat1 := uuid.New()
chat2 := uuid.New()
chatOther := uuid.New()
var mu sync.Mutex
var upsertRefCalls []database.UpsertChatDiffStatusReferenceParams
var publishedIDs []uuid.UUID
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.GetChatsByOwnerIDParams) ([]database.Chat, error) {
require.Equal(t, ownerID, arg.OwnerID)
return []database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chatOther, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil
})
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
mu.Lock()
upsertRefCalls = append(upsertRefCalls, arg)
mu.Unlock()
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, chatID uuid.UUID) error {
mu.Lock()
publishedIDs = append(publishedIDs, chatID)
mu.Unlock()
return nil
}).Times(2)
mClock := quartz.NewMock(t)
now := mClock.Now()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "feature", "https://github.com/owner/repo")
mu.Lock()
defer mu.Unlock()
require.Len(t, upsertRefCalls, 2)
for _, call := range upsertRefCalls {
assert.Equal(t, "feature", call.GitBranch)
assert.Equal(t, "https://github.com/owner/repo", call.GitRemoteOrigin)
assert.True(t, call.StaleAt.Before(now),
"stale_at should be in the past, got %v vs now %v", call.StaleAt, now)
assert.Equal(t, sql.NullString{}, call.Url)
}
require.Len(t, publishedIDs, 2)
assert.ElementsMatch(t, []uuid.UUID{chat1, chat2}, publishedIDs)
}
func TestWorker_MarkStale_NoMatchingChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
{ID: uuid.New(), OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: uuid.New(), Valid: true}},
}, nil)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "main", "https://github.com/x/y")
}
func TestWorker_MarkStale_UpsertFails_ContinuesNext(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
workspaceID := uuid.New()
ownerID := uuid.New()
chat1 := uuid.New()
chat2 := uuid.New()
var publishCount atomic.Int32
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return([]database.Chat{
{ID: chat1, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
{ID: chat2, OwnerID: ownerID, WorkspaceID: uuid.NullUUID{UUID: workspaceID, Valid: true}},
}, nil)
store.EXPECT().UpsertChatDiffStatusReference(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, arg database.UpsertChatDiffStatusReferenceParams) (database.ChatDiffStatus, error) {
if arg.ChatID == chat1 {
return database.ChatDiffStatus{}, fmt.Errorf("upsert ref error")
}
return database.ChatDiffStatus{ChatID: arg.ChatID}, nil
}).Times(2)
pub.EXPECT().PublishDiffStatusChange(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, uuid.UUID) error {
publishCount.Add(1)
return nil
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, workspaceID, ownerID, "dev", "https://github.com/a/b")
assert.Equal(t, int32(1), publishCount.Load())
}
func TestWorker_MarkStale_GetChatsFails(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().GetChatsByOwnerID(gomock.Any(), gomock.Any()).
Return(nil, fmt.Errorf("db error"))
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, uuid.New(), uuid.New(), "main", "https://github.com/x/y")
}
func TestWorker_TickStoreError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
tickDone := make(chan struct{})
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
store.EXPECT().AcquireStaleChatDiffStatuses(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, int32) ([]database.AcquireStaleChatDiffStatusesRow, error) {
close(tickDone)
return nil, fmt.Errorf("database unavailable")
})
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
tickOnce(ctx, t, mClock, worker, tickDone)
}
func TestWorker_MarkStale_EmptyBranchOrOrigin(t *testing.T) {
t.Parallel()
tests := []struct {
name string
branch string
origin string
}{
{"both empty", "", ""},
{"branch empty", "", "https://github.com/x/y"},
{"origin empty", "main", ""},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
store := gitsyncmock.NewMockStore(ctrl)
pub := gitsyncmock.NewMockEventPublisher(ctrl)
mClock := quartz.NewMock(t)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
refresher := newTestRefresher(t, mClock)
worker := gitsync.NewWorker(store, refresher, pub, mClock, logger)
worker.MarkStale(ctx, uuid.New(), uuid.New(), tc.branch, tc.origin)
})
}
}
+7 -19
View File
@@ -1835,18 +1835,6 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
Branch: strings.TrimSpace(query.Get("git_branch")),
RemoteOrigin: strings.TrimSpace(query.Get("git_remote_origin")),
}
var chatID uuid.NullUUID
if rawChatID := query.Get("chat_id"); rawChatID != "" {
parsed, err := uuid.Parse(rawChatID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid chat_id.",
Detail: err.Error(),
})
return
}
chatID = uuid.NullUUID{UUID: parsed, Valid: true}
}
// Either match or configID must be provided!
match := query.Get("match")
if match == "" {
@@ -1942,9 +1930,9 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
// Persist git refs as soon as the agent requests external auth so branch
// context is retained even if the flow requires an out-of-band login.
if gitRef.Branch != "" || gitRef.RemoteOrigin != "" {
//nolint:gocritic // System context required to persist chat git refs.
api.storeChatGitRef(dbauthz.AsSystemRestricted(ctx), workspace.ID, workspace.OwnerID, chatID, gitRef)
if gitRef.Branch != "" && gitRef.RemoteOrigin != "" {
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
}
var previousToken *database.ExternalAuthLink
@@ -1960,7 +1948,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
return
}
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, chatID, gitRef)
api.workspaceAgentsExternalAuthListen(ctx, rw, previousToken, externalAuthConfig, workspace, gitRef)
}
// This is the URL that will redirect the user with a state token.
@@ -2018,11 +2006,10 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
httpapi.Write(ctx, rw, http.StatusOK, resp)
}
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, chatID uuid.NullUUID, gitRef chatGitRef) {
func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.ResponseWriter, previous *database.ExternalAuthLink, externalAuthConfig *externalauth.Config, workspace database.Workspace, gitRef chatGitRef) {
// Since we're ticking frequently and this sign-in operation is rare,
// we are OK with polling to avoid the complexity of pubsub.
ticker, done := api.NewTicker(time.Second)
@@ -2092,7 +2079,8 @@ func (api *API) workspaceAgentsExternalAuthListen(ctx context.Context, rw http.R
})
return
}
api.triggerWorkspaceChatDiffStatusRefresh(workspace, chatID, gitRef)
//nolint:gocritic // Chat processor context required for cross-user chat lookup
api.gitSyncWorker.MarkStale(dbauthz.AsChatd(ctx), workspace.ID, workspace.OwnerID, gitRef.Branch, gitRef.RemoteOrigin)
httpapi.Write(ctx, rw, http.StatusOK, resp)
return
}
+4
View File
@@ -980,6 +980,10 @@ type ExternalAuthConfig struct {
// 'Username for "https://github.com":'
// And sending it to the Coder server to match against the Regex.
Regex string `json:"regex" yaml:"regex"`
// APIBaseURL is the base URL for provider REST API calls
// (e.g., "https://api.github.com" for GitHub). Derived from
// defaults when not explicitly configured.
APIBaseURL string `json:"api_base_url" yaml:"api_base_url"`
// DisplayName is shown in the UI to identify the auth config.
DisplayName string `json:"display_name" yaml:"display_name"`
// DisplayIcon is a URL to an icon to display in the UI.
+1
View File
@@ -22,6 +22,7 @@ externalAuthProviders:
mcp_tool_allow_regex: .*
mcp_tool_deny_regex: create_gist
regex: ^https://example.com/.*$
api_base_url: ""
display_name: GitHub
display_icon: /static/icons/github.svg
code_challenge_methods_supported:
+1
View File
@@ -279,6 +279,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \
"external_auth": {
"value": [
{
"api_base_url": "string",
"app_install_url": "string",
"app_installations_url": "string",
"auth_url": "string",
+21 -16
View File
@@ -2786,6 +2786,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"external_auth": {
"value": [
{
"api_base_url": "string",
"app_install_url": "string",
"app_installations_url": "string",
"auth_url": "string",
@@ -3357,6 +3358,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
"external_auth": {
"value": [
{
"api_base_url": "string",
"app_install_url": "string",
"app_installations_url": "string",
"auth_url": "string",
@@ -4104,6 +4106,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
```json
{
"api_base_url": "string",
"app_install_url": "string",
"app_installations_url": "string",
"auth_url": "string",
@@ -4133,22 +4136,23 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o
### Properties
| Name | Type | Required | Restrictions | Description |
|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------|
| `app_install_url` | string | false | | |
| `app_installations_url` | string | false | | |
| `auth_url` | string | false | | |
| `client_id` | string | false | | |
| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". |
| `device_code_url` | string | false | | |
| `device_flow` | boolean | false | | |
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
| `mcp_tool_allow_regex` | string | false | | |
| `mcp_tool_deny_regex` | string | false | | |
| `mcp_url` | string | false | | |
| `no_refresh` | boolean | false | | |
| Name | Type | Required | Restrictions | Description |
|------------------------------------|-----------------|----------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `api_base_url` | string | false | | Api base URL is the base URL for provider REST API calls (e.g., "https://api.github.com" for GitHub). Derived from defaults when not explicitly configured. |
| `app_install_url` | string | false | | |
| `app_installations_url` | string | false | | |
| `auth_url` | string | false | | |
| `client_id` | string | false | | |
| `code_challenge_methods_supported` | array of string | false | | Code challenge methods supported lists the PKCE code challenge methods The only one supported by Coder is "S256". |
| `device_code_url` | string | false | | |
| `device_flow` | boolean | false | | |
| `display_icon` | string | false | | Display icon is a URL to an icon to display in the UI. |
| `display_name` | string | false | | Display name is shown in the UI to identify the auth config. |
| `id` | string | false | | ID is a unique identifier for the auth config. It defaults to `type` when not provided. |
| `mcp_tool_allow_regex` | string | false | | |
| `mcp_tool_deny_regex` | string | false | | |
| `mcp_url` | string | false | | |
| `no_refresh` | boolean | false | | |
|`regex`|string|false||Regex allows API requesters to match an auth config by a string (e.g. coder.com) instead of by it's type.
Git clone makes use of this by parsing the URL from: 'Username for "https://github.com":' And sending it to the Coder server to match against the Regex.|
|`revoke_url`|string|false|||
@@ -14196,6 +14200,7 @@ None
{
"value": [
{
"api_base_url": "string",
"app_install_url": "string",
"app_installations_url": "string",
"auth_url": "string",
+6
View File
@@ -2682,6 +2682,12 @@ export interface ExternalAuthConfig {
* And sending it to the Coder server to match against the Regex.
*/
readonly regex: string;
/**
* APIBaseURL is the base URL for provider REST API calls
* (e.g., "https://api.github.com" for GitHub). Derived from
* defaults when not explicitly configured.
*/
readonly api_base_url: string;
/**
* DisplayName is shown in the UI to identify the auth config.
*/
@@ -12,6 +12,7 @@ const meta: Meta<typeof ExternalAuthSettingsPageView> = {
type: "GitHub",
client_id: "client_id",
regex: "regex",
api_base_url: "",
auth_url: "",
token_url: "",
validate_url: "",