Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c174b3037b | |||
| f5165d304f |
@@ -0,0 +1,2 @@
|
||||
enabled: true
|
||||
preservePullRequestTitle: true
|
||||
@@ -31,7 +31,8 @@ updates:
|
||||
patterns:
|
||||
- "golang.org/x/*"
|
||||
ignore:
|
||||
# Ignore patch updates for all dependencies
|
||||
# Patch updates are handled by the security-patch-prs workflow so this
|
||||
# lane stays focused on broader dependency updates.
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- version-update:semver-patch
|
||||
@@ -56,7 +57,7 @@ updates:
|
||||
labels: []
|
||||
ignore:
|
||||
# We need to coordinate terraform updates with the version hardcoded in
|
||||
# our Go code.
|
||||
# our Go code. These are handled by the security-patch-prs workflow.
|
||||
- dependency-name: "terraform"
|
||||
|
||||
- package-ecosystem: "npm"
|
||||
@@ -117,11 +118,11 @@ updates:
|
||||
interval: "weekly"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
labels: []
|
||||
groups:
|
||||
coder-modules:
|
||||
patterns:
|
||||
- "coder/*/coder"
|
||||
labels: []
|
||||
ignore:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
# Automatically backport merged PRs to the last N release branches when the
|
||||
# "backport" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "backport" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branches and opens one cherry-pick PR per branch.
|
||||
#
|
||||
# The created backport PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Backport
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: backport-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
name: Detect target branches
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'backport')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
branches: ${{ steps.find.outputs.branches }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Need all refs to discover release branches.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Find latest release branches
|
||||
id: find
|
||||
run: |
|
||||
# List remote release branches matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix), sort by minor
|
||||
# version descending, and take the top 3.
|
||||
BRANCHES=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -3
|
||||
)
|
||||
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found."
|
||||
echo "branches=[]" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Convert to JSON array for the matrix.
|
||||
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
|
||||
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
|
||||
echo "Will backport to: $JSON"
|
||||
|
||||
backport:
|
||||
name: "Backport to ${{ matrix.branch }}"
|
||||
needs: detect
|
||||
if: needs.detect.outputs.branches != '[]'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
branch: ${{ fromJson(needs.detect.outputs.branches) }}
|
||||
fail-fast: false
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
RELEASE_VERSION="${{ matrix.branch }}"
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_VERSION#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICTS=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
|
||||
CONFLICTS=true
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit
|
||||
# explaining the situation.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
|
||||
|
||||
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
|
||||
Please cherry-pick manually:
|
||||
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
BODY=$(cat <<EOF
|
||||
Backport of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
EOF
|
||||
)
|
||||
|
||||
if [ "$CONFLICTS" = true ]; then
|
||||
TITLE="${TITLE} (conflicts)"
|
||||
BODY="${BODY}
|
||||
|
||||
> [!WARNING]
|
||||
> The automatic cherry-pick had conflicts.
|
||||
> Please resolve manually by cherry-picking the original merge commit:
|
||||
>
|
||||
> \`\`\`
|
||||
> git fetch origin ${BACKPORT_BRANCH}
|
||||
> git checkout ${BACKPORT_BRANCH}
|
||||
> git reset --hard origin/${RELEASE_VERSION}
|
||||
> git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
> # resolve conflicts, then push
|
||||
> \`\`\`"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs).
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
@@ -1,139 +0,0 @@
|
||||
# Automatically cherry-pick merged PRs to the latest release branch when the
|
||||
# "cherry-pick" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "cherry-pick" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branch and opens a cherry-pick PR against it.
|
||||
#
|
||||
# The created PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Cherry-pick to release
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: cherry-pick-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
cherry-pick:
|
||||
name: Cherry-pick to latest release
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick and branch discovery.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Find the latest release branch matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix).
|
||||
RELEASE_BRANCH=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -1
|
||||
)
|
||||
|
||||
if [ -z "$RELEASE_BRANCH" ]; then
|
||||
echo "::error::No release branch found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_BRANCH#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
echo "Target branch: $RELEASE_BRANCH"
|
||||
echo "Backport branch: $BACKPORT_BRANCH"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICT=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
CONFLICT=true
|
||||
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit with
|
||||
# instructions so the PR can still be opened.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
|
||||
|
||||
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
|
||||
To resolve:
|
||||
git fetch origin ${BACKPORT_BRANCH}
|
||||
git checkout ${BACKPORT_BRANCH}
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
# resolve conflicts
|
||||
git push origin ${BACKPORT_BRANCH}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
BODY=$(cat <<EOF
|
||||
Cherry-pick of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
EOF
|
||||
)
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
TITLE="[CONFLICT] ${TITLE}"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs). Use --state all to catch closed/merged PRs too.
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY"
|
||||
@@ -121,22 +121,22 @@ jobs:
|
||||
fi
|
||||
|
||||
# Derive the release branch from the version tag.
|
||||
# Non-RC releases must be on a release/X.Y branch.
|
||||
# RC tags are allowed on any branch (typically main).
|
||||
# Standard: 2.10.2 -> release/2.10
|
||||
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
|
||||
version="$(./scripts/version.sh)"
|
||||
# Strip any pre-release suffix first (e.g. 2.32.0-rc.0 -> 2.32.0)
|
||||
base_version="${version%%-*}"
|
||||
# Then strip patch to get major.minor (e.g. 2.32.0 -> 2.32)
|
||||
release_branch="release/${base_version%.*}"
|
||||
|
||||
if [[ "$version" == *-rc.* ]]; then
|
||||
echo "RC release detected — skipping release branch check (RC tags are cut from main)."
|
||||
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
|
||||
base_version="${version%%-rc.*}" # 2.32.0
|
||||
major_minor="${base_version%.*}" # 2.32
|
||||
rc_suffix="${version##*-rc.}" # 0
|
||||
release_branch="release/${major_minor}-rc.${rc_suffix}"
|
||||
else
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a non-RC release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
fi
|
||||
release_branch=release/${version%.*}
|
||||
fi
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
name: security-backport
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- labeled
|
||||
- unlabeled
|
||||
- closed
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pull_request:
|
||||
description: Pull request number to backport.
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || inputs.pull_request }}
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
LATEST_BRANCH: release/2.31
|
||||
STABLE_BRANCH: release/2.30
|
||||
STABLE_1_BRANCH: release/2.29
|
||||
|
||||
jobs:
|
||||
label-policy:
|
||||
if: github.event_name == 'pull_request_target'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Apply security backport label policy
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,baseRefName,labels)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
|
||||
pr_number = pr["number"]
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
|
||||
def has(label: str) -> bool:
|
||||
return label in labels
|
||||
|
||||
def ensure_label(label: str) -> None:
|
||||
if not has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--add-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def remove_label(label: str) -> None:
|
||||
if has(label):
|
||||
subprocess.run(
|
||||
["gh", "pr", "edit", str(pr_number), "--remove-label", label],
|
||||
check=False,
|
||||
)
|
||||
|
||||
def comment(body: str) -> None:
|
||||
subprocess.run(
|
||||
["gh", "pr", "comment", str(pr_number), "--body", body],
|
||||
check=True,
|
||||
)
|
||||
|
||||
if not has("security:patch"):
|
||||
remove_label("status:needs-severity")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if has(label)
|
||||
]
|
||||
if len(severity_labels) == 0:
|
||||
ensure_label("status:needs-severity")
|
||||
comment(
|
||||
"This PR is labeled `security:patch` but is missing a severity "
|
||||
"label. Add one of `severity:medium`, `severity:high`, or "
|
||||
"`severity:critical` before backport automation can proceed."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
if len(severity_labels) > 1:
|
||||
comment(
|
||||
"This PR has multiple severity labels. Keep exactly one of "
|
||||
"`severity:medium`, `severity:high`, or `severity:critical`."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
remove_label("status:needs-severity")
|
||||
|
||||
target_labels = [
|
||||
label
|
||||
for label in ("backport:stable", "backport:stable-1")
|
||||
if has(label)
|
||||
]
|
||||
has_none = has("backport:none")
|
||||
if has_none and target_labels:
|
||||
comment(
|
||||
"`backport:none` cannot be combined with other backport labels. "
|
||||
"Remove `backport:none` or remove the explicit backport targets."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not has_none and not target_labels:
|
||||
ensure_label("backport:stable")
|
||||
ensure_label("backport:stable-1")
|
||||
comment(
|
||||
"Applied default backport labels `backport:stable` and "
|
||||
"`backport:stable-1` for a qualifying security patch."
|
||||
)
|
||||
PY
|
||||
|
||||
backport:
|
||||
if: >
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(
|
||||
github.event_name == 'pull_request_target' &&
|
||||
github.event.pull_request.merged == true
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Resolve PR metadata
|
||||
id: metadata
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
INPUT_PR_NUMBER: ${{ inputs.pull_request }}
|
||||
LATEST_BRANCH: ${{ env.LATEST_BRANCH }}
|
||||
STABLE_BRANCH: ${{ env.STABLE_BRANCH }}
|
||||
STABLE_1_BRANCH: ${{ env.STABLE_1_BRANCH }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ "${GITHUB_EVENT_NAME}" = "workflow_dispatch" ]; then
|
||||
pr_number="${INPUT_PR_NUMBER}"
|
||||
else
|
||||
pr_number="$(jq -r '.pull_request.number' "${GITHUB_EVENT_PATH}")"
|
||||
fi
|
||||
|
||||
case "${pr_number}" in
|
||||
''|*[!0-9]*)
|
||||
echo "A valid pull request number is required."
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
pr_json="$(gh pr view "${pr_number}" --json number,title,url,mergeCommit,baseRefName,labels,mergedAt,author)"
|
||||
|
||||
PR_JSON="${pr_json}" \
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
pr = json.loads(os.environ["PR_JSON"])
|
||||
github_output = os.environ["GITHUB_OUTPUT"]
|
||||
|
||||
labels = [label["name"] for label in pr.get("labels", [])]
|
||||
if "security:patch" not in labels:
|
||||
print("Not a security patch PR; skipping.")
|
||||
sys.exit(0)
|
||||
|
||||
severity_labels = [
|
||||
label
|
||||
for label in ("severity:medium", "severity:high", "severity:critical")
|
||||
if label in labels
|
||||
]
|
||||
if len(severity_labels) != 1:
|
||||
raise SystemExit(
|
||||
"Merged security patch PR must have exactly one severity label."
|
||||
)
|
||||
|
||||
if not pr.get("mergedAt"):
|
||||
raise SystemExit(f"PR #{pr['number']} is not merged.")
|
||||
|
||||
if "backport:none" in labels:
|
||||
target_pairs = []
|
||||
else:
|
||||
mapping = {
|
||||
"backport:stable": os.environ["STABLE_BRANCH"],
|
||||
"backport:stable-1": os.environ["STABLE_1_BRANCH"],
|
||||
}
|
||||
target_pairs = []
|
||||
for label_name, branch in mapping.items():
|
||||
if label_name in labels and branch and branch != pr["baseRefName"]:
|
||||
target_pairs.append({"label": label_name, "branch": branch})
|
||||
|
||||
with open(github_output, "a", encoding="utf-8") as f:
|
||||
f.write(f"pr_number={pr['number']}\n")
|
||||
f.write(f"merge_sha={pr['mergeCommit']['oid']}\n")
|
||||
f.write(f"title={pr['title']}\n")
|
||||
f.write(f"url={pr['url']}\n")
|
||||
f.write(f"author={pr['author']['login']}\n")
|
||||
f.write(f"severity_label={severity_labels[0]}\n")
|
||||
f.write(f"target_pairs={json.dumps(target_pairs)}\n")
|
||||
PY
|
||||
|
||||
- name: Backport to release branches
|
||||
if: ${{ steps.metadata.outputs.target_pairs != '[]' }}
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ steps.metadata.outputs.pr_number }}
|
||||
MERGE_SHA: ${{ steps.metadata.outputs.merge_sha }}
|
||||
PR_TITLE: ${{ steps.metadata.outputs.title }}
|
||||
PR_URL: ${{ steps.metadata.outputs.url }}
|
||||
PR_AUTHOR: ${{ steps.metadata.outputs.author }}
|
||||
SEVERITY_LABEL: ${{ steps.metadata.outputs.severity_label }}
|
||||
TARGET_PAIRS: ${{ steps.metadata.outputs.target_pairs }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git remote set-url origin "https://x-access-token:${GH_TOKEN}@github.com/${GITHUB_REPOSITORY}.git"
|
||||
git fetch origin --prune
|
||||
|
||||
merge_parent_count="$(git rev-list --parents -n 1 "${MERGE_SHA}" | awk '{print NF-1}')"
|
||||
|
||||
failures=()
|
||||
successes=()
|
||||
|
||||
while IFS=$'\t' read -r backport_label target_branch; do
|
||||
[ -n "${target_branch}" ] || continue
|
||||
|
||||
safe_branch_name="${target_branch//\//-}"
|
||||
head_branch="backport/${safe_branch_name}/pr-${PR_NUMBER}"
|
||||
|
||||
existing_pr="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state all \
|
||||
--json number,url \
|
||||
--jq '.[0]')"
|
||||
if [ -n "${existing_pr}" ] && [ "${existing_pr}" != "null" ]; then
|
||||
pr_url="$(printf '%s' "${existing_pr}" | jq -r '.url')"
|
||||
successes+=("${target_branch}:existing:${pr_url}")
|
||||
continue
|
||||
fi
|
||||
|
||||
git checkout -B "${head_branch}" "origin/${target_branch}"
|
||||
|
||||
if [ "${merge_parent_count}" -gt 1 ]; then
|
||||
cherry_pick_args=(-m 1 "${MERGE_SHA}")
|
||||
else
|
||||
cherry_pick_args=("${MERGE_SHA}")
|
||||
fi
|
||||
|
||||
if ! git cherry-pick -x "${cherry_pick_args[@]}"; then
|
||||
git cherry-pick --abort || true
|
||||
gh pr edit "${PR_NUMBER}" --add-label "backport:conflict" || true
|
||||
gh pr comment "${PR_NUMBER}" --body \
|
||||
"Automatic backport to \`${target_branch}\` conflicted. The original author or release manager should resolve it manually."
|
||||
failures+=("${target_branch}:cherry-pick failed")
|
||||
continue
|
||||
fi
|
||||
|
||||
git push --force-with-lease origin "${head_branch}"
|
||||
|
||||
body_file="$(mktemp)"
|
||||
printf '%s\n' \
|
||||
"Automated backport of [#${PR_NUMBER}](${PR_URL})." \
|
||||
"" \
|
||||
"- Source PR: #${PR_NUMBER}" \
|
||||
"- Source commit: ${MERGE_SHA}" \
|
||||
"- Target branch: ${target_branch}" \
|
||||
"- Severity: ${SEVERITY_LABEL}" \
|
||||
> "${body_file}"
|
||||
|
||||
pr_url="$(gh pr create \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--title "${PR_TITLE} (backport to ${target_branch})" \
|
||||
--body-file "${body_file}")"
|
||||
|
||||
backport_pr_number="$(gh pr list \
|
||||
--base "${target_branch}" \
|
||||
--head "${head_branch}" \
|
||||
--state open \
|
||||
--json number \
|
||||
--jq '.[0].number')"
|
||||
|
||||
gh pr edit "${backport_pr_number}" \
|
||||
--add-label "security:patch" \
|
||||
--add-label "${SEVERITY_LABEL}" \
|
||||
--add-label "${backport_label}" || true
|
||||
|
||||
successes+=("${target_branch}:created:${pr_url}")
|
||||
done < <(
|
||||
python3 - <<'PY'
|
||||
import json
|
||||
import os
|
||||
for pair in json.loads(os.environ["TARGET_PAIRS"]):
|
||||
print(f"{pair['label']}\t{pair['branch']}")
|
||||
PY
|
||||
)
|
||||
|
||||
summary_file="$(mktemp)"
|
||||
{
|
||||
echo "## Security backport summary"
|
||||
echo
|
||||
if [ "${#successes[@]}" -gt 0 ]; then
|
||||
echo "### Created or existing"
|
||||
for entry in "${successes[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
echo
|
||||
fi
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
echo "### Failures"
|
||||
for entry in "${failures[@]}"; do
|
||||
echo "- ${entry}"
|
||||
done
|
||||
fi
|
||||
} | tee -a "${GITHUB_STEP_SUMMARY}" > "${summary_file}"
|
||||
|
||||
gh pr comment "${PR_NUMBER}" --body-file "${summary_file}"
|
||||
|
||||
if [ "${#failures[@]}" -gt 0 ]; then
|
||||
printf 'Backport failures:\n%s\n' "${failures[@]}" >&2
|
||||
exit 1
|
||||
fi
|
||||
@@ -0,0 +1,214 @@
|
||||
name: security-patch-prs
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1-5"
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
patch:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
lane:
|
||||
- gomod
|
||||
- terraform
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: ./.github/actions/setup-go
|
||||
|
||||
- name: Patch Go dependencies
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
go get -u=patch ./...
|
||||
go mod tidy
|
||||
|
||||
# Guardrail: do not auto-edit replace directives.
|
||||
if git diff --unified=0 -- go.mod | grep -E '^[+-]replace '; then
|
||||
echo "Refusing to auto-edit go.mod replace directives"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Guardrail: only go.mod / go.sum may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(go\.mod|go\.sum)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Patch bundled Terraform
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
current="$(
|
||||
grep -oE 'NewVersion\("[0-9]+\.[0-9]+\.[0-9]+"\)' \
|
||||
provisioner/terraform/install.go \
|
||||
| head -1 \
|
||||
| grep -oE '[0-9]+\.[0-9]+\.[0-9]+'
|
||||
)"
|
||||
|
||||
series="$(echo "$current" | cut -d. -f1,2)"
|
||||
|
||||
latest="$(
|
||||
curl -fsSL https://releases.hashicorp.com/terraform/index.json \
|
||||
| jq -r --arg series "$series" '
|
||||
.versions
|
||||
| keys[]
|
||||
| select(startswith($series + "."))
|
||||
' \
|
||||
| sort -V \
|
||||
| tail -1
|
||||
)"
|
||||
|
||||
test -n "$latest"
|
||||
[ "$latest" != "$current" ] || exit 0
|
||||
|
||||
CURRENT_TERRAFORM_VERSION="$current" \
|
||||
LATEST_TERRAFORM_VERSION="$latest" \
|
||||
python3 - <<'PY'
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
current = os.environ["CURRENT_TERRAFORM_VERSION"]
|
||||
latest = os.environ["LATEST_TERRAFORM_VERSION"]
|
||||
|
||||
updates = {
|
||||
"scripts/Dockerfile.base": (
|
||||
f"terraform/{current}/",
|
||||
f"terraform/{latest}/",
|
||||
),
|
||||
"provisioner/terraform/install.go": (
|
||||
f'NewVersion("{current}")',
|
||||
f'NewVersion("{latest}")',
|
||||
),
|
||||
"install.sh": (
|
||||
f'TERRAFORM_VERSION="{current}"',
|
||||
f'TERRAFORM_VERSION="{latest}"',
|
||||
),
|
||||
}
|
||||
|
||||
for path_str, (before, after) in updates.items():
|
||||
path = Path(path_str)
|
||||
content = path.read_text()
|
||||
if before not in content:
|
||||
raise SystemExit(f"did not find expected text in {path_str}: {before}")
|
||||
path.write_text(content.replace(before, after))
|
||||
PY
|
||||
|
||||
# Guardrail: only the Terraform-version files may change.
|
||||
extra="$(git diff --name-only | grep -Ev '^(scripts/Dockerfile.base|provisioner/terraform/install.go|install.sh)$' || true)"
|
||||
test -z "$extra" || { echo "Unexpected files changed:"; echo "$extra"; exit 1; }
|
||||
|
||||
- name: Validate Go dependency patch
|
||||
if: matrix.lane == 'gomod'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./...
|
||||
|
||||
- name: Validate Terraform patch
|
||||
if: matrix.lane == 'terraform'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
go test ./provisioner/terraform/...
|
||||
docker build -f scripts/Dockerfile.base .
|
||||
|
||||
- name: Skip PR creation when there are no changes
|
||||
id: changes
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if git diff --quiet; then
|
||||
echo "has_changes=false" >> "${GITHUB_OUTPUT}"
|
||||
else
|
||||
echo "has_changes=true" >> "${GITHUB_OUTPUT}"
|
||||
fi
|
||||
|
||||
- name: Commit changes
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
git checkout -B "secpatch/${{ matrix.lane }}"
|
||||
git add -A
|
||||
git commit -m "security: patch ${{ matrix.lane }}"
|
||||
|
||||
- name: Push branch
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git push --force-with-lease \
|
||||
"https://x-access-token:${GH_TOKEN}@github.com/${{ github.repository }}.git" \
|
||||
"HEAD:refs/heads/secpatch/${{ matrix.lane }}"
|
||||
|
||||
- name: Create or update PR
|
||||
if: steps.changes.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
branch="secpatch/${{ matrix.lane }}"
|
||||
title="security: patch ${{ matrix.lane }}"
|
||||
body="$(cat <<'EOF'
|
||||
Automated security patch PR for `${{ matrix.lane }}`.
|
||||
|
||||
Scope:
|
||||
- gomod: patch-level Go dependency updates only
|
||||
- terraform: bundled Terraform patch updates only
|
||||
|
||||
Guardrails:
|
||||
- no application-code edits
|
||||
- no auto-editing of go.mod replace directives
|
||||
- CI must pass
|
||||
EOF
|
||||
)"
|
||||
|
||||
existing_pr="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
if [[ -n "${existing_pr}" ]]; then
|
||||
gh pr edit "${existing_pr}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="${existing_pr}"
|
||||
else
|
||||
gh pr create \
|
||||
--base main \
|
||||
--head "${branch}" \
|
||||
--title "${title}" \
|
||||
--body "${body}"
|
||||
pr_number="$(gh pr list --head "${branch}" --base main --json number --jq '.[0].number // empty')"
|
||||
fi
|
||||
|
||||
for label in security dependencies automated-pr; do
|
||||
if gh label list --json name --jq '.[].name' | grep -Fxq "${label}"; then
|
||||
gh pr edit "${pr_number}" --add-label "${label}"
|
||||
fi
|
||||
done
|
||||
@@ -36,7 +36,6 @@ typ = "typ"
|
||||
styl = "styl"
|
||||
edn = "edn"
|
||||
Inferrable = "Inferrable"
|
||||
IIF = "IIF"
|
||||
|
||||
[files]
|
||||
extend-exclude = [
|
||||
|
||||
@@ -103,6 +103,3 @@ PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
-e
|
||||
# Agent planning documents (local working files).
|
||||
docs/plans/
|
||||
|
||||
+6
-14
@@ -102,8 +102,6 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
BlockReversePortForwarding bool
|
||||
BlockLocalPortForwarding bool
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
@@ -216,8 +214,6 @@ func New(options Options) Agent {
|
||||
subsystems: options.Subsystems,
|
||||
logSender: agentsdk.NewLogSender(options.Logger),
|
||||
blockFileTransfer: options.BlockFileTransfer,
|
||||
blockReversePortForwarding: options.BlockReversePortForwarding,
|
||||
blockLocalPortForwarding: options.BlockLocalPortForwarding,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
@@ -284,8 +280,6 @@ type agent struct {
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
||||
@@ -337,14 +331,12 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
BlockReversePortForwarding: a.blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: a.blockLocalPortForwarding,
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
|
||||
@@ -986,161 +986,6 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
_, err = sshClient.ListenTCP(addr)
|
||||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
|
||||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||||
// local port forwarding does not prevent reverse port forwarding from
|
||||
// working. A field-name transposition at any plumbing hop would cause
|
||||
// both directions to be blocked when only one flag is set.
|
||||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Reverse forwarding must still work.
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var ll net.Listener
|
||||
for {
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = ll.Close()
|
||||
}
|
||||
|
||||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||||
// reverse port forwarding does not prevent local port forwarding from
|
||||
// working.
|
||||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
go echoOnce(t, rl)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Local forwarding must still work.
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
@@ -117,10 +117,6 @@ type Config struct {
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
|
||||
BlockReversePortForwarding bool
|
||||
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
|
||||
BlockLocalPortForwarding bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
// Experimental: allow connecting to running containers via Docker exec.
|
||||
@@ -194,7 +190,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
@@ -233,15 +229,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "unix local port forward blocked")
|
||||
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
directStreamLocalHandler(srv, conn, newChan, ctx)
|
||||
},
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
s.logger.Warn(ctx, "ssh connection failed",
|
||||
@@ -261,12 +250,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
// be set before we start listening.
|
||||
HostSigners: []ssh.Signer{},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "local port forward blocked",
|
||||
slog.F("destination_host", destinationHost),
|
||||
slog.F("destination_port", destinationPort))
|
||||
return false
|
||||
}
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination_host", destinationHost),
|
||||
@@ -277,12 +260,6 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if s.config.BlockReversePortForwarding {
|
||||
s.logger.Warn(ctx, "reverse port forward blocked",
|
||||
slog.F("bind_host", bindHost),
|
||||
slog.F("bind_port", bindPort))
|
||||
return false
|
||||
}
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "reverse port forward",
|
||||
slog.F("bind_host", bindHost),
|
||||
|
||||
@@ -35,9 +35,8 @@ type forwardedStreamLocalPayload struct {
|
||||
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
|
||||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
blockReversePortForwarding bool
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
@@ -45,11 +44,10 @@ type forwardKey struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
blockReversePortForwarding: blockReversePortForwarding,
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,10 +62,6 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if h.blockReversePortForwarding {
|
||||
log.Warn(ctx, "unix reverse port forward blocked")
|
||||
return false, nil
|
||||
}
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
|
||||
+4
-22
@@ -53,8 +53,6 @@ func workspaceAgent() *serpent.Command {
|
||||
slogJSONPath string
|
||||
slogStackdriverPath string
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainers bool
|
||||
@@ -321,12 +319,10 @@ func workspaceAgent() *serpent.Command {
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
BlockReversePortForwarding: blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: blockLocalPortForwarding,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
@@ -497,20 +493,6 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
|
||||
Value: serpent.BoolOf(&blockFileTransfer),
|
||||
},
|
||||
{
|
||||
Flag: "block-reverse-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
|
||||
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
|
||||
Value: serpent.BoolOf(&blockReversePortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "block-local-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
|
||||
Description: "Block local port forwarding through the SSH server (ssh -L).",
|
||||
Value: serpent.BoolOf(&blockLocalPortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "devcontainers-enable",
|
||||
Default: "true",
|
||||
|
||||
@@ -768,30 +768,10 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
}
|
||||
options.Pubsub = ps
|
||||
options.ChatPubsub = ps
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(ps)
|
||||
}
|
||||
defer options.Pubsub.Close()
|
||||
chatPubsub, err := pubsub.NewBatching(
|
||||
ctx,
|
||||
logger.Named("chatd").Named("pubsub_batch"),
|
||||
ps,
|
||||
sqlDB,
|
||||
dbURL,
|
||||
pubsub.BatchingConfig{
|
||||
FlushInterval: options.DeploymentValues.AI.Chat.PubsubFlushInterval.Value(),
|
||||
QueueSize: int(options.DeploymentValues.AI.Chat.PubsubQueueSize.Value()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create chat pubsub batcher: %w", err)
|
||||
}
|
||||
options.ChatPubsub = chatPubsub
|
||||
if options.DeploymentValues.Prometheus.Enable {
|
||||
options.PrometheusRegistry.MustRegister(chatPubsub)
|
||||
}
|
||||
defer options.ChatPubsub.Close()
|
||||
psWatchdog := pubsub.NewWatchdog(ctx, logger.Named("pswatch"), ps)
|
||||
pubsubWatchdogTimeout = psWatchdog.Timeout()
|
||||
defer psWatchdog.Close()
|
||||
|
||||
+17
-97
@@ -52,10 +52,6 @@ import (
|
||||
|
||||
const (
|
||||
disableUsageApp = "disable"
|
||||
|
||||
// Retry transient errors during SSH connection establishment.
|
||||
sshRetryInterval = 2 * time.Second
|
||||
sshMaxAttempts = 10 // initial + retries per step
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -66,51 +62,6 @@ var (
|
||||
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
|
||||
)
|
||||
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// retryWithInterval calls fn up to maxAttempts times, waiting
|
||||
// interval between attempts. Stops on success, non-retryable
|
||||
// error, or context cancellation.
|
||||
func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Duration, maxAttempts int, fn func() error) error {
|
||||
var lastErr error
|
||||
attempt := 0
|
||||
for r := retry.New(interval, interval); r.Wait(ctx); {
|
||||
lastErr = fn()
|
||||
if lastErr == nil || !isRetryableError(lastErr) {
|
||||
return lastErr
|
||||
}
|
||||
attempt++
|
||||
if attempt >= maxAttempts {
|
||||
break
|
||||
}
|
||||
logger.Warn(ctx, "transient error, retrying",
|
||||
slog.Error(lastErr),
|
||||
slog.F("attempt", attempt),
|
||||
)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (r *RootCmd) ssh() *serpent.Command {
|
||||
var (
|
||||
stdio bool
|
||||
@@ -326,17 +277,10 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
HostnameSuffix: hostnameSuffix,
|
||||
}
|
||||
|
||||
// Populated by the closure below.
|
||||
var workspace codersdk.Workspace
|
||||
var workspaceAgent codersdk.WorkspaceAgent
|
||||
resolveWorkspace := func() error {
|
||||
var err error
|
||||
workspace, workspaceAgent, err = findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, resolveWorkspace); err != nil {
|
||||
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -362,13 +306,8 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
wait = false
|
||||
}
|
||||
|
||||
var templateVersion codersdk.TemplateVersion
|
||||
fetchVersion := func() error {
|
||||
var err error
|
||||
templateVersion, err = client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, fetchVersion); err != nil {
|
||||
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -408,12 +347,8 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
// If we're in stdio mode, check to see if we can use Coder Connect.
|
||||
// We don't support Coder Connect over non-stdio coder ssh yet.
|
||||
if stdio && !forceNewTunnel {
|
||||
var connInfo workspacesdk.AgentConnectionInfo
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
connInfo, err = wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get agent connection info: %w", err)
|
||||
}
|
||||
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
|
||||
@@ -449,27 +384,23 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
})
|
||||
defer closeUsage()
|
||||
}
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack, logger)
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
|
||||
}
|
||||
}
|
||||
|
||||
if r.disableDirect {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
|
||||
}
|
||||
var conn workspacesdk.AgentConn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
conn, err := wsClient.
|
||||
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
Logger: logger,
|
||||
BlockEndpoints: r.disableDirect,
|
||||
EnableTelemetry: !r.disableNetworkTelemetry,
|
||||
})
|
||||
return err
|
||||
}); err != nil {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial agent: %w", err)
|
||||
}
|
||||
if err = stack.push("agent conn", conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
conn.AwaitReachable(ctx)
|
||||
@@ -1647,27 +1578,16 @@ func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDial
|
||||
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
|
||||
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
|
||||
if !ok || dialer == nil {
|
||||
// Timeout prevents hanging on broken tunnels (OS default is very long).
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
return &net.Dialer{}
|
||||
}
|
||||
return dialer
|
||||
}
|
||||
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack, logger slog.Logger) error {
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
var conn net.Conn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host %q over tcp: %w", addr, err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host: %w", err)
|
||||
}
|
||||
if err := stack.push("tcp conn", conn); err != nil {
|
||||
return err
|
||||
|
||||
+1
-149
@@ -5,9 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -228,41 +226,6 @@ func TestCloserStack_Timeout(t *testing.T) {
|
||||
testutil.TryReceive(ctx, t, closed)
|
||||
}
|
||||
|
||||
func TestCloserStack_PushAfterClose_ConnClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
|
||||
|
||||
uut.close(xerrors.New("canceled"))
|
||||
|
||||
closes := new([]*fakeCloser)
|
||||
fc := &fakeCloser{closes: closes}
|
||||
err := uut.push("conn", fc)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, []*fakeCloser{fc}, *closes, "should close conn on failed push")
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_DefaultTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
d, ok := dialer.(*net.Dialer)
|
||||
require.True(t, ok, "expected *net.Dialer")
|
||||
assert.Equal(t, 5*time.Second, d.Timeout)
|
||||
assert.Equal(t, 30*time.Second, d.KeepAlive)
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_Overridden(t *testing.T) {
|
||||
t.Parallel()
|
||||
custom := &net.Dialer{Timeout: 99 * time.Second}
|
||||
ctx := WithTestOnlyCoderConnectDialer(context.Background(), custom)
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
assert.Equal(t, custom, dialer)
|
||||
}
|
||||
|
||||
func TestCoderConnectStdio(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -291,7 +254,7 @@ func TestCoderConnectStdio(t *testing.T) {
|
||||
|
||||
stdioDone := make(chan struct{})
|
||||
go func() {
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack, logger)
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
|
||||
assert.NoError(t, err)
|
||||
close(stdioDone)
|
||||
}()
|
||||
@@ -485,114 +448,3 @@ func Test_getWorkspaceAgent(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRetryableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
retryable bool
|
||||
}{
|
||||
{"Nil", nil, false},
|
||||
{"ContextCanceled", context.Canceled, false},
|
||||
{"ContextDeadlineExceeded", context.DeadlineExceeded, false},
|
||||
{"WrappedContextCanceled", xerrors.Errorf("wrapped: %w", context.Canceled), false},
|
||||
{"DNSError", &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}, true},
|
||||
{"OpError", &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{}}, true},
|
||||
{"WrappedDNSError", xerrors.Errorf("connect: %w", &net.DNSError{Err: "no such host", Name: "example.com"}), true},
|
||||
{"SDKError_500", codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"), true},
|
||||
{"SDKError_502", codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"), true},
|
||||
{"SDKError_503", codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"), true},
|
||||
{"SDKError_401", codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"), false},
|
||||
{"SDKError_403", codersdk.NewTestError(http.StatusForbidden, "GET", "/api"), false},
|
||||
{"SDKError_404", codersdk.NewTestError(http.StatusNotFound, "GET", "/api"), false},
|
||||
{"GenericError", xerrors.New("something went wrong"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const interval = time.Millisecond
|
||||
const maxAttempts = 3
|
||||
|
||||
dnsErr := &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
t.Run("Succeeds_FirstTry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Succeeds_AfterTransientFailures", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return dnsErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_NonRetryableError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return xerrors.New("permanent failure")
|
||||
})
|
||||
require.ErrorContains(t, err, "permanent failure")
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_MaxAttemptsExhausted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, maxAttempts, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_ContextCanceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
cancel()
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
}
|
||||
|
||||
-6
@@ -39,12 +39,6 @@ OPTIONS:
|
||||
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
|
||||
Block file transfer using known applications: nc,rsync,scp,sftp.
|
||||
|
||||
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
|
||||
Block local port forwarding through the SSH server (ssh -L).
|
||||
|
||||
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
|
||||
Block reverse port forwarding through the SSH server (ssh -R).
|
||||
|
||||
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
|
||||
The path for the boundary log proxy server Unix socket. Boundary
|
||||
should write audit logs to this socket.
|
||||
|
||||
@@ -134,7 +134,6 @@ func TestUserCreate(t *testing.T) {
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
err: "Premium feature",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
|
||||
@@ -77,9 +77,8 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,59 +139,6 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Generated
-78
@@ -1266,68 +1266,6 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -14482,14 +14420,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21022,14 +20952,6 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
-70
@@ -1103,60 +1103,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -13017,14 +12963,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19305,14 +19243,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
+2
-13
@@ -159,10 +159,7 @@ type Options struct {
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub pubsub.Pubsub
|
||||
// ChatPubsub allows chatd to use a dedicated publish path without changing
|
||||
// the shared pubsub used by the rest of coderd.
|
||||
ChatPubsub pubsub.Pubsub
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
RuntimeConfig *runtimeconfig.Manager
|
||||
|
||||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
@@ -780,11 +777,6 @@ func New(options *Options) *API {
|
||||
maxChatsPerAcquire = math.MinInt32
|
||||
}
|
||||
|
||||
chatPubsub := options.ChatPubsub
|
||||
if chatPubsub == nil {
|
||||
chatPubsub = options.Pubsub
|
||||
}
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
@@ -797,7 +789,7 @@ func New(options *Options) *API {
|
||||
InstructionLookupTimeout: options.ChatdInstructionLookupTimeout,
|
||||
CreateWorkspace: api.chatCreateWorkspace,
|
||||
StartWorkspace: api.chatStartWorkspace,
|
||||
Pubsub: chatPubsub,
|
||||
Pubsub: options.Pubsub,
|
||||
WebpushDispatcher: options.WebPushDispatcher,
|
||||
UsageTracker: options.WorkspaceUsageTracker,
|
||||
})
|
||||
@@ -1197,8 +1189,6 @@ func New(options *Options) *API {
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
r.Get("/retention-days", api.getChatRetentionDays)
|
||||
r.Put("/retention-days", api.putChatRetentionDays)
|
||||
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
|
||||
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
|
||||
})
|
||||
@@ -1253,7 +1243,6 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/tool-results", api.postChatToolResults)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
|
||||
@@ -123,10 +123,6 @@ func UsersPagination(
|
||||
require.Contains(t, gotUsers[0].Name, "after")
|
||||
}
|
||||
|
||||
type UsersFilterOptions struct {
|
||||
CreateServiceAccounts bool
|
||||
}
|
||||
|
||||
// UsersFilter creates a set of users to run various filters against for
|
||||
// testing. It can be used to test filtering both users and group members.
|
||||
func UsersFilter(
|
||||
@@ -134,16 +130,11 @@ func UsersFilter(
|
||||
t *testing.T,
|
||||
client *codersdk.Client,
|
||||
db database.Store,
|
||||
options *UsersFilterOptions,
|
||||
setup func(users []codersdk.User),
|
||||
fetch func(ctx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
if options == nil {
|
||||
options = &UsersFilterOptions{}
|
||||
}
|
||||
|
||||
firstUser, err := client.User(setupCtx, codersdk.Me)
|
||||
require.NoError(t, err, "fetch me")
|
||||
|
||||
@@ -220,13 +211,11 @@ func UsersFilter(
|
||||
}
|
||||
|
||||
// Add some service accounts.
|
||||
if options.CreateServiceAccounts {
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
}
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash("SomeStrongPassword!")
|
||||
|
||||
@@ -1715,41 +1715,3 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// UserSecret converts a database ListUserSecretsRow (metadata only,
|
||||
// no value) to an SDK UserSecret.
|
||||
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecretFromFull converts a full database UserSecret row to an
|
||||
// SDK UserSecret, omitting the value and encryption key ID.
|
||||
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecrets converts a slice of database ListUserSecretsRow to
|
||||
// SDK UserSecret values.
|
||||
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
|
||||
result := make([]codersdk.UserSecret, 0, len(secrets))
|
||||
for _, s := range secrets {
|
||||
result = append(result, UserSecret(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -552,10 +552,6 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
|
||||
@@ -2031,20 +2031,6 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
|
||||
return q.db.DeleteOldAuditLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -2169,12 +2155,17 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
func (q *querier) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, secret); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.DeleteUserSecret(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -2636,14 +2627,6 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -2692,14 +2675,6 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigsForTelemetry(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
@@ -2729,15 +2704,6 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
// Chat retention is a deployment-wide config read by dbpurge.
|
||||
// Only requires a valid actor in context.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return 0, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatRetentionDays(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
// The system prompt is a deployment-wide setting read during chat
|
||||
// creation by every authenticated user, so no RBAC policy check
|
||||
@@ -2816,14 +2782,6 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -4170,6 +4128,19 @@ func (q *querier) GetUserNotificationPreferences(ctx context.Context, userID uui
|
||||
return q.db.GetUserNotificationPreferences(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, id)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
@@ -5553,7 +5524,7 @@ func (q *querier) ListUserChatCompactionThresholds(ctx context.Context, userID u
|
||||
return q.db.ListUserChatCompactionThresholds(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(userID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, obj); err != nil {
|
||||
return nil, err
|
||||
@@ -5561,16 +5532,6 @@ func (q *querier) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]data
|
||||
return q.db.ListUserSecrets(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
// This query returns decrypted secret values and must only be called
|
||||
// from system contexts (provisioner, agent manifest). REST API
|
||||
// handlers should use ListUserSecrets (metadata only).
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.ListUserSecretsWithValues(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
workspace, err := q.db.GetWorkspaceByID(ctx, workspaceID)
|
||||
if err != nil {
|
||||
@@ -6671,12 +6632,17 @@ func (q *querier) UpdateUserRoles(ctx context.Context, arg database.UpdateUserRo
|
||||
return q.db.UpdateUserRoles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, obj); err != nil {
|
||||
func (q *querier) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
// First get the secret to check ownership
|
||||
secret, err := q.db.GetUserSecret(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, secret); err != nil {
|
||||
return database.UserSecret{}, err
|
||||
}
|
||||
return q.db.UpdateUserSecret(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateUserStatus(ctx context.Context, arg database.UpdateUserStatusParams) (database.User, error) {
|
||||
@@ -7078,13 +7044,6 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
|
||||
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -600,22 +600,6 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
|
||||
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -4012,20 +3996,6 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
|
||||
@@ -5376,20 +5346,19 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("GetUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead).
|
||||
Returns(secret)
|
||||
}))
|
||||
s.Run("ListUserSecrets", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
row := testutil.Fake(s.T(), faker, database.ListUserSecretsRow{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.ListUserSecretsRow{row}, nil).AnyTimes()
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecrets(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionRead).
|
||||
Returns([]database.ListUserSecretsRow{row})
|
||||
}))
|
||||
s.Run("ListUserSecretsWithValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
dbm.EXPECT().ListUserSecretsWithValues(gomock.Any(), user.ID).Return([]database.UserSecret{secret}, nil).AnyTimes()
|
||||
check.Args(user.ID).
|
||||
Asserts(rbac.ResourceSystem, policy.ActionRead).
|
||||
Returns([]database.UserSecret{secret})
|
||||
}))
|
||||
s.Run("CreateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
@@ -5401,21 +5370,22 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionCreate).
|
||||
Returns(ret)
|
||||
}))
|
||||
s.Run("UpdateUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{UserID: user.ID})
|
||||
arg := database.UpdateUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().UpdateUserSecretByUserIDAndName(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
s.Run("UpdateUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
updated := testutil.Fake(s.T(), faker, database.UserSecret{ID: secret.ID})
|
||||
arg := database.UpdateUserSecretParams{ID: secret.ID}
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateUserSecret(gomock.Any(), arg).Return(updated, nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionUpdate).
|
||||
Asserts(secret, policy.ActionUpdate).
|
||||
Returns(updated)
|
||||
}))
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
s.Run("DeleteUserSecret", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
secret := testutil.Fake(s.T(), faker, database.UserSecret{})
|
||||
dbm.EXPECT().GetUserSecret(gomock.Any(), secret.ID).Return(secret, nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecret(gomock.Any(), secret.ID).Return(nil).AnyTimes()
|
||||
check.Args(secret.ID).
|
||||
Asserts(secret, policy.ActionRead, secret, policy.ActionDelete).
|
||||
Returns()
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1597,7 +1597,6 @@ func UserSecret(t testing.TB, db database.Store, seed database.UserSecret) datab
|
||||
Name: takeFirst(seed.Name, "secret-name"),
|
||||
Description: takeFirst(seed.Description, "secret description"),
|
||||
Value: takeFirst(seed.Value, "secret value"),
|
||||
ValueKeyID: seed.ValueKeyID,
|
||||
EnvName: takeFirst(seed.EnvName, "SECRET_ENV_NAME"),
|
||||
FilePath: takeFirst(seed.FilePath, "~/secret/file/path"),
|
||||
})
|
||||
@@ -1644,8 +1643,6 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ClientSessionID: seed.ClientSessionID,
|
||||
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
|
||||
CredentialHint: takeFirst(seed.CredentialHint, ""),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
|
||||
@@ -592,22 +592,6 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
|
||||
@@ -728,11 +712,11 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
r0 := m.s.DeleteUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecret").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
@@ -1176,14 +1160,6 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
@@ -1232,14 +1208,6 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
@@ -1272,14 +1240,6 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatRetentionDays(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatSystemPrompt(ctx)
|
||||
@@ -1352,14 +1312,6 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -2672,6 +2624,14 @@ func (m queryMetricsStore) GetUserNotificationPreferences(ctx context.Context, u
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecret(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetUserSecretByUserIDAndName(ctx, arg)
|
||||
@@ -3960,7 +3920,7 @@ func (m queryMetricsStore) ListUserChatCompactionThresholds(ctx context.Context,
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecrets(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecrets").Observe(time.Since(start).Seconds())
|
||||
@@ -3968,14 +3928,6 @@ func (m queryMetricsStore) ListUserSecrets(ctx context.Context, userID uuid.UUID
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListUserSecretsWithValues(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("ListUserSecretsWithValues").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListUserSecretsWithValues").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
@@ -4744,11 +4696,11 @@ func (m queryMetricsStore) UpdateUserRoles(ctx context.Context, arg database.Upd
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
func (m queryMetricsStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecretByUserIDAndName").Inc()
|
||||
r0, r1 := m.s.UpdateUserSecret(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateUserSecret").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserSecret").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -5048,14 +5000,6 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
|
||||
|
||||
@@ -984,36 +984,6 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles mocks base method.
|
||||
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChats mocks base method.
|
||||
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChats indicates an expected call of DeleteOldChats.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldConnectionLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1229,18 +1199,18 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
// DeleteUserSecret mocks base method.
|
||||
func (m *MockStore) DeleteUserSecret(ctx context.Context, id uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
// DeleteUserSecret indicates an expected call of DeleteUserSecret.
|
||||
func (mr *MockStoreMockRecorder) DeleteUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).DeleteUserSecretByUserIDAndName), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserSecret", reflect.TypeOf((*MockStore)(nil).DeleteUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// DeleteWebpushSubscriptionByUserIDAndEndpoint mocks base method.
|
||||
@@ -2162,21 +2132,6 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat mocks base method.
|
||||
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2267,21 +2222,6 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
|
||||
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2342,21 +2282,6 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatRetentionDays mocks base method.
|
||||
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
|
||||
}
|
||||
|
||||
// GetChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2492,21 +2417,6 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter mocks base method.
|
||||
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
|
||||
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -4997,6 +4907,21 @@ func (mr *MockStoreMockRecorder) GetUserNotificationPreferences(ctx, userID any)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserNotificationPreferences", reflect.TypeOf((*MockStore)(nil).GetUserNotificationPreferences), ctx, userID)
|
||||
}
|
||||
|
||||
// GetUserSecret mocks base method.
|
||||
func (m *MockStore) GetUserSecret(ctx context.Context, id uuid.UUID) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUserSecret", ctx, id)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetUserSecret indicates an expected call of GetUserSecret.
|
||||
func (mr *MockStoreMockRecorder) GetUserSecret(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserSecret", reflect.TypeOf((*MockStore)(nil).GetUserSecret), ctx, id)
|
||||
}
|
||||
|
||||
// GetUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) GetUserSecretByUserIDAndName(ctx context.Context, arg database.GetUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7487,10 +7412,10 @@ func (mr *MockStoreMockRecorder) ListUserChatCompactionThresholds(ctx, userID an
|
||||
}
|
||||
|
||||
// ListUserSecrets mocks base method.
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.ListUserSecretsRow, error) {
|
||||
func (m *MockStore) ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecrets", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.ListUserSecretsRow)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -7501,21 +7426,6 @@ func (mr *MockStoreMockRecorder) ListUserSecrets(ctx, userID any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecrets", reflect.TypeOf((*MockStore)(nil).ListUserSecrets), ctx, userID)
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues mocks base method.
|
||||
func (m *MockStore) ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListUserSecretsWithValues", ctx, userID)
|
||||
ret0, _ := ret[0].([]database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListUserSecretsWithValues indicates an expected call of ListUserSecretsWithValues.
|
||||
func (mr *MockStoreMockRecorder) ListUserSecretsWithValues(ctx, userID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUserSecretsWithValues", reflect.TypeOf((*MockStore)(nil).ListUserSecretsWithValues), ctx, userID)
|
||||
}
|
||||
|
||||
// ListWorkspaceAgentPortShares mocks base method.
|
||||
func (m *MockStore) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]database.WorkspaceAgentPortShare, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8944,19 +8854,19 @@ func (mr *MockStoreMockRecorder) UpdateUserRoles(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserRoles", reflect.TypeOf((*MockStore)(nil).UpdateUserRoles), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) UpdateUserSecretByUserIDAndName(ctx context.Context, arg database.UpdateUserSecretByUserIDAndNameParams) (database.UserSecret, error) {
|
||||
// UpdateUserSecret mocks base method.
|
||||
func (m *MockStore) UpdateUserSecret(ctx context.Context, arg database.UpdateUserSecretParams) (database.UserSecret, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecretByUserIDAndName", ctx, arg)
|
||||
ret := m.ctrl.Call(m, "UpdateUserSecret", ctx, arg)
|
||||
ret0, _ := ret[0].(database.UserSecret)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateUserSecretByUserIDAndName indicates an expected call of UpdateUserSecretByUserIDAndName.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecretByUserIDAndName(ctx, arg any) *gomock.Call {
|
||||
// UpdateUserSecret indicates an expected call of UpdateUserSecret.
|
||||
func (mr *MockStoreMockRecorder) UpdateUserSecret(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecretByUserIDAndName", reflect.TypeOf((*MockStore)(nil).UpdateUserSecretByUserIDAndName), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserSecret", reflect.TypeOf((*MockStore)(nil).UpdateUserSecret), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateUserStatus mocks base method.
|
||||
@@ -9489,20 +9399,6 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays mocks base method.
|
||||
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
|
||||
}
|
||||
|
||||
// UpsertChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -34,11 +34,6 @@ const (
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
// Batch sizes for chat purging. Both use 1000, which is smaller
|
||||
// than audit/connection log batches (10000), because chat_files
|
||||
// rows contain bytea blob data that make large batches heavier.
|
||||
chatsBatchSize = 1000
|
||||
chatFilesBatchSize = 1000
|
||||
)
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
@@ -114,17 +109,6 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
// purgeTick performs a single purge iteration. It returns an error if the
|
||||
// purge fails.
|
||||
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
|
||||
// Read chat retention config outside the transaction to
|
||||
// avoid poisoning the tx if the stored value is corrupt.
|
||||
// A SQL-level cast error (e.g. non-numeric text) puts PG
|
||||
// into error state, failing all subsequent queries in the
|
||||
// same transaction.
|
||||
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
|
||||
chatRetentionDays = 0
|
||||
}
|
||||
|
||||
// Start a transaction to grab advisory lock, we don't want to run
|
||||
// multiple purges at the same time (multiple replicas).
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
@@ -229,43 +213,12 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
}
|
||||
}
|
||||
|
||||
// Chat retention is configured via site_configs. When
|
||||
// enabled, old archived chats are deleted first, then
|
||||
// orphaned chat files. Deleting a chat cascades to
|
||||
// chat_file_links (removing references) but not to
|
||||
// chat_files directly, so files from deleted chats
|
||||
// become orphaned and are caught by DeleteOldChatFiles
|
||||
// in the same tick.
|
||||
var purgedChats int64
|
||||
var purgedChatFiles int64
|
||||
if chatRetentionDays > 0 {
|
||||
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
|
||||
deleteChatsBefore := start.Add(-chatRetention)
|
||||
|
||||
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatsBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chats: %w", err)
|
||||
}
|
||||
|
||||
purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatFilesBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chat files: %w", err)
|
||||
}
|
||||
}
|
||||
i.logger.Debug(ctx, "purged old database entries",
|
||||
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
|
||||
slog.F("expired_api_keys", expiredAPIKeys),
|
||||
slog.F("aibridge_records", purgedAIBridgeRecords),
|
||||
slog.F("connection_logs", purgedConnectionLogs),
|
||||
slog.F("audit_logs", purgedAuditLogs),
|
||||
slog.F("chats", purgedChats),
|
||||
slog.F("chat_files", purgedChatFiles),
|
||||
slog.F("duration", i.clk.Since(start)),
|
||||
)
|
||||
|
||||
@@ -279,8 +232,6 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
|
||||
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
|
||||
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
|
||||
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
|
||||
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -54,7 +53,6 @@ func TestPurge(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
|
||||
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
<-done // wait for doTick() to run.
|
||||
@@ -127,16 +125,6 @@ func TestMetrics(t *testing.T) {
|
||||
"record_type": "audit_logs",
|
||||
})
|
||||
require.GreaterOrEqual(t, auditLogs, 0)
|
||||
|
||||
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chats",
|
||||
})
|
||||
require.GreaterOrEqual(t, chats, 0)
|
||||
|
||||
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chat_files",
|
||||
})
|
||||
require.GreaterOrEqual(t, chatFiles, 0)
|
||||
})
|
||||
|
||||
t.Run("FailedIteration", func(t *testing.T) {
|
||||
@@ -150,7 +138,6 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
|
||||
Return(xerrors.New("simulated database error")).
|
||||
MinTimes(1)
|
||||
@@ -1647,488 +1634,3 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
//nolint:paralleltest // It uses LockIDDBPurge.
|
||||
func TestDeleteOldChatFiles(t *testing.T) {
|
||||
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// createChatFile inserts a chat file and backdates created_at.
|
||||
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
|
||||
t.Helper()
|
||||
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
Data: []byte("fake-image-data"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
|
||||
require.NoError(t, err)
|
||||
return row.ID
|
||||
}
|
||||
|
||||
// createChat inserts a chat and optionally archives it, then
|
||||
// backdates updated_at to control the "archived since" window.
|
||||
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "test-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
// setupChatDeps creates the common dependencies needed for
|
||||
// chat-related tests: user, org, org member, provider, model config.
|
||||
type chatDeps struct {
|
||||
user database.User
|
||||
org database.Organization
|
||||
modelConfig database.ChatModelConfig
|
||||
}
|
||||
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
ContextLimit: 8192,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chatDeps{user: user, org: org, modelConfig: mc}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "ChatRetentionDisabled",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Disable retention.
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an old archived chat and an orphaned old file.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Both should still exist.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.NoError(t, err, "chat should not be deleted when retention is disabled")
|
||||
_, err = db.GetChatFileByID(ctx, oldFileID)
|
||||
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OldArchivedChatsDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old archived chat (31 days) — should be deleted.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
// Insert a message so we can verify CASCADE.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: oldChat.ID,
|
||||
CreatedBy: []uuid.UUID{deps.user.ID},
|
||||
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recently archived chat (10 days) — should be retained.
|
||||
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
|
||||
// Active chat — should be retained.
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Old archived chat should be gone.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.Error(t, err, "old archived chat should be deleted")
|
||||
|
||||
// Its messages should be gone too (CASCADE).
|
||||
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: oldChat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgs, "messages should be cascade-deleted")
|
||||
|
||||
// Recent archived and active chats should remain.
|
||||
_, err = db.GetChatByID(ctx, recentChat.ID)
|
||||
require.NoError(t, err, "recently archived chat should be retained")
|
||||
_, err = db.GetChatByID(ctx, activeChat.ID)
|
||||
require.NoError(t, err, "active chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OrphanedOldFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File A: 31 days old, NOT in any chat -> should be deleted.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
// File B: 31 days old, in an active chat -> should be retained.
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileB},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// File C: 10 days old, NOT in any chat -> should be retained (too young).
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
|
||||
|
||||
// File near boundary: 29d23h old — close to threshold.
|
||||
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileA)
|
||||
require.Error(t, err, "orphaned old file A should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileB)
|
||||
require.NoError(t, err, "file B in active chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileC)
|
||||
require.NoError(t, err, "young file C should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileBoundary)
|
||||
require.NoError(t, err, "file near 30d boundary should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ArchivedChatFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
|
||||
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: oldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileD},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// LinkChatFiles does not update chats.updated_at, so backdate.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
|
||||
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: recentArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileE},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
|
||||
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: anotherOldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChatForF.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileD)
|
||||
require.Error(t, err, "file D in old archived chat should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileE)
|
||||
require.NoError(t, err, "file E in recently archived chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileF)
|
||||
require.NoError(t, err, "file F in active + old archived chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UnarchiveAfterFilePurge",
|
||||
run: func(t *testing.T) {
|
||||
// Validates that when dbpurge deletes chat_files rows,
|
||||
// the FK cascade on chat_file_links automatically
|
||||
// removes the stale links. Unarchiving a chat after
|
||||
// file purge should show only surviving files.
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create a chat with three attached files.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileA, fileB, fileC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate dbpurge deleting files A and B. The FK
|
||||
// cascade on chat_file_links_file_id_fkey should
|
||||
// automatically remove the corresponding link rows.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive the chat.
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only file C should remain linked (FK cascade
|
||||
// removed the links for deleted files A and B).
|
||||
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1, "only surviving file should be linked")
|
||||
require.Equal(t, fileC, files[0].ID)
|
||||
|
||||
// Edge case: delete the last file too. The chat
|
||||
// should have zero linked files, not an error.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, files, "all-files-deleted should yield empty result")
|
||||
|
||||
// Test parent+child cascade: deleting files should
|
||||
// clean up links for both parent and child chats
|
||||
// independently via FK cascade.
|
||||
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: deps.user.ID,
|
||||
LastModelConfigID: deps.modelConfig.ID,
|
||||
Title: "child-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Set root_chat_id to link child to parent.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach different files to parent and child.
|
||||
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: parentChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: childChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{childFileKeep, childFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive via parent (cascades to child).
|
||||
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one file from each chat.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
|
||||
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive via parent.
|
||||
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parentFiles, 1)
|
||||
require.Equal(t, parentFileKeep, parentFiles[0].ID,
|
||||
"parent should retain only non-stale file")
|
||||
|
||||
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, childFiles, 1)
|
||||
require.Equal(t, childFileKeep, childFiles[0].ID,
|
||||
"child should retain only non-stale file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitFiles",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable orphaned files (all 31 days old).
|
||||
for range 3 {
|
||||
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitChats",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable old archived chats.
|
||||
for range 3 {
|
||||
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.run(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+3
-16
@@ -293,8 +293,7 @@ CREATE TYPE chat_status AS ENUM (
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error',
|
||||
'requires_action'
|
||||
'error'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
@@ -316,11 +315,6 @@ CREATE TYPE cors_behavior AS ENUM (
|
||||
'passthru'
|
||||
);
|
||||
|
||||
CREATE TYPE credential_kind AS ENUM (
|
||||
'centralized',
|
||||
'byok'
|
||||
);
|
||||
|
||||
CREATE TYPE crypto_key_feature AS ENUM (
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
@@ -1107,9 +1101,7 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL,
|
||||
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
|
||||
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1126,10 +1118,6 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1430,8 +1418,7 @@ CREATE TABLE chats (
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb,
|
||||
dynamic_tools jsonb
|
||||
last_injected_context jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
-- First update any rows using the value we're about to remove.
|
||||
-- The column type is still the original chat_status at this point.
|
||||
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
|
||||
|
||||
-- Drop the column (this is independent of the enum).
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
|
||||
|
||||
-- Drop the partial index that references the chat_status enum type.
|
||||
-- It must be removed before the rename-create-cast-drop cycle
|
||||
-- because the index's WHERE clause (status = 'pending'::chat_status)
|
||||
-- would otherwise cause a cross-type comparison failure.
|
||||
DROP INDEX IF EXISTS idx_chats_pending;
|
||||
|
||||
-- Now recreate the enum without requires_action.
|
||||
-- We must use the rename-create-cast-drop pattern.
|
||||
ALTER TYPE chat_status RENAME TO chat_status_old;
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
|
||||
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
|
||||
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
|
||||
DROP TYPE chat_status_old;
|
||||
|
||||
-- Recreate the partial index.
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
@@ -1,3 +0,0 @@
|
||||
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
|
||||
|
||||
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
|
||||
@@ -1,5 +0,0 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN IF EXISTS credential_kind,
|
||||
DROP COLUMN IF EXISTS credential_hint;
|
||||
|
||||
DROP TYPE IF EXISTS credential_kind;
|
||||
@@ -1,12 +0,0 @@
|
||||
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
|
||||
|
||||
-- Records how each LLM request was authenticated and a masked credential
|
||||
-- identifier for audit purposes. Existing rows default to 'centralized'
|
||||
-- with an empty hint since we cannot retroactively determine their values.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
|
||||
-- Length capped as a safety measure to ensure only masked values are stored.
|
||||
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
@@ -798,7 +798,6 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.Chat.DynamicTools,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -869,8 +868,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1134,8 +1131,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1290,13 +1290,12 @@ func AllChatModeValues() []ChatMode {
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusRequiresAction ChatStatus = "requires_action"
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
@@ -1341,8 +1340,7 @@ func (e ChatStatus) Valid() bool {
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction:
|
||||
ChatStatusError:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -1356,7 +1354,6 @@ func AllChatStatusValues() []ChatStatus {
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1546,64 +1543,6 @@ func AllCorsBehaviorValues() []CorsBehavior {
|
||||
}
|
||||
}
|
||||
|
||||
type CredentialKind string
|
||||
|
||||
const (
|
||||
CredentialKindCentralized CredentialKind = "centralized"
|
||||
CredentialKindByok CredentialKind = "byok"
|
||||
)
|
||||
|
||||
func (e *CredentialKind) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = CredentialKind(s)
|
||||
case string:
|
||||
*e = CredentialKind(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullCredentialKind struct {
|
||||
CredentialKind CredentialKind `json:"credential_kind"`
|
||||
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullCredentialKind) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.CredentialKind, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.CredentialKind.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullCredentialKind) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.CredentialKind), nil
|
||||
}
|
||||
|
||||
func (e CredentialKind) Valid() bool {
|
||||
switch e {
|
||||
case CredentialKindCentralized,
|
||||
CredentialKindByok:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllCredentialKindValues() []CredentialKind {
|
||||
return []CredentialKind{
|
||||
CredentialKindCentralized,
|
||||
CredentialKindByok,
|
||||
}
|
||||
}
|
||||
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
@@ -4101,10 +4040,6 @@ type AIBridgeInterception struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
// How the request was authenticated: centralized or byok.
|
||||
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
|
||||
// Masked credential identifier for audit (e.g. sk-a***efgh).
|
||||
CredentialHint string `db:"credential_hint" json:"credential_hint"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4245,7 +4180,6 @@ type Chat struct {
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -1,749 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBatchingFlushInterval is the default upper bound on how long chatd
|
||||
// publishes wait before a scheduled flush when nearby publishes do not
|
||||
// naturally coalesce sooner.
|
||||
DefaultBatchingFlushInterval = 50 * time.Millisecond
|
||||
// DefaultBatchingQueueSize is the default number of buffered chatd publish
|
||||
// requests waiting to be flushed.
|
||||
DefaultBatchingQueueSize = 8192
|
||||
|
||||
defaultBatchingPressureWait = 10 * time.Millisecond
|
||||
defaultBatchingFinalFlushLimit = 15 * time.Second
|
||||
batchingWarnInterval = 10 * time.Second
|
||||
|
||||
batchFlushScheduled = "scheduled"
|
||||
batchFlushShutdown = "shutdown"
|
||||
|
||||
batchFlushStageNone = "none"
|
||||
batchFlushStageBegin = "begin"
|
||||
batchFlushStageExec = "exec"
|
||||
batchFlushStageCommit = "commit"
|
||||
|
||||
batchDelegateFallbackReasonQueueFull = "queue_full"
|
||||
batchDelegateFallbackReasonFlushError = "flush_error"
|
||||
|
||||
batchChannelClassStreamNotify = "stream_notify"
|
||||
batchChannelClassOwnerEvent = "owner_event"
|
||||
batchChannelClassConfigChange = "config_change"
|
||||
batchChannelClassOther = "other"
|
||||
)
|
||||
|
||||
// ErrBatchingPubsubClosed is returned when a batched pubsub publish is
|
||||
// attempted after shutdown has started.
|
||||
var ErrBatchingPubsubClosed = xerrors.New("batched pubsub is closed")
|
||||
|
||||
// BatchingConfig controls the chatd-specific PostgreSQL pubsub batching path.
|
||||
// Flush timing is automatic: the run loop wakes every FlushInterval (or on
|
||||
// backpressure) and drains everything currently queued into a single
|
||||
// transaction. There is no fixed batch-size knob — the batch size is simply
|
||||
// whatever accumulated since the last flush, which naturally adapts to load.
|
||||
type BatchingConfig struct {
|
||||
FlushInterval time.Duration
|
||||
QueueSize int
|
||||
PressureWait time.Duration
|
||||
FinalFlushTimeout time.Duration
|
||||
Clock quartz.Clock
|
||||
}
|
||||
|
||||
type queuedPublish struct {
|
||||
event string
|
||||
channelClass string
|
||||
message []byte
|
||||
}
|
||||
|
||||
type batchSender interface {
|
||||
Flush(ctx context.Context, batch []queuedPublish) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type batchFlushError struct {
|
||||
stage string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *batchFlushError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// BatchingPubsub batches chatd publish traffic onto a dedicated PostgreSQL
|
||||
// sender connection while delegating subscribe behavior to the shared listener
|
||||
// pubsub instance.
|
||||
type BatchingPubsub struct {
|
||||
logger slog.Logger
|
||||
delegate *PGPubsub
|
||||
// sender is only accessed from the run() goroutine (including
|
||||
// flushBatch and resetSender which it calls). Do not read or
|
||||
// write this field from Publish or any other goroutine.
|
||||
sender batchSender
|
||||
newSender func(context.Context) (batchSender, error)
|
||||
clock quartz.Clock
|
||||
|
||||
publishCh chan queuedPublish
|
||||
flushCh chan struct{}
|
||||
closeCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
spaceMu sync.Mutex
|
||||
spaceSignal chan struct{}
|
||||
|
||||
warnTicker *quartz.Ticker
|
||||
|
||||
flushInterval time.Duration
|
||||
pressureWait time.Duration
|
||||
finalFlushTimeout time.Duration
|
||||
|
||||
queuedCount atomic.Int64
|
||||
closed atomic.Bool
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
runErr error
|
||||
|
||||
runCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
metrics batchingMetrics
|
||||
}
|
||||
|
||||
type batchingMetrics struct {
|
||||
QueueDepth prometheus.Gauge
|
||||
BatchSize prometheus.Histogram
|
||||
FlushDuration *prometheus.HistogramVec
|
||||
DelegateFallbacksTotal *prometheus.CounterVec
|
||||
SenderResetsTotal prometheus.Counter
|
||||
SenderResetFailuresTotal prometheus.Counter
|
||||
}
|
||||
|
||||
func newBatchingMetrics() batchingMetrics {
|
||||
return batchingMetrics{
|
||||
QueueDepth: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_queue_depth",
|
||||
Help: "The number of chatd notifications waiting in the batching queue.",
|
||||
}),
|
||||
BatchSize: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_size",
|
||||
Help: "The number of logical notifications sent in each chatd batch flush.",
|
||||
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
|
||||
}),
|
||||
FlushDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_flush_duration_seconds",
|
||||
Help: "The time spent flushing one chatd batch to PostgreSQL.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 20, 30},
|
||||
}, []string{"reason"}),
|
||||
DelegateFallbacksTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_delegate_fallbacks_total",
|
||||
Help: "The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage.",
|
||||
}, []string{"channel_class", "reason", "stage"}),
|
||||
SenderResetsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_resets_total",
|
||||
Help: "The number of successful batched pubsub sender resets after flush failures.",
|
||||
}),
|
||||
SenderResetFailuresTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "coder",
|
||||
Subsystem: "pubsub",
|
||||
Name: "batch_sender_reset_failures_total",
|
||||
Help: "The number of batched pubsub sender reset attempts that failed.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Describe(descs chan<- *prometheus.Desc) {
|
||||
m.QueueDepth.Describe(descs)
|
||||
m.BatchSize.Describe(descs)
|
||||
m.FlushDuration.Describe(descs)
|
||||
m.DelegateFallbacksTotal.Describe(descs)
|
||||
m.SenderResetsTotal.Describe(descs)
|
||||
m.SenderResetFailuresTotal.Describe(descs)
|
||||
}
|
||||
|
||||
func (m batchingMetrics) Collect(metrics chan<- prometheus.Metric) {
|
||||
m.QueueDepth.Collect(metrics)
|
||||
m.BatchSize.Collect(metrics)
|
||||
m.FlushDuration.Collect(metrics)
|
||||
m.DelegateFallbacksTotal.Collect(metrics)
|
||||
m.SenderResetsTotal.Collect(metrics)
|
||||
m.SenderResetFailuresTotal.Collect(metrics)
|
||||
}
|
||||
|
||||
// NewBatching creates a chatd-specific batched pubsub wrapper around the
|
||||
// shared PostgreSQL listener implementation.
|
||||
func NewBatching(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if prototype == nil {
|
||||
return nil, xerrors.New("prototype database is nil")
|
||||
}
|
||||
if connectURL == "" {
|
||||
return nil, xerrors.New("connect URL is empty")
|
||||
}
|
||||
|
||||
newSender := func(ctx context.Context) (batchSender, error) {
|
||||
return newPGBatchSender(ctx, logger.Named("sender"), prototype, connectURL)
|
||||
}
|
||||
|
||||
sender, err := newSender(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ps, err := newBatchingPubsub(logger, delegate, sender, cfg)
|
||||
if err != nil {
|
||||
_ = sender.Close()
|
||||
return nil, err
|
||||
}
|
||||
ps.newSender = newSender
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func newBatchingPubsub(
|
||||
logger slog.Logger,
|
||||
delegate *PGPubsub,
|
||||
sender batchSender,
|
||||
cfg BatchingConfig,
|
||||
) (*BatchingPubsub, error) {
|
||||
if delegate == nil {
|
||||
return nil, xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
if sender == nil {
|
||||
return nil, xerrors.New("batch sender is nil")
|
||||
}
|
||||
|
||||
flushInterval := cfg.FlushInterval
|
||||
if flushInterval == 0 {
|
||||
flushInterval = DefaultBatchingFlushInterval
|
||||
}
|
||||
if flushInterval < 0 {
|
||||
return nil, xerrors.New("flush interval must be positive")
|
||||
}
|
||||
|
||||
queueSize := cfg.QueueSize
|
||||
if queueSize == 0 {
|
||||
queueSize = DefaultBatchingQueueSize
|
||||
}
|
||||
if queueSize < 0 {
|
||||
return nil, xerrors.New("queue size must be positive")
|
||||
}
|
||||
|
||||
pressureWait := cfg.PressureWait
|
||||
if pressureWait == 0 {
|
||||
pressureWait = defaultBatchingPressureWait
|
||||
}
|
||||
if pressureWait < 0 {
|
||||
return nil, xerrors.New("pressure wait must be positive")
|
||||
}
|
||||
|
||||
finalFlushTimeout := cfg.FinalFlushTimeout
|
||||
if finalFlushTimeout == 0 {
|
||||
finalFlushTimeout = defaultBatchingFinalFlushLimit
|
||||
}
|
||||
if finalFlushTimeout < 0 {
|
||||
return nil, xerrors.New("final flush timeout must be positive")
|
||||
}
|
||||
|
||||
clock := cfg.Clock
|
||||
if clock == nil {
|
||||
clock = quartz.NewReal()
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(context.Background())
|
||||
ps := &BatchingPubsub{
|
||||
logger: logger,
|
||||
delegate: delegate,
|
||||
sender: sender,
|
||||
clock: clock,
|
||||
publishCh: make(chan queuedPublish, queueSize),
|
||||
flushCh: make(chan struct{}, 1),
|
||||
closeCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
spaceSignal: make(chan struct{}),
|
||||
warnTicker: clock.NewTicker(batchingWarnInterval, "pubsubBatcher", "warn"),
|
||||
flushInterval: flushInterval,
|
||||
pressureWait: pressureWait,
|
||||
finalFlushTimeout: finalFlushTimeout,
|
||||
runCtx: runCtx,
|
||||
cancel: cancel,
|
||||
metrics: newBatchingMetrics(),
|
||||
}
|
||||
ps.metrics.QueueDepth.Set(0)
|
||||
|
||||
go ps.run()
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// Describe implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Describe(descs chan<- *prometheus.Desc) {
|
||||
p.metrics.Describe(descs)
|
||||
}
|
||||
|
||||
// Collect implements prometheus.Collector.
|
||||
func (p *BatchingPubsub) Collect(metrics chan<- prometheus.Metric) {
|
||||
p.metrics.Collect(metrics)
|
||||
}
|
||||
|
||||
// Subscribe delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) Subscribe(event string, listener Listener) (func(), error) {
|
||||
return p.delegate.Subscribe(event, listener)
|
||||
}
|
||||
|
||||
// SubscribeWithErr delegates to the shared PostgreSQL listener pubsub.
|
||||
func (p *BatchingPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (func(), error) {
|
||||
return p.delegate.SubscribeWithErr(event, listener)
|
||||
}
|
||||
|
||||
// Publish enqueues a logical notification for asynchronous batched delivery.
|
||||
func (p *BatchingPubsub) Publish(event string, message []byte) error {
|
||||
channelClass := batchChannelClass(event)
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
|
||||
req := queuedPublish{
|
||||
event: event,
|
||||
channelClass: channelClass,
|
||||
message: bytes.Clone(message),
|
||||
}
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
timer := p.clock.NewTimer(p.pressureWait, "pubsubBatcher", "pressureWait")
|
||||
defer timer.Stop("pubsubBatcher", "pressureWait")
|
||||
|
||||
for {
|
||||
if p.closed.Load() {
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
p.signalPressureFlush()
|
||||
spaceSignal := p.currentSpaceSignal()
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-spaceSignal:
|
||||
continue
|
||||
case <-timer.C:
|
||||
if p.tryEnqueue(req) {
|
||||
return nil
|
||||
}
|
||||
// The batching queue is still full after a pressure
|
||||
// flush and brief wait. Fall back to the shared
|
||||
// pubsub pool so the notification is still delivered
|
||||
// rather than dropped.
|
||||
p.observeDelegateFallback(channelClass, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)
|
||||
p.logPublishRejection(event)
|
||||
return p.delegate.Publish(event, message)
|
||||
case <-p.doneCh:
|
||||
return ErrBatchingPubsubClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops accepting new publishes, performs a bounded best-effort drain,
|
||||
// and then closes the dedicated sender connection.
|
||||
func (p *BatchingPubsub) Close() error {
|
||||
p.closeOnce.Do(func() {
|
||||
p.closed.Store(true)
|
||||
p.cancel()
|
||||
p.notifySpaceAvailable()
|
||||
close(p.closeCh)
|
||||
<-p.doneCh
|
||||
p.closeErr = p.runErr
|
||||
})
|
||||
return p.closeErr
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) tryEnqueue(req queuedPublish) bool {
|
||||
if p.closed.Load() {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case p.publishCh <- req:
|
||||
queuedDepth := p.queuedCount.Add(1)
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeQueueDepth(depth int64) {
|
||||
p.metrics.QueueDepth.Set(float64(depth))
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) signalPressureFlush() {
|
||||
select {
|
||||
case p.flushCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) currentSpaceSignal() <-chan struct{} {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
return p.spaceSignal
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) notifySpaceAvailable() {
|
||||
p.spaceMu.Lock()
|
||||
defer p.spaceMu.Unlock()
|
||||
close(p.spaceSignal)
|
||||
p.spaceSignal = make(chan struct{})
|
||||
}
|
||||
|
||||
func batchChannelClass(event string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(event, "chat:stream:"):
|
||||
return batchChannelClassStreamNotify
|
||||
case strings.HasPrefix(event, "chat:owner:"):
|
||||
return batchChannelClassOwnerEvent
|
||||
case event == "chat:config_change":
|
||||
return batchChannelClassConfigChange
|
||||
default:
|
||||
return batchChannelClassOther
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallback(channelClass string, reason string, stage string) {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Inc()
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) observeDelegateFallbackBatch(batch []queuedPublish, reason string, stage string) {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
counts := make(map[string]int)
|
||||
for _, item := range batch {
|
||||
counts[item.channelClass]++
|
||||
}
|
||||
for channelClass, count := range counts {
|
||||
p.metrics.DelegateFallbacksTotal.WithLabelValues(channelClass, reason, stage).Add(float64(count))
|
||||
}
|
||||
}
|
||||
|
||||
func batchFlushStage(err error) string {
|
||||
var flushErr *batchFlushError
|
||||
if errors.As(err, &flushErr) {
|
||||
return flushErr.stage
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) run() {
|
||||
defer close(p.doneCh)
|
||||
defer p.warnTicker.Stop("pubsubBatcher", "warn")
|
||||
|
||||
batch := make([]queuedPublish, 0, 64)
|
||||
timer := p.clock.NewTimer(p.flushInterval, "pubsubBatcher", "scheduledFlush")
|
||||
defer timer.Stop("pubsubBatcher", "scheduledFlush")
|
||||
|
||||
flush := func(reason string) {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
batch, _ = p.flushBatch(p.runCtx, batch, reason)
|
||||
timer.Reset(p.flushInterval, "pubsubBatcher", reason+"Flush")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
// An item arrived before the timer fired. Append it and
|
||||
// let the timer or pressure signal trigger the actual
|
||||
// flush so that nearby publishes coalesce naturally.
|
||||
batch = append(batch, item)
|
||||
p.notifySpaceAvailable()
|
||||
case <-timer.C:
|
||||
flush(batchFlushScheduled)
|
||||
case <-p.flushCh:
|
||||
flush("pressure")
|
||||
case <-p.closeCh:
|
||||
p.runErr = errors.Join(p.drain(batch), p.sender.Close())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drainIntoBatch(batch []queuedPublish) []queuedPublish {
|
||||
drained := false
|
||||
for {
|
||||
select {
|
||||
case item := <-p.publishCh:
|
||||
batch = append(batch, item)
|
||||
drained = true
|
||||
default:
|
||||
if drained {
|
||||
p.notifySpaceAvailable()
|
||||
}
|
||||
return batch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) flushBatch(
|
||||
ctx context.Context,
|
||||
batch []queuedPublish,
|
||||
reason string,
|
||||
) ([]queuedPublish, error) {
|
||||
if len(batch) == 0 {
|
||||
return batch[:0], nil
|
||||
}
|
||||
|
||||
count := len(batch)
|
||||
totalBytes := 0
|
||||
for _, item := range batch {
|
||||
totalBytes += len(item.message)
|
||||
}
|
||||
|
||||
p.metrics.BatchSize.Observe(float64(count))
|
||||
start := p.clock.Now()
|
||||
senderErr := p.sender.Flush(ctx, batch)
|
||||
elapsed := p.clock.Since(start)
|
||||
p.metrics.FlushDuration.WithLabelValues(reason).Observe(elapsed.Seconds())
|
||||
|
||||
var err error
|
||||
if senderErr != nil {
|
||||
stage := batchFlushStage(senderErr)
|
||||
delivered, failed, fallbackErr := p.replayBatchViaDelegate(batch, batchDelegateFallbackReasonFlushError, stage)
|
||||
var resetErr error
|
||||
if reason != batchFlushShutdown {
|
||||
resetErr = p.resetSender()
|
||||
}
|
||||
p.logFlushFailure(reason, stage, count, totalBytes, delivered, failed, senderErr, fallbackErr, resetErr)
|
||||
if fallbackErr != nil || resetErr != nil {
|
||||
err = errors.Join(senderErr, fallbackErr, resetErr)
|
||||
}
|
||||
} else if p.delegate != nil {
|
||||
p.delegate.publishesTotal.WithLabelValues("true").Add(float64(count))
|
||||
p.delegate.publishedBytesTotal.Add(float64(totalBytes))
|
||||
}
|
||||
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
clear(batch)
|
||||
return batch[:0], err
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) replayBatchViaDelegate(batch []queuedPublish, reason string, stage string) (delivered int, failed int, err error) {
|
||||
if len(batch) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
p.observeDelegateFallbackBatch(batch, reason, stage)
|
||||
if p.delegate == nil {
|
||||
return 0, len(batch), xerrors.New("delegate pubsub is nil")
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, item := range batch {
|
||||
if err := p.delegate.Publish(item.event, item.message); err != nil {
|
||||
failed++
|
||||
errs = append(errs, xerrors.Errorf("delegate publish %q: %w", item.event, err))
|
||||
continue
|
||||
}
|
||||
delivered++
|
||||
}
|
||||
return delivered, failed, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) resetSender() error {
|
||||
if p.newSender == nil {
|
||||
return nil
|
||||
}
|
||||
newSender, err := p.newSender(context.Background())
|
||||
if err != nil {
|
||||
p.metrics.SenderResetFailuresTotal.Inc()
|
||||
return err
|
||||
}
|
||||
oldSender := p.sender
|
||||
p.sender = newSender
|
||||
p.metrics.SenderResetsTotal.Inc()
|
||||
if oldSender == nil {
|
||||
return nil
|
||||
}
|
||||
if err := oldSender.Close(); err != nil {
|
||||
p.logger.Warn(context.Background(), "failed to close old batched pubsub sender after reset", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logFlushFailure(reason string, stage string, count int, totalBytes int, delivered int, failed int, senderErr error, fallbackErr error, resetErr error) {
|
||||
fields := []slog.Field{
|
||||
slog.F("reason", reason),
|
||||
slog.F("stage", stage),
|
||||
slog.F("count", count),
|
||||
slog.F("total_bytes", totalBytes),
|
||||
slog.F("delegate_delivered", delivered),
|
||||
slog.F("delegate_failed", failed),
|
||||
slog.Error(senderErr),
|
||||
}
|
||||
if fallbackErr != nil {
|
||||
fields = append(fields, slog.F("delegate_error", fallbackErr.Error()))
|
||||
}
|
||||
if resetErr != nil {
|
||||
fields = append(fields, slog.F("sender_reset_error", resetErr.Error()))
|
||||
}
|
||||
p.logger.Error(context.Background(), "batched pubsub flush failed", fields...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) drain(batch []queuedPublish) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), p.finalFlushTimeout)
|
||||
defer cancel()
|
||||
|
||||
var errs []error
|
||||
for {
|
||||
batch = p.drainIntoBatch(batch)
|
||||
if len(batch) == 0 {
|
||||
break
|
||||
}
|
||||
var err error
|
||||
batch, err = p.flushBatch(ctx, batch, batchFlushShutdown)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
dropped := p.dropPendingPublishes()
|
||||
if dropped > 0 {
|
||||
errs = append(errs, xerrors.Errorf("dropped %d queued notifications during shutdown", dropped))
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
errs = append(errs, xerrors.Errorf("shutdown flush timed out: %w", ctx.Err()))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) dropPendingPublishes() int {
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-p.publishCh:
|
||||
count++
|
||||
default:
|
||||
if count > 0 {
|
||||
queuedDepth := p.queuedCount.Add(-int64(count))
|
||||
p.observeQueueDepth(queuedDepth)
|
||||
}
|
||||
return count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BatchingPubsub) logPublishRejection(event string) {
|
||||
fields := []slog.Field{
|
||||
slog.F("event", event),
|
||||
slog.F("queue_size", cap(p.publishCh)),
|
||||
slog.F("queued", p.queuedCount.Load()),
|
||||
}
|
||||
select {
|
||||
case <-p.warnTicker.C:
|
||||
p.logger.Warn(context.Background(), "batched pubsub queue is full", fields...)
|
||||
default:
|
||||
p.logger.Debug(context.Background(), "batched pubsub queue is full", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
type pgBatchSender struct {
|
||||
logger slog.Logger
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func newPGBatchSender(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
prototype *sql.DB,
|
||||
connectURL string,
|
||||
) (*pgBatchSender, error) {
|
||||
connector, err := newConnector(ctx, logger, prototype, connectURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := sql.OpenDB(connector)
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxIdleTime(0)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, xerrors.Errorf("ping batched pubsub sender database: %w", err)
|
||||
}
|
||||
|
||||
return &pgBatchSender{logger: logger, db: db}, nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageBegin, err: xerrors.Errorf("begin batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed := false
|
||||
defer func() {
|
||||
if !committed {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, item := range batch {
|
||||
// This is safe because we are calling pq.QuoteLiteral. pg_notify does
|
||||
// not support the first parameter being a prepared statement.
|
||||
//nolint:gosec
|
||||
_, err = tx.ExecContext(ctx, `select pg_notify(`+pq.QuoteLiteral(item.event)+`, $1)`, item.message)
|
||||
if err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageExec, err: xerrors.Errorf("exec pg_notify: %w", err)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &batchFlushError{stage: batchFlushStageCommit, err: xerrors.Errorf("commit batched pubsub transaction: %w", err)}
|
||||
}
|
||||
committed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *pgBatchSender) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
@@ -1,520 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
prom_testutil "github.com/prometheus/client_golang/prometheus/testutil"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubScheduledFlush(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 2)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("two"), batch[1].message)
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 2, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubDefaultConfigUsesDedicatedSenderFirstDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: clock})
|
||||
|
||||
require.Equal(t, DefaultBatchingFlushInterval, ps.flushInterval)
|
||||
require.Equal(t, DefaultBatchingQueueSize, cap(ps.publishCh))
|
||||
require.Equal(t, defaultBatchingPressureWait, ps.pressureWait)
|
||||
require.Equal(t, defaultBatchingFinalFlushLimit, ps.finalFlushTimeout)
|
||||
}
|
||||
|
||||
func TestBatchChannelClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
event string
|
||||
want string
|
||||
}{
|
||||
{name: "stream notify", event: "chat:stream:123", want: batchChannelClassStreamNotify},
|
||||
{name: "owner event", event: "chat:owner:123", want: batchChannelClassOwnerEvent},
|
||||
{name: "config change", event: "chat:config_change", want: batchChannelClassConfigChange},
|
||||
{name: "fallback", event: "workspace:owner:123", want: batchChannelClassOther},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, batchChannelClass(tt.event))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubTimerFlushDrainsAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 64,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Enqueue many messages before the timer fires — all should be
|
||||
// drained and flushed in a single batch.
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
require.Empty(t, sender.Batches())
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batch := testutil.TryReceive(ctx, t, sender.flushes)
|
||||
require.Len(t, batch, 5)
|
||||
require.Equal(t, []byte("one"), batch[0].message)
|
||||
require.Equal(t, []byte("five"), batch[4].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubQueueFullFallsBackToDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
pressureTrap := clock.Trap().NewTimer("pubsubBatcher", "pressureWait")
|
||||
defer pressureTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.blockCh = make(chan struct{})
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 1,
|
||||
PressureWait: 10 * time.Millisecond,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
// Fill the queue (capacity 1).
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
// Fire the timer so the run loop starts flushing "one" — the
|
||||
// sender blocks on blockCh so the flush stays in-flight.
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
<-sender.started
|
||||
|
||||
// The run loop is blocked in flushBatch. Fill the queue again.
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
|
||||
// A third publish should fall back to the delegate (which has a
|
||||
// closed db, so the delegate Publish itself will error — but we
|
||||
// verify the fallback metric was incremented).
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- ps.Publish("chat:stream:a", []byte("three"))
|
||||
}()
|
||||
|
||||
pressureCall, err := pressureTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
pressureCall.MustRelease(ctx)
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
|
||||
err = testutil.TryReceive(ctx, t, errCh)
|
||||
// The delegate has a closed db so it returns an error from the
|
||||
// shared pool, not a batching-specific sentinel.
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonQueueFull, batchFlushStageNone)))
|
||||
|
||||
close(sender.blockCh)
|
||||
// Let the run loop finish the blocked flush and process "two".
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
require.NoError(t, ps.Close())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubCloseDrainsQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("two")))
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("three")))
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 3)
|
||||
require.Equal(t, []byte("one"), batches[0][0].message)
|
||||
require.Equal(t, []byte("two"), batches[0][1].message)
|
||||
require.Equal(t, []byte("three"), batches[0][2].message)
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, 1, sender.CloseCalls())
|
||||
}
|
||||
|
||||
func TestBatchingPubsubPreservesOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: time.Hour,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
for _, msg := range []string{"one", "two", "three", "four", "five"} {
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte(msg)))
|
||||
}
|
||||
|
||||
require.NoError(t, ps.Close())
|
||||
batches := sender.Batches()
|
||||
require.NotEmpty(t, batches)
|
||||
|
||||
messages := make([]string, 0, 5)
|
||||
for _, batch := range batches {
|
||||
for _, item := range batch {
|
||||
messages = append(messages, string(item.message))
|
||||
}
|
||||
}
|
||||
require.Equal(t, []string{"one", "two", "three", "four", "five"}, messages)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureMetrics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
clock := quartz.NewMock(t)
|
||||
newTimerTrap := clock.Trap().NewTimer("pubsubBatcher", "scheduledFlush")
|
||||
defer newTimerTrap.Close()
|
||||
resetTrap := clock.Trap().TimerReset("pubsubBatcher", "scheduledFlush")
|
||||
defer resetTrap.Close()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{
|
||||
Clock: clock,
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
QueueSize: 8,
|
||||
})
|
||||
|
||||
call, err := newTimerTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
call.MustRelease(ctx)
|
||||
|
||||
require.NoError(t, ps.Publish("chat:stream:a", []byte("one")))
|
||||
|
||||
clock.Advance(10 * time.Millisecond).MustWait(ctx)
|
||||
resetCall, err := resetTrap.Wait(ctx)
|
||||
require.NoError(t, err)
|
||||
resetCall.MustRelease(ctx)
|
||||
|
||||
batchSizeCount, batchSizeSum := histogramCountAndSum(t, ps.metrics.BatchSize)
|
||||
require.Equal(t, uint64(1), batchSizeCount)
|
||||
require.InDelta(t, 1, batchSizeSum, 0.000001)
|
||||
flushDurationCount, _ := histogramCountAndSum(t, ps.metrics.FlushDuration.WithLabelValues(batchFlushScheduled))
|
||||
require.Equal(t, uint64(1), flushDurationCount)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("true")))
|
||||
require.Zero(t, prom_testutil.ToFloat64(ps.metrics.QueueDepth))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureStageAccounting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stages := []string{batchFlushStageBegin, batchFlushStageExec, batchFlushStageCommit}
|
||||
for _, stage := range stages {
|
||||
stage := stage
|
||||
t.Run(stage, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = stage
|
||||
ps, delegate := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:test",
|
||||
channelClass: batchChannelClass("chat:stream:test"),
|
||||
message: []byte("fallback-" + stage),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, stage)))
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(delegate.publishesTotal.WithLabelValues("false")))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureResetSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
firstSender := newFakeBatchSender()
|
||||
firstSender.err = context.DeadlineExceeded
|
||||
firstSender.errStage = batchFlushStageExec
|
||||
secondSender := newFakeBatchSender()
|
||||
ps, _ := newTestBatchingPubsub(t, firstSender, BatchingConfig{Clock: clock})
|
||||
ps.newSender = func(context.Context) (batchSender, error) {
|
||||
return secondSender, nil
|
||||
}
|
||||
|
||||
firstBatch := []queuedPublish{{
|
||||
event: "chat:stream:first",
|
||||
channelClass: batchChannelClass("chat:stream:first"),
|
||||
message: []byte("first"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(firstBatch)))
|
||||
_, err := ps.flushBatch(context.Background(), firstBatch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.SenderResetsTotal))
|
||||
require.Equal(t, 1, firstSender.CloseCalls())
|
||||
|
||||
secondBatch := []queuedPublish{{
|
||||
event: "chat:stream:second",
|
||||
channelClass: batchChannelClass("chat:stream:second"),
|
||||
message: []byte("second"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(secondBatch)))
|
||||
_, err = ps.flushBatch(context.Background(), secondBatch, batchFlushScheduled)
|
||||
require.NoError(t, err)
|
||||
batches := secondSender.Batches()
|
||||
require.Len(t, batches, 1)
|
||||
require.Len(t, batches[0], 1)
|
||||
require.Equal(t, []byte("second"), batches[0][0].message)
|
||||
}
|
||||
|
||||
func TestBatchingPubsubFlushFailureReturnsJoinedErrorWhenReplayFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sender := newFakeBatchSender()
|
||||
sender.err = context.DeadlineExceeded
|
||||
sender.errStage = batchFlushStageExec
|
||||
ps, _ := newTestBatchingPubsub(t, sender, BatchingConfig{Clock: quartz.NewMock(t)})
|
||||
|
||||
batch := []queuedPublish{{
|
||||
event: "chat:stream:error",
|
||||
channelClass: batchChannelClass("chat:stream:error"),
|
||||
message: []byte("error"),
|
||||
}}
|
||||
ps.queuedCount.Store(int64(len(batch)))
|
||||
_, err := ps.flushBatch(context.Background(), batch, batchFlushScheduled)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, context.DeadlineExceeded.Error())
|
||||
require.ErrorContains(t, err, `delegate publish "chat:stream:error"`)
|
||||
require.Equal(t, float64(1), prom_testutil.ToFloat64(ps.metrics.DelegateFallbacksTotal.WithLabelValues(batchChannelClassStreamNotify, batchDelegateFallbackReasonFlushError, batchFlushStageExec)))
|
||||
}
|
||||
|
||||
func newTestBatchingPubsub(t *testing.T, sender batchSender, cfg BatchingConfig) (*BatchingPubsub, *PGPubsub) {
|
||||
t.Helper()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
// Use a closed *sql.DB so that delegate.Publish returns a real
|
||||
// error instead of panicking on a nil pointer when the batching
|
||||
// queue falls back to the shared pool under pressure.
|
||||
closedDB := newClosedDB(t)
|
||||
delegate := newWithoutListener(logger.Named("delegate"), closedDB)
|
||||
ps, err := newBatchingPubsub(logger.Named("batcher"), delegate, sender, cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = ps.Close()
|
||||
})
|
||||
return ps, delegate
|
||||
}
|
||||
|
||||
// newClosedDB returns an *sql.DB whose connections have been closed,
|
||||
// so any ExecContext call returns an error rather than panicking.
|
||||
func newClosedDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("postgres", "host=localhost dbname=closed_db_stub sslmode=disable connect_timeout=1")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
return db
|
||||
}
|
||||
|
||||
type fakeBatchSender struct {
|
||||
mu sync.Mutex
|
||||
batches [][]queuedPublish
|
||||
flushes chan []queuedPublish
|
||||
started chan struct{}
|
||||
blockCh chan struct{}
|
||||
err error
|
||||
errStage string
|
||||
closeErr error
|
||||
closeCall int
|
||||
}
|
||||
|
||||
func newFakeBatchSender() *fakeBatchSender {
|
||||
return &fakeBatchSender{
|
||||
flushes: make(chan []queuedPublish, 16),
|
||||
started: make(chan struct{}, 16),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Flush(ctx context.Context, batch []queuedPublish) error {
|
||||
select {
|
||||
case s.started <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
if s.blockCh != nil {
|
||||
select {
|
||||
case <-s.blockCh:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
clone := make([]queuedPublish, len(batch))
|
||||
for i, item := range batch {
|
||||
clone[i] = queuedPublish{
|
||||
event: item.event,
|
||||
message: bytes.Clone(item.message),
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.batches = append(s.batches, clone)
|
||||
s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case s.flushes <- clone:
|
||||
default:
|
||||
}
|
||||
if s.err == nil {
|
||||
return nil
|
||||
}
|
||||
if s.errStage != "" {
|
||||
return &batchFlushError{stage: s.errStage, err: s.err}
|
||||
}
|
||||
return s.err
|
||||
}
|
||||
|
||||
type metricWriter interface {
|
||||
Write(*dto.Metric) error
|
||||
}
|
||||
|
||||
func histogramCountAndSum(t *testing.T, observer any) (uint64, float64) {
|
||||
t.Helper()
|
||||
writer, ok := observer.(metricWriter)
|
||||
require.True(t, ok)
|
||||
|
||||
metric := &dto.Metric{}
|
||||
require.NoError(t, writer.Write(metric))
|
||||
histogram := metric.GetHistogram()
|
||||
require.NotNil(t, histogram)
|
||||
return histogram.GetSampleCount(), histogram.GetSampleSum()
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.closeCall++
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) Batches() [][]queuedPublish {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
clone := make([][]queuedPublish, len(s.batches))
|
||||
for i, batch := range s.batches {
|
||||
clone[i] = make([]queuedPublish, len(batch))
|
||||
copy(clone[i], batch)
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func (s *fakeBatchSender) CloseCalls() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.closeCall
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package pubsub_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestBatchingPubsubDedicatedSenderConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
listenerConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
require.NotEqual(t, fmt.Sprintf("%p", listenerConn), fmt.Sprintf("%p", senderConn))
|
||||
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 1)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("hello")))
|
||||
require.Equal(t, []byte("hello"), testutil.TryReceive(ctx, t, messageCh))
|
||||
}
|
||||
|
||||
func TestBatchingPubsubReconnectsAfterSenderDisconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
connectionURL, err := dbtestutil.Open(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
trackedDriver := dbtestutil.NewDriver()
|
||||
defer trackedDriver.Close()
|
||||
tconn, err := trackedDriver.Connector(connectionURL)
|
||||
require.NoError(t, err)
|
||||
trackedDB := sql.OpenDB(tconn)
|
||||
defer trackedDB.Close()
|
||||
|
||||
base, err := pubsub.New(ctx, logger.Named("base"), trackedDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
defer base.Close()
|
||||
|
||||
_ = testutil.TryReceive(ctx, t, trackedDriver.Connections) // listener connection
|
||||
batched, err := pubsub.NewBatching(ctx, logger.Named("batched"), base, trackedDB, connectionURL, pubsub.BatchingConfig{
|
||||
FlushInterval: 10 * time.Millisecond,
|
||||
|
||||
QueueSize: 8,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer batched.Close()
|
||||
|
||||
senderConn := testutil.TryReceive(ctx, t, trackedDriver.Connections)
|
||||
event := t.Name()
|
||||
messageCh := make(chan []byte, 4)
|
||||
cancel, err := batched.Subscribe(event, func(_ context.Context, message []byte) {
|
||||
messageCh <- message
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, batched.Publish(event, []byte("before-disconnect")))
|
||||
require.Equal(t, []byte("before-disconnect"), testutil.TryReceive(ctx, t, messageCh))
|
||||
require.NoError(t, senderConn.Close())
|
||||
|
||||
reconnected := false
|
||||
delivered := false
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
if !reconnected {
|
||||
select {
|
||||
case conn := <-trackedDriver.Connections:
|
||||
reconnected = conn != nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-messageCh:
|
||||
default:
|
||||
}
|
||||
if err := batched.Publish(event, []byte("after-disconnect")); err != nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case msg := <-messageCh:
|
||||
delivered = string(msg) == "after-disconnect"
|
||||
case <-time.After(testutil.IntervalFast):
|
||||
delivered = false
|
||||
}
|
||||
return reconnected && delivered
|
||||
}, testutil.IntervalMedium, "batched sender did not recover after disconnect")
|
||||
}
|
||||
@@ -487,14 +487,12 @@ func (d logDialer) DialContext(ctx context.Context, network, address string) (ne
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectURL string) (driver.Connector, error) {
|
||||
if db == nil {
|
||||
return nil, xerrors.New("database is nil")
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
// Creates a new listener using pq.
|
||||
var (
|
||||
dialer = logDialer{
|
||||
logger: logger,
|
||||
logger: p.logger,
|
||||
// pq.defaultDialer uses a zero net.Dialer as well.
|
||||
d: net.Dialer{},
|
||||
}
|
||||
@@ -503,38 +501,28 @@ func newConnector(ctx context.Context, logger slog.Logger, db *sql.DB, connectUR
|
||||
)
|
||||
|
||||
// Create a custom connector if the database driver supports it.
|
||||
connectorCreator, ok := db.Driver().(database.ConnectorCreator)
|
||||
connectorCreator, ok := p.db.Driver().(database.ConnectorCreator)
|
||||
if ok {
|
||||
connector, err = connectorCreator.Connector(connectURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create custom connector: %w", err)
|
||||
return xerrors.Errorf("create custom connector: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Use the default pq connector otherwise.
|
||||
// use the default pq connector otherwise
|
||||
connector, err = pq.NewConnector(connectURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create pq connector: %w", err)
|
||||
return xerrors.Errorf("create pq connector: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the dialer if the connector supports it.
|
||||
dc, ok := connector.(database.DialerConnector)
|
||||
if !ok {
|
||||
logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
p.logger.Critical(ctx, "connector does not support setting log dialer, database connection debug logs will be missing")
|
||||
} else {
|
||||
dc.Dialer(dialer)
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error {
|
||||
p.connected.Set(0)
|
||||
connector, err := newConnector(ctx, p.logger, p.db, connectURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
errCh = make(chan error, 1)
|
||||
sentErrCh = false
|
||||
|
||||
@@ -128,22 +128,6 @@ type sqlcQuerier interface {
|
||||
// connection events (connect, disconnect, open, close) which are handled
|
||||
// separately by DeleteOldAuditLogConnectionEvents.
|
||||
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
|
||||
// TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
// chat_files(created_at) for purge query performance.
|
||||
// See: https://github.com/coder/internal/issues/1438
|
||||
// Deletes chat files that are older than the given threshold and are
|
||||
// not referenced by any chat that is still active or was archived
|
||||
// within the same threshold window. This covers two cases:
|
||||
// 1. Orphaned files not linked to any chat.
|
||||
// 2. Files whose every referencing chat has been archived for longer
|
||||
// than the retention period.
|
||||
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error)
|
||||
// Deletes chats that have been archived for longer than the given
|
||||
// threshold. Active (non-archived) chats are never deleted.
|
||||
// Related chat_messages, chat_diff_statuses, and
|
||||
// chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
// Parent/root references on child chats are SET NULL.
|
||||
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
|
||||
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
|
||||
// Delete all notification messages which have not been updated for over a week.
|
||||
DeleteOldNotificationMessages(ctx context.Context) error
|
||||
@@ -168,7 +152,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecret(ctx context.Context, id uuid.UUID) error
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -271,27 +255,16 @@ type sqlcQuerier interface {
|
||||
// otherwise the setting defaults to true.
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
// Aggregates message-level metrics per chat for messages created
|
||||
// after the given timestamp. Uses message created_at so that
|
||||
// ongoing activity in long-running chats is captured each window.
|
||||
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
// Returns all model configurations for telemetry snapshot collection.
|
||||
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
// Returns the chat retention period in days. Chats archived longer
|
||||
// than this and orphaned chat files older than this are purged by
|
||||
// dbpurge. Returns 30 (days) when no value has been configured.
|
||||
// A value of 0 disables chat purging entirely.
|
||||
GetChatRetentionDays(ctx context.Context) (int32, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
// single read to avoid torn reads between separate site-config lookups.
|
||||
@@ -310,10 +283,6 @@ type sqlcQuerier interface {
|
||||
GetChatWorkspaceTTL(ctx context.Context) (string, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
|
||||
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
|
||||
// Retrieves chats updated after the given timestamp for telemetry
|
||||
// snapshot collection. Uses updated_at so that long-running chats
|
||||
// still appear in each snapshot window while they are active.
|
||||
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -509,10 +478,8 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck and need recovery. This covers:
|
||||
// 1. Running chats whose heartbeat has expired (worker crash).
|
||||
// 2. Chats awaiting client action (requires_action) past the
|
||||
// timeout threshold (client disappeared).
|
||||
// Find chats that appear stuck (running but heartbeat has expired).
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
@@ -631,6 +598,7 @@ type sqlcQuerier interface {
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUserLinksByUserID(ctx context.Context, userID uuid.UUID) ([]UserLink, error)
|
||||
GetUserNotificationPreferences(ctx context.Context, userID uuid.UUID) ([]NotificationPreference, error)
|
||||
GetUserSecret(ctx context.Context, id uuid.UUID) (UserSecret, error)
|
||||
GetUserSecretByUserIDAndName(ctx context.Context, arg GetUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
// GetUserStatusCounts returns the count of users in each status over time.
|
||||
// The time range is inclusively defined by the start_time and end_time parameters.
|
||||
@@ -850,13 +818,7 @@ type sqlcQuerier interface {
|
||||
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
|
||||
ListTasks(ctx context.Context, arg ListTasksParams) ([]Task, error)
|
||||
ListUserChatCompactionThresholds(ctx context.Context, userID uuid.UUID) ([]UserConfig, error)
|
||||
// Returns metadata only (no value or value_key_id) for the
|
||||
// REST API list and get endpoints.
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]ListUserSecretsRow, error)
|
||||
// Returns all columns including the secret value. Used by the
|
||||
// provisioner (build-time injection) and the agent manifest
|
||||
// (runtime injection).
|
||||
ListUserSecretsWithValues(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListUserSecrets(ctx context.Context, userID uuid.UUID) ([]UserSecret, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
|
||||
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
|
||||
@@ -898,10 +860,6 @@ type sqlcQuerier interface {
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
// Unarchives a chat (and its children). Stale file references are
|
||||
// handled automatically by FK cascades on chat_file_links: when
|
||||
// dbpurge deletes a chat_files row, the corresponding
|
||||
// chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
@@ -999,7 +957,7 @@ type sqlcQuerier interface {
|
||||
UpdateUserProfile(ctx context.Context, arg UpdateUserProfileParams) (User, error)
|
||||
UpdateUserQuietHoursSchedule(ctx context.Context, arg UpdateUserQuietHoursScheduleParams) (User, error)
|
||||
UpdateUserRoles(ctx context.Context, arg UpdateUserRolesParams) (User, error)
|
||||
UpdateUserSecretByUserIDAndName(ctx context.Context, arg UpdateUserSecretByUserIDAndNameParams) (UserSecret, error)
|
||||
UpdateUserSecret(ctx context.Context, arg UpdateUserSecretParams) (UserSecret, error)
|
||||
UpdateUserStatus(ctx context.Context, arg UpdateUserStatusParams) (User, error)
|
||||
UpdateUserTaskNotificationAlertDismissed(ctx context.Context, arg UpdateUserTaskNotificationAlertDismissedParams) (bool, error)
|
||||
UpdateUserTerminalFont(ctx context.Context, arg UpdateUserTerminalFontParams) (UserConfig, error)
|
||||
@@ -1043,7 +1001,6 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
|
||||
@@ -7339,7 +7339,13 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, secretID, createdSecret.ID)
|
||||
|
||||
// 2. READ by UserID and Name
|
||||
// 2. READ by ID
|
||||
readSecret, err := db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readSecret.Name)
|
||||
|
||||
// 3. READ by UserID and Name
|
||||
readByNameParams := database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
@@ -7347,43 +7353,33 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
readByNameSecret, err := db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, createdSecret.ID, readByNameSecret.ID)
|
||||
assert.Equal(t, "workflow-secret", readByNameSecret.Name)
|
||||
|
||||
// 3. LIST (metadata only)
|
||||
// 4. LIST
|
||||
secrets, err := db.ListUserSecrets(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 1)
|
||||
assert.Equal(t, createdSecret.ID, secrets[0].ID)
|
||||
|
||||
// 4. LIST with values
|
||||
secretsWithValues, err := db.ListUserSecretsWithValues(ctx, testUser.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretsWithValues, 1)
|
||||
assert.Equal(t, "workflow-value", secretsWithValues[0].Value)
|
||||
|
||||
// 5. UPDATE (partial - only description)
|
||||
updateParams := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
UpdateDescription: true,
|
||||
Description: "Updated workflow description",
|
||||
// 5. UPDATE
|
||||
updateParams := database.UpdateUserSecretParams{
|
||||
ID: createdSecret.ID,
|
||||
Description: "Updated workflow description",
|
||||
Value: "updated-workflow-value",
|
||||
EnvName: "UPDATED_WORKFLOW_ENV",
|
||||
FilePath: "/updated/workflow/path",
|
||||
}
|
||||
|
||||
updatedSecret, err := db.UpdateUserSecretByUserIDAndName(ctx, updateParams)
|
||||
updatedSecret, err := db.UpdateUserSecret(ctx, updateParams)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated workflow description", updatedSecret.Description)
|
||||
assert.Equal(t, "workflow-value", updatedSecret.Value) // Value unchanged
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
assert.Equal(t, "updated-workflow-value", updatedSecret.Value)
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
err = db.DeleteUserSecret(ctx, createdSecret.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletion
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, readByNameParams)
|
||||
_, err = db.GetUserSecret(ctx, createdSecret.ID)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no rows in result set")
|
||||
|
||||
@@ -7453,13 +7449,9 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
})
|
||||
|
||||
// Verify both secrets exist
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret1.Name,
|
||||
})
|
||||
_, err = db.GetUserSecret(ctx, secret1.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = db.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID, Name: secret2.Name,
|
||||
})
|
||||
_, err = db.GetUserSecret(ctx, secret2.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -7482,14 +7474,14 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Create secrets for users
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
user1Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user1.ID,
|
||||
Name: "user1-secret",
|
||||
Description: "User 1's secret",
|
||||
Value: "user1-value",
|
||||
})
|
||||
|
||||
_ = dbgen.UserSecret(t, db, database.UserSecret{
|
||||
user2Secret := dbgen.UserSecret(t, db, database.UserSecret{
|
||||
UserID: user2.ID,
|
||||
Name: "user2-secret",
|
||||
Description: "User 2's secret",
|
||||
@@ -7499,8 +7491,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
subject rbac.Subject
|
||||
lookupUserID uuid.UUID
|
||||
lookupName string
|
||||
secretID uuid.UUID
|
||||
expectedAccess bool
|
||||
}{
|
||||
{
|
||||
@@ -7510,8 +7501,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: true,
|
||||
},
|
||||
{
|
||||
@@ -7521,8 +7511,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user2.ID,
|
||||
lookupName: "user2-secret",
|
||||
secretID: user2Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7532,8 +7521,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.RoleOwner()},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
{
|
||||
@@ -7543,8 +7531,7 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
Roles: rbac.RoleIdentifiers{rbac.ScopedRoleOrgAdmin(org.ID)},
|
||||
Scope: rbac.ScopeAll,
|
||||
},
|
||||
lookupUserID: user1.ID,
|
||||
lookupName: "user1-secret",
|
||||
secretID: user1Secret.ID,
|
||||
expectedAccess: false,
|
||||
},
|
||||
}
|
||||
@@ -7556,10 +7543,8 @@ func TestUserSecretsAuthorization(t *testing.T) {
|
||||
|
||||
authCtx := dbauthz.As(ctx, tc.subject)
|
||||
|
||||
_, err := authDB.GetUserSecretByUserIDAndName(authCtx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: tc.lookupUserID,
|
||||
Name: tc.lookupName,
|
||||
})
|
||||
// Test GetUserSecret
|
||||
_, err := authDB.GetUserSecret(authCtx, tc.secretID)
|
||||
|
||||
if tc.expectedAccess {
|
||||
require.NoError(t, err, "expected access to be granted")
|
||||
@@ -9085,11 +9070,10 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
CredentialKind: database.CredentialKindCentralized,
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
|
||||
+91
-506
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -18,37 +18,3 @@ FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
-- TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
-- chat_files(created_at) for purge query performance.
|
||||
-- See: https://github.com/coder/internal/issues/1438
|
||||
-- name: DeleteOldChatFiles :execrows
|
||||
-- Deletes chat files that are older than the given threshold and are
|
||||
-- not referenced by any chat that is still active or was archived
|
||||
-- within the same threshold window. This covers two cases:
|
||||
-- 1. Orphaned files not linked to any chat.
|
||||
-- 2. Files whose every referencing chat has been archived for longer
|
||||
-- than the retention period.
|
||||
WITH kept_file_ids AS (
|
||||
-- NOTE: This uses updated_at as a proxy for archive time
|
||||
-- because there is no archived_at column. Correctness
|
||||
-- requires that updated_at is never backdated on archived
|
||||
-- chats. See ArchiveChatByID.
|
||||
SELECT DISTINCT cfl.file_id
|
||||
FROM chat_file_links cfl
|
||||
JOIN chats c ON c.id = cfl.chat_id
|
||||
WHERE c.archived = false
|
||||
OR c.updated_at >= @before_time::timestamptz
|
||||
),
|
||||
deletable AS (
|
||||
SELECT cf.id
|
||||
FROM chat_files cf
|
||||
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
|
||||
WHERE cf.created_at < @before_time::timestamptz
|
||||
AND k.file_id IS NULL
|
||||
ORDER BY cf.created_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chat_files
|
||||
USING deletable
|
||||
WHERE chat_files.id = deletable.id;
|
||||
|
||||
@@ -10,14 +10,9 @@ FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :many
|
||||
-- Unarchives a chat (and its children). Stale file references are
|
||||
-- handled automatically by FK cascades on chat_file_links: when
|
||||
-- dbpurge deletes a chat_files row, the corresponding
|
||||
-- chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
WITH chats AS (
|
||||
UPDATE chats SET
|
||||
archived = false,
|
||||
updated_at = NOW()
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
@@ -399,8 +394,7 @@ INSERT INTO chats (
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels,
|
||||
dynamic_tools
|
||||
labels
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -413,8 +407,7 @@ INSERT INTO chats (
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('dynamic_tools')::jsonb
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -671,19 +664,15 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck and need recovery. This covers:
|
||||
-- 1. Running chats whose heartbeat has expired (worker crash).
|
||||
-- 2. Chats awaiting client action (requires_action) past the
|
||||
-- timeout threshold (client disappeared).
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
(status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz)
|
||||
OR (status = 'requires_action'::chat_status
|
||||
AND updated_at < @stale_threshold::timestamptz);
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
@@ -1231,65 +1220,3 @@ LIMIT 1;
|
||||
UPDATE chats
|
||||
SET last_read_message_id = @last_read_message_id::bigint
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: DeleteOldChats :execrows
|
||||
-- Deletes chats that have been archived for longer than the given
|
||||
-- threshold. Active (non-archived) chats are never deleted.
|
||||
-- Related chat_messages, chat_diff_statuses, and
|
||||
-- chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
-- Parent/root references on child chats are SET NULL.
|
||||
WITH deletable AS (
|
||||
SELECT id
|
||||
FROM chats
|
||||
WHERE archived = true
|
||||
AND updated_at < @before_time::timestamptz
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chats
|
||||
USING deletable
|
||||
WHERE chats.id = deletable.id
|
||||
AND chats.archived = true;
|
||||
|
||||
-- name: GetChatsUpdatedAfter :many
|
||||
-- Retrieves chats updated after the given timestamp for telemetry
|
||||
-- snapshot collection. Uses updated_at so that long-running chats
|
||||
-- still appear in each snapshot window while they are active.
|
||||
SELECT
|
||||
id, owner_id, created_at, updated_at, status,
|
||||
(parent_chat_id IS NOT NULL)::bool AS has_parent,
|
||||
root_chat_id, workspace_id,
|
||||
mode, archived, last_model_config_id
|
||||
FROM chats
|
||||
WHERE updated_at > @updated_after;
|
||||
|
||||
-- name: GetChatMessageSummariesPerChat :many
|
||||
-- Aggregates message-level metrics per chat for messages created
|
||||
-- after the given timestamp. Uses message created_at so that
|
||||
-- ongoing activity in long-running chats is captured each window.
|
||||
SELECT
|
||||
cm.chat_id,
|
||||
COUNT(*)::bigint AS message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
|
||||
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
|
||||
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
|
||||
FROM chat_messages cm
|
||||
WHERE cm.created_at > @created_after
|
||||
AND cm.deleted = false
|
||||
GROUP BY cm.chat_id;
|
||||
|
||||
-- name: GetChatModelConfigsForTelemetry :many
|
||||
-- Returns all model configurations for telemetry snapshot collection.
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
|
||||
@@ -236,20 +236,3 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = @workspace_ttl::text
|
||||
WHERE site_configs.key = 'agents_workspace_ttl';
|
||||
|
||||
-- name: GetChatRetentionDays :one
|
||||
-- Returns the chat retention period in days. Chats archived longer
|
||||
-- than this and orphaned chat files older than this are purged by
|
||||
-- dbpurge. Returns 30 (days) when no value has been configured.
|
||||
-- A value of 0 disables chat purging entirely.
|
||||
SELECT COALESCE(
|
||||
(SELECT value::integer FROM site_configs
|
||||
WHERE key = 'agents_chat_retention_days'),
|
||||
30
|
||||
) :: integer AS retention_days;
|
||||
|
||||
-- name: UpsertChatRetentionDays :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
|
||||
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
|
||||
WHERE site_configs.key = 'agents_chat_retention_days';
|
||||
|
||||
@@ -1,26 +1,14 @@
|
||||
-- name: GetUserSecretByUserIDAndName :one
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1 AND name = $2;
|
||||
|
||||
-- name: GetUserSecret :one
|
||||
SELECT * FROM user_secrets
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ListUserSecrets :many
|
||||
-- Returns metadata only (no value or value_key_id) for the
|
||||
-- REST API list and get endpoints.
|
||||
SELECT
|
||||
id, user_id, name, description,
|
||||
env_name, file_path,
|
||||
created_at, updated_at
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: ListUserSecretsWithValues :many
|
||||
-- Returns all columns including the secret value. Used by the
|
||||
-- provisioner (build-time injection) and the agent manifest
|
||||
-- (runtime injection).
|
||||
SELECT *
|
||||
FROM user_secrets
|
||||
WHERE user_id = @user_id
|
||||
SELECT * FROM user_secrets
|
||||
WHERE user_id = $1
|
||||
ORDER BY name ASC;
|
||||
|
||||
-- name: CreateUserSecret :one
|
||||
@@ -30,32 +18,23 @@ INSERT INTO user_secrets (
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
value_key_id,
|
||||
env_name,
|
||||
file_path
|
||||
) VALUES (
|
||||
@id,
|
||||
@user_id,
|
||||
@name,
|
||||
@description,
|
||||
@value,
|
||||
@value_key_id,
|
||||
@env_name,
|
||||
@file_path
|
||||
$1, $2, $3, $4, $5, $6, $7
|
||||
) RETURNING *;
|
||||
|
||||
-- name: UpdateUserSecretByUserIDAndName :one
|
||||
-- name: UpdateUserSecret :one
|
||||
UPDATE user_secrets
|
||||
SET
|
||||
value = CASE WHEN @update_value::bool THEN @value ELSE value END,
|
||||
value_key_id = CASE WHEN @update_value::bool THEN @value_key_id ELSE value_key_id END,
|
||||
description = CASE WHEN @update_description::bool THEN @description ELSE description END,
|
||||
env_name = CASE WHEN @update_env_name::bool THEN @env_name ELSE env_name END,
|
||||
file_path = CASE WHEN @update_file_path::bool THEN @file_path ELSE file_path END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
description = $2,
|
||||
value = $3,
|
||||
env_name = $4,
|
||||
file_path = $5,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = $1
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecret :exec
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
WHERE id = $1;
|
||||
|
||||
@@ -398,10 +398,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use
|
||||
// from large dynamic tool schemas.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -492,50 +488,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UnsafeDynamicTools) > 250 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Too many dynamic tools.",
|
||||
Detail: "Maximum 250 dynamic tools per chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that dynamic tool names are non-empty and unique
|
||||
// within the list. Name collision with built-in tools is
|
||||
// checked at chatloop time when the full tool set is known.
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
|
||||
for _, dt := range req.UnsafeDynamicTools {
|
||||
if dt.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Dynamic tool name must not be empty.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, exists := seenNames[dt.Name]; exists {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Duplicate dynamic tool name.",
|
||||
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
|
||||
})
|
||||
return
|
||||
}
|
||||
seenNames[dt.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicToolsJSON json.RawMessage
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
var err error
|
||||
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal dynamic tools.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -545,7 +497,6 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -3232,70 +3183,6 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Get chat retention days
|
||||
// @ID get-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Produce json
|
||||
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
|
||||
// @Router /experimental/chats/config/retention-days [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
|
||||
RetentionDays: retentionDays,
|
||||
})
|
||||
}
|
||||
|
||||
// Keep in sync with retentionDaysMaximum in
|
||||
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
|
||||
const retentionDaysMaximum = 3650 // ~10 years
|
||||
|
||||
// @Summary Update chat retention days
|
||||
// @ID update-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Accept json
|
||||
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
|
||||
// @Success 204
|
||||
// @Router /experimental/chats/config/retention-days [put]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
var req codersdk.UpdateChatRetentionDaysRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -5800,77 +5687,3 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
RecentPRs: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
var req codersdk.SubmitToolResultsRequest
|
||||
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Results) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one tool result is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fast-path check outside the transaction. The authoritative
|
||||
// check happens inside SubmitToolResults under a row lock.
|
||||
if chat.Status != database.ChatStatusRequiresAction {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var dynamicTools json.RawMessage
|
||||
if chat.DynamicTools.Valid {
|
||||
dynamicTools = chat.DynamicTools.RawMessage
|
||||
}
|
||||
|
||||
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: apiKey.UserID,
|
||||
ModelConfigID: chat.LastModelConfigID,
|
||||
Results: req.Results,
|
||||
DynamicTools: dynamicTools,
|
||||
})
|
||||
if err != nil {
|
||||
var validationErr *chatd.ToolResultValidationError
|
||||
var conflictErr *chatd.ToolResultStatusConflictError
|
||||
switch {
|
||||
case errors.As(err, &conflictErr):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
case errors.As(err, &validationErr):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: validationErr.Message,
|
||||
Detail: validationErr.Detail,
|
||||
})
|
||||
default:
|
||||
api.Logger.Error(ctx, "tool results submission failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error submitting tool results.",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+18
-431
@@ -16,9 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -270,10 +268,17 @@ func TestPostChats(t *testing.T) {
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Member without agents-access should be denied.
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -283,7 +288,6 @@ func TestPostChats(t *testing.T) {
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -752,7 +756,15 @@ func TestListChats(t *testing.T) {
|
||||
// returning empty because no chats exist.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
@@ -7735,62 +7747,6 @@ func TestChatWorkspaceTTL(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestChatRetentionDays(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Default value is 30 (days) when nothing has been configured.
|
||||
resp, err := adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get default")
|
||||
require.Equal(t, int32(30), resp.RetentionDays, "default should be 30")
|
||||
|
||||
// Admin can set retention days to 90.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 90,
|
||||
})
|
||||
require.NoError(t, err, "admin set 90")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after set")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "should return 90")
|
||||
|
||||
// Non-admin member can read the value.
|
||||
resp, err = memberClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "member get")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "member should see same value")
|
||||
|
||||
// Non-admin member cannot write.
|
||||
err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
|
||||
// Admin can disable purge by setting 0.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 0,
|
||||
})
|
||||
require.NoError(t, err, "admin set 0")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after zero")
|
||||
require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable")
|
||||
|
||||
// Validation: negative value is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: -1,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
|
||||
// Validation: exceeding the 3650-day maximum is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go.
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -8197,375 +8153,6 @@ func TestGetChatsByWorkspace(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSubmitToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// setupRequiresAction creates a chat via the DB with dynamic tools,
|
||||
// inserts an assistant message containing tool-call parts for each
|
||||
// given toolCallID, and sets the chat status to requires_action.
|
||||
// It returns the chat row so callers can exercise the endpoint.
|
||||
setupRequiresAction := func(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ownerID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
dynamicToolName string,
|
||||
toolCallIDs []string,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
// Marshal dynamic tools into the chat row.
|
||||
dynamicTools := []mcp.Tool{{
|
||||
Name: dynamicToolName,
|
||||
Description: "a test dynamic tool",
|
||||
InputSchema: mcp.ToolInputSchema{Type: "object"},
|
||||
}}
|
||||
dtJSON, err := json.Marshal(dynamicTools)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "tool-results-test",
|
||||
DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build assistant message with tool-call parts.
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs))
|
||||
for _, id := range toolCallIDs {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: id,
|
||||
ToolName: dynamicToolName,
|
||||
Args: json.RawMessage(`{"key":"value"}`),
|
||||
})
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfigID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Content: []string{string(content.RawMessage)},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Transition to requires_action.
|
||||
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRequiresAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chat.Status)
|
||||
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_abc", "call_def"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify status is no longer requires_action. The chatd
|
||||
// loop may have already picked the chat up and
|
||||
// transitioned it further (pending → running → …), so we
|
||||
// accept any non-requires_action status.
|
||||
gotChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status,
|
||||
"chat should no longer be in requires_action after submitting tool results")
|
||||
|
||||
// Verify tool-result messages were persisted.
|
||||
msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var toolResultCount int
|
||||
for _, msg := range msgsResp.Messages {
|
||||
if msg.Role == codersdk.ChatMessageRoleTool {
|
||||
toolResultCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, len(toolCallIDs), toolResultCount,
|
||||
"expected one tool-result message per submitted result")
|
||||
})
|
||||
|
||||
t.Run("WrongStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat that is NOT in requires_action status.
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "wrong-status-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusConflict)
|
||||
})
|
||||
|
||||
t.Run("MissingResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_one", "call_two"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit only one of the two required results.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("UnexpectedResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_real"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit a result with a wrong tool_call_id.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("InvalidJSONOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_json"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// We must bypass the SDK client because json.RawMessage
|
||||
// rejects invalid JSON during json.Marshal. A raw HTTP
|
||||
// request lets the invalid payload reach the server so we
|
||||
// can verify server-side validation.
|
||||
rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}`
|
||||
url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolCallID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_dup1", "call_dup2"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Contains(t, sdkErr.Message, "Duplicate tool_call_id")
|
||||
})
|
||||
|
||||
t.Run("EmptyResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_empty"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_other"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Create a second user and try to submit tool results
|
||||
// to user A's chat.
|
||||
otherClientRaw, _ := coderdtest.CreateAnotherUser(
|
||||
t, client.Client, user.OrganizationID,
|
||||
rbac.RoleAgentsAccess(),
|
||||
)
|
||||
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
|
||||
|
||||
err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChats_DynamicToolValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TooManyTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
tools := make([]codersdk.DynamicTool, 251)
|
||||
for i := range tools {
|
||||
tools[i] = codersdk.DynamicTool{
|
||||
Name: fmt.Sprintf("tool-%d", i),
|
||||
}
|
||||
}
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: tools,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Too many dynamic tools.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("EmptyToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: ""},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: "dup-tool"},
|
||||
{Name: "dup-tool"},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message)
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ func TestGetOrgMembersFilter(t *testing.T) {
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Members))
|
||||
|
||||
@@ -32,9 +32,8 @@ func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error))
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
ToolCalls []codersdk.ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
@@ -45,5 +44,4 @@ const (
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
ChatEventKindActionRequired ChatEventKind = "action_required"
|
||||
)
|
||||
|
||||
@@ -776,40 +776,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return nil
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
chats, err := r.options.Database.GetChatsUpdatedAfter(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chats updated after: %w", err)
|
||||
}
|
||||
snapshot.Chats = make([]Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
snapshot.Chats = append(snapshot.Chats, ConvertChat(chat))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
summaries, err := r.options.Database.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat message summaries: %w", err)
|
||||
}
|
||||
snapshot.ChatMessageSummaries = make([]ChatMessageSummary, 0, len(summaries))
|
||||
for _, s := range summaries {
|
||||
snapshot.ChatMessageSummaries = append(snapshot.ChatMessageSummaries, ConvertChatMessageSummary(s))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
configs, err := r.options.Database.GetChatModelConfigsForTelemetry(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat model configs: %w", err)
|
||||
}
|
||||
snapshot.ChatModelConfigs = make([]ChatModelConfig, 0, len(configs))
|
||||
for _, c := range configs {
|
||||
snapshot.ChatModelConfigs = append(snapshot.ChatModelConfigs, ConvertChatModelConfig(c))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1537,9 +1503,6 @@ type Snapshot struct {
|
||||
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
|
||||
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
|
||||
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
|
||||
Chats []Chat `json:"chats"`
|
||||
ChatMessageSummaries []ChatMessageSummary `json:"chat_message_summaries"`
|
||||
ChatModelConfigs []ChatModelConfig `json:"chat_model_configs"`
|
||||
}
|
||||
|
||||
// Deployment contains information about the host running Coder.
|
||||
@@ -2150,66 +2113,6 @@ func ConvertTask(task database.Task) Task {
|
||||
return t
|
||||
}
|
||||
|
||||
// ConvertChat converts a database chat row to a telemetry Chat.
|
||||
func ConvertChat(dbChat database.GetChatsUpdatedAfterRow) Chat {
|
||||
c := Chat{
|
||||
ID: dbChat.ID,
|
||||
OwnerID: dbChat.OwnerID,
|
||||
CreatedAt: dbChat.CreatedAt,
|
||||
UpdatedAt: dbChat.UpdatedAt,
|
||||
Status: string(dbChat.Status),
|
||||
HasParent: dbChat.HasParent,
|
||||
Archived: dbChat.Archived,
|
||||
LastModelConfigID: dbChat.LastModelConfigID,
|
||||
}
|
||||
if dbChat.RootChatID.Valid {
|
||||
c.RootChatID = &dbChat.RootChatID.UUID
|
||||
}
|
||||
if dbChat.WorkspaceID.Valid {
|
||||
c.WorkspaceID = &dbChat.WorkspaceID.UUID
|
||||
}
|
||||
if dbChat.Mode.Valid {
|
||||
mode := string(dbChat.Mode.ChatMode)
|
||||
c.Mode = &mode
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ConvertChatMessageSummary converts a database chat message
|
||||
// summary row to a telemetry ChatMessageSummary.
|
||||
func ConvertChatMessageSummary(dbRow database.GetChatMessageSummariesPerChatRow) ChatMessageSummary {
|
||||
return ChatMessageSummary{
|
||||
ChatID: dbRow.ChatID,
|
||||
MessageCount: dbRow.MessageCount,
|
||||
UserMessageCount: dbRow.UserMessageCount,
|
||||
AssistantMessageCount: dbRow.AssistantMessageCount,
|
||||
ToolMessageCount: dbRow.ToolMessageCount,
|
||||
SystemMessageCount: dbRow.SystemMessageCount,
|
||||
TotalInputTokens: dbRow.TotalInputTokens,
|
||||
TotalOutputTokens: dbRow.TotalOutputTokens,
|
||||
TotalReasoningTokens: dbRow.TotalReasoningTokens,
|
||||
TotalCacheCreationTokens: dbRow.TotalCacheCreationTokens,
|
||||
TotalCacheReadTokens: dbRow.TotalCacheReadTokens,
|
||||
TotalCostMicros: dbRow.TotalCostMicros,
|
||||
TotalRuntimeMs: dbRow.TotalRuntimeMs,
|
||||
DistinctModelCount: dbRow.DistinctModelCount,
|
||||
CompressedMessageCount: dbRow.CompressedMessageCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertChatModelConfig converts a database model config row to a
|
||||
// telemetry ChatModelConfig.
|
||||
func ConvertChatModelConfig(dbRow database.GetChatModelConfigsForTelemetryRow) ChatModelConfig {
|
||||
return ChatModelConfig{
|
||||
ID: dbRow.ID,
|
||||
Provider: dbRow.Provider,
|
||||
Model: dbRow.Model,
|
||||
ContextLimit: dbRow.ContextLimit,
|
||||
Enabled: dbRow.Enabled,
|
||||
IsDefault: dbRow.IsDefault,
|
||||
}
|
||||
}
|
||||
|
||||
type telemetryItemKey string
|
||||
|
||||
// The comment below gets rid of the warning that the name "TelemetryItemKey" has
|
||||
@@ -2331,53 +2234,6 @@ type BoundaryUsageSummary struct {
|
||||
PeriodDurationMilliseconds int64 `json:"period_duration_ms"`
|
||||
}
|
||||
|
||||
// Chat contains anonymized metadata about a chat for telemetry.
|
||||
// Titles and message content are excluded to avoid PII leakage.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
OwnerID uuid.UUID `json:"owner_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Status string `json:"status"`
|
||||
HasParent bool `json:"has_parent"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id"`
|
||||
Mode *string `json:"mode"`
|
||||
Archived bool `json:"archived"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id"`
|
||||
}
|
||||
|
||||
// ChatMessageSummary contains per-chat aggregated message metrics
|
||||
// for telemetry. Individual message content is never included.
|
||||
type ChatMessageSummary struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
UserMessageCount int64 `json:"user_message_count"`
|
||||
AssistantMessageCount int64 `json:"assistant_message_count"`
|
||||
ToolMessageCount int64 `json:"tool_message_count"`
|
||||
SystemMessageCount int64 `json:"system_message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalReasoningTokens int64 `json:"total_reasoning_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
DistinctModelCount int64 `json:"distinct_model_count"`
|
||||
CompressedMessageCount int64 `json:"compressed_message_count"`
|
||||
}
|
||||
|
||||
// ChatModelConfig contains model configuration metadata for
|
||||
// telemetry. Sensitive fields like API keys are excluded.
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ContextLimit int64 `json:"context_limit"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
|
||||
return AIBridgeInterceptionsSummary{
|
||||
ID: uuid.New(),
|
||||
|
||||
@@ -1549,303 +1549,3 @@ func TestTelemetry_BoundaryUsageSummary(t *testing.T) {
|
||||
require.Nil(t, snapshot2.BoundaryUsageSummary)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatsTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create chat providers (required FK for model configs).
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a model config.
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
DisplayName: "Claude Sonnet",
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 200000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a second model config to test full dump.
|
||||
modelCfg2, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
DisplayName: "GPT-4o",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a soft-deleted model config — should NOT appear in telemetry.
|
||||
deletedCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-deleted",
|
||||
DisplayName: "Deleted Model",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 100000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.DeleteChatModelConfigByID(ctx, deletedCfg.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a root chat with a workspace.
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: org.ID,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
CreatedBy: user.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
|
||||
rootChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "Root Chat",
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
Mode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a child chat (has parent + root).
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg2.ID,
|
||||
Title: "Child Chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for root chat: 2 user, 2 assistant, 1 tool.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, user.ID, uuid.Nil, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`, `[{"type":"text","text":"hi"}]`, `[{"type":"text","text":"help"}]`, `[{"type":"text","text":"sure"}]`, `[{"type":"text","text":"result"}]`},
|
||||
ContentVersion: []int16{1, 1, 1, 1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100, 200, 150, 300, 0},
|
||||
OutputTokens: []int64{0, 50, 0, 100, 0},
|
||||
TotalTokens: []int64{100, 250, 150, 400, 0},
|
||||
ReasoningTokens: []int64{0, 10, 0, 20, 0},
|
||||
CacheCreationTokens: []int64{50, 0, 30, 0, 0},
|
||||
CacheReadTokens: []int64{0, 25, 0, 40, 0},
|
||||
ContextLimit: []int64{200000, 200000, 200000, 200000, 200000},
|
||||
Compressed: []bool{false, false, false, false, false},
|
||||
TotalCostMicros: []int64{1000, 2000, 1500, 3000, 0},
|
||||
RuntimeMs: []int64{0, 500, 0, 800, 100},
|
||||
ProviderResponseID: []string{"", "resp-1", "", "resp-2", ""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for child chat: 1 user, 1 assistant (compressed).
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: childChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg2.ID, modelCfg2.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"q"}]`, `[{"type":"text","text":"a"}]`},
|
||||
ContentVersion: []int16{1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{500, 600},
|
||||
OutputTokens: []int64{0, 200},
|
||||
TotalTokens: []int64{500, 800},
|
||||
ReasoningTokens: []int64{0, 50},
|
||||
CacheCreationTokens: []int64{100, 0},
|
||||
CacheReadTokens: []int64{0, 75},
|
||||
ContextLimit: []int64{128000, 128000},
|
||||
Compressed: []bool{false, true},
|
||||
TotalCostMicros: []int64{5000, 8000},
|
||||
RuntimeMs: []int64{0, 1200},
|
||||
ProviderResponseID: []string{"", "resp-3"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a soft-deleted message on root chat with large token values.
|
||||
// This acts as "poison" — if the deleted filter is missing, totals
|
||||
// will be inflated and assertions below will fail.
|
||||
poisonMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"poison"}]`},
|
||||
ContentVersion: []int16{1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{999999},
|
||||
OutputTokens: []int64{999999},
|
||||
TotalTokens: []int64{999999},
|
||||
ReasoningTokens: []int64{999999},
|
||||
CacheCreationTokens: []int64{999999},
|
||||
CacheReadTokens: []int64{999999},
|
||||
ContextLimit: []int64{200000},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{999999},
|
||||
RuntimeMs: []int64{999999},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.SoftDeleteChatMessageByID(ctx, poisonMsgs[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, snapshot := collectSnapshot(ctx, t, db, nil)
|
||||
|
||||
// --- Assert Chats ---
|
||||
require.Len(t, snapshot.Chats, 2)
|
||||
|
||||
// Find root and child by HasParent flag.
|
||||
var foundRoot, foundChild *telemetry.Chat
|
||||
for i := range snapshot.Chats {
|
||||
if !snapshot.Chats[i].HasParent {
|
||||
foundRoot = &snapshot.Chats[i]
|
||||
} else {
|
||||
foundChild = &snapshot.Chats[i]
|
||||
}
|
||||
}
|
||||
require.NotNil(t, foundRoot, "expected root chat")
|
||||
require.NotNil(t, foundChild, "expected child chat")
|
||||
|
||||
// Root chat assertions.
|
||||
assert.Equal(t, rootChat.ID, foundRoot.ID)
|
||||
assert.Equal(t, user.ID, foundRoot.OwnerID)
|
||||
assert.Equal(t, "running", foundRoot.Status)
|
||||
assert.False(t, foundRoot.HasParent)
|
||||
assert.Nil(t, foundRoot.RootChatID)
|
||||
require.NotNil(t, foundRoot.WorkspaceID)
|
||||
assert.Equal(t, ws.ID, *foundRoot.WorkspaceID)
|
||||
assert.Equal(t, modelCfg.ID, foundRoot.LastModelConfigID)
|
||||
require.NotNil(t, foundRoot.Mode)
|
||||
assert.Equal(t, "computer_use", *foundRoot.Mode)
|
||||
assert.False(t, foundRoot.Archived)
|
||||
|
||||
// Child chat assertions.
|
||||
assert.Equal(t, childChat.ID, foundChild.ID)
|
||||
assert.Equal(t, user.ID, foundChild.OwnerID)
|
||||
assert.True(t, foundChild.HasParent)
|
||||
require.NotNil(t, foundChild.RootChatID)
|
||||
assert.Equal(t, rootChat.ID, *foundChild.RootChatID)
|
||||
assert.Nil(t, foundChild.WorkspaceID)
|
||||
assert.Equal(t, "completed", foundChild.Status)
|
||||
assert.Equal(t, modelCfg2.ID, foundChild.LastModelConfigID)
|
||||
assert.Nil(t, foundChild.Mode)
|
||||
assert.False(t, foundChild.Archived)
|
||||
|
||||
// --- Assert ChatMessageSummaries ---
|
||||
require.Len(t, snapshot.ChatMessageSummaries, 2)
|
||||
|
||||
summaryMap := make(map[uuid.UUID]telemetry.ChatMessageSummary)
|
||||
for _, s := range snapshot.ChatMessageSummaries {
|
||||
summaryMap[s.ChatID] = s
|
||||
}
|
||||
|
||||
// Root chat summary: 2 user + 2 assistant + 1 tool = 5 messages.
|
||||
rootSummary, ok := summaryMap[rootChat.ID]
|
||||
require.True(t, ok, "expected summary for root chat")
|
||||
assert.Equal(t, int64(5), rootSummary.MessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1), rootSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), rootSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(750), rootSummary.TotalInputTokens) // 100+200+150+300+0
|
||||
assert.Equal(t, int64(150), rootSummary.TotalOutputTokens) // 0+50+0+100+0
|
||||
assert.Equal(t, int64(30), rootSummary.TotalReasoningTokens) // 0+10+0+20+0
|
||||
assert.Equal(t, int64(80), rootSummary.TotalCacheCreationTokens) // 50+0+30+0+0
|
||||
assert.Equal(t, int64(65), rootSummary.TotalCacheReadTokens) // 0+25+0+40+0
|
||||
assert.Equal(t, int64(7500), rootSummary.TotalCostMicros) // 1000+2000+1500+3000+0
|
||||
assert.Equal(t, int64(1400), rootSummary.TotalRuntimeMs) // 0+500+0+800+100
|
||||
assert.Equal(t, int64(1), rootSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(0), rootSummary.CompressedMessageCount)
|
||||
|
||||
// Child chat summary: 1 user + 1 assistant = 2 messages, 1 compressed.
|
||||
childSummary, ok := summaryMap[childChat.ID]
|
||||
require.True(t, ok, "expected summary for child chat")
|
||||
assert.Equal(t, int64(2), childSummary.MessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1100), childSummary.TotalInputTokens) // 500+600
|
||||
assert.Equal(t, int64(200), childSummary.TotalOutputTokens) // 0+200
|
||||
assert.Equal(t, int64(50), childSummary.TotalReasoningTokens) // 0+50
|
||||
assert.Equal(t, int64(0), childSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), childSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(100), childSummary.TotalCacheCreationTokens) // 100+0
|
||||
assert.Equal(t, int64(75), childSummary.TotalCacheReadTokens) // 0+75
|
||||
assert.Equal(t, int64(13000), childSummary.TotalCostMicros) // 5000+8000
|
||||
assert.Equal(t, int64(1200), childSummary.TotalRuntimeMs) // 0+1200
|
||||
assert.Equal(t, int64(1), childSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(1), childSummary.CompressedMessageCount)
|
||||
|
||||
// --- Assert ChatModelConfigs ---
|
||||
require.Len(t, snapshot.ChatModelConfigs, 2)
|
||||
|
||||
configMap := make(map[uuid.UUID]telemetry.ChatModelConfig)
|
||||
for _, c := range snapshot.ChatModelConfigs {
|
||||
configMap[c.ID] = c
|
||||
}
|
||||
|
||||
cfg1, ok := configMap[modelCfg.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic", cfg1.Provider)
|
||||
assert.Equal(t, "claude-sonnet-4-20250514", cfg1.Model)
|
||||
assert.Equal(t, int64(200000), cfg1.ContextLimit)
|
||||
assert.True(t, cfg1.Enabled)
|
||||
assert.True(t, cfg1.IsDefault)
|
||||
|
||||
cfg2, ok := configMap[modelCfg2.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "openai", cfg2.Provider)
|
||||
assert.Equal(t, "gpt-4o", cfg2.Model)
|
||||
assert.Equal(t, int64(128000), cfg2.ContextLimit)
|
||||
assert.True(t, cfg2.Enabled)
|
||||
assert.False(t, cfg2.IsDefault)
|
||||
}
|
||||
|
||||
+12
-8
@@ -475,14 +475,6 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
|
||||
// Service accounts are a Premium feature.
|
||||
if !api.Entitlements.Enabled(codersdk.FeatureServiceAccounts) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("%s is a Premium feature. Contact sales!", codersdk.FeatureServiceAccounts.Humanize()),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if req.UserLoginType == "" {
|
||||
// Default to password auth
|
||||
req.UserLoginType = codersdk.LoginTypePassword
|
||||
@@ -1638,6 +1630,18 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
|
||||
rbacRoles = req.RBACRoles
|
||||
}
|
||||
|
||||
// When the agents experiment is enabled, auto-assign the
|
||||
// agents-access role so new users can use Coder Agents
|
||||
// without manual admin intervention. Skip this for OIDC
|
||||
// users when site role sync is enabled, because the sync
|
||||
// will overwrite roles on every login anyway — those
|
||||
// admins should use --oidc-user-role-default instead.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
|
||||
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
|
||||
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
|
||||
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
orgRoles := make([]string, 0)
|
||||
|
||||
+143
-6
@@ -829,6 +829,35 @@ func TestPostUsers(t *testing.T) {
|
||||
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
|
||||
})
|
||||
|
||||
// CreateWithAgentsExperiment verifies that new users
|
||||
// are auto-assigned the agents-access role when the
|
||||
// experiment is enabled. The experiment-disabled case
|
||||
// is implicitly covered by TestInitialRoles, which
|
||||
// asserts exactly [owner] with no experiment — it
|
||||
// would fail if agents-access leaked through.
|
||||
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
|
||||
Email: "another@user.org",
|
||||
Username: "someone-else",
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
roles, err := client.UserRoles(ctx, user.Username)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
|
||||
"new user should have agents-access role when agents experiment is enabled")
|
||||
})
|
||||
|
||||
t.Run("CreateWithStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
@@ -950,7 +979,28 @@ func TestPostUsers(t *testing.T) {
|
||||
require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC)
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/Unlicensed", func(t *testing.T) {
|
||||
t.Run("ServiceAccount/OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-ok",
|
||||
UserLoginType: codersdk.LoginTypeNone,
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.LoginTypeNone, user.LoginType)
|
||||
require.Empty(t, user.Email)
|
||||
require.Equal(t, "service-acct-ok", user.Username)
|
||||
require.Equal(t, codersdk.UserStatusDormant, user.Status)
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -960,14 +1010,75 @@ func TestPostUsers(t *testing.T) {
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-ok",
|
||||
UserLoginType: codersdk.LoginTypeNone,
|
||||
Username: "service-acct-email",
|
||||
Email: "should-not-have@email.com",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Premium feature")
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Email cannot be set for service accounts")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithPassword", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-password",
|
||||
Password: "ShouldNotHavePassword123!",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Password cannot be set for service accounts")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithInvalidLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-login-type",
|
||||
UserLoginType: codersdk.LoginTypePassword,
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Service accounts must use login type 'none'")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/DefaultLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-default-login",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := client.User(ctx, user.ID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.LoginTypeNone, found.LoginType)
|
||||
require.Empty(t, found.Email)
|
||||
})
|
||||
|
||||
t.Run("NonServiceAccount/WithoutEmail", func(t *testing.T) {
|
||||
@@ -987,6 +1098,32 @@ func TestPostUsers(t *testing.T) {
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/MultipleWithoutEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user1, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-multi-1",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, user1.Email)
|
||||
|
||||
user2, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-multi-2",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, user2.Email)
|
||||
require.NotEqual(t, user1.ID, user2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotifyCreatedUser(t *testing.T) {
|
||||
@@ -1695,7 +1832,7 @@ func TestGetUsersFilter(t *testing.T) {
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.Users(testCtx, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Users))
|
||||
|
||||
@@ -181,9 +181,8 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
if logEntry.Level == "" {
|
||||
// Default to "info" to support older agents that didn't have the level field.
|
||||
logEntry.Level = codersdk.LogLevelInfo
|
||||
|
||||
@@ -260,50 +260,6 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
||||
require.Equal(t, "testing", logChunk[0].Output)
|
||||
require.Equal(t, "testing2", logChunk[1].Output)
|
||||
})
|
||||
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
rawOutput := "before\x00after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||
Logs: []agentsdk.Log{
|
||||
{
|
||||
CreatedAt: dbtime.Now(),
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
var logChunk []codersdk.WorkspaceAgentLog
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case logChunk = <-logs:
|
||||
}
|
||||
require.NoError(t, ctx.Err())
|
||||
require.Len(t, logChunk, 1)
|
||||
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||
})
|
||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
+25
-585
@@ -788,7 +788,6 @@ type CreateOptions struct {
|
||||
InitialUserContent []codersdk.ChatMessagePart
|
||||
MCPServerIDs []uuid.UUID
|
||||
Labels database.StringMap
|
||||
DynamicTools json.RawMessage
|
||||
}
|
||||
|
||||
// SendMessageBusyBehavior controls what happens when a chat is already active.
|
||||
@@ -900,10 +899,6 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
RawMessage: labelsJSON,
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: opts.DynamicTools,
|
||||
Valid: len(opts.DynamicTools) > 0,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert chat: %w", err)
|
||||
@@ -1551,238 +1546,6 @@ func (p *Server) PromoteQueued(
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SubmitToolResultsOptions controls tool result submission.
|
||||
type SubmitToolResultsOptions struct {
|
||||
ChatID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
ModelConfigID uuid.UUID
|
||||
Results []codersdk.ToolResult
|
||||
DynamicTools json.RawMessage
|
||||
}
|
||||
|
||||
// ToolResultValidationError indicates the submitted tool results
|
||||
// failed validation (e.g. missing, duplicate, or unexpected IDs,
|
||||
// or invalid JSON output).
|
||||
type ToolResultValidationError struct {
|
||||
Message string
|
||||
Detail string
|
||||
}
|
||||
|
||||
func (e *ToolResultValidationError) Error() string {
|
||||
if e.Detail != "" {
|
||||
return e.Message + ": " + e.Detail
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// ToolResultStatusConflictError indicates the chat is not in the
|
||||
// requires_action state expected for tool result submission.
|
||||
type ToolResultStatusConflictError struct {
|
||||
ActualStatus database.ChatStatus
|
||||
}
|
||||
|
||||
func (e *ToolResultStatusConflictError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"chat status is %q, expected %q",
|
||||
e.ActualStatus, database.ChatStatusRequiresAction,
|
||||
)
|
||||
}
|
||||
|
||||
// SubmitToolResults validates and persists client-provided tool
|
||||
// results, transitions the chat to pending, and wakes the run
|
||||
// loop. The caller is responsible for the fast-path status check;
|
||||
// this method performs an authoritative re-check under a row lock.
|
||||
func (p *Server) SubmitToolResults(
|
||||
ctx context.Context,
|
||||
opts SubmitToolResultsOptions,
|
||||
) error {
|
||||
dynamicToolNames, err := parseDynamicToolNames(pqtype.NullRawMessage{
|
||||
RawMessage: opts.DynamicTools,
|
||||
Valid: len(opts.DynamicTools) > 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse chat dynamic tools: %w", err)
|
||||
}
|
||||
|
||||
// The GetLastChatMessageByRole lookup and all subsequent
|
||||
// validation and persistence run inside a single transaction
|
||||
// so the assistant message cannot change between reads.
|
||||
var statusConflict *ToolResultStatusConflictError
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
// Authoritative status check under row lock.
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
||||
if lockErr != nil {
|
||||
return xerrors.Errorf("lock chat for update: %w", lockErr)
|
||||
}
|
||||
if locked.Status != database.ChatStatusRequiresAction {
|
||||
statusConflict = &ToolResultStatusConflictError{
|
||||
ActualStatus: locked.Status,
|
||||
}
|
||||
return statusConflict
|
||||
}
|
||||
|
||||
// Get the last assistant message inside the transaction
|
||||
// for consistency with the row lock above.
|
||||
lastAssistant, err := tx.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: opts.ChatID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get last assistant message: %w", err)
|
||||
}
|
||||
|
||||
// Collect tool-call IDs that already have results.
|
||||
// When a dynamic tool name collides with a built-in,
|
||||
// the chatloop executes it as a built-in and persists
|
||||
// the result. Those calls must not count as pending.
|
||||
afterMsgs, afterErr := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: opts.ChatID,
|
||||
AfterID: lastAssistant.ID,
|
||||
})
|
||||
if afterErr != nil {
|
||||
return xerrors.Errorf("get messages after assistant: %w", afterErr)
|
||||
}
|
||||
handledCallIDs := make(map[string]bool)
|
||||
for _, msg := range afterMsgs {
|
||||
if msg.Role != database.ChatMessageRoleTool {
|
||||
continue
|
||||
}
|
||||
msgParts, msgParseErr := chatprompt.ParseContent(msg)
|
||||
if msgParseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, mp := range msgParts {
|
||||
if mp.Type == codersdk.ChatMessagePartTypeToolResult {
|
||||
handledCallIDs[mp.ToolCallID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract pending dynamic tool-call IDs, skipping any
|
||||
// that were already handled by the chatloop.
|
||||
pendingCallIDs := make(map[string]bool)
|
||||
toolCallIDToName := make(map[string]string)
|
||||
parts, parseErr := chatprompt.ParseContent(lastAssistant)
|
||||
if parseErr != nil {
|
||||
return xerrors.Errorf("parse assistant message: %w", parseErr)
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall &&
|
||||
dynamicToolNames[part.ToolName] &&
|
||||
!handledCallIDs[part.ToolCallID] {
|
||||
pendingCallIDs[part.ToolCallID] = true
|
||||
toolCallIDToName[part.ToolCallID] = part.ToolName
|
||||
}
|
||||
}
|
||||
|
||||
// Validate submitted results match pending calls exactly.
|
||||
submittedIDs := make(map[string]bool, len(opts.Results))
|
||||
for _, result := range opts.Results {
|
||||
if submittedIDs[result.ToolCallID] {
|
||||
return &ToolResultValidationError{
|
||||
Message: "Duplicate tool_call_id in results.",
|
||||
Detail: fmt.Sprintf("Duplicate tool call ID %q.", result.ToolCallID),
|
||||
}
|
||||
}
|
||||
submittedIDs[result.ToolCallID] = true
|
||||
}
|
||||
for id := range pendingCallIDs {
|
||||
if !submittedIDs[id] {
|
||||
return &ToolResultValidationError{
|
||||
Message: "Missing tool result.",
|
||||
Detail: fmt.Sprintf("Missing result for tool call %q.", id),
|
||||
}
|
||||
}
|
||||
}
|
||||
for id := range submittedIDs {
|
||||
if !pendingCallIDs[id] {
|
||||
return &ToolResultValidationError{
|
||||
Message: "Unexpected tool result.",
|
||||
Detail: fmt.Sprintf("No pending tool call with ID %q.", id),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal each tool result into a separate message row.
|
||||
resultContents := make([]pqtype.NullRawMessage, 0, len(opts.Results))
|
||||
for _, result := range opts.Results {
|
||||
if !json.Valid(result.Output) {
|
||||
return &ToolResultValidationError{
|
||||
Message: "Tool result output must be valid JSON.",
|
||||
Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", result.ToolCallID),
|
||||
}
|
||||
}
|
||||
part := codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: result.ToolCallID,
|
||||
ToolName: toolCallIDToName[result.ToolCallID],
|
||||
Result: result.Output,
|
||||
IsError: result.IsError,
|
||||
}
|
||||
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part})
|
||||
if marshalErr != nil {
|
||||
return xerrors.Errorf("marshal tool result: %w", marshalErr)
|
||||
}
|
||||
resultContents = append(resultContents, marshaled)
|
||||
}
|
||||
|
||||
// Insert tool-result messages.
|
||||
n := len(resultContents)
|
||||
params := database.InsertChatMessagesParams{
|
||||
ChatID: opts.ChatID,
|
||||
CreatedBy: make([]uuid.UUID, n),
|
||||
ModelConfigID: make([]uuid.UUID, n),
|
||||
Role: make([]database.ChatMessageRole, n),
|
||||
Content: make([]string, n),
|
||||
ContentVersion: make([]int16, n),
|
||||
Visibility: make([]database.ChatMessageVisibility, n),
|
||||
InputTokens: make([]int64, n),
|
||||
OutputTokens: make([]int64, n),
|
||||
TotalTokens: make([]int64, n),
|
||||
ReasoningTokens: make([]int64, n),
|
||||
CacheCreationTokens: make([]int64, n),
|
||||
CacheReadTokens: make([]int64, n),
|
||||
ContextLimit: make([]int64, n),
|
||||
Compressed: make([]bool, n),
|
||||
TotalCostMicros: make([]int64, n),
|
||||
RuntimeMs: make([]int64, n),
|
||||
ProviderResponseID: make([]string, n),
|
||||
}
|
||||
for i, rc := range resultContents {
|
||||
params.CreatedBy[i] = opts.UserID
|
||||
params.ModelConfigID[i] = opts.ModelConfigID
|
||||
params.Role[i] = database.ChatMessageRoleTool
|
||||
params.Content[i] = string(rc.RawMessage)
|
||||
params.ContentVersion[i] = chatprompt.CurrentContentVersion
|
||||
params.Visibility[i] = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if _, insertErr := tx.InsertChatMessages(ctx, params); insertErr != nil {
|
||||
return xerrors.Errorf("insert tool results: %w", insertErr)
|
||||
}
|
||||
|
||||
// Transition chat to pending.
|
||||
if _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: opts.ChatID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
}); updateErr != nil {
|
||||
return xerrors.Errorf("update chat status: %w", updateErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if txErr != nil {
|
||||
return txErr
|
||||
}
|
||||
|
||||
// Wake the chatd run loop so it processes the chat immediately.
|
||||
p.signalWake()
|
||||
return nil
|
||||
}
|
||||
|
||||
// InterruptChat interrupts execution, sets waiting status, and broadcasts status updates.
|
||||
func (p *Server) InterruptChat(
|
||||
ctx context.Context,
|
||||
@@ -1792,32 +1555,6 @@ func (p *Server) InterruptChat(
|
||||
return chat
|
||||
}
|
||||
|
||||
// If the chat is in requires_action, insert synthetic error
|
||||
// tool-result messages for each pending dynamic tool call
|
||||
// before transitioning to waiting. Without this, the LLM
|
||||
// would see unmatched tool-call parts on the next run.
|
||||
if chat.Status == database.ChatStatusRequiresAction {
|
||||
if txErr := p.db.InTx(func(tx database.Store) error {
|
||||
locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if lockErr != nil {
|
||||
return xerrors.Errorf("lock chat for interrupt: %w", lockErr)
|
||||
}
|
||||
// Another request may have already transitioned
|
||||
// the chat (e.g. SubmitToolResults committed
|
||||
// between our snapshot and this lock).
|
||||
if locked.Status != database.ChatStatusRequiresAction {
|
||||
return nil
|
||||
}
|
||||
return insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by user")
|
||||
}, nil); txErr != nil {
|
||||
p.logger.Error(ctx, "failed to insert synthetic tool results during interrupt",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(txErr),
|
||||
)
|
||||
// Fall through — still try to set waiting status.
|
||||
}
|
||||
}
|
||||
|
||||
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
|
||||
if err != nil {
|
||||
p.logger.Error(ctx, "failed to mark chat as waiting",
|
||||
@@ -2608,7 +2345,7 @@ func insertUserMessageAndSetPending(
|
||||
// queued while a chat is active.
|
||||
func shouldQueueUserMessage(status database.ChatStatus) bool {
|
||||
switch status {
|
||||
case database.ChatStatusRunning, database.ChatStatusPending, database.ChatStatusRequiresAction:
|
||||
case database.ChatStatusRunning, database.ChatStatusPending:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -3481,12 +3218,8 @@ func (p *Server) Subscribe(
|
||||
// Pubsub will deliver a duplicate status
|
||||
// later; the frontend deduplicates it
|
||||
// (setChatStatus is idempotent).
|
||||
// action_required is also transient and
|
||||
// only published on the local stream, so
|
||||
// it must be forwarded here.
|
||||
if event.Type == codersdk.ChatStreamEventTypeMessagePart ||
|
||||
event.Type == codersdk.ChatStreamEventTypeStatus ||
|
||||
event.Type == codersdk.ChatStreamEventTypeActionRequired {
|
||||
event.Type == codersdk.ChatStreamEventTypeStatus {
|
||||
select {
|
||||
case <-mergedCtx.Done():
|
||||
return
|
||||
@@ -3612,51 +3345,6 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.Ch
|
||||
}
|
||||
}
|
||||
|
||||
// pendingToStreamToolCalls converts a slice of chatloop pending
|
||||
// tool calls into the SDK streaming representation.
|
||||
func pendingToStreamToolCalls(pending []chatloop.PendingToolCall) []codersdk.ChatStreamToolCall {
|
||||
calls := make([]codersdk.ChatStreamToolCall, len(pending))
|
||||
for i, tc := range pending {
|
||||
calls[i] = codersdk.ChatStreamToolCall{
|
||||
ToolCallID: tc.ToolCallID,
|
||||
ToolName: tc.ToolName,
|
||||
Args: tc.Args,
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
// publishChatActionRequired broadcasts an action_required event via
|
||||
// PostgreSQL pubsub so that global watchers can react to dynamic
|
||||
// tool calls without streaming each chat individually.
|
||||
func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloop.PendingToolCall) {
|
||||
if p.pubsub == nil {
|
||||
return
|
||||
}
|
||||
toolCalls := pendingToStreamToolCalls(pending)
|
||||
sdkChat := db2sdk.Chat(chat, nil, nil)
|
||||
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindActionRequired,
|
||||
Chat: sdkChat,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
p.logger.Error(context.Background(), "failed to marshal chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
||||
p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishDiffStatusChange broadcasts a diff_status_change event for
|
||||
// the given chat so that watching clients know to re-fetch the diff
|
||||
// status. This is called from the HTTP layer after the diff status
|
||||
@@ -4161,21 +3849,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
}
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange, nil)
|
||||
|
||||
// When the chat is parked in requires_action,
|
||||
// publish the stream event and global pubsub event
|
||||
// after the DB status has committed. Publishing
|
||||
// here (not in runChat) prevents a race where a
|
||||
// fast client reacts before the status is visible.
|
||||
if status == database.ChatStatusRequiresAction && len(runResult.PendingDynamicToolCalls) > 0 {
|
||||
toolCalls := pendingToStreamToolCalls(runResult.PendingDynamicToolCalls)
|
||||
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
|
||||
Type: codersdk.ChatStreamEventTypeActionRequired,
|
||||
ActionRequired: &codersdk.ChatStreamActionRequired{
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
})
|
||||
p.publishChatActionRequired(updatedChat, runResult.PendingDynamicToolCalls)
|
||||
}
|
||||
if !wasInterrupted {
|
||||
p.maybeSendPushNotification(cleanupCtx, updatedChat, status, lastError, runResult, logger)
|
||||
}
|
||||
@@ -4204,13 +3877,6 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
return
|
||||
}
|
||||
|
||||
// The LLM invoked a dynamic tool — park the chat in
|
||||
// requires_action so the client can supply tool results.
|
||||
if len(runResult.PendingDynamicToolCalls) > 0 {
|
||||
status = database.ChatStatusRequiresAction
|
||||
return
|
||||
}
|
||||
|
||||
// If runChat completed successfully but the server context was
|
||||
// canceled (e.g. during Close()), the chat should be returned
|
||||
// to pending so another replica can pick it up. There is a
|
||||
@@ -4277,10 +3943,9 @@ func (t *generatedChatTitle) Load() (string, bool) {
|
||||
}
|
||||
|
||||
type runChatResult struct {
|
||||
FinalAssistantText string
|
||||
PushSummaryModel fantasy.LanguageModel
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
PendingDynamicToolCalls []chatloop.PendingToolCall
|
||||
FinalAssistantText string
|
||||
PushSummaryModel fantasy.LanguageModel
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
}
|
||||
|
||||
func (p *Server) runChat(
|
||||
@@ -4584,8 +4249,8 @@ func (p *Server) runChat(
|
||||
// server.
|
||||
toolNameToConfigID := make(map[string]uuid.UUID)
|
||||
for _, t := range mcpTools {
|
||||
if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok {
|
||||
toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID()
|
||||
if mcp, ok := t.(mcpclient.MCPToolIdentifier); ok {
|
||||
toolNameToConfigID[t.Info().Name] = mcp.MCPServerConfigID()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4604,7 +4269,6 @@ func (p *Server) runChat(
|
||||
// (which is the common case).
|
||||
modelConfigContextLimit := modelConfig.ContextLimit
|
||||
var finalAssistantText string
|
||||
var pendingDynamicCalls []chatloop.PendingToolCall
|
||||
|
||||
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
|
||||
// If the chat context has been canceled, bail out before
|
||||
@@ -4624,10 +4288,6 @@ func (p *Server) runChat(
|
||||
return persistCtx.Err()
|
||||
}
|
||||
|
||||
// Capture pending dynamic tool calls so the caller
|
||||
// can surface them after chatloop.Run returns.
|
||||
pendingDynamicCalls = step.PendingDynamicToolCalls
|
||||
|
||||
// Split the step content into assistant blocks and tool
|
||||
// result blocks so they can be stored as separate messages
|
||||
// with the appropriate roles. Provider-executed tool results
|
||||
@@ -4665,21 +4325,6 @@ func (p *Server) runChat(
|
||||
part.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
|
||||
}
|
||||
}
|
||||
// Apply recorded timestamps so persisted
|
||||
// tool-call parts carry accurate CreatedAt.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolCallID != "" && step.ToolCallCreatedAt != nil {
|
||||
if ts, ok := step.ToolCallCreatedAt[part.ToolCallID]; ok {
|
||||
part.CreatedAt = &ts
|
||||
}
|
||||
}
|
||||
// Provider-executed tool results appear in
|
||||
// assistantBlocks rather than toolResults,
|
||||
// so apply their timestamps here as well.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolCallID != "" && step.ToolResultCreatedAt != nil {
|
||||
if ts, ok := step.ToolResultCreatedAt[part.ToolCallID]; ok {
|
||||
part.CreatedAt = &ts
|
||||
}
|
||||
}
|
||||
sdkParts = append(sdkParts, part)
|
||||
}
|
||||
finalAssistantText = strings.TrimSpace(contentBlocksToText(sdkParts))
|
||||
@@ -4698,13 +4343,6 @@ func (p *Server) runChat(
|
||||
trPart.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true}
|
||||
}
|
||||
}
|
||||
// Apply recorded timestamps so persisted
|
||||
// tool-result parts carry accurate CreatedAt.
|
||||
if trPart.ToolCallID != "" && step.ToolResultCreatedAt != nil {
|
||||
if ts, ok := step.ToolResultCreatedAt[trPart.ToolCallID]; ok {
|
||||
trPart.CreatedAt = &ts
|
||||
}
|
||||
}
|
||||
var marshalErr error
|
||||
toolResultContents[i], marshalErr = chatprompt.MarshalParts([]codersdk.ChatMessagePart{trPart})
|
||||
if marshalErr != nil {
|
||||
@@ -5036,39 +4674,6 @@ func (p *Server) runChat(
|
||||
tools = append(tools, mcpTools...)
|
||||
tools = append(tools, workspaceMCPTools...)
|
||||
|
||||
// Append dynamic tools declared by the client at chat
|
||||
// creation time. These appear in the LLM's tool list but
|
||||
// are never executed by the chatloop — the client handles
|
||||
// execution via POST /tool-results.
|
||||
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("parse dynamic tool names: %w", err)
|
||||
}
|
||||
// Unmarshal the full definitions separately so we can
|
||||
// build the filtered list below. parseDynamicToolNames
|
||||
// already validated the JSON, so this cannot fail.
|
||||
var dynamicToolDefs []codersdk.DynamicTool
|
||||
if chat.DynamicTools.Valid {
|
||||
if err := json.Unmarshal(chat.DynamicTools.RawMessage, &dynamicToolDefs); err != nil {
|
||||
return result, xerrors.Errorf("unmarshal dynamic tools: %w", err)
|
||||
}
|
||||
}
|
||||
for _, t := range tools {
|
||||
info := t.Info()
|
||||
if dynamicToolNames[info.Name] {
|
||||
logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence",
|
||||
slog.F("tool_name", info.Name))
|
||||
delete(dynamicToolNames, info.Name)
|
||||
}
|
||||
}
|
||||
|
||||
var filteredDefs []codersdk.DynamicTool
|
||||
for _, dt := range dynamicToolDefs {
|
||||
if dynamicToolNames[dt.Name] {
|
||||
filteredDefs = append(filteredDefs, dt)
|
||||
}
|
||||
}
|
||||
tools = append(tools, dynamicToolsFromSDK(p.logger, filteredDefs)...)
|
||||
// Build provider-native tools (e.g., web search) based on
|
||||
// the model configuration.
|
||||
var providerTools []chatloop.ProviderTool
|
||||
@@ -5112,7 +4717,8 @@ func (p *Server) runChat(
|
||||
)
|
||||
prompt = filterPromptForChainMode(prompt, chainInfo.trailingUserCount)
|
||||
}
|
||||
err = chatloop.Run(ctx, chatloop.RunOptions{
|
||||
|
||||
err := chatloop.Run(ctx, chatloop.RunOptions{
|
||||
Model: model,
|
||||
Messages: prompt,
|
||||
Tools: tools, MaxSteps: maxChatSteps,
|
||||
@@ -5120,9 +4726,6 @@ func (p *Server) runChat(
|
||||
ModelConfig: callConfig,
|
||||
ProviderOptions: providerOptions,
|
||||
ProviderTools: providerTools,
|
||||
// dynamicToolNames now contains only names that don't
|
||||
// collide with built-in/MCP tools.
|
||||
DynamicToolNames: dynamicToolNames,
|
||||
|
||||
ContextLimitFallback: modelConfigContextLimit,
|
||||
|
||||
@@ -5200,15 +4803,6 @@ func (p *Server) runChat(
|
||||
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
||||
},
|
||||
})
|
||||
if errors.Is(err, chatloop.ErrDynamicToolCall) {
|
||||
// The stream event is published in processChat's
|
||||
// defer after the DB status transitions to
|
||||
// requires_action, preventing a race where a fast
|
||||
// client reacts before the status is committed.
|
||||
result.FinalAssistantText = finalAssistantText
|
||||
result.PendingDynamicToolCalls = pendingDynamicCalls
|
||||
return result, nil
|
||||
}
|
||||
if err != nil {
|
||||
classified := chaterror.Classify(err).WithProvider(model.Provider())
|
||||
return result, chaterror.WithClassification(err, classified)
|
||||
@@ -5830,9 +5424,7 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
|
||||
recovered := 0
|
||||
for _, chat := range staleChats {
|
||||
p.logger.Info(ctx, "recovering stale chat",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("status", chat.Status))
|
||||
p.logger.Info(ctx, "recovering stale chat", slog.F("chat_id", chat.ID))
|
||||
|
||||
// Use a transaction with FOR UPDATE to avoid a TOCTOU race:
|
||||
// between GetStaleChats (a bare SELECT) and here, the chat's
|
||||
@@ -5844,73 +5436,34 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
return xerrors.Errorf("lock chat for recovery: %w", lockErr)
|
||||
}
|
||||
|
||||
switch locked.Status {
|
||||
case database.ChatStatusRunning:
|
||||
// Re-check: only recover if the chat is still stale.
|
||||
// A valid heartbeat at or after the threshold means
|
||||
// the chat was refreshed after our snapshot.
|
||||
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
|
||||
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
|
||||
slog.F("chat_id", chat.ID))
|
||||
return nil
|
||||
}
|
||||
case database.ChatStatusRequiresAction:
|
||||
// Re-check: the chat may have been updated after
|
||||
// our snapshot, similar to the heartbeat check for
|
||||
// running chats.
|
||||
if !locked.UpdatedAt.Before(staleAfter) {
|
||||
p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery",
|
||||
slog.F("chat_id", chat.ID))
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
// Status changed since our snapshot; skip.
|
||||
// Only recover chats that are still running.
|
||||
// Between GetStaleChats and this lock, the chat
|
||||
// may have completed normally.
|
||||
if locked.Status != database.ChatStatusRunning {
|
||||
p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("status", locked.Status))
|
||||
return nil
|
||||
}
|
||||
|
||||
lastError := sql.NullString{}
|
||||
if locked.Status == database.ChatStatusRequiresAction {
|
||||
lastError = sql.NullString{
|
||||
String: "Dynamic tool execution timed out",
|
||||
Valid: true,
|
||||
}
|
||||
// Re-check: only recover if the chat is still stale.
|
||||
// A valid heartbeat that is at or after the stale
|
||||
// threshold means the chat was refreshed after our
|
||||
// initial snapshot — skip it.
|
||||
if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) {
|
||||
p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery",
|
||||
slog.F("chat_id", chat.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
recoverStatus := database.ChatStatusPending
|
||||
if locked.Status == database.ChatStatusRequiresAction {
|
||||
// Timed-out requires_action chats have dangling
|
||||
// tool calls with no matching results. Setting
|
||||
// them back to pending would replay incomplete
|
||||
// tool calls to the LLM, so mark them as errors.
|
||||
recoverStatus = database.ChatStatusError
|
||||
}
|
||||
|
||||
// Insert synthetic error tool-result messages
|
||||
// so the LLM history remains valid if the user
|
||||
// retries the chat later.
|
||||
if locked.Status == database.ChatStatusRequiresAction {
|
||||
if synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil {
|
||||
p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(synthErr),
|
||||
)
|
||||
// Continue with error status even if
|
||||
// synthetic results fail to insert.
|
||||
}
|
||||
}
|
||||
|
||||
// Reset so any replica can pick it up (pending) or
|
||||
// the client sees the failure (error).
|
||||
// Reset to pending so any replica can pick it up.
|
||||
_, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: recoverStatus,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: lastError,
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
return updateErr
|
||||
@@ -5929,119 +5482,6 @@ func (p *Server) recoverStaleChats(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// insertSyntheticToolResultsTx inserts error tool-result messages for
|
||||
// every pending dynamic tool call in the last assistant message. This
|
||||
// keeps the LLM message history valid (every tool-call has a matching
|
||||
// tool-result) when a requires_action chat times out or is interrupted.
|
||||
// It operates on the provided store, which may be a transaction handle.
|
||||
func insertSyntheticToolResultsTx(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
reason string,
|
||||
) error {
|
||||
dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse dynamic tools: %w", err)
|
||||
}
|
||||
if len(dynamicToolNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the last assistant message to find pending tool calls.
|
||||
lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{
|
||||
ChatID: chat.ID,
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get last assistant message: %w", err)
|
||||
}
|
||||
|
||||
parts, err := chatprompt.ParseContent(lastAssistant)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse assistant message: %w", err)
|
||||
}
|
||||
|
||||
// Collect dynamic tool calls that need synthetic results.
|
||||
var resultContents []pqtype.NullRawMessage
|
||||
for _, part := range parts {
|
||||
if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] {
|
||||
continue
|
||||
}
|
||||
resultPart := codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolResult,
|
||||
ToolCallID: part.ToolCallID,
|
||||
ToolName: part.ToolName,
|
||||
Result: json.RawMessage(fmt.Sprintf("%q", reason)),
|
||||
IsError: true,
|
||||
}
|
||||
marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart})
|
||||
if marshalErr != nil {
|
||||
return xerrors.Errorf("marshal synthetic tool result: %w", marshalErr)
|
||||
}
|
||||
resultContents = append(resultContents, marshaled)
|
||||
}
|
||||
|
||||
if len(resultContents) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert tool-result messages using the same pattern as
|
||||
// SubmitToolResults.
|
||||
n := len(resultContents)
|
||||
params := database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: make([]uuid.UUID, n),
|
||||
ModelConfigID: make([]uuid.UUID, n),
|
||||
Role: make([]database.ChatMessageRole, n),
|
||||
Content: make([]string, n),
|
||||
ContentVersion: make([]int16, n),
|
||||
Visibility: make([]database.ChatMessageVisibility, n),
|
||||
InputTokens: make([]int64, n),
|
||||
OutputTokens: make([]int64, n),
|
||||
TotalTokens: make([]int64, n),
|
||||
ReasoningTokens: make([]int64, n),
|
||||
CacheCreationTokens: make([]int64, n),
|
||||
CacheReadTokens: make([]int64, n),
|
||||
ContextLimit: make([]int64, n),
|
||||
Compressed: make([]bool, n),
|
||||
TotalCostMicros: make([]int64, n),
|
||||
RuntimeMs: make([]int64, n),
|
||||
ProviderResponseID: make([]string, n),
|
||||
}
|
||||
for i, rc := range resultContents {
|
||||
params.CreatedBy[i] = uuid.Nil
|
||||
params.ModelConfigID[i] = chat.LastModelConfigID
|
||||
params.Role[i] = database.ChatMessageRoleTool
|
||||
params.Content[i] = string(rc.RawMessage)
|
||||
params.ContentVersion[i] = chatprompt.CurrentContentVersion
|
||||
params.Visibility[i] = database.ChatMessageVisibilityBoth
|
||||
}
|
||||
if _, err := store.InsertChatMessages(ctx, params); err != nil {
|
||||
return xerrors.Errorf("insert synthetic tool results: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseDynamicToolNames unmarshals the dynamic tools JSON column
|
||||
// and returns a map of tool names. This centralizes the repeated
|
||||
// pattern of deserializing DynamicTools into a name set.
|
||||
func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) {
|
||||
if !raw.Valid || len(raw.RawMessage) == 0 {
|
||||
return make(map[string]bool), nil
|
||||
}
|
||||
var tools []codersdk.DynamicTool
|
||||
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
|
||||
}
|
||||
names := make(map[string]bool, len(tools))
|
||||
for _, t := range tools {
|
||||
names[t.Name] = true
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// maybeSendPushNotification sends a web push notification when an
|
||||
// agent chat reaches a terminal state. For errors it dispatches
|
||||
// synchronously; for successful completions it spawns a goroutine
|
||||
|
||||
@@ -1531,70 +1531,6 @@ func TestRecoverStaleChatsPeriodically(t *testing.T) {
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
}
|
||||
|
||||
func TestRecoverStaleRequiresActionChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, model := seedChatDependencies(ctx, t, db)
|
||||
|
||||
// Use a very short stale threshold so the periodic recovery
|
||||
// kicks in quickly during the test.
|
||||
staleAfter := 500 * time.Millisecond
|
||||
|
||||
// Create a chat and set it to requires_action to simulate a
|
||||
// client that disappeared while the chat was waiting for
|
||||
// dynamic tool results.
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.ID,
|
||||
Title: "stale-requires-action",
|
||||
LastModelConfigID: model.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRequiresAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Backdate updated_at so the chat appears stale to the
|
||||
// recovery loop without needing time.Sleep.
|
||||
_, err = rawDB.ExecContext(ctx,
|
||||
"UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
time.Now().Add(-time.Hour), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
server := chatd.New(chatd.Config{
|
||||
Logger: logger,
|
||||
Database: db,
|
||||
ReplicaID: uuid.New(),
|
||||
Pubsub: ps,
|
||||
PendingChatAcquireInterval: testutil.WaitLong,
|
||||
InFlightChatStaleAfter: staleAfter,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, server.Close())
|
||||
})
|
||||
|
||||
// The stale recovery should transition the requires_action
|
||||
// chat to error with the timeout message.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
chatResult, err = db.GetChatByID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return chatResult.Status == database.ChatStatusError
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
|
||||
require.Contains(t, chatResult.LastError.String, "Dynamic tool execution timed out")
|
||||
require.False(t, chatResult.WorkerID.Valid)
|
||||
}
|
||||
|
||||
func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1946,518 +1882,6 @@ func TestPersistToolResultWithBinaryData(t *testing.T) {
|
||||
require.True(t, foundToolResultInSecondCall, "expected second streamed model call to include execute tool output")
|
||||
}
|
||||
|
||||
func TestDynamicToolCallPausesAndResumes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track streaming calls to the mock LLM.
|
||||
var streamedCallCount atomic.Int32
|
||||
var streamedCallsMu sync.Mutex
|
||||
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
// Non-streaming requests are title generation — return a
|
||||
// simple title.
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Dynamic tool test")
|
||||
}
|
||||
|
||||
// Capture the full request for later assertions.
|
||||
streamedCallsMu.Lock()
|
||||
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
||||
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
||||
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
||||
Stream: req.Stream,
|
||||
})
|
||||
streamedCallsMu.Unlock()
|
||||
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
// First call: the LLM invokes our dynamic tool.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello world"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
// Second call: the LLM returns a normal text response.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Dynamic tool result received.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
|
||||
// Dynamic tools do not need a workspace connection, but the
|
||||
// chatd server always builds workspace tools. Use an active
|
||||
// server without an agent connection — the built-in tools
|
||||
// are never invoked because the only tool call targets our
|
||||
// dynamic tool.
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "dynamic-tool-pause-resume",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Please call the dynamic tool."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// 2. Read the assistant message to find the tool-call ID.
|
||||
var toolCallID string
|
||||
var toolCallFound bool
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.Role != database.ChatMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
toolCallFound = true
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
require.True(t, toolCallFound, "expected to find tool call for my_dynamic_tool")
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// 3. Submit tool results via SubmitToolResults.
|
||||
toolResultOutput := json.RawMessage(`{"result":"dynamic tool output"}`)
|
||||
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: toolResultOutput,
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Wait for the chat to reach a terminal status.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
// 5. Verify the chat completed successfully.
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// 6. Verify the mock received exactly 2 streaming calls.
|
||||
require.Equal(t, int32(2), streamedCallCount.Load(),
|
||||
"expected exactly 2 streaming calls to the LLM")
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
recordedCalls := append([]chattest.OpenAIRequest(nil), streamedCalls...)
|
||||
streamedCallsMu.Unlock()
|
||||
require.Len(t, recordedCalls, 2)
|
||||
|
||||
// 7. Verify the dynamic tool appeared in the first call's tool list.
|
||||
var foundDynamicTool bool
|
||||
for _, tool := range recordedCalls[0].Tools {
|
||||
if tool.Function.Name == "my_dynamic_tool" {
|
||||
foundDynamicTool = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundDynamicTool,
|
||||
"expected 'my_dynamic_tool' in the first LLM call's tool list")
|
||||
|
||||
// 8. Verify the second call's messages contain the tool result.
|
||||
var foundToolResultInSecondCall bool
|
||||
for _, message := range recordedCalls[1].Messages {
|
||||
if message.Role != "tool" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(message.Content, "dynamic tool output") {
|
||||
foundToolResultInSecondCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, foundToolResultInSecondCall,
|
||||
"expected second LLM call to include the submitted dynamic tool result")
|
||||
}
|
||||
|
||||
func TestDynamicToolCallMixedWithBuiltIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// Track streaming calls to the mock LLM.
|
||||
var streamedCallCount atomic.Int32
|
||||
var streamedCallsMu sync.Mutex
|
||||
streamedCalls := make([]chattest.OpenAIRequest, 0, 2)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Mixed tool test")
|
||||
}
|
||||
|
||||
streamedCallsMu.Lock()
|
||||
streamedCalls = append(streamedCalls, chattest.OpenAIRequest{
|
||||
Messages: append([]chattest.OpenAIMessage(nil), req.Messages...),
|
||||
Tools: append([]chattest.OpenAITool(nil), req.Tools...),
|
||||
Stream: req.Stream,
|
||||
})
|
||||
streamedCallsMu.Unlock()
|
||||
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
// First call: return TWO tool calls in one
|
||||
// response — a built-in tool (read_file) and a
|
||||
// dynamic tool (my_dynamic_tool).
|
||||
builtinChunk := chattest.OpenAIToolCallChunk(
|
||||
"read_file",
|
||||
`{"path":"/tmp/test.txt"}`,
|
||||
)
|
||||
dynamicChunk := chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello world"}`,
|
||||
)
|
||||
// Merge both tool calls into one chunk with
|
||||
// separate indices so the LLM appears to have
|
||||
// requested both tools simultaneously.
|
||||
mergedChunk := builtinChunk
|
||||
dynCall := dynamicChunk.Choices[0].ToolCalls[0]
|
||||
dynCall.Index = 1
|
||||
mergedChunk.Choices[0].ToolCalls = append(
|
||||
mergedChunk.Choices[0].ToolCalls,
|
||||
dynCall,
|
||||
)
|
||||
return chattest.OpenAIStreamingResponse(mergedChunk)
|
||||
}
|
||||
// Second call (after tool results): normal text
|
||||
// response.
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("All done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "mixed-builtin-dynamic",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Call both tools."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1. Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// 2. Verify the built-in tool (read_file) was already
|
||||
// executed by checking that a tool result message
|
||||
// exists for it in the database.
|
||||
var builtinToolResultFound bool
|
||||
var toolCallID string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
// Check for the built-in tool result.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == "read_file" {
|
||||
builtinToolResultFound = true
|
||||
}
|
||||
// Find the dynamic tool call ID.
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
}
|
||||
}
|
||||
}
|
||||
return builtinToolResultFound && toolCallID != ""
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
require.True(t, builtinToolResultFound,
|
||||
"expected read_file tool result in the DB before dynamic tool resolution")
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// 3. Submit dynamic tool results.
|
||||
err = server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: json.RawMessage(`{"result":"dynamic output"}`),
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 4. Wait for the chat to complete.
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusWaiting || got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
|
||||
if chatResult.Status == database.ChatStatusError {
|
||||
require.FailNowf(t, "chat run failed", "last_error=%q", chatResult.LastError.String)
|
||||
}
|
||||
|
||||
// 5. Verify the LLM received exactly 2 streaming calls.
|
||||
require.Equal(t, int32(2), streamedCallCount.Load(),
|
||||
"expected exactly 2 streaming calls to the LLM")
|
||||
}
|
||||
|
||||
func TestSubmitToolResultsConcurrency(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
// The mock LLM returns a dynamic tool call on the first streaming
|
||||
// request, then a plain text reply on the second.
|
||||
var streamedCallCount atomic.Int32
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if !req.Stream {
|
||||
return chattest.OpenAINonStreamingResponse("Concurrency test")
|
||||
}
|
||||
if streamedCallCount.Add(1) == 1 {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAIToolCallChunk(
|
||||
"my_dynamic_tool",
|
||||
`{"input":"hello"}`,
|
||||
),
|
||||
)
|
||||
}
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("Done.")...,
|
||||
)
|
||||
})
|
||||
|
||||
user, model := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL)
|
||||
server := newActiveTestServer(t, db, ps)
|
||||
|
||||
// Create a chat with a dynamic tool.
|
||||
dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{
|
||||
Name: "my_dynamic_tool",
|
||||
Description: "A test dynamic tool.",
|
||||
InputSchema: mcpgo.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"input": map[string]any{"type": "string"},
|
||||
},
|
||||
Required: []string{"input"},
|
||||
},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := server.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: user.ID,
|
||||
Title: "concurrency-tool-results",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("Please call the dynamic tool."),
|
||||
},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for the chat to reach requires_action status.
|
||||
var chatResult database.Chat
|
||||
require.Eventually(t, func() bool {
|
||||
got, getErr := db.GetChatByID(ctx, chat.ID)
|
||||
if getErr != nil {
|
||||
return false
|
||||
}
|
||||
chatResult = got
|
||||
return got.Status == database.ChatStatusRequiresAction ||
|
||||
got.Status == database.ChatStatusError
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status,
|
||||
"expected requires_action, got %s (last_error=%q)",
|
||||
chatResult.Status, chatResult.LastError.String)
|
||||
|
||||
// Find the tool call ID from the assistant message.
|
||||
var toolCallID string
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
messages, dbErr := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
if msg.Role != database.ChatMessageRoleAssistant {
|
||||
continue
|
||||
}
|
||||
parts, parseErr := chatprompt.ParseContent(msg)
|
||||
if parseErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range parts {
|
||||
if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == "my_dynamic_tool" {
|
||||
toolCallID = part.ToolCallID
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, testutil.IntervalFast)
|
||||
require.NotEmpty(t, toolCallID)
|
||||
|
||||
// Spawn N goroutines that all try to submit tool results at the
|
||||
// same time. Exactly one should succeed; the rest must get a
|
||||
// ToolResultStatusConflictError.
|
||||
const numGoroutines = 10
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
ready = make(chan struct{})
|
||||
successes atomic.Int32
|
||||
conflicts atomic.Int32
|
||||
unexpectedErrors = make(chan error, numGoroutines)
|
||||
)
|
||||
|
||||
for range numGoroutines {
|
||||
wg.Go(func() {
|
||||
// Wait for all goroutines to be ready.
|
||||
<-ready
|
||||
|
||||
submitErr := server.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: user.ID,
|
||||
ModelConfigID: chatResult.LastModelConfigID,
|
||||
Results: []codersdk.ToolResult{{
|
||||
ToolCallID: toolCallID,
|
||||
Output: json.RawMessage(`{"result":"concurrent output"}`),
|
||||
}},
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
|
||||
if submitErr == nil {
|
||||
successes.Add(1)
|
||||
return
|
||||
}
|
||||
var conflict *chatd.ToolResultStatusConflictError
|
||||
if errors.As(submitErr, &conflict) {
|
||||
conflicts.Add(1)
|
||||
return
|
||||
}
|
||||
// Collect unexpected errors for assertion
|
||||
// outside the goroutine (require.NoError
|
||||
// calls t.FailNow which is illegal here).
|
||||
unexpectedErrors <- submitErr
|
||||
})
|
||||
}
|
||||
// Release all goroutines at once.
|
||||
close(ready)
|
||||
|
||||
wg.Wait()
|
||||
close(unexpectedErrors)
|
||||
|
||||
for ue := range unexpectedErrors {
|
||||
require.NoError(t, ue, "unexpected error from SubmitToolResults")
|
||||
}
|
||||
|
||||
require.Equal(t, int32(1), successes.Load(),
|
||||
"expected exactly 1 goroutine to succeed")
|
||||
require.Equal(t, int32(numGoroutines-1), conflicts.Load(),
|
||||
"expected %d conflict errors", numGoroutines-1)
|
||||
}
|
||||
|
||||
func ptrRef[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"maps"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"charm.land/fantasy/schema"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
|
||||
@@ -40,23 +38,13 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInterrupted = xerrors.New("chat interrupted")
|
||||
ErrDynamicToolCall = xerrors.New("dynamic tool call")
|
||||
ErrInterrupted = xerrors.New("chat interrupted")
|
||||
|
||||
errStartupTimeout = xerrors.New(
|
||||
"chat response did not start before the startup timeout",
|
||||
)
|
||||
)
|
||||
|
||||
// PendingToolCall describes a tool call that targets a dynamic
|
||||
// tool. These calls are not executed by the chatloop; instead
|
||||
// they are persisted so the caller can fulfill them externally.
|
||||
type PendingToolCall struct {
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
Args string
|
||||
}
|
||||
|
||||
// PersistedStep contains the full content of a completed or
|
||||
// interrupted agent step. Content includes both assistant blocks
|
||||
// (text, reasoning, tool calls) and tool result blocks. The
|
||||
@@ -72,21 +60,6 @@ type PersistedStep struct {
|
||||
// Zero indicates the duration was not measured (e.g.
|
||||
// interrupted steps).
|
||||
Runtime time.Duration
|
||||
// PendingDynamicToolCalls lists tool calls that target
|
||||
// dynamic tools. When non-empty the chatloop exits with
|
||||
// ErrDynamicToolCall so the caller can execute them
|
||||
// externally and resume the loop.
|
||||
PendingDynamicToolCalls []PendingToolCall
|
||||
// ToolCallCreatedAt maps tool-call IDs to the time
|
||||
// the model emitted each tool call. Applied by the
|
||||
// persistence layer to set CreatedAt on persisted
|
||||
// tool-call ChatMessageParts.
|
||||
ToolCallCreatedAt map[string]time.Time
|
||||
// ToolResultCreatedAt maps tool-call IDs to the time
|
||||
// each tool result was produced (or interrupted).
|
||||
// Applied by the persistence layer to set CreatedAt
|
||||
// on persisted tool-result ChatMessageParts.
|
||||
ToolResultCreatedAt map[string]time.Time
|
||||
}
|
||||
|
||||
// RunOptions configures a single streaming chat loop run.
|
||||
@@ -104,12 +77,6 @@ type RunOptions struct {
|
||||
ActiveTools []string
|
||||
ContextLimitFallback int64
|
||||
|
||||
// DynamicToolNames lists tool names that are handled
|
||||
// externally. When the model invokes one of these tools
|
||||
// the chatloop persists partial results and exits with
|
||||
// ErrDynamicToolCall instead of executing the tool.
|
||||
DynamicToolNames map[string]bool
|
||||
|
||||
// ModelConfig holds per-call LLM parameters (temperature,
|
||||
// max tokens, etc.) read from the chat model configuration.
|
||||
ModelConfig codersdk.ChatModelCallConfig
|
||||
@@ -161,14 +128,12 @@ type ProviderTool struct {
|
||||
// step. Since we own the stream consumer, all content is tracked
|
||||
// directly here — no shadow draft state needed.
|
||||
type stepResult struct {
|
||||
content []fantasy.Content
|
||||
usage fantasy.Usage
|
||||
providerMetadata fantasy.ProviderMetadata
|
||||
finishReason fantasy.FinishReason
|
||||
toolCalls []fantasy.ToolCallContent
|
||||
shouldContinue bool
|
||||
toolCallCreatedAt map[string]time.Time
|
||||
toolResultCreatedAt map[string]time.Time
|
||||
content []fantasy.Content
|
||||
usage fantasy.Usage
|
||||
providerMetadata fantasy.ProviderMetadata
|
||||
finishReason fantasy.FinishReason
|
||||
toolCalls []fantasy.ToolCallContent
|
||||
shouldContinue bool
|
||||
}
|
||||
|
||||
// toResponseMessages converts step content into messages suitable
|
||||
@@ -420,72 +385,16 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Partition tool calls into built-in and dynamic.
|
||||
var builtinCalls, dynamicCalls []fantasy.ToolCallContent
|
||||
if len(opts.DynamicToolNames) > 0 {
|
||||
for _, tc := range result.toolCalls {
|
||||
if opts.DynamicToolNames[tc.ToolName] {
|
||||
dynamicCalls = append(dynamicCalls, tc)
|
||||
} else {
|
||||
builtinCalls = append(builtinCalls, tc)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builtinCalls = result.toolCalls
|
||||
}
|
||||
|
||||
// Execute only built-in tools.
|
||||
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, builtinCalls, func(tr fantasy.ToolResultContent, completedAt time.Time) {
|
||||
recordToolResultTimestamp(&result, tr.ToolCallID, completedAt)
|
||||
ssePart := chatprompt.PartFromContent(tr)
|
||||
ssePart.CreatedAt = &completedAt
|
||||
publishMessagePart(codersdk.ChatMessageRoleTool, ssePart)
|
||||
toolResults = executeTools(ctx, opts.Tools, opts.ProviderTools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
})
|
||||
for _, tr := range toolResults {
|
||||
result.content = append(result.content, tr)
|
||||
}
|
||||
|
||||
// If dynamic tools were called, persist what we
|
||||
// have (assistant + built-in results) and exit so
|
||||
// the caller can execute them externally.
|
||||
if len(dynamicCalls) > 0 {
|
||||
pending := make([]PendingToolCall, 0, len(dynamicCalls))
|
||||
for _, dc := range dynamicCalls {
|
||||
pending = append(pending, PendingToolCall{
|
||||
ToolCallID: dc.ToolCallID,
|
||||
ToolName: dc.ToolName,
|
||||
Args: dc.Input,
|
||||
})
|
||||
}
|
||||
|
||||
contextLimit := extractContextLimit(result.providerMetadata)
|
||||
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
|
||||
contextLimit = sql.NullInt64{
|
||||
Int64: opts.ContextLimitFallback,
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
if err := opts.PersistStep(ctx, PersistedStep{
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
|
||||
Runtime: time.Since(stepStart),
|
||||
PendingDynamicToolCalls: pending,
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
return ErrInterrupted
|
||||
}
|
||||
return xerrors.Errorf("persist step: %w", err)
|
||||
}
|
||||
|
||||
tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata)
|
||||
|
||||
return ErrDynamicToolCall
|
||||
}
|
||||
|
||||
// Check for interruption after tool execution.
|
||||
// Tools that were canceled mid-flight produce error
|
||||
// results via ctx cancellation. Persist the full
|
||||
@@ -512,13 +421,11 @@ func Run(ctx context.Context, opts RunOptions) error {
|
||||
// check and here, fall back to the interrupt-safe
|
||||
// path so partial content is not lost.
|
||||
if err := opts.PersistStep(ctx, PersistedStep{
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
|
||||
Runtime: time.Since(stepStart),
|
||||
ToolCallCreatedAt: result.toolCallCreatedAt,
|
||||
ToolResultCreatedAt: result.toolResultCreatedAt,
|
||||
Content: result.content,
|
||||
Usage: result.usage,
|
||||
ContextLimit: contextLimit,
|
||||
ProviderResponseID: extractOpenAIResponseIDIfStored(opts.ProviderOptions, result.providerMetadata),
|
||||
Runtime: time.Since(stepStart),
|
||||
}); err != nil {
|
||||
if errors.Is(err, ErrInterrupted) {
|
||||
persistInterruptedStep(ctx, opts, &result)
|
||||
@@ -851,20 +758,9 @@ func processStepStream(
|
||||
// Clean up active tool call tracking.
|
||||
delete(activeToolCalls, part.ID)
|
||||
|
||||
// Record when the model emitted this tool call
|
||||
// so the persisted part carries an accurate
|
||||
// timestamp for duration computation.
|
||||
now := dbtime.Now()
|
||||
if result.toolCallCreatedAt == nil {
|
||||
result.toolCallCreatedAt = make(map[string]time.Time)
|
||||
}
|
||||
result.toolCallCreatedAt[part.ID] = now
|
||||
|
||||
ssePart := chatprompt.PartFromContent(tc)
|
||||
ssePart.CreatedAt = &now
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleAssistant,
|
||||
ssePart,
|
||||
chatprompt.PartFromContent(tc),
|
||||
)
|
||||
|
||||
case fantasy.StreamPartTypeSource:
|
||||
@@ -894,18 +790,9 @@ func processStepStream(
|
||||
ProviderMetadata: part.ProviderMetadata,
|
||||
}
|
||||
result.content = append(result.content, tr)
|
||||
|
||||
now := dbtime.Now()
|
||||
if result.toolResultCreatedAt == nil {
|
||||
result.toolResultCreatedAt = make(map[string]time.Time)
|
||||
}
|
||||
result.toolResultCreatedAt[part.ID] = now
|
||||
|
||||
ssePart := chatprompt.PartFromContent(tr)
|
||||
ssePart.CreatedAt = &now
|
||||
publishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
ssePart,
|
||||
chatprompt.PartFromContent(tr),
|
||||
)
|
||||
}
|
||||
case fantasy.StreamPartTypeFinish:
|
||||
@@ -974,7 +861,7 @@ func executeTools(
|
||||
allTools []fantasy.AgentTool,
|
||||
providerTools []ProviderTool,
|
||||
toolCalls []fantasy.ToolCallContent,
|
||||
onResult func(fantasy.ToolResultContent, time.Time),
|
||||
onResult func(fantasy.ToolResultContent),
|
||||
) []fantasy.ToolResultContent {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
@@ -1007,11 +894,10 @@ func executeTools(
|
||||
}
|
||||
|
||||
results := make([]fantasy.ToolResultContent, len(localToolCalls))
|
||||
completedAt := make([]time.Time, len(localToolCalls))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(localToolCalls))
|
||||
for i, tc := range localToolCalls {
|
||||
go func() {
|
||||
go func(i int, tc fantasy.ToolCallContent) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -1023,21 +909,17 @@ func executeTools(
|
||||
},
|
||||
}
|
||||
}
|
||||
// Record when this tool completed (or panicked).
|
||||
// Captured per-goroutine so parallel tools get
|
||||
// accurate individual completion times.
|
||||
completedAt[i] = dbtime.Now()
|
||||
}()
|
||||
results[i] = executeSingleTool(ctx, toolMap, tc)
|
||||
}()
|
||||
}(i, tc)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Publish results in the original tool-call order so SSE
|
||||
// subscribers see a deterministic event sequence.
|
||||
if onResult != nil {
|
||||
for i, tr := range results {
|
||||
onResult(tr, completedAt[i])
|
||||
for _, tr := range results {
|
||||
onResult(tr)
|
||||
}
|
||||
}
|
||||
return results
|
||||
@@ -1173,24 +1055,11 @@ func persistInterruptedStep(
|
||||
}
|
||||
}
|
||||
|
||||
// Copy existing timestamps and add result timestamps for
|
||||
// interrupted tool calls so the frontend can show partial
|
||||
// duration.
|
||||
toolCallCreatedAt := maps.Clone(result.toolCallCreatedAt)
|
||||
if toolCallCreatedAt == nil {
|
||||
toolCallCreatedAt = make(map[string]time.Time)
|
||||
}
|
||||
toolResultCreatedAt := maps.Clone(result.toolResultCreatedAt)
|
||||
if toolResultCreatedAt == nil {
|
||||
toolResultCreatedAt = make(map[string]time.Time)
|
||||
}
|
||||
|
||||
// Build combined content: all accumulated content + synthetic
|
||||
// interrupted results for any unanswered tool calls.
|
||||
content := make([]fantasy.Content, 0, len(result.content))
|
||||
content = append(content, result.content...)
|
||||
|
||||
interruptedAt := dbtime.Now()
|
||||
for _, tc := range result.toolCalls {
|
||||
if tc.ToolCallID == "" {
|
||||
continue
|
||||
@@ -1206,20 +1075,12 @@ func persistInterruptedStep(
|
||||
Error: xerrors.New(interruptedToolResultErrorMessage),
|
||||
},
|
||||
})
|
||||
// Only stamp synthetic results; don't clobber
|
||||
// timestamps from tools that completed before
|
||||
// the interruption arrived.
|
||||
if _, exists := toolResultCreatedAt[tc.ToolCallID]; !exists {
|
||||
toolResultCreatedAt[tc.ToolCallID] = interruptedAt
|
||||
}
|
||||
answeredToolCalls[tc.ToolCallID] = struct{}{}
|
||||
}
|
||||
|
||||
persistCtx := context.WithoutCancel(ctx)
|
||||
if err := opts.PersistStep(persistCtx, PersistedStep{
|
||||
Content: content,
|
||||
ToolCallCreatedAt: toolCallCreatedAt,
|
||||
ToolResultCreatedAt: toolResultCreatedAt,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
if opts.OnInterruptedPersistError != nil {
|
||||
opts.OnInterruptedPersistError(err)
|
||||
@@ -1227,38 +1088,6 @@ func persistInterruptedStep(
|
||||
}
|
||||
}
|
||||
|
||||
// tryCompactOnExit runs compaction when the chatloop is about
|
||||
// to exit early (e.g. via ErrDynamicToolCall). The normal
|
||||
// inline and post-run compaction paths are unreachable in
|
||||
// early-exit scenarios, so this ensures the context window
|
||||
// doesn't grow unbounded.
|
||||
func tryCompactOnExit(
|
||||
ctx context.Context,
|
||||
opts RunOptions,
|
||||
usage fantasy.Usage,
|
||||
metadata fantasy.ProviderMetadata,
|
||||
) {
|
||||
if opts.Compaction == nil || opts.ReloadMessages == nil {
|
||||
return
|
||||
}
|
||||
reloaded, err := opts.ReloadMessages(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, compactErr := tryCompact(
|
||||
ctx,
|
||||
opts.Model,
|
||||
opts.Compaction,
|
||||
opts.ContextLimitFallback,
|
||||
usage,
|
||||
metadata,
|
||||
reloaded,
|
||||
)
|
||||
if compactErr != nil && opts.Compaction.OnError != nil {
|
||||
opts.Compaction.OnError(compactErr)
|
||||
}
|
||||
}
|
||||
|
||||
// buildToolDefinitions converts AgentTool definitions into the
|
||||
// fantasy.Tool slice expected by fantasy.Call. When activeTools
|
||||
// is non-empty, only function tools whose name appears in the
|
||||
@@ -1410,16 +1239,6 @@ func isResponsesStoreEnabled(providerOptions fantasy.ProviderOptions) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// recordToolResultTimestamp lazily initializes the
|
||||
// toolResultCreatedAt map on the stepResult and records
|
||||
// the completion timestamp for the given tool-call ID.
|
||||
func recordToolResultTimestamp(result *stepResult, toolCallID string, ts time.Time) {
|
||||
if result.toolResultCreatedAt == nil {
|
||||
result.toolResultCreatedAt = make(map[string]time.Time)
|
||||
}
|
||||
result.toolResultCreatedAt[toolCallID] = ts
|
||||
}
|
||||
|
||||
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
|
||||
if len(metadata) == 0 {
|
||||
return sql.NullInt64{}
|
||||
|
||||
@@ -86,54 +86,6 @@ func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
|
||||
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
|
||||
}
|
||||
|
||||
func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: fantasyanthropic.Name,
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
TotalTokens: 275,
|
||||
CacheCreationTokens: 30,
|
||||
CacheReadTokens: 150,
|
||||
ReasoningTokens: 0,
|
||||
},
|
||||
FinishReason: fantasy.FinishReasonStop,
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
var persistedStep PersistedStep
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
ContextLimitFallback: 4096,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
persistedStep = step
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(200), persistedStep.Usage.InputTokens)
|
||||
require.Equal(t, int64(75), persistedStep.Usage.OutputTokens)
|
||||
require.Equal(t, int64(275), persistedStep.Usage.TotalTokens)
|
||||
require.Equal(t, int64(30), persistedStep.Usage.CacheCreationTokens)
|
||||
require.Equal(t, int64(150), persistedStep.Usage.CacheReadTokens)
|
||||
}
|
||||
|
||||
func TestRun_OnRetryEnrichesProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -583,7 +535,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
|
||||
persistedAssistantCtxErr := xerrors.New("unset")
|
||||
var persistedContent []fantasy.Content
|
||||
var persistedStep PersistedStep
|
||||
|
||||
err := Run(ctx, RunOptions{
|
||||
Model: model,
|
||||
@@ -597,7 +548,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
|
||||
persistedAssistantCtxErr = persistCtx.Err()
|
||||
persistedContent = append([]fantasy.Content(nil), step.Content...)
|
||||
persistedStep = step
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -637,14 +587,6 @@ func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
|
||||
require.True(t, foundText)
|
||||
require.True(t, foundToolCall)
|
||||
require.True(t, foundToolResult)
|
||||
|
||||
// The interrupted tool was flushed mid-stream (never reached
|
||||
// StreamPartTypeToolCall), so it has no call timestamp.
|
||||
// But the synthetic error result must have a result timestamp.
|
||||
require.Contains(t, persistedStep.ToolResultCreatedAt, "interrupt-tool-1",
|
||||
"interrupted tool result must have a result timestamp")
|
||||
require.NotContains(t, persistedStep.ToolCallCreatedAt, "interrupt-tool-1",
|
||||
"interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)")
|
||||
}
|
||||
|
||||
type loopTestModel struct {
|
||||
@@ -785,7 +727,6 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
|
||||
}
|
||||
|
||||
var persistStepCalls int
|
||||
var persistedSteps []PersistedStep
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
@@ -795,9 +736,8 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
|
||||
newNoopTool("read_file"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
persistStepCalls++
|
||||
persistedSteps = append(persistedSteps, step)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
@@ -838,112 +778,6 @@ func TestRun_MultiStepToolExecution(t *testing.T) {
|
||||
}
|
||||
require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0")
|
||||
require.True(t, foundToolResult, "second call prompt should contain tool result message")
|
||||
|
||||
// The first persisted step (tool-call step) must carry
|
||||
// accurate timestamps for duration computation.
|
||||
require.Len(t, persistedSteps, 2)
|
||||
toolStep := persistedSteps[0]
|
||||
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
|
||||
"tool-call step must record when the model emitted the call")
|
||||
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
|
||||
"tool-call step must record when the tool result was produced")
|
||||
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
|
||||
"tool-result timestamp must be >= tool-call timestamp")
|
||||
}
|
||||
|
||||
func TestRun_ParallelToolExecutionTimestamps(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var mu sync.Mutex
|
||||
var streamCalls int
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
mu.Lock()
|
||||
step := streamCalls
|
||||
streamCalls++
|
||||
mu.Unlock()
|
||||
|
||||
_ = call
|
||||
|
||||
switch step {
|
||||
case 0:
|
||||
// Step 0: produce two tool calls in one stream.
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"a.go"}`},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "tc-1",
|
||||
ToolCallName: "read_file",
|
||||
ToolCallInput: `{"path":"a.go"}`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-2", ToolCallName: "write_file"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-2", Delta: `{"path":"b.go"}`},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-2"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "tc-2",
|
||||
ToolCallName: "write_file",
|
||||
ToolCallInput: `{"path":"b.go"}`,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
|
||||
}), nil
|
||||
default:
|
||||
// Step 1: return plain text.
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}), nil
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var persistedSteps []PersistedStep
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "do both"),
|
||||
},
|
||||
Tools: []fantasy.AgentTool{
|
||||
newNoopTool("read_file"),
|
||||
newNoopTool("write_file"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
persistedSteps = append(persistedSteps, step)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Two steps: tool-call step + text step.
|
||||
require.Equal(t, 2, streamCalls)
|
||||
require.Len(t, persistedSteps, 2)
|
||||
|
||||
toolStep := persistedSteps[0]
|
||||
|
||||
// Both tool-call IDs must appear in ToolCallCreatedAt.
|
||||
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1",
|
||||
"tool-call step must record when tc-1 was emitted")
|
||||
require.Contains(t, toolStep.ToolCallCreatedAt, "tc-2",
|
||||
"tool-call step must record when tc-2 was emitted")
|
||||
|
||||
// Both tool-call IDs must appear in ToolResultCreatedAt.
|
||||
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1",
|
||||
"tool-call step must record when tc-1 result was produced")
|
||||
require.Contains(t, toolStep.ToolResultCreatedAt, "tc-2",
|
||||
"tool-call step must record when tc-2 result was produced")
|
||||
|
||||
// Result timestamps must be >= call timestamps for both.
|
||||
require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]),
|
||||
"tc-1 tool-result timestamp must be >= tool-call timestamp")
|
||||
require.False(t, toolStep.ToolResultCreatedAt["tc-2"].Before(toolStep.ToolCallCreatedAt["tc-2"]),
|
||||
"tc-2 tool-result timestamp must be >= tool-call timestamp")
|
||||
}
|
||||
|
||||
func TestRun_PersistStepErrorPropagates(t *testing.T) {
|
||||
@@ -1349,77 +1183,6 @@ func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
|
||||
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
|
||||
}
|
||||
|
||||
// TestRun_ProviderExecutedToolResultTimestamps verifies that
|
||||
// provider-executed tool results (e.g. web search) have their
|
||||
// timestamps recorded in PersistedStep.ToolResultCreatedAt so
|
||||
// the persistence layer can stamp CreatedAt on the parts.
|
||||
func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
// Simulate a provider-executed tool call and result
|
||||
// (e.g. Anthropic web search) followed by a text
|
||||
// response — all in a single stream.
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "ws-1",
|
||||
ToolCallName: "web_search",
|
||||
ToolCallInput: `{"query":"coder"}`,
|
||||
ProviderExecuted: true,
|
||||
},
|
||||
// Provider-executed tool result — emitted by
|
||||
// the provider, not our tool runner.
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolResult,
|
||||
ID: "ws-1",
|
||||
ToolCallName: "web_search",
|
||||
ProviderExecuted: true,
|
||||
},
|
||||
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "search done"},
|
||||
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
|
||||
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
||||
}), nil
|
||||
},
|
||||
}
|
||||
|
||||
var persistedSteps []PersistedStep
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "search for coder"),
|
||||
},
|
||||
MaxSteps: 1,
|
||||
PersistStep: func(_ context.Context, step PersistedStep) error {
|
||||
persistedSteps = append(persistedSteps, step)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, persistedSteps, 1)
|
||||
|
||||
step := persistedSteps[0]
|
||||
|
||||
// Provider-executed tool call should have a call timestamp.
|
||||
require.Contains(t, step.ToolCallCreatedAt, "ws-1",
|
||||
"provider-executed tool call must record its timestamp")
|
||||
|
||||
// Provider-executed tool result should have a result
|
||||
// timestamp so the frontend can compute duration.
|
||||
require.Contains(t, step.ToolResultCreatedAt, "ws-1",
|
||||
"provider-executed tool result must record its timestamp")
|
||||
|
||||
require.False(t,
|
||||
step.ToolResultCreatedAt["ws-1"].Before(step.ToolCallCreatedAt["ws-1"]),
|
||||
"tool-result timestamp must be >= tool-call timestamp")
|
||||
}
|
||||
|
||||
// TestRun_PersistStepInterruptedFallback verifies that when the normal
|
||||
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
|
||||
// race), the step is retried via the interrupt-safe path.
|
||||
|
||||
@@ -713,76 +713,4 @@ func TestRun_Compaction(t *testing.T) {
|
||||
}
|
||||
require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)")
|
||||
})
|
||||
|
||||
t.Run("TriggersOnDynamicToolExit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var persistCompactionCalls int
|
||||
const summaryText = "compaction summary for dynamic tool exit"
|
||||
|
||||
// The LLM calls a dynamic tool. Usage is above the
|
||||
// compaction threshold so compaction should fire even
|
||||
// though the chatloop exits via ErrDynamicToolCall.
|
||||
model := &loopTestModel{
|
||||
provider: "fake",
|
||||
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return streamFromParts([]fantasy.StreamPart{
|
||||
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"},
|
||||
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`},
|
||||
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeToolCall,
|
||||
ID: "tc-1",
|
||||
ToolCallName: "my_dynamic_tool",
|
||||
ToolCallInput: `{"query": "test"}`,
|
||||
},
|
||||
{
|
||||
Type: fantasy.StreamPartTypeFinish,
|
||||
FinishReason: fantasy.FinishReasonToolCalls,
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 80,
|
||||
TotalTokens: 85,
|
||||
},
|
||||
},
|
||||
}), nil
|
||||
},
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: []fantasy.Content{
|
||||
fantasy.TextContent{Text: summaryText},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
err := Run(context.Background(), RunOptions{
|
||||
Model: model,
|
||||
Messages: []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
},
|
||||
MaxSteps: 5,
|
||||
DynamicToolNames: map[string]bool{"my_dynamic_tool": true},
|
||||
PersistStep: func(_ context.Context, _ PersistedStep) error {
|
||||
return nil
|
||||
},
|
||||
ContextLimitFallback: 100,
|
||||
Compaction: &CompactionOptions{
|
||||
ThresholdPercent: 70,
|
||||
SummaryPrompt: "summarize now",
|
||||
Persist: func(_ context.Context, result CompactionResult) error {
|
||||
persistCompactionCalls++
|
||||
require.Contains(t, result.SystemSummary, summaryText)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) {
|
||||
return []fantasy.Message{
|
||||
textMessage(fantasy.MessageRoleUser, "hello"),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
require.ErrorIs(t, err, ErrDynamicToolCall)
|
||||
require.Equal(t, 1, persistCompactionCalls,
|
||||
"compaction must fire before dynamic tool exit")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2329,48 +2329,3 @@ func TestMediaToolResultRoundTrip(t *testing.T) {
|
||||
require.True(t, isText, "expected ToolResultOutputContentText")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPartFromContent_CreatedAtNotStamped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// PartFromContent must NOT stamp CreatedAt itself.
|
||||
// The chatloop layer records timestamps separately and
|
||||
// the persistence layer applies them. PartFromContent
|
||||
// is called in multiple contexts (SSE publishing,
|
||||
// persistence) so stamping inside it would produce
|
||||
// inaccurate durations.
|
||||
|
||||
t.Run("ToolCallHasNilCreatedAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := chatprompt.PartFromContent(fantasy.ToolCallContent{
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "execute",
|
||||
})
|
||||
assert.Nil(t, part.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("ToolCallPointerHasNilCreatedAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := chatprompt.PartFromContent(&fantasy.ToolCallContent{
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "execute",
|
||||
})
|
||||
assert.Nil(t, part.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("ToolResultHasNilCreatedAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := chatprompt.PartFromContent(fantasy.ToolResultContent{
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "execute",
|
||||
Result: fantasy.ToolResultOutputContentText{Text: "{}"},
|
||||
})
|
||||
assert.Nil(t, part.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("TextHasNilCreatedAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := chatprompt.PartFromContent(fantasy.TextContent{Text: "hello"})
|
||||
assert.Nil(t, part.CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -53,10 +53,8 @@ type AnthropicMessage struct {
|
||||
|
||||
// AnthropicUsage represents usage information in an Anthropic response.
|
||||
type AnthropicUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
// AnthropicChunk represents a streaming chunk from Anthropic.
|
||||
@@ -69,16 +67,14 @@ type AnthropicChunk struct {
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
Usage AnthropicUsage `json:"usage,omitempty"`
|
||||
UsageMap map[string]int `json:"-"`
|
||||
}
|
||||
|
||||
// AnthropicChunkMessage represents message metadata in a chunk.
|
||||
type AnthropicChunkMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Usage map[string]int `json:"usage,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// AnthropicContentBlock represents a content block in a chunk.
|
||||
@@ -210,11 +206,7 @@ func (s *anthropicServer) writeStreamingResponse(w http.ResponseWriter, chunks <
|
||||
"stop_reason": chunk.StopReason,
|
||||
"stop_sequence": chunk.StopSequence,
|
||||
}
|
||||
if chunk.UsageMap != nil {
|
||||
chunkData["usage"] = chunk.UsageMap
|
||||
} else {
|
||||
chunkData["usage"] = chunk.Usage
|
||||
}
|
||||
chunkData["usage"] = chunk.Usage
|
||||
case "message_stop":
|
||||
// No additional fields
|
||||
}
|
||||
@@ -350,80 +342,6 @@ func AnthropicTextChunks(deltas ...string) []AnthropicChunk {
|
||||
return chunks
|
||||
}
|
||||
|
||||
// AnthropicTextChunksWithCacheUsage creates a streaming response with text
|
||||
// deltas and explicit cache token usage. The message_start event carries
|
||||
// the initial input and cache token counts, and the final message_delta
|
||||
// carries the output token count.
|
||||
func AnthropicTextChunksWithCacheUsage(usage AnthropicUsage, deltas ...string) []AnthropicChunk {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8])
|
||||
model := "claude-3-opus-20240229"
|
||||
|
||||
messageUsage := map[string]int{
|
||||
"input_tokens": usage.InputTokens,
|
||||
}
|
||||
if usage.CacheCreationInputTokens != 0 {
|
||||
messageUsage["cache_creation_input_tokens"] = usage.CacheCreationInputTokens
|
||||
}
|
||||
if usage.CacheReadInputTokens != 0 {
|
||||
messageUsage["cache_read_input_tokens"] = usage.CacheReadInputTokens
|
||||
}
|
||||
|
||||
chunks := []AnthropicChunk{
|
||||
{
|
||||
Type: "message_start",
|
||||
Message: AnthropicChunkMessage{
|
||||
ID: messageID,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: model,
|
||||
Usage: messageUsage,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "content_block_start",
|
||||
Index: 0,
|
||||
ContentBlock: AnthropicContentBlock{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, delta := range deltas {
|
||||
chunks = append(chunks, AnthropicChunk{
|
||||
Type: "content_block_delta",
|
||||
Index: 0,
|
||||
Delta: AnthropicDeltaBlock{
|
||||
Type: "text_delta",
|
||||
Text: delta,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
chunks = append(chunks,
|
||||
AnthropicChunk{
|
||||
Type: "content_block_stop",
|
||||
Index: 0,
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_delta",
|
||||
StopReason: "end_turn",
|
||||
UsageMap: map[string]int{
|
||||
"output_tokens": usage.OutputTokens,
|
||||
},
|
||||
},
|
||||
AnthropicChunk{
|
||||
Type: "message_stop",
|
||||
},
|
||||
)
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// AnthropicToolCallChunks creates a complete streaming response for a tool call.
|
||||
// Input JSON can be split across multiple deltas, matching Anthropic's
|
||||
// input_json_delta streaming behavior.
|
||||
|
||||
@@ -63,59 +63,6 @@ func TestAnthropic_Streaming(t *testing.T) {
|
||||
require.Equal(t, len(expectedDeltas), deltaIndex, "Expected %d deltas, got %d. Total parts received: %d", len(expectedDeltas), deltaIndex, len(allParts))
|
||||
}
|
||||
|
||||
func TestAnthropic_StreamingUsageIncludesCacheTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
return chattest.AnthropicStreamingResponse(
|
||||
chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
CacheCreationInputTokens: 30,
|
||||
CacheReadInputTokens: 150,
|
||||
}, "cached", " response")...,
|
||||
)
|
||||
})
|
||||
|
||||
client, err := fantasyanthropic.New(
|
||||
fantasyanthropic.WithAPIKey("test-key"),
|
||||
fantasyanthropic.WithBaseURL(serverURL),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := client.LanguageModel(context.Background(), "claude-3-opus-20240229")
|
||||
require.NoError(t, err)
|
||||
|
||||
stream, err := model.Stream(context.Background(), fantasy.Call{
|
||||
Prompt: []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
finishPart fantasy.StreamPart
|
||||
found bool
|
||||
)
|
||||
for part := range stream {
|
||||
if part.Type != fantasy.StreamPartTypeFinish {
|
||||
continue
|
||||
}
|
||||
finishPart = part
|
||||
found = true
|
||||
}
|
||||
|
||||
require.True(t, found)
|
||||
require.Equal(t, int64(200), finishPart.Usage.InputTokens)
|
||||
require.Equal(t, int64(75), finishPart.Usage.OutputTokens)
|
||||
require.Equal(t, int64(275), finishPart.Usage.TotalTokens)
|
||||
require.Equal(t, int64(30), finishPart.Usage.CacheCreationTokens)
|
||||
require.Equal(t, int64(150), finishPart.Usage.CacheReadTokens)
|
||||
}
|
||||
|
||||
func TestAnthropic_ToolCalls(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// dynamicTool wraps a codersdk.DynamicTool as a fantasy.AgentTool.
|
||||
// These tools are presented to the LLM but never executed by the
|
||||
// chatloop — when the LLM calls one, the chatloop exits with
|
||||
// requires_action status and the client handles execution.
|
||||
// The Run method should never be called; it returns an error if
|
||||
// it is, as a safety net.
|
||||
type dynamicTool struct {
|
||||
name string
|
||||
description string
|
||||
parameters map[string]any
|
||||
required []string
|
||||
opts fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// dynamicToolsFromSDK converts codersdk.DynamicTool definitions
|
||||
// into fantasy.AgentTool implementations for inclusion in the LLM
|
||||
// tool list.
|
||||
func dynamicToolsFromSDK(logger slog.Logger, tools []codersdk.DynamicTool) []fantasy.AgentTool {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]fantasy.AgentTool, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
dt := &dynamicTool{
|
||||
name: t.Name,
|
||||
description: t.Description,
|
||||
}
|
||||
// InputSchema is a full JSON Schema object stored as
|
||||
// json.RawMessage. Extract the "properties" and
|
||||
// "required" fields that fantasy.ToolInfo expects.
|
||||
if len(t.InputSchema) > 0 {
|
||||
var schema struct {
|
||||
Properties map[string]any `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
if err := json.Unmarshal(t.InputSchema, &schema); err != nil {
|
||||
// Defensive: present the tool with no parameter
|
||||
// constraints rather than failing. The LLM may
|
||||
// hallucinate argument shapes, but the tool will
|
||||
// still appear in the tool list.
|
||||
logger.Warn(context.Background(), "failed to parse dynamic tool input schema",
|
||||
slog.F("tool_name", t.Name),
|
||||
slog.Error(err))
|
||||
} else {
|
||||
dt.parameters = schema.Properties
|
||||
dt.required = schema.Required
|
||||
}
|
||||
}
|
||||
result = append(result, dt)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *dynamicTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: t.name,
|
||||
Description: t.description,
|
||||
Parameters: t.parameters,
|
||||
Required: t.required,
|
||||
}
|
||||
}
|
||||
|
||||
func (*dynamicTool) Run(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
// Dynamic tools are never executed by the chatloop. If this
|
||||
// method is called, it indicates a bug in the chatloop's
|
||||
// dynamic tool detection logic.
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"dynamic tool called in chatloop — this is a bug; " +
|
||||
"dynamic tools should be handled by the client",
|
||||
), nil
|
||||
}
|
||||
|
||||
func (t *dynamicTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.opts
|
||||
}
|
||||
|
||||
func (t *dynamicTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.opts = opts
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestDynamicToolsFromSDK(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptySlice", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
result := dynamicToolsFromSDK(logger, nil)
|
||||
require.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("ValidToolWithSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
tools := []codersdk.DynamicTool{
|
||||
{
|
||||
Name: "my_tool",
|
||||
Description: "A useful tool",
|
||||
InputSchema: json.RawMessage(`{"type":"object","properties":{"input":{"type":"string"}},"required":["input"]}`),
|
||||
},
|
||||
}
|
||||
result := dynamicToolsFromSDK(logger, tools)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
info := result[0].Info()
|
||||
require.Equal(t, "my_tool", info.Name)
|
||||
require.Equal(t, "A useful tool", info.Description)
|
||||
require.NotNil(t, info.Parameters)
|
||||
require.Contains(t, info.Parameters, "input")
|
||||
require.Equal(t, []string{"input"}, info.Required)
|
||||
})
|
||||
|
||||
t.Run("ToolWithoutSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
tools := []codersdk.DynamicTool{
|
||||
{
|
||||
Name: "no_schema",
|
||||
Description: "Tool with no schema",
|
||||
},
|
||||
}
|
||||
result := dynamicToolsFromSDK(logger, tools)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
info := result[0].Info()
|
||||
require.Equal(t, "no_schema", info.Name)
|
||||
require.Nil(t, info.Parameters)
|
||||
require.Nil(t, info.Required)
|
||||
})
|
||||
|
||||
t.Run("MalformedSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
tools := []codersdk.DynamicTool{
|
||||
{
|
||||
Name: "bad_schema",
|
||||
Description: "Tool with malformed schema",
|
||||
InputSchema: json.RawMessage("not-json"),
|
||||
},
|
||||
}
|
||||
result := dynamicToolsFromSDK(logger, tools)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
info := result[0].Info()
|
||||
require.Equal(t, "bad_schema", info.Name)
|
||||
require.Nil(t, info.Parameters)
|
||||
require.Nil(t, info.Required)
|
||||
})
|
||||
|
||||
t.Run("MultipleTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
tools := []codersdk.DynamicTool{
|
||||
{Name: "first", Description: "First tool"},
|
||||
{Name: "second", Description: "Second tool"},
|
||||
{Name: "third", Description: "Third tool"},
|
||||
}
|
||||
result := dynamicToolsFromSDK(logger, tools)
|
||||
require.Len(t, result, 3)
|
||||
require.Equal(t, "first", result[0].Info().Name)
|
||||
require.Equal(t, "second", result[1].Info().Name)
|
||||
require.Equal(t, "third", result[2].Info().Name)
|
||||
})
|
||||
|
||||
t.Run("SchemaWithoutProperties", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
tools := []codersdk.DynamicTool{
|
||||
{
|
||||
Name: "bare_schema",
|
||||
Description: "Schema with no properties",
|
||||
InputSchema: json.RawMessage(`{"type":"object"}`),
|
||||
},
|
||||
}
|
||||
result := dynamicToolsFromSDK(logger, tools)
|
||||
require.Len(t, result, 1)
|
||||
|
||||
info := result[0].Info()
|
||||
require.Equal(t, "bare_schema", info.Name)
|
||||
require.Nil(t, info.Parameters)
|
||||
require.Nil(t, info.Required)
|
||||
})
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func generateTitle(
|
||||
model fantasy.LanguageModel,
|
||||
input string,
|
||||
) (string, error) {
|
||||
title, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
|
||||
title, _, err := generateStructuredTitle(ctx, model, titleGenerationPrompt, input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -192,24 +192,6 @@ func generateStructuredTitle(
|
||||
model fantasy.LanguageModel,
|
||||
systemPrompt string,
|
||||
userInput string,
|
||||
) (string, error) {
|
||||
title, _, err := generateStructuredTitleWithUsage(
|
||||
ctx,
|
||||
model,
|
||||
systemPrompt,
|
||||
userInput,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return title, nil
|
||||
}
|
||||
|
||||
func generateStructuredTitleWithUsage(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
systemPrompt string,
|
||||
userInput string,
|
||||
) (string, fantasy.Usage, error) {
|
||||
userInput = strings.TrimSpace(userInput)
|
||||
if userInput == "" {
|
||||
@@ -244,6 +226,8 @@ func generateStructuredTitleWithUsage(
|
||||
return genErr
|
||||
}, nil)
|
||||
if err != nil {
|
||||
// Extract usage from the error when available so that
|
||||
// failed attempts are still accounted for in usage tracking.
|
||||
var usage fantasy.Usage
|
||||
var noObjErr *fantasy.NoObjectGeneratedError
|
||||
if errors.As(err, &noObjErr) {
|
||||
@@ -545,7 +529,7 @@ func generateManualTitle(
|
||||
userInput = strings.TrimSpace(firstUserText)
|
||||
}
|
||||
|
||||
title, usage, err := generateStructuredTitleWithUsage(
|
||||
title, usage, err := generateStructuredTitle(
|
||||
titleCtx,
|
||||
fallbackModel,
|
||||
systemPrompt,
|
||||
@@ -595,7 +579,7 @@ func generatePushSummary(
|
||||
candidates = append(candidates, fallbackModel)
|
||||
|
||||
for _, model := range candidates {
|
||||
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "push summary model candidate failed",
|
||||
slog.Error(err),
|
||||
@@ -617,7 +601,7 @@ func generateShortText(
|
||||
model fantasy.LanguageModel,
|
||||
systemPrompt string,
|
||||
userInput string,
|
||||
) (string, error) {
|
||||
) (string, fantasy.Usage, error) {
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -645,7 +629,7 @@ func generateShortText(
|
||||
return genErr
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate short text: %w", err)
|
||||
return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err)
|
||||
}
|
||||
|
||||
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
|
||||
@@ -655,5 +639,5 @@ func generateShortText(
|
||||
}
|
||||
}
|
||||
text := normalizeShortTextOutput(contentBlocksToText(responseParts))
|
||||
return text, nil
|
||||
return text, response.Usage, nil
|
||||
}
|
||||
|
||||
@@ -515,9 +515,12 @@ func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
text, err := generateShortText(context.Background(), model, "system", "user")
|
||||
text, usage, err := generateShortText(context.Background(), model, "system", "user")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Quoted summary", text)
|
||||
require.Equal(t, int64(3), usage.InputTokens)
|
||||
require.Equal(t, int64(2), usage.OutputTokens)
|
||||
require.Equal(t, int64(5), usage.TotalTokens)
|
||||
}
|
||||
|
||||
type stubModel struct {
|
||||
|
||||
@@ -376,7 +376,7 @@ func ProtoFromLog(log Log) (*proto.Log, error) {
|
||||
}
|
||||
return &proto.Log{
|
||||
CreatedAt: timestamppb.New(log.CreatedAt),
|
||||
Output: SanitizeLogOutput(log.Output),
|
||||
Output: strings.ToValidUTF8(log.Output, "❌"),
|
||||
Level: proto.Log_Level(lvl),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -229,7 +229,7 @@ func TestLogSender_SkipHugeLog(t *testing.T) {
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
}
|
||||
|
||||
func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||
func TestLogSender_InvalidUTF8(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx, cancel := context.WithCancel(testCtx)
|
||||
@@ -243,7 +243,7 @@ func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||
uut.Enqueue(ls1,
|
||||
Log{
|
||||
CreatedAt: t0,
|
||||
Output: "test log 0, src 1\x00\xc3\x28",
|
||||
Output: "test log 0, src 1\xc3\x28",
|
||||
Level: codersdk.LogLevelInfo,
|
||||
},
|
||||
Log{
|
||||
@@ -260,10 +260,10 @@ func TestLogSender_SanitizeOutput(t *testing.T) {
|
||||
|
||||
req := testutil.TryReceive(ctx, t, fDest.reqs)
|
||||
require.NotNil(t, req)
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid output, but still send")
|
||||
// The sanitizer replaces the NUL byte and invalid UTF-8 with ❌ while
|
||||
// preserving the valid "(" byte that follows 0xc3.
|
||||
require.Equal(t, "test log 0, src 1❌❌(", req.Logs[0].GetOutput())
|
||||
require.Len(t, req.Logs, 2, "it should sanitize invalid UTF-8, but still send")
|
||||
// the 0xc3, 0x28 is an invalid 2-byte sequence in UTF-8. The sanitizer replaces 0xc3 with ❌, and then
|
||||
// interprets 0x28 as a 1-byte sequence "("
|
||||
require.Equal(t, "test log 0, src 1❌(", req.Logs[0].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[0].GetLevel())
|
||||
require.Equal(t, "test log 1, src 1", req.Logs[1].GetOutput())
|
||||
require.Equal(t, proto.Log_INFO, req.Logs[1].GetLevel())
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package agentsdk
|
||||
|
||||
import "strings"
|
||||
|
||||
// SanitizeLogOutput replaces invalid UTF-8 and NUL characters in log output.
|
||||
// Invalid UTF-8 cannot be transported in protobuf string fields, and PostgreSQL
|
||||
// rejects NUL bytes in text columns.
|
||||
func SanitizeLogOutput(s string) string {
|
||||
s = strings.ToValidUTF8(s, "❌")
|
||||
return strings.ReplaceAll(s, "\x00", "❌")
|
||||
}
|
||||
@@ -17,54 +17,6 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSanitizeLogOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
in: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8",
|
||||
in: "test log\xc3\x28",
|
||||
want: "test log❌(",
|
||||
},
|
||||
{
|
||||
name: "nul byte",
|
||||
in: "before\x00after",
|
||||
want: "before❌after",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 and nul byte",
|
||||
in: "before\x00middle\xc3\x28after",
|
||||
want: "before❌middle❌(after",
|
||||
},
|
||||
{
|
||||
name: "nul byte at edges",
|
||||
in: "\x00middle\x00",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
{
|
||||
name: "invalid utf8 at edges",
|
||||
in: "\xc3middle\xc3",
|
||||
want: "❌middle❌",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, tt.want, agentsdk.SanitizeLogOutput(tt.in))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartupLogsWriter_Write(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+20
-269
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/invopop/jsonschema"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -43,13 +42,12 @@ func CompactionThresholdKey(modelConfigID uuid.UUID) string {
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusRequiresAction ChatStatus = "requires_action"
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
)
|
||||
|
||||
// Chat represents a chat session with an AI agent.
|
||||
@@ -214,10 +212,6 @@ type ChatMessagePart struct {
|
||||
// ProviderExecuted indicates the tool call was executed by
|
||||
// the provider (e.g. Anthropic computer use).
|
||||
ProviderExecuted bool `json:"provider_executed,omitempty" variants:"tool-call?,tool-result?"`
|
||||
// CreatedAt records when this part was produced. Present on
|
||||
// tool-call and tool-result parts so the frontend can compute
|
||||
// tool execution duration.
|
||||
CreatedAt *time.Time `json:"created_at,omitempty" format:"date-time" variants:"tool-call?,tool-result?"`
|
||||
// ContextFilePath is the absolute path of a file loaded into
|
||||
// the LLM context (e.g. an AGENTS.md instruction file).
|
||||
ContextFilePath string `json:"context_file_path" variants:"context-file"`
|
||||
@@ -367,18 +361,6 @@ type ChatInputPart struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// SubmitToolResultsRequest is the body for POST /chats/{id}/tool-results.
|
||||
type SubmitToolResultsRequest struct {
|
||||
Results []ToolResult `json:"results"`
|
||||
}
|
||||
|
||||
// ToolResult is the client's response to a dynamic tool call.
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
// CreateChatRequest is the request to create a new chat.
|
||||
type CreateChatRequest struct {
|
||||
Content []ChatInputPart `json:"content"`
|
||||
@@ -387,10 +369,6 @@ type CreateChatRequest struct {
|
||||
ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"`
|
||||
MCPServerIDs []uuid.UUID `json:"mcp_server_ids,omitempty" format:"uuid"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
// UnsafeDynamicTools declares client-executed tools that the
|
||||
// LLM can invoke. This API is highly experimental and highly
|
||||
// subject to change.
|
||||
UnsafeDynamicTools []DynamicTool `json:"unsafe_dynamic_tools,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateChatRequest is the request to update a chat.
|
||||
@@ -567,17 +545,6 @@ type UpdateChatWorkspaceTTLRequest struct {
|
||||
WorkspaceTTLMillis int64 `json:"workspace_ttl_ms"`
|
||||
}
|
||||
|
||||
// ChatRetentionDaysResponse contains the current chat retention setting.
|
||||
type ChatRetentionDaysResponse struct {
|
||||
RetentionDays int32 `json:"retention_days"`
|
||||
}
|
||||
|
||||
// UpdateChatRetentionDaysRequest is a request to update the chat
|
||||
// retention period.
|
||||
type UpdateChatRetentionDaysRequest struct {
|
||||
RetentionDays int32 `json:"retention_days"`
|
||||
}
|
||||
|
||||
// ParseChatWorkspaceTTL parses a stored TTL string, returning the
|
||||
// default when the value is empty.
|
||||
func ParseChatWorkspaceTTL(s string) (time.Duration, error) {
|
||||
@@ -950,13 +917,12 @@ type ChatDiffContents struct {
|
||||
type ChatStreamEventType string
|
||||
|
||||
const (
|
||||
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
|
||||
ChatStreamEventTypeMessage ChatStreamEventType = "message"
|
||||
ChatStreamEventTypeStatus ChatStreamEventType = "status"
|
||||
ChatStreamEventTypeError ChatStreamEventType = "error"
|
||||
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
|
||||
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
|
||||
ChatStreamEventTypeActionRequired ChatStreamEventType = "action_required"
|
||||
ChatStreamEventTypeMessagePart ChatStreamEventType = "message_part"
|
||||
ChatStreamEventTypeMessage ChatStreamEventType = "message"
|
||||
ChatStreamEventTypeStatus ChatStreamEventType = "status"
|
||||
ChatStreamEventTypeError ChatStreamEventType = "error"
|
||||
ChatStreamEventTypeQueueUpdate ChatStreamEventType = "queue_update"
|
||||
ChatStreamEventTypeRetry ChatStreamEventType = "retry"
|
||||
)
|
||||
|
||||
// ChatQueuedMessage represents a queued message waiting to be processed.
|
||||
@@ -1011,123 +977,16 @@ type ChatStreamRetry struct {
|
||||
RetryingAt time.Time `json:"retrying_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// ChatStreamActionRequired is the payload of an action_required stream event.
|
||||
type ChatStreamActionRequired struct {
|
||||
ToolCalls []ChatStreamToolCall `json:"tool_calls"`
|
||||
}
|
||||
|
||||
// ChatStreamToolCall describes a pending dynamic tool call that the client
|
||||
// must execute.
|
||||
type ChatStreamToolCall struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Args string `json:"args"`
|
||||
}
|
||||
|
||||
// DynamicToolCall represents a pending tool invocation from the
|
||||
// chat stream that the client must execute and submit back.
|
||||
type DynamicToolCall struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Args string `json:"args"`
|
||||
}
|
||||
|
||||
// DynamicToolResponse holds the output of a dynamic tool
|
||||
// execution. IsError indicates a tool-level error the LLM
|
||||
// should see, as opposed to an infrastructure failure
|
||||
// (returned as the error return value).
|
||||
type DynamicToolResponse struct {
|
||||
Content string `json:"content"`
|
||||
IsError bool `json:"is_error"`
|
||||
}
|
||||
|
||||
// DynamicTool describes a client-declared tool definition. On the
|
||||
// client side, the Handler callback executes the tool when the LLM
|
||||
// invokes it. On the server side, only Name, Description, and
|
||||
// InputSchema are used (Handler is not serialized).
|
||||
type DynamicTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
// InputSchema's JSON key "input_schema" uses snake_case for
|
||||
// SDK consistency, deviating from the camelCase "inputSchema"
|
||||
// convention used by MCP.
|
||||
InputSchema json.RawMessage `json:"input_schema"`
|
||||
|
||||
// Handler executes the tool when the LLM invokes it.
|
||||
// Not serialized — this only exists on the client side.
|
||||
Handler func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) `json:"-"`
|
||||
}
|
||||
|
||||
// NewDynamicTool creates a DynamicTool with a typed handler.
|
||||
// The JSON schema is derived from T using invopop/jsonschema.
|
||||
// The handler receives deserialized args and the DynamicToolCall metadata.
|
||||
func NewDynamicTool[T any](
|
||||
name, description string,
|
||||
handler func(ctx context.Context, args T, call DynamicToolCall) (DynamicToolResponse, error),
|
||||
) DynamicTool {
|
||||
reflector := jsonschema.Reflector{
|
||||
DoNotReference: true,
|
||||
Anonymous: true,
|
||||
AllowAdditionalProperties: true,
|
||||
}
|
||||
schema := reflector.Reflect(new(T))
|
||||
schema.Version = ""
|
||||
schemaJSON, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("codersdk: failed to marshal schema for %q: %v", name, err))
|
||||
}
|
||||
|
||||
return DynamicTool{
|
||||
Name: name,
|
||||
Description: description,
|
||||
InputSchema: schemaJSON,
|
||||
Handler: func(ctx context.Context, call DynamicToolCall) (DynamicToolResponse, error) {
|
||||
var parsed T
|
||||
if err := json.Unmarshal([]byte(call.Args), &parsed); err != nil {
|
||||
return DynamicToolResponse{
|
||||
Content: fmt.Sprintf("invalid parameters: %s", err),
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
return handler(ctx, parsed, call)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ChatWatchEventKind represents the kind of event in the chat watch stream.
|
||||
type ChatWatchEventKind string
|
||||
|
||||
const (
|
||||
ChatWatchEventKindStatusChange ChatWatchEventKind = "status_change"
|
||||
ChatWatchEventKindTitleChange ChatWatchEventKind = "title_change"
|
||||
ChatWatchEventKindCreated ChatWatchEventKind = "created"
|
||||
ChatWatchEventKindDeleted ChatWatchEventKind = "deleted"
|
||||
ChatWatchEventKindDiffStatusChange ChatWatchEventKind = "diff_status_change"
|
||||
ChatWatchEventKindActionRequired ChatWatchEventKind = "action_required"
|
||||
)
|
||||
|
||||
// ChatWatchEvent represents an event from the global chat watch stream.
|
||||
// It delivers lifecycle events (created, status change, title change)
|
||||
// for all of the authenticated user's chats. When Kind is
|
||||
// ActionRequired, ToolCalls contains the pending dynamic tool
|
||||
// invocations the client must execute and submit back.
|
||||
type ChatWatchEvent struct {
|
||||
Kind ChatWatchEventKind `json:"kind"`
|
||||
Chat Chat `json:"chat"`
|
||||
ToolCalls []ChatStreamToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ChatStreamEvent represents a real-time update for chat streaming.
|
||||
type ChatStreamEvent struct {
|
||||
Type ChatStreamEventType `json:"type"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
|
||||
Status *ChatStreamStatus `json:"status,omitempty"`
|
||||
Error *ChatStreamError `json:"error,omitempty"`
|
||||
Retry *ChatStreamRetry `json:"retry,omitempty"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
|
||||
ActionRequired *ChatStreamActionRequired `json:"action_required,omitempty"`
|
||||
Type ChatStreamEventType `json:"type"`
|
||||
ChatID uuid.UUID `json:"chat_id" format:"uuid"`
|
||||
Message *ChatMessage `json:"message,omitempty"`
|
||||
MessagePart *ChatStreamMessagePart `json:"message_part,omitempty"`
|
||||
Status *ChatStreamStatus `json:"status,omitempty"`
|
||||
Error *ChatStreamError `json:"error,omitempty"`
|
||||
Retry *ChatStreamRetry `json:"retry,omitempty"`
|
||||
QueuedMessages []ChatQueuedMessage `json:"queued_messages,omitempty"`
|
||||
}
|
||||
|
||||
type chatStreamEnvelope struct {
|
||||
@@ -1808,33 +1667,6 @@ func (c *ExperimentalClient) UpdateChatWorkspaceTTL(ctx context.Context, req Upd
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatRetentionDays returns the configured chat retention period.
|
||||
func (c *ExperimentalClient) GetChatRetentionDays(ctx context.Context) (ChatRetentionDaysResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/retention-days", nil)
|
||||
if err != nil {
|
||||
return ChatRetentionDaysResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatRetentionDaysResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatRetentionDaysResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// UpdateChatRetentionDays updates the chat retention period.
|
||||
func (c *ExperimentalClient) UpdateChatRetentionDays(ctx context.Context, req UpdateChatRetentionDaysRequest) error {
|
||||
res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/retention-days", req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatTemplateAllowlist returns the deployment-wide chat template allowlist.
|
||||
func (c *ExperimentalClient) GetChatTemplateAllowlist(ctx context.Context) (ChatTemplateAllowlist, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/template-allowlist", nil)
|
||||
@@ -2070,73 +1902,6 @@ func (c *ExperimentalClient) StreamChat(ctx context.Context, chatID uuid.UUID, o
|
||||
}), nil
|
||||
}
|
||||
|
||||
// WatchChats streams lifecycle events for all of the authenticated
|
||||
// user's chats in real time. The returned channel emits
|
||||
// ChatWatchEvent values for status changes, title changes, creation,
|
||||
// deletion, diff-status changes, and action-required notifications.
|
||||
// Callers must close the returned io.Closer to release the websocket
|
||||
// connection when done.
|
||||
func (c *ExperimentalClient) WatchChats(ctx context.Context) (<-chan ChatWatchEvent, io.Closer, error) {
|
||||
conn, err := c.Dial(
|
||||
ctx,
|
||||
"/api/experimental/chats/watch",
|
||||
&websocket.DialOptions{CompressionMode: websocket.CompressionDisabled},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn.SetReadLimit(1 << 22) // 4MiB
|
||||
|
||||
streamCtx, streamCancel := context.WithCancel(ctx)
|
||||
events := make(chan ChatWatchEvent, 128)
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
defer streamCancel()
|
||||
defer func() {
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
|
||||
for {
|
||||
var envelope chatStreamEnvelope
|
||||
if err := wsjson.Read(streamCtx, conn, &envelope); err != nil {
|
||||
if streamCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
switch websocket.CloseStatus(err) {
|
||||
case websocket.StatusNormalClosure, websocket.StatusGoingAway:
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch envelope.Type {
|
||||
case ServerSentEventTypePing:
|
||||
continue
|
||||
case ServerSentEventTypeData:
|
||||
var event ChatWatchEvent
|
||||
if err := json.Unmarshal(envelope.Data, &event); err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
case events <- event:
|
||||
}
|
||||
case ServerSentEventTypeError:
|
||||
return
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return events, closeFunc(func() error {
|
||||
streamCancel()
|
||||
return nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
// GetChat returns a chat by ID.
|
||||
func (c *ExperimentalClient) GetChat(ctx context.Context, chatID uuid.UUID) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s", chatID), nil)
|
||||
@@ -2444,20 +2209,6 @@ func (c *ExperimentalClient) GetMyChatUsageLimitStatus(ctx context.Context) (Cha
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// SubmitToolResults submits the results of dynamic tool calls for a chat
|
||||
// that is in requires_action status.
|
||||
func (c *ExperimentalClient) SubmitToolResults(ctx context.Context, chatID uuid.UUID, req SubmitToolResultsRequest) error {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/tool-results", chatID), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatsByWorkspace returns a mapping of workspace ID to the latest
|
||||
// non-archived chat ID for each requested workspace. Workspaces with
|
||||
// no chats are omitted from the response.
|
||||
|
||||
@@ -329,42 +329,6 @@ func TestChatMessagePartVariantTags(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatMessagePart_CreatedAt_JSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RoundTrips", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ts := time.Date(2025, 6, 15, 12, 30, 0, 0, time.UTC)
|
||||
part := codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "execute",
|
||||
CreatedAt: &ts,
|
||||
}
|
||||
data, err := json.Marshal(part)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(data), `"created_at"`)
|
||||
|
||||
var decoded codersdk.ChatMessagePart
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.CreatedAt)
|
||||
require.True(t, ts.Equal(*decoded.CreatedAt))
|
||||
})
|
||||
|
||||
t.Run("OmittedWhenNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
part := codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: "tc-1",
|
||||
ToolName: "execute",
|
||||
}
|
||||
data, err := json.Marshal(part)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(data), `"created_at"`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -505,68 +469,6 @@ func TestChat_JSONRoundTrip(t *testing.T) {
|
||||
require.Equal(t, original, decoded)
|
||||
}
|
||||
|
||||
func TestNewDynamicTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testArgs struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
t.Run("CorrectSchema", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := codersdk.NewDynamicTool(
|
||||
"search", "search things",
|
||||
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
|
||||
return codersdk.DynamicToolResponse{Content: args.Query}, nil
|
||||
},
|
||||
)
|
||||
|
||||
require.Equal(t, "search", tool.Name)
|
||||
require.Equal(t, "search things", tool.Description)
|
||||
require.Contains(t, string(tool.InputSchema), `"query"`)
|
||||
require.Contains(t, string(tool.InputSchema), `"string"`)
|
||||
})
|
||||
|
||||
t.Run("HandlerReceivesArgs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var received testArgs
|
||||
tool := codersdk.NewDynamicTool(
|
||||
"search", "search things",
|
||||
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
|
||||
received = args
|
||||
return codersdk.DynamicToolResponse{Content: "ok"}, nil
|
||||
},
|
||||
)
|
||||
|
||||
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
|
||||
Args: `{"query":"hello"}`,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ok", resp.Content)
|
||||
require.Equal(t, "hello", received.Query)
|
||||
})
|
||||
|
||||
t.Run("InvalidJSONArgs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tool := codersdk.NewDynamicTool(
|
||||
"search", "search things",
|
||||
func(_ context.Context, args testArgs, _ codersdk.DynamicToolCall) (codersdk.DynamicToolResponse, error) {
|
||||
return codersdk.DynamicToolResponse{Content: "should not reach"}, nil
|
||||
},
|
||||
)
|
||||
|
||||
resp, err := tool.Handler(context.Background(), codersdk.DynamicToolCall{
|
||||
Args: "not-json",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, resp.IsError)
|
||||
require.Contains(t, resp.Content, "invalid parameters")
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest
|
||||
func TestParseChatWorkspaceTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
+2
-30
@@ -196,7 +196,6 @@ const (
|
||||
FeatureWorkspaceExternalAgent FeatureName = "workspace_external_agent"
|
||||
FeatureAIBridge FeatureName = "aibridge"
|
||||
FeatureBoundary FeatureName = "boundary"
|
||||
FeatureServiceAccounts FeatureName = "service_accounts"
|
||||
FeatureAIGovernanceUserLimit FeatureName = "ai_governance_user_limit"
|
||||
)
|
||||
|
||||
@@ -228,7 +227,6 @@ var (
|
||||
FeatureWorkspaceExternalAgent,
|
||||
FeatureAIBridge,
|
||||
FeatureBoundary,
|
||||
FeatureServiceAccounts,
|
||||
FeatureAIGovernanceUserLimit,
|
||||
}
|
||||
|
||||
@@ -277,7 +275,6 @@ func (n FeatureName) AlwaysEnable() bool {
|
||||
FeatureWorkspacePrebuilds: true,
|
||||
FeatureWorkspaceExternalAgent: true,
|
||||
FeatureBoundary: true,
|
||||
FeatureServiceAccounts: true,
|
||||
}[n]
|
||||
}
|
||||
|
||||
@@ -285,7 +282,7 @@ func (n FeatureName) AlwaysEnable() bool {
|
||||
func (n FeatureName) Enterprise() bool {
|
||||
switch n {
|
||||
// Add all features that should be excluded in the Enterprise feature set.
|
||||
case FeatureMultipleOrganizations, FeatureCustomRoles, FeatureServiceAccounts:
|
||||
case FeatureMultipleOrganizations, FeatureCustomRoles:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
@@ -3624,29 +3621,6 @@ Write out the current server config as YAML to stdout.`,
|
||||
YAML: "acquireBatchSize",
|
||||
Hidden: true, // Hidden because most operators should not need to modify this.
|
||||
},
|
||||
{
|
||||
Name: "Chat: Pubsub Flush Interval",
|
||||
Description: "The maximum time accepted chatd pubsub publishes wait before the batching loop schedules a flush.",
|
||||
Flag: "chat-pubsub-flush-interval",
|
||||
Env: "CODER_CHAT_PUBSUB_FLUSH_INTERVAL",
|
||||
Value: &c.AI.Chat.PubsubFlushInterval,
|
||||
Default: "50ms",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "pubsubFlushInterval",
|
||||
Annotations: serpent.Annotations{}.Mark(annotationFormatDuration, "true"),
|
||||
Hidden: true,
|
||||
},
|
||||
{
|
||||
Name: "Chat: Pubsub Queue Size",
|
||||
Description: "How many chatd pubsub publishes can wait in memory for the dedicated sender path when PostgreSQL falls behind.",
|
||||
Flag: "chat-pubsub-queue-size",
|
||||
Env: "CODER_CHAT_PUBSUB_QUEUE_SIZE",
|
||||
Value: &c.AI.Chat.PubsubQueueSize,
|
||||
Default: "8192",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "pubsubQueueSize",
|
||||
Hidden: true,
|
||||
},
|
||||
// AI Bridge Options
|
||||
{
|
||||
Name: "AI Bridge Enabled",
|
||||
@@ -4113,9 +4087,7 @@ type AIBridgeProxyConfig struct {
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
PubsubFlushInterval serpent.Duration `json:"pubsub_flush_interval" typescript:",notnull"`
|
||||
PubsubQueueSize serpent.Int64 `json:"pubsub_queue_size" typescript:",notnull"`
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserSecret represents a user secret's metadata. The secret value
|
||||
// is never included in API responses.
|
||||
type UserSecret struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
EnvName string `json:"env_name"`
|
||||
FilePath string `json:"file_path"`
|
||||
CreatedAt time.Time `json:"created_at" format:"date-time"`
|
||||
UpdatedAt time.Time `json:"updated_at" format:"date-time"`
|
||||
}
|
||||
|
||||
// CreateUserSecretRequest is the payload for creating a new user
|
||||
// secret. Name and Value are required. All other fields are optional
|
||||
// and default to empty string.
|
||||
type CreateUserSecretRequest struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EnvName string `json:"env_name,omitempty"`
|
||||
FilePath string `json:"file_path,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateUserSecretRequest is the payload for partially updating a
|
||||
// user secret. At least one field must be non-nil. Pointer fields
|
||||
// distinguish "not sent" (nil) from "set to empty string" (pointer
|
||||
// to empty string).
|
||||
type UpdateUserSecretRequest struct {
|
||||
Value *string `json:"value,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
EnvName *string `json:"env_name,omitempty"`
|
||||
FilePath *string `json:"file_path,omitempty"`
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
package codersdk
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// UserSecretEnvValidationOptions controls deployment-aware behavior
|
||||
// in environment variable name validation.
|
||||
type UserSecretEnvValidationOptions struct {
|
||||
// AIGatewayEnabled indicates that the deployment has AI Gateway
|
||||
// configured. When true, AI Gateway environment variables
|
||||
// (OPENAI_API_KEY, etc.) are reserved to prevent conflicts.
|
||||
AIGatewayEnabled bool
|
||||
}
|
||||
|
||||
var (
|
||||
// posixEnvNameRegex matches valid POSIX environment variable names:
|
||||
// must start with a letter or underscore, followed by letters,
|
||||
// digits, or underscores.
|
||||
posixEnvNameRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// reservedEnvNames are system environment variables that must not
|
||||
// be overridden by user secrets. This list is intentionally
|
||||
// aggressive because it is easier to remove entries later than
|
||||
// to add them after users have already created conflicting
|
||||
// secrets.
|
||||
reservedEnvNames = map[string]struct{}{
|
||||
// Core POSIX/login variables. Overriding these breaks
|
||||
// basic shell and session behavior.
|
||||
"PATH": {},
|
||||
"HOME": {},
|
||||
"SHELL": {},
|
||||
"USER": {},
|
||||
"LOGNAME": {},
|
||||
"PWD": {},
|
||||
"OLDPWD": {},
|
||||
|
||||
// Locale and terminal. Agents and IDEs depend on these
|
||||
// being set correctly by the system.
|
||||
"LANG": {},
|
||||
"TERM": {},
|
||||
|
||||
// Shell behavior. Overriding these can silently break
|
||||
// word splitting, directory resolution, and script
|
||||
// execution in every shell session and agent script.
|
||||
"IFS": {},
|
||||
"CDPATH": {},
|
||||
|
||||
// Shell startup files. ENV is sourced by POSIX sh for
|
||||
// interactive shells; BASH_ENV is sourced by bash for
|
||||
// every non-interactive invocation (scripts, subshells).
|
||||
// Allowing users to set these would inject arbitrary
|
||||
// code into every shell and script in the workspace.
|
||||
"ENV": {},
|
||||
"BASH_ENV": {},
|
||||
|
||||
// Temp directories. Overriding these is a security risk
|
||||
// (symlink attacks, world-readable paths).
|
||||
"TMPDIR": {},
|
||||
"TMP": {},
|
||||
"TEMP": {},
|
||||
|
||||
// Host identity.
|
||||
"HOSTNAME": {},
|
||||
|
||||
// SSH session variables. The Coder agent sets
|
||||
// SSH_AUTH_SOCK in agentssh.go; the others are set by
|
||||
// sshd and should never be faked.
|
||||
"SSH_AUTH_SOCK": {},
|
||||
"SSH_CLIENT": {},
|
||||
"SSH_CONNECTION": {},
|
||||
"SSH_TTY": {},
|
||||
|
||||
// Editor/pager. The Coder agent sets these so that git
|
||||
// operations inside workspaces work non-interactively.
|
||||
"EDITOR": {},
|
||||
"VISUAL": {},
|
||||
"PAGER": {},
|
||||
|
||||
// IDE integration. The agent sets these for code-server
|
||||
// and VS Code Remote proxying.
|
||||
"VSCODE_PROXY_URI": {},
|
||||
"CS_DISABLE_GETTING_STARTED_OVERRIDE": {},
|
||||
|
||||
// XDG base directories. Overriding these redirects
|
||||
// config, cache, and runtime data for every tool in the
|
||||
// workspace.
|
||||
"XDG_RUNTIME_DIR": {},
|
||||
"XDG_CONFIG_HOME": {},
|
||||
"XDG_DATA_HOME": {},
|
||||
"XDG_CACHE_HOME": {},
|
||||
"XDG_STATE_HOME": {},
|
||||
|
||||
// OIDC token. The Coder agent injects a short-lived
|
||||
// OIDC token for cloud auth flows (e.g. GCP workload
|
||||
// identity). Overriding it could break provisioner and
|
||||
// agent authentication.
|
||||
"OIDC_TOKEN": {},
|
||||
}
|
||||
|
||||
// aiGatewayReservedEnvNames are reserved only when AI Gateway
|
||||
// is enabled on the deployment. When AI Gateway is disabled,
|
||||
// users may legitimately want to inject their own API keys
|
||||
// via secrets.
|
||||
aiGatewayReservedEnvNames = map[string]struct{}{
|
||||
"OPENAI_API_KEY": {},
|
||||
"OPENAI_BASE_URL": {},
|
||||
"ANTHROPIC_AUTH_TOKEN": {},
|
||||
"ANTHROPIC_BASE_URL": {},
|
||||
}
|
||||
|
||||
// reservedEnvPrefixes are namespace prefixes where every
|
||||
// variable in the family is reserved. Checked after the
|
||||
// exact-name map. The CODER / CODER_* namespace is handled
|
||||
// separately with its own error message (see below).
|
||||
reservedEnvPrefixes = []string{
|
||||
// The Coder agent sets GIT_SSH_COMMAND, GIT_ASKPASS,
|
||||
// GIT_AUTHOR_*, GIT_COMMITTER_*, and several others.
|
||||
// Blocking the entire GIT_* namespace avoids an arms
|
||||
// race with new git env vars.
|
||||
"GIT_",
|
||||
|
||||
// Locale variables. LC_ALL, LC_CTYPE, LC_MESSAGES,
|
||||
// etc. control character encoding, sorting, and
|
||||
// formatting. Overriding them can break text
|
||||
// processing in agents and IDEs.
|
||||
"LC_",
|
||||
|
||||
// Dynamic linker variables. Allowing users to set
|
||||
// these would let a secret inject arbitrary shared
|
||||
// libraries into every process in the workspace.
|
||||
"LD_",
|
||||
"DYLD_",
|
||||
}
|
||||
)
|
||||
|
||||
// UserSecretEnvNameValid validates an environment variable name for
|
||||
// a user secret. Empty string is allowed (means no env injection).
|
||||
// The opts parameter controls deployment-aware checks such as AI
|
||||
// bridge variable reservation.
|
||||
func UserSecretEnvNameValid(s string, opts UserSecretEnvValidationOptions) error {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !posixEnvNameRegex.MatchString(s) {
|
||||
return xerrors.New("must start with a letter or underscore, followed by letters, digits, or underscores")
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(s)
|
||||
|
||||
if _, ok := reservedEnvNames[upper]; ok {
|
||||
return xerrors.Errorf("%s is a reserved environment variable name", upper)
|
||||
}
|
||||
|
||||
if upper == "CODER" || strings.HasPrefix(upper, "CODER_") {
|
||||
return xerrors.New("environment variable names starting with CODER_ are reserved for internal use")
|
||||
}
|
||||
|
||||
for _, prefix := range reservedEnvPrefixes {
|
||||
if strings.HasPrefix(upper, prefix) {
|
||||
return xerrors.Errorf("environment variables starting with %s are reserved", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.AIGatewayEnabled {
|
||||
if _, ok := aiGatewayReservedEnvNames[upper]; ok {
|
||||
return xerrors.Errorf("%s is reserved when AI Gateway is enabled", upper)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserSecretFilePathValid validates a file path for a user secret.
|
||||
// Empty string is allowed (means no file injection). Non-empty paths
|
||||
// must start with ~/ or /.
|
||||
func UserSecretFilePathValid(s string) error {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(s, "~/") || strings.HasPrefix(s, "/") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return xerrors.New("file path must start with ~/ or /")
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
package codersdk_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUserSecretEnvNameValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// noAIGateway is the default for most tests — AI Gateway disabled.
|
||||
noAIGateway := codersdk.UserSecretEnvValidationOptions{}
|
||||
withAIGateway := codersdk.UserSecretEnvValidationOptions{AIGatewayEnabled: true}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
opts codersdk.UserSecretEnvValidationOptions
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid names.
|
||||
{name: "SimpleUpper", input: "GITHUB_TOKEN", opts: noAIGateway},
|
||||
{name: "SimpleLower", input: "github_token", opts: noAIGateway},
|
||||
{name: "StartsWithUnderscore", input: "_FOO", opts: noAIGateway},
|
||||
{name: "SingleChar", input: "A", opts: noAIGateway},
|
||||
{name: "WithDigits", input: "A1B2", opts: noAIGateway},
|
||||
{name: "Empty", input: "", opts: noAIGateway},
|
||||
|
||||
// Invalid POSIX names.
|
||||
{name: "StartsWithDigit", input: "1FOO", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
|
||||
{name: "ContainsHyphen", input: "FOO-BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
|
||||
{name: "ContainsDot", input: "FOO.BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
|
||||
{name: "ContainsSpace", input: "FOO BAR", opts: noAIGateway, wantErr: true, errMsg: "must start with"},
|
||||
|
||||
// Reserved system names — core POSIX/login.
|
||||
{name: "ReservedPATH", input: "PATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedHOME", input: "HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedSHELL", input: "SHELL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedUSER", input: "USER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedLOGNAME", input: "LOGNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedPWD", input: "PWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedOLDPWD", input: "OLDPWD", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — locale/terminal.
|
||||
{name: "ReservedLANG", input: "LANG", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedTERM", input: "TERM", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — shell behavior.
|
||||
{name: "ReservedIFS", input: "IFS", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedCDPATH", input: "CDPATH", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — shell startup files.
|
||||
{name: "ReservedENV", input: "ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedBASH_ENV", input: "BASH_ENV", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — temp directories.
|
||||
{name: "ReservedTMPDIR", input: "TMPDIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedTMP", input: "TMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedTEMP", input: "TEMP", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — host identity.
|
||||
{name: "ReservedHOSTNAME", input: "HOSTNAME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — SSH.
|
||||
{name: "ReservedSSH_AUTH_SOCK", input: "SSH_AUTH_SOCK", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedSSH_CLIENT", input: "SSH_CLIENT", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedSSH_CONNECTION", input: "SSH_CONNECTION", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedSSH_TTY", input: "SSH_TTY", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — editor/pager.
|
||||
{name: "ReservedEDITOR", input: "EDITOR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedVISUAL", input: "VISUAL", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedPAGER", input: "PAGER", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — IDE integration.
|
||||
{name: "ReservedVSCODE_PROXY_URI", input: "VSCODE_PROXY_URI", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedCS_DISABLE", input: "CS_DISABLE_GETTING_STARTED_OVERRIDE", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — XDG.
|
||||
{name: "ReservedXDG_RUNTIME_DIR", input: "XDG_RUNTIME_DIR", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedXDG_CONFIG_HOME", input: "XDG_CONFIG_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedXDG_DATA_HOME", input: "XDG_DATA_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedXDG_CACHE_HOME", input: "XDG_CACHE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
{name: "ReservedXDG_STATE_HOME", input: "XDG_STATE_HOME", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// Reserved system names — OIDC.
|
||||
{name: "ReservedOIDC_TOKEN", input: "OIDC_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// AI Gateway vars — blocked when AI Gateway is enabled.
|
||||
{name: "AIGateway/OPENAI_API_KEY/Enabled", input: "OPENAI_API_KEY", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
|
||||
{name: "AIGateway/OPENAI_BASE_URL/Enabled", input: "OPENAI_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
|
||||
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Enabled", input: "ANTHROPIC_AUTH_TOKEN", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
|
||||
{name: "AIGateway/ANTHROPIC_BASE_URL/Enabled", input: "ANTHROPIC_BASE_URL", opts: withAIGateway, wantErr: true, errMsg: "AI Gateway"},
|
||||
|
||||
// AI Gateway vars — allowed when AI Gateway is disabled.
|
||||
{name: "AIGateway/OPENAI_API_KEY/Disabled", input: "OPENAI_API_KEY", opts: noAIGateway},
|
||||
{name: "AIGateway/OPENAI_BASE_URL/Disabled", input: "OPENAI_BASE_URL", opts: noAIGateway},
|
||||
{name: "AIGateway/ANTHROPIC_AUTH_TOKEN/Disabled", input: "ANTHROPIC_AUTH_TOKEN", opts: noAIGateway},
|
||||
{name: "AIGateway/ANTHROPIC_BASE_URL/Disabled", input: "ANTHROPIC_BASE_URL", opts: noAIGateway},
|
||||
|
||||
// Case insensitivity.
|
||||
{name: "ReservedCaseInsensitive", input: "path", opts: noAIGateway, wantErr: true, errMsg: "reserved"},
|
||||
|
||||
// CODER_ prefix.
|
||||
{name: "CoderExact", input: "CODER", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
|
||||
{name: "CoderPrefix", input: "CODER_WORKSPACE_NAME", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
|
||||
{name: "CoderAgentToken", input: "CODER_AGENT_TOKEN", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
|
||||
{name: "CoderLowerCase", input: "coder_foo", opts: noAIGateway, wantErr: true, errMsg: "CODER_"},
|
||||
|
||||
// GIT_* prefix.
|
||||
{name: "GitSSHCommand", input: "GIT_SSH_COMMAND", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
|
||||
{name: "GitAskpass", input: "GIT_ASKPASS", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
|
||||
{name: "GitAuthorName", input: "GIT_AUTHOR_NAME", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
|
||||
{name: "GitLowerCase", input: "git_editor", opts: noAIGateway, wantErr: true, errMsg: "GIT_"},
|
||||
|
||||
// LC_* prefix (locale).
|
||||
{name: "LcAll", input: "LC_ALL", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
|
||||
{name: "LcCtype", input: "LC_CTYPE", opts: noAIGateway, wantErr: true, errMsg: "LC_"},
|
||||
|
||||
// LD_* prefix (dynamic linker).
|
||||
{name: "LdPreload", input: "LD_PRELOAD", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
|
||||
{name: "LdLibraryPath", input: "LD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "LD_"},
|
||||
|
||||
// DYLD_* prefix (macOS dynamic linker).
|
||||
{name: "DyldInsert", input: "DYLD_INSERT_LIBRARIES", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
|
||||
{name: "DyldLibraryPath", input: "DYLD_LIBRARY_PATH", opts: noAIGateway, wantErr: true, errMsg: "DYLD_"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := codersdk.UserSecretEnvNameValid(tt.input, tt.opts)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserSecretFilePathValid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
// Valid paths.
|
||||
{name: "TildePath", input: "~/foo"},
|
||||
{name: "TildeSSH", input: "~/.ssh/id_rsa"},
|
||||
{name: "AbsolutePath", input: "/home/coder/.ssh/id_rsa"},
|
||||
{name: "RootPath", input: "/"},
|
||||
{name: "Empty", input: ""},
|
||||
|
||||
// Invalid paths.
|
||||
{name: "BareRelative", input: "foo/bar", wantErr: true},
|
||||
{name: "DotRelative", input: ".ssh/id_rsa", wantErr: true},
|
||||
{name: "JustFilename", input: "credentials", wantErr: true},
|
||||
{name: "TildeNoSlash", input: "~foo", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := codersdk.UserSecretFilePathValid(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must start with")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -211,53 +211,33 @@ Coder releases are initiated via
|
||||
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh)
|
||||
and automated via GitHub Actions. Specifically, the
|
||||
[`release.yaml`](https://github.com/coder/coder/blob/main/.github/workflows/release.yaml)
|
||||
workflow.
|
||||
workflow. They are created based on the current
|
||||
[`main`](https://github.com/coder/coder/tree/main) branch.
|
||||
|
||||
Release notes are automatically generated from commit titles and PR metadata.
|
||||
The release notes for a release are automatically generated from commit titles
|
||||
and metadata from PRs that are merged into `main`.
|
||||
|
||||
### Release types
|
||||
### Creating a release
|
||||
|
||||
| Type | Tag | Branch | Purpose |
|
||||
|------------------------|---------------|---------------|-----------------------------------------|
|
||||
| RC (release candidate) | `vX.Y.0-rc.W` | `main` | Ad-hoc pre-release for customer testing |
|
||||
| Release | `vX.Y.0` | `release/X.Y` | First release of a minor version |
|
||||
| Patch | `vX.Y.Z` | `release/X.Y` | Bug fixes and security patches |
|
||||
The creation of a release is initiated via
|
||||
[`./scripts/release.sh`](https://github.com/coder/coder/blob/main/scripts/release.sh).
|
||||
This script will show a preview of the release that will be created, and if you
|
||||
choose to continue, create and push the tag which will trigger the creation of
|
||||
the release via GitHub Actions.
|
||||
|
||||
### Workflow
|
||||
|
||||
RC tags are created directly on `main`. The `release/X.Y` branch is only cut
|
||||
when the release is ready. This avoids cherry-picking main's progress onto
|
||||
a release branch between the first RC and the release.
|
||||
|
||||
```text
|
||||
main: ──●──●──●──●──●──●──●──●──●──
|
||||
↑ ↑ ↑
|
||||
rc.0 rc.1 cut release/2.34, tag v2.34.0
|
||||
\
|
||||
release/2.34: ──●── v2.34.1 (patch)
|
||||
```
|
||||
|
||||
1. **RC:** On `main`, run `./scripts/release.sh`. The tool suggests the next
|
||||
RC version and tags it on `main`.
|
||||
2. **Release:** When the RC is blessed, create `release/X.Y` from `main` (or
|
||||
the specific RC commit). Switch to that branch and run
|
||||
`./scripts/release.sh`, which suggests `vX.Y.0`.
|
||||
3. **Patch:** Cherry-pick fixes onto `release/X.Y` and run
|
||||
`./scripts/release.sh` from that branch.
|
||||
|
||||
The release tool warns if you try to tag a non-RC on `main` or an RC on a
|
||||
release branch.
|
||||
See `./scripts/release.sh --help` for more information.
|
||||
|
||||
### Creating a release (via workflow dispatch)
|
||||
|
||||
If the
|
||||
[`release.yaml`](https://github.com/coder/coder/actions/workflows/release.yaml)
|
||||
workflow fails after the tag has been pushed, retry it from the GitHub Actions
|
||||
UI: press "Run workflow", set "Use workflow from" to the tag (e.g.
|
||||
`Tag: v2.34.0`), select the correct release channel, and do **not** select
|
||||
dry-run.
|
||||
Typically the workflow dispatch is only used to test (dry-run) a release,
|
||||
meaning no actual release will take place. The workflow can be dispatched
|
||||
manually from
|
||||
[Actions: Release](https://github.com/coder/coder/actions/workflows/release.yaml).
|
||||
Simply press "Run workflow" and choose dry-run.
|
||||
|
||||
To test the workflow without publishing, select dry-run.
|
||||
If a release has failed after the tag has been created and pushed, it can be
|
||||
retried by again, pressing "Run workflow", changing "Use workflow from" from
|
||||
"Branch: main" to "Tag: vX.X.X" and not selecting dry-run.
|
||||
|
||||
### Commit messages
|
||||
|
||||
@@ -291,23 +271,6 @@ specification, however, it's still possible to merge PRs on GitHub with a badly
|
||||
formatted title. Take care when merging single-commit PRs as GitHub may prefer
|
||||
to use the original commit title instead of the PR title.
|
||||
|
||||
### Backporting fixes to release branches
|
||||
|
||||
When a merged PR on `main` should also ship in older releases, add the
|
||||
`backport` label to the PR. The
|
||||
[backport workflow](https://github.com/coder/coder/blob/main/.github/workflows/backport.yaml)
|
||||
will automatically detect the latest three `release/*` branches,
|
||||
cherry-pick the merge commit onto each one, and open PRs for
|
||||
review.
|
||||
|
||||
The label can be added before or after the PR is merged. Each backport
|
||||
PR reuses the original title (e.g.
|
||||
`fix(site): correct button alignment (#12345)`) so the change is
|
||||
meaningful in release notes.
|
||||
|
||||
If the cherry-pick encounters conflicts, the backport PR is still created
|
||||
with instructions for manual resolution — no conflict markers are committed.
|
||||
|
||||
### Breaking changes
|
||||
|
||||
Breaking changes can be triggered in two ways:
|
||||
|
||||
@@ -150,12 +150,6 @@ deployment. They will always be available from the agent.
|
||||
| `coder_derp_server_sent_pong_total` | counter | Total pongs sent. | |
|
||||
| `coder_derp_server_unknown_frames_total` | counter | Total unknown frames received. | |
|
||||
| `coder_derp_server_watchers` | gauge | Current watchers. | |
|
||||
| `coder_pubsub_batch_delegate_fallbacks_total` | counter | The number of chatd publishes that fell back to the shared pubsub pool by channel class, reason, and flush stage. | `channel_class` `reason` `stage` |
|
||||
| `coder_pubsub_batch_flush_duration_seconds` | histogram | The time spent flushing one chatd batch to PostgreSQL. | `reason` |
|
||||
| `coder_pubsub_batch_queue_depth` | gauge | The number of chatd notifications waiting in the batching queue. | |
|
||||
| `coder_pubsub_batch_sender_reset_failures_total` | counter | The number of batched pubsub sender reset attempts that failed. | |
|
||||
| `coder_pubsub_batch_sender_resets_total` | counter | The number of successful batched pubsub sender resets after flush failures. | |
|
||||
| `coder_pubsub_batch_size` | histogram | The number of logical notifications sent in each chatd batch flush. | |
|
||||
| `coder_pubsub_connected` | gauge | Whether we are connected (1) or not connected (0) to postgres | |
|
||||
| `coder_pubsub_current_events` | gauge | The current number of pubsub event channels listened for | |
|
||||
| `coder_pubsub_current_subscribers` | gauge | The current number of active pubsub subscribers | |
|
||||
|
||||
@@ -37,11 +37,14 @@ resource "docker_container" "workspace" {
|
||||
resource "coder_agent" "main" {
|
||||
arch = data.coder_provisioner.me.arch
|
||||
os = "linux"
|
||||
startup_script = <<-EOF
|
||||
startup_script = <<EOF
|
||||
#!/bin/sh
|
||||
set -e
|
||||
sudo service docker start
|
||||
EOF
|
||||
|
||||
# Start Docker
|
||||
sudo dockerd &
|
||||
|
||||
# ...
|
||||
EOF
|
||||
}
|
||||
```
|
||||
|
||||
@@ -75,10 +78,13 @@ resource "coder_agent" "main" {
|
||||
os = "linux"
|
||||
arch = "amd64"
|
||||
dir = "/home/coder"
|
||||
startup_script = <<-EOF
|
||||
startup_script = <<EOF
|
||||
#!/bin/sh
|
||||
set -e
|
||||
sudo service docker start
|
||||
|
||||
# Start Docker
|
||||
sudo dockerd &
|
||||
|
||||
# ...
|
||||
EOF
|
||||
}
|
||||
|
||||
|
||||
@@ -1,38 +1,31 @@
|
||||
# Headless Authentication
|
||||
|
||||
> [!NOTE]
|
||||
> Creating service accounts requires a [Premium license](https://coder.com/pricing).
|
||||
Headless user accounts that cannot use the web UI to log in to Coder. This is
|
||||
useful for creating accounts for automated systems, such as CI/CD pipelines or
|
||||
for users who only consume Coder via another client/API.
|
||||
|
||||
Service accounts are headless user accounts that cannot use the web UI to log in
|
||||
to Coder. This is useful for creating accounts for automated systems, such as
|
||||
CI/CD pipelines or for users who only consume Coder via another client/API. Service accounts do not have passwords or associated email addresses.
|
||||
You must have the User Admin role or above to create headless users.
|
||||
|
||||
You must have the User Admin role or above to create service accounts.
|
||||
|
||||
## Create a service account
|
||||
## Create a headless user
|
||||
|
||||
<div class="tabs">
|
||||
|
||||
## CLI
|
||||
|
||||
Use the `--service-account` flag to create a dedicated service account:
|
||||
|
||||
```sh
|
||||
coder users create \
|
||||
--email="coder-bot@coder.com" \
|
||||
--username="coder-bot" \
|
||||
--service-account
|
||||
--login-type="none" \
|
||||
```
|
||||
|
||||
## UI
|
||||
|
||||
Navigate to **Deployment** > **Users** > **Create user**, then select
|
||||
**Service account** as the login type.
|
||||
Navigate to the `Users` > `Create user` in the topbar
|
||||
|
||||

|
||||
|
||||
</div>
|
||||
|
||||
## Authenticate as a service account
|
||||
|
||||
To make API or CLI requests on behalf of the headless user, learn how to
|
||||
[generate API tokens on behalf of a user](./sessions-tokens.md#generate-a-long-lived-api-token-on-behalf-of-another-user).
|
||||
|
||||
@@ -180,15 +180,6 @@ configuration set by an administrator.
|
||||
|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `web_search` | Searches the internet for up-to-date information. Available when web search is enabled for the configured Anthropic, OpenAI, or Google provider. |
|
||||
|
||||
### Workspace extension tools
|
||||
|
||||
These tools are conditionally available based on the workspace contents.
|
||||
|
||||
| Tool | What it does |
|
||||
|-------------------|--------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `read_skill` | Reads the instructions for a workspace skill by name. Available when the workspace has skills discovered in `.agents/skills/`. |
|
||||
| `read_skill_file` | Reads a supporting file from a skill's directory. |
|
||||
|
||||
## What runs where
|
||||
|
||||
Understanding the split between the control plane and the workspace is central
|
||||
@@ -233,11 +224,10 @@ Because state lives in the database:
|
||||
- The agent can resume work by targeting a new workspace and continuing from the
|
||||
last git branch or checkpoint.
|
||||
|
||||
## Security posture
|
||||
## Security implications
|
||||
|
||||
The control plane architecture provides built-in security properties for AI
|
||||
coding workflows. These are structural guarantees, not configuration options —
|
||||
they hold by default for every agent session.
|
||||
The control plane architecture provides several security advantages for AI
|
||||
coding workflows.
|
||||
|
||||
### No API keys in workspaces
|
||||
|
||||
|
||||
@@ -65,9 +65,12 @@ Once the server restarts with the experiment enabled:
|
||||
1. Navigate to the **Agents** page in the Coder dashboard.
|
||||
1. Open **Admin** settings and configure at least one LLM provider and model.
|
||||
See [Models](./models.md) for detailed setup instructions.
|
||||
1. Grant the **Coder Agents User** role to users who need to create chats.
|
||||
Go to **Admin** > **Users**, click the roles icon next to each user,
|
||||
and enable **Coder Agents User**.
|
||||
1. Grant the **Coder Agents User** role to existing users who need to create
|
||||
chats. New users receive the role automatically. For existing users, go to
|
||||
**Admin** > **Users**, click the roles icon next to each user, and enable
|
||||
**Coder Agents User**. See
|
||||
[Grant Coder Agents User](./getting-started.md#step-3-grant-coder-agents-user)
|
||||
for a bulk CLI option.
|
||||
1. Developers can then start a new chat from the Agents page.
|
||||
|
||||
## Licensing and availability
|
||||
|
||||
@@ -24,8 +24,9 @@ Before you begin, confirm the following:
|
||||
for the agent to select when provisioning workspaces.
|
||||
- **Admin access** to the Coder deployment for enabling the experiment and
|
||||
configuring providers.
|
||||
- **Coder Agents User role** assigned to each user who needs to interact with Coder Agents.
|
||||
Owners can assign this from **Admin** > **Users**. See
|
||||
- **Coder Agents User role** is automatically assigned to new users when the
|
||||
`agents` experiment is enabled. For existing users, owners can assign it from
|
||||
**Admin** > **Users**. See
|
||||
[Grant Coder Agents User](#step-3-grant-coder-agents-user) below.
|
||||
|
||||
## Step 1: Enable the experiment
|
||||
@@ -74,14 +75,20 @@ Detailed instructions for each provider and model option are in the
|
||||
|
||||
## Step 3: Grant Coder Agents User
|
||||
|
||||
The **Coder Agents User** role controls which users can interact with Coder Agents.
|
||||
Members do not have Coder Agents User by default.
|
||||
The **Coder Agents User** role controls which users can interact with
|
||||
Coder Agents.
|
||||
|
||||
Owners always have full access and do not need the role. Repeat the following steps for each user who needs access.
|
||||
### New users
|
||||
|
||||
> [!NOTE]
|
||||
> Users who created conversations before this role was introduced are
|
||||
> automatically granted the role during upgrade.
|
||||
When the `agents` experiment is enabled, new users are automatically
|
||||
assigned the **Coder Agents User** role at account creation. No admin
|
||||
action is required.
|
||||
|
||||
### Existing users
|
||||
|
||||
Users who were created before the experiment was enabled do not receive
|
||||
the role automatically. Owners can assign it from the dashboard or in
|
||||
bulk via the CLI.
|
||||
|
||||
**Dashboard (individual):**
|
||||
|
||||
@@ -91,8 +98,7 @@ Owners always have full access and do not need the role. Repeat the following st
|
||||
|
||||
**CLI (bulk):**
|
||||
|
||||
You can also grant the role via CLI. For example, to grant the role to
|
||||
all active users at once:
|
||||
To grant the role to all active users at once:
|
||||
|
||||
```sh
|
||||
coder users list -o json \
|
||||
@@ -105,6 +111,12 @@ coder users list -o json \
|
||||
done
|
||||
```
|
||||
|
||||
Owners always have full access and do not need the role.
|
||||
|
||||
> [!NOTE]
|
||||
> Users who created conversations before this role was introduced are
|
||||
> automatically granted the role during upgrade.
|
||||
|
||||
## Step 4: Start your first Coder Agent
|
||||
|
||||
1. Go to the **Agents** page in the Coder dashboard.
|
||||
|
||||
@@ -232,43 +232,35 @@ model. Developers select from enabled models when starting a chat.
|
||||
The agent has access to a set of workspace tools that it uses to accomplish
|
||||
tasks:
|
||||
|
||||
| Tool | Description |
|
||||
|----------------------------|--------------------------------------------------------------------------|
|
||||
| `list_templates` | Browse available workspace templates |
|
||||
| `read_template` | Get template details and configurable parameters |
|
||||
| `create_workspace` | Create a workspace from a template |
|
||||
| `start_workspace` | Start a stopped workspace for the current chat |
|
||||
| `propose_plan` | Present a Markdown plan file for user review |
|
||||
| `read_file` | Read file contents from the workspace |
|
||||
| `write_file` | Write a file to the workspace |
|
||||
| `edit_files` | Perform search-and-replace edits across files |
|
||||
| `execute` | Run shell commands in the workspace |
|
||||
| `process_output` | Retrieve output from a background process |
|
||||
| `process_list` | List all tracked processes in the workspace |
|
||||
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
|
||||
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
|
||||
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
|
||||
| `message_agent` | Send a follow-up message to a running sub-agent |
|
||||
| `close_agent` | Stop a running sub-agent |
|
||||
| `spawn_computer_use_agent` | Spawn a sub-agent with desktop interaction (screenshot, mouse, keyboard) |
|
||||
| `read_skill` | Read the instructions for a workspace skill by name |
|
||||
| `read_skill_file` | Read a supporting file from a skill's directory |
|
||||
| `web_search` | Search the internet (provider-native, when enabled) |
|
||||
| Tool | Description |
|
||||
|--------------------|---------------------------------------------------------|
|
||||
| `list_templates` | Browse available workspace templates |
|
||||
| `read_template` | Get template details and configurable parameters |
|
||||
| `create_workspace` | Create a workspace from a template |
|
||||
| `start_workspace` | Start a stopped workspace for the current chat |
|
||||
| `propose_plan` | Present a Markdown plan file for user review |
|
||||
| `read_file` | Read file contents from the workspace |
|
||||
| `write_file` | Write a file to the workspace |
|
||||
| `edit_files` | Perform search-and-replace edits across files |
|
||||
| `execute` | Run shell commands in the workspace |
|
||||
| `process_output` | Retrieve output from a background process |
|
||||
| `process_list` | List all tracked processes in the workspace |
|
||||
| `process_signal` | Send a signal (terminate/kill) to a tracked process |
|
||||
| `spawn_agent` | Delegate a task to a sub-agent running in parallel |
|
||||
| `wait_agent` | Wait for a sub-agent to complete and collect its result |
|
||||
| `message_agent` | Send a follow-up message to a running sub-agent |
|
||||
| `close_agent` | Stop a running sub-agent |
|
||||
| `web_search` | Search the internet (provider-native, when enabled) |
|
||||
|
||||
These tools connect to the workspace over the same secure connection used for
|
||||
web terminals and IDE access. No additional ports or services are required in
|
||||
the workspace.
|
||||
|
||||
Platform tools (`list_templates`, `read_template`, `create_workspace`,
|
||||
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`,
|
||||
`wait_agent`, `message_agent`, `close_agent`, `spawn_computer_use_agent`)
|
||||
are only available to root chats. Sub-agents do not have access to these
|
||||
tools and cannot create workspaces or spawn further sub-agents.
|
||||
|
||||
`spawn_computer_use_agent` additionally requires an Anthropic provider and
|
||||
the virtual desktop feature to be enabled by an administrator.
|
||||
`read_skill` and `read_skill_file` are available when the workspace contains
|
||||
skills in its `.agents/skills/` directory.
|
||||
`start_workspace`, `propose_plan`) and orchestration tools (`spawn_agent`)
|
||||
are only available to root chats. Sub-agents do
|
||||
not have access to these tools and cannot create workspaces or spawn further
|
||||
sub-agents.
|
||||
|
||||
## Comparison to Coder Tasks
|
||||
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
# Models
|
||||
|
||||
Administrators configure LLM providers and models from the Coder dashboard.
|
||||
Providers, models, and API keys are deployment-wide settings managed by
|
||||
platform teams. Developers select from the set of models that an administrator
|
||||
These are deployment-wide settings — developers do not manage API keys or
|
||||
provider configuration. They select from the set of models that an administrator
|
||||
has enabled.
|
||||
|
||||
Optionally, administrators can allow developers to supply their own API keys
|
||||
for specific providers. See [User API keys](#user-api-keys-byok) below.
|
||||
|
||||
## Providers
|
||||
|
||||
Each LLM provider has a type, an API key, and an optional base URL override.
|
||||
@@ -60,38 +57,6 @@ access to LLM providers. See
|
||||
[Architecture](./architecture.md#no-api-keys-in-workspaces) for details
|
||||
on this security model.
|
||||
|
||||
### Key policy
|
||||
|
||||
Each provider has three policy flags that control how API keys are sourced:
|
||||
|
||||
| Setting | Default | Description |
|
||||
|-------------------------|---------|-----------------------------------------------------------------------------------------------------|
|
||||
| Central API key | On | The provider uses a deployment-managed API key entered by an administrator. |
|
||||
| Allow user API keys | Off | Developers may supply their own API key for this provider. |
|
||||
| Central key as fallback | Off | When user keys are allowed, fall back to the central key if a developer has not set a personal key. |
|
||||
|
||||
At least one credential source must be enabled. These settings appear in the
|
||||
provider configuration form under **Key policy**.
|
||||
|
||||
The interaction between these flags determines whether a provider is available
|
||||
to a given developer:
|
||||
|
||||
| Central key | User keys allowed | Fallback | Developer has key | Result |
|
||||
|-------------|-------------------|----------|-------------------|----------------------|
|
||||
| On | Off | — | — | Uses central key |
|
||||
| Off | On | — | Yes | Uses developer's key |
|
||||
| Off | On | — | No | Unavailable |
|
||||
| On | On | Off | Yes | Uses developer's key |
|
||||
| On | On | Off | No | Unavailable |
|
||||
| On | On | On | Yes | Uses developer's key |
|
||||
| On | On | On | No | Uses central key |
|
||||
|
||||
When a developer's personal key is present, it always takes precedence over
|
||||
the central key. When user keys are required and fallback is disabled,
|
||||
the provider is unavailable to developers who have not saved a personal key —
|
||||
even if a central key exists. This is intentional: it enforces that each
|
||||
developer authenticates with their own credentials.
|
||||
|
||||
## Models
|
||||
|
||||
Each model belongs to a provider and has its own configuration for context limits,
|
||||
@@ -167,11 +132,11 @@ fields appear dynamically in the admin UI when you select a provider.
|
||||
|
||||
#### OpenAI
|
||||
|
||||
| Option | Description |
|
||||
|-----------------------|-------------------------------------------------------------------------------------------|
|
||||
| Reasoning Effort | How much effort the model spends reasoning (`minimal`, `low`, `medium`, `high`, `xhigh`). |
|
||||
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
|
||||
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
|
||||
| Option | Description |
|
||||
|-----------------------|---------------------------------------------------------------------------------------------------|
|
||||
| Reasoning Effort | How much effort the model spends reasoning (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
|
||||
| Max Completion Tokens | Cap on completion tokens for reasoning models. |
|
||||
| Parallel Tool Calls | Whether the model can call multiple tools at once. |
|
||||
|
||||
#### Google
|
||||
|
||||
@@ -182,10 +147,10 @@ fields appear dynamically in the admin UI when you select a provider.
|
||||
|
||||
#### OpenRouter
|
||||
|
||||
| Option | Description |
|
||||
|-------------------|---------------------------------------------------|
|
||||
| Reasoning Enabled | Enable extended reasoning mode. |
|
||||
| Reasoning Effort | Reasoning effort level (`low`, `medium`, `high`). |
|
||||
| Option | Description |
|
||||
|-------------------|-------------------------------------------------------------------------------|
|
||||
| Reasoning Enabled | Enable extended reasoning mode. |
|
||||
| Reasoning Effort | Reasoning effort level (`none`, `minimal`, `low`, `medium`, `high`, `xhigh`). |
|
||||
|
||||
#### Vercel AI Gateway
|
||||
|
||||
@@ -211,49 +176,10 @@ The model selector uses the following precedence to pre-select a model:
|
||||
1. **Admin-designated default** — the model marked with the star icon.
|
||||
1. **First available model** — if no default is set and no history exists.
|
||||
|
||||
Developers cannot add their own providers or models. If no models are
|
||||
configured, the chat interface displays a message directing developers to
|
||||
Developers cannot add their own providers, models, or API keys. If no models
|
||||
are configured, the chat interface displays a message directing developers to
|
||||
contact an administrator.
|
||||
|
||||
## User API keys (BYOK)
|
||||
|
||||
When an administrator enables **Allow user API keys** on a provider,
|
||||
developers can supply their own API key from the Agents settings page.
|
||||
|
||||
### Managing personal API keys
|
||||
|
||||
1. Navigate to the **Agents** page in the Coder dashboard.
|
||||
1. Open **Settings** and select the **API Keys** tab.
|
||||
1. Each provider that allows user keys is listed with a status indicator:
|
||||
- **Key saved** — your personal key is active and will be used for requests.
|
||||
- **Using shared key** — no personal key set, but the central deployment
|
||||
key is available as a fallback.
|
||||
- **No key** — you must add a personal key before you can use this provider.
|
||||
1. Enter your API key and click **Save**.
|
||||
|
||||
Personal API keys are encrypted at rest using the same database encryption
|
||||
as deployment-managed keys. The dashboard never displays a saved key — only
|
||||
whether one is set.
|
||||
|
||||
### How key selection works
|
||||
|
||||
When you start a chat, the control plane resolves which API key to use for
|
||||
each provider:
|
||||
|
||||
1. If you have a personal key for the provider, it is used.
|
||||
1. If you do not have a personal key and central key fallback is enabled,
|
||||
the deployment-managed key is used.
|
||||
1. If you do not have a personal key and fallback is disabled, the provider
|
||||
is unavailable to you. Models from that provider will not appear in the
|
||||
model selector.
|
||||
|
||||
### Removing a personal key
|
||||
|
||||
Click **Remove** on the provider card in the API Keys settings tab. If
|
||||
central key fallback is enabled, subsequent requests will use the shared
|
||||
deployment key. If fallback is disabled, the provider becomes unavailable
|
||||
until you add a new personal key.
|
||||
|
||||
## Using an LLM proxy
|
||||
|
||||
Organizations that route LLM traffic through a centralized proxy — such as
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Conversation Data Retention
|
||||
|
||||
Coder Agents automatically cleans up old conversation data to manage database
|
||||
growth. Archived conversations and their associated files are periodically
|
||||
purged based on a configurable retention period.
|
||||
|
||||
## How it works
|
||||
|
||||
A background process runs approximately every 10 minutes to remove expired
|
||||
conversation data. Only archived conversations are eligible for deletion —
|
||||
active (non-archived) conversations are never purged.
|
||||
|
||||
When an archived conversation exceeds the retention period, it is deleted along
|
||||
with its messages, diff statuses, and queued messages via cascade. Orphaned
|
||||
files (not referenced by any active or recently-archived conversation) are also
|
||||
deleted. Both operations run in batches of 1,000 rows per cycle.
|
||||
|
||||
## Configuration
|
||||
|
||||
Navigate to the **Agents** page, open **Settings**, and select the **Behavior**
|
||||
tab to configure the conversation retention period. The default is 30 days. Use the toggle to
|
||||
disable retention entirely.
|
||||
|
||||
The retention period is stored as the `agents_chat_retention_days` key in the
|
||||
`site_configs` table and can also be managed via the API at
|
||||
`/api/experimental/chats/config/retention-days`.
|
||||
|
||||
## What gets deleted
|
||||
|
||||
| Data | Condition | Cascade |
|
||||
|------------------------|------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
|
||||
| Archived conversations | Archived longer than retention period | Messages, diff statuses, queued messages deleted via CASCADE. |
|
||||
| Conversation files | Older than retention period AND not referenced by any active or recently-archived conversation | — |
|
||||
|
||||
## Unarchive safety
|
||||
|
||||
If a user unarchives a conversation whose files were purged, stale file
|
||||
references are automatically cleaned up by FK cascades. The conversation
|
||||
remains usable but previously attached files are no longer available.
|
||||
@@ -11,12 +11,11 @@ This means:
|
||||
- **All agent configuration is admin-level.** Providers, models, system prompts,
|
||||
and tool permissions are set by platform teams from the control plane. These
|
||||
are not user preferences — they are deployment-wide policies.
|
||||
- **Developers never need to configure anything by default.** A developer just
|
||||
describes the work they want done. They do not need to pick a provider or
|
||||
write a system prompt — the platform team has already set all of that up.
|
||||
When a platform team enables user API keys for a provider, developers may
|
||||
optionally supply their own key — but this is an opt-in policy decision, not
|
||||
a requirement.
|
||||
- **Developers never need to configure anything.** A developer just describes
|
||||
the work they want done. They do not need to pick a provider, enter an API
|
||||
key, or write a system prompt — the platform team has already set all of
|
||||
that up. The goal is not to restrict developers, but to make configuration
|
||||
unnecessary for a great experience.
|
||||
- **Enforcement, not defaults.** Settings configured by administrators are
|
||||
enforced server-side. Developers cannot override them. This is a deliberate
|
||||
distinction — a setting that a user can change is a preference, not a policy.
|
||||
@@ -37,12 +36,8 @@ self-hosted models), and per-model parameters like context limits, thinking
|
||||
budgets, and reasoning effort.
|
||||
|
||||
Developers select from the set of models an administrator has enabled. They
|
||||
cannot add their own providers or access models that have not been explicitly
|
||||
configured.
|
||||
|
||||
When an administrator enables user API keys on a provider, developers can
|
||||
supply their own key from the Agents settings page. See
|
||||
[User API keys (BYOK)](../models.md#user-api-keys-byok) for details.
|
||||
cannot add their own providers, supply their own API keys, or access models that
|
||||
have not been explicitly configured.
|
||||
|
||||
See [Models](../models.md) for setup instructions.
|
||||
|
||||
@@ -89,30 +84,6 @@ opt-out, or opt-in for each chat.
|
||||
|
||||
See [MCP Servers](./mcp-servers.md) for configuration details.
|
||||
|
||||
### Virtual desktop
|
||||
|
||||
Administrators can enable a virtual desktop within agent workspaces.
|
||||
When enabled, agents can use `spawn_computer_use_agent` to interact with a
|
||||
desktop environment using screenshots, mouse, and keyboard input.
|
||||
|
||||
This setting is available under **Agents** > **Settings** > **Behavior**.
|
||||
It requires:
|
||||
|
||||
- The [portabledesktop](https://registry.coder.com/modules/coder/portabledesktop)
|
||||
module to be installed in the workspace template.
|
||||
- An Anthropic provider to be configured (computer use is an Anthropic
|
||||
capability).
|
||||
|
||||
### Workspace autostop fallback
|
||||
|
||||
Administrators can set a default autostop timer for agent-created workspaces
|
||||
that do not define one in their template. Template-defined autostop rules always
|
||||
take precedence. Active conversations extend the stop time automatically.
|
||||
|
||||
This setting is available under **Agents** > **Settings** > **Behavior**.
|
||||
The maximum configurable value is 30 days. When disabled, workspaces follow
|
||||
their template's autostop rules (or none, if the template does not define any).
|
||||
|
||||
### Usage limits and analytics
|
||||
|
||||
Administrators can set spend limits to cap LLM usage per user within a rolling
|
||||
@@ -122,19 +93,10 @@ breakdowns.
|
||||
|
||||
See [Usage & Analytics](./usage-insights.md) for details.
|
||||
|
||||
### Data retention
|
||||
|
||||
Administrators can configure a retention period for archived conversations.
|
||||
When enabled, archived conversations and orphaned files older than the
|
||||
retention period are automatically purged. The default is 30 days.
|
||||
|
||||
This setting is available under **Agents** > **Settings** > **Behavior**.
|
||||
See [Data Retention](./chat-retention.md) for details.
|
||||
|
||||
## Where we are headed
|
||||
|
||||
The controls above cover providers, models, system prompts, templates, MCP
|
||||
servers, usage limits, and data retention. We are continuing to invest in platform controls
|
||||
servers, and usage limits. We are continuing to invest in platform controls
|
||||
based on what we hear from customers deploying agents in regulated and
|
||||
enterprise environments.
|
||||
|
||||
|
||||
@@ -83,8 +83,3 @@ Select a user to see:
|
||||
bar shows current spend relative to the limit.
|
||||
- **Per-model breakdown** — table of costs and token usage by model.
|
||||
- **Per-chat breakdown** — table of costs and token usage by chat session.
|
||||
|
||||
> [!NOTE]
|
||||
> Automatic title generation uses lightweight models, such as Claude Haiku or GPT-4o
|
||||
> Mini. Its token usage is not counted towards usage limits or shown in usage
|
||||
> summaries.
|
||||
|
||||
@@ -10,13 +10,6 @@ We provide an example Grafana dashboard that you can import as a starting point
|
||||
|
||||
These logs and metrics can be used to determine usage patterns, track costs, and evaluate tooling adoption.
|
||||
|
||||
## Structured Logging
|
||||
|
||||
AI Bridge can emit structured logs for every interception event to your
|
||||
existing log pipeline. This is useful for exporting data to external SIEM or
|
||||
observability platforms. See [Structured Logging](./setup.md#structured-logging)
|
||||
in the setup guide for configuration and a full list of record types.
|
||||
|
||||
## Exporting Data
|
||||
|
||||
AI Bridge interception data can be exported for external analysis, compliance reporting, or integration with log aggregation systems.
|
||||
|
||||
@@ -150,14 +150,4 @@ ingestion, set `--log-json` to a file path or `/dev/stderr` so that records are
|
||||
emitted as JSON.
|
||||
|
||||
Filter for AI Bridge records in your logging pipeline by matching on the
|
||||
`"interception log"` message. Each log line includes a `record_type` field that
|
||||
indicates the kind of event captured:
|
||||
|
||||
| `record_type` | Description | Key fields |
|
||||
|----------------------|-----------------------------------------|--------------------------------------------------------------------------------|
|
||||
| `interception_start` | A new intercepted request begins. | `interception_id`, `initiator_id`, `provider`, `model`, `client`, `started_at` |
|
||||
| `interception_end` | An intercepted request completes. | `interception_id`, `ended_at` |
|
||||
| `token_usage` | Token consumption for a response. | `interception_id`, `input_tokens`, `output_tokens`, `created_at` |
|
||||
| `prompt_usage` | The last user prompt in a request. | `interception_id`, `prompt`, `created_at` |
|
||||
| `tool_usage` | A tool/function call made by the model. | `interception_id`, `tool`, `input`, `server_url`, `injected`, `created_at` |
|
||||
| `model_thought` | Model reasoning or thinking content. | `interception_id`, `content`, `created_at` |
|
||||
`"interception log"` message.
|
||||
|
||||
@@ -83,9 +83,9 @@ pages.
|
||||
| [2.26](https://coder.com/changelog/coder-2-26) | September 03, 2025 | Not Supported | [v2.26.6](https://github.com/coder/coder/releases/tag/v2.26.6) |
|
||||
| [2.27](https://coder.com/changelog/coder-2-27) | October 02, 2025 | Not Supported | [v2.27.11](https://github.com/coder/coder/releases/tag/v2.27.11) |
|
||||
| [2.28](https://coder.com/changelog/coder-2-28) | November 04, 2025 | Not Supported | [v2.28.11](https://github.com/coder/coder/releases/tag/v2.28.11) |
|
||||
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Extended Support Release | [v2.29.9](https://github.com/coder/coder/releases/tag/v2.29.9) |
|
||||
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Security Support | [v2.30.6](https://github.com/coder/coder/releases/tag/v2.30.6) |
|
||||
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Stable | [v2.31.7](https://github.com/coder/coder/releases/tag/v2.31.7) |
|
||||
| [2.29](https://coder.com/changelog/coder-2-29) | December 02, 2025 | Security Support + ESR | [v2.29.8](https://github.com/coder/coder/releases/tag/v2.29.8) |
|
||||
| [2.30](https://coder.com/changelog/coder-2-30) | February 03, 2026 | Stable | [v2.30.3](https://github.com/coder/coder/releases/tag/v2.30.3) |
|
||||
| [2.31](https://coder.com/changelog/coder-2-31) | February 23, 2026 | Mainline | [v2.31.5](https://github.com/coder/coder/releases/tag/v2.31.5) |
|
||||
| 2.32 | | Not Released | N/A |
|
||||
<!-- RELEASE_CALENDAR_END -->
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user