Compare commits

..

2 Commits

Author SHA1 Message Date
Lukasz c174b3037b Merge branch 'main' into security-patch-train-doc 2026-04-07 16:31:36 +02:00
Lukasz f5165d304f ci(.github): automate security patch PRs and backports 2026-04-07 16:27:25 +02:00
236 changed files with 3890 additions and 12718 deletions
+2
View File
@@ -0,0 +1,2 @@
enabled: true
preservePullRequestTitle: true
+4 -3
View File
@@ -31,7 +31,8 @@ updates:
patterns:
- "golang.org/x/*"
ignore:
# Ignore patch updates for all dependencies
# Patch updates are handled by the security-patch-prs workflow so this
# lane stays focused on broader dependency updates.
- dependency-name: "*"
update-types:
- version-update:semver-patch
@@ -56,7 +57,7 @@ updates:
labels: []
ignore:
# We need to coordinate terraform updates with the version hardcoded in
# our Go code.
# our Go code. These are handled by the security-patch-prs workflow.
- dependency-name: "terraform"
- package-ecosystem: "npm"
@@ -117,11 +118,11 @@ updates:
interval: "weekly"
commit-message:
prefix: "chore"
labels: []
groups:
coder-modules:
patterns:
- "coder/*/coder"
labels: []
ignore:
- dependency-name: "*"
update-types:
-174
View File
@@ -1,174 +0,0 @@
# Automatically backport merged PRs to the last N release branches when the
# "backport" label is applied. Works whether the label is added before or
# after the PR is merged.
#
# Usage:
# 1. Add the "backport" label to a PR targeting main.
# 2. When the PR merges (or if already merged), the workflow detects the
# latest release/* branches and opens one cherry-pick PR per branch.
#
# The created backport PRs follow existing repo conventions:
# - Branch: backport/<pr>-to-<version>
# - Title: <original PR title> (#<pr>)
# - Body: links back to the original PR and merge commit
name: Backport
on:
pull_request_target:
branches:
- main
types:
- closed
- labeled
permissions:
contents: write
pull-requests: write
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
# fire in quick succession.
concurrency:
group: backport-${{ github.event.pull_request.number }}
jobs:
detect:
name: Detect target branches
if: >
github.event.pull_request.merged == true &&
contains(github.event.pull_request.labels.*.name, 'backport')
runs-on: ubuntu-latest
outputs:
branches: ${{ steps.find.outputs.branches }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Need all refs to discover release branches.
fetch-depth: 0
- name: Find latest release branches
id: find
run: |
# List remote release branches matching the exact release/2.X
# pattern (no suffixes like release/2.31_hotfix), sort by minor
# version descending, and take the top 3.
BRANCHES=$(
git branch -r \
| grep -E '^\s*origin/release/2\.[0-9]+$' \
| sed 's|.*origin/||' \
| sort -t. -k2 -n -r \
| head -3
)
if [ -z "$BRANCHES" ]; then
echo "No release branches found."
echo "branches=[]" >> "$GITHUB_OUTPUT"
exit 0
fi
# Convert to JSON array for the matrix.
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
echo "Will backport to: $JSON"
backport:
name: "Backport to ${{ matrix.branch }}"
needs: detect
if: needs.detect.outputs.branches != '[]'
runs-on: ubuntu-latest
strategy:
matrix:
branch: ${{ fromJson(needs.detect.outputs.branches) }}
fail-fast: false
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Full history required for cherry-pick.
fetch-depth: 0
- name: Cherry-pick and open PR
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
RELEASE_VERSION="${{ matrix.branch }}"
# Strip the release/ prefix for naming.
VERSION="${RELEASE_VERSION#release/}"
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
# Check if backport branch already exists (idempotency for re-runs).
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
exit 0
fi
# Create the backport branch from the target release branch.
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
# Cherry-pick the merge commit. Use -x to record provenance and
# -m1 to pick the first parent (the main branch side).
CONFLICTS=false
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
CONFLICTS=true
# Abort the failed cherry-pick and create an empty commit
# explaining the situation.
git cherry-pick --abort
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
Please cherry-pick manually:
git cherry-pick -x -m1 ${MERGE_SHA}"
fi
git push origin "$BACKPORT_BRANCH"
TITLE="${PR_TITLE} (#${PR_NUMBER})"
BODY=$(cat <<EOF
Backport of ${PR_URL}
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
EOF
)
if [ "$CONFLICTS" = true ]; then
TITLE="${TITLE} (conflicts)"
BODY="${BODY}
> [!WARNING]
> The automatic cherry-pick had conflicts.
> Please resolve manually by cherry-picking the original merge commit:
>
> \`\`\`
> git fetch origin ${BACKPORT_BRANCH}
> git checkout ${BACKPORT_BRANCH}
> git reset --hard origin/${RELEASE_VERSION}
> git cherry-pick -x -m1 ${MERGE_SHA}
> # resolve conflicts, then push
> \`\`\`"
fi
# Check if a PR already exists for this branch (idempotency
# for re-runs).
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
if [ -n "$EXISTING_PR" ]; then
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
exit 0
fi
gh pr create \
--base "$RELEASE_VERSION" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY"
-139
View File
@@ -1,139 +0,0 @@
# Automatically cherry-pick merged PRs to the latest release branch when the
# "cherry-pick" label is applied. Works whether the label is added before or
# after the PR is merged.
#
# Usage:
# 1. Add the "cherry-pick" label to a PR targeting main.
# 2. When the PR merges (or if already merged), the workflow detects the
# latest release/* branch and opens a cherry-pick PR against it.
#
# The created PRs follow existing repo conventions:
# - Branch: backport/<pr>-to-<version>
# - Title: <original PR title> (#<pr>)
# - Body: links back to the original PR and merge commit
name: Cherry-pick to release
on:
pull_request_target:
branches:
- main
types:
- closed
- labeled
permissions:
contents: write
pull-requests: write
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
# fire in quick succession.
concurrency:
group: cherry-pick-${{ github.event.pull_request.number }}
jobs:
cherry-pick:
name: Cherry-pick to latest release
if: >
github.event.pull_request.merged == true &&
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
runs-on: ubuntu-latest
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_URL: ${{ github.event.pull_request.html_url }}
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
# Full history required for cherry-pick and branch discovery.
fetch-depth: 0
- name: Cherry-pick and open PR
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
# Find the latest release branch matching the exact release/2.X
# pattern (no suffixes like release/2.31_hotfix).
RELEASE_BRANCH=$(
git branch -r \
| grep -E '^\s*origin/release/2\.[0-9]+$' \
| sed 's|.*origin/||' \
| sort -t. -k2 -n -r \
| head -1
)
if [ -z "$RELEASE_BRANCH" ]; then
echo "::error::No release branch found."
exit 1
fi
# Strip the release/ prefix for naming.
VERSION="${RELEASE_BRANCH#release/}"
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
echo "Target branch: $RELEASE_BRANCH"
echo "Backport branch: $BACKPORT_BRANCH"
# Check if backport branch already exists (idempotency for re-runs).
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
exit 0
fi
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
# Create the backport branch from the target release branch.
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
# Cherry-pick the merge commit. Use -x to record provenance and
# -m1 to pick the first parent (the main branch side).
CONFLICT=false
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
CONFLICT=true
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
# Abort the failed cherry-pick and create an empty commit with
# instructions so the PR can still be opened.
git cherry-pick --abort
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
To resolve:
git fetch origin ${BACKPORT_BRANCH}
git checkout ${BACKPORT_BRANCH}
git cherry-pick -x -m1 ${MERGE_SHA}
# resolve conflicts
git push origin ${BACKPORT_BRANCH}"
fi
git push origin "$BACKPORT_BRANCH"
BODY=$(cat <<EOF
Cherry-pick of ${PR_URL}
Original PR: #${PR_NUMBER} — ${PR_TITLE}
Merge commit: ${MERGE_SHA}
EOF
)
TITLE="${PR_TITLE} (#${PR_NUMBER})"
if [ "$CONFLICT" = true ]; then
TITLE="[CONFLICT] ${TITLE}"
fi
# Check if a PR already exists for this branch (idempotency
# for re-runs). Use --state all to catch closed/merged PRs too.
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
if [ -n "$EXISTING_PR" ]; then
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
exit 0
fi
gh pr create \
--base "$RELEASE_BRANCH" \
--head "$BACKPORT_BRANCH" \
--title "$TITLE" \
--body "$BODY"
+13 -13
View File
@@ -121,22 +121,22 @@ jobs:
fi
# Derive the release branch from the version tag.
# Non-RC releases must be on a release/X.Y branch.
# RC tags are allowed on any branch (typically main).
# Standard: 2.10.2 -> release/2.10
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
version="$(./scripts/version.sh)"
# Strip any pre-release suffix first (e.g. 2.32.0-rc.0 -> 2.32.0)
base_version="${version%%-*}"
# Then strip patch to get major.minor (e.g. 2.32.0 -> 2.32)
release_branch="release/${base_version%.*}"
if [[ "$version" == *-rc.* ]]; then
echo "RC release detected — skipping release branch check (RC tags are cut from main)."
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
base_version="${version%%-rc.*}" # 2.32.0
major_minor="${base_version%.*}" # 2.32
rc_suffix="${version##*-rc.}" # 0
release_branch="release/${major_minor}-rc.${rc_suffix}"
else
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a non-RC release, did you use scripts/release.sh?"
exit 1
fi
release_branch=release/${version%.*}
fi
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
if [[ -z "${branch_contains_tag}" ]]; then
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
exit 1
fi
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
+354
View File
@@ -0,0 +1,354 @@
name: security-backport
on:
pull_request_target:
types:
- labeled
- unlabeled
- closed
workflow_dispatch:
inputs:
pull_request:
description: Pull request number to backport.
required: true
type: string
permissions:
contents: write
pull-requests: write
issues: write
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
cancel-in-progress: false
env:
LATEST_BRANCH: release/2.31
STABLE_BRANCH: release/2.30
STABLE_1_BRANCH: release/2.29
jobs:
label-policy:
if: github.event_name == 'pull_request_target'
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Apply security backport label policy
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import subprocess
import sys
pr = json.loads(os.environ["PR_JSON"])
pr_number = pr["number"]
labels = [label["name"] for label in pr.get("labels", [])]
def has(label: str) -> bool:
return label in labels
def ensure_label(label: str) -> None:
if not has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--add-label", label],
check=False,
)
def remove_label(label: str) -> None:
if has(label):
subprocess.run(
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
check=False,
)
def comment(body: str) -> None:
subprocess.run(
["gh", "pr", "comment", str(pr_number), "--body", body],
check=True,
)
if not has("security:patch"):
remove_label("status:needs-severity")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if has(label)
]
if len(severity_labels) == 0:
ensure_label("status:needs-severity")
comment(
"This PR is labeled `security:patch` but is missing a severity "
"label. Add one of `severity:medium`, `severity:high`, or "
"`severity:critical` before backport automation can proceed."
)
sys.exit(0)
if len(severity_labels) > 1:
comment(
"This PR has multiple severity labels. Keep exactly one of "
"`severity:medium`, `severity:high`, or `severity:critical`."
)
sys.exit(1)
remove_label("status:needs-severity")
target_labels = [
label
for label in ("backport:stable", "backport:stable-1")
if has(label)
]
has_none = has("backport:none")
if has_none and target_labels:
comment(
"`backport:none` cannot be combined with other backport labels. "
"Remove `backport:none` or remove the explicit backport targets."
)
sys.exit(1)
if not has_none and not target_labels:
ensure_label("backport:stable")
ensure_label("backport:stable-1")
comment(
"Applied default backport labels `backport:stable` and "
"`backport:stable-1` for a qualifying security patch."
)
PY
backport:
if: >
github.event_name == 'workflow_dispatch' ||
(
github.event_name == 'pull_request_target' &&
github.event.pull_request.merged == true
)
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Resolve PR metadata
id: metadata
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
run: |
set -euo pipefail
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
pr_number="${INPUT_PR_NUMBER}"
else
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
fi
case "${pr_number}" in
''|*[!0-9]*)
echo "A valid pull request number is required."
exit 1
;;
esac
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
PR_JSON="${pr_json}" \
python3 - <<'PY'
import json
import os
import sys
pr = json.loads(os.environ["PR_JSON"])
github_output = os.environ["GITHUB_OUTPUT"]
labels = [label["name"] for label in pr.get("labels", [])]
if "security:patch" not in labels:
print("Not a security patch PR; skipping.")
sys.exit(0)
severity_labels = [
label
for label in ("severity:medium", "severity:high", "severity:critical")
if label in labels
]
if len(severity_labels) != 1:
raise SystemExit(
"Merged security patch PR must have exactly one severity label."
)
if not pr.get("mergedAt"):
raise SystemExit(f"PR #{pr['number']} is not merged.")
if "backport:none" in labels:
target_pairs = []
else:
mapping = {
"backport:stable": os.environ["STABLE_BRANCH"],
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
}
target_pairs = []
for label_name, branch in mapping.items():
if label_name in labels and branch and branch != pr["baseRefName"]:
target_pairs.append({"label": label_name, "branch": branch})
with open(github_output, "a", encoding="utf-8") as f:
f.write(f"pr_number={pr['number']}\n")
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
f.write(f"title={pr['title']}\n")
f.write(f"url={pr['url']}\n")
f.write(f"author={pr['author']['login']}\n")
f.write(f"severity_label={severity_labels[0]}\n")
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
PY
- name: Backport to release branches
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
PR_TITLE: ${{ steps.metadata.outputs.title }}
PR_URL: ${{ steps.metadata.outputs.url }}
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
git fetch origin --prune
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
failures=()
successes=()
while IFS=$'\t' read -r backport_label target_branch; do
[ -n "${target_branch}" ] || continue
safe_branch_name="${target_branch//\//-}"
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
existing_pr="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state all \
--json number,url \
--jq '.[0]')"
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
successes+=("${target_branch}:existing:${pr_url}")
continue
fi
git checkout -B "${head_branch}" "origin/${target_branch}"
if [ "${merge_parent_count}" -gt 1 ]; then
cherry_pick_args=(-m 1 "${MERGE_SHA}")
else
cherry_pick_args=("${MERGE_SHA}")
fi
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
git cherry-pick --abort || true
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
gh pr comment "${PR_NUMBER}" --body \
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
failures+=("${target_branch}:cherry-pick failed")
continue
fi
git push --force-with-lease origin "${head_branch}"
body_file="$(mktemp)"
printf '%s\n' \
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
"" \
"- Source PR: #${PR_NUMBER}" \
"- Source commit: ${MERGE_SHA}" \
"- Target branch: ${target_branch}" \
"- Severity: ${SEVERITY_LABEL}" \
> "${body_file}"
pr_url="$(gh pr create \
--base "${target_branch}" \
--head "${head_branch}" \
--title "${PR_TITLE} (backport to ${target_branch})" \
--body-file "${body_file}")"
backport_pr_number="$(gh pr list \
--base "${target_branch}" \
--head "${head_branch}" \
--state open \
--json number \
--jq '.[0].number')"
gh pr edit "${backport_pr_number}" \
--add-label "security:patch" \
--add-label "${SEVERITY_LABEL}" \
--add-label "${backport_label}" || true
successes+=("${target_branch}:created:${pr_url}")
done < <(
python3 - <<'PY'
import json
import os
for pair in json.loads(os.environ["TARGET_PAIRS"]):
print(f"{pair['label']}\t{pair['branch']}")
PY
)
summary_file="$(mktemp)"
{
echo "## Security backport summary"
echo
if [ "${#successes[@]}" -gt 0 ]; then
echo "### Created or existing"
for entry in "${successes[@]}"; do
echo "- ${entry}"
done
echo
fi
if [ "${#failures[@]}" -gt 0 ]; then
echo "### Failures"
for entry in "${failures[@]}"; do
echo "- ${entry}"
done
fi
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
if [ "${#failures[@]}" -gt 0 ]; then
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
exit 1
fi
+214
View File
@@ -0,0 +1,214 @@
name: security-patch-prs
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * 1-5"
permissions:
contents: write
pull-requests: write
jobs:
patch:
strategy:
fail-fast: false
matrix:
lane:
- gomod
- terraform
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
with:
egress-policy: audit
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup Go
uses: ./.github/actions/setup-go
- name: Patch Go dependencies
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go get -u=patch ./...
go mod tidy
# Guardrail: do not auto-edit replace directives.
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
echo "Refusing to auto-edit go.mod replace directives"
exit 1
fi
# Guardrail: only go.mod / go.sum may change.
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Patch bundled Terraform
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
current="$(
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
provisioner/terraform/install.go \
| head -1 \
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
)"
series="$(echo "$current" | cut -d. -f1,2)"
latest="$(
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
| jq -r --arg series "$series" '
.versions
| keys[]
| select(startswith($series + "."))
' \
| sort -V \
| tail -1
)"
test -n "$latest"
[ "$latest" != "$current" ] || exit 0
CURRENT_TERRAFORM_VERSION="$current" \
LATEST_TERRAFORM_VERSION="$latest" \
python3 - <<'PY'
from pathlib import Path
import os
current = os.environ["CURRENT_TERRAFORM_VERSION"]
latest = os.environ["LATEST_TERRAFORM_VERSION"]
updates = {
"scripts/Dockerfile.base": (
f"terraform/{current}/",
f"terraform/{latest}/",
),
"provisioner/terraform/install.go": (
f'NewVersion("{current}")',
f'NewVersion("{latest}")',
),
"install.sh": (
f'TERRAFORM_VERSION="{current}"',
f'TERRAFORM_VERSION="{latest}"',
),
}
for path_str, (before, after) in updates.items():
path = Path(path_str)
content = path.read_text()
if before not in content:
raise SystemExit(f"did not find expected text in {path_str}: {before}")
path.write_text(content.replace(before, after))
PY
# Guardrail: only the Terraform-version files may change.
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
- name: Validate Go dependency patch
if: matrix.lane == 'gomod'
shell: bash
run: |
set -euo pipefail
go test ./...
- name: Validate Terraform patch
if: matrix.lane == 'terraform'
shell: bash
run: |
set -euo pipefail
go test ./provisioner/terraform/...
docker build -f scripts/Dockerfile.base .
- name: Skip PR creation when there are no changes
id: changes
shell: bash
run: |
set -euo pipefail
if git diff --quiet; then
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
else
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
fi
- name: Commit changes
if: steps.changes.outputs.has_changes == 'true'
shell: bash
run: |
set -euo pipefail
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git checkout -B "secpatch/${{ matrix.lane }}"
git add -A
git commit -m "security: patch ${{ matrix.lane }}"
- name: Push branch
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
git push --force-with-lease \
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
- name: Create or update PR
if: steps.changes.outputs.has_changes == 'true'
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
set -euo pipefail
branch="secpatch/${{ matrix.lane }}"
title="security: patch ${{ matrix.lane }}"
body="$(cat <<'EOF'
Automated security patch PR for `${{ matrix.lane }}`.
Scope:
- gomod: patch-level Go dependency updates only
- terraform: bundled Terraform patch updates only
Guardrails:
- no application-code edits
- no auto-editing of go.mod replace directives
- CI must pass
EOF
)"
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
if [[ -n "${existing_pr}" ]]; then
gh pr edit "${existing_pr}" \
--title "${title}" \
--body "${body}"
pr_number="${existing_pr}"
else
gh pr create \
--base main \
--head "${branch}" \
--title "${title}" \
--body "${body}"
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
fi
for label in security dependencies automated-pr; do
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
gh pr edit "${pr_number}" --add-label "${label}"
fi
done
-1
View File
@@ -36,7 +36,6 @@ typ = "typ"
styl = "styl"
edn = "edn"
Inferrable = "Inferrable"
IIF = "IIF"
[files]
extend-exclude = [
-3
View File
@@ -103,6 +103,3 @@ PLAN.md
# Ignore any dev licenses
license.txt
-e
# Agent planning documents (local working files).
docs/plans/
+6 -14
View File
@@ -102,8 +102,6 @@ type Options struct {
ReportMetadataInterval time.Duration
ServiceBannerRefreshInterval time.Duration
BlockFileTransfer bool
BlockReversePortForwarding bool
BlockLocalPortForwarding bool
Execer agentexec.Execer
Devcontainers bool
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
@@ -216,8 +214,6 @@ func New(options Options) Agent {
subsystems: options.Subsystems,
logSender: agentsdk.NewLogSender(options.Logger),
blockFileTransfer: options.BlockFileTransfer,
blockReversePortForwarding: options.BlockReversePortForwarding,
blockLocalPortForwarding: options.BlockLocalPortForwarding,
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
@@ -284,8 +280,6 @@ type agent struct {
sshServer *agentssh.Server
sshMaxTimeout time.Duration
blockFileTransfer bool
blockReversePortForwarding bool
blockLocalPortForwarding bool
lifecycleUpdate chan struct{}
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
@@ -337,14 +331,12 @@ func (a *agent) TailnetConn() *tailnet.Conn {
func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
BlockReversePortForwarding: a.blockReversePortForwarding,
BlockLocalPortForwarding: a.blockLocalPortForwarding,
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
var connectionType proto.Connection_Type
switch magicType {
-155
View File
@@ -986,161 +986,6 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
requireEcho(t, conn)
}
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
localhost := netip.MustParseAddr("127.0.0.1")
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
_, err = sshClient.ListenTCP(addr)
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
}
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
l, err := net.Listen("unix", remoteSocketPath)
require.NoError(t, err)
defer l.Close()
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.Dial("unix", remoteSocketPath)
require.ErrorContains(t, err, "administratively prohibited")
}
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("unix domain sockets are not fully supported on Windows")
}
ctx := testutil.Context(t, testutil.WaitLong)
tmpdir := testutil.TempDirUnixSocket(t)
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
_, err = sshClient.ListenUnix(remoteSocketPath)
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
}
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
// local port forwarding does not prevent reverse port forwarding from
// working. A field-name transposition at any plumbing hop would cause
// both directions to be blocked when only one flag is set.
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockLocalPortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Reverse forwarding must still work.
localhost := netip.MustParseAddr("127.0.0.1")
var ll net.Listener
for {
randomPort := testutil.RandomPortNoListen(t)
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
ll, err = sshClient.ListenTCP(addr)
if err != nil {
t.Logf("error remote forwarding: %s", err.Error())
select {
case <-ctx.Done():
t.Fatal("timed out getting random listener")
default:
continue
}
}
break
}
_ = ll.Close()
}
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
// reverse port forwarding does not prevent local port forwarding from
// working.
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
rl, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer rl.Close()
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
require.True(t, valid)
remotePort := tcpAddr.Port
go echoOnce(t, rl)
//nolint:dogsled
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockReversePortForwarding = true
})
sshClient, err := agentConn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
// Local forwarding must still work.
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
require.NoError(t, err)
defer conn.Close()
requireEcho(t, conn)
}
func TestAgent_UnixLocalForwarding(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
+3 -26
View File
@@ -117,10 +117,6 @@ type Config struct {
X11MaxPort *int
// BlockFileTransfer restricts use of file transfer applications.
BlockFileTransfer bool
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
BlockReversePortForwarding bool
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
BlockLocalPortForwarding bool
// ReportConnection.
ReportConnection reportConnectionFunc
// Experimental: allow connecting to running containers via Docker exec.
@@ -194,7 +190,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
}
forwardHandler := &ssh.ForwardedTCPHandler{}
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
unixForwardHandler := newForwardedUnixHandler(logger)
metrics := newSSHServerMetrics(prometheusRegistry)
s := &Server{
@@ -233,15 +229,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
},
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
if s.config.BlockLocalPortForwarding {
s.logger.Warn(ctx, "unix local port forward blocked")
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
return
}
directStreamLocalHandler(srv, conn, newChan, ctx)
},
"session": ssh.DefaultSessionHandler,
"direct-streamlocal@openssh.com": directStreamLocalHandler,
"session": ssh.DefaultSessionHandler,
},
ConnectionFailedCallback: func(conn net.Conn, err error) {
s.logger.Warn(ctx, "ssh connection failed",
@@ -261,12 +250,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
// be set before we start listening.
HostSigners: []ssh.Signer{},
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
if s.config.BlockLocalPortForwarding {
s.logger.Warn(ctx, "local port forward blocked",
slog.F("destination_host", destinationHost),
slog.F("destination_port", destinationPort))
return false
}
// Allow local port forwarding all!
s.logger.Debug(ctx, "local port forward",
slog.F("destination_host", destinationHost),
@@ -277,12 +260,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
return true
},
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
if s.config.BlockReversePortForwarding {
s.logger.Warn(ctx, "reverse port forward blocked",
slog.F("bind_host", bindHost),
slog.F("bind_port", bindPort))
return false
}
// Allow reverse port forwarding all!
s.logger.Debug(ctx, "reverse port forward",
slog.F("bind_host", bindHost),
+5 -11
View File
@@ -35,9 +35,8 @@ type forwardedStreamLocalPayload struct {
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
type forwardedUnixHandler struct {
sync.Mutex
log slog.Logger
forwards map[forwardKey]net.Listener
blockReversePortForwarding bool
log slog.Logger
forwards map[forwardKey]net.Listener
}
type forwardKey struct {
@@ -45,11 +44,10 @@ type forwardKey struct {
addr string
}
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
return &forwardedUnixHandler{
log: log,
forwards: make(map[forwardKey]net.Listener),
blockReversePortForwarding: blockReversePortForwarding,
log: log,
forwards: make(map[forwardKey]net.Listener),
}
}
@@ -64,10 +62,6 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
switch req.Type {
case "streamlocal-forward@openssh.com":
if h.blockReversePortForwarding {
log.Warn(ctx, "unix reverse port forward blocked")
return false, nil
}
var reqPayload streamLocalForwardPayload
err := gossh.Unmarshal(req.Payload, &reqPayload)
if err != nil {
+4 -22
View File
@@ -53,8 +53,6 @@ func workspaceAgent() *serpent.Command {
slogJSONPath string
slogStackdriverPath string
blockFileTransfer bool
blockReversePortForwarding bool
blockLocalPortForwarding bool
agentHeaderCommand string
agentHeader []string
devcontainers bool
@@ -321,12 +319,10 @@ func workspaceAgent() *serpent.Command {
SSHMaxTimeout: sshMaxTimeout,
Subsystems: subsystems,
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
BlockReversePortForwarding: blockReversePortForwarding,
BlockLocalPortForwarding: blockLocalPortForwarding,
Execer: execer,
Devcontainers: devcontainers,
PrometheusRegistry: prometheusRegistry,
BlockFileTransfer: blockFileTransfer,
Execer: execer,
Devcontainers: devcontainers,
DevcontainerAPIOptions: []agentcontainers.Option{
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
@@ -497,20 +493,6 @@ func workspaceAgent() *serpent.Command {
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
Value: serpent.BoolOf(&blockFileTransfer),
},
{
Flag: "block-reverse-port-forwarding",
Default: "false",
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
Value: serpent.BoolOf(&blockReversePortForwarding),
},
{
Flag: "block-local-port-forwarding",
Default: "false",
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
Description: "Block local port forwarding through the SSH server (ssh -L).",
Value: serpent.BoolOf(&blockLocalPortForwarding),
},
{
Flag: "devcontainers-enable",
Default: "true",
-20
View File
@@ -768,30 +768,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
return xerrors.Errorf("create pubsub: %w", err)
}
options.Pubsub = ps
options.ChatPubsub = ps
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(ps)
}
defer options.Pubsub.Close()
chatPubsub, err := pubsub.NewBatching(
ctx,
logger.Named("chatd").Named("pubsub_batch"),
ps,
sqlDB,
dbURL,
pubsub.BatchingConfig{
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
},
)
if err != nil {
return xerrors.Errorf("create chat pubsub batcher: %w", err)
}
options.ChatPubsub = chatPubsub
if options.DeploymentValues.Prometheus.Enable {
options.PrometheusRegistry.MustRegister(chatPubsub)
}
defer options.ChatPubsub.Close()
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
pubsubWatchdogTimeout = psWatchdog.Timeout()
defer psWatchdog.Close()
+17 -97
View File
@@ -52,10 +52,6 @@ import (
const (
disableUsageApp = "disable"
// Retry transient errors during SSH connection establishment.
sshRetryInterval = 2 * time.Second
sshMaxAttempts = 10 // initial + retries per step
)
var (
@@ -66,51 +62,6 @@ var (
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
)
// isRetryableError checks for transient connection errors worth
// retrying: DNS failures, connection refused, and server 5xx.
func isRetryableError(err error) bool {
if err == nil {
return false
}
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return false
}
if codersdk.IsConnectionError(err) {
return true
}
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
return sdkErr.StatusCode() >= 500
}
return false
}
// retryWithInterval calls fn up to maxAttempts times, waiting
// interval between attempts. Stops on success, non-retryable
// error, or context cancellation.
func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Duration, maxAttempts int, fn func() error) error {
var lastErr error
attempt := 0
for r := retry.New(interval, interval); r.Wait(ctx); {
lastErr = fn()
if lastErr == nil || !isRetryableError(lastErr) {
return lastErr
}
attempt++
if attempt >= maxAttempts {
break
}
logger.Warn(ctx, "transient error, retrying",
slog.Error(lastErr),
slog.F("attempt", attempt),
)
}
if lastErr != nil {
return lastErr
}
return ctx.Err()
}
func (r *RootCmd) ssh() *serpent.Command {
var (
stdio bool
@@ -326,17 +277,10 @@ func (r *RootCmd) ssh() *serpent.Command {
HostnameSuffix: hostnameSuffix,
}
// Populated by the closure below.
var workspace codersdk.Workspace
var workspaceAgent codersdk.WorkspaceAgent
resolveWorkspace := func() error {
var err error
workspace, workspaceAgent, err = findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
return err
}
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, resolveWorkspace); err != nil {
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
ctx, inv, client,
inv.Args[0], cliConfig, disableAutostart)
if err != nil {
return err
}
@@ -362,13 +306,8 @@ func (r *RootCmd) ssh() *serpent.Command {
wait = false
}
var templateVersion codersdk.TemplateVersion
fetchVersion := func() error {
var err error
templateVersion, err = client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
return err
}
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, fetchVersion); err != nil {
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
if err != nil {
return err
}
@@ -408,12 +347,8 @@ func (r *RootCmd) ssh() *serpent.Command {
// If we're in stdio mode, check to see if we can use Coder Connect.
// We don't support Coder Connect over non-stdio coder ssh yet.
if stdio && !forceNewTunnel {
var connInfo workspacesdk.AgentConnectionInfo
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
connInfo, err = wsClient.AgentConnectionInfoGeneric(ctx)
return err
}); err != nil {
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
if err != nil {
return xerrors.Errorf("get agent connection info: %w", err)
}
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
@@ -449,27 +384,23 @@ func (r *RootCmd) ssh() *serpent.Command {
})
defer closeUsage()
}
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack, logger)
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
}
}
if r.disableDirect {
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
}
var conn workspacesdk.AgentConn
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
conn, err = wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
conn, err := wsClient.
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
Logger: logger,
BlockEndpoints: r.disableDirect,
EnableTelemetry: !r.disableNetworkTelemetry,
})
return err
}); err != nil {
if err != nil {
return xerrors.Errorf("dial agent: %w", err)
}
if err = stack.push("agent conn", conn); err != nil {
_ = conn.Close()
return err
}
conn.AwaitReachable(ctx)
@@ -1647,27 +1578,16 @@ func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDial
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
if !ok || dialer == nil {
// Timeout prevents hanging on broken tunnels (OS default is very long).
return &net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}
return &net.Dialer{}
}
return dialer
}
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack, logger slog.Logger) error {
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
dialer := testOrDefaultDialer(ctx)
var conn net.Conn
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
var err error
conn, err = dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host %q over tcp: %w", addr, err)
}
return nil
}); err != nil {
return err
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return xerrors.Errorf("dial coder connect host: %w", err)
}
if err := stack.push("tcp conn", conn); err != nil {
return err
+1 -149
View File
@@ -5,9 +5,7 @@ import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"sync"
"testing"
"time"
@@ -228,41 +226,6 @@ func TestCloserStack_Timeout(t *testing.T) {
testutil.TryReceive(ctx, t, closed)
}
func TestCloserStack_PushAfterClose_ConnClosed(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
uut.close(xerrors.New("canceled"))
closes := new([]*fakeCloser)
fc := &fakeCloser{closes: closes}
err := uut.push("conn", fc)
require.Error(t, err)
require.Equal(t, []*fakeCloser{fc}, *closes, "should close conn on failed push")
}
func TestCoderConnectDialer_DefaultTimeout(t *testing.T) {
t.Parallel()
ctx := context.Background()
dialer := testOrDefaultDialer(ctx)
d, ok := dialer.(*net.Dialer)
require.True(t, ok, "expected *net.Dialer")
assert.Equal(t, 5*time.Second, d.Timeout)
assert.Equal(t, 30*time.Second, d.KeepAlive)
}
func TestCoderConnectDialer_Overridden(t *testing.T) {
t.Parallel()
custom := &net.Dialer{Timeout: 99 * time.Second}
ctx := WithTestOnlyCoderConnectDialer(context.Background(), custom)
dialer := testOrDefaultDialer(ctx)
assert.Equal(t, custom, dialer)
}
func TestCoderConnectStdio(t *testing.T) {
t.Parallel()
@@ -291,7 +254,7 @@ func TestCoderConnectStdio(t *testing.T) {
stdioDone := make(chan struct{})
go func() {
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack, logger)
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
assert.NoError(t, err)
close(stdioDone)
}()
@@ -485,114 +448,3 @@ func Test_getWorkspaceAgent(t *testing.T) {
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
})
}
func TestIsRetryableError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
retryable bool
}{
{"Nil", nil, false},
{"ContextCanceled", context.Canceled, false},
{"ContextDeadlineExceeded", context.DeadlineExceeded, false},
{"WrappedContextCanceled", xerrors.Errorf("wrapped: %w", context.Canceled), false},
{"DNSError", &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}, true},
{"OpError", &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{}}, true},
{"WrappedDNSError", xerrors.Errorf("connect: %w", &net.DNSError{Err: "no such host", Name: "example.com"}), true},
{"SDKError_500", codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"), true},
{"SDKError_502", codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"), true},
{"SDKError_503", codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"), true},
{"SDKError_401", codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"), false},
{"SDKError_403", codersdk.NewTestError(http.StatusForbidden, "GET", "/api"), false},
{"SDKError_404", codersdk.NewTestError(http.StatusNotFound, "GET", "/api"), false},
{"GenericError", xerrors.New("something went wrong"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
})
}
}
func TestRetryWithInterval(t *testing.T) {
t.Parallel()
const interval = time.Millisecond
const maxAttempts = 3
dnsErr := &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
t.Run("Succeeds_FirstTry", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return nil
})
require.NoError(t, err)
assert.Equal(t, 1, attempts)
})
t.Run("Succeeds_AfterTransientFailures", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
if attempts < 3 {
return dnsErr
}
return nil
})
require.NoError(t, err)
assert.Equal(t, 3, attempts)
})
t.Run("Stops_NonRetryableError", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return xerrors.New("permanent failure")
})
require.ErrorContains(t, err, "permanent failure")
assert.Equal(t, 1, attempts)
})
t.Run("Stops_MaxAttemptsExhausted", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
return dnsErr
})
require.Error(t, err)
assert.Equal(t, maxAttempts, attempts)
})
t.Run("Stops_ContextCanceled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
attempts := 0
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
attempts++
cancel()
return dnsErr
})
require.Error(t, err)
assert.Equal(t, 1, attempts)
})
}
-6
View File
@@ -39,12 +39,6 @@ OPTIONS:
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
Block file transfer using known applications: nc,rsync,scp,sftp.
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
Block local port forwarding through the SSH server (ssh -L).
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
Block reverse port forwarding through the SSH server (ssh -R).
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
The path for the boundary log proxy server Unix socket. Boundary
should write audit logs to this socket.
-1
View File
@@ -134,7 +134,6 @@ func TestUserCreate(t *testing.T) {
{
name: "ServiceAccount",
args: []string{"--service-account", "-u", "dean"},
err: "Premium feature",
},
{
name: "ServiceAccountLoginType",
+2 -3
View File
@@ -77,9 +77,8 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
level := make([]database.LogLevel, 0)
outputLength := 0
for _, logEntry := range req.Logs {
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
output = append(output, logEntry.Output)
outputLength += len(logEntry.Output)
var dbLevel database.LogLevel
switch logEntry.Level {
-53
View File
@@ -139,59 +139,6 @@ func TestBatchCreateLogs(t *testing.T) {
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
})
t.Run("SanitizesOutput", func(t *testing.T) {
t.Parallel()
dbM := dbmock.NewMockStore(gomock.NewController(t))
now := dbtime.Now()
api := &agentapi.LogsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
Database: dbM,
Log: testutil.Logger(t),
TimeNowFn: func() time.Time {
return now
},
}
rawOutput := "before\x00middle\xc3\x28after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
req := &agentproto.BatchCreateLogsRequest{
LogSourceId: logSource.ID[:],
Logs: []*agentproto.Log{
{
CreatedAt: timestamppb.New(now),
Level: agentproto.Log_WARN,
Output: rawOutput,
},
},
}
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
AgentID: agent.ID,
LogSourceID: logSource.ID,
CreatedAt: now,
Output: []string{sanitizedOutput},
Level: []database.LogLevel{database.LogLevelWarn},
OutputLength: expectedOutputLength,
}).Return([]database.WorkspaceAgentLog{
{
AgentID: agent.ID,
CreatedAt: now,
ID: 1,
Output: sanitizedOutput,
Level: database.LogLevelWarn,
LogSourceID: logSource.ID,
},
}, nil)
resp, err := api.BatchCreateLogs(context.Background(), req)
require.NoError(t, err)
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
})
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
t.Parallel()
-78
View File
@@ -1266,68 +1266,6 @@ const docTemplate = `{
]
}
},
"/experimental/chats/config/retention-days": {
"get": {
"produces": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Get chat retention days",
"operationId": "get-chat-retention-days",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
},
"put": {
"consumes": [
"application/json"
],
"tags": [
"Chats"
],
"summary": "Update chat retention days",
"operationId": "update-chat-retention-days",
"parameters": [
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"produces": [
@@ -14482,14 +14420,6 @@ const docTemplate = `{
}
}
},
"codersdk.ChatRetentionDaysResponse": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -21022,14 +20952,6 @@ const docTemplate = `{
}
}
},
"codersdk.UpdateChatRetentionDaysRequest": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.UpdateCheckResponse": {
"type": "object",
"properties": {
-70
View File
@@ -1103,60 +1103,6 @@
]
}
},
"/experimental/chats/config/retention-days": {
"get": {
"produces": ["application/json"],
"tags": ["Chats"],
"summary": "Get chat retention days",
"operationId": "get-chat-retention-days",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
}
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
},
"put": {
"consumes": ["application/json"],
"tags": ["Chats"],
"summary": "Update chat retention days",
"operationId": "update-chat-retention-days",
"parameters": [
{
"description": "Request body",
"name": "request",
"in": "body",
"required": true,
"schema": {
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
}
}
],
"responses": {
"204": {
"description": "No Content"
}
},
"security": [
{
"CoderSessionToken": []
}
],
"x-apidocgen": {
"skip": true
}
}
},
"/experimental/watch-all-workspacebuilds": {
"get": {
"produces": ["application/json"],
@@ -13017,14 +12963,6 @@
}
}
},
"codersdk.ChatRetentionDaysResponse": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.ConnectionLatency": {
"type": "object",
"properties": {
@@ -19305,14 +19243,6 @@
}
}
},
"codersdk.UpdateChatRetentionDaysRequest": {
"type": "object",
"properties": {
"retention_days": {
"type": "integer"
}
}
},
"codersdk.UpdateCheckResponse": {
"type": "object",
"properties": {
+2 -13
View File
@@ -159,10 +159,7 @@ type Options struct {
Logger slog.Logger
Database database.Store
Pubsub pubsub.Pubsub
// ChatPubsub allows chatd to use a dedicated publish path without changing
// the shared pubsub used by the rest of coderd.
ChatPubsub pubsub.Pubsub
RuntimeConfig *runtimeconfig.Manager
RuntimeConfig *runtimeconfig.Manager
// CacheDir is used for caching files served by the API.
CacheDir string
@@ -780,11 +777,6 @@ func New(options *Options) *API {
maxChatsPerAcquire = math.MinInt32
}
chatPubsub := options.ChatPubsub
if chatPubsub == nil {
chatPubsub = options.Pubsub
}
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
@@ -797,7 +789,7 @@ func New(options *Options) *API {
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
CreateWorkspace: api.chatCreateWorkspace,
StartWorkspace: api.chatStartWorkspace,
Pubsub: chatPubsub,
Pubsub: options.Pubsub,
WebpushDispatcher: options.WebPushDispatcher,
UsageTracker: options.WorkspaceUsageTracker,
})
@@ -1197,8 +1189,6 @@ func New(options *Options) *API {
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
r.Get("/retention-days", api.getChatRetentionDays)
r.Put("/retention-days", api.putChatRetentionDays)
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
})
@@ -1253,7 +1243,6 @@ func New(options *Options) *API {
r.Get("/git", api.watchChatGit)
})
r.Post("/interrupt", api.interruptChat)
r.Post("/tool-results", api.postChatToolResults)
r.Post("/title/regenerate", api.regenerateChatTitle)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
+5 -16
View File
@@ -123,10 +123,6 @@ func UsersPagination(
require.Contains(t, gotUsers[0].Name, "after")
}
type UsersFilterOptions struct {
CreateServiceAccounts bool
}
// UsersFilter creates a set of users to run various filters against for
// testing. It can be used to test filtering both users and group members.
func UsersFilter(
@@ -134,16 +130,11 @@ func UsersFilter(
t *testing.T,
client *codersdk.Client,
db database.Store,
options *UsersFilterOptions,
setup func(users []codersdk.User),
fetch func(ctx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser,
) {
t.Helper()
if options == nil {
options = &UsersFilterOptions{}
}
firstUser, err := client.User(setupCtx, codersdk.Me)
require.NoError(t, err, "fetch me")
@@ -220,13 +211,11 @@ func UsersFilter(
}
// Add some service accounts.
if options.CreateServiceAccounts {
for range 3 {
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.ServiceAccount = true
})
users = append(users, user)
}
for range 3 {
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
r.ServiceAccount = true
})
users = append(users, user)
}
hashedPassword, err := userpassword.Hash("SomeStrongPassword!")
-38
View File
@@ -1715,41 +1715,3 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
return result
}
// UserSecret converts a database ListUserSecretsRow (metadata only,
// no value) to an SDK UserSecret.
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
return codersdk.UserSecret{
ID: secret.ID,
Name: secret.Name,
Description: secret.Description,
EnvName: secret.EnvName,
FilePath: secret.FilePath,
CreatedAt: secret.CreatedAt,
UpdatedAt: secret.UpdatedAt,
}
}
// UserSecretFromFull converts a full database UserSecret row to an
// SDK UserSecret, omitting the value and encryption key ID.
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
return codersdk.UserSecret{
ID: secret.ID,
Name: secret.Name,
Description: secret.Description,
EnvName: secret.EnvName,
FilePath: secret.FilePath,
CreatedAt: secret.CreatedAt,
UpdatedAt: secret.UpdatedAt,
}
}
// UserSecrets converts a slice of database ListUserSecretsRow to
// SDK UserSecret values.
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
result := make([]codersdk.UserSecret, 0, len(secrets))
for _, s := range secrets {
result = append(result, UserSecret(s))
}
return result
}
-4
View File
@@ -552,10 +552,6 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
Valid: true,
},
}
// Only ChatID is needed here. This test checks that
// Chat.DiffStatus is non-nil, not that every DiffStatus
+32 -73
View File
@@ -2031,20 +2031,6 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
return q.db.DeleteOldAuditLogs(ctx, arg)
}
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.DeleteOldChatFiles(ctx, arg)
}
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
}
return q.db.DeleteOldChats(ctx, arg)
}
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
return 0, err
@@ -2169,12 +2155,17 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
return q.db.DeleteUserChatProviderKey(ctx, arg)
}
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
// First get the secret to check ownership
secret, err := q.GetUserSecret(ctx, id)
if err != nil {
return err
}
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
return err
}
return q.db.DeleteUserSecret(ctx, id)
}
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
@@ -2636,14 +2627,6 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
return msg, nil
}
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
}
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, arg.ChatID)
@@ -2692,14 +2675,6 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
return q.db.GetChatModelConfigs(ctx)
}
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatModelConfigsForTelemetry(ctx)
}
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.ChatProvider{}, err
@@ -2729,15 +2704,6 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
return q.db.GetChatQueuedMessages(ctx, chatID)
}
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
// Chat retention is a deployment-wide config read by dbpurge.
// Only requires a valid actor in context.
if _, ok := ActorFromContext(ctx); !ok {
return 0, ErrNoActor
}
return q.db.GetChatRetentionDays(ctx)
}
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
// The system prompt is a deployment-wide setting read during chat
// creation by every authenticated user, so no RBAC policy check
@@ -2816,14 +2782,6 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
}
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
// Telemetry queries are called from system contexts only.
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
}
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
// Just like with the audit logs query, shortcut if the user is an owner.
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -4170,6 +4128,19 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
return q.db.GetUserNotificationPreferences(ctx, userID)
}
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, id)
if err != nil {
return database.UserSecret{}, err
}
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
return database.UserSecret{}, err
}
return secret, nil
}
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
@@ -5553,7 +5524,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
return q.db.ListUserChatCompactionThresholds(ctx, userID)
}
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
return nil, err
@@ -5561,16 +5532,6 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
return q.db.ListUserSecrets(ctx, userID)
}
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
// This query returns decrypted secret values and must only be called
// from system contexts (provisioner, agent manifest). REST API
// handlers should use ListUserSecrets (metadata only).
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.ListUserSecretsWithValues(ctx, userID)
}
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
if err != nil {
@@ -6671,12 +6632,17 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
return q.db.UpdateUserRoles(ctx, arg)
}
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
// First get the secret to check ownership
secret, err := q.db.GetUserSecret(ctx, arg.ID)
if err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
return database.UserSecret{}, err
}
return q.db.UpdateUserSecret(ctx, arg)
}
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
@@ -7078,13 +7044,6 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
}
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
}
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
}
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
+22 -52
View File
@@ -600,22 +600,6 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
}))
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
}))
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
check.Args().Asserts()
}))
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
}))
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -4012,20 +3996,6 @@ func (s *MethodTestSuite) TestSystemFunctions() {
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
ts := dbtime.Now()
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
@@ -5376,20 +5346,19 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns(secret)
}))
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead).
Returns(secret)
}))
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
Returns([]database.ListUserSecretsRow{row})
}))
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
check.Args(user.ID).
Asserts(rbac.ResourceSystem, policy.ActionRead).
Returns([]database.UserSecret{secret})
}))
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
@@ -5401,21 +5370,22 @@ func (s *MethodTestSuite) TestUserSecrets() {
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
Returns(ret)
}))
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
arg := database.UpdateUserSecretParams{ID: secret.ID}
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
Asserts(secret, policy.ActionUpdate).
Returns(updated)
}))
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
user := testutil.Fake(s.T(), faker, database.User{})
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
check.Args(secret.ID).
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
Returns()
}))
}
-3
View File
@@ -1597,7 +1597,6 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
Name: takeFirst(seed.Name, "secret-name"),
Description: takeFirst(seed.Description, "secret description"),
Value: takeFirst(seed.Value, "secret value"),
ValueKeyID: seed.ValueKeyID,
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
})
@@ -1644,8 +1643,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
ClientSessionID: seed.ClientSessionID,
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
CredentialHint: takeFirst(seed.CredentialHint, ""),
})
if endedAt != nil {
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
+17 -73
View File
@@ -592,22 +592,6 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
return r0, r1
}
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldChats(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
return r0, r1
}
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
start := time.Now()
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
@@ -728,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
return r0
}
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
start := time.Now()
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
r0 := m.s.DeleteUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
return r0
}
@@ -1176,14 +1160,6 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
return r0, r1
}
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
@@ -1232,14 +1208,6 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
return r0, r1
}
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
start := time.Now()
r0, r1 := m.s.GetChatProviderByID(ctx, id)
@@ -1272,14 +1240,6 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
return r0, r1
}
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
start := time.Now()
r0, r1 := m.s.GetChatRetentionDays(ctx)
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetChatSystemPrompt(ctx)
@@ -1352,14 +1312,6 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
return r0, r1
}
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
return r0, r1
}
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
start := time.Now()
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
@@ -2672,6 +2624,14 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
return r0, r1
}
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecret(ctx, id)
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
return r0, r1
}
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
@@ -3960,7 +3920,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
return r0, r1
}
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecrets(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
@@ -3968,14 +3928,6 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
return r0, r1
}
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
return r0, r1
}
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
start := time.Now()
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
@@ -4744,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
return r0, r1
}
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
start := time.Now()
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
return r0, r1
}
@@ -5048,14 +5000,6 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
return r0
}
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
start := time.Now()
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
return r0
}
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
start := time.Now()
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
+29 -133
View File
@@ -984,36 +984,6 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
}
// DeleteOldChatFiles mocks base method.
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
}
// DeleteOldChats mocks base method.
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DeleteOldChats indicates an expected call of DeleteOldChats.
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
}
// DeleteOldConnectionLogs mocks base method.
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
m.ctrl.T.Helper()
@@ -1229,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
}
// DeleteUserSecretByUserIDAndName mocks base method.
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
// DeleteUserSecret mocks base method.
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
}
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
@@ -2162,21 +2132,6 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
}
// GetChatMessageSummariesPerChat mocks base method.
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
}
// GetChatMessagesByChatID mocks base method.
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -2267,21 +2222,6 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
}
// GetChatModelConfigsForTelemetry mocks base method.
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
}
// GetChatProviderByID mocks base method.
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
m.ctrl.T.Helper()
@@ -2342,21 +2282,6 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
}
// GetChatRetentionDays mocks base method.
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
ret0, _ := ret[0].(int32)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
}
// GetChatSystemPrompt mocks base method.
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
m.ctrl.T.Helper()
@@ -2492,21 +2417,6 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
}
// GetChatsUpdatedAfter mocks base method.
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
}
// GetConnectionLogsOffset mocks base method.
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
m.ctrl.T.Helper()
@@ -4997,6 +4907,21 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
}
// GetUserSecret mocks base method.
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetUserSecret indicates an expected call of GetUserSecret.
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
}
// GetUserSecretByUserIDAndName mocks base method.
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
@@ -7487,10 +7412,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
}
// ListUserSecrets mocks base method.
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
ret0, _ := ret[0].([]database.ListUserSecretsRow)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -7501,21 +7426,6 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
}
// ListUserSecretsWithValues mocks base method.
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
ret0, _ := ret[0].([]database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
}
// ListWorkspaceAgentPortShares mocks base method.
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
m.ctrl.T.Helper()
@@ -8944,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
}
// UpdateUserSecretByUserIDAndName mocks base method.
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
// UpdateUserSecret mocks base method.
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
ret0, _ := ret[0].(database.UserSecret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
}
// UpdateUserStatus mocks base method.
@@ -9489,20 +9399,6 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
}
// UpsertChatRetentionDays mocks base method.
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
ret0, _ := ret[0].(error)
return ret0
}
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
}
// UpsertChatSystemPrompt mocks base method.
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
m.ctrl.T.Helper()
-49
View File
@@ -34,11 +34,6 @@ const (
// long enough to cover the maximum interval of a heartbeat event (currently
// 1 hour) plus some buffer.
maxTelemetryHeartbeatAge = 24 * time.Hour
// Batch sizes for chat purging. Both use 1000, which is smaller
// than audit/connection log batches (10000), because chat_files
// rows contain bytea blob data that make large batches heavier.
chatsBatchSize = 1000
chatFilesBatchSize = 1000
)
// New creates a new periodically purging database instance.
@@ -114,17 +109,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
// purgeTick performs a single purge iteration. It returns an error if the
// purge fails.
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
// Read chat retention config outside the transaction to
// avoid poisoning the tx if the stored value is corrupt.
// A SQL-level cast error (e.g. non-numeric text) puts PG
// into error state, failing all subsequent queries in the
// same transaction.
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
if err != nil {
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
chatRetentionDays = 0
}
// Start a transaction to grab advisory lock, we don't want to run
// multiple purges at the same time (multiple replicas).
return db.InTx(func(tx database.Store) error {
@@ -229,43 +213,12 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
}
}
// Chat retention is configured via site_configs. When
// enabled, old archived chats are deleted first, then
// orphaned chat files. Deleting a chat cascades to
// chat_file_links (removing references) but not to
// chat_files directly, so files from deleted chats
// become orphaned and are caught by DeleteOldChatFiles
// in the same tick.
var purgedChats int64
var purgedChatFiles int64
if chatRetentionDays > 0 {
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
deleteChatsBefore := start.Add(-chatRetention)
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: deleteChatsBefore,
LimitCount: chatsBatchSize,
})
if err != nil {
return xerrors.Errorf("failed to delete old chats: %w", err)
}
purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: deleteChatsBefore,
LimitCount: chatFilesBatchSize,
})
if err != nil {
return xerrors.Errorf("failed to delete old chat files: %w", err)
}
}
i.logger.Debug(ctx, "purged old database entries",
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
slog.F("expired_api_keys", expiredAPIKeys),
slog.F("aibridge_records", purgedAIBridgeRecords),
slog.F("connection_logs", purgedConnectionLogs),
slog.F("audit_logs", purgedAuditLogs),
slog.F("chats", purgedChats),
slog.F("chat_files", purgedChatFiles),
slog.F("duration", i.clk.Since(start)),
)
@@ -279,8 +232,6 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
}
return nil
-498
View File
@@ -12,7 +12,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -54,7 +53,6 @@ func TestPurge(t *testing.T) {
clk := quartz.NewMock(t)
done := awaitDoTick(ctx, t, clk)
mDB := dbmock.NewMockStore(gomock.NewController(t))
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
<-done // wait for doTick() to run.
@@ -127,16 +125,6 @@ func TestMetrics(t *testing.T) {
"record_type": "audit_logs",
})
require.GreaterOrEqual(t, auditLogs, 0)
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
"record_type": "chats",
})
require.GreaterOrEqual(t, chats, 0)
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
"record_type": "chat_files",
})
require.GreaterOrEqual(t, chatFiles, 0)
})
t.Run("FailedIteration", func(t *testing.T) {
@@ -150,7 +138,6 @@ func TestMetrics(t *testing.T) {
ctrl := gomock.NewController(t)
mDB := dbmock.NewMockStore(ctrl)
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
Return(xerrors.New("simulated database error")).
MinTimes(1)
@@ -1647,488 +1634,3 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
func ptr[T any](v T) *T {
return &v
}
//nolint:paralleltest // It uses LockIDDBPurge.
func TestDeleteOldChatFiles(t *testing.T) {
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
// createChatFile inserts a chat file and backdates created_at.
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
t.Helper()
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
OwnerID: ownerID,
OrganizationID: orgID,
Name: "test.png",
Mimetype: "image/png",
Data: []byte("fake-image-data"),
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
require.NoError(t, err)
return row.ID
}
// createChat inserts a chat and optionally archives it, then
// backdates updated_at to control the "archived since" window.
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
t.Helper()
chat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Title: "test-chat",
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
if archived {
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
}
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
require.NoError(t, err)
return chat
}
// setupChatDeps creates the common dependencies needed for
// chat-related tests: user, org, org member, provider, model config.
type chatDeps struct {
user database.User
org database.Organization
modelConfig database.ChatModelConfig
}
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
t.Helper()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "test-model",
ContextLimit: 8192,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
return chatDeps{user: user, org: org, modelConfig: mc}
}
tests := []struct {
name string
run func(t *testing.T)
}{
{
name: "ChatRetentionDisabled",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
// Disable retention.
err := db.UpsertChatRetentionDays(ctx, int32(0))
require.NoError(t, err)
// Create an old archived chat and an orphaned old file.
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
// Both should still exist.
_, err = db.GetChatByID(ctx, oldChat.ID)
require.NoError(t, err, "chat should not be deleted when retention is disabled")
_, err = db.GetChatFileByID(ctx, oldFileID)
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
},
},
{
name: "OldArchivedChatsDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// Old archived chat (31 days) — should be deleted.
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
// Insert a message so we can verify CASCADE.
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: oldChat.ID,
CreatedBy: []uuid.UUID{deps.user.ID},
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
Content: []string{`[{"type":"text","text":"hello"}]`},
ContentVersion: []int16{0},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
ProviderResponseID: []string{""},
})
require.NoError(t, err)
// Recently archived chat (10 days) — should be retained.
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
// Active chat — should be retained.
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
// Old archived chat should be gone.
_, err = db.GetChatByID(ctx, oldChat.ID)
require.Error(t, err, "old archived chat should be deleted")
// Its messages should be gone too (CASCADE).
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: oldChat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Empty(t, msgs, "messages should be cascade-deleted")
// Recent archived and active chats should remain.
_, err = db.GetChatByID(ctx, recentChat.ID)
require.NoError(t, err, "recently archived chat should be retained")
_, err = db.GetChatByID(ctx, activeChat.ID)
require.NoError(t, err, "active chat should be retained")
},
},
{
name: "OrphanedOldFilesDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// File A: 31 days old, NOT in any chat -> should be deleted.
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
// File B: 31 days old, in an active chat -> should be retained.
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: activeChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileB},
})
require.NoError(t, err)
// File C: 10 days old, NOT in any chat -> should be retained (too young).
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
// File near boundary: 29d23h old — close to threshold.
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
_, err = db.GetChatFileByID(ctx, fileA)
require.Error(t, err, "orphaned old file A should be deleted")
_, err = db.GetChatFileByID(ctx, fileB)
require.NoError(t, err, "file B in active chat should be retained")
_, err = db.GetChatFileByID(ctx, fileC)
require.NoError(t, err, "young file C should be retained")
_, err = db.GetChatFileByID(ctx, fileBoundary)
require.NoError(t, err, "file near 30d boundary should be retained")
},
},
{
name: "ArchivedChatFilesDeleted",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
clk := quartz.NewMock(t)
clk.Set(now).MustWait(ctx)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
deps := setupChatDeps(ctx, t, db)
err := db.UpsertChatRetentionDays(ctx, int32(30))
require.NoError(t, err)
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: oldArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileD},
})
require.NoError(t, err)
// LinkChatFiles does not update chats.updated_at, so backdate.
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
require.NoError(t, err)
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: recentArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileE},
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
require.NoError(t, err)
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: anotherOldArchivedChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileF},
})
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
require.NoError(t, err)
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: activeChatForF.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileF},
})
require.NoError(t, err)
done := awaitDoTick(ctx, t, clk)
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
defer closer.Close()
testutil.TryReceive(ctx, t, done)
_, err = db.GetChatFileByID(ctx, fileD)
require.Error(t, err, "file D in old archived chat should be deleted")
_, err = db.GetChatFileByID(ctx, fileE)
require.NoError(t, err, "file E in recently archived chat should be retained")
_, err = db.GetChatFileByID(ctx, fileF)
require.NoError(t, err, "file F in active + old archived chat should be retained")
},
},
{
name: "UnarchiveAfterFilePurge",
run: func(t *testing.T) {
// Validates that when dbpurge deletes chat_files rows,
// the FK cascade on chat_file_links automatically
// removes the stale links. Unarchiving a chat after
// file purge should show only surviving files.
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create a chat with three attached files.
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: chat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{fileA, fileB, fileC},
})
require.NoError(t, err)
// Archive the chat.
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
// Simulate dbpurge deleting files A and B. The FK
// cascade on chat_file_links_file_id_fkey should
// automatically remove the corresponding link rows.
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
require.NoError(t, err)
// Unarchive the chat.
_, err = db.UnarchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
// Only file C should remain linked (FK cascade
// removed the links for deleted files A and B).
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Len(t, files, 1, "only surviving file should be linked")
require.Equal(t, fileC, files[0].ID)
// Edge case: delete the last file too. The chat
// should have zero linked files, not an error.
_, err = db.ArchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
require.NoError(t, err)
_, err = db.UnarchiveChatByID(ctx, chat.ID)
require.NoError(t, err)
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
require.NoError(t, err)
require.Empty(t, files, "all-files-deleted should yield empty result")
// Test parent+child cascade: deleting files should
// clean up links for both parent and child chats
// independently via FK cascade.
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: deps.user.ID,
LastModelConfigID: deps.modelConfig.ID,
Title: "child-chat",
Status: database.ChatStatusWaiting,
})
require.NoError(t, err)
// Set root_chat_id to link child to parent.
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
require.NoError(t, err)
// Attach different files to parent and child.
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: parentChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
})
require.NoError(t, err)
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
ChatID: childChat.ID,
MaxFileLinks: 100,
FileIds: []uuid.UUID{childFileKeep, childFileStale},
})
require.NoError(t, err)
// Archive via parent (cascades to child).
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
require.NoError(t, err)
// Delete one file from each chat.
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
require.NoError(t, err)
// Unarchive via parent.
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
require.NoError(t, err)
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
require.NoError(t, err)
require.Len(t, parentFiles, 1)
require.Equal(t, parentFileKeep, parentFiles[0].ID,
"parent should retain only non-stale file")
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
require.NoError(t, err)
require.Len(t, childFiles, 1)
require.Equal(t, childFileKeep, childFiles[0].ID,
"child should retain only non-stale file")
},
},
{
name: "BatchLimitFiles",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create 3 deletable orphaned files (all 31 days old).
for range 3 {
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
}
// Delete with limit 2 — should delete 2, leave 1.
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
// Delete again — should delete the remaining 1.
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
},
},
{
name: "BatchLimitChats",
run: func(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
deps := setupChatDeps(ctx, t, db)
// Create 3 deletable old archived chats.
for range 3 {
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
}
// Delete with limit 2 — should delete 2, leave 1.
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
// Delete again — should delete the remaining 1.
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
BeforeTime: now.Add(-30 * 24 * time.Hour),
LimitCount: 2,
})
require.NoError(t, err)
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
tc.run(t)
})
}
}
+3 -16
View File
@@ -293,8 +293,7 @@ CREATE TYPE chat_status AS ENUM (
'running',
'paused',
'completed',
'error',
'requires_action'
'error'
);
CREATE TYPE connection_status AS ENUM (
@@ -316,11 +315,6 @@ CREATE TYPE cors_behavior AS ENUM (
'passthru'
);
CREATE TYPE credential_kind AS ENUM (
'centralized',
'byok'
);
CREATE TYPE crypto_key_feature AS ENUM (
'workspace_apps_token',
'workspace_apps_api_key',
@@ -1107,9 +1101,7 @@ CREATE TABLE aibridge_interceptions (
thread_root_id uuid,
client_session_id character varying(256),
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
provider_name text DEFAULT ''::text NOT NULL,
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
provider_name text DEFAULT ''::text NOT NULL
);
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
@@ -1126,10 +1118,6 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
CREATE TABLE aibridge_model_thoughts (
interception_id uuid NOT NULL,
content text NOT NULL,
@@ -1430,8 +1418,7 @@ CREATE TABLE chats (
agent_id uuid,
pin_order integer DEFAULT 0 NOT NULL,
last_read_message_id bigint,
last_injected_context jsonb,
dynamic_tools jsonb
last_injected_context jsonb
);
CREATE TABLE connection_logs (
@@ -1,31 +0,0 @@
-- First update any rows using the value we're about to remove.
-- The column type is still the original chat_status at this point.
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
-- Drop the column (this is independent of the enum).
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
-- Drop the partial index that references the chat_status enum type.
-- It must be removed before the rename-create-cast-drop cycle
-- because the index's WHERE clause (status = 'pending'::chat_status)
-- would otherwise cause a cross-type comparison failure.
DROP INDEX IF EXISTS idx_chats_pending;
-- Now recreate the enum without requires_action.
-- We must use the rename-create-cast-drop pattern.
ALTER TYPE chat_status RENAME TO chat_status_old;
CREATE TYPE chat_status AS ENUM (
'waiting',
'pending',
'running',
'paused',
'completed',
'error'
);
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
DROP TYPE chat_status_old;
-- Recreate the partial index.
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
@@ -1,3 +0,0 @@
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
@@ -1,5 +0,0 @@
ALTER TABLE aibridge_interceptions
DROP COLUMN IF EXISTS credential_kind,
DROP COLUMN IF EXISTS credential_hint;
DROP TYPE IF EXISTS credential_kind;
@@ -1,12 +0,0 @@
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
-- Records how each LLM request was authenticated and a masked credential
-- identifier for audit purposes. Existing rows default to 'centralized'
-- with an empty hint since we cannot retroactively determine their values.
ALTER TABLE aibridge_interceptions
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
-- Length capped as a safety measure to ensure only masked values are stored.
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
-5
View File
@@ -798,7 +798,6 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
&i.Chat.PinOrder,
&i.Chat.LastReadMessageID,
&i.Chat.LastInjectedContext,
&i.Chat.DynamicTools,
&i.HasUnread); err != nil {
return nil, err
}
@@ -869,8 +868,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
&i.VisibleUser.ID,
&i.VisibleUser.Username,
&i.VisibleUser.Name,
@@ -1134,8 +1131,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
&i.AIBridgeInterception.ClientSessionID,
&i.AIBridgeInterception.SessionID,
&i.AIBridgeInterception.ProviderName,
&i.AIBridgeInterception.CredentialKind,
&i.AIBridgeInterception.CredentialHint,
); err != nil {
return nil, err
}
+7 -73
View File
@@ -1290,13 +1290,12 @@ func AllChatModeValues() []ChatMode {
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusRequiresAction ChatStatus = "requires_action"
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
)
func (e *ChatStatus) Scan(src interface{}) error {
@@ -1341,8 +1340,7 @@ func (e ChatStatus) Valid() bool {
ChatStatusRunning,
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError,
ChatStatusRequiresAction:
ChatStatusError:
return true
}
return false
@@ -1356,7 +1354,6 @@ func AllChatStatusValues() []ChatStatus {
ChatStatusPaused,
ChatStatusCompleted,
ChatStatusError,
ChatStatusRequiresAction,
}
}
@@ -1546,64 +1543,6 @@ func AllCorsBehaviorValues() []CorsBehavior {
}
}
type CredentialKind string
const (
CredentialKindCentralized CredentialKind = "centralized"
CredentialKindByok CredentialKind = "byok"
)
func (e *CredentialKind) Scan(src interface{}) error {
switch s := src.(type) {
case []byte:
*e = CredentialKind(s)
case string:
*e = CredentialKind(s)
default:
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
}
return nil
}
type NullCredentialKind struct {
CredentialKind CredentialKind `json:"credential_kind"`
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
}
// Scan implements the Scanner interface.
func (ns *NullCredentialKind) Scan(value interface{}) error {
if value == nil {
ns.CredentialKind, ns.Valid = "", false
return nil
}
ns.Valid = true
return ns.CredentialKind.Scan(value)
}
// Value implements the driver Valuer interface.
func (ns NullCredentialKind) Value() (driver.Value, error) {
if !ns.Valid {
return nil, nil
}
return string(ns.CredentialKind), nil
}
func (e CredentialKind) Valid() bool {
switch e {
case CredentialKindCentralized,
CredentialKindByok:
return true
}
return false
}
func AllCredentialKindValues() []CredentialKind {
return []CredentialKind{
CredentialKindCentralized,
CredentialKindByok,
}
}
type CryptoKeyFeature string
const (
@@ -4101,10 +4040,6 @@ type AIBridgeInterception struct {
SessionID string `db:"session_id" json:"session_id"`
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
ProviderName string `db:"provider_name" json:"provider_name"`
// How the request was authenticated: centralized or byok.
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
// Masked credential identifier for audit (e.g. sk-a***efgh).
CredentialHint string `db:"credential_hint" json:"credential_hint"`
}
// Audit log of model thinking in intercepted requests in AI Bridge
@@ -4245,7 +4180,6 @@ type Chat struct {
PinOrder int32 `db:"pin_order" json:"pin_order"`
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
}
type ChatDiffStatus struct {
-749
View File
@@ -1,749 +0,0 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"errors"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/quartz"
)
const (
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
// publishes wait before a scheduled flush when nearby publishes do not
// naturally coalesce sooner.
DefaultBatchingFlushInterval = 50 * time.Millisecond
// DefaultBatchingQueueSize is the default number of buffered chatd publish
// requests waiting to be flushed.
DefaultBatchingQueueSize = 8192
defaultBatchingPressureWait = 10 * time.Millisecond
defaultBatchingFinalFlushLimit = 15 * time.Second
batchingWarnInterval = 10 * time.Second
batchFlushScheduled = "scheduled"
batchFlushShutdown = "shutdown"
batchFlushStageNone = "none"
batchFlushStageBegin = "begin"
batchFlushStageExec = "exec"
batchFlushStageCommit = "commit"
batchDelegateFallbackReasonQueueFull = "queue_full"
batchDelegateFallbackReasonFlushError = "flush_error"
batchChannelClassStreamNotify = "stream_notify"
batchChannelClassOwnerEvent = "owner_event"
batchChannelClassConfigChange = "config_change"
batchChannelClassOther = "other"
)
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
// attempted after shutdown has started.
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
// backpressure) and drains everything currently queued into a single
// transaction. There is no fixed batch-size knob — the batch size is simply
// whatever accumulated since the last flush, which naturally adapts to load.
type BatchingConfig struct {
FlushInterval time.Duration
QueueSize int
PressureWait time.Duration
FinalFlushTimeout time.Duration
Clock quartz.Clock
}
type queuedPublish struct {
event string
channelClass string
message []byte
}
type batchSender interface {
Flush(ctx context.Context, batch []queuedPublish) error
Close() error
}
type batchFlushError struct {
stage string
err error
}
func (e *batchFlushError) Error() string {
return e.err.Error()
}
func (e *batchFlushError) Unwrap() error {
return e.err
}
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
// sender connection while delegating subscribe behavior to the shared listener
// pubsub instance.
type BatchingPubsub struct {
logger slog.Logger
delegate *PGPubsub
// sender is only accessed from the run() goroutine (including
// flushBatch and resetSender which it calls). Do not read or
// write this field from Publish or any other goroutine.
sender batchSender
newSender func(context.Context) (batchSender, error)
clock quartz.Clock
publishCh chan queuedPublish
flushCh chan struct{}
closeCh chan struct{}
doneCh chan struct{}
spaceMu sync.Mutex
spaceSignal chan struct{}
warnTicker *quartz.Ticker
flushInterval time.Duration
pressureWait time.Duration
finalFlushTimeout time.Duration
queuedCount atomic.Int64
closed atomic.Bool
closeOnce sync.Once
closeErr error
runErr error
runCtx context.Context
cancel context.CancelFunc
metrics batchingMetrics
}
type batchingMetrics struct {
QueueDepth prometheus.Gauge
BatchSize prometheus.Histogram
FlushDuration *prometheus.HistogramVec
DelegateFallbacksTotal *prometheus.CounterVec
SenderResetsTotal prometheus.Counter
SenderResetFailuresTotal prometheus.Counter
}
func newBatchingMetrics() batchingMetrics {
return batchingMetrics{
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_queue_depth",
Help: "The number of chatd notifications waiting in the batching queue.",
}),
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_size",
Help: "The number of logical notifications sent in each chatd batch flush.",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
}),
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_flush_duration_seconds",
Help: "The time spent flushing one chatd batch to PostgreSQL.",
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
}, []string{"reason"}),
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_delegate_fallbacks_total",
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
}, []string{"channel_class", "reason", "stage"}),
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_resets_total",
Help: "The number of successful batched pubsub sender resets after flush failures.",
}),
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coder",
Subsystem: "pubsub",
Name: "batch_sender_reset_failures_total",
Help: "The number of batched pubsub sender reset attempts that failed.",
}),
}
}
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
m.QueueDepth.Describe(descs)
m.BatchSize.Describe(descs)
m.FlushDuration.Describe(descs)
m.DelegateFallbacksTotal.Describe(descs)
m.SenderResetsTotal.Describe(descs)
m.SenderResetFailuresTotal.Describe(descs)
}
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
m.QueueDepth.Collect(metrics)
m.BatchSize.Collect(metrics)
m.FlushDuration.Collect(metrics)
m.DelegateFallbacksTotal.Collect(metrics)
m.SenderResetsTotal.Collect(metrics)
m.SenderResetFailuresTotal.Collect(metrics)
}
// NewBatching creates a chatd-specific batched pubsub wrapper around the
// shared PostgreSQL listener implementation.
func NewBatching(
ctx context.Context,
logger slog.Logger,
delegate *PGPubsub,
prototype *sql.DB,
connectURL string,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if prototype == nil {
return nil, xerrors.New("prototype database is nil")
}
if connectURL == "" {
return nil, xerrors.New("connect URL is empty")
}
newSender := func(ctx context.Context) (batchSender, error) {
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
}
sender, err := newSender(ctx)
if err != nil {
return nil, err
}
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
if err != nil {
_ = sender.Close()
return nil, err
}
ps.newSender = newSender
return ps, nil
}
func newBatchingPubsub(
logger slog.Logger,
delegate *PGPubsub,
sender batchSender,
cfg BatchingConfig,
) (*BatchingPubsub, error) {
if delegate == nil {
return nil, xerrors.New("delegate pubsub is nil")
}
if sender == nil {
return nil, xerrors.New("batch sender is nil")
}
flushInterval := cfg.FlushInterval
if flushInterval == 0 {
flushInterval = DefaultBatchingFlushInterval
}
if flushInterval < 0 {
return nil, xerrors.New("flush interval must be positive")
}
queueSize := cfg.QueueSize
if queueSize == 0 {
queueSize = DefaultBatchingQueueSize
}
if queueSize < 0 {
return nil, xerrors.New("queue size must be positive")
}
pressureWait := cfg.PressureWait
if pressureWait == 0 {
pressureWait = defaultBatchingPressureWait
}
if pressureWait < 0 {
return nil, xerrors.New("pressure wait must be positive")
}
finalFlushTimeout := cfg.FinalFlushTimeout
if finalFlushTimeout == 0 {
finalFlushTimeout = defaultBatchingFinalFlushLimit
}
if finalFlushTimeout < 0 {
return nil, xerrors.New("final flush timeout must be positive")
}
clock := cfg.Clock
if clock == nil {
clock = quartz.NewReal()
}
runCtx, cancel := context.WithCancel(context.Background())
ps := &BatchingPubsub{
logger: logger,
delegate: delegate,
sender: sender,
clock: clock,
publishCh: make(chan queuedPublish, queueSize),
flushCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
spaceSignal: make(chan struct{}),
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
flushInterval: flushInterval,
pressureWait: pressureWait,
finalFlushTimeout: finalFlushTimeout,
runCtx: runCtx,
cancel: cancel,
metrics: newBatchingMetrics(),
}
ps.metrics.QueueDepth.Set(0)
go ps.run()
return ps, nil
}
// Describe implements prometheus.Collector.
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
p.metrics.Describe(descs)
}
// Collect implements prometheus.Collector.
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
p.metrics.Collect(metrics)
}
// Subscribe delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
return p.delegate.Subscribe(event, listener)
}
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
return p.delegate.SubscribeWithErr(event, listener)
}
// Publish enqueues a logical notification for asynchronous batched delivery.
func (p *BatchingPubsub) Publish(event string, message []byte) error {
channelClass := batchChannelClass(event)
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
req := queuedPublish{
event: event,
channelClass: channelClass,
message: bytes.Clone(message),
}
if p.tryEnqueue(req) {
return nil
}
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
defer timer.Stop("pubsubBatcher", "pressureWait")
for {
if p.closed.Load() {
return ErrBatchingPubsubClosed
}
p.signalPressureFlush()
spaceSignal := p.currentSpaceSignal()
if p.tryEnqueue(req) {
return nil
}
select {
case <-spaceSignal:
continue
case <-timer.C:
if p.tryEnqueue(req) {
return nil
}
// The batching queue is still full after a pressure
// flush and brief wait. Fall back to the shared
// pubsub pool so the notification is still delivered
// rather than dropped.
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
p.logPublishRejection(event)
return p.delegate.Publish(event, message)
case <-p.doneCh:
return ErrBatchingPubsubClosed
}
}
}
// Close stops accepting new publishes, performs a bounded best-effort drain,
// and then closes the dedicated sender connection.
func (p *BatchingPubsub) Close() error {
p.closeOnce.Do(func() {
p.closed.Store(true)
p.cancel()
p.notifySpaceAvailable()
close(p.closeCh)
<-p.doneCh
p.closeErr = p.runErr
})
return p.closeErr
}
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
if p.closed.Load() {
return false
}
select {
case p.publishCh <- req:
queuedDepth := p.queuedCount.Add(1)
p.observeQueueDepth(queuedDepth)
return true
default:
return false
}
}
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
p.metrics.QueueDepth.Set(float64(depth))
}
func (p *BatchingPubsub) signalPressureFlush() {
select {
case p.flushCh <- struct{}{}:
default:
}
}
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
return p.spaceSignal
}
func (p *BatchingPubsub) notifySpaceAvailable() {
p.spaceMu.Lock()
defer p.spaceMu.Unlock()
close(p.spaceSignal)
p.spaceSignal = make(chan struct{})
}
func batchChannelClass(event string) string {
switch {
case strings.HasPrefix(event, "chat:stream:"):
return batchChannelClassStreamNotify
case strings.HasPrefix(event, "chat:owner:"):
return batchChannelClassOwnerEvent
case event == "chat:config_change":
return batchChannelClassConfigChange
default:
return batchChannelClassOther
}
}
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
}
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
if len(batch) == 0 {
return
}
counts := make(map[string]int)
for _, item := range batch {
counts[item.channelClass]++
}
for channelClass, count := range counts {
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
}
}
func batchFlushStage(err error) string {
var flushErr *batchFlushError
if errors.As(err, &flushErr) {
return flushErr.stage
}
return "unknown"
}
func (p *BatchingPubsub) run() {
defer close(p.doneCh)
defer p.warnTicker.Stop("pubsubBatcher", "warn")
batch := make([]queuedPublish, 0, 64)
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
defer timer.Stop("pubsubBatcher", "scheduledFlush")
flush := func(reason string) {
batch = p.drainIntoBatch(batch)
batch, _ = p.flushBatch(p.runCtx, batch, reason)
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
}
for {
select {
case item := <-p.publishCh:
// An item arrived before the timer fired. Append it and
// let the timer or pressure signal trigger the actual
// flush so that nearby publishes coalesce naturally.
batch = append(batch, item)
p.notifySpaceAvailable()
case <-timer.C:
flush(batchFlushScheduled)
case <-p.flushCh:
flush("pressure")
case <-p.closeCh:
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
return
}
}
}
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
drained := false
for {
select {
case item := <-p.publishCh:
batch = append(batch, item)
drained = true
default:
if drained {
p.notifySpaceAvailable()
}
return batch
}
}
}
func (p *BatchingPubsub) flushBatch(
ctx context.Context,
batch []queuedPublish,
reason string,
) ([]queuedPublish, error) {
if len(batch) == 0 {
return batch[:0], nil
}
count := len(batch)
totalBytes := 0
for _, item := range batch {
totalBytes += len(item.message)
}
p.metrics.BatchSize.Observe(float64(count))
start := p.clock.Now()
senderErr := p.sender.Flush(ctx, batch)
elapsed := p.clock.Since(start)
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
var err error
if senderErr != nil {
stage := batchFlushStage(senderErr)
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
var resetErr error
if reason != batchFlushShutdown {
resetErr = p.resetSender()
}
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
if fallbackErr != nil || resetErr != nil {
err = errors.Join(senderErr, fallbackErr, resetErr)
}
} else if p.delegate != nil {
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
}
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
clear(batch)
return batch[:0], err
}
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
if len(batch) == 0 {
return 0, 0, nil
}
p.observeDelegateFallbackBatch(batch, reason, stage)
if p.delegate == nil {
return 0, len(batch), xerrors.New("delegate pubsub is nil")
}
var errs []error
for _, item := range batch {
if err := p.delegate.Publish(item.event, item.message); err != nil {
failed++
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
continue
}
delivered++
}
return delivered, failed, errors.Join(errs...)
}
func (p *BatchingPubsub) resetSender() error {
if p.newSender == nil {
return nil
}
newSender, err := p.newSender(context.Background())
if err != nil {
p.metrics.SenderResetFailuresTotal.Inc()
return err
}
oldSender := p.sender
p.sender = newSender
p.metrics.SenderResetsTotal.Inc()
if oldSender == nil {
return nil
}
if err := oldSender.Close(); err != nil {
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
}
return nil
}
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
fields := []slog.Field{
slog.F("reason", reason),
slog.F("stage", stage),
slog.F("count", count),
slog.F("total_bytes", totalBytes),
slog.F("delegate_delivered", delivered),
slog.F("delegate_failed", failed),
slog.Error(senderErr),
}
if fallbackErr != nil {
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
}
if resetErr != nil {
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
}
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
}
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
defer cancel()
var errs []error
for {
batch = p.drainIntoBatch(batch)
if len(batch) == 0 {
break
}
var err error
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
if err != nil {
errs = append(errs, err)
}
if ctx.Err() != nil {
break
}
}
dropped := p.dropPendingPublishes()
if dropped > 0 {
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
}
if ctx.Err() != nil {
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
}
return errors.Join(errs...)
}
func (p *BatchingPubsub) dropPendingPublishes() int {
count := 0
for {
select {
case <-p.publishCh:
count++
default:
if count > 0 {
queuedDepth := p.queuedCount.Add(-int64(count))
p.observeQueueDepth(queuedDepth)
}
return count
}
}
}
func (p *BatchingPubsub) logPublishRejection(event string) {
fields := []slog.Field{
slog.F("event", event),
slog.F("queue_size", cap(p.publishCh)),
slog.F("queued", p.queuedCount.Load()),
}
select {
case <-p.warnTicker.C:
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
default:
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
}
}
type pgBatchSender struct {
logger slog.Logger
db *sql.DB
}
func newPGBatchSender(
ctx context.Context,
logger slog.Logger,
prototype *sql.DB,
connectURL string,
) (*pgBatchSender, error) {
connector, err := newConnector(ctx, logger, prototype, connectURL)
if err != nil {
return nil, err
}
db := sql.OpenDB(connector)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxIdleTime(0)
db.SetConnMaxLifetime(0)
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := db.PingContext(pingCtx); err != nil {
_ = db.Close()
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
}
return &pgBatchSender{logger: logger, db: db}, nil
}
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
}
committed := false
defer func() {
if !committed {
_ = tx.Rollback()
}
}()
for _, item := range batch {
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
// not support the first parameter being a prepared statement.
//nolint:gosec
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
if err != nil {
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
}
}
if err := tx.Commit(); err != nil {
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
}
committed = true
return nil
}
func (s *pgBatchSender) Close() error {
return s.db.Close()
}
@@ -1,520 +0,0 @@
package pubsub
import (
"bytes"
"context"
"database/sql"
"sync"
"testing"
"time"
_ "github.com/lib/pq"
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestBatchingPubsubScheduledFlush(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 2)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("two"), batch[1].message)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 2, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
}
func TestBatchChannelClass(t *testing.T) {
t.Parallel()
tests := []struct {
name string
event string
want string
}{
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, batchChannelClass(tt.event))
})
}
}
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 64,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Enqueue many messages before the timer fires — all should be
// drained and flushed in a single batch.
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.Empty(t, sender.Batches())
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batch := testutil.TryReceive(ctx, t, sender.flushes)
require.Len(t, batch, 5)
require.Equal(t, []byte("one"), batch[0].message)
require.Equal(t, []byte("five"), batch[4].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
}
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
defer pressureTrap.Close()
sender := newFakeBatchSender()
sender.blockCh = make(chan struct{})
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 1,
PressureWait: 10 * time.Millisecond,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
// Fill the queue (capacity 1).
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
// Fire the timer so the run loop starts flushing "one" — the
// sender blocks on blockCh so the flush stays in-flight.
clock.Advance(10 * time.Millisecond).MustWait(ctx)
<-sender.started
// The run loop is blocked in flushBatch. Fill the queue again.
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
// A third publish should fall back to the delegate (which has a
// closed db, so the delegate Publish itself will error — but we
// verify the fallback metric was incremented).
errCh := make(chan error, 1)
go func() {
errCh <- ps.Publish("chat:stream:a", []byte("three"))
}()
pressureCall, err := pressureTrap.Wait(ctx)
require.NoError(t, err)
pressureCall.MustRelease(ctx)
clock.Advance(10 * time.Millisecond).MustWait(ctx)
err = testutil.TryReceive(ctx, t, errCh)
// The delegate has a closed db so it returns an error from the
// shared pool, not a batching-specific sentinel.
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
close(sender.blockCh)
// Let the run loop finish the blocked flush and process "two".
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
require.NoError(t, ps.Close())
}
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
require.NoError(t, ps.Close())
batches := sender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 3)
require.Equal(t, []byte("one"), batches[0][0].message)
require.Equal(t, []byte("two"), batches[0][1].message)
require.Equal(t, []byte("three"), batches[0][2].message)
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, 1, sender.CloseCalls())
}
func TestBatchingPubsubPreservesOrder(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
sender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: time.Hour,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
for _, msg := range []string{"one", "two", "three", "four", "five"} {
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
}
require.NoError(t, ps.Close())
batches := sender.Batches()
require.NotEmpty(t, batches)
messages := make([]string, 0, 5)
for _, batch := range batches {
for _, item := range batch {
messages = append(messages, string(item.message))
}
}
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
}
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
clock := quartz.NewMock(t)
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
defer newTimerTrap.Close()
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
defer resetTrap.Close()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
Clock: clock,
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
call, err := newTimerTrap.Wait(ctx)
require.NoError(t, err)
call.MustRelease(ctx)
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
clock.Advance(10 * time.Millisecond).MustWait(ctx)
resetCall, err := resetTrap.Wait(ctx)
require.NoError(t, err)
resetCall.MustRelease(ctx)
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
require.Equal(t, uint64(1), batchSizeCount)
require.InDelta(t, 1, batchSizeSum, 0.000001)
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
require.Equal(t, uint64(1), flushDurationCount)
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
t.Parallel()
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
for _, stage := range stages {
stage := stage
t.Run(stage, func(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = stage
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:test",
channelClass: batchChannelClass("chat:stream:test"),
message: []byte("fallback-" + stage),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
})
}
}
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
firstSender := newFakeBatchSender()
firstSender.err = context.DeadlineExceeded
firstSender.errStage = batchFlushStageExec
secondSender := newFakeBatchSender()
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
ps.newSender = func(context.Context) (batchSender, error) {
return secondSender, nil
}
firstBatch := []queuedPublish{{
event: "chat:stream:first",
channelClass: batchChannelClass("chat:stream:first"),
message: []byte("first"),
}}
ps.queuedCount.Store(int64(len(firstBatch)))
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
require.Error(t, err)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
require.Equal(t, 1, firstSender.CloseCalls())
secondBatch := []queuedPublish{{
event: "chat:stream:second",
channelClass: batchChannelClass("chat:stream:second"),
message: []byte("second"),
}}
ps.queuedCount.Store(int64(len(secondBatch)))
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
require.NoError(t, err)
batches := secondSender.Batches()
require.Len(t, batches, 1)
require.Len(t, batches[0], 1)
require.Equal(t, []byte("second"), batches[0][0].message)
}
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
t.Parallel()
sender := newFakeBatchSender()
sender.err = context.DeadlineExceeded
sender.errStage = batchFlushStageExec
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
batch := []queuedPublish{{
event: "chat:stream:error",
channelClass: batchChannelClass("chat:stream:error"),
message: []byte("error"),
}}
ps.queuedCount.Store(int64(len(batch)))
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
require.Error(t, err)
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
}
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
t.Helper()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
// Use a closed *sql.DB so that delegate.Publish returns a real
// error instead of panicking on a nil pointer when the batching
// queue falls back to the shared pool under pressure.
closedDB := newClosedDB(t)
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
require.NoError(t, err)
t.Cleanup(func() {
_ = ps.Close()
})
return ps, delegate
}
// newClosedDB returns an *sql.DB whose connections have been closed,
// so any ExecContext call returns an error rather than panicking.
func newClosedDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
require.NoError(t, err)
require.NoError(t, db.Close())
return db
}
type fakeBatchSender struct {
mu sync.Mutex
batches [][]queuedPublish
flushes chan []queuedPublish
started chan struct{}
blockCh chan struct{}
err error
errStage string
closeErr error
closeCall int
}
func newFakeBatchSender() *fakeBatchSender {
return &fakeBatchSender{
flushes: make(chan []queuedPublish, 16),
started: make(chan struct{}, 16),
}
}
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
select {
case s.started <- struct{}{}:
default:
}
if s.blockCh != nil {
select {
case <-s.blockCh:
case <-ctx.Done():
return ctx.Err()
}
}
clone := make([]queuedPublish, len(batch))
for i, item := range batch {
clone[i] = queuedPublish{
event: item.event,
message: bytes.Clone(item.message),
}
}
s.mu.Lock()
s.batches = append(s.batches, clone)
s.mu.Unlock()
select {
case s.flushes <- clone:
default:
}
if s.err == nil {
return nil
}
if s.errStage != "" {
return &batchFlushError{stage: s.errStage, err: s.err}
}
return s.err
}
type metricWriter interface {
Write(*dto.Metric) error
}
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
t.Helper()
writer, ok := observer.(metricWriter)
require.True(t, ok)
metric := &dto.Metric{}
require.NoError(t, writer.Write(metric))
histogram := metric.GetHistogram()
require.NotNil(t, histogram)
return histogram.GetSampleCount(), histogram.GetSampleSum()
}
func (s *fakeBatchSender) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.closeCall++
return s.closeErr
}
func (s *fakeBatchSender) Batches() [][]queuedPublish {
s.mu.Lock()
defer s.mu.Unlock()
clone := make([][]queuedPublish, len(s.batches))
for i, batch := range s.batches {
clone[i] = make([]queuedPublish, len(batch))
copy(clone[i], batch)
}
return clone
}
func (s *fakeBatchSender) CloseCalls() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.closeCall
}
-130
View File
@@ -1,130 +0,0 @@
package pubsub_test
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/testutil"
)
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
event := t.Name()
messageCh := make(chan []byte, 1)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("hello")))
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
}
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
trackedDriver := dbtestutil.NewDriver()
defer trackedDriver.Close()
tconn, err := trackedDriver.Connector(connectionURL)
require.NoError(t, err)
trackedDB := sql.OpenDB(tconn)
defer trackedDB.Close()
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
require.NoError(t, err)
defer base.Close()
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
FlushInterval: 10 * time.Millisecond,
QueueSize: 8,
})
require.NoError(t, err)
defer batched.Close()
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
event := t.Name()
messageCh := make(chan []byte, 4)
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
messageCh <- message
})
require.NoError(t, err)
defer cancel()
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
require.NoError(t, senderConn.Close())
reconnected := false
delivered := false
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
if !reconnected {
select {
case conn := <-trackedDriver.Connections:
reconnected = conn != nil
default:
}
}
select {
case <-messageCh:
default:
}
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
return false
}
select {
case msg := <-messageCh:
delivered = string(msg) == "after-disconnect"
case <-time.After(testutil.IntervalFast):
delivered = false
}
return reconnected && delivered
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
}
+9 -21
View File
@@ -487,14 +487,12 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
return conn, nil
}
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
if db == nil {
return nil, xerrors.New("database is nil")
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
// Creates a new listener using pq.
var (
dialer = logDialer{
logger: logger,
logger: p.logger,
// pq.defaultDialer uses a zero net.Dialer as well.
d: net.Dialer{},
}
@@ -503,38 +501,28 @@ func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectUR
)
// Create a custom connector if the database driver supports it.
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
if ok {
connector, err = connectorCreator.Connector(connectURL)
if err != nil {
return nil, xerrors.Errorf("create custom connector: %w", err)
return xerrors.Errorf("create custom connector: %w", err)
}
} else {
// Use the default pq connector otherwise.
// use the default pq connector otherwise
connector, err = pq.NewConnector(connectURL)
if err != nil {
return nil, xerrors.Errorf("create pq connector: %w", err)
return xerrors.Errorf("create pq connector: %w", err)
}
}
// Set the dialer if the connector supports it.
dc, ok := connector.(database.DialerConnector)
if !ok {
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
} else {
dc.Dialer(dialer)
}
return connector, nil
}
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
p.connected.Set(0)
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
if err != nil {
return err
}
var (
errCh = make(chan error, 1)
sentErrCh = false
+6 -49
View File
@@ -128,22 +128,6 @@ type sqlcQuerier interface {
// connection events (connect, disconnect, open, close) which are handled
// separately by DeleteOldAuditLogConnectionEvents.
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
// TODO(cian): Add indexes on chats(archived, updated_at) and
// chat_files(created_at) for purge query performance.
// See: https://github.com/coder/internal/issues/1438
// Deletes chat files that are older than the given threshold and are
// not referenced by any chat that is still active or was archived
// within the same threshold window. This covers two cases:
// 1. Orphaned files not linked to any chat.
// 2. Files whose every referencing chat has been archived for longer
// than the retention period.
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error)
// Deletes chats that have been archived for longer than the given
// threshold. Active (non-archived) chats are never deleted.
// Related chat_messages, chat_diff_statuses, and
// chat_queued_messages are removed via ON DELETE CASCADE.
// Parent/root references on child chats are SET NULL.
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
// Delete all notification messages which have not been updated for over a week.
DeleteOldNotificationMessages(ctx context.Context) error
@@ -168,7 +152,7 @@ type sqlcQuerier interface {
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
@@ -271,27 +255,16 @@ type sqlcQuerier interface {
// otherwise the setting defaults to true.
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
// Aggregates message-level metrics per chat for messages created
// after the given timestamp. Uses message created_at so that
// ongoing activity in long-running chats is captured each window.
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
// Returns all model configurations for telemetry snapshot collection.
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
// Returns the chat retention period in days. Chats archived longer
// than this and orphaned chat files older than this are purged by
// dbpurge. Returns 30 (days) when no value has been configured.
// A value of 0 disables chat purging entirely.
GetChatRetentionDays(ctx context.Context) (int32, error)
GetChatSystemPrompt(ctx context.Context) (string, error)
// GetChatSystemPromptConfig returns both chat system prompt settings in a
// single read to avoid torn reads between separate site-config lookups.
@@ -310,10 +283,6 @@ type sqlcQuerier interface {
GetChatWorkspaceTTL(ctx context.Context) (string, error)
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
// Retrieves chats updated after the given timestamp for telemetry
// snapshot collection. Uses updated_at so that long-running chats
// still appear in each snapshot window while they are active.
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
@@ -509,10 +478,8 @@ type sqlcQuerier interface {
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
GetRuntimeConfig(ctx context.Context, key string) (string, error)
// Find chats that appear stuck and need recovery. This covers:
// 1. Running chats whose heartbeat has expired (worker crash).
// 2. Chats awaiting client action (requires_action) past the
// timeout threshold (client disappeared).
// Find chats that appear stuck (running but heartbeat has expired).
// Used for recovery after coderd crashes or long hangs.
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
@@ -631,6 +598,7 @@ type sqlcQuerier interface {
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
// GetUserStatusCounts returns the count of users in each status over time.
// The time range is inclusively defined by the start_time and end_time parameters.
@@ -850,13 +818,7 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
// Returns metadata only (no value or value_key_id) for the
// REST API list and get endpoints.
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
// Returns all columns including the secret value. Used by the
// provisioner (build-time injection) and the agent manifest
// (runtime injection).
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
@@ -898,10 +860,6 @@ type sqlcQuerier interface {
// This must be called from within a transaction. The lock will be automatically
// released when the transaction ends.
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
// Unarchives a chat (and its children). Stale file references are
// handled automatically by FK cascades on chat_file_links: when
// dbpurge deletes a chat_files row, the corresponding
// chat_file_links rows are cascade-deleted by PostgreSQL.
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
// This will always work regardless of the current state of the template version.
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
@@ -999,7 +957,7 @@ type sqlcQuerier interface {
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
@@ -1043,7 +1001,6 @@ type sqlcQuerier interface {
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
UpsertChatSystemPrompt(ctx context.Context, value string) error
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
+34 -50
View File
@@ -7339,7 +7339,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, secretID, createdSecret.ID)
// 2. READ by UserID and Name
// 2. READ by ID
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readSecret.ID)
assert.Equal(t, "workflow-secret", readSecret.Name)
// 3. READ by UserID and Name
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
@@ -7347,43 +7353,33 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
require.NoError(t, err)
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
// 3. LIST (metadata only)
// 4. LIST
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secrets, 1)
assert.Equal(t, createdSecret.ID, secrets[0].ID)
// 4. LIST with values
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
require.NoError(t, err)
require.Len(t, secretsWithValues, 1)
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
// 5. UPDATE (partial - only description)
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
UpdateDescription: true,
Description: "Updated workflow description",
// 5. UPDATE
updateParams := database.UpdateUserSecretParams{
ID: createdSecret.ID,
Description: "Updated workflow description",
Value: "updated-workflow-value",
EnvName: "UPDATED_WORKFLOW_ENV",
FilePath: "/updated/workflow/path",
}
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
require.NoError(t, err)
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
// 6. DELETE
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
UserID: testUser.ID,
Name: "workflow-secret",
})
err = db.DeleteUserSecret(ctx, createdSecret.ID)
require.NoError(t, err)
// Verify deletion
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
_, err = db.GetUserSecret(ctx, createdSecret.ID)
require.Error(t, err)
assert.Contains(t, err.Error(), "no rows in result set")
@@ -7453,13 +7449,9 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
})
// Verify both secrets exist
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret1.Name,
})
_, err = db.GetUserSecret(ctx, secret1.ID)
require.NoError(t, err)
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
UserID: testUser.ID, Name: secret2.Name,
})
_, err = db.GetUserSecret(ctx, secret2.ID)
require.NoError(t, err)
})
}
@@ -7482,14 +7474,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
org := dbgen.Organization(t, db, database.Organization{})
// Create secrets for users
_ = dbgen.UserSecret(t, db, database.UserSecret{
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
UserID: user1.ID,
Name: "user1-secret",
Description: "User 1's secret",
Value: "user1-value",
})
_ = dbgen.UserSecret(t, db, database.UserSecret{
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
UserID: user2.ID,
Name: "user2-secret",
Description: "User 2's secret",
@@ -7499,8 +7491,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
testCases := []struct {
name string
subject rbac.Subject
lookupUserID uuid.UUID
lookupName string
secretID uuid.UUID
expectedAccess bool
}{
{
@@ -7510,8 +7501,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: true,
},
{
@@ -7521,8 +7511,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
},
lookupUserID: user2.ID,
lookupName: "user2-secret",
secretID: user2Secret.ID,
expectedAccess: false,
},
{
@@ -7532,8 +7521,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: false,
},
{
@@ -7543,8 +7531,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
Scope: rbac.ScopeAll,
},
lookupUserID: user1.ID,
lookupName: "user1-secret",
secretID: user1Secret.ID,
expectedAccess: false,
},
}
@@ -7556,10 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
authCtx := dbauthz.As(ctx, tc.subject)
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
UserID: tc.lookupUserID,
Name: tc.lookupName,
})
// Test GetUserSecret
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
if tc.expectedAccess {
require.NoError(t, err, "expected access to be granted")
@@ -9085,11 +9070,10 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
insertParams := database.InsertAIBridgeInterceptionParams{
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
Client: sql.NullString{String: "client", Valid: true},
CredentialKind: database.CredentialKindCentralized,
ID: uid,
InitiatorID: user.ID,
Metadata: json.RawMessage("{}"),
Client: sql.NullString{String: "client", Valid: true},
}
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -1,8 +1,8 @@
-- name: InsertAIBridgeInterception :one
INSERT INTO aibridge_interceptions (
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
) VALUES (
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
)
RETURNING *;
-34
View File
@@ -18,37 +18,3 @@ FROM chat_files cf
JOIN chat_file_links cfl ON cfl.file_id = cf.id
WHERE cfl.chat_id = @chat_id::uuid
ORDER BY cf.created_at ASC;
-- TODO(cian): Add indexes on chats(archived, updated_at) and
-- chat_files(created_at) for purge query performance.
-- See: https://github.com/coder/internal/issues/1438
-- name: DeleteOldChatFiles :execrows
-- Deletes chat files that are older than the given threshold and are
-- not referenced by any chat that is still active or was archived
-- within the same threshold window. This covers two cases:
-- 1. Orphaned files not linked to any chat.
-- 2. Files whose every referencing chat has been archived for longer
-- than the retention period.
WITH kept_file_ids AS (
-- NOTE: This uses updated_at as a proxy for archive time
-- because there is no archived_at column. Correctness
-- requires that updated_at is never backdated on archived
-- chats. See ArchiveChatByID.
SELECT DISTINCT cfl.file_id
FROM chat_file_links cfl
JOIN chats c ON c.id = cfl.chat_id
WHERE c.archived = false
OR c.updated_at >= @before_time::timestamptz
),
deletable AS (
SELECT cf.id
FROM chat_files cf
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
WHERE cf.created_at < @before_time::timestamptz
AND k.file_id IS NULL
ORDER BY cf.created_at ASC
LIMIT @limit_count
)
DELETE FROM chat_files
USING deletable
WHERE chat_files.id = deletable.id;
+8 -81
View File
@@ -10,14 +10,9 @@ FROM chats
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
-- name: UnarchiveChatByID :many
-- Unarchives a chat (and its children). Stale file references are
-- handled automatically by FK cascades on chat_file_links: when
-- dbpurge deletes a chat_files row, the corresponding
-- chat_file_links rows are cascade-deleted by PostgreSQL.
WITH chats AS (
UPDATE chats SET
archived = false,
updated_at = NOW()
UPDATE chats
SET archived = false, updated_at = NOW()
WHERE id = @id::uuid OR root_chat_id = @id::uuid
RETURNING *
)
@@ -399,8 +394,7 @@ INSERT INTO chats (
mode,
status,
mcp_server_ids,
labels,
dynamic_tools
labels
) VALUES (
@owner_id::uuid,
sqlc.narg('workspace_id')::uuid,
@@ -413,8 +407,7 @@ INSERT INTO chats (
sqlc.narg('mode')::chat_mode,
@status::chat_status,
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
sqlc.narg('dynamic_tools')::jsonb
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
)
RETURNING
*;
@@ -671,19 +664,15 @@ RETURNING
*;
-- name: GetStaleChats :many
-- Find chats that appear stuck and need recovery. This covers:
-- 1. Running chats whose heartbeat has expired (worker crash).
-- 2. Chats awaiting client action (requires_action) past the
-- timeout threshold (client disappeared).
-- Find chats that appear stuck (running but heartbeat has expired).
-- Used for recovery after coderd crashes or long hangs.
SELECT
*
FROM
chats
WHERE
(status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz)
OR (status = 'requires_action'::chat_status
AND updated_at < @stale_threshold::timestamptz);
status = 'running'::chat_status
AND heartbeat_at < @stale_threshold::timestamptz;
-- name: UpdateChatHeartbeats :many
-- Bumps the heartbeat timestamp for the given set of chat IDs,
@@ -1231,65 +1220,3 @@ LIMIT 1;
UPDATE chats
SET last_read_message_id = @last_read_message_id::bigint
WHERE id = @id::uuid;
-- name: DeleteOldChats :execrows
-- Deletes chats that have been archived for longer than the given
-- threshold. Active (non-archived) chats are never deleted.
-- Related chat_messages, chat_diff_statuses, and
-- chat_queued_messages are removed via ON DELETE CASCADE.
-- Parent/root references on child chats are SET NULL.
WITH deletable AS (
SELECT id
FROM chats
WHERE archived = true
AND updated_at < @before_time::timestamptz
ORDER BY updated_at ASC
LIMIT @limit_count
)
DELETE FROM chats
USING deletable
WHERE chats.id = deletable.id
AND chats.archived = true;
-- name: GetChatsUpdatedAfter :many
-- Retrieves chats updated after the given timestamp for telemetry
-- snapshot collection. Uses updated_at so that long-running chats
-- still appear in each snapshot window while they are active.
SELECT
id, owner_id, created_at, updated_at, status,
(parent_chat_id IS NOT NULL)::bool AS has_parent,
root_chat_id, workspace_id,
mode, archived, last_model_config_id
FROM chats
WHERE updated_at > @updated_after;
-- name: GetChatMessageSummariesPerChat :many
-- Aggregates message-level metrics per chat for messages created
-- after the given timestamp. Uses message created_at so that
-- ongoing activity in long-running chats is captured each window.
SELECT
cm.chat_id,
COUNT(*)::bigint AS message_count,
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
FROM chat_messages cm
WHERE cm.created_at > @created_after
AND cm.deleted = false
GROUP BY cm.chat_id;
-- name: GetChatModelConfigsForTelemetry :many
-- Returns all model configurations for telemetry snapshot collection.
SELECT id, provider, model, context_limit, enabled, is_default
FROM chat_model_configs
WHERE deleted = false;
-17
View File
@@ -236,20 +236,3 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
ON CONFLICT (key) DO UPDATE
SET value = @workspace_ttl::text
WHERE site_configs.key = 'agents_workspace_ttl';
-- name: GetChatRetentionDays :one
-- Returns the chat retention period in days. Chats archived longer
-- than this and orphaned chat files older than this are purged by
-- dbpurge. Returns 30 (days) when no value has been configured.
-- A value of 0 disables chat purging entirely.
SELECT COALESCE(
(SELECT value::integer FROM site_configs
WHERE key = 'agents_chat_retention_days'),
30
) :: integer AS retention_days;
-- name: UpsertChatRetentionDays :exec
INSERT INTO site_configs (key, value)
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
WHERE site_configs.key = 'agents_chat_retention_days';
+18 -39
View File
@@ -1,26 +1,14 @@
-- name: GetUserSecretByUserIDAndName :one
SELECT *
FROM user_secrets
WHERE user_id = @user_id AND name = @name;
SELECT * FROM user_secrets
WHERE user_id = $1 AND name = $2;
-- name: GetUserSecret :one
SELECT * FROM user_secrets
WHERE id = $1;
-- name: ListUserSecrets :many
-- Returns metadata only (no value or value_key_id) for the
-- REST API list and get endpoints.
SELECT
id, user_id, name, description,
env_name, file_path,
created_at, updated_at
FROM user_secrets
WHERE user_id = @user_id
ORDER BY name ASC;
-- name: ListUserSecretsWithValues :many
-- Returns all columns including the secret value. Used by the
-- provisioner (build-time injection) and the agent manifest
-- (runtime injection).
SELECT *
FROM user_secrets
WHERE user_id = @user_id
SELECT * FROM user_secrets
WHERE user_id = $1
ORDER BY name ASC;
-- name: CreateUserSecret :one
@@ -30,32 +18,23 @@ INSERT INTO user_secrets (
name,
description,
value,
value_key_id,
env_name,
file_path
) VALUES (
@id,
@user_id,
@name,
@description,
@value,
@value_key_id,
@env_name,
@file_path
$1, $2, $3, $4, $5, $6, $7
) RETURNING *;
-- name: UpdateUserSecretByUserIDAndName :one
-- name: UpdateUserSecret :one
UPDATE user_secrets
SET
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
updated_at = CURRENT_TIMESTAMP
WHERE user_id = @user_id AND name = @name
description = $2,
value = $3,
env_name = $4,
file_path = $5,
updated_at = CURRENT_TIMESTAMP
WHERE id = $1
RETURNING *;
-- name: DeleteUserSecretByUserIDAndName :exec
-- name: DeleteUserSecret :exec
DELETE FROM user_secrets
WHERE user_id = @user_id AND name = @name;
WHERE id = $1;
-187
View File
@@ -398,10 +398,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
// Cap the raw request body to prevent excessive memory use
// from large dynamic tool schemas.
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
var req codersdk.CreateChatRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
@@ -492,50 +488,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
return
}
if len(req.UnsafeDynamicTools) > 250 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Too many dynamic tools.",
Detail: "Maximum 250 dynamic tools per chat.",
})
return
}
// Validate that dynamic tool names are non-empty and unique
// within the list. Name collision with built-in tools is
// checked at chatloop time when the full tool set is known.
if len(req.UnsafeDynamicTools) > 0 {
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
for _, dt := range req.UnsafeDynamicTools {
if dt.Name == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Dynamic tool name must not be empty.",
})
return
}
if _, exists := seenNames[dt.Name]; exists {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Duplicate dynamic tool name.",
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
})
return
}
seenNames[dt.Name] = struct{}{}
}
}
var dynamicToolsJSON json.RawMessage
if len(req.UnsafeDynamicTools) > 0 {
var err error
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to marshal dynamic tools.",
Detail: err.Error(),
})
return
}
}
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
OwnerID: apiKey.UserID,
WorkspaceID: workspaceSelection.WorkspaceID,
@@ -545,7 +497,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
InitialUserContent: contentBlocks,
MCPServerIDs: mcpServerIDs,
Labels: labels,
DynamicTools: dynamicToolsJSON,
})
if err != nil {
if maybeWriteLimitErr(ctx, rw, err) {
@@ -3232,70 +3183,6 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
}
// @Summary Get chat retention days
// @ID get-chat-retention-days
// @Security CoderSessionToken
// @Tags Chats
// @Produce json
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
// @Router /experimental/chats/config/retention-days [get]
// @x-apidocgen {"skip": true}
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get chat retention days.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
RetentionDays: retentionDays,
})
}
// Keep in sync with retentionDaysMaximum in
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
const retentionDaysMaximum = 3650 // ~10 years
// @Summary Update chat retention days
// @ID update-chat-retention-days
// @Security CoderSessionToken
// @Tags Chats
// @Accept json
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
// @Success 204
// @Router /experimental/chats/config/retention-days [put]
// @x-apidocgen {"skip": true}
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
httpapi.Forbidden(rw)
return
}
var req codersdk.UpdateChatRetentionDaysRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
})
return
}
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to update chat retention days.",
Detail: err.Error(),
})
return
}
rw.WriteHeader(http.StatusNoContent)
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
@@ -5800,77 +5687,3 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
RecentPRs: prEntries,
})
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
apiKey := httpmw.APIKey(r)
// Cap the raw request body to prevent excessive memory use.
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
var req codersdk.SubmitToolResultsRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if len(req.Results) == 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "At least one tool result is required.",
})
return
}
// Fast-path check outside the transaction. The authoritative
// check happens inside SubmitToolResults under a row lock.
if chat.Status != database.ChatStatusRequiresAction {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Chat is not waiting for tool results.",
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
})
return
}
var dynamicTools json.RawMessage
if chat.DynamicTools.Valid {
dynamicTools = chat.DynamicTools.RawMessage
}
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: apiKey.UserID,
ModelConfigID: chat.LastModelConfigID,
Results: req.Results,
DynamicTools: dynamicTools,
})
if err != nil {
var validationErr *chatd.ToolResultValidationError
var conflictErr *chatd.ToolResultStatusConflictError
switch {
case errors.As(err, &conflictErr):
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Chat is not waiting for tool results.",
Detail: err.Error(),
})
case errors.As(err, &validationErr):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: validationErr.Message,
Detail: validationErr.Detail,
})
default:
api.Logger.Error(ctx, "tool results submission failed",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error submitting tool results.",
})
}
return
}
rw.WriteHeader(http.StatusNoContent)
}
+18 -431
View File
@@ -16,9 +16,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
"github.com/shopspring/decimal"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
@@ -270,10 +268,17 @@ func TestPostChats(t *testing.T) {
_ = createChatModelConfig(t, client)
// Member without agents-access should be denied.
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
// Strip the auto-assigned agents-access role to test
// the denied case.
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
Roles: []string{},
})
require.NoError(t, err)
_, err = memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
@@ -283,7 +288,6 @@ func TestPostChats(t *testing.T) {
})
requireSDKError(t, err, http.StatusForbidden)
})
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
t.Parallel()
@@ -752,7 +756,15 @@ func TestListChats(t *testing.T) {
// returning empty because no chats exist.
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
// Strip the auto-assigned agents-access role to test
// the denied case.
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
Roles: []string{},
})
require.NoError(t, err)
_, err = db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: member.ID,
LastModelConfigID: modelConfig.ID,
@@ -7735,62 +7747,6 @@ func TestChatWorkspaceTTL(t *testing.T) {
requireSDKError(t, err, http.StatusBadRequest)
}
func TestChatRetentionDays(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
// Default value is 30 (days) when nothing has been configured.
resp, err := adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get default")
require.Equal(t, int32(30), resp.RetentionDays, "default should be 30")
// Admin can set retention days to 90.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 90,
})
require.NoError(t, err, "admin set 90")
resp, err = adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get after set")
require.Equal(t, int32(90), resp.RetentionDays, "should return 90")
// Non-admin member can read the value.
resp, err = memberClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "member get")
require.Equal(t, int32(90), resp.RetentionDays, "member should see same value")
// Non-admin member cannot write.
err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7})
requireSDKError(t, err, http.StatusForbidden)
// Admin can disable purge by setting 0.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 0,
})
require.NoError(t, err, "admin set 0")
resp, err = adminClient.GetChatRetentionDays(ctx)
require.NoError(t, err, "get after zero")
require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable")
// Validation: negative value is rejected.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: -1,
})
requireSDKError(t, err, http.StatusBadRequest)
// Validation: exceeding the 3650-day maximum is rejected.
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go.
})
requireSDKError(t, err, http.StatusBadRequest)
}
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
func TestUserChatCompactionThresholds(t *testing.T) {
t.Parallel()
@@ -8197,375 +8153,6 @@ func TestGetChatsByWorkspace(t *testing.T) {
})
}
func TestSubmitToolResults(t *testing.T) {
t.Parallel()
// setupRequiresAction creates a chat via the DB with dynamic tools,
// inserts an assistant message containing tool-call parts for each
// given toolCallID, and sets the chat status to requires_action.
// It returns the chat row so callers can exercise the endpoint.
setupRequiresAction := func(
ctx context.Context,
t *testing.T,
db database.Store,
ownerID uuid.UUID,
modelConfigID uuid.UUID,
dynamicToolName string,
toolCallIDs []string,
) database.Chat {
t.Helper()
// Marshal dynamic tools into the chat row.
dynamicTools := []mcp.Tool{{
Name: dynamicToolName,
Description: "a test dynamic tool",
InputSchema: mcp.ToolInputSchema{Type: "object"},
}}
dtJSON, err := json.Marshal(dynamicTools)
require.NoError(t, err)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Title: "tool-results-test",
DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true},
})
require.NoError(t, err)
// Build assistant message with tool-call parts.
parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs))
for _, id := range toolCallIDs {
parts = append(parts, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: id,
ToolName: dynamicToolName,
Args: json.RawMessage(`{"key":"value"}`),
})
}
content, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfigID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Content: []string{string(content.RawMessage)},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
// Transition to requires_action.
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRequiresAction,
})
require.NoError(t, err)
require.Equal(t, database.ChatStatusRequiresAction, chat.Status)
return chat
}
t.Run("Success", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_abc", "call_def"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)},
{ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)},
},
})
require.NoError(t, err)
// Verify status is no longer requires_action. The chatd
// loop may have already picked the chat up and
// transitioned it further (pending → running → …), so we
// accept any non-requires_action status.
gotChat, err := client.GetChat(ctx, chat.ID)
require.NoError(t, err)
require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status,
"chat should no longer be in requires_action after submitting tool results")
// Verify tool-result messages were persisted.
msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil)
require.NoError(t, err)
var toolResultCount int
for _, msg := range msgsResp.Messages {
if msg.Role == codersdk.ChatMessageRoleTool {
toolResultCount++
}
}
require.Equal(t, len(toolCallIDs), toolResultCount,
"expected one tool-result message per submitted result")
})
t.Run("WrongStatus", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
// Create a chat that is NOT in requires_action status.
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "wrong-status-test",
})
require.NoError(t, err)
err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)},
},
})
requireSDKError(t, err, http.StatusConflict)
})
t.Run("MissingResult", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_one", "call_two"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Submit only one of the two required results.
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)},
},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("UnexpectedResult", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_real"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Submit a result with a wrong tool_call_id.
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)},
},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("InvalidJSONOutput", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_json"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// We must bypass the SDK client because json.RawMessage
// rejects invalid JSON during json.Marshal. A raw HTTP
// request lets the invalid payload reach the server so we
// can verify server-side validation.
rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}`
url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run("DuplicateToolCallID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_dup1", "call_dup2"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)},
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Contains(t, sdkErr.Message, "Duplicate tool_call_id")
})
t.Run("EmptyResults", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_empty"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{},
})
requireSDKError(t, err, http.StatusBadRequest)
})
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
const toolName = "my_dynamic_tool"
toolCallIDs := []string{"call_other"}
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
// Create a second user and try to submit tool results
// to user A's chat.
otherClientRaw, _ := coderdtest.CreateAnotherUser(
t, client.Client, user.OrganizationID,
rbac.RoleAgentsAccess(),
)
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
Results: []codersdk.ToolResult{
{ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)},
},
})
requireSDKError(t, err, http.StatusNotFound)
})
}
func TestPostChats_DynamicToolValidation(t *testing.T) {
t.Parallel()
t.Run("TooManyTools", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
tools := make([]codersdk.DynamicTool, 251)
for i := range tools {
tools[i] = codersdk.DynamicTool{
Name: fmt.Sprintf("tool-%d", i),
}
}
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: tools,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Too many dynamic tools.", sdkErr.Message)
})
t.Run("EmptyToolName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: []codersdk.DynamicTool{
{Name: ""},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message)
})
t.Run("DuplicateToolName", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "hello",
}},
UnsafeDynamicTools: []codersdk.DynamicTool{
{Name: "dup-tool"},
{Name: "dup-tool"},
},
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message)
})
}
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
t.Helper()
+1 -1
View File
@@ -148,7 +148,7 @@ func TestGetOrgMembersFilter(t *testing.T) {
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Members))
+2 -4
View File
@@ -32,9 +32,8 @@ func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error))
}
type ChatEvent struct {
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
Kind ChatEventKind `json:"kind"`
Chat codersdk.Chat `json:"chat"`
}
type ChatEventKind string
@@ -45,5 +44,4 @@ const (
ChatEventKindCreated ChatEventKind = "created"
ChatEventKindDeleted ChatEventKind = "deleted"
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
ChatEventKindActionRequired ChatEventKind = "action_required"
)
-144
View File
@@ -776,40 +776,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
chats, err := r.options.Database.GetChatsUpdatedAfter(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get chats updated after: %w", err)
}
snapshot.Chats = make([]Chat, 0, len(chats))
for _, chat := range chats {
snapshot.Chats = append(snapshot.Chats, ConvertChat(chat))
}
return nil
})
eg.Go(func() error {
summaries, err := r.options.Database.GetChatMessageSummariesPerChat(ctx, createdAfter)
if err != nil {
return xerrors.Errorf("get chat message summaries: %w", err)
}
snapshot.ChatMessageSummaries = make([]ChatMessageSummary, 0, len(summaries))
for _, s := range summaries {
snapshot.ChatMessageSummaries = append(snapshot.ChatMessageSummaries, ConvertChatMessageSummary(s))
}
return nil
})
eg.Go(func() error {
configs, err := r.options.Database.GetChatModelConfigsForTelemetry(ctx)
if err != nil {
return xerrors.Errorf("get chat model configs: %w", err)
}
snapshot.ChatModelConfigs = make([]ChatModelConfig, 0, len(configs))
for _, c := range configs {
snapshot.ChatModelConfigs = append(snapshot.ChatModelConfigs, ConvertChatModelConfig(c))
}
return nil
})
err := eg.Wait()
if err != nil {
return nil, err
@@ -1537,9 +1503,6 @@ type Snapshot struct {
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
Chats []Chat `json:"chats"`
ChatMessageSummaries []ChatMessageSummary `json:"chat_message_summaries"`
ChatModelConfigs []ChatModelConfig `json:"chat_model_configs"`
}
// Deployment contains information about the host running Coder.
@@ -2150,66 +2113,6 @@ func ConvertTask(task database.Task) Task {
return t
}
// ConvertChat converts a database chat row to a telemetry Chat.
func ConvertChat(dbChat database.GetChatsUpdatedAfterRow) Chat {
c := Chat{
ID: dbChat.ID,
OwnerID: dbChat.OwnerID,
CreatedAt: dbChat.CreatedAt,
UpdatedAt: dbChat.UpdatedAt,
Status: string(dbChat.Status),
HasParent: dbChat.HasParent,
Archived: dbChat.Archived,
LastModelConfigID: dbChat.LastModelConfigID,
}
if dbChat.RootChatID.Valid {
c.RootChatID = &dbChat.RootChatID.UUID
}
if dbChat.WorkspaceID.Valid {
c.WorkspaceID = &dbChat.WorkspaceID.UUID
}
if dbChat.Mode.Valid {
mode := string(dbChat.Mode.ChatMode)
c.Mode = &mode
}
return c
}
// ConvertChatMessageSummary converts a database chat message
// summary row to a telemetry ChatMessageSummary.
func ConvertChatMessageSummary(dbRow database.GetChatMessageSummariesPerChatRow) ChatMessageSummary {
return ChatMessageSummary{
ChatID: dbRow.ChatID,
MessageCount: dbRow.MessageCount,
UserMessageCount: dbRow.UserMessageCount,
AssistantMessageCount: dbRow.AssistantMessageCount,
ToolMessageCount: dbRow.ToolMessageCount,
SystemMessageCount: dbRow.SystemMessageCount,
TotalInputTokens: dbRow.TotalInputTokens,
TotalOutputTokens: dbRow.TotalOutputTokens,
TotalReasoningTokens: dbRow.TotalReasoningTokens,
TotalCacheCreationTokens: dbRow.TotalCacheCreationTokens,
TotalCacheReadTokens: dbRow.TotalCacheReadTokens,
TotalCostMicros: dbRow.TotalCostMicros,
TotalRuntimeMs: dbRow.TotalRuntimeMs,
DistinctModelCount: dbRow.DistinctModelCount,
CompressedMessageCount: dbRow.CompressedMessageCount,
}
}
// ConvertChatModelConfig converts a database model config row to a
// telemetry ChatModelConfig.
func ConvertChatModelConfig(dbRow database.GetChatModelConfigsForTelemetryRow) ChatModelConfig {
return ChatModelConfig{
ID: dbRow.ID,
Provider: dbRow.Provider,
Model: dbRow.Model,
ContextLimit: dbRow.ContextLimit,
Enabled: dbRow.Enabled,
IsDefault: dbRow.IsDefault,
}
}
type telemetryItemKey string
// The comment below gets rid of the warning that the name "TelemetryItemKey" has
@@ -2331,53 +2234,6 @@ type BoundaryUsageSummary struct {
PeriodDurationMilliseconds int64 `json:"period_duration_ms"`
}
// Chat contains anonymized metadata about a chat for telemetry.
// Titles and message content are excluded to avoid PII leakage.
type Chat struct {
ID uuid.UUID `json:"id"`
OwnerID uuid.UUID `json:"owner_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Status string `json:"status"`
HasParent bool `json:"has_parent"`
RootChatID *uuid.UUID `json:"root_chat_id"`
WorkspaceID *uuid.UUID `json:"workspace_id"`
Mode *string `json:"mode"`
Archived bool `json:"archived"`
LastModelConfigID uuid.UUID `json:"last_model_config_id"`
}
// ChatMessageSummary contains per-chat aggregated message metrics
// for telemetry. Individual message content is never included.
type ChatMessageSummary struct {
ChatID uuid.UUID `json:"chat_id"`
MessageCount int64 `json:"message_count"`
UserMessageCount int64 `json:"user_message_count"`
AssistantMessageCount int64 `json:"assistant_message_count"`
ToolMessageCount int64 `json:"tool_message_count"`
SystemMessageCount int64 `json:"system_message_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalReasoningTokens int64 `json:"total_reasoning_tokens"`
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
TotalCostMicros int64 `json:"total_cost_micros"`
TotalRuntimeMs int64 `json:"total_runtime_ms"`
DistinctModelCount int64 `json:"distinct_model_count"`
CompressedMessageCount int64 `json:"compressed_message_count"`
}
// ChatModelConfig contains model configuration metadata for
// telemetry. Sensitive fields like API keys are excluded.
type ChatModelConfig struct {
ID uuid.UUID `json:"id"`
Provider string `json:"provider"`
Model string `json:"model"`
ContextLimit int64 `json:"context_limit"`
Enabled bool `json:"enabled"`
IsDefault bool `json:"is_default"`
}
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
return AIBridgeInterceptionsSummary{
ID: uuid.New(),
-300
View File
@@ -1549,303 +1549,3 @@ func TestTelemetry_BoundaryUsageSummary(t *testing.T) {
require.Nil(t, snapshot2.BoundaryUsageSummary)
})
}
func TestChatsTelemetry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
// Create chat providers (required FK for model configs).
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "Anthropic",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "openai",
DisplayName: "OpenAI",
Enabled: true,
CentralApiKeyEnabled: true,
})
require.NoError(t, err)
// Create a model config.
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "claude-sonnet-4-20250514",
DisplayName: "Claude Sonnet",
Enabled: true,
IsDefault: true,
ContextLimit: 200000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
// Create a second model config to test full dump.
modelCfg2, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "openai",
Model: "gpt-4o",
DisplayName: "GPT-4o",
Enabled: true,
IsDefault: false,
ContextLimit: 128000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
// Create a soft-deleted model config — should NOT appear in telemetry.
deletedCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "claude-deleted",
DisplayName: "Deleted Model",
Enabled: true,
IsDefault: false,
ContextLimit: 100000,
CompressionThreshold: 70,
Options: json.RawMessage("{}"),
})
require.NoError(t, err)
err = db.DeleteChatModelConfigByID(ctx, deletedCfg.ID)
require.NoError(t, err)
// Create a root chat with a workspace.
org, err := db.GetDefaultOrganization(ctx)
require.NoError(t, err)
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
OrganizationID: org.ID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
})
tpl := dbgen.Template(t, db, database.Template{
OrganizationID: org.ID,
CreatedBy: user.ID,
})
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
OrganizationID: org.ID,
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
CreatedBy: user.ID,
JobID: job.ID,
})
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
OwnerID: user.ID,
OrganizationID: org.ID,
TemplateID: tpl.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStart,
Reason: database.BuildReasonInitiator,
WorkspaceID: ws.ID,
TemplateVersionID: tv.ID,
JobID: job.ID,
})
rootChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg.ID,
Title: "Root Chat",
Status: database.ChatStatusRunning,
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
Mode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
})
require.NoError(t, err)
// Create a child chat (has parent + root).
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: modelCfg2.ID,
Title: "Child Chat",
Status: database.ChatStatusCompleted,
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
})
require.NoError(t, err)
// Insert messages for root chat: 2 user, 2 assistant, 1 tool.
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: rootChat.ID,
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, user.ID, uuid.Nil, uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
Content: []string{`[{"type":"text","text":"hello"}]`, `[{"type":"text","text":"hi"}]`, `[{"type":"text","text":"help"}]`, `[{"type":"text","text":"sure"}]`, `[{"type":"text","text":"result"}]`},
ContentVersion: []int16{1, 1, 1, 1, 1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{100, 200, 150, 300, 0},
OutputTokens: []int64{0, 50, 0, 100, 0},
TotalTokens: []int64{100, 250, 150, 400, 0},
ReasoningTokens: []int64{0, 10, 0, 20, 0},
CacheCreationTokens: []int64{50, 0, 30, 0, 0},
CacheReadTokens: []int64{0, 25, 0, 40, 0},
ContextLimit: []int64{200000, 200000, 200000, 200000, 200000},
Compressed: []bool{false, false, false, false, false},
TotalCostMicros: []int64{1000, 2000, 1500, 3000, 0},
RuntimeMs: []int64{0, 500, 0, 800, 100},
ProviderResponseID: []string{"", "resp-1", "", "resp-2", ""},
})
require.NoError(t, err)
// Insert messages for child chat: 1 user, 1 assistant (compressed).
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: childChat.ID,
CreatedBy: []uuid.UUID{user.ID, uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg2.ID, modelCfg2.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant},
Content: []string{`[{"type":"text","text":"q"}]`, `[{"type":"text","text":"a"}]`},
ContentVersion: []int16{1, 1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{500, 600},
OutputTokens: []int64{0, 200},
TotalTokens: []int64{500, 800},
ReasoningTokens: []int64{0, 50},
CacheCreationTokens: []int64{100, 0},
CacheReadTokens: []int64{0, 75},
ContextLimit: []int64{128000, 128000},
Compressed: []bool{false, true},
TotalCostMicros: []int64{5000, 8000},
RuntimeMs: []int64{0, 1200},
ProviderResponseID: []string{"", "resp-3"},
})
require.NoError(t, err)
// Insert a soft-deleted message on root chat with large token values.
// This acts as "poison" — if the deleted filter is missing, totals
// will be inflated and assertions below will fail.
poisonMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: rootChat.ID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{modelCfg.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{`[{"type":"text","text":"poison"}]`},
ContentVersion: []int16{1},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
InputTokens: []int64{999999},
OutputTokens: []int64{999999},
TotalTokens: []int64{999999},
ReasoningTokens: []int64{999999},
CacheCreationTokens: []int64{999999},
CacheReadTokens: []int64{999999},
ContextLimit: []int64{200000},
Compressed: []bool{false},
TotalCostMicros: []int64{999999},
RuntimeMs: []int64{999999},
ProviderResponseID: []string{""},
})
require.NoError(t, err)
err = db.SoftDeleteChatMessageByID(ctx, poisonMsgs[0].ID)
require.NoError(t, err)
_, snapshot := collectSnapshot(ctx, t, db, nil)
// --- Assert Chats ---
require.Len(t, snapshot.Chats, 2)
// Find root and child by HasParent flag.
var foundRoot, foundChild *telemetry.Chat
for i := range snapshot.Chats {
if !snapshot.Chats[i].HasParent {
foundRoot = &snapshot.Chats[i]
} else {
foundChild = &snapshot.Chats[i]
}
}
require.NotNil(t, foundRoot, "expected root chat")
require.NotNil(t, foundChild, "expected child chat")
// Root chat assertions.
assert.Equal(t, rootChat.ID, foundRoot.ID)
assert.Equal(t, user.ID, foundRoot.OwnerID)
assert.Equal(t, "running", foundRoot.Status)
assert.False(t, foundRoot.HasParent)
assert.Nil(t, foundRoot.RootChatID)
require.NotNil(t, foundRoot.WorkspaceID)
assert.Equal(t, ws.ID, *foundRoot.WorkspaceID)
assert.Equal(t, modelCfg.ID, foundRoot.LastModelConfigID)
require.NotNil(t, foundRoot.Mode)
assert.Equal(t, "computer_use", *foundRoot.Mode)
assert.False(t, foundRoot.Archived)
// Child chat assertions.
assert.Equal(t, childChat.ID, foundChild.ID)
assert.Equal(t, user.ID, foundChild.OwnerID)
assert.True(t, foundChild.HasParent)
require.NotNil(t, foundChild.RootChatID)
assert.Equal(t, rootChat.ID, *foundChild.RootChatID)
assert.Nil(t, foundChild.WorkspaceID)
assert.Equal(t, "completed", foundChild.Status)
assert.Equal(t, modelCfg2.ID, foundChild.LastModelConfigID)
assert.Nil(t, foundChild.Mode)
assert.False(t, foundChild.Archived)
// --- Assert ChatMessageSummaries ---
require.Len(t, snapshot.ChatMessageSummaries, 2)
summaryMap := make(map[uuid.UUID]telemetry.ChatMessageSummary)
for _, s := range snapshot.ChatMessageSummaries {
summaryMap[s.ChatID] = s
}
// Root chat summary: 2 user + 2 assistant + 1 tool = 5 messages.
rootSummary, ok := summaryMap[rootChat.ID]
require.True(t, ok, "expected summary for root chat")
assert.Equal(t, int64(5), rootSummary.MessageCount)
assert.Equal(t, int64(2), rootSummary.UserMessageCount)
assert.Equal(t, int64(2), rootSummary.AssistantMessageCount)
assert.Equal(t, int64(1), rootSummary.ToolMessageCount)
assert.Equal(t, int64(0), rootSummary.SystemMessageCount)
assert.Equal(t, int64(750), rootSummary.TotalInputTokens) // 100+200+150+300+0
assert.Equal(t, int64(150), rootSummary.TotalOutputTokens) // 0+50+0+100+0
assert.Equal(t, int64(30), rootSummary.TotalReasoningTokens) // 0+10+0+20+0
assert.Equal(t, int64(80), rootSummary.TotalCacheCreationTokens) // 50+0+30+0+0
assert.Equal(t, int64(65), rootSummary.TotalCacheReadTokens) // 0+25+0+40+0
assert.Equal(t, int64(7500), rootSummary.TotalCostMicros) // 1000+2000+1500+3000+0
assert.Equal(t, int64(1400), rootSummary.TotalRuntimeMs) // 0+500+0+800+100
assert.Equal(t, int64(1), rootSummary.DistinctModelCount)
assert.Equal(t, int64(0), rootSummary.CompressedMessageCount)
// Child chat summary: 1 user + 1 assistant = 2 messages, 1 compressed.
childSummary, ok := summaryMap[childChat.ID]
require.True(t, ok, "expected summary for child chat")
assert.Equal(t, int64(2), childSummary.MessageCount)
assert.Equal(t, int64(1), childSummary.UserMessageCount)
assert.Equal(t, int64(1), childSummary.AssistantMessageCount)
assert.Equal(t, int64(1100), childSummary.TotalInputTokens) // 500+600
assert.Equal(t, int64(200), childSummary.TotalOutputTokens) // 0+200
assert.Equal(t, int64(50), childSummary.TotalReasoningTokens) // 0+50
assert.Equal(t, int64(0), childSummary.ToolMessageCount)
assert.Equal(t, int64(0), childSummary.SystemMessageCount)
assert.Equal(t, int64(100), childSummary.TotalCacheCreationTokens) // 100+0
assert.Equal(t, int64(75), childSummary.TotalCacheReadTokens) // 0+75
assert.Equal(t, int64(13000), childSummary.TotalCostMicros) // 5000+8000
assert.Equal(t, int64(1200), childSummary.TotalRuntimeMs) // 0+1200
assert.Equal(t, int64(1), childSummary.DistinctModelCount)
assert.Equal(t, int64(1), childSummary.CompressedMessageCount)
// --- Assert ChatModelConfigs ---
require.Len(t, snapshot.ChatModelConfigs, 2)
configMap := make(map[uuid.UUID]telemetry.ChatModelConfig)
for _, c := range snapshot.ChatModelConfigs {
configMap[c.ID] = c
}
cfg1, ok := configMap[modelCfg.ID]
require.True(t, ok)
assert.Equal(t, "anthropic", cfg1.Provider)
assert.Equal(t, "claude-sonnet-4-20250514", cfg1.Model)
assert.Equal(t, int64(200000), cfg1.ContextLimit)
assert.True(t, cfg1.Enabled)
assert.True(t, cfg1.IsDefault)
cfg2, ok := configMap[modelCfg2.ID]
require.True(t, ok)
assert.Equal(t, "openai", cfg2.Provider)
assert.Equal(t, "gpt-4o", cfg2.Model)
assert.Equal(t, int64(128000), cfg2.ContextLimit)
assert.True(t, cfg2.Enabled)
assert.False(t, cfg2.IsDefault)
}
+12 -8
View File
@@ -475,14 +475,6 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
}
req.UserLoginType = codersdk.LoginTypeNone
// Service accounts are a Premium feature.
if !api.Entitlements.Enabled(codersdk.FeatureServiceAccounts) {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
Message: fmt.Sprintf("%s is a Premium feature. Contact sales!", codersdk.FeatureServiceAccounts.Humanize()),
})
return
}
} else if req.UserLoginType == "" {
// Default to password auth
req.UserLoginType = codersdk.LoginTypePassword
@@ -1638,6 +1630,18 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
rbacRoles = req.RBACRoles
}
// When the agents experiment is enabled, auto-assign the
// agents-access role so new users can use Coder Agents
// without manual admin intervention. Skip this for OIDC
// users when site role sync is enabled, because the sync
// will overwrite roles on every login anyway — those
// admins should use --oidc-user-role-default instead.
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
}
var user database.User
err := store.InTx(func(tx database.Store) error {
orgRoles := make([]string, 0)
+143 -6
View File
@@ -829,6 +829,35 @@ func TestPostUsers(t *testing.T) {
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
})
// CreateWithAgentsExperiment verifies that new users
// are auto-assigned the agents-access role when the
// experiment is enabled. The experiment-disabled case
// is implicitly covered by TestInitialRoles, which
// asserts exactly [owner] with no experiment — it
// would fail if agents-access leaked through.
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
firstUser := coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
Email: "another@user.org",
Username: "someone-else",
Password: "SomeSecurePassword!",
})
require.NoError(t, err)
roles, err := client.UserRoles(ctx, user.Username)
require.NoError(t, err)
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
"new user should have agents-access role when agents experiment is enabled")
})
t.Run("CreateWithStatus", func(t *testing.T) {
t.Parallel()
auditor := audit.NewMock()
@@ -950,7 +979,28 @@ func TestPostUsers(t *testing.T) {
require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC)
})
t.Run("ServiceAccount/Unlicensed", func(t *testing.T) {
t.Run("ServiceAccount/OK", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-ok",
UserLoginType: codersdk.LoginTypeNone,
ServiceAccount: true,
})
require.NoError(t, err)
require.Equal(t, codersdk.LoginTypeNone, user.LoginType)
require.Empty(t, user.Email)
require.Equal(t, "service-acct-ok", user.Username)
require.Equal(t, codersdk.UserStatusDormant, user.Status)
})
t.Run("ServiceAccount/WithEmail", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
@@ -960,14 +1010,75 @@ func TestPostUsers(t *testing.T) {
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-ok",
UserLoginType: codersdk.LoginTypeNone,
Username: "service-acct-email",
Email: "should-not-have@email.com",
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Premium feature")
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Email cannot be set for service accounts")
})
t.Run("ServiceAccount/WithPassword", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-password",
Password: "ShouldNotHavePassword123!",
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Password cannot be set for service accounts")
})
t.Run("ServiceAccount/WithInvalidLoginType", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-login-type",
UserLoginType: codersdk.LoginTypePassword,
ServiceAccount: true,
})
var apiErr *codersdk.Error
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
require.Contains(t, apiErr.Message, "Service accounts must use login type 'none'")
})
t.Run("ServiceAccount/DefaultLoginType", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-default-login",
ServiceAccount: true,
})
require.NoError(t, err)
found, err := client.User(ctx, user.ID.String())
require.NoError(t, err)
require.Equal(t, codersdk.LoginTypeNone, found.LoginType)
require.Empty(t, found.Email)
})
t.Run("NonServiceAccount/WithoutEmail", func(t *testing.T) {
@@ -987,6 +1098,32 @@ func TestPostUsers(t *testing.T) {
require.ErrorAs(t, err, &apiErr)
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
})
t.Run("ServiceAccount/MultipleWithoutEmail", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
user1, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-multi-1",
ServiceAccount: true,
})
require.NoError(t, err)
require.Empty(t, user1.Email)
user2, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
OrganizationIDs: []uuid.UUID{first.OrganizationID},
Username: "service-acct-multi-2",
ServiceAccount: true,
})
require.NoError(t, err)
require.Empty(t, user2.Email)
require.NotEqual(t, user1.ID, user2.ID)
})
}
func TestNotifyCreatedUser(t *testing.T) {
@@ -1695,7 +1832,7 @@ func TestGetUsersFilter(t *testing.T) {
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
res, err := client.Users(testCtx, req)
require.NoError(t, err)
reduced := make([]codersdk.ReducedUser, len(res.Users))
+2 -3
View File
@@ -181,9 +181,8 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
level := make([]database.LogLevel, 0)
outputLength := 0
for _, logEntry := range req.Logs {
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
output = append(output, sanitizedOutput)
outputLength += len(sanitizedOutput)
output = append(output, logEntry.Output)
outputLength += len(logEntry.Output)
if logEntry.Level == "" {
// Default to "info" to support older agents that didn't have the level field.
logEntry.Level = codersdk.LogLevelInfo
-44
View File
@@ -260,50 +260,6 @@ func TestWorkspaceAgentLogs(t *testing.T) {
require.Equal(t, "testing", logChunk[0].Output)
require.Equal(t, "testing2", logChunk[1].Output)
})
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent().Do()
rawOutput := "before\x00after"
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
Logs: []agentsdk.Log{
{
CreatedAt: dbtime.Now(),
Output: rawOutput,
},
},
})
require.NoError(t, err)
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
require.NoError(t, err)
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
workspace, err := client.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
require.NoError(t, err)
defer func() {
_ = closer.Close()
}()
var logChunk []codersdk.WorkspaceAgentLog
select {
case <-ctx.Done():
case logChunk = <-logs:
}
require.NoError(t, ctx.Err())
require.Len(t, logChunk, 1)
require.Equal(t, sanitizedOutput, logChunk[0].Output)
})
t.Run("Close logs on outdated build", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
+25 -585
View File
@@ -788,7 +788,6 @@ type CreateOptions struct {
InitialUserContent []codersdk.ChatMessagePart
MCPServerIDs []uuid.UUID
Labels database.StringMap
DynamicTools json.RawMessage
}
// SendMessageBusyBehavior controls what happens when a chat is already active.
@@ -900,10 +899,6 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
RawMessage: labelsJSON,
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{
RawMessage: opts.DynamicTools,
Valid: len(opts.DynamicTools) > 0,
},
})
if err != nil {
return xerrors.Errorf("insert chat: %w", err)
@@ -1551,238 +1546,6 @@ func (p *Server) PromoteQueued(
return result, nil
}
// SubmitToolResultsOptions controls tool result submission.
type SubmitToolResultsOptions struct {
ChatID uuid.UUID
UserID uuid.UUID
ModelConfigID uuid.UUID
Results []codersdk.ToolResult
DynamicTools json.RawMessage
}
// ToolResultValidationError indicates the submitted tool results
// failed validation (e.g. missing, duplicate, or unexpected IDs,
// or invalid JSON output).
type ToolResultValidationError struct {
Message string
Detail string
}
func (e *ToolResultValidationError) Error() string {
if e.Detail != "" {
return e.Message + ": " + e.Detail
}
return e.Message
}
// ToolResultStatusConflictError indicates the chat is not in the
// requires_action state expected for tool result submission.
type ToolResultStatusConflictError struct {
ActualStatus database.ChatStatus
}
func (e *ToolResultStatusConflictError) Error() string {
return fmt.Sprintf(
"chat status is %q, expected %q",
e.ActualStatus, database.ChatStatusRequiresAction,
)
}
// SubmitToolResults validates and persists client-provided tool
// results, transitions the chat to pending, and wakes the run
// loop. The caller is responsible for the fast-path status check;
// this method performs an authoritative re-check under a row lock.
func (p *Server) SubmitToolResults(
ctx context.Context,
opts SubmitToolResultsOptions,
) error {
dynamicToolNames, err := parseDynamicToolNames(pqtype.NullRawMessage{
RawMessage: opts.DynamicTools,
Valid: len(opts.DynamicTools) > 0,
})
if err != nil {
return xerrors.Errorf("parse chat dynamic tools: %w", err)
}
// The GetLastChatMessageByRole lookup and all subsequent
// validation and persistence run inside a single transaction
// so the assistant message cannot change between reads.
var statusConflict *ToolResultStatusConflictError
txErr := p.db.InTx(func(tx database.Store) error {
// Authoritative status check under row lock.
locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
if lockErr != nil {
return xerrors.Errorf("lock chat for update: %w", lockErr)
}
if locked.Status != database.ChatStatusRequiresAction {
statusConflict = &ToolResultStatusConflictError{
ActualStatus: locked.Status,
}
return statusConflict
}
// Get the last assistant message inside the transaction
// for consistency with the row lock above.
lastAssistant, err := tx.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: opts.ChatID,
Role: database.ChatMessageRoleAssistant,
})
if err != nil {
return xerrors.Errorf("get last assistant message: %w", err)
}
// Collect tool-call IDs that already have results.
// When a dynamic tool name collides with a built-in,
// the chatloop executes it as a built-in and persists
// the result. Those calls must not count as pending.
afterMsgs, afterErr := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: opts.ChatID,
AfterID: lastAssistant.ID,
})
if afterErr != nil {
return xerrors.Errorf("get messages after assistant: %w", afterErr)
}
handledCallIDs := make(map[string]bool)
for _, msg := range afterMsgs {
if msg.Role != database.ChatMessageRoleTool {
continue
}
msgParts, msgParseErr := chatprompt.ParseContent(msg)
if msgParseErr != nil {
continue
}
for _, mp := range msgParts {
if mp.Type == codersdk.ChatMessagePartTypeToolResult {
handledCallIDs[mp.ToolCallID] = true
}
}
}
// Extract pending dynamic tool-call IDs, skipping any
// that were already handled by the chatloop.
pendingCallIDs := make(map[string]bool)
toolCallIDToName := make(map[string]string)
parts, parseErr := chatprompt.ParseContent(lastAssistant)
if parseErr != nil {
return xerrors.Errorf("parse assistant message: %w", parseErr)
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall &&
dynamicToolNames[part.ToolName] &&
!handledCallIDs[part.ToolCallID] {
pendingCallIDs[part.ToolCallID] = true
toolCallIDToName[part.ToolCallID] = part.ToolName
}
}
// Validate submitted results match pending calls exactly.
submittedIDs := make(map[string]bool, len(opts.Results))
for _, result := range opts.Results {
if submittedIDs[result.ToolCallID] {
return &ToolResultValidationError{
Message: "Duplicate tool_call_id in results.",
Detail: fmt.Sprintf("Duplicate tool call ID %q.", result.ToolCallID),
}
}
submittedIDs[result.ToolCallID] = true
}
for id := range pendingCallIDs {
if !submittedIDs[id] {
return &ToolResultValidationError{
Message: "Missing tool result.",
Detail: fmt.Sprintf("Missing result for tool call %q.", id),
}
}
}
for id := range submittedIDs {
if !pendingCallIDs[id] {
return &ToolResultValidationError{
Message: "Unexpected tool result.",
Detail: fmt.Sprintf("No pending tool call with ID %q.", id),
}
}
}
// Marshal each tool result into a separate message row.
resultContents := make([]pqtype.NullRawMessage, 0, len(opts.Results))
for _, result := range opts.Results {
if !json.Valid(result.Output) {
return &ToolResultValidationError{
Message: "Tool result output must be valid JSON.",
Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", result.ToolCallID),
}
}
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: result.ToolCallID,
ToolName: toolCallIDToName[result.ToolCallID],
Result: result.Output,
IsError: result.IsError,
}
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part})
if marshalErr != nil {
return xerrors.Errorf("marshal tool result: %w", marshalErr)
}
resultContents = append(resultContents, marshaled)
}
// Insert tool-result messages.
n := len(resultContents)
params := database.InsertChatMessagesParams{
ChatID: opts.ChatID,
CreatedBy: make([]uuid.UUID, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
ContentVersion: make([]int16, n),
Visibility: make([]database.ChatMessageVisibility, n),
InputTokens: make([]int64, n),
OutputTokens: make([]int64, n),
TotalTokens: make([]int64, n),
ReasoningTokens: make([]int64, n),
CacheCreationTokens: make([]int64, n),
CacheReadTokens: make([]int64, n),
ContextLimit: make([]int64, n),
Compressed: make([]bool, n),
TotalCostMicros: make([]int64, n),
RuntimeMs: make([]int64, n),
ProviderResponseID: make([]string, n),
}
for i, rc := range resultContents {
params.CreatedBy[i] = opts.UserID
params.ModelConfigID[i] = opts.ModelConfigID
params.Role[i] = database.ChatMessageRoleTool
params.Content[i] = string(rc.RawMessage)
params.ContentVersion[i] = chatprompt.CurrentContentVersion
params.Visibility[i] = database.ChatMessageVisibilityBoth
}
if _, insertErr := tx.InsertChatMessages(ctx, params); insertErr != nil {
return xerrors.Errorf("insert tool results: %w", insertErr)
}
// Transition chat to pending.
if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: opts.ChatID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
}); updateErr != nil {
return xerrors.Errorf("update chat status: %w", updateErr)
}
return nil
}, nil)
if txErr != nil {
return txErr
}
// Wake the chatd run loop so it processes the chat immediately.
p.signalWake()
return nil
}
// InterruptChat interrupts execution, sets waiting status, and broadcasts status updates.
func (p *Server) InterruptChat(
ctx context.Context,
@@ -1792,32 +1555,6 @@ func (p *Server) InterruptChat(
return chat
}
// If the chat is in requires_action, insert synthetic error
// tool-result messages for each pending dynamic tool call
// before transitioning to waiting. Without this, the LLM
// would see unmatched tool-call parts on the next run.
if chat.Status == database.ChatStatusRequiresAction {
if txErr := p.db.InTx(func(tx database.Store) error {
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID)
if lockErr != nil {
return xerrors.Errorf("lock chat for interrupt: %w", lockErr)
}
// Another request may have already transitioned
// the chat (e.g. SubmitToolResults committed
// between our snapshot and this lock).
if locked.Status != database.ChatStatusRequiresAction {
return nil
}
return insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user")
}, nil); txErr != nil {
p.logger.Error(ctx, "failed to insert synthetic tool results during interrupt",
slog.F("chat_id", chat.ID),
slog.Error(txErr),
)
// Fall through — still try to set waiting status.
}
}
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
if err != nil {
p.logger.Error(ctx, "failed to mark chat as waiting",
@@ -2608,7 +2345,7 @@ func insertUserMessageAndSetPending(
// queued while a chat is active.
func shouldQueueUserMessage(status database.ChatStatus) bool {
switch status {
case database.ChatStatusRunning, database.ChatStatusPending, database.ChatStatusRequiresAction:
case database.ChatStatusRunning, database.ChatStatusPending:
return true
default:
return false
@@ -3481,12 +3218,8 @@ func (p *Server) Subscribe(
// Pubsub will deliver a duplicate status
// later; the frontend deduplicates it
// (setChatStatus is idempotent).
// action_required is also transient and
// only published on the local stream, so
// it must be forwarded here.
if event.Type == codersdk.ChatStreamEventTypeMessagePart ||
event.Type == codersdk.ChatStreamEventTypeStatus ||
event.Type == codersdk.ChatStreamEventTypeActionRequired {
event.Type == codersdk.ChatStreamEventTypeStatus {
select {
case <-mergedCtx.Done():
return
@@ -3612,51 +3345,6 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
}
}
// pendingToStreamToolCalls converts a slice of chatloop pending
// tool calls into the SDK streaming representation.
func pendingToStreamToolCalls(pending []chatloop.PendingToolCall) []codersdk.ChatStreamToolCall {
calls := make([]codersdk.ChatStreamToolCall, len(pending))
for i, tc := range pending {
calls[i] = codersdk.ChatStreamToolCall{
ToolCallID: tc.ToolCallID,
ToolName: tc.ToolName,
Args: tc.Args,
}
}
return calls
}
// publishChatActionRequired broadcasts an action_required event via
// PostgreSQL pubsub so that global watchers can react to dynamic
// tool calls without streaming each chat individually.
func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloop.PendingToolCall) {
if p.pubsub == nil {
return
}
toolCalls := pendingToStreamToolCalls(pending)
sdkChat := db2sdk.Chat(chat, nil, nil)
event := coderdpubsub.ChatEvent{
Kind: coderdpubsub.ChatEventKindActionRequired,
Chat: sdkChat,
ToolCalls: toolCalls,
}
payload, err := json.Marshal(event)
if err != nil {
p.logger.Error(context.Background(), "failed to marshal chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
}
}
// PublishDiffStatusChange broadcasts a diff_status_change event for
// the given chat so that watching clients know to re-fetch the diff
// status. This is called from the HTTP layer after the diff status
@@ -4161,21 +3849,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
}
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
// When the chat is parked in requires_action,
// publish the stream event and global pubsub event
// after the DB status has committed. Publishing
// here (not in runChat) prevents a race where a
// fast client reacts before the status is visible.
if status == database.ChatStatusRequiresAction && len(runResult.PendingDynamicToolCalls) > 0 {
toolCalls := pendingToStreamToolCalls(runResult.PendingDynamicToolCalls)
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeActionRequired,
ActionRequired: &codersdk.ChatStreamActionRequired{
ToolCalls: toolCalls,
},
})
p.publishChatActionRequired(updatedChat, runResult.PendingDynamicToolCalls)
}
if !wasInterrupted {
p.maybeSendPushNotification(cleanupCtx, updatedChat, status, lastError, runResult, logger)
}
@@ -4204,13 +3877,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
return
}
// The LLM invoked a dynamic tool — park the chat in
// requires_action so the client can supply tool results.
if len(runResult.PendingDynamicToolCalls) > 0 {
status = database.ChatStatusRequiresAction
return
}
// If runChat completed successfully but the server context was
// canceled (e.g. during Close()), the chat should be returned
// to pending so another replica can pick it up. There is a
@@ -4277,10 +3943,9 @@ func (t *generatedChatTitle) Load() (string, bool) {
}
type runChatResult struct {
FinalAssistantText string
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
PendingDynamicToolCalls []chatloop.PendingToolCall
FinalAssistantText string
PushSummaryModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
}
func (p *Server) runChat(
@@ -4584,8 +4249,8 @@ func (p *Server) runChat(
// server.
toolNameToConfigID := make(map[string]uuid.UUID)
for _, t := range mcpTools {
if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok {
toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID()
if mcp, ok := t.(mcpclient.MCPToolIdentifier); ok {
toolNameToConfigID[t.Info().Name] = mcp.MCPServerConfigID()
}
}
@@ -4604,7 +4269,6 @@ func (p *Server) runChat(
// (which is the common case).
modelConfigContextLimit := modelConfig.ContextLimit
var finalAssistantText string
var pendingDynamicCalls []chatloop.PendingToolCall
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
// If the chat context has been canceled, bail out before
@@ -4624,10 +4288,6 @@ func (p *Server) runChat(
return persistCtx.Err()
}
// Capture pending dynamic tool calls so the caller
// can surface them after chatloop.Run returns.
pendingDynamicCalls = step.PendingDynamicToolCalls
// Split the step content into assistant blocks and tool
// result blocks so they can be stored as separate messages
// with the appropriate roles. Provider-executed tool results
@@ -4665,21 +4325,6 @@ func (p *Server) runChat(
part.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
}
}
// Apply recorded timestamps so persisted
// tool-call parts carry accurate CreatedAt.
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolCallID != "" && step.ToolCallCreatedAt != nil {
if ts, ok := step.ToolCallCreatedAt[part.ToolCallID]; ok {
part.CreatedAt = &ts
}
}
// Provider-executed tool results appear in
// assistantBlocks rather than toolResults,
// so apply their timestamps here as well.
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolCallID != "" && step.ToolResultCreatedAt != nil {
if ts, ok := step.ToolResultCreatedAt[part.ToolCallID]; ok {
part.CreatedAt = &ts
}
}
sdkParts = append(sdkParts, part)
}
finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts))
@@ -4698,13 +4343,6 @@ func (p *Server) runChat(
trPart.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
}
}
// Apply recorded timestamps so persisted
// tool-result parts carry accurate CreatedAt.
if trPart.ToolCallID != "" && step.ToolResultCreatedAt != nil {
if ts, ok := step.ToolResultCreatedAt[trPart.ToolCallID]; ok {
trPart.CreatedAt = &ts
}
}
var marshalErr error
toolResultContents[i], marshalErr = chatprompt.MarshalParts([]codersdk.ChatMessagePart{trPart})
if marshalErr != nil {
@@ -5036,39 +4674,6 @@ func (p *Server) runChat(
tools = append(tools, mcpTools...)
tools = append(tools, workspaceMCPTools...)
// Append dynamic tools declared by the client at chat
// creation time. These appear in the LLM's tool list but
// are never executed by the chatloop — the client handles
// execution via POST /tool-results.
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
if err != nil {
return result, xerrors.Errorf("parse dynamic tool names: %w", err)
}
// Unmarshal the full definitions separately so we can
// build the filtered list below. parseDynamicToolNames
// already validated the JSON, so this cannot fail.
var dynamicToolDefs []codersdk.DynamicTool
if chat.DynamicTools.Valid {
if err := json.Unmarshal(chat.DynamicTools.RawMessage, &dynamicToolDefs); err != nil {
return result, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
}
for _, t := range tools {
info := t.Info()
if dynamicToolNames[info.Name] {
logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence",
slog.F("tool_name", info.Name))
delete(dynamicToolNames, info.Name)
}
}
var filteredDefs []codersdk.DynamicTool
for _, dt := range dynamicToolDefs {
if dynamicToolNames[dt.Name] {
filteredDefs = append(filteredDefs, dt)
}
}
tools = append(tools, dynamicToolsFromSDK(p.logger, filteredDefs)...)
// Build provider-native tools (e.g., web search) based on
// the model configuration.
var providerTools []chatloop.ProviderTool
@@ -5112,7 +4717,8 @@ func (p *Server) runChat(
)
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
}
err = chatloop.Run(ctx, chatloop.RunOptions{
err := chatloop.Run(ctx, chatloop.RunOptions{
Model: model,
Messages: prompt,
Tools: tools, MaxSteps: maxChatSteps,
@@ -5120,9 +4726,6 @@ func (p *Server) runChat(
ModelConfig: callConfig,
ProviderOptions: providerOptions,
ProviderTools: providerTools,
// dynamicToolNames now contains only names that don't
// collide with built-in/MCP tools.
DynamicToolNames: dynamicToolNames,
ContextLimitFallback: modelConfigContextLimit,
@@ -5200,15 +4803,6 @@ func (p *Server) runChat(
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
},
})
if errors.Is(err, chatloop.ErrDynamicToolCall) {
// The stream event is published in processChat's
// defer after the DB status transitions to
// requires_action, preventing a race where a fast
// client reacts before the status is committed.
result.FinalAssistantText = finalAssistantText
result.PendingDynamicToolCalls = pendingDynamicCalls
return result, nil
}
if err != nil {
classified := chaterror.Classify(err).WithProvider(model.Provider())
return result, chaterror.WithClassification(err, classified)
@@ -5830,9 +5424,7 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
recovered := 0
for _, chat := range staleChats {
p.logger.Info(ctx, "recovering stale chat",
slog.F("chat_id", chat.ID),
slog.F("status", chat.Status))
p.logger.Info(ctx, "recovering stale chat", slog.F("chat_id", chat.ID))
// Use a transaction with FOR UPDATE to avoid a TOCTOU race:
// between GetStaleChats (a bare SELECT) and here, the chat's
@@ -5844,73 +5436,34 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
return xerrors.Errorf("lock chat for recovery: %w", lockErr)
}
switch locked.Status {
case database.ChatStatusRunning:
// Re-check: only recover if the chat is still stale.
// A valid heartbeat at or after the threshold means
// the chat was refreshed after our snapshot.
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
}
case database.ChatStatusRequiresAction:
// Re-check: the chat may have been updated after
// our snapshot, similar to the heartbeat check for
// running chats.
if !locked.UpdatedAt.Before(staleAfter) {
p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
}
default:
// Status changed since our snapshot; skip.
// Only recover chats that are still running.
// Between GetStaleChats and this lock, the chat
// may have completed normally.
if locked.Status != database.ChatStatusRunning {
p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID),
slog.F("status", locked.Status))
return nil
}
lastError := sql.NullString{}
if locked.Status == database.ChatStatusRequiresAction {
lastError = sql.NullString{
String: "Dynamic tool execution timed out",
Valid: true,
}
// Re-check: only recover if the chat is still stale.
// A valid heartbeat that is at or after the stale
// threshold means the chat was refreshed after our
// initial snapshot — skip it.
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
slog.F("chat_id", chat.ID))
return nil
}
recoverStatus := database.ChatStatusPending
if locked.Status == database.ChatStatusRequiresAction {
// Timed-out requires_action chats have dangling
// tool calls with no matching results. Setting
// them back to pending would replay incomplete
// tool calls to the LLM, so mark them as errors.
recoverStatus = database.ChatStatusError
}
// Insert synthetic error tool-result messages
// so the LLM history remains valid if the user
// retries the chat later.
if locked.Status == database.ChatStatusRequiresAction {
if synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil {
p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery",
slog.F("chat_id", chat.ID),
slog.Error(synthErr),
)
// Continue with error status even if
// synthetic results fail to insert.
}
}
// Reset so any replica can pick it up (pending) or
// the client sees the failure (error).
// Reset to pending so any replica can pick it up.
_, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: recoverStatus,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: lastError,
LastError: sql.NullString{},
})
if updateErr != nil {
return updateErr
@@ -5929,119 +5482,6 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
}
}
// insertSyntheticToolResultsTx inserts error tool-result messages for
// every pending dynamic tool call in the last assistant message. This
// keeps the LLM message history valid (every tool-call has a matching
// tool-result) when a requires_action chat times out or is interrupted.
// It operates on the provided store, which may be a transaction handle.
func insertSyntheticToolResultsTx(
ctx context.Context,
store database.Store,
chat database.Chat,
reason string,
) error {
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
if err != nil {
return xerrors.Errorf("parse dynamic tools: %w", err)
}
if len(dynamicToolNames) == 0 {
return nil
}
// Get the last assistant message to find pending tool calls.
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
ChatID: chat.ID,
Role: database.ChatMessageRoleAssistant,
})
if err != nil {
return xerrors.Errorf("get last assistant message: %w", err)
}
parts, err := chatprompt.ParseContent(lastAssistant)
if err != nil {
return xerrors.Errorf("parse assistant message: %w", err)
}
// Collect dynamic tool calls that need synthetic results.
var resultContents []pqtype.NullRawMessage
for _, part := range parts {
if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] {
continue
}
resultPart := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolResult,
ToolCallID: part.ToolCallID,
ToolName: part.ToolName,
Result: json.RawMessage(fmt.Sprintf("%q", reason)),
IsError: true,
}
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart})
if marshalErr != nil {
return xerrors.Errorf("marshal synthetic tool result: %w", marshalErr)
}
resultContents = append(resultContents, marshaled)
}
if len(resultContents) == 0 {
return nil
}
// Insert tool-result messages using the same pattern as
// SubmitToolResults.
n := len(resultContents)
params := database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: make([]uuid.UUID, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
ContentVersion: make([]int16, n),
Visibility: make([]database.ChatMessageVisibility, n),
InputTokens: make([]int64, n),
OutputTokens: make([]int64, n),
TotalTokens: make([]int64, n),
ReasoningTokens: make([]int64, n),
CacheCreationTokens: make([]int64, n),
CacheReadTokens: make([]int64, n),
ContextLimit: make([]int64, n),
Compressed: make([]bool, n),
TotalCostMicros: make([]int64, n),
RuntimeMs: make([]int64, n),
ProviderResponseID: make([]string, n),
}
for i, rc := range resultContents {
params.CreatedBy[i] = uuid.Nil
params.ModelConfigID[i] = chat.LastModelConfigID
params.Role[i] = database.ChatMessageRoleTool
params.Content[i] = string(rc.RawMessage)
params.ContentVersion[i] = chatprompt.CurrentContentVersion
params.Visibility[i] = database.ChatMessageVisibilityBoth
}
if _, err := store.InsertChatMessages(ctx, params); err != nil {
return xerrors.Errorf("insert synthetic tool results: %w", err)
}
return nil
}
// parseDynamicToolNames unmarshals the dynamic tools JSON column
// and returns a map of tool names. This centralizes the repeated
// pattern of deserializing DynamicTools into a name set.
func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return make(map[string]bool), nil
}
var tools []codersdk.DynamicTool
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
names := make(map[string]bool, len(tools))
for _, t := range tools {
names[t.Name] = true
}
return names, nil
}
// maybeSendPushNotification sends a web push notification when an
// agent chat reaches a terminal state. For errors it dispatches
// synchronously; for successful completions it spawns a goroutine
-576
View File
@@ -1531,70 +1531,6 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
}, testutil.WaitMedium, testutil.IntervalFast)
}
func TestRecoverStaleRequiresActionChat(t *testing.T) {
t.Parallel()
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
user, model := seedChatDependencies(ctx, t, db)
// Use a very short stale threshold so the periodic recovery
// kicks in quickly during the test.
staleAfter := 500 * time.Millisecond
// Create a chat and set it to requires_action to simulate a
// client that disappeared while the chat was waiting for
// dynamic tool results.
chat, err := db.InsertChat(ctx, database.InsertChatParams{
Status: database.ChatStatusWaiting,
OwnerID: user.ID,
Title: "stale-requires-action",
LastModelConfigID: model.ID,
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusRequiresAction,
})
require.NoError(t, err)
// Backdate updated_at so the chat appears stale to the
// recovery loop without needing time.Sleep.
_, err = rawDB.ExecContext(ctx,
"UPDATE chats SET updated_at = $1 WHERE id = $2",
time.Now().Add(-time.Hour), chat.ID)
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
server := chatd.New(chatd.Config{
Logger: logger,
Database: db,
ReplicaID: uuid.New(),
Pubsub: ps,
PendingChatAcquireInterval: testutil.WaitLong,
InFlightChatStaleAfter: staleAfter,
})
t.Cleanup(func() {
require.NoError(t, server.Close())
})
// The stale recovery should transition the requires_action
// chat to error with the timeout message.
var chatResult database.Chat
require.Eventually(t, func() bool {
chatResult, err = db.GetChatByID(ctx, chat.ID)
if err != nil {
return false
}
return chatResult.Status == database.ChatStatusError
}, testutil.WaitMedium, testutil.IntervalFast)
require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out")
require.False(t, chatResult.WorkerID.Valid)
}
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
t.Parallel()
@@ -1946,518 +1882,6 @@ func TestPersistToolResultWithBinaryData(t *testing.T) {
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
}
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
// Non-streaming requests are title generation — return a
// simple title.
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
}
// Capture the full request for later assertions.
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: the LLM invokes our dynamic tool.
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
),
)
}
// Second call: the LLM returns a normal text response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Dynamic tool result received.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
// Dynamic tools do not need a workspace connection, but the
// chatd server always builds workspace tools. Use an active
// server without an agent connection — the built-in tools
// are never invoked because the only tool call targets our
// dynamic tool.
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "dynamic-tool-pause-resume",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Read the assistant message to find the tool-call ID.
var toolCallID string
var toolCallFound bool
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
toolCallFound = true
return true
}
}
}
return false
}, testutil.IntervalFast)
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
require.NotEmpty(t, toolCallID)
// 3. Submit tool results via SubmitToolResults.
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: toolResultOutput,
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to reach a terminal status.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
// 5. Verify the chat completed successfully.
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 6. Verify the mock received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
streamedCallsMu.Lock()
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
streamedCallsMu.Unlock()
require.Len(t, recordedCalls, 2)
// 7. Verify the dynamic tool appeared in the first call's tool list.
var foundDynamicTool bool
for _, tool := range recordedCalls[0].Tools {
if tool.Function.Name == "my_dynamic_tool" {
foundDynamicTool = true
break
}
}
require.True(t, foundDynamicTool,
"expected 'my_dynamic_tool' in the first LLM call's tool list")
// 8. Verify the second call's messages contain the tool result.
var foundToolResultInSecondCall bool
for _, message := range recordedCalls[1].Messages {
if message.Role != "tool" {
continue
}
if strings.Contains(message.Content, "dynamic tool output") {
foundToolResultInSecondCall = true
break
}
}
require.True(t, foundToolResultInSecondCall,
"expected second LLM call to include the submitted dynamic tool result")
}
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// Track streaming calls to the mock LLM.
var streamedCallCount atomic.Int32
var streamedCallsMu sync.Mutex
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Mixed tool test")
}
streamedCallsMu.Lock()
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
Stream: req.Stream,
})
streamedCallsMu.Unlock()
if streamedCallCount.Add(1) == 1 {
// First call: return TWO tool calls in one
// response — a built-in tool (read_file) and a
// dynamic tool (my_dynamic_tool).
builtinChunk := chattest.OpenAIToolCallChunk(
"read_file",
`{"path":"/tmp/test.txt"}`,
)
dynamicChunk := chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello world"}`,
)
// Merge both tool calls into one chunk with
// separate indices so the LLM appears to have
// requested both tools simultaneously.
mergedChunk := builtinChunk
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
dynCall.Index = 1
mergedChunk.Choices[0].ToolCalls = append(
mergedChunk.Choices[0].ToolCalls,
dynCall,
)
return chattest.OpenAIStreamingResponse(mergedChunk)
}
// Second call (after tool results): normal text
// response.
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("All done.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "mixed-builtin-dynamic",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Call both tools."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 1. Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// 2. Verify the built-in tool (read_file) was already
// executed by checking that a tool result message
// exists for it in the database.
var builtinToolResultFound bool
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
// Check for the built-in tool result.
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
builtinToolResultFound = true
}
// Find the dynamic tool call ID.
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
}
}
}
return builtinToolResultFound && toolCallID != ""
}, testutil.IntervalFast)
require.True(t, builtinToolResultFound,
"expected read_file tool result in the DB before dynamic tool resolution")
require.NotEmpty(t, toolCallID)
// 3. Submit dynamic tool results.
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"dynamic output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// 4. Wait for the chat to complete.
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
if chatResult.Status == database.ChatStatusError {
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
}
// 5. Verify the LLM received exactly 2 streaming calls.
require.Equal(t, int32(2), streamedCallCount.Load(),
"expected exactly 2 streaming calls to the LLM")
}
func TestSubmitToolResultsConcurrency(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
// The mock LLM returns a dynamic tool call on the first streaming
// request, then a plain text reply on the second.
var streamedCallCount atomic.Int32
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if !req.Stream {
return chattest.OpenAINonStreamingResponse("Concurrency test")
}
if streamedCallCount.Add(1) == 1 {
return chattest.OpenAIStreamingResponse(
chattest.OpenAIToolCallChunk(
"my_dynamic_tool",
`{"input":"hello"}`,
),
)
}
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("Done.")...,
)
})
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
server := newActiveTestServer(t, db, ps)
// Create a chat with a dynamic tool.
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
Name: "my_dynamic_tool",
Description: "A test dynamic tool.",
InputSchema: mcpgo.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{"type": "string"},
},
Required: []string{"input"},
},
}})
require.NoError(t, err)
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
OwnerID: user.ID,
Title: "concurrency-tool-results",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("Please call the dynamic tool."),
},
DynamicTools: dynamicToolsJSON,
})
require.NoError(t, err)
// Wait for the chat to reach requires_action status.
var chatResult database.Chat
require.Eventually(t, func() bool {
got, getErr := db.GetChatByID(ctx, chat.ID)
if getErr != nil {
return false
}
chatResult = got
return got.Status == database.ChatStatusRequiresAction ||
got.Status == database.ChatStatusError
}, testutil.WaitLong, testutil.IntervalFast)
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
"expected requires_action, got %s (last_error=%q)",
chatResult.Status, chatResult.LastError.String)
// Find the tool call ID from the assistant message.
var toolCallID string
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
if dbErr != nil {
return false
}
for _, msg := range messages {
if msg.Role != database.ChatMessageRoleAssistant {
continue
}
parts, parseErr := chatprompt.ParseContent(msg)
if parseErr != nil {
continue
}
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
toolCallID = part.ToolCallID
return true
}
}
}
return false
}, testutil.IntervalFast)
require.NotEmpty(t, toolCallID)
// Spawn N goroutines that all try to submit tool results at the
// same time. Exactly one should succeed; the rest must get a
// ToolResultStatusConflictError.
const numGoroutines = 10
var (
wg sync.WaitGroup
ready = make(chan struct{})
successes atomic.Int32
conflicts atomic.Int32
unexpectedErrors = make(chan error, numGoroutines)
)
for range numGoroutines {
wg.Go(func() {
// Wait for all goroutines to be ready.
<-ready
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
ChatID: chat.ID,
UserID: user.ID,
ModelConfigID: chatResult.LastModelConfigID,
Results: []codersdk.ToolResult{{
ToolCallID: toolCallID,
Output: json.RawMessage(`{"result":"concurrent output"}`),
}},
DynamicTools: dynamicToolsJSON,
})
if submitErr == nil {
successes.Add(1)
return
}
var conflict *chatd.ToolResultStatusConflictError
if errors.As(submitErr, &conflict) {
conflicts.Add(1)
return
}
// Collect unexpected errors for assertion
// outside the goroutine (require.NoError
// calls t.FailNow which is illegal here).
unexpectedErrors <- submitErr
})
}
// Release all goroutines at once.
close(ready)
wg.Wait()
close(unexpectedErrors)
for ue := range unexpectedErrors {
require.NoError(t, ue, "unexpected error from SubmitToolResults")
}
require.Equal(t, int32(1), successes.Load(),
"expected exactly 1 goroutine to succeed")
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
"expected %d conflict errors", numGoroutines-1)
}
func ptrRef[T any](v T) *T {
return &v
}
+25 -206
View File
@@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/json"
"errors"
"maps"
"slices"
"strconv"
"strings"
@@ -19,7 +18,6 @@ import (
"charm.land/fantasy/schema"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
@@ -40,23 +38,13 @@ const (
)
var (
ErrInterrupted = xerrors.New("chat interrupted")
ErrDynamicToolCall = xerrors.New("dynamic tool call")
ErrInterrupted = xerrors.New("chat interrupted")
errStartupTimeout = xerrors.New(
"chat response did not start before the startup timeout",
)
)
// PendingToolCall describes a tool call that targets a dynamic
// tool. These calls are not executed by the chatloop; instead
// they are persisted so the caller can fulfill them externally.
type PendingToolCall struct {
ToolCallID string
ToolName string
Args string
}
// PersistedStep contains the full content of a completed or
// interrupted agent step. Content includes both assistant blocks
// (text, reasoning, tool calls) and tool result blocks. The
@@ -72,21 +60,6 @@ type PersistedStep struct {
// Zero indicates the duration was not measured (e.g.
// interrupted steps).
Runtime time.Duration
// PendingDynamicToolCalls lists tool calls that target
// dynamic tools. When non-empty the chatloop exits with
// ErrDynamicToolCall so the caller can execute them
// externally and resume the loop.
PendingDynamicToolCalls []PendingToolCall
// ToolCallCreatedAt maps tool-call IDs to the time
// the model emitted each tool call. Applied by the
// persistence layer to set CreatedAt on persisted
// tool-call ChatMessageParts.
ToolCallCreatedAt map[string]time.Time
// ToolResultCreatedAt maps tool-call IDs to the time
// each tool result was produced (or interrupted).
// Applied by the persistence layer to set CreatedAt
// on persisted tool-result ChatMessageParts.
ToolResultCreatedAt map[string]time.Time
}
// RunOptions configures a single streaming chat loop run.
@@ -104,12 +77,6 @@ type RunOptions struct {
ActiveTools []string
ContextLimitFallback int64
// DynamicToolNames lists tool names that are handled
// externally. When the model invokes one of these tools
// the chatloop persists partial results and exits with
// ErrDynamicToolCall instead of executing the tool.
DynamicToolNames map[string]bool
// ModelConfig holds per-call LLM parameters (temperature,
// max tokens, etc.) read from the chat model configuration.
ModelConfig codersdk.ChatModelCallConfig
@@ -161,14 +128,12 @@ type ProviderTool struct {
// step. Since we own the stream consumer, all content is tracked
// directly here — no shadow draft state needed.
type stepResult struct {
content []fantasy.Content
usage fantasy.Usage
providerMetadata fantasy.ProviderMetadata
finishReason fantasy.FinishReason
toolCalls []fantasy.ToolCallContent
shouldContinue bool
toolCallCreatedAt map[string]time.Time
toolResultCreatedAt map[string]time.Time
content []fantasy.Content
usage fantasy.Usage
providerMetadata fantasy.ProviderMetadata
finishReason fantasy.FinishReason
toolCalls []fantasy.ToolCallContent
shouldContinue bool
}
// toResponseMessages converts step content into messages suitable
@@ -420,72 +385,16 @@ func Run(ctx context.Context, opts RunOptions) error {
return ctx.Err()
}
// Partition tool calls into built-in and dynamic.
var builtinCalls, dynamicCalls []fantasy.ToolCallContent
if len(opts.DynamicToolNames) > 0 {
for _, tc := range result.toolCalls {
if opts.DynamicToolNames[tc.ToolName] {
dynamicCalls = append(dynamicCalls, tc)
} else {
builtinCalls = append(builtinCalls, tc)
}
}
} else {
builtinCalls = result.toolCalls
}
// Execute only built-in tools.
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, builtinCalls, func(tr fantasy.ToolResultContent, completedAt time.Time) {
recordToolResultTimestamp(&result, tr.ToolCallID, completedAt)
ssePart := chatprompt.PartFromContent(tr)
ssePart.CreatedAt = &completedAt
publishMessagePart(codersdk.ChatMessageRoleTool, ssePart)
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
publishMessagePart(
codersdk.ChatMessageRoleTool,
chatprompt.PartFromContent(tr),
)
})
for _, tr := range toolResults {
result.content = append(result.content, tr)
}
// If dynamic tools were called, persist what we
// have (assistant + built-in results) and exit so
// the caller can execute them externally.
if len(dynamicCalls) > 0 {
pending := make([]PendingToolCall, 0, len(dynamicCalls))
for _, dc := range dynamicCalls {
pending = append(pending, PendingToolCall{
ToolCallID: dc.ToolCallID,
ToolName: dc.ToolName,
Args: dc.Input,
})
}
contextLimit := extractContextLimit(result.providerMetadata)
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
contextLimit = sql.NullInt64{
Int64: opts.ContextLimitFallback,
Valid: true,
}
}
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
PendingDynamicToolCalls: pending,
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
return ErrInterrupted
}
return xerrors.Errorf("persist step: %w", err)
}
tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata)
return ErrDynamicToolCall
}
// Check for interruption after tool execution.
// Tools that were canceled mid-flight produce error
// results via ctx cancellation. Persist the full
@@ -512,13 +421,11 @@ func Run(ctx context.Context, opts RunOptions) error {
// check and here, fall back to the interrupt-safe
// path so partial content is not lost.
if err := opts.PersistStep(ctx, PersistedStep{
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
ToolCallCreatedAt: result.toolCallCreatedAt,
ToolResultCreatedAt: result.toolResultCreatedAt,
Content: result.content,
Usage: result.usage,
ContextLimit: contextLimit,
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
Runtime: time.Since(stepStart),
}); err != nil {
if errors.Is(err, ErrInterrupted) {
persistInterruptedStep(ctx, opts, &result)
@@ -851,20 +758,9 @@ func processStepStream(
// Clean up active tool call tracking.
delete(activeToolCalls, part.ID)
// Record when the model emitted this tool call
// so the persisted part carries an accurate
// timestamp for duration computation.
now := dbtime.Now()
if result.toolCallCreatedAt == nil {
result.toolCallCreatedAt = make(map[string]time.Time)
}
result.toolCallCreatedAt[part.ID] = now
ssePart := chatprompt.PartFromContent(tc)
ssePart.CreatedAt = &now
publishMessagePart(
codersdk.ChatMessageRoleAssistant,
ssePart,
chatprompt.PartFromContent(tc),
)
case fantasy.StreamPartTypeSource:
@@ -894,18 +790,9 @@ func processStepStream(
ProviderMetadata: part.ProviderMetadata,
}
result.content = append(result.content, tr)
now := dbtime.Now()
if result.toolResultCreatedAt == nil {
result.toolResultCreatedAt = make(map[string]time.Time)
}
result.toolResultCreatedAt[part.ID] = now
ssePart := chatprompt.PartFromContent(tr)
ssePart.CreatedAt = &now
publishMessagePart(
codersdk.ChatMessageRoleTool,
ssePart,
chatprompt.PartFromContent(tr),
)
}
case fantasy.StreamPartTypeFinish:
@@ -974,7 +861,7 @@ func executeTools(
allTools []fantasy.AgentTool,
providerTools []ProviderTool,
toolCalls []fantasy.ToolCallContent,
onResult func(fantasy.ToolResultContent, time.Time),
onResult func(fantasy.ToolResultContent),
) []fantasy.ToolResultContent {
if len(toolCalls) == 0 {
return nil
@@ -1007,11 +894,10 @@ func executeTools(
}
results := make([]fantasy.ToolResultContent, len(localToolCalls))
completedAt := make([]time.Time, len(localToolCalls))
var wg sync.WaitGroup
wg.Add(len(localToolCalls))
for i, tc := range localToolCalls {
go func() {
go func(i int, tc fantasy.ToolCallContent) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
@@ -1023,21 +909,17 @@ func executeTools(
},
}
}
// Record when this tool completed (or panicked).
// Captured per-goroutine so parallel tools get
// accurate individual completion times.
completedAt[i] = dbtime.Now()
}()
results[i] = executeSingleTool(ctx, toolMap, tc)
}()
}(i, tc)
}
wg.Wait()
// Publish results in the original tool-call order so SSE
// subscribers see a deterministic event sequence.
if onResult != nil {
for i, tr := range results {
onResult(tr, completedAt[i])
for _, tr := range results {
onResult(tr)
}
}
return results
@@ -1173,24 +1055,11 @@ func persistInterruptedStep(
}
}
// Copy existing timestamps and add result timestamps for
// interrupted tool calls so the frontend can show partial
// duration.
toolCallCreatedAt := maps.Clone(result.toolCallCreatedAt)
if toolCallCreatedAt == nil {
toolCallCreatedAt = make(map[string]time.Time)
}
toolResultCreatedAt := maps.Clone(result.toolResultCreatedAt)
if toolResultCreatedAt == nil {
toolResultCreatedAt = make(map[string]time.Time)
}
// Build combined content: all accumulated content + synthetic
// interrupted results for any unanswered tool calls.
content := make([]fantasy.Content, 0, len(result.content))
content = append(content, result.content...)
interruptedAt := dbtime.Now()
for _, tc := range result.toolCalls {
if tc.ToolCallID == "" {
continue
@@ -1206,20 +1075,12 @@ func persistInterruptedStep(
Error: xerrors.New(interruptedToolResultErrorMessage),
},
})
// Only stamp synthetic results; don't clobber
// timestamps from tools that completed before
// the interruption arrived.
if _, exists := toolResultCreatedAt[tc.ToolCallID]; !exists {
toolResultCreatedAt[tc.ToolCallID] = interruptedAt
}
answeredToolCalls[tc.ToolCallID] = struct{}{}
}
persistCtx := context.WithoutCancel(ctx)
if err := opts.PersistStep(persistCtx, PersistedStep{
Content: content,
ToolCallCreatedAt: toolCallCreatedAt,
ToolResultCreatedAt: toolResultCreatedAt,
Content: content,
}); err != nil {
if opts.OnInterruptedPersistError != nil {
opts.OnInterruptedPersistError(err)
@@ -1227,38 +1088,6 @@ func persistInterruptedStep(
}
}
// tryCompactOnExit runs compaction when the chatloop is about
// to exit early (e.g. via ErrDynamicToolCall). The normal
// inline and post-run compaction paths are unreachable in
// early-exit scenarios, so this ensures the context window
// doesn't grow unbounded.
func tryCompactOnExit(
ctx context.Context,
opts RunOptions,
usage fantasy.Usage,
metadata fantasy.ProviderMetadata,
) {
if opts.Compaction == nil || opts.ReloadMessages == nil {
return
}
reloaded, err := opts.ReloadMessages(ctx)
if err != nil {
return
}
_, compactErr := tryCompact(
ctx,
opts.Model,
opts.Compaction,
opts.ContextLimitFallback,
usage,
metadata,
reloaded,
)
if compactErr != nil && opts.Compaction.OnError != nil {
opts.Compaction.OnError(compactErr)
}
}
// buildToolDefinitions converts AgentTool definitions into the
// fantasy.Tool slice expected by fantasy.Call. When activeTools
// is non-empty, only function tools whose name appears in the
@@ -1410,16 +1239,6 @@ func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool {
return false
}
// recordToolResultTimestamp lazily initializes the
// toolResultCreatedAt map on the stepResult and records
// the completion timestamp for the given tool-call ID.
func recordToolResultTimestamp(result *stepResult, toolCallID string, ts time.Time) {
if result.toolResultCreatedAt == nil {
result.toolResultCreatedAt = make(map[string]time.Time)
}
result.toolResultCreatedAt[toolCallID] = ts
}
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
if len(metadata) == 0 {
return sql.NullInt64{}
+1 -238
View File
@@ -86,54 +86,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{
Type: fantasy.StreamPartTypeFinish,
Usage: fantasy.Usage{
InputTokens: 200,
OutputTokens: 75,
TotalTokens: 275,
CacheCreationTokens: 30,
CacheReadTokens: 150,
ReasoningTokens: 0,
},
FinishReason: fantasy.FinishReasonStop,
},
}), nil
},
}
var persistedStep PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, int64(200), persistedStep.Usage.InputTokens)
require.Equal(t, int64(75), persistedStep.Usage.OutputTokens)
require.Equal(t, int64(275), persistedStep.Usage.TotalTokens)
require.Equal(t, int64(30), persistedStep.Usage.CacheCreationTokens)
require.Equal(t, int64(150), persistedStep.Usage.CacheReadTokens)
}
func TestRun_OnRetryEnrichesProvider(t *testing.T) {
t.Parallel()
@@ -583,7 +535,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
var persistedStep PersistedStep
err := Run(ctx, RunOptions{
Model: model,
@@ -597,7 +548,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
persistedStep = step
return nil
},
})
@@ -637,14 +587,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
// The interrupted tool was flushed mid-stream (never reached
// StreamPartTypeToolCall), so it has no call timestamp.
// But the synthetic error result must have a result timestamp.
require.Contains(t, persistedStep.ToolResultCreatedAt, "interrupt-tool-1",
"interrupted tool result must have a result timestamp")
require.NotContains(t, persistedStep.ToolCallCreatedAt, "interrupt-tool-1",
"interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)")
}
type loopTestModel struct {
@@ -785,7 +727,6 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
}
var persistStepCalls int
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
@@ -795,9 +736,8 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, step PersistedStep) error {
PersistStep: func(_ context.Context, _ PersistedStep) error {
persistStepCalls++
persistedSteps = append(persistedSteps, step)
return nil
},
})
@@ -838,112 +778,6 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
}
require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0")
require.True(t, foundToolResult, "second call prompt should contain tool result message")
// The first persisted step (tool-call step) must carry
// accurate timestamps for duration computation.
require.Len(t, persistedSteps, 2)
toolStep := persistedSteps[0]
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
"tool-call step must record when the model emitted the call")
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
"tool-call step must record when the tool result was produced")
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
"tool-result timestamp must be >= tool-call timestamp")
}
func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCalls int
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
mu.Unlock()
_ = call
switch step {
case 0:
// Step 0: produce two tool calls in one stream.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"a.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{"path":"a.go"}`,
},
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-2", ToolCallName: "write_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-2", Delta: `{"path":"b.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-2"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-2",
ToolCallName: "write_file",
ToolCallInput: `{"path":"b.go"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
default:
// Step 1: return plain text.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
}
},
}
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "do both"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
newNoopTool("write_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedSteps = append(persistedSteps, step)
return nil
},
})
require.NoError(t, err)
// Two steps: tool-call step + text step.
require.Equal(t, 2, streamCalls)
require.Len(t, persistedSteps, 2)
toolStep := persistedSteps[0]
// Both tool-call IDs must appear in ToolCallCreatedAt.
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
"tool-call step must record when tc-1 was emitted")
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-2",
"tool-call step must record when tc-2 was emitted")
// Both tool-call IDs must appear in ToolResultCreatedAt.
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
"tool-call step must record when tc-1 result was produced")
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-2",
"tool-call step must record when tc-2 result was produced")
// Result timestamps must be >= call timestamps for both.
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
"tc-1 tool-result timestamp must be >= tool-call timestamp")
require.False(t, toolStep.ToolResultCreatedAt["tc-2"].Before(toolStep.ToolCallCreatedAt["tc-2"]),
"tc-2 tool-result timestamp must be >= tool-call timestamp")
}
func TestRun_PersistStepErrorPropagates(t *testing.T) {
@@ -1349,77 +1183,6 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
}
// TestRun_ProviderExecutedToolResultTimestamps verifies that
// provider-executed tool results (e.g. web search) have their
// timestamps recorded in PersistedStep.ToolResultCreatedAt so
// the persistence layer can stamp CreatedAt on the parts.
func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
// Simulate a provider-executed tool call and result
// (e.g. Anthropic web search) followed by a text
// response — all in a single stream.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "ws-1",
ToolCallName: "web_search",
ToolCallInput: `{"query":"coder"}`,
ProviderExecuted: true,
},
// Provider-executed tool result — emitted by
// the provider, not our tool runner.
{
Type: fantasy.StreamPartTypeToolResult,
ID: "ws-1",
ToolCallName: "web_search",
ProviderExecuted: true,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "search done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
var persistedSteps []PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "search for coder"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistedSteps = append(persistedSteps, step)
return nil
},
})
require.NoError(t, err)
require.Len(t, persistedSteps, 1)
step := persistedSteps[0]
// Provider-executed tool call should have a call timestamp.
require.Contains(t, step.ToolCallCreatedAt, "ws-1",
"provider-executed tool call must record its timestamp")
// Provider-executed tool result should have a result
// timestamp so the frontend can compute duration.
require.Contains(t, step.ToolResultCreatedAt, "ws-1",
"provider-executed tool result must record its timestamp")
require.False(t,
step.ToolResultCreatedAt["ws-1"].Before(step.ToolCallCreatedAt["ws-1"]),
"tool-result timestamp must be >= tool-call timestamp")
}
// TestRun_PersistStepInterruptedFallback verifies that when the normal
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
// race), the step is retried via the interrupt-safe path.
@@ -713,76 +713,4 @@ func TestRun_Compaction(t *testing.T) {
}
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
})
t.Run("TriggersOnDynamicToolExit", func(t *testing.T) {
t.Parallel()
var persistCompactionCalls int
const summaryText = "compaction summary for dynamic tool exit"
// The LLM calls a dynamic tool. Usage is above the
// compaction threshold so compaction should fire even
// though the chatloop exits via ErrDynamicToolCall.
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "my_dynamic_tool",
ToolCallInput: `{"query": "test"}`,
},
{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonToolCalls,
Usage: fantasy.Usage{
InputTokens: 80,
TotalTokens: 85,
},
},
}), nil
},
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: []fantasy.Content{
fantasy.TextContent{Text: summaryText},
},
}, nil
},
}
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 5,
DynamicToolNames: map[string]bool{"my_dynamic_tool": true},
PersistStep: func(_ context.Context, _ PersistedStep) error {
return nil
},
ContextLimitFallback: 100,
Compaction: &CompactionOptions{
ThresholdPercent: 70,
SummaryPrompt: "summarize now",
Persist: func(_ context.Context, result CompactionResult) error {
persistCompactionCalls++
require.Contains(t, result.SystemSummary, summaryText)
return nil
},
},
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
return []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
}, nil
},
})
require.ErrorIs(t, err, ErrDynamicToolCall)
require.Equal(t, 1, persistCompactionCalls,
"compaction must fire before dynamic tool exit")
})
}
@@ -2329,48 +2329,3 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
require.True(t, isText, "expected ToolResultOutputContentText")
})
}
func TestPartFromContent_CreatedAtNotStamped(t *testing.T) {
t.Parallel()
// PartFromContent must NOT stamp CreatedAt itself.
// The chatloop layer records timestamps separately and
// the persistence layer applies them. PartFromContent
// is called in multiple contexts (SSE publishing,
// persistence) so stamping inside it would produce
// inaccurate durations.
t.Run("ToolCallHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "execute",
})
assert.Nil(t, part.CreatedAt)
})
t.Run("ToolCallPointerHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(&fantasy.ToolCallContent{
ToolCallID: "tc-1",
ToolName: "execute",
})
assert.Nil(t, part.CreatedAt)
})
t.Run("ToolResultHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.ToolResultContent{
ToolCallID: "tc-1",
ToolName: "execute",
Result: fantasy.ToolResultOutputContentText{Text: "{}"},
})
assert.Nil(t, part.CreatedAt)
})
t.Run("TextHasNilCreatedAt", func(t *testing.T) {
t.Parallel()
part := chatprompt.PartFromContent(fantasy.TextContent{Text: "hello"})
assert.Nil(t, part.CreatedAt)
})
}
+7 -89
View File
@@ -53,10 +53,8 @@ type AnthropicMessage struct {
// AnthropicUsage represents usage information in an Anthropic response.
type AnthropicUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
// AnthropicChunk represents a streaming chunk from Anthropic.
@@ -69,16 +67,14 @@ type AnthropicChunk struct {
StopReason string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
Usage AnthropicUsage `json:"usage,omitempty"`
UsageMap map[string]int `json:"-"`
}
// AnthropicChunkMessage represents message metadata in a chunk.
type AnthropicChunkMessage struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Usage map[string]int `json:"usage,omitempty"`
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
}
// AnthropicContentBlock represents a content block in a chunk.
@@ -210,11 +206,7 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
"stop_reason": chunk.StopReason,
"stop_sequence": chunk.StopSequence,
}
if chunk.UsageMap != nil {
chunkData["usage"] = chunk.UsageMap
} else {
chunkData["usage"] = chunk.Usage
}
chunkData["usage"] = chunk.Usage
case "message_stop":
// No additional fields
}
@@ -350,80 +342,6 @@ func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
return chunks
}
// AnthropicTextChunksWithCacheUsage creates a streaming response with text
// deltas and explicit cache token usage. The message_start event carries
// the initial input and cache token counts, and the final message_delta
// carries the output token count.
func AnthropicTextChunksWithCacheUsage(usage AnthropicUsage, deltas ...string) []AnthropicChunk {
if len(deltas) == 0 {
return nil
}
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
model := "claude-3-opus-20240229"
messageUsage := map[string]int{
"input_tokens": usage.InputTokens,
}
if usage.CacheCreationInputTokens != 0 {
messageUsage["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens != 0 {
messageUsage["cache_read_input_tokens"] = usage.CacheReadInputTokens
}
chunks := []AnthropicChunk{
{
Type: "message_start",
Message: AnthropicChunkMessage{
ID: messageID,
Type: "message",
Role: "assistant",
Model: model,
Usage: messageUsage,
},
},
{
Type: "content_block_start",
Index: 0,
ContentBlock: AnthropicContentBlock{
Type: "text",
Text: "",
},
},
}
for _, delta := range deltas {
chunks = append(chunks, AnthropicChunk{
Type: "content_block_delta",
Index: 0,
Delta: AnthropicDeltaBlock{
Type: "text_delta",
Text: delta,
},
})
}
chunks = append(chunks,
AnthropicChunk{
Type: "content_block_stop",
Index: 0,
},
AnthropicChunk{
Type: "message_delta",
StopReason: "end_turn",
UsageMap: map[string]int{
"output_tokens": usage.OutputTokens,
},
},
AnthropicChunk{
Type: "message_stop",
},
)
return chunks
}
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
// Input JSON can be split across multiple deltas, matching Anthropic's
// input_json_delta streaming behavior.
-53
View File
@@ -63,59 +63,6 @@ func TestAnthropic_Streaming(t *testing.T) {
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
}
func TestAnthropic_StreamingUsageIncludesCacheTokens(t *testing.T) {
t.Parallel()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
return chattest.AnthropicStreamingResponse(
chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{
InputTokens: 200,
OutputTokens: 75,
CacheCreationInputTokens: 30,
CacheReadInputTokens: 150,
}, "cached", " response")...,
)
})
client, err := fantasyanthropic.New(
fantasyanthropic.WithAPIKey("test-key"),
fantasyanthropic.WithBaseURL(serverURL),
)
require.NoError(t, err)
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
require.NoError(t, err)
stream, err := model.Stream(context.Background(), fantasy.Call{
Prompt: []fantasy.Message{
{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
},
},
})
require.NoError(t, err)
var (
finishPart fantasy.StreamPart
found bool
)
for part := range stream {
if part.Type != fantasy.StreamPartTypeFinish {
continue
}
finishPart = part
found = true
}
require.True(t, found)
require.Equal(t, int64(200), finishPart.Usage.InputTokens)
require.Equal(t, int64(75), finishPart.Usage.OutputTokens)
require.Equal(t, int64(275), finishPart.Usage.TotalTokens)
require.Equal(t, int64(30), finishPart.Usage.CacheCreationTokens)
require.Equal(t, int64(150), finishPart.Usage.CacheReadTokens)
}
func TestAnthropic_ToolCalls(t *testing.T) {
t.Parallel()
-91
View File
@@ -1,91 +0,0 @@
package chatd
import (
"context"
"encoding/json"
"charm.land/fantasy"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/codersdk"
)
// dynamicTool wraps a codersdk.DynamicTool as a fantasy.AgentTool.
// These tools are presented to the LLM but never executed by the
// chatloop — when the LLM calls one, the chatloop exits with
// requires_action status and the client handles execution.
// The Run method should never be called; it returns an error if
// it is, as a safety net.
type dynamicTool struct {
name string
description string
parameters map[string]any
required []string
opts fantasy.ProviderOptions
}
// dynamicToolsFromSDK converts codersdk.DynamicTool definitions
// into fantasy.AgentTool implementations for inclusion in the LLM
// tool list.
func dynamicToolsFromSDK(logger slog.Logger, tools []codersdk.DynamicTool) []fantasy.AgentTool {
if len(tools) == 0 {
return nil
}
result := make([]fantasy.AgentTool, 0, len(tools))
for _, t := range tools {
dt := &dynamicTool{
name: t.Name,
description: t.Description,
}
// InputSchema is a full JSON Schema object stored as
// json.RawMessage. Extract the "properties" and
// "required" fields that fantasy.ToolInfo expects.
if len(t.InputSchema) > 0 {
var schema struct {
Properties map[string]any `json:"properties"`
Required []string `json:"required"`
}
if err := json.Unmarshal(t.InputSchema, &schema); err != nil {
// Defensive: present the tool with no parameter
// constraints rather than failing. The LLM may
// hallucinate argument shapes, but the tool will
// still appear in the tool list.
logger.Warn(context.Background(), "failed to parse dynamic tool input schema",
slog.F("tool_name", t.Name),
slog.Error(err))
} else {
dt.parameters = schema.Properties
dt.required = schema.Required
}
}
result = append(result, dt)
}
return result
}
func (t *dynamicTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: t.name,
Description: t.description,
Parameters: t.parameters,
Required: t.required,
}
}
func (*dynamicTool) Run(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
// Dynamic tools are never executed by the chatloop. If this
// method is called, it indicates a bug in the chatloop's
// dynamic tool detection logic.
return fantasy.NewTextErrorResponse(
"dynamic tool called in chatloop — this is a bug; " +
"dynamic tools should be handled by the client",
), nil
}
func (t *dynamicTool) ProviderOptions() fantasy.ProviderOptions {
return t.opts
}
func (t *dynamicTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.opts = opts
}
-114
View File
@@ -1,114 +0,0 @@
package chatd
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk"
)
func TestDynamicToolsFromSDK(t *testing.T) {
t.Parallel()
t.Run("EmptySlice", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
result := dynamicToolsFromSDK(logger, nil)
require.Nil(t, result)
})
t.Run("ValidToolWithSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "my_tool",
Description: "A useful tool",
InputSchema: json.RawMessage(`{"type":"object","properties":{"input":{"type":"string"}},"required":["input"]}`),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "my_tool", info.Name)
require.Equal(t, "A useful tool", info.Description)
require.NotNil(t, info.Parameters)
require.Contains(t, info.Parameters, "input")
require.Equal(t, []string{"input"}, info.Required)
})
t.Run("ToolWithoutSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "no_schema",
Description: "Tool with no schema",
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "no_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
t.Run("MalformedSchema", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "bad_schema",
Description: "Tool with malformed schema",
InputSchema: json.RawMessage("not-json"),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "bad_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
t.Run("MultipleTools", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{Name: "first", Description: "First tool"},
{Name: "second", Description: "Second tool"},
{Name: "third", Description: "Third tool"},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 3)
require.Equal(t, "first", result[0].Info().Name)
require.Equal(t, "second", result[1].Info().Name)
require.Equal(t, "third", result[2].Info().Name)
})
t.Run("SchemaWithoutProperties", func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
tools := []codersdk.DynamicTool{
{
Name: "bare_schema",
Description: "Schema with no properties",
InputSchema: json.RawMessage(`{"type":"object"}`),
},
}
result := dynamicToolsFromSDK(logger, tools)
require.Len(t, result, 1)
info := result[0].Info()
require.Equal(t, "bare_schema", info.Name)
require.Nil(t, info.Parameters)
require.Nil(t, info.Required)
})
}
+8 -24
View File
@@ -180,7 +180,7 @@ func generateTitle(
model fantasy.LanguageModel,
input string,
) (string, error) {
title, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
title, _, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
if err != nil {
return "", err
}
@@ -192,24 +192,6 @@ func generateStructuredTitle(
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, error) {
title, _, err := generateStructuredTitleWithUsage(
ctx,
model,
systemPrompt,
userInput,
)
if err != nil {
return "", err
}
return title, nil
}
func generateStructuredTitleWithUsage(
ctx context.Context,
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, fantasy.Usage, error) {
userInput = strings.TrimSpace(userInput)
if userInput == "" {
@@ -244,6 +226,8 @@ func generateStructuredTitleWithUsage(
return genErr
}, nil)
if err != nil {
// Extract usage from the error when available so that
// failed attempts are still accounted for in usage tracking.
var usage fantasy.Usage
var noObjErr *fantasy.NoObjectGeneratedError
if errors.As(err, &noObjErr) {
@@ -545,7 +529,7 @@ func generateManualTitle(
userInput = strings.TrimSpace(firstUserText)
}
title, usage, err := generateStructuredTitleWithUsage(
title, usage, err := generateStructuredTitle(
titleCtx,
fallbackModel,
systemPrompt,
@@ -595,7 +579,7 @@ func generatePushSummary(
candidates = append(candidates, fallbackModel)
for _, model := range candidates {
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
if err != nil {
logger.Debug(ctx, "push summary model candidate failed",
slog.Error(err),
@@ -617,7 +601,7 @@ func generateShortText(
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, error) {
) (string, fantasy.Usage, error) {
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
@@ -645,7 +629,7 @@ func generateShortText(
return genErr
}, nil)
if err != nil {
return "", xerrors.Errorf("generate short text: %w", err)
return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err)
}
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
@@ -655,5 +639,5 @@ func generateShortText(
}
}
text := normalizeShortTextOutput(contentBlocksToText(responseParts))
return text, nil
return text, response.Usage, nil
}
+4 -1
View File
@@ -515,9 +515,12 @@ func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
},
}
text, err := generateShortText(context.Background(), model, "system", "user")
text, usage, err := generateShortText(context.Background(), model, "system", "user")
require.NoError(t, err)
require.Equal(t, "Quoted summary", text)
require.Equal(t, int64(3), usage.InputTokens)
require.Equal(t, int64(2), usage.OutputTokens)
require.Equal(t, int64(5), usage.TotalTokens)
}
type stubModel struct {
+1 -1
View File
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
}
return &proto.Log{
CreatedAt: timestamppb.New(log.CreatedAt),
Output: SanitizeLogOutput(log.Output),
Output: strings.ToValidUTF8(log.Output, "❌"),
Level: proto.Log_Level(lvl),
}, nil
}
+6 -6
View File
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
require.ErrorIs(t, err, context.Canceled)
}
func TestLogSender_SanitizeOutput(t *testing.T) {
func TestLogSender_InvalidUTF8(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
@@ -243,7 +243,7 @@ func TestLogSender_SanitizeOutput(t *testing.T) {
uut.Enqueue(ls1,
Log{
CreatedAt: t0,
Output: "test log 0, src 1\x00\xc3\x28",
Output: "test log 0, src 1\xc3\x28",
Level: codersdk.LogLevelInfo,
},
Log{
@@ -260,10 +260,10 @@ func TestLogSender_SanitizeOutput(t *testing.T) {
req := testutil.TryReceive(ctx, t, fDest.reqs)
require.NotNil(t, req)
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
// preserving the valid "(" byte that follows 0xc3.
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
// interprets 0x28 as a 1-byte sequence "("
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
-11
View File
@@ -1,11 +0,0 @@
package agentsdk
import "strings"
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
// rejects NUL bytes in text columns.
func SanitizeLogOutput(s string) string {
s = strings.ToValidUTF8(s, "❌")
return strings.ReplaceAll(s, "\x00", "❌")
}
-48
View File
@@ -17,54 +17,6 @@ import (
"github.com/coder/coder/v2/testutil"
)
func TestSanitizeLogOutput(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{
name: "valid",
in: "hello world",
want: "hello world",
},
{
name: "invalid utf8",
in: "test log\xc3\x28",
want: "test log❌(",
},
{
name: "nul byte",
in: "before\x00after",
want: "before❌after",
},
{
name: "invalid utf8 and nul byte",
in: "before\x00middle\xc3\x28after",
want: "before❌middle❌(after",
},
{
name: "nul byte at edges",
in: "\x00middle\x00",
want: "❌middle❌",
},
{
name: "invalid utf8 at edges",
in: "\xc3middle\xc3",
want: "❌middle❌",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
})
}
}
func TestStartupLogsWriter_Write(t *testing.T) {
t.Parallel()
+20 -269
View File
@@ -15,7 +15,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/invopop/jsonschema"
"github.com/shopspring/decimal"
"golang.org/x/xerrors"
@@ -43,13 +42,12 @@ func CompactionThresholdKey(modelConfigID uuid.UUID) string {
type ChatStatus string
const (
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
ChatStatusRequiresAction ChatStatus = "requires_action"
ChatStatusWaiting ChatStatus = "waiting"
ChatStatusPending ChatStatus = "pending"
ChatStatusRunning ChatStatus = "running"
ChatStatusPaused ChatStatus = "paused"
ChatStatusCompleted ChatStatus = "completed"
ChatStatusError ChatStatus = "error"
)
// Chat represents a chat session with an AI agent.
@@ -214,10 +212,6 @@ type ChatMessagePart struct {
// ProviderExecuted indicates the tool call was executed by
// the provider (e.g. Anthropic computer use).
ProviderExecuted bool `json:"provider_executed,omitempty" variants:"tool-call?,tool-result?"`
// CreatedAt records when this part was produced. Present on
// tool-call and tool-result parts so the frontend can compute
// tool execution duration.
CreatedAt *time.Time `json:"created_at,omitempty" format:"date-time" variants:"tool-call?,tool-result?"`
// ContextFilePath is the absolute path of a file loaded into
// the LLM context (e.g. an AGENTS.md instruction file).
ContextFilePath string `json:"context_file_path" variants:"context-file"`
@@ -367,18 +361,6 @@ type ChatInputPart struct {
Content string `json:"content,omitempty"`
}
// SubmitToolResultsRequest is the body for POST /chats/{id}/tool-results.
type SubmitToolResultsRequest struct {
Results []ToolResult `json:"results"`
}
// ToolResult is the client's response to a dynamic tool call.
type ToolResult struct {
ToolCallID string `json:"tool_call_id"`
Output json.RawMessage `json:"output"`
IsError bool `json:"is_error"`
}
// CreateChatRequest is the request to create a new chat.
type CreateChatRequest struct {
Content []ChatInputPart `json:"content"`
@@ -387,10 +369,6 @@ type CreateChatRequest struct {
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
Labels map[string]string `json:"labels,omitempty"`
// UnsafeDynamicTools declares client-executed tools that the
// LLM can invoke. This API is highly experimental and highly
// subject to change.
UnsafeDynamicTools []DynamicTool `json:"unsafe_dynamic_tools,omitempty"`
}
// UpdateChatRequest is the request to update a chat.
@@ -567,17 +545,6 @@ type UpdateChatWorkspaceTTLRequest struct {
WorkspaceTTLMillis int64 `json:"workspace_ttl_ms"`
}
// ChatRetentionDaysResponse contains the current chat retention setting.
type ChatRetentionDaysResponse struct {
RetentionDays int32 `json:"retention_days"`
}
// UpdateChatRetentionDaysRequest is a request to update the chat
// retention period.
type UpdateChatRetentionDaysRequest struct {
RetentionDays int32 `json:"retention_days"`
}
// ParseChatWorkspaceTTL parses a stored TTL string, returning the
// default when the value is empty.
func ParseChatWorkspaceTTL(s string) (time.Duration, error) {
@@ -950,13 +917,12 @@ type ChatDiffContents struct {
type ChatStreamEventType string
const (
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
ChatStreamEventTypeMessage ChatStreamEventType = "message"
ChatStreamEventTypeStatus ChatStreamEventType = "status"
ChatStreamEventTypeError ChatStreamEventType = "error"
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
ChatStreamEventTypeActionRequired ChatStreamEventType = "action_required"
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
ChatStreamEventTypeMessage ChatStreamEventType = "message"
ChatStreamEventTypeStatus ChatStreamEventType = "status"
ChatStreamEventTypeError ChatStreamEventType = "error"
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
)
// ChatQueuedMessage represents a queued message waiting to be processed.
@@ -1011,123 +977,16 @@ type ChatStreamRetry struct {
RetryingAt time.Time `json:"retrying_at" format:"date-time"`
}
// ChatStreamActionRequired is the payload of an action_required stream event.
type ChatStreamActionRequired struct {
ToolCalls []ChatStreamToolCall `json:"tool_calls"`
}
// ChatStreamToolCall describes a pending dynamic tool call that the client
// must execute.
type ChatStreamToolCall struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Args string `json:"args"`
}
// DynamicToolCall represents a pending tool invocation from the
// chat stream that the client must execute and submit back.
type DynamicToolCall struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Args string `json:"args"`
}
// DynamicToolResponse holds the output of a dynamic tool
// execution. IsError indicates a tool-level error the LLM
// should see, as opposed to an infrastructure failure
// (returned as the error return value).
type DynamicToolResponse struct {
Content string `json:"content"`
IsError bool `json:"is_error"`
}
// DynamicTool describes a client-declared tool definition. On the
// client side, the Handler callback executes the tool when the LLM
// invokes it. On the server side, only Name, Description, and
// InputSchema are used (Handler is not serialized).
type DynamicTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
// InputSchema's JSON key "input_schema" uses snake_case for
// SDK consistency, deviating from the camelCase "inputSchema"
// convention used by MCP.
InputSchema json.RawMessage `json:"input_schema"`
// Handler executes the tool when the LLM invokes it.
// Not serialized — this only exists on the client side.
Handler func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) `json:"-"`
}
// NewDynamicTool creates a DynamicTool with a typed handler.
// The JSON schema is derived from T using invopop/jsonschema.
// The handler receives deserialized args and the DynamicToolCall metadata.
func NewDynamicTool[T any](
name, description string,
handler func(ctx context.Context, args T, call DynamicToolCall) (DynamicToolResponse, error),
) DynamicTool {
reflector := jsonschema.Reflector{
DoNotReference: true,
Anonymous: true,
AllowAdditionalProperties: true,
}
schema := reflector.Reflect(new(T))
schema.Version = ""
schemaJSON, err := json.Marshal(schema)
if err != nil {
panic(fmt.Sprintf("codersdk: failed to marshal schema for %q: %v", name, err))
}
return DynamicTool{
Name: name,
Description: description,
InputSchema: schemaJSON,
Handler: func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) {
var parsed T
if err := json.Unmarshal([]byte(call.Args), &parsed); err != nil {
return DynamicToolResponse{
Content: fmt.Sprintf("invalid parameters: %s", err),
IsError: true,
}, nil
}
return handler(ctx, parsed, call)
},
}
}
// ChatWatchEventKind represents the kind of event in the chat watch stream.
type ChatWatchEventKind string
const (
ChatWatchEventKindStatusChange ChatWatchEventKind = "status_change"
ChatWatchEventKindTitleChange ChatWatchEventKind = "title_change"
ChatWatchEventKindCreated ChatWatchEventKind = "created"
ChatWatchEventKindDeleted ChatWatchEventKind = "deleted"
ChatWatchEventKindDiffStatusChange ChatWatchEventKind = "diff_status_change"
ChatWatchEventKindActionRequired ChatWatchEventKind = "action_required"
)
// ChatWatchEvent represents an event from the global chat watch stream.
// It delivers lifecycle events (created, status change, title change)
// for all of the authenticated user's chats. When Kind is
// ActionRequired, ToolCalls contains the pending dynamic tool
// invocations the client must execute and submit back.
type ChatWatchEvent struct {
Kind ChatWatchEventKind `json:"kind"`
Chat Chat `json:"chat"`
ToolCalls []ChatStreamToolCall `json:"tool_calls,omitempty"`
}
// ChatStreamEvent represents a real-time update for chat streaming.
type ChatStreamEvent struct {
Type ChatStreamEventType `json:"type"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Message *ChatMessage `json:"message,omitempty"`
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
Status *ChatStreamStatus `json:"status,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
Retry *ChatStreamRetry `json:"retry,omitempty"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
Type ChatStreamEventType `json:"type"`
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
Message *ChatMessage `json:"message,omitempty"`
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
Status *ChatStreamStatus `json:"status,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
Retry *ChatStreamRetry `json:"retry,omitempty"`
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
}
type chatStreamEnvelope struct {
@@ -1808,33 +1667,6 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd
return nil
}
// GetChatRetentionDays returns the configured chat retention period.
func (c *ExperimentalClient) GetChatRetentionDays(ctx context.Context) (ChatRetentionDaysResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/retention-days", nil)
if err != nil {
return ChatRetentionDaysResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatRetentionDaysResponse{}, ReadBodyAsError(res)
}
var resp ChatRetentionDaysResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// UpdateChatRetentionDays updates the chat retention period.
func (c *ExperimentalClient) UpdateChatRetentionDays(ctx context.Context, req UpdateChatRetentionDaysRequest) error {
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/retention-days", req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist.
func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil)
@@ -2070,73 +1902,6 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
}), nil
}
// WatchChats streams lifecycle events for all of the authenticated
// user's chats in real time. The returned channel emits
// ChatWatchEvent values for status changes, title changes, creation,
// deletion, diff-status changes, and action-required notifications.
// Callers must close the returned io.Closer to release the websocket
// connection when done.
func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEvent, io.Closer, error) {
conn, err := c.Dial(
ctx,
"/api/experimental/chats/watch",
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
)
if err != nil {
return nil, nil, err
}
conn.SetReadLimit(1 << 22) // 4MiB
streamCtx, streamCancel := context.WithCancel(ctx)
events := make(chan ChatWatchEvent, 128)
go func() {
defer close(events)
defer streamCancel()
defer func() {
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
for {
var envelope chatStreamEnvelope
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
if streamCtx.Err() != nil {
return
}
switch websocket.CloseStatus(err) {
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
return
}
return
}
switch envelope.Type {
case ServerSentEventTypePing:
continue
case ServerSentEventTypeData:
var event ChatWatchEvent
if err := json.Unmarshal(envelope.Data, &event); err != nil {
return
}
select {
case <-streamCtx.Done():
return
case events <- event:
}
case ServerSentEventTypeError:
return
default:
return
}
}
}()
return events, closeFunc(func() error {
streamCancel()
return nil
}), nil
}
// GetChat returns a chat by ID.
func (c *ExperimentalClient) GetChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
@@ -2444,20 +2209,6 @@ func (c *ExperimentalClient) GetMyChatUsageLimitStatus(ctx context.Context) (Cha
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// SubmitToolResults submits the results of dynamic tool calls for a chat
// that is in requires_action status.
func (c *ExperimentalClient) SubmitToolResults(ctx context.Context, chatID uuid.UUID, req SubmitToolResultsRequest) error {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/tool-results", chatID), req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
// GetChatsByWorkspace returns a mapping of workspace ID to the latest
// non-archived chat ID for each requested workspace. Workspaces with
// no chats are omitted from the response.
-98
View File
@@ -329,42 +329,6 @@ func TestChatMessagePartVariantTags(t *testing.T) {
})
}
func TestChatMessagePart_CreatedAt_JSON(t *testing.T) {
t.Parallel()
t.Run("RoundTrips", func(t *testing.T) {
t.Parallel()
ts := time.Date(2025, 6, 15, 12, 30, 0, 0, time.UTC)
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: "tc-1",
ToolName: "execute",
CreatedAt: &ts,
}
data, err := json.Marshal(part)
require.NoError(t, err)
require.Contains(t, string(data), `"created_at"`)
var decoded codersdk.ChatMessagePart
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.CreatedAt)
require.True(t, ts.Equal(*decoded.CreatedAt))
})
t.Run("OmittedWhenNil", func(t *testing.T) {
t.Parallel()
part := codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: "tc-1",
ToolName: "execute",
}
data, err := json.Marshal(part)
require.NoError(t, err)
require.NotContains(t, string(data), `"created_at"`)
})
}
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
t.Parallel()
@@ -505,68 +469,6 @@ func TestChat_JSONRoundTrip(t *testing.T) {
require.Equal(t, original, decoded)
}
func TestNewDynamicTool(t *testing.T) {
t.Parallel()
type testArgs struct {
Query string `json:"query"`
}
t.Run("CorrectSchema", func(t *testing.T) {
t.Parallel()
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
return codersdk.DynamicToolResponse{Content: args.Query}, nil
},
)
require.Equal(t, "search", tool.Name)
require.Equal(t, "search things", tool.Description)
require.Contains(t, string(tool.InputSchema), `"query"`)
require.Contains(t, string(tool.InputSchema), `"string"`)
})
t.Run("HandlerReceivesArgs", func(t *testing.T) {
t.Parallel()
var received testArgs
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
received = args
return codersdk.DynamicToolResponse{Content: "ok"}, nil
},
)
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
Args: `{"query":"hello"}`,
})
require.NoError(t, err)
require.Equal(t, "ok", resp.Content)
require.Equal(t, "hello", received.Query)
})
t.Run("InvalidJSONArgs", func(t *testing.T) {
t.Parallel()
tool := codersdk.NewDynamicTool(
"search", "search things",
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
return codersdk.DynamicToolResponse{Content: "should not reach"}, nil
},
)
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
Args: "not-json",
})
require.NoError(t, err)
require.True(t, resp.IsError)
require.Contains(t, resp.Content, "invalid parameters")
})
}
//nolint:tparallel,paralleltest
func TestParseChatWorkspaceTTL(t *testing.T) {
t.Parallel()
+2 -30
View File
@@ -196,7 +196,6 @@ const (
FeatureWorkspaceExternalAgent FeatureName = "workspace_external_agent"
FeatureAIBridge FeatureName = "aibridge"
FeatureBoundary FeatureName = "boundary"
FeatureServiceAccounts FeatureName = "service_accounts"
FeatureAIGovernanceUserLimit FeatureName = "ai_governance_user_limit"
)
@@ -228,7 +227,6 @@ var (
FeatureWorkspaceExternalAgent,
FeatureAIBridge,
FeatureBoundary,
FeatureServiceAccounts,
FeatureAIGovernanceUserLimit,
}
@@ -277,7 +275,6 @@ func (n FeatureName) AlwaysEnable() bool {
FeatureWorkspacePrebuilds: true,
FeatureWorkspaceExternalAgent: true,
FeatureBoundary: true,
FeatureServiceAccounts: true,
}[n]
}
@@ -285,7 +282,7 @@ func (n FeatureName) AlwaysEnable() bool {
func (n FeatureName) Enterprise() bool {
switch n {
// Add all features that should be excluded in the Enterprise feature set.
case FeatureMultipleOrganizations, FeatureCustomRoles, FeatureServiceAccounts:
case FeatureMultipleOrganizations, FeatureCustomRoles:
return false
default:
return true
@@ -3624,29 +3621,6 @@ Write out the current server config as YAML to stdout.`,
YAML: "acquireBatchSize",
Hidden: true, // Hidden because most operators should not need to modify this.
},
{
Name: "Chat: Pubsub Flush Interval",
Description: "The maximum time accepted chatd pubsub publishes wait before the batching loop schedules a flush.",
Flag: "chat-pubsub-flush-interval",
Env: "CODER_CHAT_PUBSUB_FLUSH_INTERVAL",
Value: &c.AI.Chat.PubsubFlushInterval,
Default: "50ms",
Group: &deploymentGroupChat,
YAML: "pubsubFlushInterval",
Annotations: serpent.Annotations{}.Mark(annotationFormatDuration, "true"),
Hidden: true,
},
{
Name: "Chat: Pubsub Queue Size",
Description: "How many chatd pubsub publishes can wait in memory for the dedicated sender path when PostgreSQL falls behind.",
Flag: "chat-pubsub-queue-size",
Env: "CODER_CHAT_PUBSUB_QUEUE_SIZE",
Value: &c.AI.Chat.PubsubQueueSize,
Default: "8192",
Group: &deploymentGroupChat,
YAML: "pubsubQueueSize",
Hidden: true,
},
// AI Bridge Options
{
Name: "AI Bridge Enabled",
@@ -4113,9 +4087,7 @@ type AIBridgeProxyConfig struct {
}
type ChatConfig struct {
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
PubsubFlushInterval serpent.Duration `json:"pubsub_flush_interval" typescript:",notnull"`
PubsubQueueSize serpent.Int64 `json:"pubsub_queue_size" typescript:",notnull"`
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
}
type AIConfig struct {
-41
View File
@@ -1,41 +0,0 @@
package codersdk
import (
"time"
"github.com/google/uuid"
)
// UserSecret represents a user secret's metadata. The secret value
// is never included in API responses.
type UserSecret struct {
ID uuid.UUID `json:"id" format:"uuid"`
Name string `json:"name"`
Description string `json:"description"`
EnvName string `json:"env_name"`
FilePath string `json:"file_path"`
CreatedAt time.Time `json:"created_at" format:"date-time"`
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
}
// CreateUserSecretRequest is the payload for creating a new user
// secret. Name and Value are required. All other fields are optional
// and default to empty string.
type CreateUserSecretRequest struct {
Name string `json:"name"`
Value string `json:"value"`
Description string `json:"description,omitempty"`
EnvName string `json:"env_name,omitempty"`
FilePath string `json:"file_path,omitempty"`
}
// UpdateUserSecretRequest is the payload for partially updating a
// user secret. At least one field must be non-nil. Pointer fields
// distinguish "not sent" (nil) from "set to empty string" (pointer
// to empty string).
type UpdateUserSecretRequest struct {
Value *string `json:"value,omitempty"`
Description *string `json:"description,omitempty"`
EnvName *string `json:"env_name,omitempty"`
FilePath *string `json:"file_path,omitempty"`
}
-191
View File
@@ -1,191 +0,0 @@
package codersdk
import (
"regexp"
"strings"
"golang.org/x/xerrors"
)
// UserSecretEnvValidationOptions controls deployment-aware behavior
// in environment variable name validation.
type UserSecretEnvValidationOptions struct {
// AIGatewayEnabled indicates that the deployment has AI Gateway
// configured. When true, AI Gateway environment variables
// (OPENAI_API_KEY, etc.) are reserved to prevent conflicts.
AIGatewayEnabled bool
}
var (
// posixEnvNameRegex matches valid POSIX environment variable names:
// must start with a letter or underscore, followed by letters,
// digits, or underscores.
posixEnvNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// reservedEnvNames are system environment variables that must not
// be overridden by user secrets. This list is intentionally
// aggressive because it is easier to remove entries later than
// to add them after users have already created conflicting
// secrets.
reservedEnvNames = map[string]struct{}{
// Core POSIX/login variables. Overriding these breaks
// basic shell and session behavior.
"PATH": {},
"HOME": {},
"SHELL": {},
"USER": {},
"LOGNAME": {},
"PWD": {},
"OLDPWD": {},
// Locale and terminal. Agents and IDEs depend on these
// being set correctly by the system.
"LANG": {},
"TERM": {},
// Shell behavior. Overriding these can silently break
// word splitting, directory resolution, and script
// execution in every shell session and agent script.
"IFS": {},
"CDPATH": {},
// Shell startup files. ENV is sourced by POSIX sh for
// interactive shells; BASH_ENV is sourced by bash for
// every non-interactive invocation (scripts, subshells).
// Allowing users to set these would inject arbitrary
// code into every shell and script in the workspace.
"ENV": {},
"BASH_ENV": {},
// Temp directories. Overriding these is a security risk
// (symlink attacks, world-readable paths).
"TMPDIR": {},
"TMP": {},
"TEMP": {},
// Host identity.
"HOSTNAME": {},
// SSH session variables. The Coder agent sets
// SSH_AUTH_SOCK in agentssh.go; the others are set by
// sshd and should never be faked.
"SSH_AUTH_SOCK": {},
"SSH_CLIENT": {},
"SSH_CONNECTION": {},
"SSH_TTY": {},
// Editor/pager. The Coder agent sets these so that git
// operations inside workspaces work non-interactively.
"EDITOR": {},
"VISUAL": {},
"PAGER": {},
// IDE integration. The agent sets these for code-server
// and VS Code Remote proxying.
"VSCODE_PROXY_URI": {},
"CS_DISABLE_GETTING_STARTED_OVERRIDE": {},
// XDG base directories. Overriding these redirects
// config, cache, and runtime data for every tool in the
// workspace.
"XDG_RUNTIME_DIR": {},
"XDG_CONFIG_HOME": {},
"XDG_DATA_HOME": {},
"XDG_CACHE_HOME": {},
"XDG_STATE_HOME": {},
// OIDC token. The Coder agent injects a short-lived
// OIDC token for cloud auth flows (e.g. GCP workload
// identity). Overriding it could break provisioner and
// agent authentication.
"OIDC_TOKEN": {},
}
// aiGatewayReservedEnvNames are reserved only when AI Gateway
// is enabled on the deployment. When AI Gateway is disabled,
// users may legitimately want to inject their own API keys
// via secrets.
aiGatewayReservedEnvNames = map[string]struct{}{
"OPENAI_API_KEY": {},
"OPENAI_BASE_URL": {},
"ANTHROPIC_AUTH_TOKEN": {},
"ANTHROPIC_BASE_URL": {},
}
// reservedEnvPrefixes are namespace prefixes where every
// variable in the family is reserved. Checked after the
// exact-name map. The CODER / CODER_* namespace is handled
// separately with its own error message (see below).
reservedEnvPrefixes = []string{
// The Coder agent sets GIT_SSH_COMMAND, GIT_ASKPASS,
// GIT_AUTHOR_*, GIT_COMMITTER_*, and several others.
// Blocking the entire GIT_* namespace avoids an arms
// race with new git env vars.
"GIT_",
// Locale variables. LC_ALL, LC_CTYPE, LC_MESSAGES,
// etc. control character encoding, sorting, and
// formatting. Overriding them can break text
// processing in agents and IDEs.
"LC_",
// Dynamic linker variables. Allowing users to set
// these would let a secret inject arbitrary shared
// libraries into every process in the workspace.
"LD_",
"DYLD_",
}
)
// UserSecretEnvNameValid validates an environment variable name for
// a user secret. Empty string is allowed (means no env injection).
// The opts parameter controls deployment-aware checks such as AI
// bridge variable reservation.
func UserSecretEnvNameValid(s string, opts UserSecretEnvValidationOptions) error {
if s == "" {
return nil
}
if !posixEnvNameRegex.MatchString(s) {
return xerrors.New("must start with a letter or underscore, followed by letters, digits, or underscores")
}
upper := strings.ToUpper(s)
if _, ok := reservedEnvNames[upper]; ok {
return xerrors.Errorf("%s is a reserved environment variable name", upper)
}
if upper == "CODER" || strings.HasPrefix(upper, "CODER_") {
return xerrors.New("environment variable names starting with CODER_ are reserved for internal use")
}
for _, prefix := range reservedEnvPrefixes {
if strings.HasPrefix(upper, prefix) {
return xerrors.Errorf("environment variables starting with %s are reserved", prefix)
}
}
if opts.AIGatewayEnabled {
if _, ok := aiGatewayReservedEnvNames[upper]; ok {
return xerrors.Errorf("%s is reserved when AI Gateway is enabled", upper)
}
}
return nil
}
// UserSecretFilePathValid validates a file path for a user secret.
// Empty string is allowed (means no file injection). Non-empty paths
// must start with ~/ or /.
func UserSecretFilePathValid(s string) error {
if s == "" {
return nil
}
if strings.HasPrefix(s, "~/") || strings.HasPrefix(s, "/") {
return nil
}
return xerrors.New("file path must start with ~/ or /")
}
-183
View File
@@ -1,183 +0,0 @@
package codersdk_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/codersdk"
)
func TestUserSecretEnvNameValid(t *testing.T) {
t.Parallel()
// noAIGateway is the default for most tests — AI Gateway disabled.
noAIGateway := codersdk.UserSecretEnvValidationOptions{}
withAIGateway := codersdk.UserSecretEnvValidationOptions{AIGatewayEnabled: true}
tests := []struct {
name string
input string
opts codersdk.UserSecretEnvValidationOptions
wantErr bool
errMsg string
}{
// Valid names.
{name: "SimpleUpper", input: "GITHUB_TOKEN", opts: noAIGateway},
{name: "SimpleLower", input: "github_token", opts: noAIGateway},
{name: "StartsWithUnderscore", input: "_FOO", opts: noAIGateway},
{name: "SingleChar", input: "A", opts: noAIGateway},
{name: "WithDigits", input: "A1B2", opts: noAIGateway},
{name: "Empty", input: "", opts: noAIGateway},
// Invalid POSIX names.
{name: "StartsWithDigit", input: "1FOO", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsHyphen", input: "FOO-BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsDot", input: "FOO.BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
{name: "ContainsSpace", input: "FOO BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
// Reserved system names — core POSIX/login.
{name: "ReservedPATH", input: "PATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedHOME", input: "HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSHELL", input: "SHELL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedUSER", input: "USER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedLOGNAME", input: "LOGNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedPWD", input: "PWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedOLDPWD", input: "OLDPWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — locale/terminal.
{name: "ReservedLANG", input: "LANG", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTERM", input: "TERM", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — shell behavior.
{name: "ReservedIFS", input: "IFS", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedCDPATH", input: "CDPATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — shell startup files.
{name: "ReservedENV", input: "ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedBASH_ENV", input: "BASH_ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — temp directories.
{name: "ReservedTMPDIR", input: "TMPDIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTMP", input: "TMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedTEMP", input: "TEMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — host identity.
{name: "ReservedHOSTNAME", input: "HOSTNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — SSH.
{name: "ReservedSSH_AUTH_SOCK", input: "SSH_AUTH_SOCK", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_CLIENT", input: "SSH_CLIENT", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_CONNECTION", input: "SSH_CONNECTION", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedSSH_TTY", input: "SSH_TTY", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — editor/pager.
{name: "ReservedEDITOR", input: "EDITOR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedVISUAL", input: "VISUAL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedPAGER", input: "PAGER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — IDE integration.
{name: "ReservedVSCODE_PROXY_URI", input: "VSCODE_PROXY_URI", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedCS_DISABLE", input: "CS_DISABLE_GETTING_STARTED_OVERRIDE", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — XDG.
{name: "ReservedXDG_RUNTIME_DIR", input: "XDG_RUNTIME_DIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_CONFIG_HOME", input: "XDG_CONFIG_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_DATA_HOME", input: "XDG_DATA_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_CACHE_HOME", input: "XDG_CACHE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
{name: "ReservedXDG_STATE_HOME", input: "XDG_STATE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// Reserved system names — OIDC.
{name: "ReservedOIDC_TOKEN", input: "OIDC_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// AI Gateway vars — blocked when AI Gateway is enabled.
{name: "AIGateway/OPENAI_API_KEY/Enabled", input: "OPENAI_API_KEY", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/OPENAI_BASE_URL/Enabled", input: "OPENAI_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Enabled", input: "ANTHROPIC_AUTH_TOKEN", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
{name: "AIGateway/ANTHROPIC_BASE_URL/Enabled", input: "ANTHROPIC_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
// AI Gateway vars — allowed when AI Gateway is disabled.
{name: "AIGateway/OPENAI_API_KEY/Disabled", input: "OPENAI_API_KEY", opts: noAIGateway},
{name: "AIGateway/OPENAI_BASE_URL/Disabled", input: "OPENAI_BASE_URL", opts: noAIGateway},
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Disabled", input: "ANTHROPIC_AUTH_TOKEN", opts: noAIGateway},
{name: "AIGateway/ANTHROPIC_BASE_URL/Disabled", input: "ANTHROPIC_BASE_URL", opts: noAIGateway},
// Case insensitivity.
{name: "ReservedCaseInsensitive", input: "path", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
// CODER_ prefix.
{name: "CoderExact", input: "CODER", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderPrefix", input: "CODER_WORKSPACE_NAME", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderAgentToken", input: "CODER_AGENT_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
{name: "CoderLowerCase", input: "coder_foo", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
// GIT_* prefix.
{name: "GitSSHCommand", input: "GIT_SSH_COMMAND", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitAskpass", input: "GIT_ASKPASS", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitAuthorName", input: "GIT_AUTHOR_NAME", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
{name: "GitLowerCase", input: "git_editor", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
// LC_* prefix (locale).
{name: "LcAll", input: "LC_ALL", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
{name: "LcCtype", input: "LC_CTYPE", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
// LD_* prefix (dynamic linker).
{name: "LdPreload", input: "LD_PRELOAD", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
{name: "LdLibraryPath", input: "LD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
// DYLD_* prefix (macOS dynamic linker).
{name: "DyldInsert", input: "DYLD_INSERT_LIBRARIES", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
{name: "DyldLibraryPath", input: "DYLD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := codersdk.UserSecretEnvNameValid(tt.input, tt.opts)
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestUserSecretFilePathValid(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
wantErr bool
}{
// Valid paths.
{name: "TildePath", input: "~/foo"},
{name: "TildeSSH", input: "~/.ssh/id_rsa"},
{name: "AbsolutePath", input: "/home/coder/.ssh/id_rsa"},
{name: "RootPath", input: "/"},
{name: "Empty", input: ""},
// Invalid paths.
{name: "BareRelative", input: "foo/bar", wantErr: true},
{name: "DotRelative", input: ".ssh/id_rsa", wantErr: true},
{name: "JustFilename", input: "credentials", wantErr: true},
{name: "TildeNoSlash", input: "~foo", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := codersdk.UserSecretFilePathValid(tt.input)
if tt.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), "must start with")
} else {
assert.NoError(t, err)
}
})
}
}
+19 -56
View File
@@ -211,53 +211,33 @@ Coder releases are initiated via
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh)
and automated via GitHub Actions. Specifically, the
[`release.yaml`](https://github.com/coder/coder/blob/main/.github/workflows/release.yaml)
workflow.
workflow. They are created based on the current
[`main`](https://github.com/coder/coder/tree/main) branch.
Release notes are automatically generated from commit titles and PR metadata.
The release notes for a release are automatically generated from commit titles
and metadata from PRs that are merged into `main`.
### Release types
### Creating a release
| Type | Tag | Branch | Purpose |
|------------------------|---------------|---------------|-----------------------------------------|
| RC (release candidate) | `vX.Y.0-rc.W` | `main` | Ad-hoc pre-release for customer testing |
| Release | `vX.Y.0` | `release/X.Y` | First release of a minor version |
| Patch | `vX.Y.Z` | `release/X.Y` | Bug fixes and security patches |
The creation of a release is initiated via
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh).
This script will show a preview of the release that will be created, and if you
choose to continue, create and push the tag which will trigger the creation of
the release via GitHub Actions.
### Workflow
RC tags are created directly on `main`. The `release/X.Y` branch is only cut
when the release is ready. This avoids cherry-picking main's progress onto
a release branch between the first RC and the release.
```text
main: ──●──●──●──●──●──●──●──●──●──
↑ ↑ ↑
rc.0 rc.1 cut release/2.34, tag v2.34.0
\
release/2.34: ──●── v2.34.1 (patch)
```
1. **RC:** On `main`, run `./scripts/release.sh`. The tool suggests the next
RC version and tags it on `main`.
2. **Release:** When the RC is blessed, create `release/X.Y` from `main` (or
the specific RC commit). Switch to that branch and run
`./scripts/release.sh`, which suggests `vX.Y.0`.
3. **Patch:** Cherry-pick fixes onto `release/X.Y` and run
`./scripts/release.sh` from that branch.
The release tool warns if you try to tag a non-RC on `main` or an RC on a
release branch.
See `./scripts/release.sh --help` for more information.
### Creating a release (via workflow dispatch)
If the
[`release.yaml`](https://github.com/coder/coder/actions/workflows/release.yaml)
workflow fails after the tag has been pushed, retry it from the GitHub Actions
UI: press "Run workflow", set "Use workflow from" to the tag (e.g.
`Tag: v2.34.0`), select the correct release channel, and do **not** select
dry-run.
Typically the workflow dispatch is only used to test (dry-run) a release,
meaning no actual release will take place. The workflow can be dispatched
manually from
[Actions: Release](https://github.com/coder/coder/actions/workflows/release.yaml).
Simply press "Run workflow" and choose dry-run.
To test the workflow without publishing, select dry-run.
If a release has failed after the tag has been created and pushed, it can be
retried by again, pressing "Run workflow", changing "Use workflow from" from
"Branch: main" to "Tag: vX.X.X" and not selecting dry-run.
### Commit messages
@@ -291,23 +271,6 @@ specification, however, it's still possible to merge PRs on GitHub with a badly
formatted title. Take care when merging single-commit PRs as GitHub may prefer
to use the original commit title instead of the PR title.
### Backporting fixes to release branches
When a merged PR on `main` should also ship in older releases, add the
`backport` label to the PR. The
[backport workflow](https://github.com/coder/coder/blob/main/.github/workflows/backport.yaml)
will automatically detect the latest three `release/*` branches,
cherry-pick the merge commit onto each one, and open PRs for
review.
The label can be added before or after the PR is merged. Each backport
PR reuses the original title (e.g.
`fix(site): correct button alignment (#12345)`) so the change is
meaningful in release notes.
If the cherry-pick encounters conflicts, the backport PR is still created
with instructions for manual resolution — no conflict markers are committed.
### Breaking changes
Breaking changes can be triggered in two ways:
-6
View File
@@ -150,12 +150,6 @@ deployment. They will always be available from the agent.
| `coder_derp_server_sent_pong_total` | counter | Total pongs sent. | |
| `coder_derp_server_unknown_frames_total` | counter | Total unknown frames received. | |
| `coder_derp_server_watchers` | gauge | Current watchers. | |
| `coder_pubsub_batch_delegate_fallbacks_total` | counter | The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage. | `channel_class` `reason` `stage` |
| `coder_pubsub_batch_flush_duration_seconds` | histogram | The time spent flushing one chatd batch to PostgreSQL. | `reason` |
| `coder_pubsub_batch_queue_depth` | gauge | The number of chatd notifications waiting in the batching queue. | |
| `coder_pubsub_batch_sender_reset_failures_total` | counter | The number of batched pubsub sender reset attempts that failed. | |
| `coder_pubsub_batch_sender_resets_total` | counter | The number of successful batched pubsub sender resets after flush failures. | |
| `coder_pubsub_batch_size` | histogram | The number of logical notifications sent in each chatd batch flush. | |
| `coder_pubsub_connected` | gauge | Whether we are connected (1) or not connected (0) to postgres | |
| `coder_pubsub_current_events` | gauge | The current number of pubsub event channels listened for | |
| `coder_pubsub_current_subscribers` | gauge | The current number of active pubsub subscribers | |
@@ -37,11 +37,14 @@ resource "docker_container" "workspace" {
resource "coder_agent" "main" {
arch = data.coder_provisioner.me.arch
os = "linux"
startup_script = <<-EOF
startup_script = <<EOF
#!/bin/sh
set -e
sudo service docker start
EOF
# Start Docker
sudo dockerd &
# ...
EOF
}
```
@@ -75,10 +78,13 @@ resource "coder_agent" "main" {
os = "linux"
arch = "amd64"
dir = "/home/coder"
startup_script = <<-EOF
startup_script = <<EOF
#!/bin/sh
set -e
sudo service docker start
# Start Docker
sudo dockerd &
# ...
EOF
}
+8 -15
View File
@@ -1,38 +1,31 @@
# Headless Authentication
> [!NOTE]
> Creating service accounts requires a [Premium license](https://coder.com/pricing).
Headless user accounts that cannot use the web UI to log in to Coder. This is
useful for creating accounts for automated systems, such as CI/CD pipelines or
for users who only consume Coder via another client/API.
Service accounts are headless user accounts that cannot use the web UI to log in
to Coder. This is useful for creating accounts for automated systems, such as
CI/CD pipelines or for users who only consume Coder via another client/API. Service accounts do not have passwords or associated email addresses.
You must have the User Admin role or above to create headless users.
You must have the User Admin role or above to create service accounts.
## Create a service account
## Create a headless user
<div class="tabs">
## CLI
Use the `--service-account` flag to create a dedicated service account:
```sh
coder users create \
--email="coder-bot@coder.com" \
--username="coder-bot" \
--service-account
--login-type="none" \
```
## UI
Navigate to **Deployment** > **Users** > **Create user**, then select
**Service account** as the login type.
Navigate to the `Users` > `Create user` in the topbar
![Create a user via the UI](../../images/admin/users/headless-user.png)
</div>
## Authenticate as a service account
To make API or CLI requests on behalf of the headless user, learn how to
[generate API tokens on behalf of a user](./sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-another-user).
+3 -13
View File
@@ -180,15 +180,6 @@ configuration set by an administrator.
|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
| `web_search` | Searches the internet for up-to-date information. Available when web search is enabled for the configured Anthropic, OpenAI, or Google provider. |
### Workspace extension tools
These tools are conditionally available based on the workspace contents.
| Tool | What it does |
|-------------------|--------------------------------------------------------------------------------------------------------------------------------|
| `read_skill` | Reads the instructions for a workspace skill by name. Available when the workspace has skills discovered in `.agents/skills/`. |
| `read_skill_file` | Reads a supporting file from a skill's directory. |
## What runs where
Understanding the split between the control plane and the workspace is central
@@ -233,11 +224,10 @@ Because state lives in the database:
- The agent can resume work by targeting a new workspace and continuing from the
last git branch or checkpoint.
## Security posture
## Security implications
The control plane architecture provides built-in security properties for AI
coding workflows. These are structural guarantees, not configuration options —
they hold by default for every agent session.
The control plane architecture provides several security advantages for AI
coding workflows.
### No API keys in workspaces
+6 -3
View File
@@ -65,9 +65,12 @@ Once the server restarts with the experiment enabled:
1. Navigate to the **Agents** page in the Coder dashboard.
1. Open **Admin** settings and configure at least one LLM provider and model.
See [Models](./models.md) for detailed setup instructions.
1. Grant the **Coder Agents User** role to users who need to create chats.
Go to **Admin** > **Users**, click the roles icon next to each user,
and enable **Coder Agents User**.
1. Grant the **Coder Agents User** role to existing users who need to create
chats. New users receive the role automatically. For existing users, go to
**Admin** > **Users**, click the roles icon next to each user, and enable
**Coder Agents User**. See
[Grant Coder Agents User](./getting-started.md#step-3-grant-coder-agents-user)
for a bulk CLI option.
1. Developers can then start a new chat from the Agents page.
## Licensing and availability
+22 -10
View File
@@ -24,8 +24,9 @@ Before you begin, confirm the following:
for the agent to select when provisioning workspaces.
- **Admin access** to the Coder deployment for enabling the experiment and
configuring providers.
- **Coder Agents User role** assigned to each user who needs to interact with Coder Agents.
Owners can assign this from **Admin** > **Users**. See
- **Coder Agents User role** is automatically assigned to new users when the
`agents` experiment is enabled. For existing users, owners can assign it from
**Admin** > **Users**. See
[Grant Coder Agents User](#step-3-grant-coder-agents-user) below.
## Step 1: Enable the experiment
@@ -74,14 +75,20 @@ Detailed instructions for each provider and model option are in the
## Step 3: Grant Coder Agents User
The **Coder Agents User** role controls which users can interact with Coder Agents.
Members do not have Coder Agents User by default.
The **Coder Agents User** role controls which users can interact with
Coder Agents.
Owners always have full access and do not need the role. Repeat the following steps for each user who needs access.
### New users
> [!NOTE]
> Users who created conversations before this role was introduced are
> automatically granted the role during upgrade.
When the `agents` experiment is enabled, new users are automatically
assigned the **Coder Agents User** role at account creation. No admin
action is required.
### Existing users
Users who were created before the experiment was enabled do not receive
the role automatically. Owners can assign it from the dashboard or in
bulk via the CLI.
**Dashboard (individual):**
@@ -91,8 +98,7 @@ Owners always have full access and do not need the role. Repeat the following st
**CLI (bulk):**
You can also grant the role via CLI. For example, to grant the role to
all active users at once:
To grant the role to all active users at once:
```sh
coder users list -o json \
@@ -105,6 +111,12 @@ coder users list -o json \
done
```
Owners always have full access and do not need the role.
> [!NOTE]
> Users who created conversations before this role was introduced are
> automatically granted the role during upgrade.
## Step 4: Start your first Coder Agent
1. Go to the **Agents** page in the Coder dashboard.
+23 -31
View File
@@ -232,43 +232,35 @@ model. Developers select from enabled models when starting a chat.
The agent has access to a set of workspace tools that it uses to accomplish
tasks:
| Tool | Description |
|----------------------------|--------------------------------------------------------------------------|
| `list_templates` | Browse available workspace templates |
| `read_template` | Get template details and configurable parameters |
| `create_workspace` | Create a workspace from a template |
| `start_workspace` | Start a stopped workspace for the current chat |
| `propose_plan` | Present a Markdown plan file for user review |
| `read_file` | Read file contents from the workspace |
| `write_file` | Write a file to the workspace |
| `edit_files` | Perform search-and-replace edits across files |
| `execute` | Run shell commands in the workspace |
| `process_output` | Retrieve output from a background process |
| `process_list` | List all tracked processes in the workspace |
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
| `message_agent` | Send a follow-up message to a running sub-agent |
| `close_agent` | Stop a running sub-agent |
| `spawn_computer_use_agent` | Spawn a sub-agent with desktop interaction (screenshot, mouse, keyboard) |
| `read_skill` | Read the instructions for a workspace skill by name |
| `read_skill_file` | Read a supporting file from a skill's directory |
| `web_search` | Search the internet (provider-native, when enabled) |
| Tool | Description |
|--------------------|---------------------------------------------------------|
| `list_templates` | Browse available workspace templates |
| `read_template` | Get template details and configurable parameters |
| `create_workspace` | Create a workspace from a template |
| `start_workspace` | Start a stopped workspace for the current chat |
| `propose_plan` | Present a Markdown plan file for user review |
| `read_file` | Read file contents from the workspace |
| `write_file` | Write a file to the workspace |
| `edit_files` | Perform search-and-replace edits across files |
| `execute` | Run shell commands in the workspace |
| `process_output` | Retrieve output from a background process |
| `process_list` | List all tracked processes in the workspace |
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
| `message_agent` | Send a follow-up message to a running sub-agent |
| `close_agent` | Stop a running sub-agent |
| `web_search` | Search the internet (provider-native, when enabled) |
These tools connect to the workspace over the same secure connection used for
web terminals and IDE access. No additional ports or services are required in
the workspace.
Platform tools (`list_templates`, `read_template`, `create_workspace`,
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`,
`wait_agent`, `message_agent`, `close_agent`, `spawn_computer_use_agent`)
are only available to root chats. Sub-agents do not have access to these
tools and cannot create workspaces or spawn further sub-agents.
`spawn_computer_use_agent` additionally requires an Anthropic provider and
the virtual desktop feature to be enabled by an administrator.
`read_skill` and `read_skill_file` are available when the workspace contains
skills in its `.agents/skills/` directory.
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`)
are only available to root chats. Sub-agents do
not have access to these tools and cannot create workspaces or spawn further
sub-agents.
## Comparison to Coder Tasks
+13 -87
View File
@@ -1,13 +1,10 @@
# Models
Administrators configure LLM providers and models from the Coder dashboard.
Providers, models, and API keys are deployment-wide settings managed by
platform teams. Developers select from the set of models that an administrator
These are deployment-wide settings — developers do not manage API keys or
provider configuration. They select from the set of models that an administrator
has enabled.
Optionally, administrators can allow developers to supply their own API keys
for specific providers. See [User API keys](#user-api-keys-byok) below.
## Providers
Each LLM provider has a type, an API key, and an optional base URL override.
@@ -60,38 +57,6 @@ access to LLM providers. See
[Architecture](./architecture.md#no-api-keys-in-workspaces) for details
on this security model.
### Key policy
Each provider has three policy flags that control how API keys are sourced:
| Setting | Default | Description |
|-------------------------|---------|-----------------------------------------------------------------------------------------------------|
| Central API key | On | The provider uses a deployment-managed API key entered by an administrator. |
| Allow user API keys | Off | Developers may supply their own API key for this provider. |
| Central key as fallback | Off | When user keys are allowed, fall back to the central key if a developer has not set a personal key. |
At least one credential source must be enabled. These settings appear in the
provider configuration form under **Key policy**.
The interaction between these flags determines whether a provider is available
to a given developer:
| Central key | User keys allowed | Fallback | Developer has key | Result |
|-------------|-------------------|----------|-------------------|----------------------|
| On | Off | — | — | Uses central key |
| Off | On | — | Yes | Uses developer's key |
| Off | On | — | No | Unavailable |
| On | On | Off | Yes | Uses developer's key |
| On | On | Off | No | Unavailable |
| On | On | On | Yes | Uses developer's key |
| On | On | On | No | Uses central key |
When a developer's personal key is present, it always takes precedence over
the central key. When user keys are required and fallback is disabled,
the provider is unavailable to developers who have not saved a personal key —
even if a central key exists. This is intentional: it enforces that each
developer authenticates with their own credentials.
## Models
Each model belongs to a provider and has its own configuration for context limits,
@@ -167,11 +132,11 @@ fields appear dynamically in the admin UI when you select a provider.
#### OpenAI
| Option | Description |
|-----------------------|-------------------------------------------------------------------------------------------|
| Reasoning Effort | How much effort the model spends reasoning (`minimal`, `low`, `medium`, `high`, `xhigh`). |
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
| Option | Description |
|-----------------------|---------------------------------------------------------------------------------------------------|
| Reasoning Effort | How much effort the model spends reasoning (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
#### Google
@@ -182,10 +147,10 @@ fields appear dynamically in the admin UI when you select a provider.
#### OpenRouter
| Option | Description |
|-------------------|---------------------------------------------------|
| Reasoning Enabled | Enable extended reasoning mode. |
| Reasoning Effort | Reasoning effort level (`low`, `medium`, `high`). |
| Option | Description |
|-------------------|-------------------------------------------------------------------------------|
| Reasoning Enabled | Enable extended reasoning mode. |
| Reasoning Effort | Reasoning effort level (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
#### Vercel AI Gateway
@@ -211,49 +176,10 @@ The model selector uses the following precedence to pre-select a model:
1. **Admin-designated default** — the model marked with the star icon.
1. **First available model** — if no default is set and no history exists.
Developers cannot add their own providers or models. If no models are
configured, the chat interface displays a message directing developers to
Developers cannot add their own providers, models, or API keys. If no models
are configured, the chat interface displays a message directing developers to
contact an administrator.
## User API keys (BYOK)
When an administrator enables **Allow user API keys** on a provider,
developers can supply their own API key from the Agents settings page.
### Managing personal API keys
1. Navigate to the **Agents** page in the Coder dashboard.
1. Open **Settings** and select the **API Keys** tab.
1. Each provider that allows user keys is listed with a status indicator:
- **Key saved** — your personal key is active and will be used for requests.
- **Using shared key** — no personal key set, but the central deployment
key is available as a fallback.
- **No key** — you must add a personal key before you can use this provider.
1. Enter your API key and click **Save**.
Personal API keys are encrypted at rest using the same database encryption
as deployment-managed keys. The dashboard never displays a saved key — only
whether one is set.
### How key selection works
When you start a chat, the control plane resolves which API key to use for
each provider:
1. If you have a personal key for the provider, it is used.
1. If you do not have a personal key and central key fallback is enabled,
the deployment-managed key is used.
1. If you do not have a personal key and fallback is disabled, the provider
is unavailable to you. Models from that provider will not appear in the
model selector.
### Removing a personal key
Click **Remove** on the provider card in the API Keys settings tab. If
central key fallback is enabled, subsequent requests will use the shared
deployment key. If fallback is disabled, the provider becomes unavailable
until you add a new personal key.
## Using an LLM proxy
Organizations that route LLM traffic through a centralized proxy — such as
@@ -1,39 +0,0 @@
# Conversation Data Retention
Coder Agents automatically cleans up old conversation data to manage database
growth. Archived conversations and their associated files are periodically
purged based on a configurable retention period.
## How it works
A background process runs approximately every 10 minutes to remove expired
conversation data. Only archived conversations are eligible for deletion —
active (non-archived) conversations are never purged.
When an archived conversation exceeds the retention period, it is deleted along
with its messages, diff statuses, and queued messages via cascade. Orphaned
files (not referenced by any active or recently-archived conversation) are also
deleted. Both operations run in batches of 1,000 rows per cycle.
## Configuration
Navigate to the **Agents** page, open **Settings**, and select the **Behavior**
tab to configure the conversation retention period. The default is 30 days. Use the toggle to
disable retention entirely.
The retention period is stored as the `agents_chat_retention_days` key in the
`site_configs` table and can also be managed via the API at
`/api/experimental/chats/config/retention-days`.
## What gets deleted
| Data | Condition | Cascade |
|------------------------|------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
| Archived conversations | Archived longer than retention period | Messages, diff statuses, queued messages deleted via CASCADE. |
| Conversation files | Older than retention period AND not referenced by any active or recently-archived conversation | — |
## Unarchive safety
If a user unarchives a conversation whose files were purged, stale file
references are automatically cleaned up by FK cascades. The conversation
remains usable but previously attached files are no longer available.
@@ -11,12 +11,11 @@ This means:
- **All agent configuration is admin-level.** Providers, models, system prompts,
and tool permissions are set by platform teams from the control plane. These
are not user preferences — they are deployment-wide policies.
- **Developers never need to configure anything by default.** A developer just
describes the work they want done. They do not need to pick a provider or
write a system prompt — the platform team has already set all of that up.
When a platform team enables user API keys for a provider, developers may
optionally supply their own key — but this is an opt-in policy decision, not
a requirement.
- **Developers never need to configure anything.** A developer just describes
the work they want done. They do not need to pick a provider, enter an API
key, or write a system prompt — the platform team has already set all of
that up. The goal is not to restrict developers, but to make configuration
unnecessary for a great experience.
- **Enforcement, not defaults.** Settings configured by administrators are
enforced server-side. Developers cannot override them. This is a deliberate
distinction — a setting that a user can change is a preference, not a policy.
@@ -37,12 +36,8 @@ self-hosted models), and per-model parameters like context limits, thinking
budgets, and reasoning effort.
Developers select from the set of models an administrator has enabled. They
cannot add their own providers or access models that have not been explicitly
configured.
When an administrator enables user API keys on a provider, developers can
supply their own key from the Agents settings page. See
[User API keys (BYOK)](../models.md#user-api-keys-byok) for details.
cannot add their own providers, supply their own API keys, or access models that
have not been explicitly configured.
See [Models](../models.md) for setup instructions.
@@ -89,30 +84,6 @@ opt-out, or opt-in for each chat.
See [MCP Servers](./mcp-servers.md) for configuration details.
### Virtual desktop
Administrators can enable a virtual desktop within agent workspaces.
When enabled, agents can use `spawn_computer_use_agent` to interact with a
desktop environment using screenshots, mouse, and keyboard input.
This setting is available under **Agents** > **Settings** > **Behavior**.
It requires:
- The [portabledesktop](https://registry.coder.com/modules/coder/portabledesktop)
module to be installed in the workspace template.
- An Anthropic provider to be configured (computer use is an Anthropic
capability).
### Workspace autostop fallback
Administrators can set a default autostop timer for agent-created workspaces
that do not define one in their template. Template-defined autostop rules always
take precedence. Active conversations extend the stop time automatically.
This setting is available under **Agents** > **Settings** > **Behavior**.
The maximum configurable value is 30 days. When disabled, workspaces follow
their template's autostop rules (or none, if the template does not define any).
### Usage limits and analytics
Administrators can set spend limits to cap LLM usage per user within a rolling
@@ -122,19 +93,10 @@ breakdowns.
See [Usage & Analytics](./usage-insights.md) for details.
### Data retention
Administrators can configure a retention period for archived conversations.
When enabled, archived conversations and orphaned files older than the
retention period are automatically purged. The default is 30 days.
This setting is available under **Agents** > **Settings** > **Behavior**.
See [Data Retention](./chat-retention.md) for details.
## Where we are headed
The controls above cover providers, models, system prompts, templates, MCP
servers, usage limits, and data retention. We are continuing to invest in platform controls
servers, and usage limits. We are continuing to invest in platform controls
based on what we hear from customers deploying agents in regulated and
enterprise environments.
@@ -83,8 +83,3 @@ Select a user to see:
bar shows current spend relative to the limit.
- **Per-model breakdown** — table of costs and token usage by model.
- **Per-chat breakdown** — table of costs and token usage by chat session.
> [!NOTE]
> Automatic title generation uses lightweight models, such as Claude Haiku or GPT-4o
> Mini. Its token usage is not counted towards usage limits or shown in usage
> summaries.
-7
View File
@@ -10,13 +10,6 @@ We provide an example Grafana dashboard that you can import as a starting point
These logs and metrics can be used to determine usage patterns, track costs, and evaluate tooling adoption.
## Structured Logging
AI Bridge can emit structured logs for every interception event to your
existing log pipeline. This is useful for exporting data to external SIEM or
observability platforms. See [Structured Logging](./setup.md#structured-logging)
in the setup guide for configuration and a full list of record types.
## Exporting Data
AI Bridge interception data can be exported for external analysis, compliance reporting, or integration with log aggregation systems.
+1 -11
View File
@@ -150,14 +150,4 @@ ingestion, set `--log-json` to a file path or `/dev/stderr` so that records are
emitted as JSON.
Filter for AI Bridge records in your logging pipeline by matching on the
`"interception log"` message. Each log line includes a `record_type` field that
indicates the kind of event captured:
| `record_type` | Description | Key fields |
|----------------------|-----------------------------------------|--------------------------------------------------------------------------------|
| `interception_start` | A new intercepted request begins. | `interception_id`, `initiator_id`, `provider`, `model`, `client`, `started_at` |
| `interception_end` | An intercepted request completes. | `interception_id`, `ended_at` |
| `token_usage` | Token consumption for a response. | `interception_id`, `input_tokens`, `output_tokens`, `created_at` |
| `prompt_usage` | The last user prompt in a request. | `interception_id`, `prompt`, `created_at` |
| `tool_usage` | A tool/function call made by the model. | `interception_id`, `tool`, `input`, `server_url`, `injected`, `created_at` |
| `model_thought` | Model reasoning or thinking content. | `interception_id`, `content`, `created_at` |
`"interception log"` message.
+3 -3
View File
@@ -83,9 +83,9 @@ pages.
| [2.26](https://coder.com/changelog/coder-2-26) | September 03, 2025 | Not Supported | [v2.26.6](https://github.com/coder/coder/releases/tag/v2.26.6) |
| [2.27](https://coder.com/changelog/coder-2-27) | October 02, 2025 | Not Supported | [v2.27.11](https://github.com/coder/coder/releases/tag/v2.27.11) |
| [2.28](https://coder.com/changelog/coder-2-28) | November 04, 2025 | Not Supported | [v2.28.11](https://github.com/coder/coder/releases/tag/v2.28.11) |
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Extended Support Release | [v2.29.9](https://github.com/coder/coder/releases/tag/v2.29.9) |
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Security Support | [v2.30.6](https://github.com/coder/coder/releases/tag/v2.30.6) |
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Stable | [v2.31.7](https://github.com/coder/coder/releases/tag/v2.31.7) |
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Security Support + ESR | [v2.29.8](https://github.com/coder/coder/releases/tag/v2.29.8) |
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Stable | [v2.30.3](https://github.com/coder/coder/releases/tag/v2.30.3) |
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Mainline | [v2.31.5](https://github.com/coder/coder/releases/tag/v2.31.5) |
| 2.32 | | Not Released | N/A |
<!-- RELEASE_CALENDAR_END -->

Some files were not shown because too many files have changed in this diff Show More