Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c174b3037b | |||
| f5165d304f | |||
| 684f21740d | |||
| 86ca61d6ca | |||
| f0521cfa3c | |||
| 0c5d189aff | |||
| d7c8213eee | |||
| 63924ac687 |
@@ -31,7 +31,8 @@ updates:
|
||||
patterns:
|
||||
- "golang.org/x/*"
|
||||
ignore:
|
||||
# Ignore patch updates for all dependencies
|
||||
# Patch updates are handled by the security-patch-prs workflow so this
|
||||
# lane stays focused on broader dependency updates.
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- version-update:semver-patch
|
||||
@@ -56,7 +57,7 @@ updates:
|
||||
labels: []
|
||||
ignore:
|
||||
# We need to coordinate terraform updates with the version hardcoded in
|
||||
# our Go code.
|
||||
# our Go code. These are handled by the security-patch-prs workflow.
|
||||
- dependency-name: "terraform"
|
||||
|
||||
- package-ecosystem: "npm"
|
||||
@@ -117,11 +118,11 @@ updates:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
labels: []
|
||||
groups:
|
||||
coder-modules:
|
||||
patterns:
|
||||
- "coder/*/coder"
|
||||
labels: []
|
||||
ignore:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
name: security-backport
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- labeled
|
||||
- unlabeled
|
||||
- closed
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pull_request:
|
||||
description: Pull request number to backport.
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
LATEST_BRANCH: release/2.31
|
||||
STABLE_BRANCH: release/2.30
|
||||
STABLE_1_BRANCH: release/2.29
|
||||
|
||||
jobs:
|
||||
label-policy:
|
||||
if: github.event_name == 'pull_request_target'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Apply security backport label policy
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
|
||||
pr_number = pr["number"]
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
|
||||
def has(label: str) -> bool:
|
||||
return label in labels
|
||||
|
||||
def ensure_label(label: str) -> None:
|
||||
if not has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--add-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def remove_label(label: str) -> None:
|
||||
if has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def comment(body: str) -> None:
|
||||
subprocess.run(
|
||||
["gh", "pr", "comment", str(pr_number), "--body", body],
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not has("security:patch"):
|
||||
remove_label("status:needs-severity")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if has(label)
|
||||
]
|
||||
if len(severity_labels) == 0:
|
||||
ensure_label("status:needs-severity")
|
||||
comment(
|
||||
"This PR is labeled `security:patch` but is missing a severity "
|
||||
"label. Add one of `severity:medium`, `severity:high`, or "
|
||||
"`severity:critical` before backport automation can proceed."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if len(severity_labels) > 1:
|
||||
comment(
|
||||
"This PR has multiple severity labels. Keep exactly one of "
|
||||
"`severity:medium`, `severity:high`, or `severity:critical`."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
remove_label("status:needs-severity")
|
||||
|
||||
target_labels = [
|
||||
label
|
||||
for label in ("backport:stable", "backport:stable-1")
|
||||
if has(label)
|
||||
]
|
||||
has_none = has("backport:none")
|
||||
if has_none and target_labels:
|
||||
comment(
|
||||
"`backport:none` cannot be combined with other backport labels. "
|
||||
"Remove `backport:none` or remove the explicit backport targets."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not has_none and not target_labels:
|
||||
ensure_label("backport:stable")
|
||||
ensure_label("backport:stable-1")
|
||||
comment(
|
||||
"Applied default backport labels `backport:stable` and "
|
||||
"`backport:stable-1` for a qualifying security patch."
|
||||
)
|
||||
PY
|
||||
|
||||
backport:
|
||||
if: >
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.pull_request.merged == true
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Resolve PR metadata
|
||||
id: metadata
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
|
||||
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
|
||||
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
|
||||
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
|
||||
pr_number="${INPUT_PR_NUMBER}"
|
||||
else
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
fi
|
||||
|
||||
case "${pr_number}" in
|
||||
''|*[!0-9]*)
|
||||
echo "A valid pull request number is required."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
github_output = os.environ["GITHUB_OUTPUT"]
|
||||
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
if "security:patch" not in labels:
|
||||
print("Not a security patch PR; skipping.")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if label in labels
|
||||
]
|
||||
if len(severity_labels) != 1:
|
||||
raise SystemExit(
|
||||
"Merged security patch PR must have exactly one severity label."
|
||||
)
|
||||
|
||||
if not pr.get("mergedAt"):
|
||||
raise SystemExit(f"PR #{pr['number']} is not merged.")
|
||||
|
||||
if "backport:none" in labels:
|
||||
target_pairs = []
|
||||
else:
|
||||
mapping = {
|
||||
"backport:stable": os.environ["STABLE_BRANCH"],
|
||||
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
|
||||
}
|
||||
target_pairs = []
|
||||
for label_name, branch in mapping.items():
|
||||
if label_name in labels and branch and branch != pr["baseRefName"]:
|
||||
target_pairs.append({"label": label_name, "branch": branch})
|
||||
|
||||
with open(github_output, "a", encoding="utf-8") as f:
|
||||
f.write(f"pr_number={pr['number']}\n")
|
||||
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
|
||||
f.write(f"title={pr['title']}\n")
|
||||
f.write(f"url={pr['url']}\n")
|
||||
f.write(f"author={pr['author']['login']}\n")
|
||||
f.write(f"severity_label={severity_labels[0]}\n")
|
||||
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
|
||||
PY
|
||||
|
||||
- name: Backport to release branches
|
||||
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
|
||||
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
|
||||
PR_TITLE: ${{ steps.metadata.outputs.title }}
|
||||
PR_URL: ${{ steps.metadata.outputs.url }}
|
||||
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
|
||||
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
|
||||
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
|
||||
git fetch origin --prune
|
||||
|
||||
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
|
||||
|
||||
failures=()
|
||||
successes=()
|
||||
|
||||
while IFS=$'\t' read -r backport_label target_branch; do
|
||||
[ -n "${target_branch}" ] || continue
|
||||
|
||||
safe_branch_name="${target_branch//\//-}"
|
||||
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
|
||||
|
||||
existing_pr="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state all \
|
||||
--json number,url \
|
||||
--jq '.[0]')"
|
||||
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
|
||||
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
|
||||
successes+=("${target_branch}:existing:${pr_url}")
|
||||
continue
|
||||
fi
|
||||
|
||||
git checkout -B "${head_branch}" "origin/${target_branch}"
|
||||
|
||||
if [ "${merge_parent_count}" -gt 1 ]; then
|
||||
cherry_pick_args=(-m 1 "${MERGE_SHA}")
|
||||
else
|
||||
cherry_pick_args=("${MERGE_SHA}")
|
||||
fi
|
||||
|
||||
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
|
||||
git cherry-pick --abort || true
|
||||
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
|
||||
gh pr comment "${PR_NUMBER}" --body \
|
||||
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
|
||||
failures+=("${target_branch}:cherry-pick failed")
|
||||
continue
|
||||
fi
|
||||
|
||||
git push --force-with-lease origin "${head_branch}"
|
||||
|
||||
body_file="$(mktemp)"
|
||||
printf '%s\n' \
|
||||
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
|
||||
"" \
|
||||
"- Source PR: #${PR_NUMBER}" \
|
||||
"- Source commit: ${MERGE_SHA}" \
|
||||
"- Target branch: ${target_branch}" \
|
||||
"- Severity: ${SEVERITY_LABEL}" \
|
||||
> "${body_file}"
|
||||
|
||||
pr_url="$(gh pr create \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--title "${PR_TITLE} (backport to ${target_branch})" \
|
||||
--body-file "${body_file}")"
|
||||
|
||||
backport_pr_number="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state open \
|
||||
--json number \
|
||||
--jq '.[0].number')"
|
||||
|
||||
gh pr edit "${backport_pr_number}" \
|
||||
--add-label "security:patch" \
|
||||
--add-label "${SEVERITY_LABEL}" \
|
||||
--add-label "${backport_label}" || true
|
||||
|
||||
successes+=("${target_branch}:created:${pr_url}")
|
||||
done < <(
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
for pair in json.loads(os.environ["TARGET_PAIRS"]):
|
||||
print(f"{pair['label']}\t{pair['branch']}")
|
||||
PY
|
||||
)
|
||||
|
||||
summary_file="$(mktemp)"
|
||||
{
|
||||
echo "## Security backport summary"
|
||||
echo
|
||||
if [ "${#successes[@]}" -gt 0 ]; then
|
||||
echo "### Created or existing"
|
||||
for entry in "${successes[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
echo
|
||||
fi
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
echo "### Failures"
|
||||
for entry in "${failures[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
fi
|
||||
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
|
||||
|
||||
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
|
||||
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -0,0 +1,214 @@
|
||||
name: security-patch-prs
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1-5"
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
patch:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
lane:
|
||||
- gomod
|
||||
- terraform
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Patch Go dependencies
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
go get -u=patch ./...
|
||||
go mod tidy
|
||||
|
||||
# Guardrail: do not auto-edit replace directives.
|
||||
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
|
||||
echo "Refusing to auto-edit go.mod replace directives"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Guardrail: only go.mod / go.sum may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Patch bundled Terraform
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
current="$(
|
||||
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
|
||||
provisioner/terraform/install.go \
|
||||
| head -1 \
|
||||
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
|
||||
)"
|
||||
|
||||
series="$(echo "$current" | cut -d. -f1,2)"
|
||||
|
||||
latest="$(
|
||||
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
|
||||
| jq -r --arg series "$series" '
|
||||
.versions
|
||||
| keys[]
|
||||
| select(startswith($series + "."))
|
||||
' \
|
||||
| sort -V \
|
||||
| tail -1
|
||||
)"
|
||||
|
||||
test -n "$latest"
|
||||
[ "$latest" != "$current" ] || exit 0
|
||||
|
||||
CURRENT_TERRAFORM_VERSION="$current" \
|
||||
LATEST_TERRAFORM_VERSION="$latest" \
|
||||
python3 - <<'PY'
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
current = os.environ["CURRENT_TERRAFORM_VERSION"]
|
||||
latest = os.environ["LATEST_TERRAFORM_VERSION"]
|
||||
|
||||
updates = {
|
||||
"scripts/Dockerfile.base": (
|
||||
f"terraform/{current}/",
|
||||
f"terraform/{latest}/",
|
||||
),
|
||||
"provisioner/terraform/install.go": (
|
||||
f'NewVersion("{current}")',
|
||||
f'NewVersion("{latest}")',
|
||||
),
|
||||
"install.sh": (
|
||||
f'TERRAFORM_VERSION="{current}"',
|
||||
f'TERRAFORM_VERSION="{latest}"',
|
||||
),
|
||||
}
|
||||
|
||||
for path_str, (before, after) in updates.items():
|
||||
path = Path(path_str)
|
||||
content = path.read_text()
|
||||
if before not in content:
|
||||
raise SystemExit(f"did not find expected text in {path_str}: {before}")
|
||||
path.write_text(content.replace(before, after))
|
||||
PY
|
||||
|
||||
# Guardrail: only the Terraform-version files may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Validate Go dependency patch
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./...
|
||||
|
||||
- name: Validate Terraform patch
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./provisioner/terraform/...
|
||||
docker build -f scripts/Dockerfile.base .
|
||||
|
||||
- name: Skip PR creation when there are no changes
|
||||
id: changes
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if git diff --quiet; then
|
||||
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
|
||||
else
|
||||
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
|
||||
- name: Commit changes
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git checkout -B "secpatch/${{ matrix.lane }}"
|
||||
git add -A
|
||||
git commit -m "security: patch ${{ matrix.lane }}"
|
||||
|
||||
- name: Push branch
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git push --force-with-lease \
|
||||
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
|
||||
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
|
||||
|
||||
- name: Create or update PR
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
branch="secpatch/${{ matrix.lane }}"
|
||||
title="security: patch ${{ matrix.lane }}"
|
||||
body="$(cat <<'EOF'
|
||||
Automated security patch PR for `${{ matrix.lane }}`.
|
||||
|
||||
Scope:
|
||||
- gomod: patch-level Go dependency updates only
|
||||
- terraform: bundled Terraform patch updates only
|
||||
|
||||
Guardrails:
|
||||
- no application-code edits
|
||||
- no auto-editing of go.mod replace directives
|
||||
- CI must pass
|
||||
EOF
|
||||
)"
|
||||
|
||||
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
if [[ -n "${existing_pr}" ]]; then
|
||||
gh pr edit "${existing_pr}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="${existing_pr}"
|
||||
else
|
||||
gh pr create \
|
||||
--base main \
|
||||
--head "${branch}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
fi
|
||||
|
||||
for label in security dependencies automated-pr; do
|
||||
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
|
||||
gh pr edit "${pr_number}" --add-label "${label}"
|
||||
fi
|
||||
done
|
||||
Generated
+6
@@ -14175,6 +14175,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -14496,6 +14499,9 @@ const docTemplate = `{
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Generated
+6
@@ -12739,6 +12739,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -13039,6 +13042,9 @@
|
||||
},
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"count_cap": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
+8
-1
@@ -26,6 +26,11 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Limit the count query to avoid a slow sequential scan due to joins
|
||||
// on a large table. Set to 0 to disable capping (but also see the note
|
||||
// in the SQL query).
|
||||
const auditLogCountCap = 2000
|
||||
|
||||
// @Summary Get audit logs
|
||||
// @ID get-audit-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -66,7 +71,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
countFilter.Username = ""
|
||||
}
|
||||
|
||||
// Use the same filters to count the number of audit logs
|
||||
countFilter.CountCap = auditLogCountCap
|
||||
count, err := api.Database.CountAuditLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -81,6 +86,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: []codersdk.AuditLog{},
|
||||
Count: 0,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -98,6 +104,7 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: api.convertAuditLogs(ctx, dblogs),
|
||||
Count: count,
|
||||
CountCap: auditLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -5782,15 +5782,15 @@ func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByI
|
||||
return q.db.UpdateChatByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func (q *querier) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
// The batch heartbeat is a system-level operation filtered by
|
||||
// worker_id. Authorization is enforced by the AsChatd context
|
||||
// at the call site rather than per-row, because checking each
|
||||
// row individually would defeat the purpose of batching.
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.UpdateChatHeartbeat(ctx, arg)
|
||||
return q.db.UpdateChatHeartbeats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateChatLabelsByIDParams) (database.Chat, error) {
|
||||
|
||||
@@ -842,15 +842,15 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
s.Run("UpdateChatHeartbeats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
resultID := uuid.New()
|
||||
arg := database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{resultID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatHeartbeat(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(int64(1))
|
||||
dbm.EXPECT().UpdateChatHeartbeats(gomock.Any(), arg).Return([]uuid.UUID{resultID}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns([]uuid.UUID{resultID})
|
||||
}))
|
||||
s.Run("UpdateChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
|
||||
@@ -4136,11 +4136,11 @@ func (m queryMetricsStore) UpdateChatByID(ctx context.Context, arg database.Upda
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
func (m queryMetricsStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatHeartbeat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeat").Inc()
|
||||
r0, r1 := m.s.UpdateChatHeartbeats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatHeartbeats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatHeartbeats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
|
||||
@@ -7835,19 +7835,19 @@ func (mr *MockStoreMockRecorder) UpdateChatByID(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatByID", reflect.TypeOf((*MockStore)(nil).UpdateChatByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeat(ctx context.Context, arg database.UpdateChatHeartbeatParams) (int64, error) {
|
||||
// UpdateChatHeartbeats mocks base method.
|
||||
func (m *MockStore) UpdateChatHeartbeats(ctx context.Context, arg database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeat", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret := m.ctrl.Call(m, "UpdateChatHeartbeats", ctx, arg)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatHeartbeat indicates an expected call of UpdateChatHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeat(ctx, arg any) *gomock.Call {
|
||||
// UpdateChatHeartbeats indicates an expected call of UpdateChatHeartbeats.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatHeartbeats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeat), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatHeartbeats", reflect.TypeOf((*MockStore)(nil).UpdateChatHeartbeats), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLabelsByID mocks base method.
|
||||
|
||||
@@ -584,6 +584,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -720,6 +721,7 @@ func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg Coun
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
@@ -145,5 +145,13 @@ func extractWhereClause(query string) string {
|
||||
// Remove SQL comments
|
||||
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
|
||||
|
||||
// Normalize indentation so subquery wrapping doesn't cause
|
||||
// mismatches.
|
||||
lines := strings.Split(whereClause, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = strings.TrimLeft(line, " \t")
|
||||
}
|
||||
whereClause = strings.Join(lines, "\n")
|
||||
|
||||
return strings.TrimSpace(whereClause)
|
||||
}
|
||||
|
||||
@@ -870,9 +870,11 @@ type sqlcQuerier interface {
|
||||
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
|
||||
UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error)
|
||||
UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error)
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
// Updates the cached injected context parts (AGENTS.md +
|
||||
// skills) on the chat row. Called only when context changes
|
||||
|
||||
+242
-204
@@ -2275,93 +2275,105 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
|
||||
}
|
||||
|
||||
const countAuditLogs = `-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN $1::text != '' THEN resource_type = $1::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = $2
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN resource_target = $4
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN action = $5::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower($7)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN $9::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN $10::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN $11::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN $12::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF($13::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountAuditLogsParams struct {
|
||||
@@ -2377,6 +2389,7 @@ type CountAuditLogsParams struct {
|
||||
DateTo time.Time `db:"date_to" json:"date_to"`
|
||||
BuildReason string `db:"build_reason" json:"build_reason"`
|
||||
RequestID uuid.UUID `db:"request_id" json:"request_id"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) {
|
||||
@@ -2393,6 +2406,7 @@ func (q *sqlQuerier) CountAuditLogs(ctx context.Context, arg CountAuditLogsParam
|
||||
arg.DateTo,
|
||||
arg.BuildReason,
|
||||
arg.RequestID,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
@@ -6601,30 +6615,49 @@ func (q *sqlQuerier) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParam
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatHeartbeat = `-- name: UpdateChatHeartbeat :execrows
|
||||
const updateChatHeartbeats = `-- name: UpdateChatHeartbeats :many
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = $1::timestamptz
|
||||
WHERE
|
||||
id = $1::uuid
|
||||
AND worker_id = $2::uuid
|
||||
id = ANY($2::uuid[])
|
||||
AND worker_id = $3::uuid
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
type UpdateChatHeartbeatParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
type UpdateChatHeartbeatsParams struct {
|
||||
Now time.Time `db:"now" json:"now"`
|
||||
IDs []uuid.UUID `db:"ids" json:"ids"`
|
||||
WorkerID uuid.UUID `db:"worker_id" json:"worker_id"`
|
||||
}
|
||||
|
||||
// Bumps the heartbeat timestamp for a running chat so that other
|
||||
// replicas know the worker is still alive.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) {
|
||||
result, err := q.db.ExecContext(ctx, updateChatHeartbeat, arg.ID, arg.WorkerID)
|
||||
// Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
// provided they are still running and owned by the specified
|
||||
// worker. Returns the IDs that were actually updated so the
|
||||
// caller can detect stolen or completed chats via set-difference.
|
||||
func (q *sqlQuerier) UpdateChatHeartbeats(ctx context.Context, arg UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, updateChatHeartbeats, arg.Now, pq.Array(arg.IDs), arg.WorkerID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var id uuid.UUID
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateChatLabelsByID = `-- name: UpdateChatLabelsByID :one
|
||||
@@ -7571,110 +7604,113 @@ func (q *sqlQuerier) BatchUpsertConnectionLogs(ctx context.Context, arg BatchUps
|
||||
}
|
||||
|
||||
const countConnectionLogs = `-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($2) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = $3
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN $4 :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = $4 AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN $5 :: text != '' THEN
|
||||
type = $5 :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $6
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN $7 :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower($7) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN $8 :: text != '' THEN
|
||||
users.email = $8
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= $9
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= $10
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = $11
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = $12
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN $13 :: text != '' THEN
|
||||
(($13 = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
($13 = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF($14::int, 0) + 1
|
||||
) AS limited_count
|
||||
`
|
||||
|
||||
type CountConnectionLogsParams struct {
|
||||
@@ -7691,6 +7727,7 @@ type CountConnectionLogsParams struct {
|
||||
WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"`
|
||||
ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"`
|
||||
Status string `db:"status" json:"status"`
|
||||
CountCap int32 `db:"count_cap" json:"count_cap"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) {
|
||||
@@ -7708,6 +7745,7 @@ func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectio
|
||||
arg.WorkspaceID,
|
||||
arg.ConnectionID,
|
||||
arg.Status,
|
||||
arg.CountCap,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
|
||||
@@ -149,94 +149,105 @@ VALUES (
|
||||
RETURNING *;
|
||||
|
||||
-- name: CountAuditLogs :one
|
||||
SELECT COUNT(*)
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1
|
||||
FROM audit_logs
|
||||
LEFT JOIN users ON audit_logs.user_id = users.id
|
||||
LEFT JOIN organizations ON audit_logs.organization_id = organizations.id
|
||||
-- First join on workspaces to get the initial workspace create
|
||||
-- to workspace build 1 id. This is because the first create is
|
||||
-- is a different audit log than subsequent starts.
|
||||
LEFT JOIN workspaces ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.resource_id = workspaces.id
|
||||
-- Get the reason from the build if the resource type
|
||||
-- is a workspace_build
|
||||
LEFT JOIN workspace_builds wb_build ON audit_logs.resource_type = 'workspace_build'
|
||||
AND audit_logs.resource_id = wb_build.id
|
||||
-- Get the reason from the build #1 if this is the first
|
||||
-- workspace create.
|
||||
LEFT JOIN workspace_builds wb_workspace ON audit_logs.resource_type = 'workspace'
|
||||
AND audit_logs.action = 'create'
|
||||
AND workspaces.id = wb_workspace.workspace_id
|
||||
AND wb_workspace.build_number = 1
|
||||
WHERE
|
||||
-- Filter resource_type
|
||||
CASE
|
||||
WHEN @resource_type::text != '' THEN resource_type = @resource_type::resource_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter resource_id
|
||||
AND CASE
|
||||
WHEN @resource_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN resource_id = @resource_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter organization_id
|
||||
AND CASE
|
||||
WHEN @organization_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by resource_target
|
||||
AND CASE
|
||||
WHEN @resource_target::text != '' THEN resource_target = @resource_target
|
||||
ELSE true
|
||||
END
|
||||
-- Filter action
|
||||
AND CASE
|
||||
WHEN @action::text != '' THEN action = @action::audit_action
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username::text != '' THEN user_id = (
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE lower(username) = lower(@username)
|
||||
AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @email::text != '' THEN users.email = @email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_from
|
||||
AND CASE
|
||||
WHEN @date_from::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" >= @date_from
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by date_to
|
||||
AND CASE
|
||||
WHEN @date_to::timestamp with time zone != '0001-01-01 00:00:00Z' THEN "time" <= @date_to
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by build_reason
|
||||
AND CASE
|
||||
WHEN @build_reason::text != '' THEN COALESCE(wb_build.reason::text, wb_workspace.reason::text) = @build_reason
|
||||
ELSE true
|
||||
END
|
||||
-- Filter request_id
|
||||
AND CASE
|
||||
WHEN @request_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN audit_logs.request_id = @request_id
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAuditLogs
|
||||
-- @authorize_filter
|
||||
-- Avoid a slow scan on a large table with joins. The caller
|
||||
-- passes the count cap and we add 1 so the frontend can detect
|
||||
-- capping and show "... of N+". A cap of 0 means no limit (NULLIF
|
||||
-- -> NULL + 1 = NULL).
|
||||
-- NOTE: Parameterizing this so that we can easily change from,
|
||||
-- e.g., 2000 to 5000. However, use literal NULL (or no LIMIT)
|
||||
-- here if disabling the capping on a large table permanently.
|
||||
-- This way the PG planner can plan parallel execution for
|
||||
-- potential large wins.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldAuditLogConnectionEvents :exec
|
||||
DELETE FROM audit_logs
|
||||
|
||||
@@ -674,17 +674,20 @@ WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeat :execrows
|
||||
-- Bumps the heartbeat timestamp for a running chat so that other
|
||||
-- replicas know the worker is still alive.
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
-- provided they are still running and owned by the specified
|
||||
-- worker. Returns the IDs that were actually updated so the
|
||||
-- caller can detect stolen or completed chats via set-difference.
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
heartbeat_at = NOW()
|
||||
heartbeat_at = @now::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
id = ANY(@ids::uuid[])
|
||||
AND worker_id = @worker_id::uuid
|
||||
AND status = 'running'::chat_status;
|
||||
AND status = 'running'::chat_status
|
||||
RETURNING id;
|
||||
|
||||
-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
|
||||
@@ -133,111 +133,113 @@ OFFSET
|
||||
@offset_opt;
|
||||
|
||||
-- name: CountConnectionLogs :one
|
||||
SELECT
|
||||
COUNT(*) AS count
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
;
|
||||
SELECT COUNT(*) AS count FROM (
|
||||
SELECT 1
|
||||
FROM
|
||||
connection_logs
|
||||
JOIN users AS workspace_owner ON
|
||||
connection_logs.workspace_owner_id = workspace_owner.id
|
||||
LEFT JOIN users ON
|
||||
connection_logs.user_id = users.id
|
||||
JOIN organizations ON
|
||||
connection_logs.organization_id = organizations.id
|
||||
WHERE
|
||||
-- Filter organization_id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace owner username
|
||||
AND CASE
|
||||
WHEN @workspace_owner :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@workspace_owner) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_id
|
||||
AND CASE
|
||||
WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
workspace_owner_id = @workspace_owner_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_owner_email
|
||||
AND CASE
|
||||
WHEN @workspace_owner_email :: text != '' THEN
|
||||
workspace_owner_id = (
|
||||
SELECT id FROM users
|
||||
WHERE email = @workspace_owner_email AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by type
|
||||
AND CASE
|
||||
WHEN @type :: text != '' THEN
|
||||
type = @type :: connection_type
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by username
|
||||
AND CASE
|
||||
WHEN @username :: text != '' THEN
|
||||
user_id = (
|
||||
SELECT id FROM users
|
||||
WHERE lower(username) = lower(@username) AND deleted = false
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user_email
|
||||
AND CASE
|
||||
WHEN @user_email :: text != '' THEN
|
||||
users.email = @user_email
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_after
|
||||
AND CASE
|
||||
WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time >= @connected_after
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connected_before
|
||||
AND CASE
|
||||
WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN
|
||||
connect_time <= @connected_before
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by workspace_id
|
||||
AND CASE
|
||||
WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.workspace_id = @workspace_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by connection_id
|
||||
AND CASE
|
||||
WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
connection_logs.connection_id = @connection_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by whether the session has a disconnect_time
|
||||
AND CASE
|
||||
WHEN @status :: text != '' THEN
|
||||
((@status = 'ongoing' AND disconnect_time IS NULL) OR
|
||||
(@status = 'completed' AND disconnect_time IS NOT NULL)) AND
|
||||
-- Exclude web events, since we don't know their close time.
|
||||
"type" NOT IN ('workspace_app', 'port_forwarding')
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in
|
||||
-- CountAuthorizedConnectionLogs
|
||||
-- @authorize_filter
|
||||
-- NOTE: See the CountAuditLogs LIMIT note.
|
||||
LIMIT NULLIF(@count_cap::int, 0) + 1
|
||||
) AS limited_count;
|
||||
|
||||
-- name: DeleteOldConnectionLogs :execrows
|
||||
WITH old_logs AS (
|
||||
|
||||
@@ -298,6 +298,40 @@ neq(input.object.owner, "");
|
||||
ExpectedSQL: p("'' = 'org-id'"),
|
||||
VariableConverter: regosql.ChatConverter(),
|
||||
},
|
||||
{
|
||||
Name: "AuditLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708", "05f58202-4bfc-43ce-9ba4-5ff6e0174a71"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("audit_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("audit_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("audit_logs.organization_id = ANY(ARRAY ['05f58202-4bfc-43ce-9ba4-5ff6e0174a71'::uuid,'8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.AuditLogConverter(),
|
||||
},
|
||||
{
|
||||
Name: "ConnectionLogUUID",
|
||||
Queries: []string{
|
||||
`"8c0b9bdc-a013-4b14-a49b-5747bc335708" = input.object.org_owner`,
|
||||
`input.object.org_owner != ""`,
|
||||
`neq(input.object.org_owner, "8c0b9bdc-a013-4b14-a49b-5747bc335708")`,
|
||||
`input.object.org_owner in {"8c0b9bdc-a013-4b14-a49b-5747bc335708"}`,
|
||||
`"read" in input.object.acl_group_list[input.object.org_owner]`,
|
||||
},
|
||||
ExpectedSQL: p(
|
||||
p("connection_logs.organization_id = '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id IS NOT NULL") + " OR " +
|
||||
p("connection_logs.organization_id != '8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid") + " OR " +
|
||||
p("connection_logs.organization_id = ANY(ARRAY ['8c0b9bdc-a013-4b14-a49b-5747bc335708'::uuid])") + " OR " +
|
||||
"(false)"),
|
||||
VariableConverter: regosql.ConnectionLogConverter(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -53,7 +53,7 @@ func WorkspaceConverter() *sqltypes.VariableConverter {
|
||||
func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(audit_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("audit_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Audit logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
@@ -67,7 +67,7 @@ func AuditLogConverter() *sqltypes.VariableConverter {
|
||||
func ConnectionLogConverter() *sqltypes.VariableConverter {
|
||||
matcher := sqltypes.NewVariableConverter().RegisterMatcher(
|
||||
resourceIDMatcher(),
|
||||
sqltypes.StringVarMatcher("COALESCE(connection_logs.organization_id :: text, '')", []string{"input", "object", "org_owner"}),
|
||||
sqltypes.UUIDVarMatcher("connection_logs.organization_id", []string{"input", "object", "org_owner"}),
|
||||
// Connection logs have no user owner, only owner by an organization.
|
||||
sqltypes.AlwaysFalse(userOwnerMatcher()),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package sqltypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ VariableMatcher = astUUIDVar{}
|
||||
_ Node = astUUIDVar{}
|
||||
_ SupportsEquality = astUUIDVar{}
|
||||
)
|
||||
|
||||
// astUUIDVar is a variable that represents a UUID column. Unlike
|
||||
// astStringVar it emits native UUID comparisons (column = 'val'::uuid)
|
||||
// instead of text-based ones (COALESCE(column::text, ”) = 'val').
|
||||
// This allows PostgreSQL to use indexes on UUID columns.
|
||||
type astUUIDVar struct {
|
||||
Source RegoSource
|
||||
FieldPath []string
|
||||
ColumnString string
|
||||
}
|
||||
|
||||
func UUIDVarMatcher(sqlColumn string, regoPath []string) VariableMatcher {
|
||||
return astUUIDVar{FieldPath: regoPath, ColumnString: sqlColumn}
|
||||
}
|
||||
|
||||
func (astUUIDVar) UseAs() Node { return astUUIDVar{} }
|
||||
|
||||
func (u astUUIDVar) ConvertVariable(rego ast.Ref) (Node, bool) {
|
||||
left, err := RegoVarPath(u.FieldPath, rego)
|
||||
if err == nil && len(left) == 0 {
|
||||
return astUUIDVar{
|
||||
Source: RegoSource(rego.String()),
|
||||
FieldPath: u.FieldPath,
|
||||
ColumnString: u.ColumnString,
|
||||
}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (u astUUIDVar) SQLString(_ *SQLGenerator) string {
|
||||
return u.ColumnString
|
||||
}
|
||||
|
||||
// EqualsSQLString handles equality comparisons for UUID columns.
|
||||
// Rego always produces string literals, so we accept AstString and
|
||||
// cast the literal to ::uuid in the output SQL. This lets PG use
|
||||
// native UUID indexes instead of falling back to text comparisons.
|
||||
// nolint:revive
|
||||
func (u astUUIDVar) EqualsSQLString(cfg *SQLGenerator, not bool, other Node) (string, error) {
|
||||
switch other.UseAs().(type) {
|
||||
case AstString:
|
||||
// The other side is a rego string literal like
|
||||
// "8c0b9bdc-a013-4b14-a49b-5747bc335708". Emit a comparison
|
||||
// that casts the literal to uuid so PG can use indexes:
|
||||
// column = 'val'::uuid
|
||||
// instead of the text-based:
|
||||
// 'val' = COALESCE(column::text, '')
|
||||
s, ok := other.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString, got %T", other)
|
||||
}
|
||||
if s.Value == "" {
|
||||
// Empty string in rego means "no value". Compare the
|
||||
// column against NULL since UUID columns represent
|
||||
// absent values as NULL, not empty strings.
|
||||
op := "IS NULL"
|
||||
if not {
|
||||
op = "IS NOT NULL"
|
||||
}
|
||||
return fmt.Sprintf("%s %s", u.ColumnString, op), nil
|
||||
}
|
||||
return fmt.Sprintf("%s %s '%s'::uuid",
|
||||
u.ColumnString, equalsOp(not), s.Value), nil
|
||||
case astUUIDVar:
|
||||
return basicSQLEquality(cfg, not, u, other), nil
|
||||
default:
|
||||
return "", xerrors.Errorf("unsupported equality: %T %s %T",
|
||||
u, equalsOp(not), other)
|
||||
}
|
||||
}
|
||||
|
||||
// ContainedInSQL implements SupportsContainedIn so that a UUID column
|
||||
// can appear in membership checks like `col = ANY(ARRAY[...])`. The
|
||||
// array elements are rego strings, so we cast each to ::uuid.
|
||||
func (u astUUIDVar) ContainedInSQL(_ *SQLGenerator, haystack Node) (string, error) {
|
||||
arr, ok := haystack.(ASTArray)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("unsupported containedIn: %T in %T", u, haystack)
|
||||
}
|
||||
|
||||
if len(arr.Value) == 0 {
|
||||
return "false", nil
|
||||
}
|
||||
|
||||
// Build ARRAY['uuid1'::uuid, 'uuid2'::uuid, ...]
|
||||
values := make([]string, 0, len(arr.Value))
|
||||
for _, v := range arr.Value {
|
||||
s, ok := v.(AstString)
|
||||
if !ok {
|
||||
return "", xerrors.Errorf("expected AstString array element, got %T", v)
|
||||
}
|
||||
values = append(values, fmt.Sprintf("'%s'::uuid", s.Value))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s = ANY(ARRAY [%s])",
|
||||
u.ColumnString,
|
||||
strings.Join(values, ",")), nil
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G
|
||||
}
|
||||
|
||||
// Prepare the count filter, which uses the same parameters as the GetAuditLogsOffsetParams.
|
||||
// nolint:exhaustruct // UserID is not obtained from the query parameters.
|
||||
// nolint:exhaustruct // UserID and CountCap are not obtained from the query parameters.
|
||||
countFilter := database.CountAuditLogsParams{
|
||||
RequestID: filter.RequestID,
|
||||
ResourceID: filter.ResourceID,
|
||||
@@ -123,6 +123,7 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey
|
||||
}
|
||||
|
||||
// This MUST be kept in sync with the above
|
||||
// nolint:exhaustruct // CountCap is not obtained from the query parameters.
|
||||
countFilter := database.CountConnectionLogsParams{
|
||||
OrganizationID: filter.OrganizationID,
|
||||
WorkspaceOwner: filter.WorkspaceOwner,
|
||||
|
||||
+124
-28
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -151,6 +152,12 @@ type Server struct {
|
||||
inFlightChatStaleAfter time.Duration
|
||||
chatHeartbeatInterval time.Duration
|
||||
|
||||
// heartbeatMu guards heartbeatRegistry.
|
||||
heartbeatMu sync.Mutex
|
||||
// heartbeatRegistry maps chat IDs to their cancel functions
|
||||
// and workspace state for the centralized heartbeat loop.
|
||||
heartbeatRegistry map[uuid.UUID]*heartbeatEntry
|
||||
|
||||
// wakeCh is signaled by SendMessage, EditMessage, CreateChat,
|
||||
// and PromoteQueued so the run loop calls processOnce
|
||||
// immediately instead of waiting for the next ticker.
|
||||
@@ -706,6 +713,17 @@ type chatStreamState struct {
|
||||
bufferRetainedAt time.Time
|
||||
}
|
||||
|
||||
// heartbeatEntry tracks a single chat's cancel function and workspace
|
||||
// state for the centralized heartbeat loop. Instead of spawning a
|
||||
// per-chat goroutine, processChat registers an entry here and the
|
||||
// single heartbeatLoop goroutine handles all chats.
|
||||
type heartbeatEntry struct {
|
||||
cancelWithCause context.CancelCauseFunc
|
||||
chatID uuid.UUID
|
||||
workspaceID uuid.NullUUID
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
// resetDropCounters zeroes the rate-limiting state for both buffer
|
||||
// and subscriber drop warnings. The caller must hold s.mu.
|
||||
func (s *chatStreamState) resetDropCounters() {
|
||||
@@ -2420,8 +2438,8 @@ func New(cfg Config) *Server {
|
||||
clock: clk,
|
||||
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
||||
ctx = dbauthz.AsChatd(ctx)
|
||||
|
||||
@@ -2461,6 +2479,9 @@ func (p *Server) start(ctx context.Context) {
|
||||
// to handle chats orphaned by crashed or redeployed workers.
|
||||
p.recoverStaleChats(ctx)
|
||||
|
||||
// Single heartbeat loop for all chats on this replica.
|
||||
go p.heartbeatLoop(ctx)
|
||||
|
||||
acquireTicker := p.clock.NewTicker(
|
||||
p.pendingChatAcquireInterval,
|
||||
"chatd",
|
||||
@@ -2730,6 +2751,97 @@ func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) {
|
||||
p.workspaceMCPToolsCache.Delete(chatID)
|
||||
}
|
||||
|
||||
// registerHeartbeat enrolls a chat in the centralized batch
|
||||
// heartbeat loop. Must be called after chatCtx is created.
|
||||
func (p *Server) registerHeartbeat(entry *heartbeatEntry) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
if _, exists := p.heartbeatRegistry[entry.chatID]; exists {
|
||||
p.logger.Warn(context.Background(),
|
||||
"duplicate heartbeat registration, skipping",
|
||||
slog.F("chat_id", entry.chatID))
|
||||
return
|
||||
}
|
||||
p.heartbeatRegistry[entry.chatID] = entry
|
||||
}
|
||||
|
||||
// unregisterHeartbeat removes a chat from the centralized
|
||||
// heartbeat loop when chat processing finishes.
|
||||
func (p *Server) unregisterHeartbeat(chatID uuid.UUID) {
|
||||
p.heartbeatMu.Lock()
|
||||
defer p.heartbeatMu.Unlock()
|
||||
delete(p.heartbeatRegistry, chatID)
|
||||
}
|
||||
|
||||
// heartbeatLoop runs in a single goroutine, issuing one batch
|
||||
// heartbeat query per interval for all registered chats.
|
||||
func (p *Server) heartbeatLoop(ctx context.Context) {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.heartbeatTick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatTick issues a single batch UPDATE for all running chats
|
||||
// owned by this worker. Chats missing from the result set are
|
||||
// interrupted (stolen by another replica or already completed).
|
||||
func (p *Server) heartbeatTick(ctx context.Context) {
|
||||
// Snapshot the registry under the lock.
|
||||
p.heartbeatMu.Lock()
|
||||
snapshot := maps.Clone(p.heartbeatRegistry)
|
||||
p.heartbeatMu.Unlock()
|
||||
|
||||
if len(snapshot) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect the IDs we believe we own.
|
||||
ids := slices.Collect(maps.Keys(snapshot))
|
||||
|
||||
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
||||
// access for batch-updating heartbeats.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: ids,
|
||||
WorkerID: p.workerID,
|
||||
Now: p.clock.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Build a set of IDs that were successfully updated.
|
||||
updated := make(map[uuid.UUID]struct{}, len(updatedIDs))
|
||||
for _, id := range updatedIDs {
|
||||
updated[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Interrupt registered chats that were not in the result
|
||||
// (stolen by another replica or already completed).
|
||||
for id, entry := range snapshot {
|
||||
if _, ok := updated[id]; !ok {
|
||||
entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting")
|
||||
entry.cancelWithCause(chatloop.ErrInterrupted)
|
||||
continue
|
||||
}
|
||||
// Bump workspace usage for surviving chats.
|
||||
newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger)
|
||||
// Update workspace ID in the registry for next tick.
|
||||
p.heartbeatMu.Lock()
|
||||
if current, exists := p.heartbeatRegistry[id]; exists {
|
||||
current.workspaceID = newWsID
|
||||
}
|
||||
p.heartbeatMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) Subscribe(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
@@ -3575,33 +3687,17 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Periodically update the heartbeat so other replicas know this
|
||||
// worker is still alive. The goroutine stops when chatCtx is
|
||||
// canceled (either by completion or interruption).
|
||||
go func() {
|
||||
ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "heartbeat")
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-chatCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
WorkerID: p.workerID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
if rows == 0 {
|
||||
cancel(chatloop.ErrInterrupted)
|
||||
return
|
||||
}
|
||||
chat.WorkspaceID = p.trackWorkspaceUsage(chatCtx, chat.ID, chat.WorkspaceID, logger)
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Register with the centralized heartbeat loop instead of
|
||||
// running a per-chat goroutine. The loop issues a single batch
|
||||
// UPDATE for all chats on this worker and detects stolen chats
|
||||
// via set-difference.
|
||||
p.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chat.ID,
|
||||
workspaceID: chat.WorkspaceID,
|
||||
logger: logger,
|
||||
})
|
||||
defer p.unregisterHeartbeat(chat.ID)
|
||||
|
||||
// Start buffering stream events BEFORE publishing the running
|
||||
// status. This closes a race where a subscriber sees
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
@@ -2071,6 +2072,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
configCache: newChatConfigCache(ctx, db, clock),
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Publish a stale "pending" notification on the control channel
|
||||
@@ -2133,3 +2135,130 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) {
|
||||
require.Equal(t, database.ChatStatusError, finalStatus,
|
||||
"processChat should have reached runChat (error), not been interrupted (waiting)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the
|
||||
// batch heartbeat UPDATE does not return a registered chat's ID
|
||||
// (because another replica stole it or it was completed), the
|
||||
// heartbeat tick cancels that chat's context with ErrInterrupted
|
||||
// while leaving surviving chats untouched.
|
||||
func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
workerID := uuid.New()
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: workerID,
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
// Create three chats with independent cancel functions.
|
||||
chat1 := uuid.New()
|
||||
chat2 := uuid.New()
|
||||
chat3 := uuid.New()
|
||||
|
||||
_, cancel1 := context.WithCancelCause(ctx)
|
||||
_, cancel2 := context.WithCancelCause(ctx)
|
||||
ctx3, cancel3 := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel1,
|
||||
chatID: chat1,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel2,
|
||||
chatID: chat2,
|
||||
logger: logger,
|
||||
})
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel3,
|
||||
chatID: chat3,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// The batch UPDATE returns only chat1 and chat2 —
|
||||
// chat3 was "stolen" by another replica.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) {
|
||||
require.Equal(t, workerID, params.WorkerID)
|
||||
require.Len(t, params.IDs, 3)
|
||||
// Return only chat1 and chat2 as surviving.
|
||||
return []uuid.UUID{chat1, chat2}, nil
|
||||
},
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// chat3's context should be canceled with ErrInterrupted.
|
||||
require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted,
|
||||
"stolen chat should be interrupted")
|
||||
|
||||
// chat3 should have been removed from the registry by
|
||||
// unregister (in production this happens via defer in
|
||||
// processChat). The heartbeat tick itself does not
|
||||
// unregister — it only cancels. Verify the entry is
|
||||
// still present (processChat's defer would clean it up).
|
||||
server.heartbeatMu.Lock()
|
||||
_, chat1Exists := server.heartbeatRegistry[chat1]
|
||||
_, chat2Exists := server.heartbeatRegistry[chat2]
|
||||
_, chat3Exists := server.heartbeatRegistry[chat3]
|
||||
server.heartbeatMu.Unlock()
|
||||
|
||||
require.True(t, chat1Exists, "surviving chat1 should remain registered")
|
||||
require.True(t, chat2Exists, "surviving chat2 should remain registered")
|
||||
require.True(t, chat3Exists,
|
||||
"stolen chat3 should still be in registry (processChat defer removes it)")
|
||||
}
|
||||
|
||||
// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a
|
||||
// transient database failure causes the tick to log and return
|
||||
// without canceling any registered chats.
|
||||
func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
clock := quartz.NewMock(t)
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: clock,
|
||||
workerID: uuid.New(),
|
||||
chatHeartbeatInterval: time.Minute,
|
||||
heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry),
|
||||
}
|
||||
|
||||
chatID := uuid.New()
|
||||
chatCtx, cancel := context.WithCancelCause(ctx)
|
||||
|
||||
server.registerHeartbeat(&heartbeatEntry{
|
||||
cancelWithCause: cancel,
|
||||
chatID: chatID,
|
||||
logger: logger,
|
||||
})
|
||||
|
||||
// Simulate a transient DB error.
|
||||
db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return(
|
||||
nil, xerrors.New("connection reset"),
|
||||
)
|
||||
|
||||
server.heartbeatTick(ctx)
|
||||
|
||||
// Chat should NOT be interrupted — the tick logged and
|
||||
// returned early.
|
||||
require.NoError(t, chatCtx.Err(),
|
||||
"chat context should not be canceled on transient DB error")
|
||||
}
|
||||
|
||||
@@ -474,7 +474,7 @@ func TestArchiveChatInterruptsActiveProcessing(t *testing.T) {
|
||||
require.Equal(t, 1, userMessages, "expected queued message to stay queued after archive")
|
||||
}
|
||||
|
||||
func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
@@ -501,19 +501,24 @@ func TestUpdateChatHeartbeatRequiresOwnership(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rows, err := db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Wrong worker_id should return no IDs.
|
||||
ids, err := db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: uuid.New(),
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), rows)
|
||||
require.Empty(t, ids)
|
||||
|
||||
rows, err = db.UpdateChatHeartbeat(ctx, database.UpdateChatHeartbeatParams{
|
||||
ID: chat.ID,
|
||||
// Correct worker_id should return the chat's ID.
|
||||
ids, err = db.UpdateChatHeartbeats(ctx, database.UpdateChatHeartbeatsParams{
|
||||
IDs: []uuid.UUID{chat.ID},
|
||||
WorkerID: workerID,
|
||||
Now: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), rows)
|
||||
require.Len(t, ids, 1)
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -49,10 +50,11 @@ const connectTimeout = 10 * time.Second
|
||||
const toolCallTimeout = 60 * time.Second
|
||||
|
||||
// ConnectAll connects to all configured MCP servers, discovers
|
||||
// their tools, and returns them as fantasy.AgentTool values. It
|
||||
// skips servers that fail to connect and logs warnings. The
|
||||
// returned cleanup function must be called to close all
|
||||
// connections.
|
||||
// their tools, and returns them as fantasy.AgentTool values.
|
||||
// Tools are sorted by their prefixed name so callers
|
||||
// receive a deterministic order. It skips servers that fail to
|
||||
// connect and logs warnings. The returned cleanup function
|
||||
// must be called to close all connections.
|
||||
func ConnectAll(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
@@ -108,7 +110,9 @@ func ConnectAll(
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
clients = append(clients, mcpClient)
|
||||
if mcpClient != nil {
|
||||
clients = append(clients, mcpClient)
|
||||
}
|
||||
tools = append(tools, serverTools...)
|
||||
mu.Unlock()
|
||||
return nil
|
||||
@@ -119,6 +123,31 @@ func ConnectAll(
|
||||
// discarded.
|
||||
_ = eg.Wait()
|
||||
|
||||
// Sort tools by prefixed name for deterministic ordering
|
||||
// regardless of goroutine completion order. Ties, possible
|
||||
// when the __ separator produces ambiguous prefixed names,
|
||||
// are broken by config ID. Stable prompt construction
|
||||
// depends on consistent tool ordering.
|
||||
slices.SortFunc(tools, func(a, b fantasy.AgentTool) int {
|
||||
// All tools in this slice are mcpToolWrapper values
|
||||
// created by connectOne above, so these checked
|
||||
// assertions should always succeed. The config ID
|
||||
// tiebreaker resolves the __ separator ambiguity
|
||||
// documented at the top of this file.
|
||||
aTool, ok := a.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", a))
|
||||
}
|
||||
bTool, ok := b.(MCPToolIdentifier)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("unexpected tool type %T", b))
|
||||
}
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.Info().Name, b.Info().Name),
|
||||
cmp.Compare(aTool.MCPServerConfigID().String(), bTool.MCPServerConfigID().String()),
|
||||
)
|
||||
})
|
||||
|
||||
return tools, cleanup
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,17 @@ func greetTool() mcpserver.ServerTool {
|
||||
}
|
||||
}
|
||||
|
||||
// makeTool returns a ServerTool with the given name and a
|
||||
// no-op handler that always returns "ok".
|
||||
func makeTool(name string) mcpserver.ServerTool {
|
||||
return mcpserver.ServerTool{
|
||||
Tool: mcp.NewTool(name),
|
||||
Handler: func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// makeConfig builds a database.MCPServerConfig suitable for tests.
|
||||
func makeConfig(slug, url string) database.MCPServerConfig {
|
||||
return database.MCPServerConfig{
|
||||
@@ -198,6 +209,121 @@ func TestConnectAll_MultipleServers(t *testing.T) {
|
||||
assert.Contains(t, names, "beta__greet")
|
||||
}
|
||||
|
||||
func TestConnectAll_NoToolsAfterFiltering(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts := newTestMCPServer(t, echoTool())
|
||||
|
||||
cfg := makeConfig("filtered", ts.URL)
|
||||
cfg.ToolAllowList = []string{"greet"}
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg},
|
||||
nil,
|
||||
)
|
||||
|
||||
require.Empty(t, tools)
|
||||
assert.NotPanics(t, cleanup)
|
||||
}
|
||||
|
||||
func TestConnectAll_DeterministicOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("AcrossServers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("zebra"))
|
||||
ts2 := newTestMCPServer(t, makeTool("alpha"))
|
||||
ts3 := newTestMCPServer(t, makeTool("middle"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("srv3", ts3.URL),
|
||||
makeConfig("srv1", ts1.URL),
|
||||
makeConfig("srv2", ts2.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
// Sorted by full prefixed name (slug__tool), so slug
|
||||
// order determines the sequence, not the tool name.
|
||||
assert.Equal(t,
|
||||
[]string{"srv1__zebra", "srv2__alpha", "srv3__middle"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("WithMultiToolServer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
multi := newTestMCPServer(t, makeTool("zeta"), makeTool("beta"))
|
||||
other := newTestMCPServer(t, makeTool("gamma"))
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{
|
||||
makeConfig("zzz", multi.URL),
|
||||
makeConfig("aaa", other.URL),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 3)
|
||||
assert.Equal(t,
|
||||
[]string{"aaa__gamma", "zzz__beta", "zzz__zeta"},
|
||||
toolNames(tools),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("TiebreakByConfigID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
|
||||
ts1 := newTestMCPServer(t, makeTool("b__z"))
|
||||
ts2 := newTestMCPServer(t, makeTool("z"))
|
||||
|
||||
// Use fixed UUIDs so the tiebreaker order is
|
||||
// predictable. Both servers produce the same prefixed
|
||||
// name, a__b__z, due to the __ separator ambiguity.
|
||||
cfg1 := makeConfig("a", ts1.URL)
|
||||
cfg1.ID = uuid.MustParse("00000000-0000-0000-0000-000000000002")
|
||||
|
||||
cfg2 := makeConfig("a__b", ts2.URL)
|
||||
cfg2.ID = uuid.MustParse("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
tools, cleanup := mcpclient.ConnectAll(
|
||||
ctx,
|
||||
logger,
|
||||
[]database.MCPServerConfig{cfg1, cfg2},
|
||||
nil,
|
||||
)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
require.Len(t, tools, 2)
|
||||
assert.Equal(t, []string{"a__b__z", "a__b__z"}, toolNames(tools))
|
||||
|
||||
id0 := tools[0].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
id1 := tools[1].(mcpclient.MCPToolIdentifier).MCPServerConfigID()
|
||||
assert.Equal(t, cfg2.ID, id0, "lower config ID should sort first")
|
||||
assert.Equal(t, cfg1.ID, id1, "higher config ID should sort second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectAll_AuthHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -212,6 +212,7 @@ type AuditLogsRequest struct {
|
||||
type AuditLogResponse struct {
|
||||
AuditLogs []AuditLog `json:"audit_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
type CreateTestAuditLogRequest struct {
|
||||
|
||||
@@ -96,6 +96,7 @@ type ConnectionLogsRequest struct {
|
||||
type ConnectionLogResponse struct {
|
||||
ConnectionLogs []ConnectionLog `json:"connection_logs"`
|
||||
Count int64 `json:"count"`
|
||||
CountCap int64 `json:"count_cap"`
|
||||
}
|
||||
|
||||
func (c *Client) ConnectionLogs(ctx context.Context, req ConnectionLogsRequest) (ConnectionLogResponse, error) {
|
||||
|
||||
Generated
+2
-1
@@ -90,7 +90,8 @@ curl -X GET http://coder-server:8080/api/v2/audit?limit=0 \
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+2
-1
@@ -291,7 +291,8 @@ curl -X GET http://coder-server:8080/api/v2/connectionlog?limit=0 \
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
Generated
+6
-2
@@ -1740,7 +1740,8 @@
|
||||
"user_agent": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1750,6 +1751,7 @@
|
||||
|--------------|-------------------------------------------------|----------|--------------|-------------|
|
||||
| `audit_logs` | array of [codersdk.AuditLog](#codersdkauditlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.AuthMethod
|
||||
|
||||
@@ -2173,7 +2175,8 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
"workspace_owner_username": "string"
|
||||
}
|
||||
],
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"count_cap": 0
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2183,6 +2186,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in
|
||||
|-------------------|-----------------------------------------------------------|----------|--------------|-------------|
|
||||
| `connection_logs` | array of [codersdk.ConnectionLog](#codersdkconnectionlog) | false | | |
|
||||
| `count` | integer | false | | |
|
||||
| `count_cap` | integer | false | | |
|
||||
|
||||
## codersdk.ConnectionLogSSHInfo
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// NOTE: See the auditLogCountCap note.
|
||||
const connectionLogCountCap = 2000
|
||||
|
||||
// @Summary Get connection logs
|
||||
// @ID get-connection-logs
|
||||
// @Security CoderSessionToken
|
||||
@@ -49,6 +52,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
// #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range
|
||||
filter.LimitOpt = int32(page.Limit)
|
||||
|
||||
countFilter.CountCap = connectionLogCountCap
|
||||
count, err := api.Database.CountConnectionLogs(ctx, countFilter)
|
||||
if dbauthz.IsNotAuthorizedError(err) {
|
||||
httpapi.Forbidden(rw)
|
||||
@@ -63,6 +67,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: []codersdk.ConnectionLog{},
|
||||
Count: 0,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -80,6 +85,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{
|
||||
ConnectionLogs: convertConnectionLogs(dblogs),
|
||||
Count: count,
|
||||
CountCap: connectionLogCountCap,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Generated
+2
@@ -913,6 +913,7 @@ export interface AuditLog {
|
||||
export interface AuditLogResponse {
|
||||
readonly audit_logs: readonly AuditLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/audit.go
|
||||
@@ -2269,6 +2270,7 @@ export interface ConnectionLog {
|
||||
export interface ConnectionLogResponse {
|
||||
readonly connection_logs: readonly ConnectionLog[];
|
||||
readonly count: number;
|
||||
readonly count_cap: number;
|
||||
}
|
||||
|
||||
// From codersdk/connectionlog.go
|
||||
|
||||
@@ -35,10 +35,10 @@ export const AvatarData: FC<AvatarDataProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center w-full min-w-0 gap-3">
|
||||
<div className="flex items-center w-full gap-3">
|
||||
{avatar}
|
||||
|
||||
<div className="flex flex-col w-full min-w-0">
|
||||
<div className="flex flex-col w-full">
|
||||
<span className="text-sm font-semibold text-content-primary">
|
||||
{title}
|
||||
</span>
|
||||
|
||||
@@ -7,6 +7,7 @@ type PaginationHeaderProps = {
|
||||
limit: number;
|
||||
totalRecords: number | undefined;
|
||||
currentOffsetStart: number | undefined;
|
||||
countIsCapped?: boolean;
|
||||
|
||||
// Temporary escape hatch until Workspaces can be switched over to using
|
||||
// PaginationContainer
|
||||
@@ -18,6 +19,7 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
limit,
|
||||
totalRecords,
|
||||
currentOffsetStart,
|
||||
countIsCapped,
|
||||
className,
|
||||
}) => {
|
||||
const theme = useTheme();
|
||||
@@ -52,10 +54,16 @@ export const PaginationAmount: FC<PaginationHeaderProps> = ({
|
||||
<strong>
|
||||
{(
|
||||
currentOffsetStart +
|
||||
Math.min(limit - 1, totalRecords - currentOffsetStart)
|
||||
(countIsCapped
|
||||
? limit - 1
|
||||
: Math.min(limit - 1, totalRecords - currentOffsetStart))
|
||||
).toLocaleString()}
|
||||
</strong>{" "}
|
||||
of <strong>{totalRecords.toLocaleString()}</strong>{" "}
|
||||
of{" "}
|
||||
<strong>
|
||||
{totalRecords.toLocaleString()}
|
||||
{countIsCapped && "+"}
|
||||
</strong>{" "}
|
||||
{paginationUnitLabel}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -18,6 +18,7 @@ export const mockPaginationResultBase: ResultBase = {
|
||||
limit: 25,
|
||||
hasNextPage: false,
|
||||
hasPreviousPage: false,
|
||||
countIsCapped: false,
|
||||
goToPreviousPage: () => {},
|
||||
goToNextPage: () => {},
|
||||
goToFirstPage: () => {},
|
||||
@@ -33,6 +34,7 @@ export const mockInitialRenderResult: PaginationResult = {
|
||||
hasPreviousPage: false,
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
countIsCapped: false,
|
||||
};
|
||||
|
||||
export const mockSuccessResult: PaginationResult = {
|
||||
|
||||
@@ -94,7 +94,7 @@ export const FirstPageWithTonsOfData: Story = {
|
||||
currentPage: 2,
|
||||
currentOffsetStart: 1000,
|
||||
totalRecords: 123_456,
|
||||
totalPages: 1235,
|
||||
totalPages: 4939,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
@@ -135,3 +135,54 @@ export const SecondPageWithData: Story = {
|
||||
children: <div>New data for page 2</div>,
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountFirstPage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 1,
|
||||
currentOffsetStart: 1,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: false,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountMiddlePage: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 3,
|
||||
currentOffsetStart: 51,
|
||||
totalRecords: 2000,
|
||||
totalPages: 80,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CappedCountBeyondKnownPages: Story = {
|
||||
args: {
|
||||
query: {
|
||||
...mockPaginationResultBase,
|
||||
isSuccess: true,
|
||||
currentPage: 85,
|
||||
currentOffsetStart: 2101,
|
||||
totalRecords: 2000,
|
||||
totalPages: 85,
|
||||
hasPreviousPage: true,
|
||||
hasNextPage: true,
|
||||
isPlaceholderData: false,
|
||||
countIsCapped: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -27,12 +27,14 @@ export const PaginationContainer: FC<PaginationProps> = ({
|
||||
totalRecords={query.totalRecords}
|
||||
currentOffsetStart={query.currentOffsetStart}
|
||||
paginationUnitLabel={paginationUnitLabel}
|
||||
countIsCapped={query.countIsCapped}
|
||||
className="justify-end"
|
||||
/>
|
||||
|
||||
{query.isSuccess && (
|
||||
<PaginationWidgetBase
|
||||
totalRecords={query.totalRecords}
|
||||
totalPages={query.totalPages}
|
||||
currentPage={query.currentPage}
|
||||
pageSize={query.limit}
|
||||
onPageChange={query.onPageChange}
|
||||
|
||||
@@ -12,6 +12,10 @@ export type PaginationWidgetBaseProps = {
|
||||
|
||||
hasPreviousPage?: boolean;
|
||||
hasNextPage?: boolean;
|
||||
/** Override the computed totalPages.
|
||||
* Used when, e.g., the row count is capped and the user navigates beyond
|
||||
* the known range, so totalPages stays at least as high as currentPage. */
|
||||
totalPages?: number;
|
||||
};
|
||||
|
||||
export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
@@ -21,8 +25,9 @@ export const PaginationWidgetBase: FC<PaginationWidgetBaseProps> = ({
|
||||
onPageChange,
|
||||
hasPreviousPage,
|
||||
hasNextPage,
|
||||
totalPages: totalPagesProp,
|
||||
}) => {
|
||||
const totalPages = Math.ceil(totalRecords / pageSize);
|
||||
const totalPages = totalPagesProp ?? Math.ceil(totalRecords / pageSize);
|
||||
|
||||
if (totalPages < 2) {
|
||||
return null;
|
||||
|
||||
@@ -258,6 +258,78 @@ describe(usePaginatedQuery.name, () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count behavior", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
|
||||
// Returns count 2001 (capped) with items on pages up to page 84
|
||||
// (84 * 25 = 2100 items total).
|
||||
const mockCappedQueryFn = vi.fn(({ pageNumber, limit }) => {
|
||||
const totalItems = 2100;
|
||||
const offset = (pageNumber - 1) * limit;
|
||||
// Returns 0 items when the requested page is past the end, simulating
|
||||
// an empty server response.
|
||||
const itemsOnPage = Math.max(0, Math.min(limit, totalItems - offset));
|
||||
return Promise.resolve({
|
||||
data: new Array(itemsOnPage).fill(pageNumber),
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
});
|
||||
|
||||
it("Caps totalRecords at 2000 when count exceeds cap", async () => {
|
||||
const { result } = await render({
|
||||
queryKey: mockQueryKey,
|
||||
queryFn: mockCappedQueryFn,
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.totalRecords).toBe(2000);
|
||||
});
|
||||
|
||||
it("hasNextPage is true when count is capped", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=80",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasNextPage).toBe(true);
|
||||
});
|
||||
|
||||
it("hasPreviousPage is true when count is capped and page is beyond cap", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
expect(result.current.hasPreviousPage).toBe(true);
|
||||
});
|
||||
|
||||
it("Does not redirect to last page when count is capped and page is valid", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=83",
|
||||
);
|
||||
|
||||
await waitFor(() => expect(result.current.isSuccess).toBe(true));
|
||||
// Should stay on page 83 — not redirect to page 80.
|
||||
expect(result.current.currentPage).toBe(83);
|
||||
});
|
||||
|
||||
it("Redirects to last known page when navigating beyond actual data", async () => {
|
||||
const { result } = await render(
|
||||
{ queryKey: mockQueryKey, queryFn: mockCappedQueryFn },
|
||||
"/?page=999",
|
||||
);
|
||||
|
||||
// Page 999 has no items. Should redirect to page 81
|
||||
// (ceil(2001 / 25) = 81), the last page guaranteed to
|
||||
// have data.
|
||||
await waitFor(() => expect(result.current.currentPage).toBe(81));
|
||||
});
|
||||
});
|
||||
|
||||
describe("Passing in searchParams property", () => {
|
||||
const mockQueryKey = vi.fn(() => ["mock"]);
|
||||
const mockQueryFn = vi.fn(({ pageNumber, limit }) =>
|
||||
|
||||
@@ -144,16 +144,44 @@ export function usePaginatedQuery<
|
||||
placeholderData: keepPreviousData,
|
||||
});
|
||||
|
||||
const totalRecords = query.data?.count;
|
||||
const totalPages =
|
||||
totalRecords !== undefined ? Math.ceil(totalRecords / limit) : undefined;
|
||||
const count = query.data?.count;
|
||||
const countCap = query.data?.count_cap;
|
||||
const countIsCapped =
|
||||
countCap !== undefined &&
|
||||
countCap > 0 &&
|
||||
count !== undefined &&
|
||||
count > countCap;
|
||||
const totalRecords = countIsCapped ? countCap : count;
|
||||
let totalPages =
|
||||
totalRecords !== undefined
|
||||
? Math.max(
|
||||
Math.ceil(totalRecords / limit),
|
||||
// True count is not known; let them navigate forward
|
||||
// until they hit an empty page (checked below).
|
||||
countIsCapped ? currentPage : 0,
|
||||
)
|
||||
: undefined;
|
||||
|
||||
// When the true count is unknown, the user can navigate past
|
||||
// all actual data. If that happens, we need to redirect (via
|
||||
// updatePageIfInvalid) to the last page guaranteed to be not
|
||||
// empty.
|
||||
const pageIsEmpty =
|
||||
query.data != null &&
|
||||
!Object.values(query.data).some((v) => Array.isArray(v) && v.length > 0);
|
||||
if (pageIsEmpty) {
|
||||
totalPages = count !== undefined ? Math.ceil(count / limit) : 1;
|
||||
}
|
||||
|
||||
const hasNextPage =
|
||||
totalRecords !== undefined && limit + currentPageOffset < totalRecords;
|
||||
totalRecords !== undefined &&
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
limit + currentPageOffset < totalRecords);
|
||||
const hasPreviousPage =
|
||||
totalRecords !== undefined &&
|
||||
currentPage > 1 &&
|
||||
currentPageOffset - limit < totalRecords;
|
||||
((countIsCapped && !pageIsEmpty) ||
|
||||
currentPageOffset - limit < totalRecords);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const prefetchPage = useEffectEvent((newPage: number) => {
|
||||
@@ -224,10 +252,14 @@ export function usePaginatedQuery<
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (!query.isFetching && totalPages !== undefined) {
|
||||
if (
|
||||
!query.isFetching &&
|
||||
totalPages !== undefined &&
|
||||
currentPage > totalPages
|
||||
) {
|
||||
void updatePageIfInvalid(totalPages);
|
||||
}
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages]);
|
||||
}, [updatePageIfInvalid, query.isFetching, totalPages, currentPage]);
|
||||
|
||||
const onPageChange = (newPage: number) => {
|
||||
// Page 1 is the only page that can be safely navigated to without knowing
|
||||
@@ -236,7 +268,12 @@ export function usePaginatedQuery<
|
||||
return;
|
||||
}
|
||||
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, totalPages ?? 1);
|
||||
// If the true count is unknown, we allow navigating past the
|
||||
// known page range.
|
||||
const upperBound = countIsCapped
|
||||
? Number.MAX_SAFE_INTEGER
|
||||
: (totalPages ?? 1);
|
||||
const cleanedInput = clamp(Math.trunc(newPage), 1, upperBound);
|
||||
if (Number.isNaN(cleanedInput)) {
|
||||
return;
|
||||
}
|
||||
@@ -274,6 +311,7 @@ export function usePaginatedQuery<
|
||||
totalRecords: totalRecords as number,
|
||||
totalPages: totalPages as number,
|
||||
currentOffsetStart: currentPageOffset + 1,
|
||||
countIsCapped,
|
||||
}
|
||||
: {
|
||||
isSuccess: false,
|
||||
@@ -282,6 +320,7 @@ export function usePaginatedQuery<
|
||||
totalRecords: undefined,
|
||||
totalPages: undefined,
|
||||
currentOffsetStart: undefined,
|
||||
countIsCapped: false as const,
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -323,6 +362,7 @@ export type PaginationResultInfo = {
|
||||
totalRecords: undefined;
|
||||
totalPages: undefined;
|
||||
currentOffsetStart: undefined;
|
||||
countIsCapped: false;
|
||||
}
|
||||
| {
|
||||
isSuccess: true;
|
||||
@@ -331,6 +371,7 @@ export type PaginationResultInfo = {
|
||||
totalRecords: number;
|
||||
totalPages: number;
|
||||
currentOffsetStart: number;
|
||||
countIsCapped: boolean;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -417,6 +458,7 @@ type QueryPageParamsWithPayload<TPayload = never> = QueryPageParams & {
|
||||
*/
|
||||
export type PaginatedData = {
|
||||
count: number;
|
||||
count_cap?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -76,6 +76,8 @@ const defaultAgentMetadata = [
|
||||
},
|
||||
];
|
||||
|
||||
const fixedLogTimestamp = "2021-05-05T00:00:00.000Z";
|
||||
|
||||
const logs = [
|
||||
"\x1b[91mCloning Git repository...",
|
||||
"\x1b[2;37;41mStarting Docker Daemon...",
|
||||
@@ -87,7 +89,7 @@ const logs = [
|
||||
level: "info",
|
||||
output: line,
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
}));
|
||||
|
||||
const installScriptLogSource: TypesGen.WorkspaceAgentLogSource = {
|
||||
@@ -102,21 +104,21 @@ const tabbedLogs = [
|
||||
level: "info",
|
||||
output: "startup: preparing workspace",
|
||||
source_id: M.MockWorkspaceAgentLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
{
|
||||
id: 101,
|
||||
level: "info",
|
||||
output: "install: pnpm install",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
{
|
||||
id: 102,
|
||||
level: "info",
|
||||
output: "install: setup complete",
|
||||
source_id: installScriptLogSource.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -153,7 +155,7 @@ const overflowLogs = overflowLogSources.map((source, index) => ({
|
||||
level: "info",
|
||||
output: `${source.display_name}: line`,
|
||||
source_id: source.id,
|
||||
created_at: new Date().toISOString(),
|
||||
created_at: fixedLogTimestamp,
|
||||
}));
|
||||
|
||||
const meta: Meta<typeof AgentRow> = {
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import dayjs, { type Dayjs } from "dayjs";
|
||||
import { type FC, useState } from "react";
|
||||
import { useQuery } from "react-query";
|
||||
import { useSearchParams } from "react-router";
|
||||
import { chatCostSummary } from "#/api/queries/chats";
|
||||
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
|
||||
import { ScrollArea } from "#/components/ScrollArea/ScrollArea";
|
||||
import { useAuthContext } from "#/contexts/auth/AuthProvider";
|
||||
import { AgentAnalyticsPageView } from "./AgentAnalyticsPageView";
|
||||
import { AgentPageHeader } from "./components/AgentPageHeader";
|
||||
|
||||
const startDateSearchParam = "startDate";
|
||||
const endDateSearchParam = "endDate";
|
||||
|
||||
const getDefaultDateRange = (now?: Dayjs): DateRangeValue => {
|
||||
const createDateRange = (now?: Dayjs) => {
|
||||
const end = now ?? dayjs();
|
||||
const start = end.subtract(30, "day");
|
||||
return {
|
||||
startDate: end.subtract(30, "day").toDate(),
|
||||
endDate: end.toDate(),
|
||||
startDate: start.toISOString(),
|
||||
endDate: end.toISOString(),
|
||||
rangeLabel: `${start.format("MMM D")} – ${end.format("MMM D, YYYY")}`,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -27,44 +24,13 @@ interface AgentAnalyticsPageProps {
|
||||
|
||||
const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
|
||||
const { user } = useAuthContext();
|
||||
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const startDateParam = searchParams.get(startDateSearchParam)?.trim() ?? "";
|
||||
const endDateParam = searchParams.get(endDateSearchParam)?.trim() ?? "";
|
||||
const [defaultDateRange] = useState(() => getDefaultDateRange(now));
|
||||
let dateRange = defaultDateRange;
|
||||
let hasExplicitDateRange = false;
|
||||
|
||||
if (startDateParam && endDateParam) {
|
||||
const parsedStartDate = new Date(startDateParam);
|
||||
const parsedEndDate = new Date(endDateParam);
|
||||
|
||||
if (
|
||||
!Number.isNaN(parsedStartDate.getTime()) &&
|
||||
!Number.isNaN(parsedEndDate.getTime()) &&
|
||||
parsedStartDate.getTime() <= parsedEndDate.getTime()
|
||||
) {
|
||||
dateRange = {
|
||||
startDate: parsedStartDate,
|
||||
endDate: parsedEndDate,
|
||||
};
|
||||
hasExplicitDateRange = true;
|
||||
}
|
||||
}
|
||||
|
||||
const onDateRangeChange = (value: DateRangeValue) => {
|
||||
setSearchParams((prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.set(startDateSearchParam, value.startDate.toISOString());
|
||||
next.set(endDateSearchParam, value.endDate.toISOString());
|
||||
return next;
|
||||
});
|
||||
};
|
||||
const [anchor] = useState<Dayjs>(() => dayjs());
|
||||
const dateRange = createDateRange(now ?? anchor);
|
||||
|
||||
const summaryQuery = useQuery({
|
||||
...chatCostSummary(user?.id ?? "me", {
|
||||
start_date: dateRange.startDate.toISOString(),
|
||||
end_date: dateRange.endDate.toISOString(),
|
||||
start_date: dateRange.startDate,
|
||||
end_date: dateRange.endDate,
|
||||
}),
|
||||
enabled: Boolean(user?.id),
|
||||
});
|
||||
@@ -77,9 +43,7 @@ const AgentAnalyticsPage: FC<AgentAnalyticsPageProps> = ({ now }) => {
|
||||
isLoading={summaryQuery.isLoading}
|
||||
error={summaryQuery.error}
|
||||
onRetry={() => void summaryQuery.refetch()}
|
||||
dateRange={dateRange}
|
||||
onDateRangeChange={onDateRangeChange}
|
||||
hasExplicitDateRange={hasExplicitDateRange}
|
||||
rangeLabel={dateRange.rangeLabel}
|
||||
/>
|
||||
</ScrollArea>
|
||||
);
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
import { BarChart3Icon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import type { ChatCostSummary } from "#/api/typesGenerated";
|
||||
import {
|
||||
DateRangePicker,
|
||||
type DateRangeValue,
|
||||
} from "#/components/DateRangePicker/DateRangePicker";
|
||||
import { ChatCostSummaryView } from "./components/ChatCostSummaryView";
|
||||
import { SectionHeader } from "./components/SectionHeader";
|
||||
import { toInclusiveDateRange } from "./utils/dateRange";
|
||||
|
||||
interface AgentAnalyticsPageViewProps {
|
||||
summary: ChatCostSummary | undefined;
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
onRetry: () => void;
|
||||
dateRange: DateRangeValue;
|
||||
onDateRangeChange: (value: DateRangeValue) => void;
|
||||
hasExplicitDateRange: boolean;
|
||||
rangeLabel: string;
|
||||
}
|
||||
|
||||
export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
@@ -23,15 +17,8 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
isLoading,
|
||||
error,
|
||||
onRetry,
|
||||
dateRange,
|
||||
onDateRangeChange,
|
||||
hasExplicitDateRange,
|
||||
rangeLabel,
|
||||
}) => {
|
||||
const displayDateRange = toInclusiveDateRange(
|
||||
dateRange,
|
||||
hasExplicitDateRange,
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col p-4 pt-8">
|
||||
<div className="mx-auto w-full max-w-3xl">
|
||||
@@ -39,10 +26,10 @@ export const AgentAnalyticsPageView: FC<AgentAnalyticsPageViewProps> = ({
|
||||
label="Analytics"
|
||||
description="Review your personal Coder Agents usage and cost breakdowns."
|
||||
action={
|
||||
<DateRangePicker
|
||||
value={displayDateRange}
|
||||
onChange={onDateRangeChange}
|
||||
/>
|
||||
<div className="flex items-center gap-2 text-xs text-content-secondary">
|
||||
<BarChart3Icon className="h-4 w-4" />
|
||||
<span>{rangeLabel}</span>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
|
||||
|
||||
@@ -646,17 +646,24 @@ const AgentChatPage: FC = () => {
|
||||
const isRegenerateTitleDisabled = isArchived || isRegeneratingThisChat;
|
||||
const chatLastModelConfigID = chatRecord?.last_model_config_id;
|
||||
|
||||
const sendMutation = useMutation(
|
||||
// Destructure mutation results directly so the React Compiler
|
||||
// tracks stable primitives/functions instead of the whole result
|
||||
// object (TanStack Query v5 recreates it every render via object
|
||||
// spread). Keeping no intermediate variable prevents future code
|
||||
// from accidentally closing over the unstable object.
|
||||
const { isPending: isSendPending, mutateAsync: sendMessage } = useMutation(
|
||||
createChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const editMutation = useMutation(editChatMessage(queryClient, agentId ?? ""));
|
||||
const interruptMutation = useMutation(
|
||||
const { isPending: isEditPending, mutateAsync: editMessage } = useMutation(
|
||||
editChatMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const { isPending: isInterruptPending, mutateAsync: interrupt } = useMutation(
|
||||
interruptChat(queryClient, agentId ?? ""),
|
||||
);
|
||||
const deleteQueuedMutation = useMutation(
|
||||
const { mutateAsync: deleteQueuedMessage } = useMutation(
|
||||
deleteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
const promoteQueuedMutation = useMutation(
|
||||
const { mutateAsync: promoteQueuedMessage } = useMutation(
|
||||
promoteChatQueuedMessage(queryClient, agentId ?? ""),
|
||||
);
|
||||
|
||||
@@ -754,9 +761,7 @@ const AgentChatPage: FC = () => {
|
||||
hasUserFixableModelProviders,
|
||||
});
|
||||
const isSubmissionPending =
|
||||
sendMutation.isPending ||
|
||||
editMutation.isPending ||
|
||||
interruptMutation.isPending;
|
||||
isSendPending || isEditPending || isInterruptPending;
|
||||
const isInputDisabled = !hasModelOptions || isArchived;
|
||||
|
||||
const handleUsageLimitError = (error: unknown): void => {
|
||||
@@ -842,7 +847,7 @@ const AgentChatPage: FC = () => {
|
||||
setPendingEditMessageId(editedMessageID);
|
||||
scrollToBottomRef.current?.();
|
||||
try {
|
||||
await editMutation.mutateAsync({
|
||||
await editMessage({
|
||||
messageId: editedMessageID,
|
||||
req: request,
|
||||
});
|
||||
@@ -873,9 +878,9 @@ const AgentChatPage: FC = () => {
|
||||
// For queued sends the WebSocket status events handle
|
||||
// clearing; for non-queued sends we clear explicitly
|
||||
// below. Clearing eagerly causes a visible cutoff.
|
||||
let response: Awaited<ReturnType<typeof sendMutation.mutateAsync>>;
|
||||
let response: Awaited<ReturnType<typeof sendMessage>>;
|
||||
try {
|
||||
response = await sendMutation.mutateAsync(request);
|
||||
response = await sendMessage(request);
|
||||
} catch (error) {
|
||||
handleUsageLimitError(error);
|
||||
throw error;
|
||||
@@ -908,10 +913,10 @@ const AgentChatPage: FC = () => {
|
||||
};
|
||||
|
||||
const handleInterrupt = () => {
|
||||
if (!agentId || interruptMutation.isPending) {
|
||||
if (!agentId || isInterruptPending) {
|
||||
return;
|
||||
}
|
||||
void interruptMutation.mutateAsync();
|
||||
void interrupt();
|
||||
};
|
||||
|
||||
const handleDeleteQueuedMessage = async (id: number) => {
|
||||
@@ -920,7 +925,7 @@ const AgentChatPage: FC = () => {
|
||||
previousQueuedMessages.filter((message) => message.id !== id),
|
||||
);
|
||||
try {
|
||||
await deleteQueuedMutation.mutateAsync(id);
|
||||
await deleteQueuedMessage(id);
|
||||
} catch (error) {
|
||||
store.setQueuedMessages(previousQueuedMessages);
|
||||
throw error;
|
||||
@@ -941,7 +946,7 @@ const AgentChatPage: FC = () => {
|
||||
store.clearStreamError();
|
||||
store.setChatStatus("pending");
|
||||
try {
|
||||
const promotedMessage = await promoteQueuedMutation.mutateAsync(id);
|
||||
const promotedMessage = await promoteQueuedMessage(id);
|
||||
// Insert the promoted message into the store and cache
|
||||
// immediately so it appears in the timeline without
|
||||
// waiting for the WebSocket to deliver it.
|
||||
@@ -990,7 +995,8 @@ const AgentChatPage: FC = () => {
|
||||
? `ssh ${workspaceAgent.name}.${workspace.name}.${workspace.owner_name}.${sshConfigQuery.data.hostname_suffix}`
|
||||
: undefined;
|
||||
|
||||
const generateKeyMutation = useMutation({
|
||||
// See mutation destructuring comment above (React Compiler).
|
||||
const { mutate: generateKey } = useMutation({
|
||||
mutationFn: () => API.getApiKey(),
|
||||
});
|
||||
|
||||
@@ -1005,7 +1011,7 @@ const AgentChatPage: FC = () => {
|
||||
const repoRoots = Array.from(gitWatcher.repositories.keys()).sort();
|
||||
const folder = repoRoots[0] ?? workspaceAgent.expanded_directory;
|
||||
|
||||
generateKeyMutation.mutate(undefined, {
|
||||
generateKey(undefined, {
|
||||
onSuccess: ({ key }) => {
|
||||
location.href = getVSCodeHref(editor, {
|
||||
owner: workspace.owner_name,
|
||||
@@ -1141,7 +1147,7 @@ const AgentChatPage: FC = () => {
|
||||
compressionThreshold={compressionThreshold}
|
||||
isInputDisabled={isInputDisabled}
|
||||
isSubmissionPending={isSubmissionPending}
|
||||
isInterruptPending={interruptMutation.isPending}
|
||||
isInterruptPending={isInterruptPending}
|
||||
isSidebarCollapsed={isSidebarCollapsed}
|
||||
onToggleSidebarCollapsed={onToggleSidebarCollapsed}
|
||||
showSidebarPanel={showSidebarPanel}
|
||||
|
||||
@@ -372,7 +372,7 @@ export const DefaultAutostopNotVisibleToNonAdmin: Story = {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
// Personal Instructions should be visible.
|
||||
await canvas.findByText("Instructions");
|
||||
await canvas.findByText("Personal Instructions");
|
||||
|
||||
// Admin-only sections should not be present.
|
||||
expect(canvas.queryByText("Workspace Autostop Fallback")).toBeNull();
|
||||
@@ -412,7 +412,7 @@ export const InvisibleUnicodeWarningUserPrompt: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
await canvas.findByText("Instructions");
|
||||
await canvas.findByText("Personal Instructions");
|
||||
|
||||
const alert = await canvas.findByText(/invisible Unicode/);
|
||||
expect(alert).toBeInTheDocument();
|
||||
@@ -454,7 +454,7 @@ export const NoWarningForCleanPrompt: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
await canvas.findByText("Instructions");
|
||||
await canvas.findByText("Personal Instructions");
|
||||
await canvas.findByText("System Instructions");
|
||||
|
||||
expect(canvas.queryByText(/invisible Unicode/)).toBeNull();
|
||||
|
||||
@@ -9,13 +9,6 @@ import { Switch } from "#/components/Switch/Switch";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { countInvisibleCharacters } from "#/utils/invisibleUnicode";
|
||||
import { AdminBadge } from "./components/AdminBadge";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "./components/CollapsibleSection";
|
||||
import { DurationField } from "./components/DurationField/DurationField";
|
||||
import { SectionHeader } from "./components/SectionHeader";
|
||||
import { TextPreviewDialog } from "./components/TextPreviewDialog";
|
||||
@@ -253,43 +246,142 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<>
|
||||
<SectionHeader
|
||||
label="Behavior"
|
||||
description="Custom instructions that shape how the agent responds in your conversations."
|
||||
/>
|
||||
{/* ── Personal prompt (always visible) ── */}
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveUserPrompt(event)}
|
||||
>
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Personal Instructions
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Applied to all your conversations. Only visible to you.
|
||||
</p>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isUserPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional behavior, style, and tone preferences"
|
||||
value={userPromptDraft}
|
||||
onChange={(event) => setLocalUserEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsUserPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
disabled={isPromptSaving}
|
||||
minRows={1}
|
||||
/>
|
||||
{userInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {userInvisibleCharCount} invisible Unicode{" "}
|
||||
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
|
||||
could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalUserEdit("")}
|
||||
disabled={isPromptSaving || !userPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isPromptSaving || !isUserPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveUserPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save personal instructions.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
|
||||
{/* Personal prompt (always visible) */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Instructions</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>Applied to all your conversations. Only visible to you.</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<UserCompactionThresholdSettings
|
||||
modelConfigs={modelConfigsData ?? []}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoadingModelConfigs={isLoadingModelConfigs}
|
||||
thresholds={thresholds}
|
||||
isThresholdsLoading={isThresholdsLoading}
|
||||
thresholdsError={thresholdsError}
|
||||
onSaveThreshold={onSaveThreshold}
|
||||
onResetThreshold={onResetThreshold}
|
||||
/>
|
||||
{/* ── Admin system prompt (admin only) ── */}
|
||||
{canSetSystemPrompt && (
|
||||
<>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveUserPrompt(event)}
|
||||
onSubmit={(event) => void handleSaveSystemPrompt(event)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
System Instructions
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
|
||||
<span>Include Coder Agents default system prompt</span>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="subtle"
|
||||
type="button"
|
||||
onClick={() => setShowDefaultPromptPreview(true)}
|
||||
disabled={!hasLoadedSystemPrompt}
|
||||
className="min-w-0 px-0 text-content-link hover:text-content-link"
|
||||
>
|
||||
Preview
|
||||
</Button>
|
||||
</div>
|
||||
<Switch
|
||||
checked={includeDefaultDraft}
|
||||
onCheckedChange={setLocalIncludeDefault}
|
||||
aria-label="Include Coder Agents default system prompt"
|
||||
disabled={isSystemPromptDisabled}
|
||||
/>
|
||||
</div>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
{includeDefaultDraft
|
||||
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
|
||||
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
|
||||
</p>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isUserPromptOverflowing && textareaOverflowClassName,
|
||||
isSystemPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional behavior, style, and tone preferences"
|
||||
value={userPromptDraft}
|
||||
onChange={(event) => setLocalUserEdit(event.target.value)}
|
||||
placeholder="Additional instructions for all users"
|
||||
value={systemPromptDraft}
|
||||
onChange={(event) => setLocalEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsUserPromptOverflowing(height >= textareaMaxHeight)
|
||||
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
disabled={isPromptSaving}
|
||||
disabled={isSystemPromptDisabled}
|
||||
minRows={1}
|
||||
/>
|
||||
{userInvisibleCharCount > 0 && (
|
||||
{systemInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {userInvisibleCharCount} invisible Unicode{" "}
|
||||
{userInvisibleCharCount !== 1 ? "characters" : "character"} that
|
||||
could hide content. These will be stripped on save.
|
||||
This text contains {systemInvisibleCharCount} invisible
|
||||
Unicode{" "}
|
||||
{systemInvisibleCharCount !== 1 ? "characters" : "character"}{" "}
|
||||
that could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
@@ -298,304 +390,164 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalUserEdit("")}
|
||||
disabled={isPromptSaving || !userPromptDraft}
|
||||
onClick={() => setLocalEdit("")}
|
||||
disabled={isSystemPromptDisabled || !systemPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isPromptSaving || !isUserPromptDirty}
|
||||
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</Button>{" "}
|
||||
</div>
|
||||
{isSaveUserPromptError && (
|
||||
{isSaveSystemPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save personal instructions.
|
||||
Failed to save system prompt.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
{/* Context compaction */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Context compaction</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>Control when conversation context is automatically summarized for each model. Setting 100% means the conversation will never auto-compact.</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<UserCompactionThresholdSettings
|
||||
hideHeader
|
||||
modelConfigs={modelConfigsData ?? []}
|
||||
modelConfigsError={modelConfigsError}
|
||||
isLoadingModelConfigs={isLoadingModelConfigs}
|
||||
thresholds={thresholds}
|
||||
isThresholdsLoading={isThresholdsLoading}
|
||||
thresholdsError={thresholdsError}
|
||||
onSaveThreshold={onSaveThreshold}
|
||||
onResetThreshold={onResetThreshold}
|
||||
/>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
{/* Admin system prompt (admin only) */}
|
||||
{canSetSystemPrompt && (
|
||||
<>
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>System Instructions</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<CollapsibleSectionDescription>
|
||||
{includeDefaultDraft
|
||||
? "The built-in Coder Agents prompt is prepended. Additional instructions below are appended."
|
||||
: "Only the additional instructions below are used. When empty, no deployment-wide system prompt is sent."}
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveSystemPrompt(event)}
|
||||
>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="flex min-w-0 items-center gap-2 text-xs font-medium text-content-primary">
|
||||
<span>Include Coder Agents default system prompt</span>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="subtle"
|
||||
type="button"
|
||||
onClick={() => setShowDefaultPromptPreview(true)}
|
||||
disabled={!hasLoadedSystemPrompt}
|
||||
className="min-w-0 px-0 text-content-link hover:text-content-link"
|
||||
>
|
||||
Preview
|
||||
</Button>
|
||||
</div>
|
||||
<Switch
|
||||
checked={includeDefaultDraft}
|
||||
onCheckedChange={setLocalIncludeDefault}
|
||||
aria-label="Include Coder Agents default system prompt"
|
||||
disabled={isSystemPromptDisabled}
|
||||
/>
|
||||
</div>
|
||||
<TextareaAutosize
|
||||
className={cn(
|
||||
textareaBaseClassName,
|
||||
isSystemPromptOverflowing && textareaOverflowClassName,
|
||||
)}
|
||||
placeholder="Additional instructions for all users"
|
||||
value={systemPromptDraft}
|
||||
onChange={(event) => setLocalEdit(event.target.value)}
|
||||
onHeightChange={(height) =>
|
||||
setIsSystemPromptOverflowing(height >= textareaMaxHeight)
|
||||
}
|
||||
disabled={isSystemPromptDisabled}
|
||||
minRows={1}
|
||||
/>
|
||||
{systemInvisibleCharCount > 0 && (
|
||||
<Alert severity="warning">
|
||||
<AlertDescription>
|
||||
This text contains {systemInvisibleCharCount} invisible
|
||||
Unicode{" "}
|
||||
{systemInvisibleCharCount !== 1
|
||||
? "characters"
|
||||
: "character"}{" "}
|
||||
that could hide content. These will be stripped on save.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Virtual Desktop
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<div className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
<p className="m-0">
|
||||
Allow agents to use a virtual, graphical desktop within
|
||||
workspaces. Requires the{" "}
|
||||
<Link
|
||||
href="https://registry.coder.com/modules/coder/portabledesktop"
|
||||
target="_blank"
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={() => setLocalEdit("")}
|
||||
disabled={isSystemPromptDisabled || !systemPromptDraft}
|
||||
>
|
||||
Clear
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isSystemPromptDisabled || !isSystemPromptDirty}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
{isSaveSystemPromptError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save system prompt.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Virtual Desktop</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable virtual desktop
|
||||
</span>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Allow agents to use a virtual, graphical desktop within
|
||||
workspaces. Requires the{" "}
|
||||
<Link
|
||||
href="https://registry.coder.com/modules/coder/portabledesktop"
|
||||
target="_blank"
|
||||
size="sm"
|
||||
>
|
||||
portabledesktop module
|
||||
</Link>{" "}
|
||||
to be installed in the workspace and the Anthropic provider to
|
||||
be configured.
|
||||
</p>
|
||||
<p className="m-0 text-xs text-content-secondary font-semibold">
|
||||
Warning: This is a work-in-progress feature, and you're likely
|
||||
to encounter bugs if you enable it.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={desktopEnabled}
|
||||
onCheckedChange={(checked) =>
|
||||
onSaveDesktopEnabled({ enable_desktop: checked })
|
||||
}
|
||||
aria-label="Enable"
|
||||
disabled={isSavingDesktopEnabled}
|
||||
/>
|
||||
</div>{" "}
|
||||
{isSaveDesktopEnabledError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save desktop setting.
|
||||
portabledesktop module
|
||||
</Link>{" "}
|
||||
to be installed in the workspace and the Anthropic provider to
|
||||
be configured.
|
||||
</p>
|
||||
<p className="mt-2 mb-0 font-semibold text-content-secondary">
|
||||
Warning: This is a work-in-progress feature, and you're likely
|
||||
to encounter bugs if you enable it.
|
||||
</p>
|
||||
)}
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Workspace Autostop Fallback</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable autostop fallback
|
||||
</span>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Set a default autostop for agent-created workspaces that don't
|
||||
have one defined in their template. Template-defined autostop
|
||||
rules always take precedence. Active conversations will extend
|
||||
the stop time.
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={isAutostopEnabled}
|
||||
onCheckedChange={handleToggleAutostop}
|
||||
aria-label="Enable default autostop"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
/>
|
||||
<Switch
|
||||
checked={desktopEnabled}
|
||||
onCheckedChange={(checked) =>
|
||||
onSaveDesktopEnabled({ enable_desktop: checked })
|
||||
}
|
||||
aria-label="Enable"
|
||||
disabled={isSavingDesktopEnabled}
|
||||
/>
|
||||
</div>
|
||||
{isSaveDesktopEnabledError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save desktop setting.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
<form
|
||||
className="space-y-2"
|
||||
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Workspace Autostop Fallback
|
||||
</h3>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
Set a default autostop for agent-created workspaces that don't
|
||||
have one defined in their template. Template-defined autostop
|
||||
rules always take precedence. Active conversations will extend
|
||||
the stop time.
|
||||
</p>
|
||||
<Switch
|
||||
checked={isAutostopEnabled}
|
||||
onCheckedChange={handleToggleAutostop}
|
||||
aria-label="Enable default autostop"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
/>{" "}
|
||||
</div>
|
||||
{isAutostopEnabled && (
|
||||
<DurationField
|
||||
valueMs={ttlMs}
|
||||
onChange={handleTTLChange}
|
||||
label="Autostop Fallback"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
error={isTTLOverMax || isTTLZero}
|
||||
helperText={
|
||||
isTTLZero
|
||||
? "Duration must be greater than zero."
|
||||
: isTTLOverMax
|
||||
? "Must not exceed 30 days (720 hours)."
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{isAutostopEnabled && (
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={
|
||||
isSavingWorkspaceTTL ||
|
||||
!isTTLDirty ||
|
||||
isTTLOverMax ||
|
||||
isTTLZero
|
||||
}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
<form
|
||||
className="space-y-2 pt-4"
|
||||
onSubmit={(event) => void handleSaveChatWorkspaceTTL(event)}
|
||||
>
|
||||
{" "}
|
||||
{isAutostopEnabled && (
|
||||
<DurationField
|
||||
valueMs={ttlMs}
|
||||
onChange={handleTTLChange}
|
||||
label="Autostop Fallback"
|
||||
disabled={isSavingWorkspaceTTL || isWorkspaceTTLLoading}
|
||||
error={isTTLOverMax || isTTLZero}
|
||||
helperText={
|
||||
isTTLZero
|
||||
? "Duration must be greater than zero."
|
||||
: isTTLOverMax
|
||||
? "Must not exceed 30 days (720 hours)."
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{isAutostopEnabled && (
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={
|
||||
isSavingWorkspaceTTL ||
|
||||
!isTTLDirty ||
|
||||
isTTLOverMax ||
|
||||
isTTLZero
|
||||
}
|
||||
>
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
{isSaveWorkspaceTTLError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save autostop setting.
|
||||
</p>
|
||||
)}
|
||||
{isWorkspaceTTLLoadError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load autostop setting.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
)}
|
||||
{isSaveWorkspaceTTLError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to save autostop setting.
|
||||
</p>
|
||||
)}
|
||||
{isWorkspaceTTLLoadError && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
Failed to load autostop setting.
|
||||
</p>
|
||||
)}
|
||||
</form>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Kyleosophy toggle (always visible) */}
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Kyleosophy</CollapsibleSectionTitle>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<span className="text-sm font-medium text-content-primary">
|
||||
Enable Kyleosophy
|
||||
<hr className="my-5 border-0 border-t border-solid border-border" />
|
||||
{/* ── Kyleosophy toggle (always visible) ── */}
|
||||
<div className="space-y-2">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Kyleosophy
|
||||
</h3>
|
||||
<div className="flex items-center justify-between gap-4">
|
||||
<p className="!mt-0.5 m-0 flex-1 text-xs text-content-secondary">
|
||||
Replace the standard completion chime. IYKYK.
|
||||
{kylesophyForced && (
|
||||
<span className="ml-1 font-semibold">
|
||||
Kyleosophy is mandatory on <code>dev.coder.com</code>.
|
||||
</span>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
{" "}
|
||||
Replace the standard completion chime. IYKYK.
|
||||
{kylesophyForced && (
|
||||
<span className="ml-1 font-semibold">
|
||||
Kyleosophy is mandatory on <code>dev.coder.com</code>.
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={kylesophyEnabled}
|
||||
onCheckedChange={(checked) => {
|
||||
setKylesophyEnabled(checked);
|
||||
setLocalKylesophy(checked);
|
||||
}}
|
||||
aria-label="Enable Kyleosophy"
|
||||
disabled={kylesophyForced}
|
||||
/>
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
)}
|
||||
</p>
|
||||
<Switch
|
||||
checked={kylesophyEnabled}
|
||||
onCheckedChange={(checked) => {
|
||||
setKylesophyEnabled(checked);
|
||||
setLocalKylesophy(checked);
|
||||
}}
|
||||
aria-label="Enable Kyleosophy"
|
||||
disabled={kylesophyForced}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{showDefaultPromptPreview && (
|
||||
<TextPreviewDialog
|
||||
content={defaultSystemPrompt}
|
||||
@@ -603,6 +555,6 @@ export const AgentSettingsBehaviorPageView: FC<
|
||||
onClose={() => setShowDefaultPromptPreview(false)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
+1
-1
@@ -255,7 +255,7 @@ export const ProviderAccordionCards: Story = {
|
||||
expect(body.queryByText("OpenAI")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.click(body.getByRole("button", { name: /OpenRouter/i }));
|
||||
await expect(body.getByLabelText("Base URL")).toBeInTheDocument();
|
||||
await expect(await body.findByLabelText("Base URL")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -285,7 +285,7 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
|
||||
const modelConfigsUnavailable = modelConfigsData === null;
|
||||
|
||||
return (
|
||||
<div className={cn("flex min-h-full flex-col space-y-6", className)}>
|
||||
<div className={cn("flex min-h-full flex-col space-y-3", className)}>
|
||||
{isLoading && (
|
||||
<div className="flex items-center gap-1.5 text-xs text-content-secondary">
|
||||
<Spinner className="h-4 w-4" loading />
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { useFormik } from "formik";
|
||||
import { ChevronLeftIcon, InfoIcon, PencilIcon } from "lucide-react";
|
||||
import {
|
||||
ChevronDownIcon,
|
||||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
InfoIcon,
|
||||
PencilIcon,
|
||||
} from "lucide-react";
|
||||
import { type FC, useState } from "react";
|
||||
import * as Yup from "yup";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
@@ -35,13 +41,6 @@ import {
|
||||
} from "#/components/Tooltip/Tooltip";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { getFormHelpers } from "#/utils/formUtils";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "../CollapsibleSection";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import {
|
||||
GeneralModelConfigFields,
|
||||
@@ -117,6 +116,9 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
}) => {
|
||||
const isEditing = Boolean(editingModel);
|
||||
const isDefaultModel = isEditing && editingModel?.is_default === true;
|
||||
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||
const [showPricing, setShowPricing] = useState(false);
|
||||
const [showProviderConfig, setShowProviderConfig] = useState(false);
|
||||
const [confirmingDelete, setConfirmingDelete] = useState(false);
|
||||
|
||||
const canManageModels = Boolean(
|
||||
@@ -486,18 +488,29 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
</div>
|
||||
|
||||
{/* Usage Tracking */}
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Cost Tracking
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Set per-token pricing so Coder can track costs and enforce
|
||||
spending limits.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-4">
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowPricing((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Cost Tracking{" "}
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Set per-token pricing so Coder can track costs and enforce
|
||||
spending limits.
|
||||
</p>
|
||||
</div>
|
||||
{showPricing ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showPricing && (
|
||||
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-4">
|
||||
<PricingModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -505,21 +518,33 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
disabled={isSaving}
|
||||
/>
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Provider Configuration */}
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Provider Configuration
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Tune provider-specific behavior like reasoning, tool calling,
|
||||
and web search.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowProviderConfig((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Provider Configuration
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Tune provider-specific behavior like reasoning, tool calling,
|
||||
and web search.
|
||||
</p>
|
||||
</div>
|
||||
{showProviderConfig ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showProviderConfig && (
|
||||
<div className="pt-3">
|
||||
<ModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -527,21 +552,33 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
disabled={isSaving}
|
||||
/>
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Advanced */}
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">
|
||||
Advanced
|
||||
</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Low-level parameters like temperature and penalties. Rarely need
|
||||
changing.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-3">
|
||||
<div className="border-0 border-t border-solid border-border pt-4">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowAdvanced((v) => !v)}
|
||||
className="flex w-full cursor-pointer items-start justify-between border-0 bg-transparent p-0 text-left transition-colors hover:text-content-primary"
|
||||
>
|
||||
<div>
|
||||
<h3 className="m-0 text-sm font-medium text-content-primary">
|
||||
Advanced
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Low-level parameters like temperature and penalties. Rarely
|
||||
need changing.
|
||||
</p>
|
||||
</div>
|
||||
{showAdvanced ? (
|
||||
<ChevronDownIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
) : (
|
||||
<ChevronRightIcon className="mt-0.5 h-4 w-4 shrink-0 text-content-secondary" />
|
||||
)}
|
||||
</button>
|
||||
{showAdvanced && (
|
||||
<div className="grid grid-cols-2 gap-3 pt-3 sm:grid-cols-3">
|
||||
<GeneralModelConfigFields
|
||||
provider={selectedProviderState.provider}
|
||||
form={form}
|
||||
@@ -592,11 +629,10 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-auto py-6">
|
||||
{" "}
|
||||
<hr className="mb-4 border-0 border-t border-solid border-border" />
|
||||
<div className="flex items-center justify-between">
|
||||
{isEditing && editingModel && onDeleteModel ? (
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, userEvent, within } from "storybook/test";
|
||||
import { AdminBadge } from "./AdminBadge";
|
||||
import {
|
||||
CollapsibleSection,
|
||||
CollapsibleSectionContent,
|
||||
CollapsibleSectionDescription,
|
||||
CollapsibleSectionHeader,
|
||||
CollapsibleSectionTitle,
|
||||
} from "./CollapsibleSection";
|
||||
|
||||
const meta: Meta<typeof CollapsibleSection> = {
|
||||
title: "pages/AgentsPage/CollapsibleSection",
|
||||
component: CollapsibleSection,
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div style={{ maxWidth: 600 }}>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof CollapsibleSection>;
|
||||
|
||||
const Placeholder = () => (
|
||||
<p className="text-sm text-content-secondary">Placeholder content</p>
|
||||
);
|
||||
|
||||
export const DefaultOpen: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<CollapsibleSectionDescription>
|
||||
The deployment-wide spending cap.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Default spend limit/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const Collapsed: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<div className="flex items-center gap-2">
|
||||
<CollapsibleSectionTitle>Default spend limit</CollapsibleSectionTitle>
|
||||
<AdminBadge />
|
||||
</div>
|
||||
<CollapsibleSectionDescription>
|
||||
The deployment-wide spending cap.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Default spend limit/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const NoBadge: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Group limits</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Override defaults for groups.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
};
|
||||
|
||||
export const InlineVariant: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection variant="inline" defaultOpen={false}>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle as="h3">Cost Tracking</CollapsibleSectionTitle>
|
||||
<CollapsibleSectionDescription>
|
||||
Set per-token pricing so Coder can track costs and enforce spending
|
||||
limits.
|
||||
</CollapsibleSectionDescription>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Cost Tracking/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.click(header);
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const KeyboardToggle: Story = {
|
||||
render: () => (
|
||||
<CollapsibleSection>
|
||||
<CollapsibleSectionHeader>
|
||||
<CollapsibleSectionTitle>Keyboard section</CollapsibleSectionTitle>
|
||||
</CollapsibleSectionHeader>
|
||||
<CollapsibleSectionContent>
|
||||
<Placeholder />
|
||||
</CollapsibleSectionContent>
|
||||
</CollapsibleSection>
|
||||
),
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
|
||||
const header = canvas.getByRole("button", {
|
||||
name: /Keyboard section/i,
|
||||
});
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
header.focus();
|
||||
|
||||
await userEvent.keyboard("{Enter}");
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard("{Enter}");
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard(" ");
|
||||
expect(header).toHaveAttribute("aria-expanded", "false");
|
||||
expect(canvas.queryByText("Placeholder content")).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.keyboard(" ");
|
||||
expect(header).toHaveAttribute("aria-expanded", "true");
|
||||
expect(canvas.getByText("Placeholder content")).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
@@ -1,211 +0,0 @@
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
import { ChevronDownIcon } from "lucide-react";
|
||||
import {
|
||||
type ComponentProps,
|
||||
createContext,
|
||||
type FC,
|
||||
type ReactNode,
|
||||
useContext,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "#/components/Collapsible/Collapsible";
|
||||
import { cn } from "#/utils/cn";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Variant styles
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const wrapperStyles = cva("", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "rounded-lg border border-solid border-border-default",
|
||||
inline: "border-0 border-t border-solid border-border-default pt-4",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const triggerStyles = cva(
|
||||
"flex w-full cursor-pointer items-start justify-between gap-4 border-0 bg-transparent text-left focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-content-link",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
card: "rounded-lg px-6 py-5",
|
||||
inline: "rounded-md p-0",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
},
|
||||
);
|
||||
|
||||
const titleStyles = cva("m-0 text-content-primary", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "text-lg font-semibold",
|
||||
inline: "text-sm font-medium",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const descriptionStyles = cva("m-0 text-content-secondary", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "mt-1 text-sm",
|
||||
inline: "text-xs",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
const contentStyles = cva("", {
|
||||
variants: {
|
||||
variant: {
|
||||
card: "border-0 border-t border-solid border-border-default px-6 pb-5 pt-4",
|
||||
inline: "pt-3",
|
||||
},
|
||||
},
|
||||
defaultVariants: { variant: "card" },
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type Variant = "card" | "inline";
|
||||
|
||||
interface CollapsibleSectionContextValue {
|
||||
variant: Variant;
|
||||
open: boolean;
|
||||
}
|
||||
|
||||
const CollapsibleSectionContext = createContext<CollapsibleSectionContextValue>(
|
||||
{
|
||||
variant: "card",
|
||||
open: true,
|
||||
},
|
||||
);
|
||||
|
||||
const useCollapsibleSection = () => useContext(CollapsibleSectionContext);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Root
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionProps extends VariantProps<typeof wrapperStyles> {
|
||||
defaultOpen?: boolean;
|
||||
/** Controlled open state. */
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSection: FC<CollapsibleSectionProps> = ({
|
||||
defaultOpen,
|
||||
open: controlledOpen,
|
||||
onOpenChange: controlledOnOpenChange,
|
||||
variant = "card",
|
||||
children,
|
||||
}) => {
|
||||
const [uncontrolledOpen, setUncontrolledOpen] = useState(defaultOpen ?? true);
|
||||
const isControlled = controlledOpen !== undefined;
|
||||
const open = isControlled ? controlledOpen : uncontrolledOpen;
|
||||
const onOpenChange = isControlled
|
||||
? controlledOnOpenChange
|
||||
: setUncontrolledOpen;
|
||||
|
||||
return (
|
||||
<CollapsibleSectionContext.Provider
|
||||
value={{ variant: variant ?? "card", open }}
|
||||
>
|
||||
<Collapsible open={open} onOpenChange={onOpenChange}>
|
||||
<div className={wrapperStyles({ variant })}>{children}</div>
|
||||
</Collapsible>
|
||||
</CollapsibleSectionContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Header (trigger)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionHeaderProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionHeader: FC<CollapsibleSectionHeaderProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
const { variant, open } = useCollapsibleSection();
|
||||
|
||||
return (
|
||||
<CollapsibleTrigger className={triggerStyles({ variant })}>
|
||||
<div className="min-w-0 flex-1">{children}</div>
|
||||
<ChevronDownIcon
|
||||
className={cn(
|
||||
"h-4 w-4 shrink-0 text-content-secondary transition-transform duration-200",
|
||||
open && "rotate-180",
|
||||
)}
|
||||
/>
|
||||
</CollapsibleTrigger>
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Title — renders the heading element at whatever level the consumer picks.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type HeadingTag = "h1" | "h2" | "h3" | "h4" | "h5" | "h6";
|
||||
|
||||
interface CollapsibleSectionTitleProps extends ComponentProps<HeadingTag> {
|
||||
as?: HeadingTag;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionTitle: FC<CollapsibleSectionTitleProps> = ({
|
||||
as: Component = "h2",
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
return (
|
||||
<Component className={cn(titleStyles({ variant }), className)} {...props} />
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Description
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const CollapsibleSectionDescription: FC<ComponentProps<"p">> = ({
|
||||
className,
|
||||
...props
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
return (
|
||||
<p className={cn(descriptionStyles({ variant }), className)} {...props} />
|
||||
);
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Content
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface CollapsibleSectionContentProps {
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export const CollapsibleSectionContent: FC<CollapsibleSectionContentProps> = ({
|
||||
children,
|
||||
}) => {
|
||||
const { variant } = useCollapsibleSection();
|
||||
|
||||
return (
|
||||
<CollapsibleContent>
|
||||
<div className={contentStyles({ variant })}>{children}</div>
|
||||
</CollapsibleContent>
|
||||
);
|
||||
};
|
||||
@@ -989,7 +989,7 @@ export const MCPServerAdminPanel: FC<MCPServerAdminPanelProps> = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex min-h-full flex-col space-y-6">
|
||||
<div className="flex min-h-full flex-col space-y-3">
|
||||
{!isFormView ? (
|
||||
<ServerList
|
||||
servers={servers}
|
||||
|
||||
@@ -21,9 +21,7 @@ import {
|
||||
} from "#/components/Table/Table";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { formatCostMicros } from "#/utils/currency";
|
||||
import { AdminBadge } from "./AdminBadge";
|
||||
import { PrStateIcon } from "./GitPanel/GitPanel";
|
||||
import { SectionHeader } from "./SectionHeader";
|
||||
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
@@ -297,16 +295,19 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
|
||||
const isEmpty = summary.total_prs_created === 0;
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="space-y-8">
|
||||
{/* ── Header ── */}
|
||||
<SectionHeader
|
||||
label="Pull Request Insights"
|
||||
description="Code changes detected by Agents."
|
||||
badge={<AdminBadge />}
|
||||
action={
|
||||
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
|
||||
}
|
||||
/>
|
||||
<div className="flex items-end justify-between">
|
||||
<div>
|
||||
<h2 className="m-0 text-xl font-semibold tracking-tight text-content-primary">
|
||||
Pull Request Insights
|
||||
</h2>
|
||||
<p className="m-0 mt-1 text-[13px] text-content-secondary">
|
||||
Code changes detected by Agents.
|
||||
</p>
|
||||
</div>
|
||||
<TimeRangeFilter value={timeRange} onChange={onTimeRangeChange} />
|
||||
</div>
|
||||
|
||||
{isEmpty ? (
|
||||
<EmptyState />
|
||||
@@ -354,7 +355,7 @@ export const PRInsightsView: FC<PRInsightsViewProps> = ({
|
||||
</section>
|
||||
|
||||
{/* ── Model breakdown + Recent PRs side by side ── */}
|
||||
<div className="grid grid-cols-1 gap-6">
|
||||
<div className="grid grid-cols-1 gap-6 lg:grid-cols-2">
|
||||
{/* ── Model performance (simplified) ── */}
|
||||
{by_model.length > 0 && (
|
||||
<section>
|
||||
|
||||
@@ -15,8 +15,8 @@ export const SectionHeader: FC<SectionHeaderProps> = ({
|
||||
}) => (
|
||||
<>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-2">
|
||||
<h2 className="m-0 text-lg font-medium text-content-primary">
|
||||
{label}
|
||||
</h2>
|
||||
|
||||
@@ -78,8 +78,11 @@ export const Default: Story = {
|
||||
expect(canvas.getByText("Claude Sonnet")).toBeInTheDocument();
|
||||
expect(canvas.queryByText("GPT-3.5 (Disabled)")).not.toBeInTheDocument();
|
||||
|
||||
// Save button should be disabled when nothing is dirty.
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
// No footer visible when nothing is dirty
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Type a value to make the footer appear
|
||||
await userEvent.type(gpt4oInput, "95");
|
||||
await waitFor(() => {
|
||||
@@ -160,9 +163,11 @@ export const CancelChanges: Story = {
|
||||
const cancelButton = await canvas.findByRole("button", { name: /Cancel/i });
|
||||
await userEvent.click(cancelButton);
|
||||
|
||||
// Save button should be disabled after cancel.
|
||||
// Footer should disappear after cancel
|
||||
await waitFor(() => {
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Input should be cleared back to empty (no override)
|
||||
@@ -189,8 +194,10 @@ export const InvalidDraftShowsFooter: Story = {
|
||||
// Cancel button should be visible so user can discard the edit
|
||||
expect(canvas.getByRole("button", { name: /Cancel/i })).toBeInTheDocument();
|
||||
|
||||
// Save button should be disabled (nothing valid to save).
|
||||
expect(canvas.getByRole("button", { name: /Save/i })).toBeDisabled();
|
||||
// Save button should NOT be visible (nothing valid to save)
|
||||
expect(
|
||||
canvas.queryByRole("button", { name: /Save/i }),
|
||||
).not.toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableFooter,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
@@ -32,7 +33,6 @@ interface UserCompactionThresholdSettingsProps {
|
||||
thresholdPercent: number,
|
||||
) => Promise<unknown>;
|
||||
onResetThreshold: (modelConfigId: string) => Promise<unknown>;
|
||||
hideHeader?: boolean;
|
||||
}
|
||||
|
||||
const parseThresholdDraft = (value: string): number | null => {
|
||||
@@ -60,7 +60,6 @@ export const UserCompactionThresholdSettings: FC<
|
||||
thresholdsError,
|
||||
onSaveThreshold,
|
||||
onResetThreshold,
|
||||
hideHeader,
|
||||
}) => {
|
||||
const [drafts, setDrafts] = useState<Record<string, string>>({});
|
||||
const [rowErrors, setRowErrors] = useState<Record<string, string>>({});
|
||||
@@ -173,23 +172,19 @@ export const UserCompactionThresholdSettings: FC<
|
||||
};
|
||||
|
||||
const hasAnyPending = pendingModels.size > 0;
|
||||
|
||||
const headerBlock = !hideHeader ? (
|
||||
<>
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
</>
|
||||
) : null;
|
||||
const hasAnyErrors = Object.keys(rowErrors).length > 0;
|
||||
const hasAnyDrafts = Object.keys(drafts).length > 0;
|
||||
|
||||
if (isThresholdsLoading) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{headerBlock}
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
<div className="flex items-center gap-2 text-sm text-content-secondary">
|
||||
<Spinner loading className="h-4 w-4" />
|
||||
Loading thresholds...
|
||||
@@ -201,7 +196,13 @@ export const UserCompactionThresholdSettings: FC<
|
||||
if (thresholdsError != null) {
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{headerBlock}
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
{getErrorMessage(
|
||||
thresholdsError,
|
||||
@@ -214,7 +215,13 @@ export const UserCompactionThresholdSettings: FC<
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{headerBlock}
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Context Compaction
|
||||
</h3>
|
||||
<p className="!mt-0.5 m-0 text-xs text-content-secondary">
|
||||
Control when conversation context is automatically summarized for each
|
||||
model. Setting 100% means the conversation will never auto-compact.
|
||||
</p>
|
||||
{isLoadingModelConfigs ? (
|
||||
<div className="flex items-center gap-2 text-sm text-content-secondary">
|
||||
<Spinner loading className="h-4 w-4" />
|
||||
@@ -233,163 +240,165 @@ export const UserCompactionThresholdSettings: FC<
|
||||
models before compaction thresholds can be set.
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">
|
||||
Threshold
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{enabledModelConfigs.map((modelConfig) => {
|
||||
const existingOverride = overridesByModelID.get(modelConfig.id);
|
||||
const hasOverride = overridesByModelID.has(modelConfig.id);
|
||||
const draftValue =
|
||||
drafts[modelConfig.id] ??
|
||||
(existingOverride !== undefined
|
||||
? String(existingOverride)
|
||||
: "");
|
||||
const parsedDraftValue = parseThresholdDraft(draftValue);
|
||||
const isThisModelMutating = pendingModels.has(modelConfig.id);
|
||||
const isInvalid =
|
||||
draftValue.length > 0 && parsedDraftValue === null;
|
||||
// Only warn when user-typed, not when loaded from
|
||||
// the server.
|
||||
const isDraftDisablingCompaction =
|
||||
draftValue === "100" && drafts[modelConfig.id] !== undefined;
|
||||
const rowError = rowErrors[modelConfig.id];
|
||||
const modelName = modelConfig.display_name || modelConfig.model;
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Model</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Default</TableHead>
|
||||
<TableHead className="w-0 whitespace-nowrap">Threshold</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{enabledModelConfigs.map((modelConfig) => {
|
||||
const existingOverride = overridesByModelID.get(modelConfig.id);
|
||||
const hasOverride = overridesByModelID.has(modelConfig.id);
|
||||
const draftValue =
|
||||
drafts[modelConfig.id] ??
|
||||
(existingOverride !== undefined
|
||||
? String(existingOverride)
|
||||
: "");
|
||||
const parsedDraftValue = parseThresholdDraft(draftValue);
|
||||
const isThisModelMutating = pendingModels.has(modelConfig.id);
|
||||
const isInvalid =
|
||||
draftValue.length > 0 && parsedDraftValue === null;
|
||||
// Only warn when user-typed, not when loaded from
|
||||
// the server.
|
||||
const isDraftDisablingCompaction =
|
||||
draftValue === "100" && drafts[modelConfig.id] !== undefined;
|
||||
const rowError = rowErrors[modelConfig.id];
|
||||
const modelName = modelConfig.display_name || modelConfig.model;
|
||||
|
||||
return (
|
||||
<TableRow key={modelConfig.id}>
|
||||
<TableCell className="text-[13px] font-medium text-content-primary">
|
||||
{modelName}
|
||||
{rowError && (
|
||||
<p
|
||||
aria-live="polite"
|
||||
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
|
||||
>
|
||||
{rowError}
|
||||
</p>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap tabular-nums">
|
||||
{modelConfig.compression_threshold}%
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Input
|
||||
aria-label={`${modelName} compaction threshold`}
|
||||
aria-invalid={isInvalid || undefined}
|
||||
type="number"
|
||||
min={0}
|
||||
max={100}
|
||||
inputMode="numeric"
|
||||
className={cn(
|
||||
"h-7 w-16 px-2 text-xs tabular-nums",
|
||||
isInvalid &&
|
||||
"border-content-destructive focus:ring-content-destructive/30",
|
||||
)}
|
||||
value={draftValue}
|
||||
placeholder={String(
|
||||
modelConfig.compression_threshold,
|
||||
)}
|
||||
onChange={(event) => {
|
||||
setDrafts((currentDrafts) => ({
|
||||
...currentDrafts,
|
||||
[modelConfig.id]: event.target.value,
|
||||
}));
|
||||
clearRowError(modelConfig.id);
|
||||
}}
|
||||
disabled={isThisModelMutating}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
{(isInvalid || isDraftDisablingCompaction) && (
|
||||
<TooltipContent>
|
||||
{isInvalid
|
||||
? "Enter a whole number between 0 and 100."
|
||||
: "Setting 100% will disable auto-compaction for this model."}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
<span className="text-xs text-content-secondary">
|
||||
%
|
||||
</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="subtle"
|
||||
className={cn(
|
||||
"size-7",
|
||||
hasOverride
|
||||
? "opacity-100"
|
||||
: "pointer-events-none opacity-0",
|
||||
)}
|
||||
aria-label={`Reset ${modelName} to default`}
|
||||
aria-hidden={!hasOverride}
|
||||
tabIndex={hasOverride ? 0 : -1}
|
||||
disabled={isThisModelMutating || !hasOverride}
|
||||
onClick={() => handleReset(modelConfig.id)}
|
||||
>
|
||||
<RotateCcwIcon className="size-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{hasOverride && (
|
||||
<TooltipContent>
|
||||
Reset to default (
|
||||
{modelConfig.compression_threshold}%)
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
{isInvalid && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Enter a whole number between 0 and 100.
|
||||
</span>
|
||||
)}
|
||||
{isDraftDisablingCompaction && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Setting 100% will disable auto-compaction for this
|
||||
model.
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<div className="flex justify-end gap-2 pt-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={handleCancelAll}
|
||||
disabled={hasAnyPending}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
type="button"
|
||||
disabled={hasAnyPending || dirtyRows.length === 0}
|
||||
onClick={handleSaveAll}
|
||||
>
|
||||
{hasAnyPending
|
||||
? "Saving..."
|
||||
: dirtyRows.length > 0
|
||||
? `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`
|
||||
: "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
return (
|
||||
<TableRow key={modelConfig.id}>
|
||||
<TableCell className="text-[13px] font-medium text-content-primary">
|
||||
{modelName}
|
||||
{rowError && (
|
||||
<p
|
||||
aria-live="polite"
|
||||
className="m-0 mt-0.5 text-2xs font-normal text-content-destructive"
|
||||
>
|
||||
{rowError}
|
||||
</p>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap tabular-nums">
|
||||
{modelConfig.compression_threshold}%
|
||||
</TableCell>
|
||||
<TableCell className="w-0 whitespace-nowrap">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Input
|
||||
aria-label={`${modelName} compaction threshold`}
|
||||
aria-invalid={isInvalid || undefined}
|
||||
type="number"
|
||||
min={0}
|
||||
max={100}
|
||||
inputMode="numeric"
|
||||
className={cn(
|
||||
"h-7 w-16 px-2 text-xs tabular-nums",
|
||||
isInvalid &&
|
||||
"border-content-destructive focus:ring-content-destructive/30",
|
||||
)}
|
||||
value={draftValue}
|
||||
placeholder={String(
|
||||
modelConfig.compression_threshold,
|
||||
)}
|
||||
onChange={(event) => {
|
||||
setDrafts((currentDrafts) => ({
|
||||
...currentDrafts,
|
||||
[modelConfig.id]: event.target.value,
|
||||
}));
|
||||
clearRowError(modelConfig.id);
|
||||
}}
|
||||
disabled={isThisModelMutating}
|
||||
/>
|
||||
</TooltipTrigger>
|
||||
{(isInvalid || isDraftDisablingCompaction) && (
|
||||
<TooltipContent>
|
||||
{isInvalid
|
||||
? "Enter a whole number between 0 and 100."
|
||||
: "Setting 100% will disable auto-compaction for this model."}
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
<span className="text-xs text-content-secondary">%</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="subtle"
|
||||
className={cn(
|
||||
"size-7",
|
||||
hasOverride
|
||||
? "opacity-100"
|
||||
: "pointer-events-none opacity-0",
|
||||
)}
|
||||
aria-label={`Reset ${modelName} to default`}
|
||||
aria-hidden={!hasOverride}
|
||||
tabIndex={hasOverride ? 0 : -1}
|
||||
disabled={isThisModelMutating || !hasOverride}
|
||||
onClick={() => handleReset(modelConfig.id)}
|
||||
>
|
||||
<RotateCcwIcon className="size-3.5" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{hasOverride && (
|
||||
<TooltipContent>
|
||||
Reset to default (
|
||||
{modelConfig.compression_threshold}%)
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
{isInvalid && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Enter a whole number between 0 and 100.
|
||||
</span>
|
||||
)}
|
||||
{isDraftDisablingCompaction && (
|
||||
<span className="sr-only" aria-live="polite">
|
||||
Setting 100% will disable auto-compaction for this
|
||||
model.
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
{(dirtyRows.length > 0 || hasAnyErrors || hasAnyDrafts) && (
|
||||
<TableFooter className="bg-transparent">
|
||||
<TableRow className="border-0">
|
||||
<TableCell colSpan={3} className="border-0 p-0">
|
||||
<div className="flex items-center justify-end gap-2 px-3 py-1.5">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
type="button"
|
||||
onClick={handleCancelAll}
|
||||
disabled={hasAnyPending}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
{dirtyRows.length > 0 && (
|
||||
<Button
|
||||
size="sm"
|
||||
type="button"
|
||||
disabled={hasAnyPending}
|
||||
onClick={handleSaveAll}
|
||||
>
|
||||
{hasAnyPending
|
||||
? "Saving..."
|
||||
: `Save ${dirtyRows.length} ${dirtyRows.length === 1 ? "change" : "changes"}`}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableFooter>
|
||||
)}
|
||||
</Table>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import type { DateRangeValue } from "#/components/DateRangePicker/DateRangePicker";
|
||||
|
||||
/**
|
||||
* Returns true when the given date falls exactly on midnight
|
||||
* (00:00:00.000). Date-range pickers use midnight of the *following*
|
||||
* day as the exclusive upper bound for a full-day selection. Detecting
|
||||
* this lets call sites subtract 1 ms (or 1 day) so the UI shows the
|
||||
* inclusive end date instead.
|
||||
*/
|
||||
function isMidnight(date: Date): boolean {
|
||||
return (
|
||||
date.getHours() === 0 &&
|
||||
date.getMinutes() === 0 &&
|
||||
date.getSeconds() === 0 &&
|
||||
date.getMilliseconds() === 0
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* When the user picks an explicit date range whose end boundary is
|
||||
* midnight of the following day, adjust it by −1 ms so the
|
||||
* DateRangePicker highlights the inclusive end date.
|
||||
*/
|
||||
export function toInclusiveDateRange(
|
||||
dateRange: DateRangeValue,
|
||||
hasExplicitDateRange: boolean,
|
||||
): DateRangeValue {
|
||||
if (hasExplicitDateRange && isMidnight(dateRange.endDate)) {
|
||||
return {
|
||||
startDate: dateRange.startDate,
|
||||
endDate: new Date(dateRange.endDate.getTime() - 1),
|
||||
};
|
||||
}
|
||||
return dateRange;
|
||||
}
|
||||
@@ -71,6 +71,7 @@ describe("AuditPage", () => {
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -90,6 +91,7 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -114,6 +116,7 @@ describe("AuditPage", () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
await renderPage();
|
||||
@@ -140,9 +143,11 @@ describe("AuditPage", () => {
|
||||
|
||||
describe("Filtering", () => {
|
||||
it("filters by URL", async () => {
|
||||
const getAuditLogsSpy = vi
|
||||
.spyOn(API, "getAuditLogs")
|
||||
.mockResolvedValue({ audit_logs: [MockAuditLog], count: 1 });
|
||||
const getAuditLogsSpy = vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
const query = "resource_type:workspace action:create";
|
||||
await renderPage({ filter: query });
|
||||
@@ -173,4 +178,29 @@ describe("AuditPage", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Capped count", () => {
|
||||
it("shows capped count indicator and navigates to next page with correct offset", async () => {
|
||||
vi.spyOn(API, "getAuditLogs").mockResolvedValue({
|
||||
audit_logs: [MockAuditLog, MockAuditLog2],
|
||||
count: 2001,
|
||||
count_cap: 2000,
|
||||
});
|
||||
|
||||
const user = userEvent.setup();
|
||||
await renderPage();
|
||||
|
||||
await screen.findByText(/2,000\+/);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /next page/i }));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(API.getAuditLogs).toHaveBeenLastCalledWith<[AuditLogsRequest]>({
|
||||
limit: DEFAULT_RECORDS_PER_PAGE,
|
||||
offset: DEFAULT_RECORDS_PER_PAGE,
|
||||
q: "",
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -69,6 +69,7 @@ describe("ConnectionLogPage", () => {
|
||||
MockDisconnectedSSHConnectionLog,
|
||||
],
|
||||
count: 2,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
// When
|
||||
@@ -95,6 +96,7 @@ describe("ConnectionLogPage", () => {
|
||||
.mockResolvedValue({
|
||||
connection_logs: [MockConnectedSSHConnectionLog],
|
||||
count: 1,
|
||||
count_cap: 0,
|
||||
});
|
||||
|
||||
const query = "type:ssh status:ongoing";
|
||||
|
||||
Reference in New Issue
Block a user