Compare commits

..

8 Commits

Author SHA1 Message Date
Lukasz c174b3037b Merge branch 'main' into security-patch-train-doc 2026-04-07 16:31:36 +02:00
Lukasz f5165d304f ci(.github): automate security patch PRs and backports 2026-04-07 16:27:25 +02:00
Kyle Carberry 684f21740d perf(coderd): batch chat heartbeat queries into single UPDATE per interval (#24037)
## Summary

Replaces N per-chat heartbeat goroutines with a single centralized
heartbeat loop that issues one `UPDATE` per 30s interval for all running
chats on a worker.

## Problem

Each running chat spawned a dedicated goroutine that issued an
individual `UPDATE chats SET heartbeat_at = NOW() WHERE id = $1 AND
worker_id = $2 AND status = 'running'` query every 30 seconds. At 10,000
concurrent chats this produces **~333 DB queries/second** just for
heartbeats, plus ~333 `ActivityBumpWorkspace` CTE queries/second from
`trackWorkspaceUsage`.

## Solution

New `UpdateChatHeartbeats` (plural) SQL query replaces the old singular
`UpdateChatHeartbeat`:

```sql
UPDATE chats
SET    heartbeat_at = @now::timestamptz
WHERE  worker_id = @worker_id::uuid
  AND  status = 'running'::chat_status
RETURNING id;
```

A single `heartbeatLoop` goroutine on the `Server`:
1. Ticks every `chatHeartbeatInterval` (30s)
2. Issues one batch UPDATE for all registered chats
3. Detects stolen/completed chats via set-difference (equivalent of old
`rows == 0`)
4. Calls `trackWorkspaceUsage` for surviving chats

`processChat` registers an entry in the heartbeat registry instead of
spawning a goroutine.

## Impact

| Metric | Before (10K chats) | After (10K chats) |
|---|---|---|
| Heartbeat queries/sec | ~333 | ~0.03 (1 per 30s per replica) |
| Heartbeat goroutines | 10,000 | 1 |
| Self-interrupt detection | Per-chat `rows==0` | Batch set-difference |

---

> 🤖 Generated by Coder Agents

<details><summary>Implementation notes</summary>

- Uses `@now` parameter instead of `NOW()` so tests with `quartz.Mock`
can control timestamps.
- `heartbeatEntry` stores `context.CancelCauseFunc` + workspace state
for the centralized loop.
- `recoverStaleChats` is unaffected — it reads `heartbeat_at` which is
still updated.
- The old singular `UpdateChatHeartbeat` is removed entirely.
- `dbauthz` wrapper uses system-level `rbac.ResourceChat` authorization
(same pattern as `AcquireChats`).

</details>
2026-04-07 10:25:46 -04:00
George K 86ca61d6ca perf: cap count queries and emit native UUID comparisons for audit/connection logs (#23835)
Audit and connection log pages were timing out due to expensive COUNT(*)
queries over large tables. This commit adds opt-in count capping: requests can
return a `count_cap` field signaling that the count was truncated at a threshold,
avoiding full table scans that caused page timeouts.

Text-cast UUID comparisons in regosql-generated authorization queries
also contributed to the slowdown by preventing index usage for connection
and audit log queries. These now emit native UUID operators.

Frontend changes handle the capped state in usePaginatedQuery and
PaginationWidget, optionally displaying a capped count in the pagination
UI (e.g. "Showing 2,076 to 2,100 of 2,000+ logs")

Related to:
https://linear.app/codercom/issue/PLAT-31/connectionaudit-log-performance-issue
2026-04-07 07:24:53 -07:00
Jake Howell f0521cfa3c fix: resolve <LogLine /> storybook flake (#24084)
This pull-request ensures we have a stable test where the content
doesn't change every time we have a new storybook artifact by setting it
to a consistent date.

Closes https://github.com/coder/internal/issues/1454
2026-04-08 00:17:06 +10:00
Danielle Maywood 0c5d189aff fix(site): stabilize mutation callbacks for React Compiler memoization (#24089) 2026-04-07 15:05:27 +01:00
Michael Suchacz d7c8213eee fix(coderd/x/chatd/mcpclient): deterministic external MCP tool ordering (#24075)
> This PR was authored by Mux on behalf of Mike.

External MCP tools returned by `ConnectAll` were ordered by goroutine
completion, making the tool list nondeterministic across chat turns.
This broke prompt-cache stability since tools are serialized in order.

Sort tools by their model-visible name after all connections complete,
matching the existing pattern in workspace MCP tools
(`agent/x/agentmcp/manager.go`). Also guards against a nil-client panic
in cleanup when a connected server contributes zero tools after
filtering.
2026-04-07 14:42:30 +02:00
Cian Johnston 63924ac687 fix(site): use async findByLabelText in ProviderAccordionCards story (#24087)
- Use async `findByLabelText` instead of sync `getByLabelText` in
`ProviderAccordionCards` story
- Same bug fixed in #23999 for three other stories but missed for this
one

> 🤖 Written by a Coder Agent. Will be reviewed by a human.
2026-04-07 14:13:56 +02:00
60 changed files with 2524 additions and 1581 deletions
+4 -3
View File
@@ -31,7 +31,8 @@ updates:
patterns:
- "golang.org/x/*"
ignore:
# Ignore patch updates for all dependencies
# Patch updates are handled by the security-patch-prs workflow so this
# lane stays focused on broader dependency updates.
- dependency-name: "*"
update-types:
- version-update:semver-patch
@@ -56,7 +57,7 @@ updates:
labels: []
ignore:
# We need to coordinate terraform updates with the version hardcoded in
# our Go code.
# our Go code. These are handled by the security-patch-prs workflow.
- dependency-name: "terraform"
- package-ecosystem: "npm"
@@ -117,11 +118,11 @@ updates:
interval: "weekly"
commit-message:
prefix: "chore"
labels: []
groups:
coder-modules:
patterns:
- "coder/*/coder"
labels: []
ignore:
- dependency-name: "*"
update-types:
+354
View File
@@ -0,0 +1,354 @@
name: security-backport
on:
pull_request_target:
types:
- labeled
- unlabeled
- closed
workflow_dispatch:
inputs:
pull_request:
description: Pull request number to backport.
required: true
type: string
permissions:
contents: write
pull-requests: write
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
cancel-in-progress: false
env:
LATEST_BRANCH: release/2.31
STABLE_BRANCH: release/2.30
STABLE_1_BRANCH: release/2.29
jobs:
label-policy:
if: github.event_name == 'pull_request_target'
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Apply security backport label policy
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import subprocess
import sys
pr = json.loads(os.environ["PR_JSON"])
pr_number = pr["number"]
labels = [label["name"] for label in pr.get("labels", [])]
def has(label: str) -> bool:
return label in labels
def ensure_label(label: str) -> None:
if not has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--add-label", label],
check=False,
)
def remove_label(label: str) -> None:
if has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
check=False,
)
def comment(body: str) -> None:
subprocess.run(
["gh", "pr", "comment", str(pr_number), "--body", body],
check=True,
)
if not has("security:patch"):
remove_label("status:needs-severity")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if has(label)
]
if len(severity_labels) == 0:
ensure_label("status:needs-severity")
comment(
"This PR is labeled `security:patch` but is missing a severity "
"label. Add one of `severity:medium`, `severity:high`, or "
"`severity:critical` before backport automation can proceed."
)
sys.exit(0)
if len(severity_labels) > 1:
comment(
"This PR has multiple severity labels. Keep exactly one of "
"`severity:medium`, `severity:high`, or `severity:critical`."
)
sys.exit(1)
remove_label("status:needs-severity")
target_labels = [
label
for label in ("backport:stable", "backport:stable-1")
if has(label)
]
has_none = has("backport:none")
if has_none and target_labels:
comment(
"`backport:none` cannot be combined with other backport labels. "
"Remove `backport:none` or remove the explicit backport targets."
)
sys.exit(1)
if not has_none and not target_labels:
ensure_label("backport:stable")
ensure_label("backport:stable-1")
comment(
"Applied default backport labels `backport:stable` and "
"`backport:stable-1` for a qualifying security patch."
)
PY
backport:
if: >
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
github.event.pull_request.merged == true
)
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Resolve PR metadata
id: metadata
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
run: |
set -euo pipefail
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
pr_number="${INPUT_PR_NUMBER}"
else
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
fi
case "${pr_number}" in
''|*[!0-9]*)
echo "A valid pull request number is required."
exit 1
;;
esac
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import sys
pr = json.loads(os.environ["PR_JSON"])
github_output = os.environ["GITHUB_OUTPUT"]
labels = [label["name"] for label in pr.get("labels", [])]
if "security:patch" not in labels:
print("Not a security patch PR; skipping.")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if label in labels
]
if len(severity_labels) != 1:
raise SystemExit(
"Merged security patch PR must have exactly one severity label."
)
if not pr.get("mergedAt"):
raise SystemExit(f"PR #{pr['number']} is not merged.")
if "backport:none" in labels:
target_pairs = []
else:
mapping = {
"backport:stable": os.environ["STABLE_BRANCH"],
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
}
target_pairs = []
for label_name, branch in mapping.items():
if label_name in labels and branch and branch != pr["baseRefName"]:
target_pairs.append({"label": label_name, "branch": branch})
with open(github_output, "a", encoding="utf-8") as f:
f.write(f"pr_number={pr['number']}\n")
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
f.write(f"title={pr['title']}\n")
f.write(f"url={pr['url']}\n")
f.write(f"author={pr['author']['login']}\n")
f.write(f"severity_label={severity_labels[0]}\n")
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
PY
- name: Backport to release branches
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
PR_TITLE: ${{ steps.metadata.outputs.title }}
PR_URL: ${{ steps.metadata.outputs.url }}
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git fetch origin --prune
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
failures=()
successes=()
while IFS=$'\t' read -r backport_label target_branch; do
[ -n "${target_branch}" ] || continue
safe_branch_name="${target_branch//\//-}"
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
existing_pr="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state all \
--json number,url \
--jq '.[0]')"
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
successes+=("${target_branch}:existing:${pr_url}")
continue
fi
git checkout -B "${head_branch}" "origin/${target_branch}"
if [ "${merge_parent_count}" -gt 1 ]; then
cherry_pick_args=(-m 1 "${MERGE_SHA}")
else
cherry_pick_args=("${MERGE_SHA}")
fi
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
git cherry-pick --abort || true
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
gh pr comment "${PR_NUMBER}" --body \
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
failures+=("${target_branch}:cherry-pick failed")
continue
fi
git push --force-with-lease origin "${head_branch}"
body_file="$(mktemp)"
printf '%s\n' \
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
"" \
"- Source PR: #${PR_NUMBER}" \
"- Source commit: ${MERGE_SHA}" \
"- Target branch: ${target_branch}" \
"- Severity: ${SEVERITY_LABEL}" \
> "${body_file}"
pr_url="$(gh pr create \
--base "${target_branch}" \
--head "${head_branch}" \
--title "${PR_TITLE} (backport to ${target_branch})" \
--body-file "${body_file}")"
backport_pr_number="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state open \
--json number \
--jq '.[0].number')"
gh pr edit "${backport_pr_number}" \
--add-label "security:patch" \
--add-label "${SEVERITY_LABEL}" \
--add-label "${backport_label}" || true
successes+=("${target_branch}:created:${pr_url}")
done < <(
python3 - <<'PY'
import json
import os
for pair in json.loads(os.environ["TARGET_PAIRS"]):
print(f"{pair['label']}\t{pair['branch']}")
PY
)
summary_file="$(mktemp)"
{
echo "## Security backport summary"
echo
if [ "${#successes[@]}" -gt 0 ]; then
echo "### Created or existing"
for entry in "${successes[@]}"; do
echo "- ${entry}"
done
echo
fi
if [ "${#failures[@]}" -gt 0 ]; then
echo "### Failures"
for entry in "${failures[@]}"; do
echo "- ${entry}"
done
fi
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
if [ "${#failures[@]}" -gt 0 ]; then
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
exit 1
fi
+214
View File
@@ -0,0 +1,214 @@
name: security-patch-prs
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * 1-5"
permissions:
contents: write
pull-requests: write
jobs:
patch:
strategy:
fail-fast: false
matrix:
lane:
- gomod
- terraform
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Patch Go dependencies
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go get -u=patch ./...
go mod tidy
# Guardrail: do not auto-edit replace directives.
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
echo "Refusing to auto-edit go.mod replace directives"
exit 1
fi
# Guardrail: only go.mod / go.sum may change.
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Patch bundled Terraform
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
current="$(
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
provisioner/terraform/install.go \
| head -1 \
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
)"
series="$(echo "$current" | cut -d. -f1,2)"
latest="$(
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
| jq -r --arg series "$series" '
.versions
| keys[]
| select(startswith($series + "."))
' \
| sort -V \
| tail -1
)"
test -n "$latest"
[ "$latest" != "$current" ] || exit 0
CURRENT_TERRAFORM_VERSION="$current" \
LATEST_TERRAFORM_VERSION="$latest" \
python3 - <<'PY'
from pathlib import Path
import os
current = os.environ["CURRENT_TERRAFORM_VERSION"]
latest = os.environ["LATEST_TERRAFORM_VERSION"]
updates = {
"scripts/Dockerfile.base": (
f"terraform/{current}/",
f"terraform/{latest}/",
),
"provisioner/terraform/install.go": (
f'NewVersion("{current}")',
f'NewVersion("{latest}")',
),
"install.sh": (
f'TERRAFORM_VERSION="{current}"',
f'TERRAFORM_VERSION="{latest}"',
),
}
for path_str, (before, after) in updates.items():
path = Path(path_str)
content = path.read_text()
if before not in content:
raise SystemExit(f"did not find expected text in {path_str}: {before}")
path.write_text(content.replace(before, after))
PY
# Guardrail: only the Terraform-version files may change.
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Validate Go dependency patch
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go test ./...
- name: Validate Terraform patch
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
go test ./provisioner/terraform/...
docker build -f scripts/Dockerfile.base .
- name: Skip PR creation when there are no changes
id: changes
shell: bash
run: |
set -euo pipefail
if git diff --quiet; then
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
else
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
fi
- name: Commit changes
if: steps.changes.outputs.has_changes == 'true'
shell: bash
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -B "secpatch/${{ matrix.lane }}"
git add -A
git commit -m "security: patch ${{ matrix.lane }}"
- name: Push branch
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git push --force-with-lease \
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
- name: Create or update PR
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
branch="secpatch/${{ matrix.lane }}"
title="security: patch ${{ matrix.lane }}"
body="$(cat <<'EOF'
Automated security patch PR for `${{ matrix.lane }}`.
Scope:
- gomod: patch-level Go dependency updates only
- terraform: bundled Terraform patch updates only
Guardrails:
- no application-code edits
- no auto-editing of go.mod replace directives
- CI must pass
EOF
)"
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
if [[ -n "${existing_pr}" ]]; then
gh pr edit "${existing_pr}" \
--title "${title}" \
--body "${body}"
pr_number="${existing_pr}"
else
gh pr create \
--base main \
--head "${branch}" \
--title "${title}" \
--body "${body}"
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
fi
for label in security dependencies automated-pr; do
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
gh pr edit "${pr_number}" --add-label "${label}"
fi
done
+6
View File
@@ -14175,6 +14175,9 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -14496,6 +14499,9 @@ const docTemplate = `{
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
+6
View File
@@ -12739,6 +12739,9 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
@@ -13039,6 +13042,9 @@
},
"count": {
"type": "integer"
},
"count_cap": {
"type": "integer"
}
}
},
+8 -1
View File
@@ -26,6 +26,11 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// Limit the count query to avoid a slow sequential scan due to joins
// on a large table. Set to 0 to disable capping (but also see the note
// in the SQL query).
const auditLogCountCap = 2000
// @Summary Get audit logs
// @ID get-audit-logs
// @Security CoderSessionToken
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
countFilter.Username = ""
}
// Use the same filters to count the number of audit logs
countFilter.CountCap = auditLogCountCap
count, err := api.Database.CountAuditLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
CountCap: auditLogCountCap,
})
return
}
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: api.convertAuditLogs(ctx, dblogs),
Count: count,
CountCap: auditLogCountCap,
})
}
+8 -8
View File
@@ -5782,15 +5782,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
return q.db.UpdateChatByID(ctx, arg)
}
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return 0, err
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
// The batch heartbeat is a system-level operation filtered by
// worker_id. Authorization is enforced by the AsChatd context
// at the call site rather than per-row, because checking each
// row individually would defeat the purpose of batching.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return 0, err
}
return q.db.UpdateChatHeartbeat(ctx, arg)
return q.db.UpdateChatHeartbeats(ctx, arg)
}
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
+7 -7
View File
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
ID: chat.ID,
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
resultID := uuid.New()
arg := database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{resultID},
WorkerID: uuid.New(),
Now: time.Now(),
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
}))
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
+4 -4
View File
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
return r0, r1
}
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
return r0, r1
}
+7 -7
View File
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
}
// UpdateChatHeartbeat mocks base method.
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
// UpdateChatHeartbeats mocks base method.
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
ret0, _ := ret[0].(int64)
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
}
// UpdateChatLabelsByID mocks base method.
+2
View File
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
if err != nil {
return 0, err
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
// Remove SQL comments
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
// Normalize indentation so subquery wrapping doesn't cause
// mismatches.
lines := strings.Split(whereClause, "\n")
for i, line := range lines {
lines[i] = strings.TrimLeft(line, " \t")
}
whereClause = strings.Join(lines, "\n")
return strings.TrimSpace(whereClause)
}
+5 -3
View File
@@ -870,9 +870,11 @@ type sqlcQuerier interface {
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
// Updates the cached injected context parts (AGENTS.md +
// skills) on the chat row. Called only when context changes
+242 -204
View File
@@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
}
const countAuditLogs = `-- name: CountAuditLogs :one
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN $1::text != '' THEN resource_type = $1::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
ELSE true
END
-- Filter organization_id
AND CASE
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $4::text != '' THEN resource_target = $4
ELSE true
END
-- Filter action
AND CASE
WHEN $5::text != '' THEN action = $5::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower($7)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8::text != '' THEN users.email = $8
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
ELSE true
END
-- Filter request_id
AND CASE
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF($13::int, 0) + 1
) AS limited_count
`
type CountAuditLogsParams struct {
@@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct {
DateTo time.Time `db:"date_to" json:"date_to"`
BuildReason string `db:"build_reason" json:"build_reason"`
RequestID uuid.UUID `db:"request_id" json:"request_id"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
@@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
arg.DateTo,
arg.BuildReason,
arg.RequestID,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
@@ -6601,30 +6615,49 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
return i, err
}
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
UPDATE
chats
SET
heartbeat_at = NOW()
heartbeat_at = $1::timestamptz
WHERE
id = $1::uuid
AND worker_id = $2::uuid
id = ANY($2::uuid[])
AND worker_id = $3::uuid
AND status = 'running'::chat_status
RETURNING id
`
type UpdateChatHeartbeatParams struct {
ID uuid.UUID `db:"id" json:"id"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
type UpdateChatHeartbeatsParams struct {
Now time.Time `db:"now" json:"now"`
IDs []uuid.UUID `db:"ids" json:"ids"`
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
}
// Bumps the heartbeat timestamp for a running chat so that other
// replicas know the worker is still alive.
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
// Bumps the heartbeat timestamp for the given set of chat IDs,
// provided they are still running and owned by the specified
// worker. Returns the IDs that were actually updated so the
// caller can detect stolen or completed chats via set-difference.
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
if err != nil {
return 0, err
return nil, err
}
return result.RowsAffected()
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var id uuid.UUID
if err := rows.Scan(&id); err != nil {
return nil, err
}
items = append(items, id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
@@ -7571,110 +7604,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
}
const countConnectionLogs = `-- name: CountConnectionLogs :one
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = $1
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN $2 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower($2) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = $3
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN $4 :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = $4 AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN $5 :: text != '' THEN
type = $5 :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $6
ELSE true
END
-- Filter by username
AND CASE
WHEN $7 :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower($7) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $8 :: text != '' THEN
users.email = $8
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= $9
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= $10
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = $11
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = $12
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN $13 :: text != '' THEN
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF($14::int, 0) + 1
) AS limited_count
`
type CountConnectionLogsParams struct {
@@ -7691,6 +7727,7 @@ type CountConnectionLogsParams struct {
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
Status string `db:"status" json:"status"`
CountCap int32 `db:"count_cap" json:"count_cap"`
}
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
@@ -7708,6 +7745,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
arg.WorkspaceID,
arg.ConnectionID,
arg.Status,
arg.CountCap,
)
var count int64
err := row.Scan(&count)
+99 -88
View File
@@ -149,94 +149,105 @@ VALUES (
RETURNING *;
-- name: CountAuditLogs :one
SELECT COUNT(*)
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
;
SELECT COUNT(*) FROM (
SELECT 1
FROM audit_logs
LEFT JOIN users ON audit_logs.user_id = users.id
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
-- First join on workspaces to get the initial workspace create
-- to workspace build 1 id. This is because the first create is
-- is a different audit log than subsequent starts.
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
AND audit_logs.resource_id = workspaces.id
-- Get the reason from the build if the resource type
-- is a workspace_build
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
AND audit_logs.resource_id = wb_build.id
-- Get the reason from the build #1 if this is the first
-- workspace create.
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
AND audit_logs.action = 'create'
AND workspaces.id = wb_workspace.workspace_id
AND wb_workspace.build_number = 1
WHERE
-- Filter resource_type
CASE
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
ELSE true
END
-- Filter organization_id
AND CASE
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target::text != '' THEN resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action::text != '' THEN action = @action::audit_action
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username::text != '' THEN user_id = (
SELECT id
FROM users
WHERE lower(username) = lower(@username)
AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email::text != '' THEN users.email = @email
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
ELSE true
END
-- Filter by build_reason
AND CASE
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
ELSE true
END
-- Filter request_id
AND CASE
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
ELSE true
END
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
-- @authorize_filter
-- Avoid a slow scan on a large table with joins. The caller
-- passes the count cap and we add 1 so the frontend can detect
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
-- -> NULL + 1 = NULL).
-- NOTE: Parameterizing this so that we can easily change from,
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
-- here if disabling the capping on a large table permanently.
-- This way the PG planner can plan parallel execution for
-- potential large wins.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
-- name: DeleteOldAuditLogConnectionEvents :exec
DELETE FROM audit_logs
+9 -6
View File
@@ -674,17 +674,20 @@ WHERE
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeat :execrows
-- Bumps the heartbeat timestamp for a running chat so that other
-- replicas know the worker is still alive.
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
-- provided they are still running and owned by the specified
-- worker. Returns the IDs that were actually updated so the
-- caller can detect stolen or completed chats via set-difference.
UPDATE
chats
SET
heartbeat_at = NOW()
heartbeat_at = @now::timestamptz
WHERE
id = @id::uuid
id = ANY(@ids::uuid[])
AND worker_id = @worker_id::uuid
AND status = 'running'::chat_status;
AND status = 'running'::chat_status
RETURNING id;
-- name: GetChatDiffStatusByChatID :one
SELECT
+107 -105
View File
@@ -133,111 +133,113 @@ OFFSET
@offset_opt;
-- name: CountConnectionLogs :one
SELECT
COUNT(*) AS count
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
;
SELECT COUNT(*) AS count FROM (
SELECT 1
FROM
connection_logs
JOIN users AS workspace_owner ON
connection_logs.workspace_owner_id = workspace_owner.id
LEFT JOIN users ON
connection_logs.user_id = users.id
JOIN organizations ON
connection_logs.organization_id = organizations.id
WHERE
-- Filter organization_id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.organization_id = @organization_id
ELSE true
END
-- Filter by workspace owner username
AND CASE
WHEN @workspace_owner :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
)
ELSE true
END
-- Filter by workspace_owner_id
AND CASE
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
workspace_owner_id = @workspace_owner_id
ELSE true
END
-- Filter by workspace_owner_email
AND CASE
WHEN @workspace_owner_email :: text != '' THEN
workspace_owner_id = (
SELECT id FROM users
WHERE email = @workspace_owner_email AND deleted = false
)
ELSE true
END
-- Filter by type
AND CASE
WHEN @type :: text != '' THEN
type = @type :: connection_type
ELSE true
END
-- Filter by user_id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (
SELECT id FROM users
WHERE lower(username) = lower(@username) AND deleted = false
)
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @user_email :: text != '' THEN
users.email = @user_email
ELSE true
END
-- Filter by connected_after
AND CASE
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time >= @connected_after
ELSE true
END
-- Filter by connected_before
AND CASE
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
connect_time <= @connected_before
ELSE true
END
-- Filter by workspace_id
AND CASE
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.workspace_id = @workspace_id
ELSE true
END
-- Filter by connection_id
AND CASE
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
connection_logs.connection_id = @connection_id
ELSE true
END
-- Filter by whether the session has a disconnect_time
AND CASE
WHEN @status :: text != '' THEN
((@status = 'ongoing' AND disconnect_time IS NULL) OR
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
-- Exclude web events, since we don't know their close time.
"type" NOT IN ('workspace_app', 'port_forwarding')
ELSE true
END
-- Authorize Filter clause will be injected below in
-- CountAuthorizedConnectionLogs
-- @authorize_filter
-- NOTE: See the CountAuditLogs LIMIT note.
LIMIT NULLIF(@count_cap::int, 0) + 1
) AS limited_count;
-- name: DeleteOldConnectionLogs :execrows
WITH old_logs AS (
+34
View File
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
ExpectedSQL: p("'' = 'org-id'"),
VariableConverter: regosql.ChatConverter(),
},
{
Name: "AuditLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id IS NOT NULL") + " OR " +
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.AuditLogConverter(),
},
{
Name: "ConnectionLogUUID",
Queries: []string{
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
`input.object.org_owner != ""`,
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
`"read" in input.object.acl_group_list[input.object.org_owner]`,
},
ExpectedSQL: p(
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id IS NOT NULL") + " OR " +
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
"(false)"),
VariableConverter: regosql.ConnectionLogConverter(),
},
}
for _, tc := range testCases {
+2 -2
View File
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
func AuditLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
// Audit logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
func ConnectionLogConverter() *sqltypes.VariableConverter {
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
resourceIDMatcher(),
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
// Connection logs have no user owner, only owner by an organization.
sqltypes.AlwaysFalse(userOwnerMatcher()),
)
+114
View File
@@ -0,0 +1,114 @@
package sqltypes
import (
"fmt"
"strings"
"github.com/open-policy-agent/opa/ast"
"golang.org/x/xerrors"
)
var (
_ VariableMatcher = astUUIDVar{}
_ Node = astUUIDVar{}
_ SupportsEquality = astUUIDVar{}
)
// astUUIDVar is a variable that represents a UUID column. Unlike
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
// This allows PostgreSQL to use indexes on UUID columns.
type astUUIDVar struct {
Source RegoSource
FieldPath []string
ColumnString string
}
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
}
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
left, err := RegoVarPath(u.FieldPath, rego)
if err == nil && len(left) == 0 {
return astUUIDVar{
Source: RegoSource(rego.String()),
FieldPath: u.FieldPath,
ColumnString: u.ColumnString,
}, true
}
return nil, false
}
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
return u.ColumnString
}
// EqualsSQLString handles equality comparisons for UUID columns.
// Rego always produces string literals, so we accept AstString and
// cast the literal to ::uuid in the output SQL. This lets PG use
// native UUID indexes instead of falling back to text comparisons.
// nolint:revive
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
switch other.UseAs().(type) {
case AstString:
// The other side is a rego string literal like
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
// that casts the literal to uuid so PG can use indexes:
// column = 'val'::uuid
// instead of the text-based:
// 'val' = COALESCE(column::text, '')
s, ok := other.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString, got %T", other)
}
if s.Value == "" {
// Empty string in rego means "no value". Compare the
// column against NULL since UUID columns represent
// absent values as NULL, not empty strings.
op := "IS NULL"
if not {
op = "IS NOT NULL"
}
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
}
return fmt.Sprintf("%s %s '%s'::uuid",
u.ColumnString, equalsOp(not), s.Value), nil
case astUUIDVar:
return basicSQLEquality(cfg, not, u, other), nil
default:
return "", xerrors.Errorf("unsupported equality: %T %s %T",
u, equalsOp(not), other)
}
}
// ContainedInSQL implements SupportsContainedIn so that a UUID column
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
// array elements are rego strings, so we cast each to ::uuid.
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
arr, ok := haystack.(ASTArray)
if !ok {
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
}
if len(arr.Value) == 0 {
return "false", nil
}
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
values := make([]string, 0, len(arr.Value))
for _, v := range arr.Value {
s, ok := v.(AstString)
if !ok {
return "", xerrors.Errorf("expected AstString array element, got %T", v)
}
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
}
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
u.ColumnString,
strings.Join(values, ",")), nil
}
+2 -1
View File
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
}
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
// nolint:exhaustruct // UserID is not obtained from the query parameters.
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
countFilter := database.CountAuditLogsParams{
RequestID: filter.RequestID,
ResourceID: filter.ResourceID,
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
}
// This MUST be kept in sync with the above
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
countFilter := database.CountConnectionLogsParams{
OrganizationID: filter.OrganizationID,
WorkspaceOwner: filter.WorkspaceOwner,
+124 -28
View File
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"net/http"
"slices"
"strconv"
@@ -151,6 +152,12 @@ type Server struct {
inFlightChatStaleAfter time.Duration
chatHeartbeatInterval time.Duration
// heartbeatMu guards heartbeatRegistry.
heartbeatMu sync.Mutex
// heartbeatRegistry maps chat IDs to their cancel functions
// and workspace state for the centralized heartbeat loop.
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
// and PromoteQueued so the run loop calls processOnce
// immediately instead of waiting for the next ticker.
@@ -706,6 +713,17 @@ type chatStreamState struct {
bufferRetainedAt time.Time
}
// heartbeatEntry tracks a single chat's cancel function and workspace
// state for the centralized heartbeat loop. Instead of spawning a
// per-chat goroutine, processChat registers an entry here and the
// single heartbeatLoop goroutine handles all chats.
type heartbeatEntry struct {
cancelWithCause context.CancelCauseFunc
chatID uuid.UUID
workspaceID uuid.NullUUID
logger slog.Logger
}
// resetDropCounters zeroes the rate-limiting state for both buffer
// and subscriber drop warnings. The caller must hold s.mu.
func (s *chatStreamState) resetDropCounters() {
@@ -2420,8 +2438,8 @@ func New(cfg Config) *Server {
clock: clk,
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
wakeCh: make(chan struct{}, 1),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
//nolint:gocritic // The chat processor uses a scoped chatd context.
ctx = dbauthz.AsChatd(ctx)
@@ -2461,6 +2479,9 @@ func (p *Server) start(ctx context.Context) {
// to handle chats orphaned by crashed or redeployed workers.
p.recoverStaleChats(ctx)
// Single heartbeat loop for all chats on this replica.
go p.heartbeatLoop(ctx)
acquireTicker := p.clock.NewTicker(
p.pendingChatAcquireInterval,
"chatd",
@@ -2730,6 +2751,97 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
p.workspaceMCPToolsCache.Delete(chatID)
}
// registerHeartbeat enrolls a chat in the centralized batch
// heartbeat loop. Must be called after chatCtx is created.
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
p.logger.Warn(context.Background(),
"duplicate heartbeat registration, skipping",
slog.F("chat_id", entry.chatID))
return
}
p.heartbeatRegistry[entry.chatID] = entry
}
// unregisterHeartbeat removes a chat from the centralized
// heartbeat loop when chat processing finishes.
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
p.heartbeatMu.Lock()
defer p.heartbeatMu.Unlock()
delete(p.heartbeatRegistry, chatID)
}
// heartbeatLoop runs in a single goroutine, issuing one batch
// heartbeat query per interval for all registered chats.
func (p *Server) heartbeatLoop(ctx context.Context) {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.heartbeatTick(ctx)
}
}
}
// heartbeatTick issues a single batch UPDATE for all running chats
// owned by this worker. Chats missing from the result set are
// interrupted (stolen by another replica or already completed).
func (p *Server) heartbeatTick(ctx context.Context) {
// Snapshot the registry under the lock.
p.heartbeatMu.Lock()
snapshot := maps.Clone(p.heartbeatRegistry)
p.heartbeatMu.Unlock()
if len(snapshot) == 0 {
return
}
// Collect the IDs we believe we own.
ids := slices.Collect(maps.Keys(snapshot))
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for batch-updating heartbeats.
chatdCtx := dbauthz.AsChatd(ctx)
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
IDs: ids,
WorkerID: p.workerID,
Now: p.clock.Now(),
})
if err != nil {
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
return
}
// Build a set of IDs that were successfully updated.
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
for _, id := range updatedIDs {
updated[id] = struct{}{}
}
// Interrupt registered chats that were not in the result
// (stolen by another replica or already completed).
for id, entry := range snapshot {
if _, ok := updated[id]; !ok {
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
entry.cancelWithCause(chatloop.ErrInterrupted)
continue
}
// Bump workspace usage for surviving chats.
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
// Update workspace ID in the registry for next tick.
p.heartbeatMu.Lock()
if current, exists := p.heartbeatRegistry[id]; exists {
current.workspaceID = newWsID
}
p.heartbeatMu.Unlock()
}
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
@@ -3575,33 +3687,17 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
}()
// Periodically update the heartbeat so other replicas know this
// worker is still alive. The goroutine stops when chatCtx is
// canceled (either by completion or interruption).
go func() {
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
defer ticker.Stop()
for {
select {
case <-chatCtx.Done():
return
case <-ticker.C:
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
WorkerID: p.workerID,
})
if err != nil {
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
continue
}
if rows == 0 {
cancel(chatloop.ErrInterrupted)
return
}
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
}
}
}()
// Register with the centralized heartbeat loop instead of
// running a per-chat goroutine. The loop issues a single batch
// UPDATE for all chats on this worker and detects stolen chats
// via set-difference.
p.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chat.ID,
workspaceID: chat.WorkspaceID,
logger: logger,
})
defer p.unregisterHeartbeat(chat.ID)
// Start buffering stream events BEFORE publishing the running
// status. This closes a race where a subscriber sees
+129
View File
@@ -21,6 +21,7 @@ import (
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
@@ -2071,6 +2072,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
workerID: workerID,
chatHeartbeatInterval: time.Minute,
configCache: newChatConfigCache(ctx, db, clock),
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Publish a stale "pending" notification on the control channel
@@ -2133,3 +2135,130 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
require.Equal(t, database.ChatStatusError, finalStatus,
"processChat should have reached runChat (error), not been interrupted (waiting)")
}
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
// batch heartbeat UPDATE does not return a registered chat's ID
// (because another replica stole it or it was completed), the
// heartbeat tick cancels that chat's context with ErrInterrupted
// while leaving surviving chats untouched.
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
workerID := uuid.New()
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: workerID,
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
// Create three chats with independent cancel functions.
chat1 := uuid.New()
chat2 := uuid.New()
chat3 := uuid.New()
_, cancel1 := context.WithCancelCause(ctx)
_, cancel2 := context.WithCancelCause(ctx)
ctx3, cancel3 := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel1,
chatID: chat1,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel2,
chatID: chat2,
logger: logger,
})
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel3,
chatID: chat3,
logger: logger,
})
// The batch UPDATE returns only chat1 and chat2 —
// chat3 was "stolen" by another replica.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
require.Equal(t, workerID, params.WorkerID)
require.Len(t, params.IDs, 3)
// Return only chat1 and chat2 as surviving.
return []uuid.UUID{chat1, chat2}, nil
},
)
server.heartbeatTick(ctx)
// chat3's context should be canceled with ErrInterrupted.
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
"stolen chat should be interrupted")
// chat3 should have been removed from the registry by
// unregister (in production this happens via defer in
// processChat). The heartbeat tick itself does not
// unregister — it only cancels. Verify the entry is
// still present (processChat's defer would clean it up).
server.heartbeatMu.Lock()
_, chat1Exists := server.heartbeatRegistry[chat1]
_, chat2Exists := server.heartbeatRegistry[chat2]
_, chat3Exists := server.heartbeatRegistry[chat3]
server.heartbeatMu.Unlock()
require.True(t, chat1Exists, "surviving chat1 should remain registered")
require.True(t, chat2Exists, "surviving chat2 should remain registered")
require.True(t, chat3Exists,
"stolen chat3 should still be in registry (processChat defer removes it)")
}
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
// transient database failure causes the tick to log and return
// without canceling any registered chats.
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
clock := quartz.NewMock(t)
server := &Server{
db: db,
logger: logger,
clock: clock,
workerID: uuid.New(),
chatHeartbeatInterval: time.Minute,
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
}
chatID := uuid.New()
chatCtx, cancel := context.WithCancelCause(ctx)
server.registerHeartbeat(&heartbeatEntry{
cancelWithCause: cancel,
chatID: chatID,
logger: logger,
})
// Simulate a transient DB error.
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
nil, xerrors.New("connection reset"),
)
server.heartbeatTick(ctx)
// Chat should NOT be interrupted — the tick logged and
// returned early.
require.NoError(t, chatCtx.Err(),
"chat context should not be canceled on transient DB error")
}
+12 -7
View File
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
}
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
@@ -501,19 +501,24 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
})
require.NoError(t, err)
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
// Wrong worker_id should return no IDs.
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: uuid.New(),
Now: time.Now(),
})
require.NoError(t, err)
require.Equal(t, int64(0), rows)
require.Empty(t, ids)
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
ID: chat.ID,
// Correct worker_id should return the chat's ID.
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
IDs: []uuid.UUID{chat.ID},
WorkerID: workerID,
Now: time.Now(),
})
require.NoError(t, err)
require.Equal(t, int64(1), rows)
require.Len(t, ids, 1)
require.Equal(t, chat.ID, ids[0])
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
+34 -5
View File
@@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
@@ -49,10 +50,11 @@ const connectTimeout = 10 * time.Second
const toolCallTimeout = 60 * time.Second
// ConnectAll connects to all configured MCP servers, discovers
// their tools, and returns them as fantasy.AgentTool values. It
// skips servers that fail to connect and logs warnings. The
// returned cleanup function must be called to close all
// connections.
// their tools, and returns them as fantasy.AgentTool values.
// Tools are sorted by their prefixed name so callers
// receive a deterministic order. It skips servers that fail to
// connect and logs warnings. The returned cleanup function
// must be called to close all connections.
func ConnectAll(
ctx context.Context,
logger slog.Logger,
@@ -108,7 +110,9 @@ func ConnectAll(
}
mu.Lock()
clients = append(clients, mcpClient)
if mcpClient != nil {
clients = append(clients, mcpClient)
}
tools = append(tools, serverTools...)
mu.Unlock()
return nil
@@ -119,6 +123,31 @@ func ConnectAll(
// discarded.
_ = eg.Wait()
// Sort tools by prefixed name for deterministic ordering
// regardless of goroutine completion order. Ties, possible
// when the __ separator produces ambiguous prefixed names,
// are broken by config ID. Stable prompt construction
// depends on consistent tool ordering.
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
// All tools in this slice are mcpToolWrapper values
// created by connectOne above, so these checked
// assertions should always succeed. The config ID
// tiebreaker resolves the __ separator ambiguity
// documented at the top of this file.
aTool, ok := a.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", a))
}
bTool, ok := b.(MCPToolIdentifier)
if !ok {
panic(fmt.Sprintf("unexpected tool type %T", b))
}
return cmp.Or(
cmp.Compare(a.Info().Name, b.Info().Name),
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
)
})
return tools, cleanup
}
+126
View File
@@ -63,6 +63,17 @@ func greetTool() mcpserver.ServerTool {
}
}
// makeTool returns a ServerTool with the given name and a
// no-op handler that always returns "ok".
func makeTool(name string) mcpserver.ServerTool {
return mcpserver.ServerTool{
Tool: mcp.NewTool(name),
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("ok"), nil
},
}
}
// makeConfig builds a database.MCPServerConfig suitable for tests.
func makeConfig(slug, url string) database.MCPServerConfig {
return database.MCPServerConfig{
@@ -198,6 +209,121 @@ func TestConnectAll_MultipleServers(t *testing.T) {
assert.Contains(t, names, "beta__greet")
}
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts := newTestMCPServer(t, echoTool())
cfg := makeConfig("filtered", ts.URL)
cfg.ToolAllowList = []string{"greet"}
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg},
nil,
)
require.Empty(t, tools)
assert.NotPanics(t, cleanup)
}
func TestConnectAll_DeterministicOrder(t *testing.T) {
t.Parallel()
t.Run("AcrossServers", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("zebra"))
ts2 := newTestMCPServer(t, makeTool("alpha"))
ts3 := newTestMCPServer(t, makeTool("middle"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("srv3", ts3.URL),
makeConfig("srv1", ts1.URL),
makeConfig("srv2", ts2.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
// Sorted by full prefixed name (slug__tool), so slug
// order determines the sequence, not the tool name.
assert.Equal(t,
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
toolNames(tools),
)
})
t.Run("WithMultiToolServer", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
other := newTestMCPServer(t, makeTool("gamma"))
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{
makeConfig("zzz", multi.URL),
makeConfig("aaa", other.URL),
},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 3)
assert.Equal(t,
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
toolNames(tools),
)
})
t.Run("TiebreakByConfigID", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
ts1 := newTestMCPServer(t, makeTool("b__z"))
ts2 := newTestMCPServer(t, makeTool("z"))
// Use fixed UUIDs so the tiebreaker order is
// predictable. Both servers produce the same prefixed
// name, a__b__z, due to the __ separator ambiguity.
cfg1 := makeConfig("a", ts1.URL)
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
cfg2 := makeConfig("a__b", ts2.URL)
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
tools, cleanup := mcpclient.ConnectAll(
ctx,
logger,
[]database.MCPServerConfig{cfg1, cfg2},
nil,
)
t.Cleanup(cleanup)
require.Len(t, tools, 2)
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
})
}
func TestConnectAll_AuthHeaders(t *testing.T) {
t.Parallel()
ctx := context.Background()
+1
View File
@@ -212,6 +212,7 @@ type AuditLogsRequest struct {
type AuditLogResponse struct {
AuditLogs []AuditLog `json:"audit_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
type CreateTestAuditLogRequest struct {
+1
View File
@@ -96,6 +96,7 @@ type ConnectionLogsRequest struct {
type ConnectionLogResponse struct {
ConnectionLogs []ConnectionLog `json:"connection_logs"`
Count int64 `json:"count"`
CountCap int64 `json:"count_cap"`
}
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
+2 -1
View File
@@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
"user_agent": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
+2 -1
View File
@@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
"workspace_owner_username": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
+6 -2
View File
@@ -1740,7 +1740,8 @@
"user_agent": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
@@ -1750,6 +1751,7 @@
|--------------|-------------------------------------------------|----------|--------------|-------------|
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.AuthMethod
@@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
"workspace_owner_username": "string"
}
],
"count": 0
"count": 0,
"count_cap": 0
}
```
@@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
| `count` | integer | false | | |
| `count_cap` | integer | false | | |
## codersdk.ConnectionLogSSHInfo
+6
View File
@@ -16,6 +16,9 @@ import (
"github.com/coder/coder/v2/codersdk"
)
// NOTE: See the auditLogCountCap note.
const connectionLogCountCap = 2000
// @Summary Get connection logs
// @ID get-connection-logs
// @Security CoderSessionToken
@@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
filter.LimitOpt = int32(page.Limit)
countFilter.CountCap = connectionLogCountCap
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
if dbauthz.IsNotAuthorizedError(err) {
httpapi.Forbidden(rw)
@@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: []codersdk.ConnectionLog{},
Count: 0,
CountCap: connectionLogCountCap,
})
return
}
@@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
ConnectionLogs: convertConnectionLogs(dblogs),
Count: count,
CountCap: connectionLogCountCap,
})
}
+2
View File
@@ -913,6 +913,7 @@ export interface AuditLog {
export interface AuditLogResponse {
readonly audit_logs: readonly AuditLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/audit.go
@@ -2269,6 +2270,7 @@ export interface ConnectionLog {
export interface ConnectionLogResponse {
readonly connection_logs: readonly ConnectionLog[];
readonly count: number;
readonly count_cap: number;
}
// From codersdk/connectionlog.go
+2 -2
View File
@@ -35,10 +35,10 @@ export const AvatarData: FC<AvatarDataProps> = ({
}
return (
<div className="flex items-center w-full min-w-0 gap-3">
<div className="flex items-center w-full gap-3">
{avatar}
<div className="flex flex-col w-full min-w-0">
<div className="flex flex-col w-full">
<span className="text-sm font-semibold text-content-primary">
{title}
</span>
@@ -7,6 +7,7 @@ type PaginationHeaderProps = {
limit: number;
totalRecords: number | undefined;
currentOffsetStart: number | undefined;
countIsCapped?: boolean;
// Temporary escape hatch until Workspaces can be switched over to using
// PaginationContainer
@@ -18,6 +19,7 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
limit,
totalRecords,
currentOffsetStart,
countIsCapped,
className,
}) => {
const theme = useTheme();
@@ -52,10 +54,16 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
<strong>
{(
currentOffsetStart +
Math.min(limit - 1, totalRecords - currentOffsetStart)
(countIsCapped
? limit - 1
: Math.min(limit - 1, totalRecords - currentOffsetStart))
).toLocaleString()}
</strong>{" "}
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
of{" "}
<strong>
{totalRecords.toLocaleString()}
{countIsCapped && "+"}
</strong>{" "}
{paginationUnitLabel}
</div>
)}
@@ -18,6 +18,7 @@ export const mockPaginationResultBase: ResultBase = {
limit: 25,
hasNextPage: false,
hasPreviousPage: false,
countIsCapped: false,
goToPreviousPage: () => {},
goToNextPage: () => {},
goToFirstPage: () => {},
@@ -33,6 +34,7 @@ export const mockInitialRenderResult: PaginationResult = {
hasPreviousPage: false,
totalRecords: undefined,
totalPages: undefined,
countIsCapped: false,
};
export const mockSuccessResult: PaginationResult = {
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
currentPage: 2,
currentOffsetStart: 1000,
totalRecords: 123_456,
totalPages: 1235,
totalPages: 4939,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
@@ -135,3 +135,54 @@ export const SecondPageWithData: Story = {
children: <div>New data for page 2</div>,
},
};
export const CappedCountFirstPage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 1,
currentOffsetStart: 1,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: false,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountMiddlePage: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 3,
currentOffsetStart: 51,
totalRecords: 2000,
totalPages: 80,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
export const CappedCountBeyondKnownPages: Story = {
args: {
query: {
...mockPaginationResultBase,
isSuccess: true,
currentPage: 85,
currentOffsetStart: 2101,
totalRecords: 2000,
totalPages: 85,
hasPreviousPage: true,
hasNextPage: true,
isPlaceholderData: false,
countIsCapped: true,
},
},
};
@@ -27,12 +27,14 @@ export const PaginationContainer: FC<PaginationProps> = ({
totalRecords={query.totalRecords}
currentOffsetStart={query.currentOffsetStart}
paginationUnitLabel={paginationUnitLabel}
countIsCapped={query.countIsCapped}
className="justify-end"
/>
{query.isSuccess && (
<PaginationWidgetBase
totalRecords={query.totalRecords}
totalPages={query.totalPages}
currentPage={query.currentPage}
pageSize={query.limit}
onPageChange={query.onPageChange}
@@ -12,6 +12,10 @@ export type PaginationWidgetBaseProps = {
hasPreviousPage?: boolean;
hasNextPage?: boolean;
/** Override the computed totalPages.
* Used when, e.g., the row count is capped and the user navigates beyond
* the known range, so totalPages stays at least as high as currentPage. */
totalPages?: number;
};
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
@@ -21,8 +25,9 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
onPageChange,
hasPreviousPage,
hasNextPage,
totalPages: totalPagesProp,
}) => {
const totalPages = Math.ceil(totalRecords / pageSize);
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
if (totalPages < 2) {
return null;
+72
View File
@@ -258,6 +258,78 @@ describe(usePaginatedQuery.name, () => {
});
});
describe("Capped count behavior", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
// Returns count 2001 (capped) with items on pages up to page 84
// (84 * 25 = 2100 items total).
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
const totalItems = 2100;
const offset = (pageNumber - 1) * limit;
// Returns 0 items when the requested page is past the end, simulating
// an empty server response.
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
return Promise.resolve({
data: new Array(itemsOnPage).fill(pageNumber),
count: 2001,
count_cap: 2000,
});
});
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
const { result } = await render({
queryKey: mockQueryKey,
queryFn: mockCappedQueryFn,
});
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.totalRecords).toBe(2000);
});
it("hasNextPage is true when count is capped", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=80",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasNextPage).toBe(true);
});
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
expect(result.current.hasPreviousPage).toBe(true);
});
it("Does not redirect to last page when count is capped and page is valid", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=83",
);
await waitFor(() => expect(result.current.isSuccess).toBe(true));
// Should stay on page 83 — not redirect to page 80.
expect(result.current.currentPage).toBe(83);
});
it("Redirects to last known page when navigating beyond actual data", async () => {
const { result } = await render(
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
"/?page=999",
);
// Page 999 has no items. Should redirect to page 81
// (ceil(2001 / 25) = 81), the last page guaranteed to
// have data.
await waitFor(() => expect(result.current.currentPage).toBe(81));
});
});
describe("Passing in searchParams property", () => {
const mockQueryKey = vi.fn(() => ["mock"]);
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
+50 -8
View File
@@ -144,16 +144,44 @@ export function usePaginatedQuery<
placeholderData: keepPreviousData,
});
const totalRecords = query.data?.count;
const totalPages =
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
const count = query.data?.count;
const countCap = query.data?.count_cap;
const countIsCapped =
countCap !== undefined &&
countCap > 0 &&
count !== undefined &&
count > countCap;
const totalRecords = countIsCapped ? countCap : count;
let totalPages =
totalRecords !== undefined
? Math.max(
Math.ceil(totalRecords / limit),
// True count is not known; let them navigate forward
// until they hit an empty page (checked below).
countIsCapped ? currentPage : 0,
)
: undefined;
// When the true count is unknown, the user can navigate past
// all actual data. If that happens, we need to redirect (via
// updatePageIfInvalid) to the last page guaranteed to be not
// empty.
const pageIsEmpty =
query.data != null &&
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
if (pageIsEmpty) {
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
}
const hasNextPage =
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
totalRecords !== undefined &&
((countIsCapped && !pageIsEmpty) ||
limit + currentPageOffset < totalRecords);
const hasPreviousPage =
totalRecords !== undefined &&
currentPage > 1 &&
currentPageOffset - limit < totalRecords;
((countIsCapped && !pageIsEmpty) ||
currentPageOffset - limit < totalRecords);
const queryClient = useQueryClient();
const prefetchPage = useEffectEvent((newPage: number) => {
@@ -224,10 +252,14 @@ export function usePaginatedQuery<
});
useEffect(() => {
if (!query.isFetching && totalPages !== undefined) {
if (
!query.isFetching &&
totalPages !== undefined &&
currentPage > totalPages
) {
void updatePageIfInvalid(totalPages);
}
}, [updatePageIfInvalid, query.isFetching, totalPages]);
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
const onPageChange = (newPage: number) => {
// Page 1 is the only page that can be safely navigated to without knowing
@@ -236,7 +268,12 @@ export function usePaginatedQuery<
return;
}
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
// If the true count is unknown, we allow navigating past the
// known page range.
const upperBound = countIsCapped
? Number.MAX_SAFE_INTEGER
: (totalPages ?? 1);
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
if (Number.isNaN(cleanedInput)) {
return;
}
@@ -274,6 +311,7 @@ export function usePaginatedQuery<
totalRecords: totalRecords as number,
totalPages: totalPages as number,
currentOffsetStart: currentPageOffset + 1,
countIsCapped,
}
: {
isSuccess: false,
@@ -282,6 +320,7 @@ export function usePaginatedQuery<
totalRecords: undefined,
totalPages: undefined,
currentOffsetStart: undefined,
countIsCapped: false as const,
}),
};
@@ -323,6 +362,7 @@ export type PaginationResultInfo = {
totalRecords: undefined;
totalPages: undefined;
currentOffsetStart: undefined;
countIsCapped: false;
}
| {
isSuccess: true;
@@ -331,6 +371,7 @@ export type PaginationResultInfo = {
totalRecords: number;
totalPages: number;
currentOffsetStart: number;
countIsCapped: boolean;
}
);
@@ -417,6 +458,7 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
*/
export type PaginatedData = {
count: number;
count_cap?: number;
};
/**
@@ -76,6 +76,8 @@ const defaultAgentMetadata = [
},
];
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
const logs = [
"\x1b[91mCloning Git repository...",
"\x1b[2;37;41mStarting Docker Daemon...",
@@ -87,7 +89,7 @@ const logs = [
level: "info",
output: line,
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
}));
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
@@ -102,21 +104,21 @@ const tabbedLogs = [
level: "info",
output: "startup: preparing workspace",
source_id: M.MockWorkspaceAgentLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
{
id: 101,
level: "info",
output: "install: pnpm install",
source_id: installScriptLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
{
id: 102,
level: "info",
output: "install: setup complete",
source_id: installScriptLogSource.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
},
];
@@ -153,7 +155,7 @@ const overflowLogs = overflowLogSources.map((source, index) => ({
level: "info",
output: `${source.display_name}: line`,
source_id: source.id,
created_at: new Date().toISOString(),
created_at: fixedLogTimestamp,
}));
const meta: Meta<typeof AgentRow> = {
@@ -1,22 +1,19 @@
import dayjs, { type Dayjs } from "dayjs";
import { type FC, useState } from "react";
import { useQuery } from "react-query";
import { useSearchParams } from "react-router";
import { chatCostSummary } from "#/api/queries/chats";
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
import { useAuthContext } from "#/contexts/auth/AuthProvider";
import { AgentAnalyticsPageView } from "./AgentAnalyticsPageView";
import { AgentPageHeader } from "./components/AgentPageHeader";
const startDateSearchParam = "startDate";
const endDateSearchParam = "endDate";
const getDefaultDateRange = (now?: Dayjs): DateRangeValue => {
const createDateRange = (now?: Dayjs) => {
const end = now ?? dayjs();
const start = end.subtract(30, "day");
return {
startDate: end.subtract(30, "day").toDate(),
endDate: end.toDate(),
startDate: start.toISOString(),
endDate: end.toISOString(),
rangeLabel: `${start.format("MMM D")} ${end.format("MMM D, YYYY")}`,
};
};
@@ -27,44 +24,13 @@ interface AgentAnalyticsPageProps {
const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
const { user } = useAuthContext();
const [searchParams, setSearchParams] = useSearchParams();
const startDateParam = searchParams.get(startDateSearchParam)?.trim() ?? "";
const endDateParam = searchParams.get(endDateSearchParam)?.trim() ?? "";
const [defaultDateRange] = useState(() => getDefaultDateRange(now));
let dateRange = defaultDateRange;
let hasExplicitDateRange = false;
if (startDateParam && endDateParam) {
const parsedStartDate = new Date(startDateParam);
const parsedEndDate = new Date(endDateParam);
if (
!Number.isNaN(parsedStartDate.getTime()) &&
!Number.isNaN(parsedEndDate.getTime()) &&
parsedStartDate.getTime() <= parsedEndDate.getTime()
) {
dateRange = {
startDate: parsedStartDate,
endDate: parsedEndDate,
};
hasExplicitDateRange = true;
}
}
const onDateRangeChange = (value: DateRangeValue) => {
setSearchParams((prev) => {
const next = new URLSearchParams(prev);
next.set(startDateSearchParam, value.startDate.toISOString());
next.set(endDateSearchParam, value.endDate.toISOString());
return next;
});
};
const [anchor] = useState<Dayjs>(() => dayjs());
const dateRange = createDateRange(now ?? anchor);
const summaryQuery = useQuery({
...chatCostSummary(user?.id ?? "me", {
start_date: dateRange.startDate.toISOString(),
end_date: dateRange.endDate.toISOString(),
start_date: dateRange.startDate,
end_date: dateRange.endDate,
}),
enabled: Boolean(user?.id),
});
@@ -77,9 +43,7 @@ const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
isLoading={summaryQuery.isLoading}
error={summaryQuery.error}
onRetry={() => void summaryQuery.refetch()}
dateRange={dateRange}
onDateRangeChange={onDateRangeChange}
hasExplicitDateRange={hasExplicitDateRange}
rangeLabel={dateRange.rangeLabel}
/>
</ScrollArea>
);
@@ -1,21 +1,15 @@
import { BarChart3Icon } from "lucide-react";
import type { FC } from "react";
import type { ChatCostSummary } from "#/api/typesGenerated";
import {
DateRangePicker,
type DateRangeValue,
} from "#/components/DateRangePicker/DateRangePicker";
import { ChatCostSummaryView } from "./components/ChatCostSummaryView";
import { SectionHeader } from "./components/SectionHeader";
import { toInclusiveDateRange } from "./utils/dateRange";
interface AgentAnalyticsPageViewProps {
summary: ChatCostSummary | undefined;
isLoading: boolean;
error: unknown;
onRetry: () => void;
dateRange: DateRangeValue;
onDateRangeChange: (value: DateRangeValue) => void;
hasExplicitDateRange: boolean;
rangeLabel: string;
}
export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
@@ -23,15 +17,8 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
isLoading,
error,
onRetry,
dateRange,
onDateRangeChange,
hasExplicitDateRange,
rangeLabel,
}) => {
const displayDateRange = toInclusiveDateRange(
dateRange,
hasExplicitDateRange,
);
return (
<div className="flex flex-col p-4 pt-8">
<div className="mx-auto w-full max-w-3xl">
@@ -39,10 +26,10 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
label="Analytics"
description="Review your personal Coder Agents usage and cost breakdowns."
action={
<DateRangePicker
value={displayDateRange}
onChange={onDateRangeChange}
/>
<div className="flex items-center gap-2 text-xs text-content-secondary">
<BarChart3Icon className="h-4 w-4" />
<span>{rangeLabel}</span>
</div>
}
/>
+24 -18
View File
@@ -646,17 +646,24 @@ const AgentChatPage: FC = () => {
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
const chatLastModelConfigID = chatRecord?.last_model_config_id;
const sendMutation = useMutation(
// Destructure mutation results directly so the React Compiler
// tracks stable primitives/functions instead of the whole result
// object (TanStack Query v5 recreates it every render via object
// spread). Keeping no intermediate variable prevents future code
// from accidentally closing over the unstable object.
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
createChatMessage(queryClient, agentId ?? ""),
);
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
const interruptMutation = useMutation(
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
editChatMessage(queryClient, agentId ?? ""),
);
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
interruptChat(queryClient, agentId ?? ""),
);
const deleteQueuedMutation = useMutation(
const { mutateAsync: deleteQueuedMessage } = useMutation(
deleteChatQueuedMessage(queryClient, agentId ?? ""),
);
const promoteQueuedMutation = useMutation(
const { mutateAsync: promoteQueuedMessage } = useMutation(
promoteChatQueuedMessage(queryClient, agentId ?? ""),
);
@@ -754,9 +761,7 @@ const AgentChatPage: FC = () => {
hasUserFixableModelProviders,
});
const isSubmissionPending =
sendMutation.isPending ||
editMutation.isPending ||
interruptMutation.isPending;
isSendPending || isEditPending || isInterruptPending;
const isInputDisabled = !hasModelOptions || isArchived;
const handleUsageLimitError = (error: unknown): void => {
@@ -842,7 +847,7 @@ const AgentChatPage: FC = () => {
setPendingEditMessageId(editedMessageID);
scrollToBottomRef.current?.();
try {
await editMutation.mutateAsync({
await editMessage({
messageId: editedMessageID,
req: request,
});
@@ -873,9 +878,9 @@ const AgentChatPage: FC = () => {
// For queued sends the WebSocket status events handle
// clearing; for non-queued sends we clear explicitly
// below. Clearing eagerly causes a visible cutoff.
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
let response: Awaited<ReturnType<typeof sendMessage>>;
try {
response = await sendMutation.mutateAsync(request);
response = await sendMessage(request);
} catch (error) {
handleUsageLimitError(error);
throw error;
@@ -908,10 +913,10 @@ const AgentChatPage: FC = () => {
};
const handleInterrupt = () => {
if (!agentId || interruptMutation.isPending) {
if (!agentId || isInterruptPending) {
return;
}
void interruptMutation.mutateAsync();
void interrupt();
};
const handleDeleteQueuedMessage = async (id: number) => {
@@ -920,7 +925,7 @@ const AgentChatPage: FC = () => {
previousQueuedMessages.filter((message) => message.id !== id),
);
try {
await deleteQueuedMutation.mutateAsync(id);
await deleteQueuedMessage(id);
} catch (error) {
store.setQueuedMessages(previousQueuedMessages);
throw error;
@@ -941,7 +946,7 @@ const AgentChatPage: FC = () => {
store.clearStreamError();
store.setChatStatus("pending");
try {
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
const promotedMessage = await promoteQueuedMessage(id);
// Insert the promoted message into the store and cache
// immediately so it appears in the timeline without
// waiting for the WebSocket to deliver it.
@@ -990,7 +995,8 @@ const AgentChatPage: FC = () => {
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
: undefined;
const generateKeyMutation = useMutation({
// See mutation destructuring comment above (React Compiler).
const { mutate: generateKey } = useMutation({
mutationFn: () => API.getApiKey(),
});
@@ -1005,7 +1011,7 @@ const AgentChatPage: FC = () => {
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
generateKeyMutation.mutate(undefined, {
generateKey(undefined, {
onSuccess: ({ key }) => {
location.href = getVSCodeHref(editor, {
owner: workspace.owner_name,
@@ -1141,7 +1147,7 @@ const AgentChatPage: FC = () => {
compressionThreshold={compressionThreshold}
isInputDisabled={isInputDisabled}
isSubmissionPending={isSubmissionPending}
isInterruptPending={interruptMutation.isPending}
isInterruptPending={isInterruptPending}
isSidebarCollapsed={isSidebarCollapsed}
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
showSidebarPanel={showSidebarPanel}
@@ -372,7 +372,7 @@ export const DefaultAutostopNotVisibleToNonAdmin: Story = {
const canvas = within(canvasElement);
// Personal Instructions should be visible.
await canvas.findByText("Instructions");
await canvas.findByText("Personal Instructions");
// Admin-only sections should not be present.
expect(canvas.queryByText("Workspace Autostop Fallback")).toBeNull();
@@ -412,7 +412,7 @@ export const InvisibleUnicodeWarningUserPrompt: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
await canvas.findByText("Instructions");
await canvas.findByText("Personal Instructions");
const alert = await canvas.findByText(/invisible Unicode/);
expect(alert).toBeInTheDocument();
@@ -454,7 +454,7 @@ export const NoWarningForCleanPrompt: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
await canvas.findByText("Instructions");
await canvas.findByText("Personal Instructions");
await canvas.findByText("System Instructions");
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
@@ -9,13 +9,6 @@ import { Switch } from "#/components/Switch/Switch";
import { cn } from "#/utils/cn";
import { countInvisibleCharacters } from "#/utils/invisibleUnicode";
import { AdminBadge } from "./components/AdminBadge";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "./components/CollapsibleSection";
import { DurationField } from "./components/DurationField/DurationField";
import { SectionHeader } from "./components/SectionHeader";
import { TextPreviewDialog } from "./components/TextPreviewDialog";
@@ -253,43 +246,142 @@ export const AgentSettingsBehaviorPageView: FC<
};
return (
<div className="space-y-6">
<>
<SectionHeader
label="Behavior"
description="Custom instructions that shape how the agent responds in your conversations."
/>
{/* ── Personal prompt (always visible) ── */}
<form
className="space-y-2"
onSubmit={(event) => void handleSaveUserPrompt(event)}
>
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Personal Instructions
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Applied to all your conversations. Only visible to you.
</p>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isUserPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional behavior, style, and tone preferences"
value={userPromptDraft}
onChange={(event) => setLocalUserEdit(event.target.value)}
onHeightChange={(height) =>
setIsUserPromptOverflowing(height >= textareaMaxHeight)
}
disabled={isPromptSaving}
minRows={1}
/>
{userInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {userInvisibleCharCount} invisible Unicode{" "}
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
<div className="flex justify-end gap-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={() => setLocalUserEdit("")}
disabled={isPromptSaving || !userPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isPromptSaving || !isUserPromptDirty}
>
Save
</Button>
</div>
{isSaveUserPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save personal instructions.
</p>
)}
</form>
{/* Personal prompt (always visible) */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Instructions</CollapsibleSectionTitle>
<CollapsibleSectionDescription>Applied to all your conversations. Only visible to you.</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<hr className="my-5 border-0 border-t border-solid border-border" />
<UserCompactionThresholdSettings
modelConfigs={modelConfigsData ?? []}
modelConfigsError={modelConfigsError}
isLoadingModelConfigs={isLoadingModelConfigs}
thresholds={thresholds}
isThresholdsLoading={isThresholdsLoading}
thresholdsError={thresholdsError}
onSaveThreshold={onSaveThreshold}
onResetThreshold={onResetThreshold}
/>
{/* ── Admin system prompt (admin only) ── */}
{canSetSystemPrompt && (
<>
<hr className="my-5 border-0 border-t border-solid border-border" />
<form
className="space-y-2"
onSubmit={(event) => void handleSaveUserPrompt(event)}
onSubmit={(event) => void handleSaveSystemPrompt(event)}
>
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
System Instructions
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
<span>Include Coder Agents default system prompt</span>
<Button
size="xs"
variant="subtle"
type="button"
onClick={() => setShowDefaultPromptPreview(true)}
disabled={!hasLoadedSystemPrompt}
className="min-w-0 px-0 text-content-link hover:text-content-link"
>
Preview
</Button>
</div>
<Switch
checked={includeDefaultDraft}
onCheckedChange={setLocalIncludeDefault}
aria-label="Include Coder Agents default system prompt"
disabled={isSystemPromptDisabled}
/>
</div>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
{includeDefaultDraft
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
</p>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isUserPromptOverflowing && textareaOverflowClassName,
isSystemPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional behavior, style, and tone preferences"
value={userPromptDraft}
onChange={(event) => setLocalUserEdit(event.target.value)}
placeholder="Additional instructions for all users"
value={systemPromptDraft}
onChange={(event) => setLocalEdit(event.target.value)}
onHeightChange={(height) =>
setIsUserPromptOverflowing(height >= textareaMaxHeight)
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
}
disabled={isPromptSaving}
disabled={isSystemPromptDisabled}
minRows={1}
/>
{userInvisibleCharCount > 0 && (
{systemInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {userInvisibleCharCount} invisible Unicode{" "}
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
could hide content. These will be stripped on save.
This text contains {systemInvisibleCharCount} invisible
Unicode{" "}
{systemInvisibleCharCount !== 1 ? "characters" : "character"}{" "}
that could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
@@ -298,304 +390,164 @@ export const AgentSettingsBehaviorPageView: FC<
size="sm"
variant="outline"
type="button"
onClick={() => setLocalUserEdit("")}
disabled={isPromptSaving || !userPromptDraft}
onClick={() => setLocalEdit("")}
disabled={isSystemPromptDisabled || !systemPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isPromptSaving || !isUserPromptDirty}
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
>
Save
</Button>
</Button>{" "}
</div>
{isSaveUserPromptError && (
{isSaveSystemPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save personal instructions.
Failed to save system prompt.
</p>
)}
</form>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Context compaction */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Context compaction</CollapsibleSectionTitle>
<CollapsibleSectionDescription>Control when conversation context is automatically summarized for each model. Setting 100% means the conversation will never auto-compact.</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<UserCompactionThresholdSettings
hideHeader
modelConfigs={modelConfigsData ?? []}
modelConfigsError={modelConfigsError}
isLoadingModelConfigs={isLoadingModelConfigs}
thresholds={thresholds}
isThresholdsLoading={isThresholdsLoading}
thresholdsError={thresholdsError}
onSaveThreshold={onSaveThreshold}
onResetThreshold={onResetThreshold}
/>
</CollapsibleSectionContent>
</CollapsibleSection>
{/* Admin system prompt (admin only) */}
{canSetSystemPrompt && (
<>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>System Instructions</CollapsibleSectionTitle>
<AdminBadge />
</div>
<CollapsibleSectionDescription>
{includeDefaultDraft
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<form
className="space-y-2"
onSubmit={(event) => void handleSaveSystemPrompt(event)}
>
<div className="flex items-center justify-between gap-4">
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
<span>Include Coder Agents default system prompt</span>
<Button
size="xs"
variant="subtle"
type="button"
onClick={() => setShowDefaultPromptPreview(true)}
disabled={!hasLoadedSystemPrompt}
className="min-w-0 px-0 text-content-link hover:text-content-link"
>
Preview
</Button>
</div>
<Switch
checked={includeDefaultDraft}
onCheckedChange={setLocalIncludeDefault}
aria-label="Include Coder Agents default system prompt"
disabled={isSystemPromptDisabled}
/>
</div>
<TextareaAutosize
className={cn(
textareaBaseClassName,
isSystemPromptOverflowing && textareaOverflowClassName,
)}
placeholder="Additional instructions for all users"
value={systemPromptDraft}
onChange={(event) => setLocalEdit(event.target.value)}
onHeightChange={(height) =>
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
}
disabled={isSystemPromptDisabled}
minRows={1}
/>
{systemInvisibleCharCount > 0 && (
<Alert severity="warning">
<AlertDescription>
This text contains {systemInvisibleCharCount} invisible
Unicode{" "}
{systemInvisibleCharCount !== 1
? "characters"
: "character"}{" "}
that could hide content. These will be stripped on save.
</AlertDescription>
</Alert>
)}
<div className="flex justify-end gap-2">
<Button
<hr className="my-5 border-0 border-t border-solid border-border" />
<div className="space-y-2">
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Virtual Desktop
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<div className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
<p className="m-0">
Allow agents to use a virtual, graphical desktop within
workspaces. Requires the{" "}
<Link
href="https://registry.coder.com/modules/coder/portabledesktop"
target="_blank"
size="sm"
variant="outline"
type="button"
onClick={() => setLocalEdit("")}
disabled={isSystemPromptDisabled || !systemPromptDraft}
>
Clear
</Button>
<Button
size="sm"
type="submit"
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
>
Save
</Button>
</div>
{isSaveSystemPromptError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save system prompt.
</p>
)}
</form>
</CollapsibleSectionContent>
</CollapsibleSection>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Virtual Desktop</CollapsibleSectionTitle>
<AdminBadge />
</div>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable virtual desktop
</span>
<p className="m-0 text-xs text-content-secondary">
Allow agents to use a virtual, graphical desktop within
workspaces. Requires the{" "}
<Link
href="https://registry.coder.com/modules/coder/portabledesktop"
target="_blank"
size="sm"
>
portabledesktop module
</Link>{" "}
to be installed in the workspace and the Anthropic provider to
be configured.
</p>
<p className="m-0 text-xs text-content-secondary font-semibold">
Warning: This is a work-in-progress feature, and you're likely
to encounter bugs if you enable it.
</p>
</div>
<Switch
checked={desktopEnabled}
onCheckedChange={(checked) =>
onSaveDesktopEnabled({ enable_desktop: checked })
}
aria-label="Enable"
disabled={isSavingDesktopEnabled}
/>
</div>{" "}
{isSaveDesktopEnabledError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save desktop setting.
portabledesktop module
</Link>{" "}
to be installed in the workspace and the Anthropic provider to
be configured.
</p>
<p className="mt-2 mb-0 font-semibold text-content-secondary">
Warning: This is a work-in-progress feature, and you're likely
to encounter bugs if you enable it.
</p>
)}
</CollapsibleSectionContent>
</CollapsibleSection>
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Workspace Autostop Fallback</CollapsibleSectionTitle>
<AdminBadge />
</div>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable autostop fallback
</span>
<p className="m-0 text-xs text-content-secondary">
Set a default autostop for agent-created workspaces that don't
have one defined in their template. Template-defined autostop
rules always take precedence. Active conversations will extend
the stop time.
</p>
</div>
<Switch
checked={isAutostopEnabled}
onCheckedChange={handleToggleAutostop}
aria-label="Enable default autostop"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
/>
<Switch
checked={desktopEnabled}
onCheckedChange={(checked) =>
onSaveDesktopEnabled({ enable_desktop: checked })
}
aria-label="Enable"
disabled={isSavingDesktopEnabled}
/>
</div>
{isSaveDesktopEnabledError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save desktop setting.
</p>
)}
</div>
<hr className="my-5 border-0 border-t border-solid border-border" />
<form
className="space-y-2"
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
>
<div className="flex items-center gap-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Workspace Autostop Fallback
</h3>
<AdminBadge />
</div>
<div className="flex items-center justify-between gap-4">
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
Set a default autostop for agent-created workspaces that don't
have one defined in their template. Template-defined autostop
rules always take precedence. Active conversations will extend
the stop time.
</p>
<Switch
checked={isAutostopEnabled}
onCheckedChange={handleToggleAutostop}
aria-label="Enable default autostop"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
/>{" "}
</div>
{isAutostopEnabled && (
<DurationField
valueMs={ttlMs}
onChange={handleTTLChange}
label="Autostop Fallback"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
error={isTTLOverMax || isTTLZero}
helperText={
isTTLZero
? "Duration must be greater than zero."
: isTTLOverMax
? "Must not exceed 30 days (720 hours)."
: undefined
}
/>
)}
{isAutostopEnabled && (
<div className="flex justify-end">
<Button
size="sm"
type="submit"
disabled={
isSavingWorkspaceTTL ||
!isTTLDirty ||
isTTLOverMax ||
isTTLZero
}
>
Save
</Button>
</div>
<form
className="space-y-2 pt-4"
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
>
{" "}
{isAutostopEnabled && (
<DurationField
valueMs={ttlMs}
onChange={handleTTLChange}
label="Autostop Fallback"
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
error={isTTLOverMax || isTTLZero}
helperText={
isTTLZero
? "Duration must be greater than zero."
: isTTLOverMax
? "Must not exceed 30 days (720 hours)."
: undefined
}
/>
)}
{isAutostopEnabled && (
<div className="flex justify-end">
<Button
size="sm"
type="submit"
disabled={
isSavingWorkspaceTTL ||
!isTTLDirty ||
isTTLOverMax ||
isTTLZero
}
>
Save
</Button>
</div>
)}
{isSaveWorkspaceTTLError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save autostop setting.
</p>
)}
{isWorkspaceTTLLoadError && (
<p className="m-0 text-xs text-content-destructive">
Failed to load autostop setting.
</p>
)}
</form>
</CollapsibleSectionContent>
</CollapsibleSection>
)}
{isSaveWorkspaceTTLError && (
<p className="m-0 text-xs text-content-destructive">
Failed to save autostop setting.
</p>
)}
{isWorkspaceTTLLoadError && (
<p className="m-0 text-xs text-content-destructive">
Failed to load autostop setting.
</p>
)}
</form>
</>
)}
{/* Kyleosophy toggle (always visible) */}
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Kyleosophy</CollapsibleSectionTitle>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="flex items-start gap-4">
<div className="min-w-0 flex-1 space-y-1">
<span className="text-sm font-medium text-content-primary">
Enable Kyleosophy
<hr className="my-5 border-0 border-t border-solid border-border" />
{/* ── Kyleosophy toggle (always visible) ── */}
<div className="space-y-2">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Kyleosophy
</h3>
<div className="flex items-center justify-between gap-4">
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
Replace the standard completion chime. IYKYK.
{kylesophyForced && (
<span className="ml-1 font-semibold">
Kyleosophy is mandatory on <code>dev.coder.com</code>.
</span>
<p className="m-0 text-xs text-content-secondary">
{" "}
Replace the standard completion chime. IYKYK.
{kylesophyForced && (
<span className="ml-1 font-semibold">
Kyleosophy is mandatory on <code>dev.coder.com</code>.
</span>
)}
</p>
</div>
<Switch
checked={kylesophyEnabled}
onCheckedChange={(checked) => {
setKylesophyEnabled(checked);
setLocalKylesophy(checked);
}}
aria-label="Enable Kyleosophy"
disabled={kylesophyForced}
/>
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
)}
</p>
<Switch
checked={kylesophyEnabled}
onCheckedChange={(checked) => {
setKylesophyEnabled(checked);
setLocalKylesophy(checked);
}}
aria-label="Enable Kyleosophy"
disabled={kylesophyForced}
/>
</div>
</div>
{showDefaultPromptPreview && (
<TextPreviewDialog
content={defaultSystemPrompt}
@@ -603,6 +555,6 @@ export const AgentSettingsBehaviorPageView: FC<
onClose={() => setShowDefaultPromptPreview(false)}
/>
)}
</div>
</>
);
};
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
},
};
@@ -285,7 +285,7 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
const modelConfigsUnavailable = modelConfigsData === null;
return (
<div className={cn("flex min-h-full flex-col space-y-6", className)}>
<div className={cn("flex min-h-full flex-col space-y-3", className)}>
{isLoading && (
<div className="flex items-center gap-1.5 text-xs text-content-secondary">
<Spinner className="h-4 w-4" loading />
@@ -1,5 +1,11 @@
import { useFormik } from "formik";
import { ChevronLeftIcon, InfoIcon, PencilIcon } from "lucide-react";
import {
ChevronDownIcon,
ChevronLeftIcon,
ChevronRightIcon,
InfoIcon,
PencilIcon,
} from "lucide-react";
import { type FC, useState } from "react";
import * as Yup from "yup";
import type * as TypesGen from "#/api/typesGenerated";
@@ -35,13 +41,6 @@ import {
} from "#/components/Tooltip/Tooltip";
import { cn } from "#/utils/cn";
import { getFormHelpers } from "#/utils/formUtils";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "../CollapsibleSection";
import type { ProviderState } from "./ChatModelAdminPanel";
import {
GeneralModelConfigFields,
@@ -117,6 +116,9 @@ export const ModelForm: FC<ModelFormProps> = ({
}) => {
const isEditing = Boolean(editingModel);
const isDefaultModel = isEditing && editingModel?.is_default === true;
const [showAdvanced, setShowAdvanced] = useState(false);
const [showPricing, setShowPricing] = useState(false);
const [showProviderConfig, setShowProviderConfig] = useState(false);
const [confirmingDelete, setConfirmingDelete] = useState(false);
const canManageModels = Boolean(
@@ -486,18 +488,29 @@ export const ModelForm: FC<ModelFormProps> = ({
</div>
{/* Usage Tracking */}
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Cost Tracking
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Set per-token pricing so Coder can track costs and enforce
spending limits.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="grid grid-cols-2 gap-3 sm:grid-cols-4">
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowPricing((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Cost Tracking{" "}
</h3>
<p className="m-0 text-xs text-content-secondary">
Set per-token pricing so Coder can track costs and enforce
spending limits.
</p>
</div>
{showPricing ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showPricing && (
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-4">
<PricingModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -505,21 +518,33 @@ export const ModelForm: FC<ModelFormProps> = ({
disabled={isSaving}
/>
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
)}
</div>
{/* Provider Configuration */}
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Provider Configuration
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Tune provider-specific behavior like reasoning, tool calling,
and web search.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowProviderConfig((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Provider Configuration
</h3>
<p className="m-0 text-xs text-content-secondary">
Tune provider-specific behavior like reasoning, tool calling,
and web search.
</p>
</div>
{showProviderConfig ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showProviderConfig && (
<div className="pt-3">
<ModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -527,21 +552,33 @@ export const ModelForm: FC<ModelFormProps> = ({
disabled={isSaving}
/>
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
)}
</div>
{/* Advanced */}
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">
Advanced
</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Low-level parameters like temperature and penalties. Rarely need
changing.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<div className="grid grid-cols-2 gap-3 sm:grid-cols-3">
<div className="border-0 border-t border-solid border-border pt-4">
<button
type="button"
onClick={() => setShowAdvanced((v) => !v)}
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
>
<div>
<h3 className="m-0 text-sm font-medium text-content-primary">
Advanced
</h3>
<p className="m-0 text-xs text-content-secondary">
Low-level parameters like temperature and penalties. Rarely
need changing.
</p>
</div>
{showAdvanced ? (
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
) : (
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
)}
</button>
{showAdvanced && (
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-3">
<GeneralModelConfigFields
provider={selectedProviderState.provider}
form={form}
@@ -592,11 +629,10 @@ export const ModelForm: FC<ModelFormProps> = ({
)}
</div>
</div>
</CollapsibleSectionContent>
</CollapsibleSection>
)}
</div>
</div>
<div className="mt-auto py-6">
{" "}
<hr className="mb-4 border-0 border-t border-solid border-border" />
<div className="flex items-center justify-between">
{isEditing && editingModel && onDeleteModel ? (
@@ -1,193 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, userEvent, within } from "storybook/test";
import { AdminBadge } from "./AdminBadge";
import {
CollapsibleSection,
CollapsibleSectionContent,
CollapsibleSectionDescription,
CollapsibleSectionHeader,
CollapsibleSectionTitle,
} from "./CollapsibleSection";
const meta: Meta<typeof CollapsibleSection> = {
title: "pages/AgentsPage/CollapsibleSection",
component: CollapsibleSection,
decorators: [
(Story) => (
<div style={{ maxWidth: 600 }}>
<Story />
</div>
),
],
};
export default meta;
type Story = StoryObj<typeof CollapsibleSection>;
const Placeholder = () => (
<p className="text-sm text-content-secondary">Placeholder content</p>
);
export const DefaultOpen: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
<AdminBadge />
</div>
<CollapsibleSectionDescription>
The deployment-wide spending cap.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Default spend limit/i,
});
expect(header).toHaveAttribute("aria-expanded", "true");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
},
};
export const Collapsed: Story = {
render: () => (
<CollapsibleSection defaultOpen={false}>
<CollapsibleSectionHeader>
<div className="flex items-center gap-2">
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
<AdminBadge />
</div>
<CollapsibleSectionDescription>
The deployment-wide spending cap.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Default spend limit/i,
});
expect(header).toHaveAttribute("aria-expanded", "false");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
},
};
export const NoBadge: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Group limits</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Override defaults for groups.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
};
export const InlineVariant: Story = {
render: () => (
<CollapsibleSection variant="inline" defaultOpen={false}>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle as="h3">Cost Tracking</CollapsibleSectionTitle>
<CollapsibleSectionDescription>
Set per-token pricing so Coder can track costs and enforce spending
limits.
</CollapsibleSectionDescription>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
const header = canvas.getByRole("button", {
name: /Cost Tracking/i,
});
expect(header).toHaveAttribute("aria-expanded", "false");
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.click(header);
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
},
};
export const KeyboardToggle: Story = {
render: () => (
<CollapsibleSection>
<CollapsibleSectionHeader>
<CollapsibleSectionTitle>Keyboard section</CollapsibleSectionTitle>
</CollapsibleSectionHeader>
<CollapsibleSectionContent>
<Placeholder />
</CollapsibleSectionContent>
</CollapsibleSection>
),
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const header = canvas.getByRole("button", {
name: /Keyboard section/i,
});
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
header.focus();
await userEvent.keyboard("{Enter}");
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.keyboard("{Enter}");
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
await userEvent.keyboard(" ");
expect(header).toHaveAttribute("aria-expanded", "false");
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
await userEvent.keyboard(" ");
expect(header).toHaveAttribute("aria-expanded", "true");
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
},
};
@@ -1,211 +0,0 @@
import { cva, type VariantProps } from "class-variance-authority";
import { ChevronDownIcon } from "lucide-react";
import {
type ComponentProps,
createContext,
type FC,
type ReactNode,
useContext,
useState,
} from "react";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "#/components/Collapsible/Collapsible";
import { cn } from "#/utils/cn";
// ---------------------------------------------------------------------------
// Variant styles
// ---------------------------------------------------------------------------
const wrapperStyles = cva("", {
variants: {
variant: {
card: "rounded-lg border border-solid border-border-default",
inline: "border-0 border-t border-solid border-border-default pt-4",
},
},
defaultVariants: { variant: "card" },
});
const triggerStyles = cva(
"flex w-full cursor-pointer items-start justify-between gap-4 border-0 bg-transparent text-left focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-content-link",
{
variants: {
variant: {
card: "rounded-lg px-6 py-5",
inline: "rounded-md p-0",
},
},
defaultVariants: { variant: "card" },
},
);
const titleStyles = cva("m-0 text-content-primary", {
variants: {
variant: {
card: "text-lg font-semibold",
inline: "text-sm font-medium",
},
},
defaultVariants: { variant: "card" },
});
const descriptionStyles = cva("m-0 text-content-secondary", {
variants: {
variant: {
card: "mt-1 text-sm",
inline: "text-xs",
},
},
defaultVariants: { variant: "card" },
});
const contentStyles = cva("", {
variants: {
variant: {
card: "border-0 border-t border-solid border-border-default px-6 pb-5 pt-4",
inline: "pt-3",
},
},
defaultVariants: { variant: "card" },
});
// ---------------------------------------------------------------------------
// Context
// ---------------------------------------------------------------------------
type Variant = "card" | "inline";
interface CollapsibleSectionContextValue {
variant: Variant;
open: boolean;
}
const CollapsibleSectionContext = createContext<CollapsibleSectionContextValue>(
{
variant: "card",
open: true,
},
);
const useCollapsibleSection = () => useContext(CollapsibleSectionContext);
// ---------------------------------------------------------------------------
// Root
// ---------------------------------------------------------------------------
interface CollapsibleSectionProps extends VariantProps<typeof wrapperStyles> {
defaultOpen?: boolean;
/** Controlled open state. */
open?: boolean;
onOpenChange?: (open: boolean) => void;
children: ReactNode;
}
export const CollapsibleSection: FC<CollapsibleSectionProps> = ({
defaultOpen,
open: controlledOpen,
onOpenChange: controlledOnOpenChange,
variant = "card",
children,
}) => {
const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen ?? true);
const isControlled = controlledOpen !== undefined;
const open = isControlled ? controlledOpen : uncontrolledOpen;
const onOpenChange = isControlled
? controlledOnOpenChange
: setUncontrolledOpen;
return (
<CollapsibleSectionContext.Provider
value={{ variant: variant ?? "card", open }}
>
<Collapsible open={open} onOpenChange={onOpenChange}>
<div className={wrapperStyles({ variant })}>{children}</div>
</Collapsible>
</CollapsibleSectionContext.Provider>
);
};
// ---------------------------------------------------------------------------
// Header (trigger)
// ---------------------------------------------------------------------------
interface CollapsibleSectionHeaderProps {
children: ReactNode;
}
export const CollapsibleSectionHeader: FC<CollapsibleSectionHeaderProps> = ({
children,
}) => {
const { variant, open } = useCollapsibleSection();
return (
<CollapsibleTrigger className={triggerStyles({ variant })}>
<div className="min-w-0 flex-1">{children}</div>
<ChevronDownIcon
className={cn(
"h-4 w-4 shrink-0 text-content-secondary transition-transform duration-200",
open && "rotate-180",
)}
/>
</CollapsibleTrigger>
);
};
// ---------------------------------------------------------------------------
// Title — renders the heading element at whatever level the consumer picks.
// ---------------------------------------------------------------------------
type HeadingTag = "h1" | "h2" | "h3" | "h4" | "h5" | "h6";
interface CollapsibleSectionTitleProps extends ComponentProps<HeadingTag> {
as?: HeadingTag;
}
export const CollapsibleSectionTitle: FC<CollapsibleSectionTitleProps> = ({
as: Component = "h2",
className,
...props
}) => {
const { variant } = useCollapsibleSection();
return (
<Component className={cn(titleStyles({ variant }), className)} {...props} />
);
};
// ---------------------------------------------------------------------------
// Description
// ---------------------------------------------------------------------------
export const CollapsibleSectionDescription: FC<ComponentProps<"p">> = ({
className,
...props
}) => {
const { variant } = useCollapsibleSection();
return (
<p className={cn(descriptionStyles({ variant }), className)} {...props} />
);
};
// ---------------------------------------------------------------------------
// Content
// ---------------------------------------------------------------------------
interface CollapsibleSectionContentProps {
children: ReactNode;
}
export const CollapsibleSectionContent: FC<CollapsibleSectionContentProps> = ({
children,
}) => {
const { variant } = useCollapsibleSection();
return (
<CollapsibleContent>
<div className={contentStyles({ variant })}>{children}</div>
</CollapsibleContent>
);
};
@@ -989,7 +989,7 @@ export const MCPServerAdminPanel: FC<MCPServerAdminPanelProps> = ({
}
return (
<div className="flex min-h-full flex-col space-y-6">
<div className="flex min-h-full flex-col space-y-3">
{!isFormView ? (
<ServerList
servers={servers}
@@ -21,9 +21,7 @@ import {
} from "#/components/Table/Table";
import { cn } from "#/utils/cn";
import { formatCostMicros } from "#/utils/currency";
import { AdminBadge } from "./AdminBadge";
import { PrStateIcon } from "./GitPanel/GitPanel";
import { SectionHeader } from "./SectionHeader";
dayjs.extend(relativeTime);
@@ -297,16 +295,19 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
const isEmpty = summary.total_prs_created === 0;
return (
<div className="space-y-6">
<div className="space-y-8">
{/* ── Header ── */}
<SectionHeader
label="Pull Request Insights"
description="Code changes detected by Agents."
badge={<AdminBadge />}
action={
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
}
/>
<div className="flex items-end justify-between">
<div>
<h2 className="m-0 text-xl font-semibold tracking-tight text-content-primary">
Pull Request Insights
</h2>
<p className="m-0 mt-1 text-[13px] text-content-secondary">
Code changes detected by Agents.
</p>
</div>
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
</div>
{isEmpty ? (
<EmptyState />
@@ -354,7 +355,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
</section>
{/* ── Model breakdown + Recent PRs side by side ── */}
<div className="grid grid-cols-1 gap-6">
<div className="grid grid-cols-1 gap-6 lg:grid-cols-2">
{/* ── Model performance (simplified) ── */}
{by_model.length > 0 && (
<section>
@@ -15,8 +15,8 @@ export const SectionHeader: FC<SectionHeaderProps> = ({
}) => (
<>
<div className="flex items-start justify-between gap-4">
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<div>
<div className="flex w-full items-center gap-2">
<h2 className="m-0 text-lg font-medium text-content-primary">
{label}
</h2>
@@ -78,8 +78,11 @@ export const Default: Story = {
expect(canvas.getByText("Claude Sonnet")).toBeInTheDocument();
expect(canvas.queryByText("GPT-3.5 (Disabled)")).not.toBeInTheDocument();
// Save button should be disabled when nothing is dirty.
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
// No footer visible when nothing is dirty
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
// Type a value to make the footer appear
await userEvent.type(gpt4oInput, "95");
await waitFor(() => {
@@ -160,9 +163,11 @@ export const CancelChanges: Story = {
const cancelButton = await canvas.findByRole("button", { name: /Cancel/i });
await userEvent.click(cancelButton);
// Save button should be disabled after cancel.
// Footer should disappear after cancel
await waitFor(() => {
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
});
// Input should be cleared back to empty (no override)
@@ -189,8 +194,10 @@ export const InvalidDraftShowsFooter: Story = {
// Cancel button should be visible so user can discard the edit
expect(canvas.getByRole("button", { name: /Cancel/i })).toBeInTheDocument();
// Save button should be disabled (nothing valid to save).
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
// Save button should NOT be visible (nothing valid to save)
expect(
canvas.queryByRole("button", { name: /Save/i }),
).not.toBeInTheDocument();
},
};
@@ -9,6 +9,7 @@ import {
Table,
TableBody,
TableCell,
TableFooter,
TableHead,
TableHeader,
TableRow,
@@ -32,7 +33,6 @@ interface UserCompactionThresholdSettingsProps {
thresholdPercent: number,
) => Promise<unknown>;
onResetThreshold: (modelConfigId: string) => Promise<unknown>;
hideHeader?: boolean;
}
const parseThresholdDraft = (value: string): number | null => {
@@ -60,7 +60,6 @@ export const UserCompactionThresholdSettings: FC<
thresholdsError,
onSaveThreshold,
onResetThreshold,
hideHeader,
}) => {
const [drafts, setDrafts] = useState<Record<string, string>>({});
const [rowErrors, setRowErrors] = useState<Record<string, string>>({});
@@ -173,23 +172,19 @@ export const UserCompactionThresholdSettings: FC<
};
const hasAnyPending = pendingModels.size > 0;
const headerBlock = !hideHeader ? (
<>
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
</>
) : null;
const hasAnyErrors = Object.keys(rowErrors).length > 0;
const hasAnyDrafts = Object.keys(drafts).length > 0;
if (isThresholdsLoading) {
return (
<div className="space-y-2">
{headerBlock}
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
<div className="flex items-center gap-2 text-sm text-content-secondary">
<Spinner loading className="h-4 w-4" />
Loading thresholds...
@@ -201,7 +196,13 @@ export const UserCompactionThresholdSettings: FC<
if (thresholdsError != null) {
return (
<div className="space-y-2">
{headerBlock}
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
<p className="m-0 text-xs text-content-destructive">
{getErrorMessage(
thresholdsError,
@@ -214,7 +215,13 @@ export const UserCompactionThresholdSettings: FC<
return (
<div className="space-y-2">
{headerBlock}
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Context Compaction
</h3>
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
Control when conversation context is automatically summarized for each
model. Setting 100% means the conversation will never auto-compact.
</p>
{isLoadingModelConfigs ? (
<div className="flex items-center gap-2 text-sm text-content-secondary">
<Spinner loading className="h-4 w-4" />
@@ -233,163 +240,165 @@ export const UserCompactionThresholdSettings: FC<
models before compaction thresholds can be set.
</p>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
<TableHead className="w-0 whitespace-nowrap">
Threshold
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{enabledModelConfigs.map((modelConfig) => {
const existingOverride = overridesByModelID.get(modelConfig.id);
const hasOverride = overridesByModelID.has(modelConfig.id);
const draftValue =
drafts[modelConfig.id] ??
(existingOverride !== undefined
? String(existingOverride)
: "");
const parsedDraftValue = parseThresholdDraft(draftValue);
const isThisModelMutating = pendingModels.has(modelConfig.id);
const isInvalid =
draftValue.length > 0 && parsedDraftValue === null;
// Only warn when user-typed, not when loaded from
// the server.
const isDraftDisablingCompaction =
draftValue === "100" && drafts[modelConfig.id] !== undefined;
const rowError = rowErrors[modelConfig.id];
const modelName = modelConfig.display_name || modelConfig.model;
<Table>
<TableHeader>
<TableRow>
<TableHead>Model</TableHead>
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
<TableHead className="w-0 whitespace-nowrap">Threshold</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{enabledModelConfigs.map((modelConfig) => {
const existingOverride = overridesByModelID.get(modelConfig.id);
const hasOverride = overridesByModelID.has(modelConfig.id);
const draftValue =
drafts[modelConfig.id] ??
(existingOverride !== undefined
? String(existingOverride)
: "");
const parsedDraftValue = parseThresholdDraft(draftValue);
const isThisModelMutating = pendingModels.has(modelConfig.id);
const isInvalid =
draftValue.length > 0 && parsedDraftValue === null;
// Only warn when user-typed, not when loaded from
// the server.
const isDraftDisablingCompaction =
draftValue === "100" && drafts[modelConfig.id] !== undefined;
const rowError = rowErrors[modelConfig.id];
const modelName = modelConfig.display_name || modelConfig.model;
return (
<TableRow key={modelConfig.id}>
<TableCell className="text-[13px] font-medium text-content-primary">
{modelName}
{rowError && (
<p
aria-live="polite"
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
>
{rowError}
</p>
)}
</TableCell>
<TableCell className="w-0 whitespace-nowrap tabular-nums">
{modelConfig.compression_threshold}%
</TableCell>
<TableCell className="w-0 whitespace-nowrap">
<div className="flex items-center gap-1.5">
<Tooltip>
<TooltipTrigger asChild>
<Input
aria-label={`${modelName} compaction threshold`}
aria-invalid={isInvalid || undefined}
type="number"
min={0}
max={100}
inputMode="numeric"
className={cn(
"h-7 w-16 px-2 text-xs tabular-nums",
isInvalid &&
"border-content-destructive focus:ring-content-destructive/30",
)}
value={draftValue}
placeholder={String(
modelConfig.compression_threshold,
)}
onChange={(event) => {
setDrafts((currentDrafts) => ({
...currentDrafts,
[modelConfig.id]: event.target.value,
}));
clearRowError(modelConfig.id);
}}
disabled={isThisModelMutating}
/>
</TooltipTrigger>
{(isInvalid || isDraftDisablingCompaction) && (
<TooltipContent>
{isInvalid
? "Enter a whole number between 0 and 100."
: "Setting 100% will disable auto-compaction for this model."}
</TooltipContent>
)}
</Tooltip>
<span className="text-xs text-content-secondary">
%
</span>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="icon"
variant="subtle"
className={cn(
"size-7",
hasOverride
? "opacity-100"
: "pointer-events-none opacity-0",
)}
aria-label={`Reset ${modelName} to default`}
aria-hidden={!hasOverride}
tabIndex={hasOverride ? 0 : -1}
disabled={isThisModelMutating || !hasOverride}
onClick={() => handleReset(modelConfig.id)}
>
<RotateCcwIcon className="size-3.5" />
</Button>
</TooltipTrigger>
{hasOverride && (
<TooltipContent>
Reset to default (
{modelConfig.compression_threshold}%)
</TooltipContent>
)}
</Tooltip>
</div>
{isInvalid && (
<span className="sr-only" aria-live="polite">
Enter a whole number between 0 and 100.
</span>
)}
{isDraftDisablingCompaction && (
<span className="sr-only" aria-live="polite">
Setting 100% will disable auto-compaction for this
model.
</span>
)}
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
<div className="flex justify-end gap-2 pt-2">
<Button
size="sm"
variant="outline"
type="button"
onClick={handleCancelAll}
disabled={hasAnyPending}
>
Cancel
</Button>
<Button
size="sm"
type="button"
disabled={hasAnyPending || dirtyRows.length === 0}
onClick={handleSaveAll}
>
{hasAnyPending
? "Saving..."
: dirtyRows.length > 0
? `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`
: "Save"}
</Button>
</div>
</>
return (
<TableRow key={modelConfig.id}>
<TableCell className="text-[13px] font-medium text-content-primary">
{modelName}
{rowError && (
<p
aria-live="polite"
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
>
{rowError}
</p>
)}
</TableCell>
<TableCell className="w-0 whitespace-nowrap tabular-nums">
{modelConfig.compression_threshold}%
</TableCell>
<TableCell className="w-0 whitespace-nowrap">
<div className="flex items-center gap-1.5">
<Tooltip>
<TooltipTrigger asChild>
<Input
aria-label={`${modelName} compaction threshold`}
aria-invalid={isInvalid || undefined}
type="number"
min={0}
max={100}
inputMode="numeric"
className={cn(
"h-7 w-16 px-2 text-xs tabular-nums",
isInvalid &&
"border-content-destructive focus:ring-content-destructive/30",
)}
value={draftValue}
placeholder={String(
modelConfig.compression_threshold,
)}
onChange={(event) => {
setDrafts((currentDrafts) => ({
...currentDrafts,
[modelConfig.id]: event.target.value,
}));
clearRowError(modelConfig.id);
}}
disabled={isThisModelMutating}
/>
</TooltipTrigger>
{(isInvalid || isDraftDisablingCompaction) && (
<TooltipContent>
{isInvalid
? "Enter a whole number between 0 and 100."
: "Setting 100% will disable auto-compaction for this model."}
</TooltipContent>
)}
</Tooltip>
<span className="text-xs text-content-secondary">%</span>
<Tooltip>
<TooltipTrigger asChild>
<Button
size="icon"
variant="subtle"
className={cn(
"size-7",
hasOverride
? "opacity-100"
: "pointer-events-none opacity-0",
)}
aria-label={`Reset ${modelName} to default`}
aria-hidden={!hasOverride}
tabIndex={hasOverride ? 0 : -1}
disabled={isThisModelMutating || !hasOverride}
onClick={() => handleReset(modelConfig.id)}
>
<RotateCcwIcon className="size-3.5" />
</Button>
</TooltipTrigger>
{hasOverride && (
<TooltipContent>
Reset to default (
{modelConfig.compression_threshold}%)
</TooltipContent>
)}
</Tooltip>
</div>
{isInvalid && (
<span className="sr-only" aria-live="polite">
Enter a whole number between 0 and 100.
</span>
)}
{isDraftDisablingCompaction && (
<span className="sr-only" aria-live="polite">
Setting 100% will disable auto-compaction for this
model.
</span>
)}
</TableCell>
</TableRow>
);
})}
</TableBody>
{(dirtyRows.length > 0 || hasAnyErrors || hasAnyDrafts) && (
<TableFooter className="bg-transparent">
<TableRow className="border-0">
<TableCell colSpan={3} className="border-0 p-0">
<div className="flex items-center justify-end gap-2 px-3 py-1.5">
<Button
size="sm"
variant="outline"
type="button"
onClick={handleCancelAll}
disabled={hasAnyPending}
>
Cancel
</Button>
{dirtyRows.length > 0 && (
<Button
size="sm"
type="button"
disabled={hasAnyPending}
onClick={handleSaveAll}
>
{hasAnyPending
? "Saving..."
: `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`}
</Button>
)}
</div>
</TableCell>
</TableRow>
</TableFooter>
)}
</Table>
)}
</div>
);
@@ -1,35 +0,0 @@
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
/**
* Returns true when the given date falls exactly on midnight
* (00:00:00.000). Date-range pickers use midnight of the *following*
* day as the exclusive upper bound for a full-day selection. Detecting
* this lets call sites subtract 1 ms (or 1 day) so the UI shows the
* inclusive end date instead.
*/
function isMidnight(date: Date): boolean {
return (
date.getHours() === 0 &&
date.getMinutes() === 0 &&
date.getSeconds() === 0 &&
date.getMilliseconds() === 0
);
}
/**
* When the user picks an explicit date range whose end boundary is
* midnight of the following day, adjust it by 1 ms so the
* DateRangePicker highlights the inclusive end date.
*/
export function toInclusiveDateRange(
dateRange: DateRangeValue,
hasExplicitDateRange: boolean,
): DateRangeValue {
if (hasExplicitDateRange && isMidnight(dateRange.endDate)) {
return {
startDate: dateRange.startDate,
endDate: new Date(dateRange.endDate.getTime() - 1),
};
}
return dateRange;
}
+33 -3
View File
@@ -71,6 +71,7 @@ describe("AuditPage", () => {
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2,
count_cap: 0,
});
// When
@@ -90,6 +91,7 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -114,6 +116,7 @@ describe("AuditPage", () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
await renderPage();
@@ -140,9 +143,11 @@ describe("AuditPage", () => {
describe("Filtering", () => {
it("filters by URL", async () => {
const getAuditLogsSpy = vi
.spyOn(API, "getAuditLogs")
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog],
count: 1,
count_cap: 0,
});
const query = "resource_type:workspace action:create";
await renderPage({ filter: query });
@@ -173,4 +178,29 @@ describe("AuditPage", () => {
);
});
});
describe("Capped count", () => {
it("shows capped count indicator and navigates to next page with correct offset", async () => {
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
audit_logs: [MockAuditLog, MockAuditLog2],
count: 2001,
count_cap: 2000,
});
const user = userEvent.setup();
await renderPage();
await screen.findByText(/2,000\+/);
await user.click(screen.getByRole("button", { name: /next page/i }));
await waitFor(() =>
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
limit: DEFAULT_RECORDS_PER_PAGE,
offset: DEFAULT_RECORDS_PER_PAGE,
q: "",
}),
);
});
});
});
@@ -69,6 +69,7 @@ describe("ConnectionLogPage", () => {
MockDisconnectedSSHConnectionLog,
],
count: 2,
count_cap: 0,
});
// When
@@ -95,6 +96,7 @@ describe("ConnectionLogPage", () => {
.mockResolvedValue({
connection_logs: [MockConnectedSSHConnectionLog],
count: 1,
count_cap: 0,
});
const query = "type:ssh status:ongoing";