Compare commits
91 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 72960aeb77 | |||
| a62ead8588 | |||
| b68c14dd04 | |||
| 508114d484 | |||
| e0fbb0e4ec | |||
| 7bde763b66 | |||
| 36141fafad | |||
| 3462c31f43 | |||
| a0ea71b74c | |||
| 0a14bb529e | |||
| 2c32d84f12 | |||
| 76d89f59af | |||
| 1a3a92bd1b | |||
| 4018320614 | |||
| d9700baa8d | |||
| 82456ff62e | |||
| 83fd4cf5c2 | |||
| 38d4da82b9 | |||
| 19e0e0e8e6 | |||
| 1d0653cdab | |||
| 95cff8c5fb | |||
| ad2415ede7 | |||
| 1e40cea199 | |||
| 9d6557d173 | |||
| 224db483d7 | |||
| 8237822441 | |||
| 65bf7c3b18 | |||
| 76cbc580f0 | |||
| 391b22aef7 | |||
| f8e8f979a2 | |||
| fb0ed1162b | |||
| 3f519744aa | |||
| 2505f6245f | |||
| 29ad2c6201 | |||
| 27e5ff0a8e | |||
| 128a7c23e6 | |||
| efb19eb748 | |||
| 2c499484b7 | |||
| 33d9d0d875 | |||
| f219834f5c | |||
| 7a94a683c4 | |||
| 2e6fdf2344 | |||
| 3d139c1a24 | |||
| f957981c8b | |||
| 584c61acb5 | |||
| f95a5202bf | |||
| d954460380 | |||
| f4240bb8c1 | |||
| 7caef4987f | |||
| 9b91af8ab7 | |||
| 506fba9ebf | |||
| 461a31e5d8 | |||
| e3a0dcd6fc | |||
| 12ada0115f | |||
| 7b0421d8c6 | |||
| 477d6d0cde | |||
| de61ac529d | |||
| 7f496c2f18 | |||
| 590235138f | |||
| 543c448b72 | |||
| 35c26ce22a | |||
| c2592c9f12 | |||
| b969d66978 | |||
| 1f808cdc62 | |||
| 497f637f58 | |||
| be686a8d0d | |||
| 7b7baea851 | |||
| a3de0fc78d | |||
| ab77154975 | |||
| c5d720f73d | |||
| 983819860f | |||
| f820945d9f | |||
| da5395a8ae | |||
| 86b919e4f7 | |||
| 233343c010 | |||
| 3a612898c6 | |||
| 3f7a3e3354 | |||
| 17a71aea72 | |||
| 7d3c5ac78c | |||
| d87c5ef439 | |||
| ef3e17317c | |||
| 1187b84c54 | |||
| 45336bd9ce | |||
| 36cf7debce | |||
| 027c222e82 | |||
| d00f148b76 | |||
| 48bc215f20 | |||
| 08bd9e672a | |||
| c5f1a2fccf | |||
| 655d647d40 | |||
| f3f0a2c553 |
@@ -1,2 +0,0 @@
|
||||
enabled: true
|
||||
preservePullRequestTitle: true
|
||||
@@ -0,0 +1,178 @@
|
||||
# Automatically backport merged PRs to the last N release branches when the
|
||||
# "backport" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "backport" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branches and opens one cherry-pick PR per branch.
|
||||
#
|
||||
# The created backport PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Backport
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: backport-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
name: Detect target branches
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'backport')
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
branches: ${{ steps.find.outputs.branches }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Need all refs to discover release branches.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Find latest release branches
|
||||
id: find
|
||||
run: |
|
||||
# List remote release branches matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix), sort by minor
|
||||
# version descending, and take the top 3.
|
||||
BRANCHES=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -3
|
||||
)
|
||||
|
||||
if [ -z "$BRANCHES" ]; then
|
||||
echo "No release branches found."
|
||||
echo "branches=[]" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Convert to JSON array for the matrix.
|
||||
JSON=$(echo "$BRANCHES" | jq -Rnc '[inputs | select(length > 0)]')
|
||||
echo "branches=$JSON" >> "$GITHUB_OUTPUT"
|
||||
echo "Will backport to: $JSON"
|
||||
|
||||
backport:
|
||||
name: "Backport to ${{ matrix.branch }}"
|
||||
needs: detect
|
||||
if: needs.detect.outputs.branches != '[]'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
branch: ${{ fromJson(needs.detect.outputs.branches) }}
|
||||
fail-fast: false
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
RELEASE_VERSION="${{ matrix.branch }}"
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_VERSION#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Backport branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_VERSION}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICTS=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
echo "::warning::Cherry-pick to ${RELEASE_VERSION} had conflicts."
|
||||
CONFLICTS=true
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit
|
||||
# explaining the situation.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "Cherry-pick of #${PR_NUMBER} requires manual resolution
|
||||
|
||||
The automatic cherry-pick of ${MERGE_SHA} to ${RELEASE_VERSION} had conflicts.
|
||||
Please cherry-pick manually:
|
||||
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
BODY=$(cat <<EOF
|
||||
Backport of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
if [ "$CONFLICTS" = true ]; then
|
||||
TITLE="${TITLE} (conflicts)"
|
||||
BODY="${BODY}
|
||||
|
||||
> [!WARNING]
|
||||
> The automatic cherry-pick had conflicts.
|
||||
> Please resolve manually by cherry-picking the original merge commit:
|
||||
>
|
||||
> \`\`\`
|
||||
> git fetch origin ${BACKPORT_BRANCH}
|
||||
> git checkout ${BACKPORT_BRANCH}
|
||||
> git reset --hard origin/${RELEASE_VERSION}
|
||||
> git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
> # resolve conflicts, then push
|
||||
> \`\`\`"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs).
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_VERSION" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
gh pr create \
|
||||
--base "$RELEASE_VERSION" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
@@ -0,0 +1,152 @@
|
||||
# Automatically cherry-pick merged PRs to the latest release branch when the
|
||||
# "cherry-pick" label is applied. Works whether the label is added before or
|
||||
# after the PR is merged.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Add the "cherry-pick" label to a PR targeting main.
|
||||
# 2. When the PR merges (or if already merged), the workflow detects the
|
||||
# latest release/* branch and opens a cherry-pick PR against it.
|
||||
#
|
||||
# The created PRs follow existing repo conventions:
|
||||
# - Branch: backport/<pr>-to-<version>
|
||||
# - Title: <original PR title> (#<pr>)
|
||||
# - Body: links back to the original PR and merge commit
|
||||
|
||||
name: Cherry-pick to release
|
||||
on:
|
||||
pull_request_target:
|
||||
branches:
|
||||
- main
|
||||
types:
|
||||
- closed
|
||||
- labeled
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
# Prevent duplicate runs for the same PR when both 'closed' and 'labeled'
|
||||
# fire in quick succession.
|
||||
concurrency:
|
||||
group: cherry-pick-${{ github.event.pull_request.number }}
|
||||
|
||||
jobs:
|
||||
cherry-pick:
|
||||
name: Cherry-pick to latest release
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
contains(github.event.pull_request.labels.*.name, 'cherry-pick')
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
MERGE_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
SENDER: ${{ github.event.sender.login }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
# Full history required for cherry-pick and branch discovery.
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Cherry-pick and open PR
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Find the latest release branch matching the exact release/2.X
|
||||
# pattern (no suffixes like release/2.31_hotfix).
|
||||
RELEASE_BRANCH=$(
|
||||
git branch -r \
|
||||
| grep -E '^\s*origin/release/2\.[0-9]+$' \
|
||||
| sed 's|.*origin/||' \
|
||||
| sort -t. -k2 -n -r \
|
||||
| head -1
|
||||
)
|
||||
|
||||
if [ -z "$RELEASE_BRANCH" ]; then
|
||||
echo "::error::No release branch found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Strip the release/ prefix for naming.
|
||||
VERSION="${RELEASE_BRANCH#release/}"
|
||||
BACKPORT_BRANCH="backport/${PR_NUMBER}-to-${VERSION}"
|
||||
|
||||
echo "Target branch: $RELEASE_BRANCH"
|
||||
echo "Backport branch: $BACKPORT_BRANCH"
|
||||
|
||||
# Check if backport branch already exists (idempotency for re-runs).
|
||||
if git ls-remote --exit-code origin "refs/heads/${BACKPORT_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Branch ${BACKPORT_BRANCH} already exists, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Create the backport branch from the target release branch.
|
||||
git checkout -b "$BACKPORT_BRANCH" "origin/${RELEASE_BRANCH}"
|
||||
|
||||
# Cherry-pick the merge commit. Use -x to record provenance and
|
||||
# -m1 to pick the first parent (the main branch side).
|
||||
CONFLICT=false
|
||||
if ! git cherry-pick -x -m1 "$MERGE_SHA"; then
|
||||
CONFLICT=true
|
||||
echo "::warning::Cherry-pick to ${RELEASE_BRANCH} had conflicts."
|
||||
|
||||
# Abort the failed cherry-pick and create an empty commit with
|
||||
# instructions so the PR can still be opened.
|
||||
git cherry-pick --abort
|
||||
git commit --allow-empty -m "cherry-pick of #${PR_NUMBER} failed — resolve conflicts manually
|
||||
|
||||
Cherry-pick of ${MERGE_SHA} onto ${RELEASE_BRANCH} had conflicts.
|
||||
To resolve:
|
||||
git fetch origin ${BACKPORT_BRANCH}
|
||||
git checkout ${BACKPORT_BRANCH}
|
||||
git cherry-pick -x -m1 ${MERGE_SHA}
|
||||
# resolve conflicts
|
||||
git push origin ${BACKPORT_BRANCH}"
|
||||
fi
|
||||
|
||||
git push origin "$BACKPORT_BRANCH"
|
||||
|
||||
BODY=$(cat <<EOF
|
||||
Cherry-pick of ${PR_URL}
|
||||
|
||||
Original PR: #${PR_NUMBER} — ${PR_TITLE}
|
||||
Merge commit: ${MERGE_SHA}
|
||||
Requested by: @${SENDER}
|
||||
EOF
|
||||
)
|
||||
|
||||
TITLE="${PR_TITLE} (#${PR_NUMBER})"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
TITLE="[CONFLICT] ${TITLE}"
|
||||
fi
|
||||
|
||||
# Check if a PR already exists for this branch (idempotency
|
||||
# for re-runs). Use --state all to catch closed/merged PRs too.
|
||||
EXISTING_PR=$(gh pr list --head "$BACKPORT_BRANCH" --base "$RELEASE_BRANCH" --state all --json number --jq '.[0].number // empty')
|
||||
if [ -n "$EXISTING_PR" ]; then
|
||||
echo "PR #${EXISTING_PR} already exists for ${BACKPORT_BRANCH}, skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
NEW_PR_URL=$(
|
||||
gh pr create \
|
||||
--base "$RELEASE_BRANCH" \
|
||||
--head "$BACKPORT_BRANCH" \
|
||||
--title "$TITLE" \
|
||||
--body "$BODY" \
|
||||
--assignee "$SENDER" \
|
||||
--reviewer "$SENDER"
|
||||
)
|
||||
|
||||
# Comment on the original PR to notify the author.
|
||||
COMMENT="Cherry-pick PR created: ${NEW_PR_URL}"
|
||||
if [ "$CONFLICT" = true ]; then
|
||||
COMMENT="${COMMENT} (⚠️ conflicts need manual resolution)"
|
||||
fi
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT"
|
||||
@@ -0,0 +1,93 @@
|
||||
# Ensures that only bug fixes are cherry-picked to release branches.
|
||||
# PRs targeting release/* must have a title starting with "fix:" or "fix(scope):".
|
||||
name: PR Cherry-Pick Check
|
||||
|
||||
on:
|
||||
# zizmor: ignore[dangerous-triggers] Only reads PR metadata and comments; does not checkout PR code.
|
||||
pull_request_target:
|
||||
types: [opened, reopened, edited]
|
||||
branches:
|
||||
- "release/*"
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
check-cherry-pick:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@fe104658747b27e96e4f7e80cd0a94068e53901d # v2.16.1
|
||||
with:
|
||||
egress-policy: audit
|
||||
|
||||
- name: Check PR title for bug fix
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const baseBranch = context.payload.pull_request.base.ref;
|
||||
const author = context.payload.pull_request.user.login;
|
||||
|
||||
console.log(`PR #${prNumber}: "${title}" -> ${baseBranch}`);
|
||||
|
||||
// Match conventional commit "fix:" or "fix(scope):" prefix.
|
||||
const isBugFix = /^fix(\(.+\))?:/.test(title);
|
||||
|
||||
if (isBugFix) {
|
||||
console.log("PR title indicates a bug fix. No action needed.");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log("PR title does not indicate a bug fix. Commenting.");
|
||||
|
||||
// Check for an existing comment from this bot to avoid duplicates
|
||||
// on title edits.
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
});
|
||||
|
||||
const marker = "<!-- cherry-pick-check -->";
|
||||
const existingComment = comments.find(
|
||||
(c) => c.body && c.body.includes(marker),
|
||||
);
|
||||
|
||||
const body = [
|
||||
marker,
|
||||
`👋 Hey @${author}!`,
|
||||
"",
|
||||
`This PR is targeting the \`${baseBranch}\` release branch, but its title does not start with \`fix:\` or \`fix(scope):\`.`,
|
||||
"",
|
||||
"Only **bug fixes** should be cherry-picked to release branches. If this is a bug fix, please update the PR title to match the conventional commit format:",
|
||||
"",
|
||||
"```",
|
||||
"fix: description of the bug fix",
|
||||
"fix(scope): description of the bug fix",
|
||||
"```",
|
||||
"",
|
||||
"If this is **not** a bug fix, it likely should not target a release branch.",
|
||||
].join("\n");
|
||||
|
||||
if (existingComment) {
|
||||
console.log(`Updating existing comment ${existingComment.id}.`);
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existingComment.id,
|
||||
body,
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body,
|
||||
});
|
||||
}
|
||||
|
||||
core.warning(
|
||||
`PR #${prNumber} targets ${baseBranch} but is not a bug fix. Title must start with "fix:" or "fix(scope):".`,
|
||||
);
|
||||
@@ -121,22 +121,22 @@ jobs:
|
||||
fi
|
||||
|
||||
# Derive the release branch from the version tag.
|
||||
# Standard: 2.10.2 -> release/2.10
|
||||
# RC: 2.32.0-rc.0 -> release/2.32-rc.0
|
||||
# Non-RC releases must be on a release/X.Y branch.
|
||||
# RC tags are allowed on any branch (typically main).
|
||||
version="$(./scripts/version.sh)"
|
||||
# Strip any pre-release suffix first (e.g. 2.32.0-rc.0 -> 2.32.0)
|
||||
base_version="${version%%-*}"
|
||||
# Then strip patch to get major.minor (e.g. 2.32.0 -> 2.32)
|
||||
release_branch="release/${base_version%.*}"
|
||||
|
||||
if [[ "$version" == *-rc.* ]]; then
|
||||
# Extract major.minor and rc suffix from e.g. 2.32.0-rc.0
|
||||
base_version="${version%%-rc.*}" # 2.32.0
|
||||
major_minor="${base_version%.*}" # 2.32
|
||||
rc_suffix="${version##*-rc.}" # 0
|
||||
release_branch="release/${major_minor}-rc.${rc_suffix}"
|
||||
echo "RC release detected — skipping release branch check (RC tags are cut from main)."
|
||||
else
|
||||
release_branch=release/${version%.*}
|
||||
fi
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
branch_contains_tag=$(git branch --remotes --contains "${GITHUB_REF}" --list "*/${release_branch}" --format='%(refname)')
|
||||
if [[ -z "${branch_contains_tag}" ]]; then
|
||||
echo "Ref tag must exist in a branch named ${release_branch} when creating a non-RC release, did you use scripts/release.sh?"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CODER_RELEASE_NOTES}" ]]; then
|
||||
|
||||
@@ -36,6 +36,7 @@ typ = "typ"
|
||||
styl = "styl"
|
||||
edn = "edn"
|
||||
Inferrable = "Inferrable"
|
||||
IIF = "IIF"
|
||||
|
||||
[files]
|
||||
extend-exclude = [
|
||||
|
||||
@@ -103,3 +103,6 @@ PLAN.md
|
||||
|
||||
# Ignore any dev licenses
|
||||
license.txt
|
||||
-e
|
||||
# Agent planning documents (local working files).
|
||||
docs/plans/
|
||||
|
||||
@@ -91,6 +91,59 @@ define atomic_write
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
endef
|
||||
|
||||
# Helper binary targets. Built with go build -o to avoid caching
|
||||
# link-stage executables in GOCACHE. Each binary is a real Make
|
||||
# target so parallel -j builds serialize correctly instead of
|
||||
# racing on the same output path.
|
||||
|
||||
_gen/bin/apitypings: $(wildcard scripts/apitypings/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apitypings
|
||||
|
||||
_gen/bin/auditdocgen: $(wildcard scripts/auditdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/auditdocgen
|
||||
|
||||
_gen/bin/check-scopes: $(wildcard scripts/check-scopes/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/check-scopes
|
||||
|
||||
_gen/bin/clidocgen: $(wildcard scripts/clidocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/clidocgen
|
||||
|
||||
_gen/bin/dbdump: $(wildcard coderd/database/gen/dump/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./coderd/database/gen/dump
|
||||
|
||||
_gen/bin/examplegen: $(wildcard scripts/examplegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/examplegen
|
||||
|
||||
_gen/bin/gensite: $(wildcard scripts/gensite/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/gensite
|
||||
|
||||
_gen/bin/apikeyscopesgen: $(wildcard scripts/apikeyscopesgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/apikeyscopesgen
|
||||
|
||||
_gen/bin/metricsdocgen: $(wildcard scripts/metricsdocgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen
|
||||
|
||||
_gen/bin/metricsdocgen-scanner: $(wildcard scripts/metricsdocgen/scanner/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/metricsdocgen/scanner
|
||||
|
||||
_gen/bin/modeloptionsgen: $(wildcard scripts/modeloptionsgen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/modeloptionsgen
|
||||
|
||||
_gen/bin/typegen: $(wildcard scripts/typegen/*.go) | _gen
|
||||
@mkdir -p _gen/bin
|
||||
go build -o $@ ./scripts/typegen
|
||||
|
||||
# Shared temp directory for atomic writes. Lives at the project root
|
||||
# so all targets share the same filesystem, and is gitignored.
|
||||
# Order-only prerequisite: recipes that need it depend on | _gen
|
||||
@@ -201,6 +254,7 @@ endif
|
||||
|
||||
clean:
|
||||
rm -rf build/ site/build/ site/out/
|
||||
rm -rf _gen/bin
|
||||
mkdir -p build/
|
||||
git restore site/out/
|
||||
.PHONY: clean
|
||||
@@ -654,8 +708,8 @@ lint/go:
|
||||
go tool github.com/coder/paralleltestctx/cmd/paralleltestctx -custom-funcs="testutil.Context" ./...
|
||||
.PHONY: lint/go
|
||||
|
||||
lint/examples:
|
||||
go run ./scripts/examplegen/main.go -lint
|
||||
lint/examples: | _gen/bin/examplegen
|
||||
_gen/bin/examplegen -lint
|
||||
.PHONY: lint/examples
|
||||
|
||||
# Use shfmt to determine the shell files, takes editorconfig into consideration.
|
||||
@@ -693,8 +747,8 @@ lint/actions/zizmor:
|
||||
.PHONY: lint/actions/zizmor
|
||||
|
||||
# Verify api_key_scope enum contains all RBAC <resource>:<action> values.
|
||||
lint/check-scopes: coderd/database/dump.sql
|
||||
go run ./scripts/check-scopes
|
||||
lint/check-scopes: coderd/database/dump.sql | _gen/bin/check-scopes
|
||||
_gen/bin/check-scopes
|
||||
.PHONY: lint/check-scopes
|
||||
|
||||
# Verify migrations do not hardcode the public schema.
|
||||
@@ -734,8 +788,8 @@ lint/typos: build/typos-$(TYPOS_VERSION)
|
||||
# The pre-push hook is allowlisted, see scripts/githooks/pre-push.
|
||||
#
|
||||
# pre-commit uses two phases: gen+fmt first, then lint+build. This
|
||||
# avoids races where gen's `go run` creates temporary .go files that
|
||||
# lint's find-based checks pick up. Within each phase, targets run in
|
||||
# avoids races where gen creates temporary .go files that lint's
|
||||
# find-based checks pick up. Within each phase, targets run in
|
||||
# parallel via -j. It fails if any tracked files have unstaged
|
||||
# changes afterward.
|
||||
|
||||
@@ -949,8 +1003,8 @@ gen/mark-fresh:
|
||||
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql) | _gen/bin/dbdump
|
||||
_gen/bin/dbdump
|
||||
touch "$@"
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
@@ -1067,88 +1121,88 @@ enterprise/aibridged/proto/aibridged.pb.go: enterprise/aibridged/proto/aibridged
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./enterprise/aibridged/proto/aibridged.proto
|
||||
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen
|
||||
$(call atomic_write,go run -C ./scripts/apitypings main.go,./scripts/biome_format.sh)
|
||||
site/src/api/typesGenerated.ts: site/node_modules/.installed $(wildcard scripts/apitypings/*) $(shell find ./codersdk $(FIND_EXCLUSIONS) -type f -name '*.go') | _gen _gen/bin/apitypings
|
||||
$(call atomic_write,_gen/bin/apitypings,./scripts/biome_format.sh)
|
||||
|
||||
site/e2e/provisionerGenerated.ts: site/node_modules/.installed provisionerd/proto/provisionerd.pb.go provisionersdk/proto/provisioner.pb.go
|
||||
(cd site/ && pnpm run gen:provisioner)
|
||||
touch "$@"
|
||||
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen
|
||||
site/src/theme/icons.json: site/node_modules/.installed $(wildcard scripts/gensite/*) $(wildcard site/static/icon/*) | _gen _gen/bin/gensite
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && \
|
||||
go run ./scripts/gensite/ -icons "$$tmpfile" && \
|
||||
_gen/bin/gensite -icons "$$tmpfile" && \
|
||||
./scripts/biome_format.sh "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen
|
||||
$(call atomic_write,go run ./scripts/examplegen/main.go)
|
||||
examples/examples.gen.json: scripts/examplegen/main.go examples/examples.go $(shell find ./examples/templates) | _gen _gen/bin/examplegen
|
||||
$(call atomic_write,_gen/bin/examplegen)
|
||||
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac object)
|
||||
coderd/rbac/object_gen.go: scripts/typegen/rbacobject.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac object)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go because `go run` compiles
|
||||
# coderd/rbac which includes it.
|
||||
# NOTE: depends on object_gen.go because the generator build
|
||||
# compiles coderd/rbac which includes it.
|
||||
coderd/rbac/scopes_constants_gen.go: scripts/typegen/scopenames.gotmpl scripts/typegen/main.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go | _gen
|
||||
coderd/rbac/object_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file first to avoid truncating the package
|
||||
# during build since the generator imports the rbac package.
|
||||
$(call atomic_write,go run ./scripts/typegen/main.go rbac scopenames)
|
||||
$(call atomic_write,_gen/bin/typegen rbac scopenames)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/rbacresources_gen.go: scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
# Write to a temp file to avoid truncating the target, which
|
||||
# would break the codersdk package and any parallel build targets.
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac codersdk)
|
||||
$(call atomic_write,_gen/bin/typegen rbac codersdk)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
codersdk/apikey_scopes_gen.go: scripts/apikeyscopesgen/main.go coderd/rbac/scopes_catalog.go coderd/rbac/scopes.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/apikeyscopesgen
|
||||
# Generate SDK constants for external API key scopes.
|
||||
$(call atomic_write,go run ./scripts/apikeyscopesgen)
|
||||
$(call atomic_write,_gen/bin/apikeyscopesgen)
|
||||
touch "$@"
|
||||
|
||||
# NOTE: depends on object_gen.go and scopes_constants_gen.go because
|
||||
# `go run` compiles coderd/rbac which includes both.
|
||||
# the generator build compiles coderd/rbac which includes both.
|
||||
site/src/api/rbacresourcesGenerated.ts: site/node_modules/.installed scripts/typegen/codersdk.gotmpl scripts/typegen/main.go coderd/rbac/object.go coderd/rbac/policy/policy.go \
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go rbac typescript,./scripts/biome_format.sh)
|
||||
coderd/rbac/object_gen.go coderd/rbac/scopes_constants_gen.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen rbac typescript,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen
|
||||
$(call atomic_write,go run scripts/typegen/main.go countries,./scripts/biome_format.sh)
|
||||
site/src/api/countriesGenerated.ts: site/node_modules/.installed scripts/typegen/countries.tstmpl scripts/typegen/main.go codersdk/countries.go | _gen _gen/bin/typegen
|
||||
$(call atomic_write,_gen/bin/typegen countries,./scripts/biome_format.sh)
|
||||
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen
|
||||
$(call atomic_write,go run ./scripts/modeloptionsgen/main.go | tail -n +2,./scripts/biome_format.sh)
|
||||
site/src/api/chatModelOptionsGenerated.json: scripts/modeloptionsgen/main.go codersdk/chats.go | _gen _gen/bin/modeloptionsgen
|
||||
$(call atomic_write,_gen/bin/modeloptionsgen | tail -n +2,./scripts/biome_format.sh)
|
||||
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen
|
||||
$(call atomic_write,go run ./scripts/metricsdocgen/scanner)
|
||||
scripts/metricsdocgen/generated_metrics: $(GO_SRC_FILES) | _gen _gen/bin/metricsdocgen-scanner
|
||||
$(call atomic_write,_gen/bin/metricsdocgen-scanner)
|
||||
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen
|
||||
docs/admin/integrations/prometheus.md: node_modules/.installed scripts/metricsdocgen/main.go scripts/metricsdocgen/metrics scripts/metricsdocgen/generated_metrics | _gen _gen/bin/metricsdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/metricsdocgen/main.go --prometheus-doc-file="$$tmpfile" && \
|
||||
_gen/bin/metricsdocgen --prometheus-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen
|
||||
docs/reference/cli/index.md: node_modules/.installed scripts/clidocgen/main.go examples/examples.gen.json $(GO_SRC_FILES) | _gen _gen/bin/clidocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && \
|
||||
tmpdir=$$(realpath "$$tmpdir") && \
|
||||
mkdir -p "$$tmpdir/docs/reference/cli" && \
|
||||
cp docs/manifest.json "$$tmpdir/docs/manifest.json" && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" go run ./scripts/clidocgen && \
|
||||
CI=true DOCS_DIR="$$tmpdir/docs" _gen/bin/clidocgen && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpdir/docs/reference/cli/*.md" && \
|
||||
for f in "$$tmpdir/docs/reference/cli/"*.md; do mv "$$f" "docs/reference/cli/$$(basename "$$f")"; done && \
|
||||
rm -rf "$$tmpdir"
|
||||
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen
|
||||
docs/admin/security/audit-logs.md: node_modules/.installed coderd/database/querier.go scripts/auditdocgen/main.go enterprise/audit/table.go coderd/rbac/object_gen.go | _gen _gen/bin/auditdocgen
|
||||
tmpdir=$$(mktemp -d -p _gen) && tmpfile=$$(realpath "$$tmpdir")/$(notdir $@) && cp "$@" "$$tmpfile" && \
|
||||
go run scripts/auditdocgen/main.go --audit-doc-file="$$tmpfile" && \
|
||||
_gen/bin/auditdocgen --audit-doc-file="$$tmpfile" && \
|
||||
pnpm exec markdownlint-cli2 --fix "$$tmpfile" && \
|
||||
pnpm exec markdown-table-formatter "$$tmpfile" && \
|
||||
mv "$$tmpfile" "$@" && rm -rf "$$tmpdir"
|
||||
|
||||
+14
-6
@@ -102,6 +102,8 @@ type Options struct {
|
||||
ReportMetadataInterval time.Duration
|
||||
ServiceBannerRefreshInterval time.Duration
|
||||
BlockFileTransfer bool
|
||||
BlockReversePortForwarding bool
|
||||
BlockLocalPortForwarding bool
|
||||
Execer agentexec.Execer
|
||||
Devcontainers bool
|
||||
DevcontainerAPIOptions []agentcontainers.Option // Enable Devcontainers for these to be effective.
|
||||
@@ -214,6 +216,8 @@ func New(options Options) Agent {
|
||||
subsystems: options.Subsystems,
|
||||
logSender: agentsdk.NewLogSender(options.Logger),
|
||||
blockFileTransfer: options.BlockFileTransfer,
|
||||
blockReversePortForwarding: options.BlockReversePortForwarding,
|
||||
blockLocalPortForwarding: options.BlockLocalPortForwarding,
|
||||
|
||||
prometheusRegistry: prometheusRegistry,
|
||||
metrics: newAgentMetrics(prometheusRegistry),
|
||||
@@ -280,6 +284,8 @@ type agent struct {
|
||||
sshServer *agentssh.Server
|
||||
sshMaxTimeout time.Duration
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
|
||||
lifecycleUpdate chan struct{}
|
||||
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
|
||||
@@ -331,12 +337,14 @@ func (a *agent) TailnetConn() *tailnet.Conn {
|
||||
func (a *agent) init() {
|
||||
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
|
||||
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.execer, &agentssh.Config{
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
MaxTimeout: a.sshMaxTimeout,
|
||||
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
|
||||
AnnouncementBanners: func() *[]codersdk.BannerConfig { return a.announcementBanners.Load() },
|
||||
UpdateEnv: a.updateCommandEnv,
|
||||
WorkingDirectory: func() string { return a.manifest.Load().Directory },
|
||||
BlockFileTransfer: a.blockFileTransfer,
|
||||
BlockReversePortForwarding: a.blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: a.blockLocalPortForwarding,
|
||||
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
|
||||
var connectionType proto.Connection_Type
|
||||
switch magicType {
|
||||
|
||||
@@ -986,6 +986,161 @@ func TestAgent_TCPRemoteForwarding(t *testing.T) {
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_TCPLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_TCPRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
_, err = sshClient.ListenTCP(addr)
|
||||
require.ErrorContains(t, err, "tcpip-forward request denied by peer")
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
l, err := net.Listen("unix", remoteSocketPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.Dial("unix", remoteSocketPath)
|
||||
require.ErrorContains(t, err, "administratively prohibited")
|
||||
}
|
||||
|
||||
func TestAgent_UnixRemoteForwardingBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix domain sockets are not fully supported on Windows")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
tmpdir := testutil.TempDirUnixSocket(t)
|
||||
remoteSocketPath := filepath.Join(tmpdir, "remote-socket")
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
_, err = sshClient.ListenUnix(remoteSocketPath)
|
||||
require.ErrorContains(t, err, "streamlocal-forward@openssh.com request denied by peer")
|
||||
}
|
||||
|
||||
// TestAgent_LocalBlockedDoesNotAffectReverse verifies that blocking
|
||||
// local port forwarding does not prevent reverse port forwarding from
|
||||
// working. A field-name transposition at any plumbing hop would cause
|
||||
// both directions to be blocked when only one flag is set.
|
||||
func TestAgent_LocalBlockedDoesNotAffectReverse(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockLocalPortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Reverse forwarding must still work.
|
||||
localhost := netip.MustParseAddr("127.0.0.1")
|
||||
var ll net.Listener
|
||||
for {
|
||||
randomPort := testutil.RandomPortNoListen(t)
|
||||
addr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(localhost, randomPort))
|
||||
ll, err = sshClient.ListenTCP(addr)
|
||||
if err != nil {
|
||||
t.Logf("error remote forwarding: %s", err.Error())
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timed out getting random listener")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
_ = ll.Close()
|
||||
}
|
||||
|
||||
// TestAgent_ReverseBlockedDoesNotAffectLocal verifies that blocking
|
||||
// reverse port forwarding does not prevent local port forwarding from
|
||||
// working.
|
||||
func TestAgent_ReverseBlockedDoesNotAffectLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
rl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
tcpAddr, valid := rl.Addr().(*net.TCPAddr)
|
||||
require.True(t, valid)
|
||||
remotePort := tcpAddr.Port
|
||||
go echoOnce(t, rl)
|
||||
|
||||
//nolint:dogsled
|
||||
agentConn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
|
||||
o.BlockReversePortForwarding = true
|
||||
})
|
||||
sshClient, err := agentConn.SSHClient(ctx)
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
|
||||
// Local forwarding must still work.
|
||||
conn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", remotePort))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
requireEcho(t, conn)
|
||||
}
|
||||
|
||||
func TestAgent_UnixLocalForwarding(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
@@ -2862,6 +2862,126 @@ func TestAPI(t *testing.T) {
|
||||
"rebuilt agent should include updated display apps")
|
||||
})
|
||||
|
||||
// Verify that when a terraform-managed subagent is injected into
|
||||
// a devcontainer, the Directory field sent to Create reflects
|
||||
// the container-internal workspaceFolder from devcontainer
|
||||
// read-configuration, not the host-side workspace_folder from
|
||||
// the terraform resource. This is the scenario described in
|
||||
// https://linear.app/codercom/issue/PRODUCT-259:
|
||||
// 1. Non-terraform subagent → directory = /workspaces/foo (correct)
|
||||
// 2. Terraform subagent → directory was stuck on host path (bug)
|
||||
t.Run("TerraformDefinedSubAgentUsesContainerInternalDirectory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Dev Container tests are not supported on Windows (this test uses mocks but fails due to Windows paths)")
|
||||
}
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitMedium)
|
||||
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
mCtrl = gomock.NewController(t)
|
||||
|
||||
terraformAgentID = uuid.New()
|
||||
containerID = "test-container-id"
|
||||
|
||||
// Given: A container with a host-side workspace folder.
|
||||
terraformContainer = codersdk.WorkspaceAgentContainer{
|
||||
ID: containerID,
|
||||
FriendlyName: "test-container",
|
||||
Image: "test-image",
|
||||
Running: true,
|
||||
CreatedAt: time.Now(),
|
||||
Labels: map[string]string{
|
||||
agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project",
|
||||
agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
},
|
||||
}
|
||||
|
||||
// Given: A terraform-defined devcontainer whose
|
||||
// workspace_folder is the HOST-side path (set by provisioner).
|
||||
terraformDevcontainer = codersdk.WorkspaceAgentDevcontainer{
|
||||
ID: uuid.New(),
|
||||
Name: "terraform-devcontainer",
|
||||
WorkspaceFolder: "/home/coder/project",
|
||||
ConfigPath: "/home/coder/project/.devcontainer/devcontainer.json",
|
||||
SubagentID: uuid.NullUUID{UUID: terraformAgentID, Valid: true},
|
||||
}
|
||||
|
||||
fCCLI = &fakeContainerCLI{
|
||||
containers: codersdk.WorkspaceAgentListContainersResponse{
|
||||
Containers: []codersdk.WorkspaceAgentContainer{terraformContainer},
|
||||
},
|
||||
arch: runtime.GOARCH,
|
||||
}
|
||||
|
||||
// Given: devcontainer read-configuration returns the
|
||||
// CONTAINER-INTERNAL workspace folder.
|
||||
fDCCLI = &fakeDevcontainerCLI{
|
||||
upID: containerID,
|
||||
readConfig: agentcontainers.DevcontainerConfig{
|
||||
Workspace: agentcontainers.DevcontainerWorkspace{
|
||||
WorkspaceFolder: "/workspaces/project",
|
||||
},
|
||||
MergedConfiguration: agentcontainers.DevcontainerMergedConfiguration{
|
||||
Customizations: agentcontainers.DevcontainerMergedCustomizations{
|
||||
Coder: []agentcontainers.CoderCustomization{{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mSAC = acmock.NewMockSubAgentClient(mCtrl)
|
||||
createCalls = make(chan agentcontainers.SubAgent, 1)
|
||||
closed bool
|
||||
)
|
||||
|
||||
mSAC.EXPECT().List(gomock.Any()).Return([]agentcontainers.SubAgent{}, nil).AnyTimes()
|
||||
|
||||
mSAC.EXPECT().Create(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, agent agentcontainers.SubAgent) (agentcontainers.SubAgent, error) {
|
||||
agent.AuthToken = uuid.New()
|
||||
createCalls <- agent
|
||||
return agent, nil
|
||||
},
|
||||
).Times(1)
|
||||
|
||||
mSAC.EXPECT().Delete(gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, _ uuid.UUID) error {
|
||||
assert.True(t, closed, "Delete should only be called after Close")
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
|
||||
api := agentcontainers.NewAPI(logger,
|
||||
agentcontainers.WithContainerCLI(fCCLI),
|
||||
agentcontainers.WithDevcontainerCLI(fDCCLI),
|
||||
agentcontainers.WithDevcontainers(
|
||||
[]codersdk.WorkspaceAgentDevcontainer{terraformDevcontainer},
|
||||
[]codersdk.WorkspaceAgentScript{{ID: terraformDevcontainer.ID, LogSourceID: uuid.New()}},
|
||||
),
|
||||
agentcontainers.WithSubAgentClient(mSAC),
|
||||
agentcontainers.WithSubAgentURL("test-subagent-url"),
|
||||
agentcontainers.WithWatcher(watcher.NewNoop()),
|
||||
)
|
||||
api.Start()
|
||||
defer func() {
|
||||
closed = true
|
||||
api.Close()
|
||||
}()
|
||||
|
||||
// When: The devcontainer is created (triggering injection).
|
||||
err := api.CreateDevcontainer(terraformDevcontainer.WorkspaceFolder, terraformDevcontainer.ConfigPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: The subagent sent to Create has the correct
|
||||
// container-internal directory, not the host path.
|
||||
createdAgent := testutil.RequireReceive(ctx, t, createCalls)
|
||||
assert.Equal(t, terraformAgentID, createdAgent.ID,
|
||||
"agent should use terraform-defined ID")
|
||||
assert.Equal(t, "/workspaces/project", createdAgent.Directory,
|
||||
"directory should be the container-internal path from devcontainer "+
|
||||
"read-configuration, not the host-side workspace_folder")
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -134,6 +134,33 @@ func Config(workingDir string) (workspacesdk.ContextConfigResponse, []string) {
|
||||
}, ResolvePaths(mcpConfigFile, workingDir)
|
||||
}
|
||||
|
||||
// ContextPartsFromDir reads instruction files and discovers skills
|
||||
// from a specific directory, using default file names. This is used
|
||||
// by the CLI chat context commands to read context from an arbitrary
|
||||
// directory without consulting agent env vars.
|
||||
func ContextPartsFromDir(dir string) []codersdk.ChatMessagePart {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
|
||||
if entry, found := readInstructionFileFromDir(dir, DefaultInstructionsFile); found {
|
||||
parts = append(parts, entry)
|
||||
}
|
||||
|
||||
// Reuse ResolvePaths so CLI skill discovery follows the same
|
||||
// project-relative path handling as agent config resolution.
|
||||
skillParts := discoverSkills(
|
||||
ResolvePaths(strings.Join([]string{DefaultSkillsDir, "skills"}, ","), dir),
|
||||
DefaultSkillMetaFile,
|
||||
)
|
||||
parts = append(parts, skillParts...)
|
||||
|
||||
// Guarantee non-nil slice.
|
||||
if parts == nil {
|
||||
parts = []codersdk.ChatMessagePart{}
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// MCPConfigFiles returns the resolved MCP configuration file
|
||||
// paths for the agent's MCP manager.
|
||||
func (api *API) MCPConfigFiles() []string {
|
||||
|
||||
@@ -23,18 +23,144 @@ func filterParts(parts []codersdk.ChatMessagePart, t codersdk.ChatMessagePartTyp
|
||||
return out
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
func writeSkillMetaFileInRoot(t *testing.T, skillsRoot, name, description string) string {
|
||||
t.Helper()
|
||||
|
||||
// Clear all env vars so defaults are used.
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
skillDir := filepath.Join(skillsRoot, name)
|
||||
require.NoError(t, os.MkdirAll(skillDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: "+name+"\ndescription: "+description+"\n---\nSkill body"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
return skillDir
|
||||
}
|
||||
|
||||
func writeSkillMetaFile(t *testing.T, dir, name, description string) string {
|
||||
t.Helper()
|
||||
return writeSkillMetaFileInRoot(t, filepath.Join(dir, ".agents", "skills"), name, description)
|
||||
}
|
||||
|
||||
func TestContextPartsFromDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReturnsInstructionFilePart", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("project instructions"), 0o600))
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Empty(t, skillParts)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "project instructions", contextParts[0].ContextFileContent)
|
||||
require.False(t, contextParts[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillParts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFile(t, dir, "my-skill", "A test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsSkillPartsFromSkillsDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
skillDir := writeSkillMetaFileInRoot(
|
||||
t,
|
||||
filepath.Join(dir, "skills"),
|
||||
"my-skill",
|
||||
"A test skill",
|
||||
)
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 1)
|
||||
require.Empty(t, contextParts)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, "my-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, "A test skill", skillParts[0].SkillDescription)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
t.Run("ReturnsEmptyForEmptyDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(t.TempDir())
|
||||
|
||||
require.NotNil(t, parts)
|
||||
require.Empty(t, parts)
|
||||
})
|
||||
|
||||
t.Run("ReturnsCombinedResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
instructionPath := filepath.Join(dir, "AGENTS.md")
|
||||
require.NoError(t, os.WriteFile(instructionPath, []byte("combined instructions"), 0o600))
|
||||
skillDir := writeSkillMetaFile(t, dir, "combined-skill", "Combined test skill")
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(dir)
|
||||
contextParts := filterParts(parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
skillParts := filterParts(parts, codersdk.ChatMessagePartTypeSkill)
|
||||
|
||||
require.Len(t, parts, 2)
|
||||
require.Len(t, contextParts, 1)
|
||||
require.Len(t, skillParts, 1)
|
||||
require.Equal(t, instructionPath, contextParts[0].ContextFilePath)
|
||||
require.Equal(t, "combined instructions", contextParts[0].ContextFileContent)
|
||||
require.Equal(t, "combined-skill", skillParts[0].SkillName)
|
||||
require.Equal(t, skillDir, skillParts[0].SkillDir)
|
||||
})
|
||||
}
|
||||
|
||||
func setupConfigTestEnv(t *testing.T, overrides map[string]string) string {
|
||||
t.Helper()
|
||||
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
for key, value := range overrides {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
return fakeHome
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("Defaults", func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
|
||||
workDir := platformAbsPath("work")
|
||||
cfg, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -46,20 +172,18 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{filepath.Join(workDir, ".mcp.json")}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CustomEnvVars", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
optInstructions := t.TempDir()
|
||||
optSkills := t.TempDir()
|
||||
optMCP := platformAbsPath("opt", "mcp.json")
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, optInstructions)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "CUSTOM.md")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, optSkills)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "META.yaml")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: optInstructions,
|
||||
agentcontextconfig.EnvInstructionsFile: "CUSTOM.md",
|
||||
agentcontextconfig.EnvSkillsDirs: optSkills,
|
||||
agentcontextconfig.EnvSkillMetaFile: "META.yaml",
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
|
||||
// Create files matching the custom names so we can
|
||||
// verify the env vars actually change lookup behavior.
|
||||
@@ -85,15 +209,12 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "META.yaml", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("WhitespaceInFileNames", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsFile: " CLAUDE.md ",
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, " CLAUDE.md ")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// Create a file matching the trimmed name.
|
||||
@@ -106,19 +227,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "hello", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("CommaSeparatedDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
a := t.TempDir()
|
||||
b := t.TempDir()
|
||||
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, a+","+b)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: a + "," + b,
|
||||
})
|
||||
|
||||
// Put instruction files in both dirs.
|
||||
require.NoError(t, os.WriteFile(filepath.Join(a, "AGENTS.md"), []byte("from a"), 0o600))
|
||||
@@ -133,17 +248,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "from b", ctxFiles[1].ContextFileContent)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsInstructionFiles", func(t *testing.T) {
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
|
||||
// Create ~/.coder/AGENTS.md
|
||||
coderDir := filepath.Join(fakeHome, ".coder")
|
||||
@@ -164,16 +272,9 @@ func TestConfig(t *testing.T) {
|
||||
require.False(t, ctxFiles[0].ContextFileTruncated)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("ReadsWorkingDirInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
|
||||
// Create AGENTS.md in the working directory.
|
||||
@@ -193,16 +294,9 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, filepath.Join(workDir, "AGENTS.md"), ctxFiles[0].ContextFilePath)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("TruncatesLargeInstructionFile", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
largeContent := strings.Repeat("a", 64*1024+100)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(workDir, "AGENTS.md"), []byte(largeContent), 0o600))
|
||||
@@ -215,79 +309,47 @@ func TestConfig(t *testing.T) {
|
||||
require.Len(t, ctxFiles[0].ContextFileContent, 64*1024)
|
||||
})
|
||||
|
||||
t.Run("SanitizesHTMLComments", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
sanitizationTests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SanitizesHTMLComments",
|
||||
input: "visible\n<!-- hidden -->content",
|
||||
expected: "visible\ncontent",
|
||||
},
|
||||
{
|
||||
name: "SanitizesInvisibleUnicode",
|
||||
input: "before\u200bafter",
|
||||
expected: "beforeafter",
|
||||
},
|
||||
{
|
||||
name: "NormalizesCRLF",
|
||||
input: "line1\r\nline2\rline3",
|
||||
expected: "line1\nline2\nline3",
|
||||
},
|
||||
}
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
for _, tt := range sanitizationTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setupConfigTestEnv(t, nil)
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte(tt.input),
|
||||
0o600,
|
||||
))
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("visible\n<!-- hidden -->content"),
|
||||
0o600,
|
||||
))
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "visible\ncontent", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("SanitizesInvisibleUnicode", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
// U+200B (zero-width space) should be stripped.
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("before\u200bafter"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "beforeafter", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
|
||||
t.Run("NormalizesCRLF", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
require.NoError(t, os.WriteFile(
|
||||
filepath.Join(workDir, "AGENTS.md"),
|
||||
[]byte("line1\r\nline2\rline3"),
|
||||
0o600,
|
||||
))
|
||||
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, "line1\nline2\nline3", ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
ctxFiles := filterParts(cfg.Parts, codersdk.ChatMessagePartTypeContextFile)
|
||||
require.Len(t, ctxFiles, 1)
|
||||
require.Equal(t, tt.expected, ctxFiles[0].ContextFileContent)
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DiscoversSkills", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
@@ -320,17 +382,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, "SKILL.md", skillParts[0].ContextFileSkillMetaFile)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkipsMissingDirs", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
|
||||
nonExistent := filepath.Join(t.TempDir(), "does-not-exist")
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, nonExistent)
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvInstructionsDirs: nonExistent,
|
||||
agentcontextconfig.EnvSkillsDirs: nonExistent,
|
||||
})
|
||||
|
||||
workDir := t.TempDir()
|
||||
cfg, _ := agentcontextconfig.Config(workDir)
|
||||
@@ -340,17 +398,13 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, cfg.Parts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("MCPConfigFilesResolvedSeparately", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillsDirs, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
|
||||
optMCP := platformAbsPath("opt", "custom.json")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, optMCP)
|
||||
fakeHome := setupConfigTestEnv(t, map[string]string{
|
||||
agentcontextconfig.EnvMCPConfigFiles: optMCP,
|
||||
})
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
|
||||
workDir := t.TempDir()
|
||||
_, mcpFiles := agentcontextconfig.Config(workDir)
|
||||
@@ -358,14 +412,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Equal(t, []string{optMCP}, mcpFiles)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("SkillNameMustMatchDir", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir := filepath.Join(workDir, "skills")
|
||||
@@ -385,14 +435,10 @@ func TestConfig(t *testing.T) {
|
||||
require.Empty(t, skillParts)
|
||||
})
|
||||
|
||||
//nolint:paralleltest // Uses t.Setenv to mutate process-wide environment.
|
||||
t.Run("DuplicateSkillsFirstWins", func(t *testing.T) {
|
||||
fakeHome := t.TempDir()
|
||||
t.Setenv("HOME", fakeHome)
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
fakeHome := setupConfigTestEnv(t, nil)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsDirs, fakeHome)
|
||||
t.Setenv(agentcontextconfig.EnvInstructionsFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvSkillMetaFile, "")
|
||||
t.Setenv(agentcontextconfig.EnvMCPConfigFiles, "")
|
||||
|
||||
workDir := t.TempDir()
|
||||
skillsDir1 := filepath.Join(workDir, "skills1")
|
||||
|
||||
@@ -117,6 +117,10 @@ type Config struct {
|
||||
X11MaxPort *int
|
||||
// BlockFileTransfer restricts use of file transfer applications.
|
||||
BlockFileTransfer bool
|
||||
// BlockReversePortForwarding disables reverse port forwarding (ssh -R).
|
||||
BlockReversePortForwarding bool
|
||||
// BlockLocalPortForwarding disables local port forwarding (ssh -L).
|
||||
BlockLocalPortForwarding bool
|
||||
// ReportConnection.
|
||||
ReportConnection reportConnectionFunc
|
||||
// Experimental: allow connecting to running containers via Docker exec.
|
||||
@@ -190,7 +194,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := newForwardedUnixHandler(logger)
|
||||
unixForwardHandler := newForwardedUnixHandler(logger, config.BlockReversePortForwarding)
|
||||
|
||||
metrics := newSSHServerMetrics(prometheusRegistry)
|
||||
s := &Server{
|
||||
@@ -229,8 +233,15 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
wrapped := NewJetbrainsChannelWatcher(ctx, s.logger, s.config.ReportConnection, newChan, &s.connCountJetBrains)
|
||||
ssh.DirectTCPIPHandler(srv, conn, wrapped, ctx)
|
||||
},
|
||||
"direct-streamlocal@openssh.com": directStreamLocalHandler,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-streamlocal@openssh.com": func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "unix local port forward blocked")
|
||||
_ = newChan.Reject(gossh.Prohibited, "local port forwarding is disabled")
|
||||
return
|
||||
}
|
||||
directStreamLocalHandler(srv, conn, newChan, ctx)
|
||||
},
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
ConnectionFailedCallback: func(conn net.Conn, err error) {
|
||||
s.logger.Warn(ctx, "ssh connection failed",
|
||||
@@ -250,6 +261,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
// be set before we start listening.
|
||||
HostSigners: []ssh.Signer{},
|
||||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if s.config.BlockLocalPortForwarding {
|
||||
s.logger.Warn(ctx, "local port forward blocked",
|
||||
slog.F("destination_host", destinationHost),
|
||||
slog.F("destination_port", destinationPort))
|
||||
return false
|
||||
}
|
||||
// Allow local port forwarding all!
|
||||
s.logger.Debug(ctx, "local port forward",
|
||||
slog.F("destination_host", destinationHost),
|
||||
@@ -260,6 +277,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
|
||||
return true
|
||||
},
|
||||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if s.config.BlockReversePortForwarding {
|
||||
s.logger.Warn(ctx, "reverse port forward blocked",
|
||||
slog.F("bind_host", bindHost),
|
||||
slog.F("bind_port", bindPort))
|
||||
return false
|
||||
}
|
||||
// Allow reverse port forwarding all!
|
||||
s.logger.Debug(ctx, "reverse port forward",
|
||||
slog.F("bind_host", bindHost),
|
||||
|
||||
@@ -35,8 +35,9 @@ type forwardedStreamLocalPayload struct {
|
||||
// streamlocal forwarding (aka. unix forwarding) instead of TCP forwarding.
|
||||
type forwardedUnixHandler struct {
|
||||
sync.Mutex
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
log slog.Logger
|
||||
forwards map[forwardKey]net.Listener
|
||||
blockReversePortForwarding bool
|
||||
}
|
||||
|
||||
type forwardKey struct {
|
||||
@@ -44,10 +45,11 @@ type forwardKey struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
|
||||
func newForwardedUnixHandler(log slog.Logger, blockReversePortForwarding bool) *forwardedUnixHandler {
|
||||
return &forwardedUnixHandler{
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
log: log,
|
||||
forwards: make(map[forwardKey]net.Listener),
|
||||
blockReversePortForwarding: blockReversePortForwarding,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +64,10 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
|
||||
|
||||
switch req.Type {
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if h.blockReversePortForwarding {
|
||||
log.Warn(ctx, "unix reverse port forward blocked")
|
||||
return false, nil
|
||||
}
|
||||
var reqPayload streamLocalForwardPayload
|
||||
err := gossh.Unmarshal(req.Payload, &reqPayload)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -620,6 +622,11 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer artifact.Reader.Close()
|
||||
defer func() {
|
||||
if artifact.ThumbnailReader != nil {
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if artifact.Size > workspacesdk.MaxRecordingSize {
|
||||
a.logger.Warn(ctx, "recording file exceeds maximum size",
|
||||
@@ -633,10 +640,60 @@ func (a *API) handleRecordingStop(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
rw.Header().Set("Content-Type", "video/mp4")
|
||||
rw.Header().Set("Content-Length", strconv.FormatInt(artifact.Size, 10))
|
||||
// Discard the thumbnail if it exceeds the maximum size.
|
||||
// The server-side consumer also enforces this per-part, but
|
||||
// rejecting it here avoids streaming a large thumbnail over
|
||||
// the wire for nothing.
|
||||
if artifact.ThumbnailReader != nil && artifact.ThumbnailSize > workspacesdk.MaxThumbnailSize {
|
||||
a.logger.Warn(ctx, "thumbnail file exceeds maximum size, omitting",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("size", artifact.ThumbnailSize),
|
||||
slog.F("max_size", workspacesdk.MaxThumbnailSize),
|
||||
)
|
||||
_ = artifact.ThumbnailReader.Close()
|
||||
artifact.ThumbnailReader = nil
|
||||
artifact.ThumbnailSize = 0
|
||||
}
|
||||
|
||||
// The multipart response is best-effort: once WriteHeader(200) is
|
||||
// called, CreatePart failures produce a truncated response without
|
||||
// the closing boundary. The server-side consumer handles this
|
||||
// gracefully, preserving any parts read before the error.
|
||||
mw := multipart.NewWriter(rw)
|
||||
defer mw.Close()
|
||||
rw.Header().Set("Content-Type", "multipart/mixed; boundary="+mw.Boundary())
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(rw, artifact.Reader)
|
||||
|
||||
// Part 1: video/mp4 (always present).
|
||||
videoPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"video/mp4"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := io.Copy(videoPart, artifact.Reader); err != nil {
|
||||
a.logger.Warn(ctx, "failed to write video multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Part 2: image/jpeg (present only when thumbnail was extracted).
|
||||
if artifact.ThumbnailReader != nil {
|
||||
thumbPart, err := mw.CreatePart(textproto.MIMEHeader{
|
||||
"Content-Type": {"image/jpeg"},
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "failed to create thumbnail multipart part",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.Error(err))
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(thumbPart, artifact.ThumbnailReader)
|
||||
}
|
||||
}
|
||||
|
||||
// coordFromAction extracts the coordinate pair from a DesktopAction,
|
||||
|
||||
@@ -4,12 +4,17 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -59,6 +64,8 @@ type fakeDesktop struct {
|
||||
lastKeyDown string
|
||||
lastKeyUp string
|
||||
|
||||
thumbnailData []byte // if set, StopRecording includes a thumbnail
|
||||
|
||||
// Recording tracking (guarded by recMu).
|
||||
recMu sync.Mutex
|
||||
recordings map[string]string // ID → file path
|
||||
@@ -187,10 +194,15 @@ func (f *fakeDesktop) StopRecording(_ context.Context, recordingID string) (*age
|
||||
_ = file.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &agentdesktop.RecordingArtifact{
|
||||
artifact := &agentdesktop.RecordingArtifact{
|
||||
Reader: file,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
if f.thumbnailData != nil {
|
||||
artifact.ThumbnailReader = io.NopCloser(bytes.NewReader(f.thumbnailData))
|
||||
artifact.ThumbnailSize = int64(len(f.thumbnailData))
|
||||
}
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func (f *fakeDesktop) RecordActivity() {
|
||||
@@ -785,8 +797,8 @@ func TestRecordingStartStop(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDDefault+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStartFails(t *testing.T) {
|
||||
@@ -847,8 +859,8 @@ func TestRecordingStartIdempotent(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, []byte("fake-mp4-data-"+testRecIDStartIdempotent+"-1"), parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestRecordingStopIdempotent(t *testing.T) {
|
||||
@@ -872,7 +884,7 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop twice - both should succeed with identical data.
|
||||
var bodies [2][]byte
|
||||
var videoParts [2][]byte
|
||||
for i := range 2 {
|
||||
body, err := json.Marshal(map[string]string{"recording_id": testRecIDStopIdempotent})
|
||||
require.NoError(t, err)
|
||||
@@ -880,10 +892,10 @@ func TestRecordingStopIdempotent(t *testing.T) {
|
||||
request := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, "video/mp4", recorder.Header().Get("Content-Type"))
|
||||
bodies[i] = recorder.Body.Bytes()
|
||||
parts := parseMultipartParts(t, recorder.Header().Get("Content-Type"), recorder.Body.Bytes())
|
||||
videoParts[i] = parts["video/mp4"]
|
||||
}
|
||||
assert.Equal(t, bodies[0], bodies[1])
|
||||
assert.Equal(t, videoParts[0], videoParts[1])
|
||||
}
|
||||
|
||||
func TestRecordingStopInvalidIDFormat(t *testing.T) {
|
||||
@@ -1004,8 +1016,8 @@ func TestRecordingMultipleSimultaneous(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(body))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
assert.Equal(t, expected[id], rr.Body.Bytes())
|
||||
parts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
assert.Equal(t, expected[id], parts["video/mp4"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1112,8 +1124,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
firstData := rr.Body.Bytes()
|
||||
firstParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
firstData := firstParts["video/mp4"]
|
||||
require.NotEmpty(t, firstData)
|
||||
|
||||
// Step 3: Start again with the same ID - should succeed
|
||||
@@ -1128,8 +1140,8 @@ func TestRecordingStartAfterCompleted(t *testing.T) {
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
assert.Equal(t, "video/mp4", rr.Header().Get("Content-Type"))
|
||||
secondData := rr.Body.Bytes()
|
||||
secondParts := parseMultipartParts(t, rr.Header().Get("Content-Type"), rr.Body.Bytes())
|
||||
secondData := secondParts["video/mp4"]
|
||||
require.NotEmpty(t, secondData)
|
||||
|
||||
// The two recordings should have different data because the
|
||||
@@ -1235,3 +1247,166 @@ func TestRecordingStopCorrupted(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Recording is corrupted.", respStop.Message)
|
||||
}
|
||||
|
||||
// parseMultipartParts parses a multipart/mixed response and returns
|
||||
// a map from Content-Type to body bytes.
|
||||
func parseMultipartParts(t *testing.T, contentType string, body []byte) map[string][]byte {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
require.NoError(t, err, "parse Content-Type")
|
||||
boundary := params["boundary"]
|
||||
require.NotEmpty(t, boundary, "missing boundary")
|
||||
mr := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
parts := make(map[string][]byte)
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err, "unexpected multipart parse error")
|
||||
ct := part.Header.Get("Content-Type")
|
||||
data, readErr := io.ReadAll(part)
|
||||
require.NoError(t, readErr)
|
||||
parts[ct] = data
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create a fake JPEG header: 0xFF 0xD8 0xFF followed by 509 zero bytes.
|
||||
thumbnail := make([]byte, 512)
|
||||
thumbnail[0] = 0xff
|
||||
thumbnail[1] = 0xd8
|
||||
thumbnail[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: thumbnail,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 2, "expected exactly 2 parts (video + thumbnail)")
|
||||
|
||||
// The fake writes "fake-mp4-data-<id>-<counter>" as the MP4 content.
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
assert.Equal(t, thumbnail, parts["image/jpeg"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_NoThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
func TestHandleRecordingStop_OversizedThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
// Create thumbnail data that exceeds MaxThumbnailSize.
|
||||
oversizedThumb := make([]byte, workspacesdk.MaxThumbnailSize+1)
|
||||
oversizedThumb[0] = 0xff
|
||||
oversizedThumb[1] = 0xd8
|
||||
oversizedThumb[2] = 0xff
|
||||
|
||||
fake := &fakeDesktop{
|
||||
startCfg: agentdesktop.DisplayConfig{Width: 1920, Height: 1080},
|
||||
thumbnailData: oversizedThumb,
|
||||
}
|
||||
api := agentdesktop.NewAPI(logger, fake, nil)
|
||||
defer api.Close()
|
||||
|
||||
handler := api.Routes()
|
||||
|
||||
// Start recording.
|
||||
recID := uuid.New().String()
|
||||
startBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/recording/start", bytes.NewReader(startBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Stop recording.
|
||||
stopBody, err := json.Marshal(map[string]string{"recording_id": recID})
|
||||
require.NoError(t, err)
|
||||
rr = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/recording/stop", bytes.NewReader(stopBody))
|
||||
handler.ServeHTTP(rr, req)
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
// Verify multipart response contains only the video part.
|
||||
ct := rr.Header().Get("Content-Type")
|
||||
assert.True(t, strings.HasPrefix(ct, "multipart/mixed"),
|
||||
"expected multipart/mixed Content-Type, got %s", ct)
|
||||
|
||||
parts := parseMultipartParts(t, ct, rr.Body.Bytes())
|
||||
assert.Len(t, parts, 1, "expected exactly 1 part (video only, oversized thumbnail discarded)")
|
||||
|
||||
expectedMP4 := []byte("fake-mp4-data-" + recID + "-1")
|
||||
assert.Equal(t, expectedMP4, parts["video/mp4"])
|
||||
}
|
||||
|
||||
@@ -105,6 +105,11 @@ type RecordingArtifact struct {
|
||||
Reader io.ReadCloser
|
||||
// Size is the byte length of the MP4 content.
|
||||
Size int64
|
||||
// ThumbnailReader is the JPEG thumbnail. May be nil if no
|
||||
// thumbnail was produced. Callers must close it when done.
|
||||
ThumbnailReader io.ReadCloser
|
||||
// ThumbnailSize is the byte length of the thumbnail.
|
||||
ThumbnailSize int64
|
||||
}
|
||||
|
||||
// DisplayConfig describes a running desktop session.
|
||||
|
||||
@@ -56,6 +56,7 @@ type screenshotOutput struct {
|
||||
type recordingProcess struct {
|
||||
cmd *exec.Cmd
|
||||
filePath string
|
||||
thumbPath string
|
||||
stopped bool
|
||||
killed bool // true when the process was SIGKILLed
|
||||
done chan struct{} // closed when cmd.Wait() returns
|
||||
@@ -383,13 +384,20 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
}
|
||||
// Completed recording - discard old file, start fresh.
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old recording file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove old thumbnail file",
|
||||
slog.F("recording_id", recordingID),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, recordingID)
|
||||
}
|
||||
|
||||
@@ -406,6 +414,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".mp4")
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recordingID+".thumb.jpg")
|
||||
|
||||
// Use a background context so the process outlives the HTTP
|
||||
// request that triggered it.
|
||||
@@ -419,6 +428,7 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
"--idle-speedup", "20",
|
||||
"--idle-min-duration", "0.35",
|
||||
"--idle-noise-tolerance", "-38dB",
|
||||
"--thumbnail", thumbPath,
|
||||
filePath)
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
@@ -427,9 +437,10 @@ func (p *portableDesktop) StartRecording(ctx context.Context, recordingID string
|
||||
}
|
||||
|
||||
rec := &recordingProcess{
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
done: make(chan struct{}),
|
||||
cmd: cmd,
|
||||
filePath: filePath,
|
||||
thumbPath: thumbPath,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go func() {
|
||||
rec.waitErr = cmd.Wait()
|
||||
@@ -499,10 +510,35 @@ func (p *portableDesktop) StopRecording(ctx context.Context, recordingID string)
|
||||
_ = f.Close()
|
||||
return nil, xerrors.Errorf("stat recording artifact: %w", err)
|
||||
}
|
||||
return &RecordingArtifact{
|
||||
artifact := &RecordingArtifact{
|
||||
Reader: f,
|
||||
Size: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
// Attach thumbnail if the subprocess wrote one.
|
||||
thumbFile, err := os.Open(rec.thumbPath)
|
||||
if err != nil {
|
||||
p.logger.Warn(ctx, "thumbnail not available",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
thumbInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail stat failed",
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err))
|
||||
return artifact, nil
|
||||
}
|
||||
if thumbInfo.Size() == 0 {
|
||||
_ = thumbFile.Close()
|
||||
p.logger.Warn(ctx, "thumbnail file is empty",
|
||||
slog.F("thumbnail_path", rec.thumbPath))
|
||||
return artifact, nil
|
||||
}
|
||||
artifact.ThumbnailReader = thumbFile
|
||||
artifact.ThumbnailSize = thumbInfo.Size()
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// lockedStopRecordingProcess stops a single recording via stopOnce.
|
||||
@@ -571,18 +607,33 @@ func (p *portableDesktop) lockedCleanStaleRecordings(ctx context.Context) {
|
||||
}
|
||||
info, err := os.Stat(rec.filePath)
|
||||
if err != nil {
|
||||
// File already removed or inaccessible; drop entry.
|
||||
// File already removed or inaccessible; clean up
|
||||
// any leftover thumbnail and drop the entry.
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
continue
|
||||
}
|
||||
if p.clock.Since(info.ModTime()) > time.Hour {
|
||||
if err := os.Remove(rec.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(rec.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale recording file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("file_path", rec.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(rec.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(ctx, "failed to remove stale thumbnail file",
|
||||
slog.F("recording_id", id),
|
||||
slog.F("thumbnail_path", rec.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
}
|
||||
@@ -603,13 +654,14 @@ func (p *portableDesktop) Close() error {
|
||||
// Snapshot recording file paths and idle goroutine channels
|
||||
// for cleanup, then clear the map.
|
||||
type recEntry struct {
|
||||
id string
|
||||
filePath string
|
||||
idleDone chan struct{}
|
||||
id string
|
||||
filePath string
|
||||
thumbPath string
|
||||
idleDone chan struct{}
|
||||
}
|
||||
var allRecs []recEntry
|
||||
for id, rec := range p.recordings {
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, idleDone: rec.idleDone})
|
||||
allRecs = append(allRecs, recEntry{id: id, filePath: rec.filePath, thumbPath: rec.thumbPath, idleDone: rec.idleDone})
|
||||
delete(p.recordings, id)
|
||||
}
|
||||
session := p.session
|
||||
@@ -630,13 +682,20 @@ func (p *portableDesktop) Close() error {
|
||||
go func() {
|
||||
defer close(cleanupDone)
|
||||
for _, entry := range allRecs {
|
||||
if err := os.Remove(entry.filePath); err != nil && !os.IsNotExist(err) {
|
||||
if err := os.Remove(entry.filePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove recording file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("file_path", entry.filePath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
if err := os.Remove(entry.thumbPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
p.logger.Warn(context.Background(), "failed to remove thumbnail file on close",
|
||||
slog.F("recording_id", entry.id),
|
||||
slog.F("thumbnail_path", entry.thumbPath),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
session.cancel()
|
||||
|
||||
@@ -2,6 +2,7 @@ package agentdesktop
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -584,6 +585,7 @@ func TestPortableDesktop_StartRecording(t *testing.T) {
|
||||
joined := strings.Join(cmd, " ")
|
||||
if strings.Contains(joined, "record") && strings.Contains(joined, "coder-recording-"+recID) {
|
||||
found = true
|
||||
assert.Contains(t, joined, "--thumbnail", "record command should include --thumbnail flag")
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -666,6 +668,66 @@ func TestPortableDesktop_StopRecording_ReturnsArtifact(t *testing.T) {
|
||||
defer artifact.Reader.Close()
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// No thumbnail file exists, so ThumbnailReader should be nil.
|
||||
assert.Nil(t, artifact.ThumbnailReader, "ThumbnailReader should be nil when no thumbnail file exists")
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
func TestPortableDesktop_StopRecording_WithThumbnail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
rec := &recordedExecer{
|
||||
scripts: map[string]string{
|
||||
"record": `trap 'exit 0' INT; sleep 120 & wait`,
|
||||
"up": `printf '{"vncPort":5901,"geometry":"1920x1080"}\n' && sleep 120`,
|
||||
},
|
||||
}
|
||||
|
||||
clk := quartz.NewReal()
|
||||
pd := &portableDesktop{
|
||||
logger: logger,
|
||||
execer: rec,
|
||||
scriptBinDir: t.TempDir(),
|
||||
clock: clk,
|
||||
binPath: "portabledesktop",
|
||||
recordings: make(map[string]*recordingProcess),
|
||||
}
|
||||
pd.lastDesktopActionAt.Store(clk.Now().UnixNano())
|
||||
|
||||
ctx := t.Context()
|
||||
recID := uuid.New().String()
|
||||
err := pd.StartRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write a dummy MP4 file at the expected path.
|
||||
filePath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".mp4")
|
||||
require.NoError(t, os.WriteFile(filePath, []byte("fake-mp4-data"), 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(filePath) })
|
||||
|
||||
// Write a thumbnail file at the expected path.
|
||||
thumbPath := filepath.Join(os.TempDir(), "coder-recording-"+recID+".thumb.jpg")
|
||||
thumbContent := []byte("fake-jpeg-thumbnail")
|
||||
require.NoError(t, os.WriteFile(thumbPath, thumbContent, 0o600))
|
||||
t.Cleanup(func() { _ = os.Remove(thumbPath) })
|
||||
|
||||
artifact, err := pd.StopRecording(ctx, recID)
|
||||
require.NoError(t, err)
|
||||
defer artifact.Reader.Close()
|
||||
|
||||
assert.Equal(t, int64(len("fake-mp4-data")), artifact.Size)
|
||||
|
||||
// Thumbnail should be attached.
|
||||
require.NotNil(t, artifact.ThumbnailReader, "ThumbnailReader should be non-nil when thumbnail file exists")
|
||||
defer artifact.ThumbnailReader.Close()
|
||||
assert.Equal(t, int64(len(thumbContent)), artifact.ThumbnailSize)
|
||||
|
||||
// Read and verify thumbnail content.
|
||||
thumbData, err := io.ReadAll(artifact.ThumbnailReader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, thumbContent, thumbData)
|
||||
|
||||
require.NoError(t, pd.Close())
|
||||
}
|
||||
|
||||
|
||||
@@ -87,6 +87,12 @@ func IsDevVersion(v string) bool {
|
||||
return strings.Contains(v, "-"+develPreRelease)
|
||||
}
|
||||
|
||||
// IsRCVersion returns true if the version has a release candidate
|
||||
// pre-release tag, e.g. "v2.31.0-rc.0".
|
||||
func IsRCVersion(v string) bool {
|
||||
return strings.Contains(v, "-rc.")
|
||||
}
|
||||
|
||||
// IsDev returns true if this is a development build.
|
||||
// CI builds are also considered development builds.
|
||||
func IsDev() bool {
|
||||
|
||||
@@ -102,3 +102,29 @@ func TestBuildInfo(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRCVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected bool
|
||||
}{
|
||||
{"RC0", "v2.31.0-rc.0", true},
|
||||
{"RC1WithBuild", "v2.31.0-rc.1+abc123", true},
|
||||
{"RC10", "v2.31.0-rc.10", true},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", true},
|
||||
{"DevelVersion", "v2.31.0-devel+abc123", false},
|
||||
{"StableVersion", "v2.31.0", false},
|
||||
{"DevNoVersion", "v0.0.0-devel+abc123", false},
|
||||
{"BetaVersion", "v2.31.0-beta.1", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, buildinfo.IsRCVersion(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+22
-4
@@ -53,6 +53,8 @@ func workspaceAgent() *serpent.Command {
|
||||
slogJSONPath string
|
||||
slogStackdriverPath string
|
||||
blockFileTransfer bool
|
||||
blockReversePortForwarding bool
|
||||
blockLocalPortForwarding bool
|
||||
agentHeaderCommand string
|
||||
agentHeader []string
|
||||
devcontainers bool
|
||||
@@ -319,10 +321,12 @@ func workspaceAgent() *serpent.Command {
|
||||
SSHMaxTimeout: sshMaxTimeout,
|
||||
Subsystems: subsystems,
|
||||
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
PrometheusRegistry: prometheusRegistry,
|
||||
BlockFileTransfer: blockFileTransfer,
|
||||
BlockReversePortForwarding: blockReversePortForwarding,
|
||||
BlockLocalPortForwarding: blockLocalPortForwarding,
|
||||
Execer: execer,
|
||||
Devcontainers: devcontainers,
|
||||
DevcontainerAPIOptions: []agentcontainers.Option{
|
||||
agentcontainers.WithSubAgentURL(agentAuth.agentURL.String()),
|
||||
agentcontainers.WithProjectDiscovery(devcontainerProjectDiscovery),
|
||||
@@ -493,6 +497,20 @@ func workspaceAgent() *serpent.Command {
|
||||
Description: fmt.Sprintf("Block file transfer using known applications: %s.", strings.Join(agentssh.BlockedFileTransferCommands, ",")),
|
||||
Value: serpent.BoolOf(&blockFileTransfer),
|
||||
},
|
||||
{
|
||||
Flag: "block-reverse-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING",
|
||||
Description: "Block reverse port forwarding through the SSH server (ssh -R).",
|
||||
Value: serpent.BoolOf(&blockReversePortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "block-local-port-forwarding",
|
||||
Default: "false",
|
||||
Env: "CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING",
|
||||
Description: "Block local port forwarding through the SSH server (ssh -L).",
|
||||
Value: serpent.BoolOf(&blockLocalPortForwarding),
|
||||
},
|
||||
{
|
||||
Flag: "devcontainers-enable",
|
||||
Default: "true",
|
||||
|
||||
+194
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) chatCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "chat",
|
||||
Short: "Manage agent chats",
|
||||
Long: "Commands for interacting with chats from within a workspace.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RootCmd) chatContextCommand() *serpent.Command {
|
||||
return &serpent.Command{
|
||||
Use: "context",
|
||||
Short: "Manage chat context",
|
||||
Long: "Add or clear context files and skills for an active chat session.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.chatContextAddCommand(),
|
||||
r.chatContextClearCommand(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextAddCommand() *serpent.Command {
|
||||
var (
|
||||
dir string
|
||||
chatID string
|
||||
)
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "add",
|
||||
Short: "Add context to an active chat",
|
||||
Long: "Read instruction files and discover skills from a directory, then add " +
|
||||
"them as context to an active chat session. Multiple calls " +
|
||||
"are additive.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
if dir == "" && inv.Environ.Get("CODER") != "true" {
|
||||
return xerrors.New("this command must be run inside a Coder workspace (set --dir to override)")
|
||||
}
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedDir := dir
|
||||
if resolvedDir == "" {
|
||||
resolvedDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get working directory: %w", err)
|
||||
}
|
||||
}
|
||||
resolvedDir, err = filepath.Abs(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resolve directory: %w", err)
|
||||
}
|
||||
info, err := os.Stat(resolvedDir)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("cannot read directory %q: %w", resolvedDir, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return xerrors.Errorf("%q is not a directory", resolvedDir)
|
||||
}
|
||||
|
||||
parts := agentcontextconfig.ContextPartsFromDir(resolvedDir)
|
||||
if len(parts) == 0 {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "No context files or skills found in "+resolvedDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve chat ID from flag or auto-detect.
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.AddChatContext(ctx, agentsdk.AddChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
Parts: parts,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("add chat context: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Added %d context part(s) to chat %s\n", resp.Count, resp.ChatID)
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{
|
||||
{
|
||||
Name: "Directory",
|
||||
Flag: "dir",
|
||||
Description: "Directory to read context files and skills from. Defaults to the current working directory.",
|
||||
Value: serpent.StringOf(&dir),
|
||||
},
|
||||
{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to add context to. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
},
|
||||
},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) chatContextClearCommand() *serpent.Command {
|
||||
var chatID string
|
||||
agentAuth := &AgentAuth{}
|
||||
cmd := &serpent.Command{
|
||||
Use: "clear",
|
||||
Short: "Clear context from an active chat",
|
||||
Long: "Soft-delete all context-file and skill messages from an active chat. " +
|
||||
"The next turn will re-fetch default context from the agent.",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
ctx := inv.Context()
|
||||
ctx, stop := inv.SignalNotifyContext(ctx, StopSignals...)
|
||||
defer stop()
|
||||
|
||||
client, err := agentAuth.CreateClient()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
|
||||
resolvedChatID, err := parseChatID(chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.ClearChatContext(ctx, agentsdk.ClearChatContextRequest{
|
||||
ChatID: resolvedChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear chat context: %w", err)
|
||||
}
|
||||
|
||||
if resp.ChatID == uuid.Nil {
|
||||
_, _ = fmt.Fprintln(inv.Stdout, "No active chats to clear.")
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(inv.Stdout, "Cleared context from chat %s\n", resp.ChatID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Options: serpent.OptionSet{{
|
||||
Name: "Chat ID",
|
||||
Flag: "chat",
|
||||
Env: "CODER_CHAT_ID",
|
||||
Description: "Chat ID to clear context from. Auto-detected from CODER_CHAT_ID, the only active chat, or the only top-level active chat.",
|
||||
Value: serpent.StringOf(&chatID),
|
||||
}},
|
||||
}
|
||||
agentAuth.AttachOptions(cmd, false)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseChatID returns the chat UUID from the flag value (which
|
||||
// serpent already populates from --chat or CODER_CHAT_ID). Returns
|
||||
// uuid.Nil if empty (the server will auto-detect).
|
||||
func parseChatID(flagValue string) (uuid.UUID, error) {
|
||||
if flagValue == "" {
|
||||
return uuid.Nil, nil
|
||||
}
|
||||
parsed, err := uuid.Parse(flagValue)
|
||||
if err != nil {
|
||||
return uuid.Nil, xerrors.Errorf("invalid chat ID %q: %w", flagValue, err)
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
)
|
||||
|
||||
func TestExpChatContextAdd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RequiresWorkspaceOrDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
|
||||
err := inv.Run()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
})
|
||||
|
||||
t.Run("AllowsExplicitDir", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add", "--dir", t.TempDir())
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AllowsWorkspaceEnv", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv, _ := clitest.New(t, "exp", "chat", "context", "add")
|
||||
inv.Environ.Set("CODER", "true")
|
||||
|
||||
err := inv.Run()
|
||||
if err != nil {
|
||||
require.NotContains(t, err.Error(), "this command must be run inside a Coder workspace")
|
||||
}
|
||||
})
|
||||
}
|
||||
+29
-5
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -148,6 +149,7 @@ func (r *RootCmd) AGPLExperimental() []*serpent.Command {
|
||||
return []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.chatCommand(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
@@ -710,7 +712,7 @@ func (r *RootCmd) createHTTPClient(ctx context.Context, serverURL *url.URL, inv
|
||||
transport = wrapTransportWithTelemetryHeader(transport, inv)
|
||||
transport = wrapTransportWithUserAgentHeader(transport, inv)
|
||||
if !r.noVersionCheck {
|
||||
transport = wrapTransportWithVersionMismatchCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
transport = wrapTransportWithVersionCheck(transport, inv, buildinfo.Version(), func(ctx context.Context) (codersdk.BuildInfoResponse, error) {
|
||||
// Create a new client without any wrapped transport
|
||||
// otherwise it creates an infinite loop!
|
||||
basicClient := codersdk.New(serverURL)
|
||||
@@ -1434,6 +1436,21 @@ func defaultUpgradeMessage(version string) string {
|
||||
return fmt.Sprintf("download the server version with: 'curl -L https://coder.com/install.sh | sh -s -- --version %s'", version)
|
||||
}
|
||||
|
||||
// serverVersionMessage returns a warning message if the server version
|
||||
// is a release candidate or development build. Returns empty string
|
||||
// for stable versions. RC is checked before devel because RC dev
|
||||
// builds (e.g. v2.33.0-rc.1-devel+hash) contain both tags.
|
||||
func serverVersionMessage(serverVersion string) string {
|
||||
switch {
|
||||
case buildinfo.IsRCVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a release candidate of Coder (%s)", serverVersion)
|
||||
case buildinfo.IsDevVersion(serverVersion):
|
||||
return fmt.Sprintf("the server is running a development version of Coder (%s)", serverVersion)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// wrapTransportWithEntitlementsCheck adds a middleware to the HTTP transport
|
||||
// that checks for entitlement warnings and prints them to the user.
|
||||
func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.RoundTripper {
|
||||
@@ -1452,10 +1469,10 @@ func wrapTransportWithEntitlementsCheck(rt http.RoundTripper, w io.Writer) http.
|
||||
})
|
||||
}
|
||||
|
||||
// wrapTransportWithVersionMismatchCheck adds a middleware to the HTTP transport
|
||||
// that checks for version mismatches between the client and server. If a mismatch
|
||||
// is detected, a warning is printed to the user.
|
||||
func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
// wrapTransportWithVersionCheck adds a middleware to the HTTP transport
|
||||
// that checks the server version and warns about development builds,
|
||||
// release candidates, and client/server version mismatches.
|
||||
func wrapTransportWithVersionCheck(rt http.RoundTripper, inv *serpent.Invocation, clientVersion string, getBuildInfo func(ctx context.Context) (codersdk.BuildInfoResponse, error)) http.RoundTripper {
|
||||
var once sync.Once
|
||||
return roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
res, err := rt.RoundTrip(req)
|
||||
@@ -1467,9 +1484,16 @@ func wrapTransportWithVersionMismatchCheck(rt http.RoundTripper, inv *serpent.In
|
||||
if serverVersion == "" {
|
||||
return
|
||||
}
|
||||
// Warn about non-stable server versions. Skip
|
||||
// during tests to avoid polluting golden files.
|
||||
if msg := serverVersionMessage(serverVersion); msg != "" && flag.Lookup("test.v") == nil {
|
||||
warning := pretty.Sprint(cliui.DefaultStyles.Warn, msg)
|
||||
_, _ = fmt.Fprintln(inv.Stderr, warning)
|
||||
}
|
||||
if buildinfo.VersionsMatch(clientVersion, serverVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
upgradeMessage := defaultUpgradeMessage(semver.Canonical(serverVersion))
|
||||
if serverInfo, err := getBuildInfo(inv.Context()); err == nil {
|
||||
switch {
|
||||
|
||||
@@ -91,7 +91,7 @@ func Test_formatExamples(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
func Test_wrapTransportWithVersionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoOutput", func(t *testing.T) {
|
||||
@@ -102,7 +102,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -131,7 +131,7 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
expectedUpgradeMessage := "My custom upgrade message"
|
||||
rt := wrapTransportWithVersionMismatchCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
@@ -159,6 +159,53 @@ func Test_wrapTransportWithVersionMismatchCheck(t *testing.T) {
|
||||
expectedOutput := fmt.Sprintln(pretty.Sprint(cliui.DefaultStyles.Warn, fmtOutput))
|
||||
require.Equal(t, expectedOutput, buf.String())
|
||||
})
|
||||
|
||||
t.Run("ServerStableVersion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := &RootCmd{}
|
||||
cmd, err := r.Command(nil)
|
||||
require.NoError(t, err)
|
||||
var buf bytes.Buffer
|
||||
inv := cmd.Invoke()
|
||||
inv.Stderr = &buf
|
||||
rt := wrapTransportWithVersionCheck(roundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
codersdk.BuildVersionHeader: []string{"v2.31.0"},
|
||||
},
|
||||
Body: io.NopCloser(nil),
|
||||
}, nil
|
||||
}), inv, "v2.31.0", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
res, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Empty(t, buf.String())
|
||||
})
|
||||
}
|
||||
|
||||
func Test_serverVersionMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
expected string
|
||||
}{
|
||||
{"Stable", "v2.31.0", ""},
|
||||
{"Dev", "v0.0.0-devel+abc123", "the server is running a development version of Coder (v0.0.0-devel+abc123)"},
|
||||
{"RC", "v2.31.0-rc.1", "the server is running a release candidate of Coder (v2.31.0-rc.1)"},
|
||||
{"RCDevel", "v2.33.0-rc.1-devel+727ec00f7", "the server is running a release candidate of Coder (v2.33.0-rc.1-devel+727ec00f7)"},
|
||||
{"Empty", "", ""},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.Equal(t, c.expected, serverVersionMessage(c.version))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapTransportWithTelemetryHeader(t *testing.T) {
|
||||
|
||||
+99
-17
@@ -52,6 +52,10 @@ import (
|
||||
|
||||
const (
|
||||
disableUsageApp = "disable"
|
||||
|
||||
// Retry transient errors during SSH connection establishment.
|
||||
sshRetryInterval = 2 * time.Second
|
||||
sshMaxAttempts = 10 // initial + retries per step
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -62,6 +66,53 @@ var (
|
||||
workspaceNameRe = regexp.MustCompile(`[/.]+|--`)
|
||||
)
|
||||
|
||||
// isRetryableError checks for transient connection errors worth
|
||||
// retrying: DNS failures, connection refused, and server 5xx.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil || xerrors.Is(err, context.Canceled) {
|
||||
return false
|
||||
}
|
||||
// Check connection errors before context.DeadlineExceeded because
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both.
|
||||
if codersdk.IsConnectionError(err) {
|
||||
return true
|
||||
}
|
||||
if xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
return sdkErr.StatusCode() >= 500
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// retryWithInterval calls fn up to maxAttempts times, waiting
|
||||
// interval between attempts. Stops on success, non-retryable
|
||||
// error, or context cancellation.
|
||||
func retryWithInterval(ctx context.Context, logger slog.Logger, interval time.Duration, maxAttempts int, fn func() error) error {
|
||||
var lastErr error
|
||||
attempt := 0
|
||||
for r := retry.New(interval, interval); r.Wait(ctx); {
|
||||
lastErr = fn()
|
||||
if lastErr == nil || !isRetryableError(lastErr) {
|
||||
return lastErr
|
||||
}
|
||||
attempt++
|
||||
if attempt >= maxAttempts {
|
||||
break
|
||||
}
|
||||
logger.Warn(ctx, "transient error, retrying",
|
||||
slog.Error(lastErr),
|
||||
slog.F("attempt", attempt),
|
||||
)
|
||||
}
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
func (r *RootCmd) ssh() *serpent.Command {
|
||||
var (
|
||||
stdio bool
|
||||
@@ -277,10 +328,17 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
HostnameSuffix: hostnameSuffix,
|
||||
}
|
||||
|
||||
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
if err != nil {
|
||||
// Populated by the closure below.
|
||||
var workspace codersdk.Workspace
|
||||
var workspaceAgent codersdk.WorkspaceAgent
|
||||
resolveWorkspace := func() error {
|
||||
var err error
|
||||
workspace, workspaceAgent, err = findWorkspaceAndAgentByHostname(
|
||||
ctx, inv, client,
|
||||
inv.Args[0], cliConfig, disableAutostart)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, resolveWorkspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -306,8 +364,13 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
wait = false
|
||||
}
|
||||
|
||||
templateVersion, err := client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
if err != nil {
|
||||
var templateVersion codersdk.TemplateVersion
|
||||
fetchVersion := func() error {
|
||||
var err error
|
||||
templateVersion, err = client.TemplateVersion(ctx, workspace.LatestBuild.TemplateVersionID)
|
||||
return err
|
||||
}
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, fetchVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -347,8 +410,12 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
// If we're in stdio mode, check to see if we can use Coder Connect.
|
||||
// We don't support Coder Connect over non-stdio coder ssh yet.
|
||||
if stdio && !forceNewTunnel {
|
||||
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
if err != nil {
|
||||
var connInfo workspacesdk.AgentConnectionInfo
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
connInfo, err = wsClient.AgentConnectionInfoGeneric(ctx)
|
||||
return err
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("get agent connection info: %w", err)
|
||||
}
|
||||
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
|
||||
@@ -384,23 +451,27 @@ func (r *RootCmd) ssh() *serpent.Command {
|
||||
})
|
||||
defer closeUsage()
|
||||
}
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
|
||||
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack, logger)
|
||||
}
|
||||
}
|
||||
|
||||
if r.disableDirect {
|
||||
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
|
||||
}
|
||||
conn, err := wsClient.
|
||||
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
var conn workspacesdk.AgentConn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = wsClient.DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
|
||||
Logger: logger,
|
||||
BlockEndpoints: r.disableDirect,
|
||||
EnableTelemetry: !r.disableNetworkTelemetry,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("dial agent: %w", err)
|
||||
}
|
||||
if err = stack.push("agent conn", conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
conn.AwaitReachable(ctx)
|
||||
@@ -1578,16 +1649,27 @@ func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDial
|
||||
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
|
||||
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
|
||||
if !ok || dialer == nil {
|
||||
return &net.Dialer{}
|
||||
// Timeout prevents hanging on broken tunnels (OS default is very long).
|
||||
return &net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return dialer
|
||||
}
|
||||
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
|
||||
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack, logger slog.Logger) error {
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host: %w", err)
|
||||
var conn net.Conn
|
||||
if err := retryWithInterval(ctx, logger, sshRetryInterval, sshMaxAttempts, func() error {
|
||||
var err error
|
||||
conn, err = dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial coder connect host %q over tcp: %w", addr, err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := stack.push("tcp conn", conn); err != nil {
|
||||
return err
|
||||
|
||||
+166
-1
@@ -5,7 +5,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -226,6 +228,41 @@ func TestCloserStack_Timeout(t *testing.T) {
|
||||
testutil.TryReceive(ctx, t, closed)
|
||||
}
|
||||
|
||||
func TestCloserStack_PushAfterClose_ConnClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
|
||||
|
||||
uut.close(xerrors.New("canceled"))
|
||||
|
||||
closes := new([]*fakeCloser)
|
||||
fc := &fakeCloser{closes: closes}
|
||||
err := uut.push("conn", fc)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, []*fakeCloser{fc}, *closes, "should close conn on failed push")
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_DefaultTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
d, ok := dialer.(*net.Dialer)
|
||||
require.True(t, ok, "expected *net.Dialer")
|
||||
assert.Equal(t, 5*time.Second, d.Timeout)
|
||||
assert.Equal(t, 30*time.Second, d.KeepAlive)
|
||||
}
|
||||
|
||||
func TestCoderConnectDialer_Overridden(t *testing.T) {
|
||||
t.Parallel()
|
||||
custom := &net.Dialer{Timeout: 99 * time.Second}
|
||||
ctx := WithTestOnlyCoderConnectDialer(context.Background(), custom)
|
||||
|
||||
dialer := testOrDefaultDialer(ctx)
|
||||
assert.Equal(t, custom, dialer)
|
||||
}
|
||||
|
||||
func TestCoderConnectStdio(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -254,7 +291,7 @@ func TestCoderConnectStdio(t *testing.T) {
|
||||
|
||||
stdioDone := make(chan struct{})
|
||||
go func() {
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
|
||||
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack, logger)
|
||||
assert.NoError(t, err)
|
||||
close(stdioDone)
|
||||
}()
|
||||
@@ -448,3 +485,131 @@ func Test_getWorkspaceAgent(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsRetryableError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
retryable bool
|
||||
}{
|
||||
{"Nil", nil, false},
|
||||
{"ContextCanceled", context.Canceled, false},
|
||||
{"ContextDeadlineExceeded", context.DeadlineExceeded, false},
|
||||
{"WrappedContextCanceled", xerrors.Errorf("wrapped: %w", context.Canceled), false},
|
||||
{"DNSError", &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}, true},
|
||||
{"OpError", &net.OpError{Op: "dial", Net: "tcp", Err: &os.SyscallError{}}, true},
|
||||
{"WrappedDNSError", xerrors.Errorf("connect: %w", &net.DNSError{Err: "no such host", Name: "example.com"}), true},
|
||||
{"SDKError_500", codersdk.NewTestError(http.StatusInternalServerError, "GET", "/api"), true},
|
||||
{"SDKError_502", codersdk.NewTestError(http.StatusBadGateway, "GET", "/api"), true},
|
||||
{"SDKError_503", codersdk.NewTestError(http.StatusServiceUnavailable, "GET", "/api"), true},
|
||||
{"SDKError_401", codersdk.NewTestError(http.StatusUnauthorized, "GET", "/api"), false},
|
||||
{"SDKError_403", codersdk.NewTestError(http.StatusForbidden, "GET", "/api"), false},
|
||||
{"SDKError_404", codersdk.NewTestError(http.StatusNotFound, "GET", "/api"), false},
|
||||
{"GenericError", xerrors.New("something went wrong"), false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.Equal(t, tt.retryable, isRetryableError(tt.err))
|
||||
})
|
||||
}
|
||||
|
||||
// net.Dialer.Timeout produces *net.OpError that matches both
|
||||
// IsConnectionError and context.DeadlineExceeded. Verify it is retryable.
|
||||
t.Run("DialTimeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
|
||||
defer cancel()
|
||||
<-ctx.Done() // ensure deadline has fired
|
||||
_, err := (&net.Dialer{}).DialContext(ctx, "tcp", "127.0.0.1:1")
|
||||
require.Error(t, err)
|
||||
// Proves the ambiguity: this error matches BOTH checks.
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorAs(t, err, new(*net.OpError))
|
||||
assert.True(t, isRetryableError(err))
|
||||
// Also when wrapped, as runCoderConnectStdio does.
|
||||
assert.True(t, isRetryableError(xerrors.Errorf("dial coder connect: %w", err)))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetryWithInterval(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const interval = time.Millisecond
|
||||
const maxAttempts = 3
|
||||
|
||||
dnsErr := &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
|
||||
t.Run("Succeeds_FirstTry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Succeeds_AfterTransientFailures", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return dnsErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_NonRetryableError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return xerrors.New("permanent failure")
|
||||
})
|
||||
require.ErrorContains(t, err, "permanent failure")
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_MaxAttemptsExhausted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, maxAttempts, attempts)
|
||||
})
|
||||
|
||||
t.Run("Stops_ContextCanceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
err := retryWithInterval(ctx, logger, interval, maxAttempts, func() error {
|
||||
attempts++
|
||||
cancel()
|
||||
return dnsErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 1, attempts)
|
||||
})
|
||||
}
|
||||
|
||||
+6
@@ -39,6 +39,12 @@ OPTIONS:
|
||||
--block-file-transfer bool, $CODER_AGENT_BLOCK_FILE_TRANSFER (default: false)
|
||||
Block file transfer using known applications: nc,rsync,scp,sftp.
|
||||
|
||||
--block-local-port-forwarding bool, $CODER_AGENT_BLOCK_LOCAL_PORT_FORWARDING (default: false)
|
||||
Block local port forwarding through the SSH server (ssh -L).
|
||||
|
||||
--block-reverse-port-forwarding bool, $CODER_AGENT_BLOCK_REVERSE_PORT_FORWARDING (default: false)
|
||||
Block reverse port forwarding through the SSH server (ssh -R).
|
||||
|
||||
--boundary-log-proxy-socket-path string, $CODER_AGENT_BOUNDARY_LOG_PROXY_SOCKET_PATH (default: /tmp/boundary-audit.sock)
|
||||
The path for the boundary log proxy server Unix socket. Boundary
|
||||
should write audit logs to this socket.
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ OPTIONS:
|
||||
-O, --org string, $CODER_ORGANIZATION
|
||||
Select which organization (uuid or name) to use.
|
||||
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
-c, --column [id|created at|started at|completed at|canceled at|error|error code|status|worker id|worker name|file id|tags|queue position|queue size|organization id|initiator id|template version id|workspace build id|type|available workers|template version name|template id|template name|template display name|template icon|workspace id|workspace name|workspace build transition|logs overflowed|organization|queue] (default: created at,id,type,template display name,status,queue,tags)
|
||||
Columns to display in table output.
|
||||
|
||||
-i, --initiator string, $CODER_PROVISIONER_JOB_LIST_INITIATOR
|
||||
|
||||
@@ -58,7 +58,8 @@
|
||||
"template_display_name": "",
|
||||
"template_icon": "",
|
||||
"workspace_id": "===========[workspace ID]===========",
|
||||
"workspace_name": "test-workspace"
|
||||
"workspace_name": "test-workspace",
|
||||
"workspace_build_transition": "start"
|
||||
},
|
||||
"logs_overflowed": false,
|
||||
"organization_name": "Coder"
|
||||
|
||||
@@ -134,6 +134,7 @@ func TestUserCreate(t *testing.T) {
|
||||
{
|
||||
name: "ServiceAccount",
|
||||
args: []string{"--service-account", "-u", "dean"},
|
||||
err: "Premium feature",
|
||||
},
|
||||
{
|
||||
name: "ServiceAccountLoginType",
|
||||
|
||||
@@ -77,8 +77,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
|
||||
var dbLevel database.LogLevel
|
||||
switch logEntry.Level {
|
||||
|
||||
@@ -139,6 +139,59 @@ func TestBatchCreateLogs(t *testing.T) {
|
||||
require.True(t, publishWorkspaceAgentLogsUpdateCalled)
|
||||
})
|
||||
|
||||
t.Run("SanitizesOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbM := dbmock.NewMockStore(gomock.NewController(t))
|
||||
now := dbtime.Now()
|
||||
api := &agentapi.LogsAPI{
|
||||
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
|
||||
return agent, nil
|
||||
},
|
||||
Database: dbM,
|
||||
Log: testutil.Logger(t),
|
||||
TimeNowFn: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
|
||||
rawOutput := "before\x00middle\xc3\x28after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
expectedOutputLength := int32(len(sanitizedOutput)) //nolint:gosec // Test-controlled string length is small.
|
||||
req := &agentproto.BatchCreateLogsRequest{
|
||||
LogSourceId: logSource.ID[:],
|
||||
Logs: []*agentproto.Log{
|
||||
{
|
||||
CreatedAt: timestamppb.New(now),
|
||||
Level: agentproto.Log_WARN,
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), database.InsertWorkspaceAgentLogsParams{
|
||||
AgentID: agent.ID,
|
||||
LogSourceID: logSource.ID,
|
||||
CreatedAt: now,
|
||||
Output: []string{sanitizedOutput},
|
||||
Level: []database.LogLevel{database.LogLevelWarn},
|
||||
OutputLength: expectedOutputLength,
|
||||
}).Return([]database.WorkspaceAgentLog{
|
||||
{
|
||||
AgentID: agent.ID,
|
||||
CreatedAt: now,
|
||||
ID: 1,
|
||||
Output: sanitizedOutput,
|
||||
Level: database.LogLevelWarn,
|
||||
LogSourceID: logSource.ID,
|
||||
},
|
||||
}, nil)
|
||||
|
||||
resp, err := api.BatchCreateLogs(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp)
|
||||
})
|
||||
|
||||
t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
// An ID is only given in the request when it is a terraform-defined devcontainer
|
||||
// that has attached resources. These subagents are pre-provisioned by terraform
|
||||
// (the agent record already exists), so we update configurable fields like
|
||||
// display_apps rather than creating a new agent.
|
||||
// display_apps and directory rather than creating a new agent.
|
||||
if req.Id != nil {
|
||||
id, err := uuid.FromBytes(req.Id)
|
||||
if err != nil {
|
||||
@@ -97,6 +97,16 @@ func (a *SubAgentAPI) CreateSubAgent(ctx context.Context, req *agentproto.Create
|
||||
return nil, xerrors.Errorf("update workspace agent display apps: %w", err)
|
||||
}
|
||||
|
||||
if req.Directory != "" {
|
||||
if err := a.Database.UpdateWorkspaceAgentDirectoryByID(ctx, database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: id,
|
||||
Directory: req.Directory,
|
||||
UpdatedAt: createdAt,
|
||||
}); err != nil {
|
||||
return nil, xerrors.Errorf("update workspace agent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &agentproto.CreateSubAgentResponse{
|
||||
Agent: &agentproto.SubAgent{
|
||||
Name: subAgent.Name,
|
||||
|
||||
@@ -1267,11 +1267,11 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// And: The database agent's other fields are unchanged.
|
||||
// And: The database agent's name, architecture, and OS are unchanged.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseChildAgent.Name, updatedAgent.Name)
|
||||
require.Equal(t, baseChildAgent.Directory, updatedAgent.Directory)
|
||||
require.Equal(t, "/different/path", updatedAgent.Directory)
|
||||
require.Equal(t, baseChildAgent.Architecture, updatedAgent.Architecture)
|
||||
require.Equal(t, baseChildAgent.OperatingSystem, updatedAgent.OperatingSystem)
|
||||
|
||||
@@ -1280,6 +1280,42 @@ func TestSubAgentAPI(t *testing.T) {
|
||||
require.Equal(t, database.DisplayAppWebTerminal, updatedAgent.DisplayApps[0])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OK_DirectoryUpdated",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
// Given: An existing child agent with a stale host-side
|
||||
// directory (as set by the provisioner at build time).
|
||||
childAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ParentID: uuid.NullUUID{Valid: true, UUID: agent.ID},
|
||||
ResourceID: agent.ResourceID,
|
||||
Name: baseChildAgent.Name,
|
||||
Directory: "/home/coder/project",
|
||||
Architecture: baseChildAgent.Architecture,
|
||||
OperatingSystem: baseChildAgent.OperatingSystem,
|
||||
DisplayApps: baseChildAgent.DisplayApps,
|
||||
})
|
||||
|
||||
// When: Agent injection sends the correct
|
||||
// container-internal path.
|
||||
return &proto.CreateSubAgentRequest{
|
||||
Id: childAgent.ID[:],
|
||||
Directory: "/workspaces/project",
|
||||
DisplayApps: []proto.CreateSubAgentRequest_DisplayApp{
|
||||
proto.CreateSubAgentRequest_WEB_TERMINAL,
|
||||
},
|
||||
}
|
||||
},
|
||||
check: func(t *testing.T, ctx context.Context, db database.Store, resp *proto.CreateSubAgentResponse, agent database.WorkspaceAgent) {
|
||||
agentID, err := uuid.FromBytes(resp.Agent.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: Directory is updated to the container-internal
|
||||
// path.
|
||||
updatedAgent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), agentID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/workspaces/project", updatedAgent.Directory)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error/MalformedID",
|
||||
setup: func(t *testing.T, db database.Store, agent database.WorkspaceAgent) *proto.CreateSubAgentRequest {
|
||||
|
||||
Generated
+359
@@ -1266,6 +1266,68 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Chats"
|
||||
],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -9452,6 +9514,212 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": [
|
||||
"application/json"
|
||||
],
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"Secrets"
|
||||
],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": [
|
||||
@@ -13177,6 +13445,12 @@ const docTemplate = `{
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -14420,6 +14694,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -15072,6 +15354,26 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -18847,6 +19149,9 @@ const docTemplate = `{
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -20952,6 +21257,14 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21193,6 +21506,23 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -21648,6 +21978,35 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
|
||||
Generated
+329
@@ -1103,6 +1103,60 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/experimental/chats/config/retention-days": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Get chat retention days",
|
||||
"operationId": "get-chat-retention-days",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.ChatRetentionDaysResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"consumes": ["application/json"],
|
||||
"tags": ["Chats"],
|
||||
"summary": "Update chat retention days",
|
||||
"operationId": "update-chat-retention-days",
|
||||
"parameters": [
|
||||
{
|
||||
"description": "Request body",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateChatRetentionDaysRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
],
|
||||
"x-apidocgen": {
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/experimental/watch-all-workspacebuilds": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -8377,6 +8431,190 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "List user secrets",
|
||||
"operationId": "list-user-secrets",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"post": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Create a new user secret",
|
||||
"operationId": "create-a-new-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Create secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.CreateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Created",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/secrets/{name}": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Get a user secret by name",
|
||||
"operationId": "get-a-user-secret-by-name",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Delete a user secret",
|
||||
"operationId": "delete-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"204": {
|
||||
"description": "No Content"
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"patch": {
|
||||
"consumes": ["application/json"],
|
||||
"produces": ["application/json"],
|
||||
"tags": ["Secrets"],
|
||||
"summary": "Update a user secret",
|
||||
"operationId": "update-a-user-secret",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "User ID, username, or me",
|
||||
"name": "user",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Secret name",
|
||||
"name": "name",
|
||||
"in": "path",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"description": "Update secret request",
|
||||
"name": "request",
|
||||
"in": "body",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UpdateUserSecretRequest"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.UserSecret"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/users/{user}/status/activate": {
|
||||
"put": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11755,6 +11993,12 @@
|
||||
"$ref": "#/definitions/codersdk.AIBridgeAgenticAction"
|
||||
}
|
||||
},
|
||||
"credential_hint": {
|
||||
"type": "string"
|
||||
},
|
||||
"credential_kind": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
@@ -12963,6 +13207,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ChatRetentionDaysResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.ConnectionLatency": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -13581,6 +13833,26 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateWorkspaceBuildReason": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
@@ -17237,6 +17509,9 @@
|
||||
"template_version_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"workspace_build_transition": {
|
||||
"$ref": "#/definitions/codersdk.WorkspaceTransition"
|
||||
},
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
@@ -19243,6 +19518,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateChatRetentionDaysRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retention_days": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateCheckResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19475,6 +19758,23 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateUserSecretRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"value": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UpdateWorkspaceACL": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -19905,6 +20205,35 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserSecret": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"created_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
"env_name": {
|
||||
"type": "string"
|
||||
},
|
||||
"file_path": {
|
||||
"type": "string"
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"updated_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.UserStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "dormant", "suspended"],
|
||||
|
||||
@@ -1189,6 +1189,8 @@ func New(options *Options) *API {
|
||||
r.Delete("/user-compaction-thresholds/{modelConfig}", api.deleteUserChatCompactionThreshold)
|
||||
r.Get("/workspace-ttl", api.getChatWorkspaceTTL)
|
||||
r.Put("/workspace-ttl", api.putChatWorkspaceTTL)
|
||||
r.Get("/retention-days", api.getChatRetentionDays)
|
||||
r.Put("/retention-days", api.putChatRetentionDays)
|
||||
r.Get("/template-allowlist", api.getChatTemplateAllowlist)
|
||||
r.Put("/template-allowlist", api.putChatTemplateAllowlist)
|
||||
})
|
||||
@@ -1243,6 +1245,7 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/tool-results", api.postChatToolResults)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
@@ -1605,6 +1608,15 @@ func New(options *Options) *API {
|
||||
|
||||
r.Get("/gitsshkey", api.gitSSHKey)
|
||||
r.Put("/gitsshkey", api.regenerateGitSSHKey)
|
||||
r.Route("/secrets", func(r chi.Router) {
|
||||
r.Post("/", api.postUserSecret)
|
||||
r.Get("/", api.getUserSecrets)
|
||||
r.Route("/{name}", func(r chi.Router) {
|
||||
r.Get("/", api.getUserSecret)
|
||||
r.Patch("/", api.patchUserSecret)
|
||||
r.Delete("/", api.deleteUserSecret)
|
||||
})
|
||||
})
|
||||
r.Route("/notifications", func(r chi.Router) {
|
||||
r.Route("/preferences", func(r chi.Router) {
|
||||
r.Get("/", api.userNotificationPreferences)
|
||||
@@ -1650,6 +1662,10 @@ func New(options *Options) *API {
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Post("/log-source", api.workspaceAgentPostLogSource)
|
||||
r.Get("/reinit", api.workspaceAgentReinit)
|
||||
r.Route("/experimental", func(r chi.Router) {
|
||||
r.Post("/chat-context", api.workspaceAgentAddChatContext)
|
||||
r.Delete("/chat-context", api.workspaceAgentClearChatContext)
|
||||
})
|
||||
r.Route("/tasks/{task}", func(r chi.Router) {
|
||||
r.Post("/log-snapshot", api.postWorkspaceAgentTaskLogSnapshot)
|
||||
})
|
||||
|
||||
@@ -147,6 +147,10 @@ func parseSwaggerComment(commentGroup *ast.CommentGroup) SwaggerComment {
|
||||
return c
|
||||
}
|
||||
|
||||
func isExperimentalEndpoint(route string) bool {
|
||||
return strings.HasPrefix(route, "/workspaceagents/me/experimental/")
|
||||
}
|
||||
|
||||
func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments []SwaggerComment) {
|
||||
assertUniqueRoutes(t, swaggerComments)
|
||||
assertSingleAnnotations(t, swaggerComments)
|
||||
@@ -165,6 +169,9 @@ func VerifySwaggerDefinitions(t *testing.T, router chi.Router, swaggerComments [
|
||||
if strings.HasSuffix(route, "/*") {
|
||||
return
|
||||
}
|
||||
if isExperimentalEndpoint(route) {
|
||||
return
|
||||
}
|
||||
|
||||
c := findSwaggerCommentByMethodAndRoute(swaggerComments, method, route)
|
||||
assert.NotNil(t, c, "Missing @Router annotation")
|
||||
|
||||
@@ -123,6 +123,10 @@ func UsersPagination(
|
||||
require.Contains(t, gotUsers[0].Name, "after")
|
||||
}
|
||||
|
||||
type UsersFilterOptions struct {
|
||||
CreateServiceAccounts bool
|
||||
}
|
||||
|
||||
// UsersFilter creates a set of users to run various filters against for
|
||||
// testing. It can be used to test filtering both users and group members.
|
||||
func UsersFilter(
|
||||
@@ -130,11 +134,16 @@ func UsersFilter(
|
||||
t *testing.T,
|
||||
client *codersdk.Client,
|
||||
db database.Store,
|
||||
options *UsersFilterOptions,
|
||||
setup func(users []codersdk.User),
|
||||
fetch func(ctx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
if options == nil {
|
||||
options = &UsersFilterOptions{}
|
||||
}
|
||||
|
||||
firstUser, err := client.User(setupCtx, codersdk.Me)
|
||||
require.NoError(t, err, "fetch me")
|
||||
|
||||
@@ -211,11 +220,13 @@ func UsersFilter(
|
||||
}
|
||||
|
||||
// Add some service accounts.
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
if options.CreateServiceAccounts {
|
||||
for range 3 {
|
||||
_, user := CreateAnotherUserMutators(t, client, orgID, nil, func(r *codersdk.CreateUserRequestWithOrgs) {
|
||||
r.ServiceAccount = true
|
||||
})
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := userpassword.Hash("SomeStrongPassword!")
|
||||
|
||||
@@ -538,6 +538,12 @@ func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator,
|
||||
switch {
|
||||
case workspaceAgent.Status != codersdk.WorkspaceAgentConnected && workspaceAgent.LifecycleState == codersdk.WorkspaceAgentLifecycleOff:
|
||||
workspaceAgent.Health.Reason = "agent is not running"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentConnecting:
|
||||
// Note: the case above catches connecting+off as "not running".
|
||||
// This case handles connecting agents with a non-off lifecycle
|
||||
// (e.g. "created" or "starting"), where the agent binary has
|
||||
// not yet established a connection to coderd.
|
||||
workspaceAgent.Health.Reason = "agent has not yet connected"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentTimeout:
|
||||
workspaceAgent.Health.Reason = "agent is taking too long to connect"
|
||||
case workspaceAgent.Status == codersdk.WorkspaceAgentDisconnected:
|
||||
@@ -1234,6 +1240,8 @@ func buildAIBridgeThread(
|
||||
if rootIntc != nil {
|
||||
thread.Model = rootIntc.Model
|
||||
thread.Provider = rootIntc.Provider
|
||||
thread.CredentialKind = string(rootIntc.CredentialKind)
|
||||
thread.CredentialHint = rootIntc.CredentialHint
|
||||
// Get first user prompt from root interception.
|
||||
// A thread can only have one prompt, by definition, since we currently
|
||||
// only store the last prompt observed in an interception.
|
||||
@@ -1715,3 +1723,41 @@ func ChatDiffStatus(chatID uuid.UUID, status *database.ChatDiffStatus) codersdk.
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// UserSecret converts a database ListUserSecretsRow (metadata only,
|
||||
// no value) to an SDK UserSecret.
|
||||
func UserSecret(secret database.ListUserSecretsRow) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecretFromFull converts a full database UserSecret row to an
|
||||
// SDK UserSecret, omitting the value and encryption key ID.
|
||||
func UserSecretFromFull(secret database.UserSecret) codersdk.UserSecret {
|
||||
return codersdk.UserSecret{
|
||||
ID: secret.ID,
|
||||
Name: secret.Name,
|
||||
Description: secret.Description,
|
||||
EnvName: secret.EnvName,
|
||||
FilePath: secret.FilePath,
|
||||
CreatedAt: secret.CreatedAt,
|
||||
UpdatedAt: secret.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserSecrets converts a slice of database ListUserSecretsRow to
|
||||
// SDK UserSecret values.
|
||||
func UserSecrets(secrets []database.ListUserSecretsRow) []codersdk.UserSecret {
|
||||
result := make([]codersdk.UserSecret, 0, len(secrets))
|
||||
for _, s := range secrets {
|
||||
result = append(result, UserSecret(s))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -552,6 +552,10 @@ func TestChat_AllFieldsPopulated(t *testing.T) {
|
||||
RawMessage: json.RawMessage(`[{"type":"context-file","context_file_path":"/AGENTS.md"}]`),
|
||||
Valid: true,
|
||||
},
|
||||
DynamicTools: pqtype.NullRawMessage{
|
||||
RawMessage: json.RawMessage(`[{"name":"tool1","description":"test tool","inputSchema":{"type":"object"}}]`),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
// Only ChatID is needed here. This test checks that
|
||||
// Chat.DiffStatus is non-nil, not that every DiffStatus
|
||||
|
||||
@@ -1708,6 +1708,17 @@ func (q *querier) CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
return q.db.CleanupDeletedMCPServerIDsFromChats(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
@@ -2031,6 +2042,20 @@ func (q *querier) DeleteOldAuditLogs(ctx context.Context, arg database.DeleteOld
|
||||
return q.db.DeleteOldAuditLogs(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChatFiles(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteOldChats(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceSystem); err != nil {
|
||||
return 0, err
|
||||
@@ -2155,10 +2180,10 @@ func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.De
|
||||
return q.db.DeleteUserChatProviderKey(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String())
|
||||
if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
return q.db.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
}
|
||||
@@ -2399,6 +2424,10 @@ func (q *querier) GetActiveAISeatCount(ctx context.Context) (int64, error) {
|
||||
return q.db.GetActiveAISeatCount(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetActiveChatsByAgentID)(ctx, agentID)
|
||||
}
|
||||
|
||||
func (q *querier) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceTemplate.All()); err != nil {
|
||||
return nil, err
|
||||
@@ -2622,6 +2651,14 @@ func (q *querier) GetChatMessageByID(ctx context.Context, id int64) (database.Ch
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
@@ -2670,6 +2707,14 @@ func (q *querier) GetChatModelConfigs(ctx context.Context) ([]database.ChatModel
|
||||
return q.db.GetChatModelConfigs(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatModelConfigsForTelemetry(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return database.ChatProvider{}, err
|
||||
@@ -2699,6 +2744,15 @@ func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) (
|
||||
return q.db.GetChatQueuedMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
// Chat retention is a deployment-wide config read by dbpurge.
|
||||
// Only requires a valid actor in context.
|
||||
if _, ok := ActorFromContext(ctx); !ok {
|
||||
return 0, ErrNoActor
|
||||
}
|
||||
return q.db.GetChatRetentionDays(ctx)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
// The system prompt is a deployment-wide setting read during chat
|
||||
// creation by every authenticated user, so no RBAC policy check
|
||||
@@ -2777,6 +2831,14 @@ func (q *querier) GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) (
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetChatsByWorkspaceIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
// Telemetry queries are called from system contexts only.
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
}
|
||||
|
||||
func (q *querier) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
// Just like with the audit logs query, shortcut if the user is an owner.
|
||||
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
|
||||
@@ -3339,11 +3401,11 @@ func (q *querier) GetPRInsightsPerModel(ctx context.Context, arg database.GetPRI
|
||||
return q.db.GetPRInsightsPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (q *querier) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetPRInsightsRecentPRs(ctx, arg)
|
||||
return q.db.GetPRInsightsPullRequests(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetPRInsightsSummary(ctx context.Context, arg database.GetPRInsightsSummaryParams) (database.GetPRInsightsSummaryRow, error) {
|
||||
@@ -5681,6 +5743,17 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
|
||||
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
chat, err := q.db.GetChatByID(ctx, chatID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
}
|
||||
|
||||
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
|
||||
return q.db.TryAcquireLock(ctx, id)
|
||||
}
|
||||
@@ -6710,6 +6783,19 @@ func (q *querier) UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg da
|
||||
return q.db.UpdateWorkspaceAgentConnectionByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdateAgent, workspace); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return q.db.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
workspace, err := q.db.GetWorkspaceByAgentID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -7031,6 +7117,13 @@ func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, incl
|
||||
return q.db.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
return q.db.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
}
|
||||
|
||||
func (q *querier) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
|
||||
return err
|
||||
|
||||
@@ -478,6 +478,24 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatsByWorkspaceIDs(gomock.Any(), arg).Return([]database.Chat{chatA, chatB}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chatA, policy.ActionRead, chatB, policy.ActionRead).Returns([]database.Chat{chatA, chatB})
|
||||
}))
|
||||
s.Run("GetActiveChatsByAgentID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
agentID := uuid.New()
|
||||
dbm.EXPECT().GetActiveChatsByAgentID(gomock.Any(), agentID).Return([]database.Chat{chat}, nil).AnyTimes()
|
||||
check.Args(agentID).Asserts(chat, policy.ActionRead).Returns([]database.Chat{chat})
|
||||
}))
|
||||
s.Run("SoftDeleteContextFileMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().SoftDeleteContextFileMessages(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("ClearChatMessageProviderResponseIDsByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().ClearChatMessageProviderResponseIDsByChatID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
@@ -600,6 +618,22 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
|
||||
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
|
||||
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
|
||||
check.Args().Asserts()
|
||||
}))
|
||||
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
|
||||
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
|
||||
}))
|
||||
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
|
||||
@@ -2227,9 +2261,9 @@ func (s *MethodTestSuite) TestTemplate() {
|
||||
dbm.EXPECT().GetPRInsightsPerModel(gomock.Any(), arg).Return([]database.GetPRInsightsPerModelRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetPRInsightsRecentPRs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsRecentPRsParams{}
|
||||
dbm.EXPECT().GetPRInsightsRecentPRs(gomock.Any(), arg).Return([]database.GetPRInsightsRecentPRsRow{}, nil).AnyTimes()
|
||||
s.Run("GetPRInsightsPullRequests", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetPRInsightsPullRequestsParams{}
|
||||
dbm.EXPECT().GetPRInsightsPullRequests(gomock.Any(), arg).Return([]database.GetPRInsightsPullRequestsRow{}, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetTelemetryTaskEvents", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
@@ -2901,6 +2935,17 @@ func (s *MethodTestSuite) TestWorkspace() {
|
||||
dbm.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdate).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDirectoryByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
arg := database.UpdateWorkspaceAgentDirectoryByIDParams{
|
||||
ID: agt.ID,
|
||||
Directory: "/workspaces/project",
|
||||
}
|
||||
dbm.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agt.ID).Return(w, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateWorkspaceAgentDirectoryByID(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
check.Args(arg).Asserts(w, policy.ActionUpdateAgent).Returns()
|
||||
}))
|
||||
s.Run("UpdateWorkspaceAgentDisplayAppsByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
w := testutil.Fake(s.T(), faker, database.Workspace{})
|
||||
agt := testutil.Fake(s.T(), faker, database.WorkspaceAgent{})
|
||||
@@ -3996,6 +4041,20 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
dbm.EXPECT().GetWorkspaceAgentsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceAgent{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatsUpdatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatsUpdatedAfter(gomock.Any(), ts).Return([]database.GetChatsUpdatedAfterRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatMessageSummariesPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetChatMessageSummariesPerChat(gomock.Any(), ts).Return([]database.GetChatMessageSummariesPerChatRow{}, nil).AnyTimes()
|
||||
check.Args(ts).Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetChatModelConfigsForTelemetry", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
dbm.EXPECT().GetChatModelConfigsForTelemetry(gomock.Any()).Return([]database.GetChatModelConfigsForTelemetryRow{}, nil).AnyTimes()
|
||||
check.Args().Asserts(rbac.ResourceSystem, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetWorkspaceAppsCreatedAfter", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
ts := dbtime.Now()
|
||||
dbm.EXPECT().GetWorkspaceAppsCreatedAfter(gomock.Any(), ts).Return([]database.WorkspaceApp{}, nil).AnyTimes()
|
||||
@@ -5383,10 +5442,10 @@ func (s *MethodTestSuite) TestUserSecrets() {
|
||||
s.Run("DeleteUserSecretByUserIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
user := testutil.Fake(s.T(), faker, database.User{})
|
||||
arg := database.DeleteUserSecretByUserIDAndNameParams{UserID: user.ID, Name: "test"}
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(nil).AnyTimes()
|
||||
dbm.EXPECT().DeleteUserSecretByUserIDAndName(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
|
||||
check.Args(arg).
|
||||
Asserts(rbac.ResourceUserSecret.WithOwner(user.ID.String()), policy.ActionDelete).
|
||||
Returns()
|
||||
Returns(int64(1))
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1644,6 +1644,8 @@ func AIBridgeInterception(t testing.TB, db database.Store, seed database.InsertA
|
||||
ThreadParentInterceptionID: seed.ThreadParentInterceptionID,
|
||||
ThreadRootInterceptionID: seed.ThreadRootInterceptionID,
|
||||
ClientSessionID: seed.ClientSessionID,
|
||||
CredentialKind: takeFirst(seed.CredentialKind, database.CredentialKindCentralized),
|
||||
CredentialHint: takeFirst(seed.CredentialHint, ""),
|
||||
})
|
||||
if endedAt != nil {
|
||||
interception, err = db.UpdateAIBridgeInterceptionEnded(genCtx, database.UpdateAIBridgeInterceptionEndedParams{
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CleanupDeletedMCPServerIDsFromChats(ctx context.Conte
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("ClearChatMessageProviderResponseIDsByChatID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ClearChatMessageProviderResponseIDsByChatID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeInterceptions(ctx, arg)
|
||||
@@ -592,6 +600,22 @@ func (m queryMetricsStore) DeleteOldAuditLogs(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChatFiles(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChatFiles").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChatFiles").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldChats(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteOldChats").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteOldChats").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.DeleteOldConnectionLogs(ctx, arg)
|
||||
@@ -712,12 +736,12 @@ func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg da
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("DeleteUserSecretByUserIDAndName").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserSecretByUserIDAndName").Inc()
|
||||
return r0
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg database.DeleteWebpushSubscriptionByUserIDAndEndpointParams) error {
|
||||
@@ -952,6 +976,14 @@ func (m queryMetricsStore) GetActiveAISeatCount(ctx context.Context) (int64, err
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActiveChatsByAgentID(ctx, agentID)
|
||||
m.queryLatencies.WithLabelValues("GetActiveChatsByAgentID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetActiveChatsByAgentID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetActivePresetPrebuildSchedules(ctx)
|
||||
@@ -1160,6 +1192,14 @@ func (m queryMetricsStore) GetChatMessageByID(ctx context.Context, id int64) (da
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessageSummariesPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessageSummariesPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatID(ctx, chatID)
|
||||
@@ -1208,6 +1248,14 @@ func (m queryMetricsStore) GetChatModelConfigs(ctx context.Context) ([]database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatModelConfigsForTelemetry(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatModelConfigsForTelemetry").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatModelConfigsForTelemetry").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatProviderByID(ctx, id)
|
||||
@@ -1240,6 +1288,14 @@ func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatRetentionDays(ctx)
|
||||
m.queryLatencies.WithLabelValues("GetChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatRetentionDays").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatSystemPrompt(ctx)
|
||||
@@ -1312,6 +1368,14 @@ func (m queryMetricsStore) GetChatsByWorkspaceIDs(ctx context.Context, ids []uui
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatsUpdatedAfter(ctx, updatedAfter)
|
||||
m.queryLatencies.WithLabelValues("GetChatsUpdatedAfter").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatsUpdatedAfter").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetConnectionLogsOffset(ctx, arg)
|
||||
@@ -1928,11 +1992,11 @@ func (m queryMetricsStore) GetPRInsightsPerModel(ctx context.Context, arg databa
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
func (m queryMetricsStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetPRInsightsRecentPRs(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsRecentPRs").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsRecentPRs").Inc()
|
||||
r0, r1 := m.s.GetPRInsightsPullRequests(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetPRInsightsPullRequests").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetPRInsightsPullRequests").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
@@ -4056,6 +4120,14 @@ func (m queryMetricsStore) SoftDeleteChatMessagesAfterID(ctx context.Context, ar
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.SoftDeleteContextFileMessages(ctx, chatID)
|
||||
m.queryLatencies.WithLabelValues("SoftDeleteContextFileMessages").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "SoftDeleteContextFileMessages").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock)
|
||||
@@ -4768,6 +4840,14 @@ func (m queryMetricsStore) UpdateWorkspaceAgentConnectionByID(ctx context.Contex
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDirectoryByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateWorkspaceAgentDirectoryByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateWorkspaceAgentDirectoryByID").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpdateWorkspaceAgentDisplayAppsByID(ctx, arg)
|
||||
@@ -5000,6 +5080,14 @@ func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Cont
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatRetentionDays(ctx, retentionDays)
|
||||
m.queryLatencies.WithLabelValues("UpsertChatRetentionDays").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatRetentionDays").Inc()
|
||||
return r0
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.UpsertChatSystemPrompt(ctx, value)
|
||||
|
||||
@@ -363,6 +363,20 @@ func (mr *MockStoreMockRecorder) CleanupDeletedMCPServerIDsFromChats(ctx any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupDeletedMCPServerIDsFromChats", reflect.TypeOf((*MockStore)(nil).CleanupDeletedMCPServerIDsFromChats), ctx)
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID mocks base method.
|
||||
func (m *MockStore) ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClearChatMessageProviderResponseIDsByChatID", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClearChatMessageProviderResponseIDsByChatID indicates an expected call of ClearChatMessageProviderResponseIDsByChatID.
|
||||
func (mr *MockStoreMockRecorder) ClearChatMessageProviderResponseIDsByChatID(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearChatMessageProviderResponseIDsByChatID", reflect.TypeOf((*MockStore)(nil).ClearChatMessageProviderResponseIDsByChatID), ctx, chatID)
|
||||
}
|
||||
|
||||
// CountAIBridgeInterceptions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeInterceptions(ctx context.Context, arg database.CountAIBridgeInterceptionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -984,6 +998,36 @@ func (mr *MockStoreMockRecorder) DeleteOldAuditLogs(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldAuditLogs", reflect.TypeOf((*MockStore)(nil).DeleteOldAuditLogs), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles mocks base method.
|
||||
func (m *MockStore) DeleteOldChatFiles(ctx context.Context, arg database.DeleteOldChatFilesParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChatFiles", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChatFiles indicates an expected call of DeleteOldChatFiles.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChatFiles(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChatFiles", reflect.TypeOf((*MockStore)(nil).DeleteOldChatFiles), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldChats mocks base method.
|
||||
func (m *MockStore) DeleteOldChats(ctx context.Context, arg database.DeleteOldChatsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteOldChats", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteOldChats indicates an expected call of DeleteOldChats.
|
||||
func (mr *MockStoreMockRecorder) DeleteOldChats(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOldChats", reflect.TypeOf((*MockStore)(nil).DeleteOldChats), ctx, arg)
|
||||
}
|
||||
|
||||
// DeleteOldConnectionLogs mocks base method.
|
||||
func (m *MockStore) DeleteOldConnectionLogs(ctx context.Context, arg database.DeleteOldConnectionLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1200,11 +1244,12 @@ func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName mocks base method.
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) error {
|
||||
func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteUserSecretByUserIDAndName", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DeleteUserSecretByUserIDAndName indicates an expected call of DeleteUserSecretByUserIDAndName.
|
||||
@@ -1637,6 +1682,21 @@ func (mr *MockStoreMockRecorder) GetActiveAISeatCount(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveAISeatCount", reflect.TypeOf((*MockStore)(nil).GetActiveAISeatCount), ctx)
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID mocks base method.
|
||||
func (m *MockStore) GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveChatsByAgentID", ctx, agentID)
|
||||
ret0, _ := ret[0].([]database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveChatsByAgentID indicates an expected call of GetActiveChatsByAgentID.
|
||||
func (mr *MockStoreMockRecorder) GetActiveChatsByAgentID(ctx, agentID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveChatsByAgentID", reflect.TypeOf((*MockStore)(nil).GetActiveChatsByAgentID), ctx, agentID)
|
||||
}
|
||||
|
||||
// GetActivePresetPrebuildSchedules mocks base method.
|
||||
func (m *MockStore) GetActivePresetPrebuildSchedules(ctx context.Context) ([]database.TemplateVersionPresetPrebuildSchedule, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2132,6 +2192,21 @@ func (mr *MockStoreMockRecorder) GetChatMessageByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageByID", reflect.TypeOf((*MockStore)(nil).GetChatMessageByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat mocks base method.
|
||||
func (m *MockStore) GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]database.GetChatMessageSummariesPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessageSummariesPerChat", ctx, createdAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatMessageSummariesPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessageSummariesPerChat indicates an expected call of GetChatMessageSummariesPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessageSummariesPerChat(ctx, createdAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessageSummariesPerChat", reflect.TypeOf((*MockStore)(nil).GetChatMessageSummariesPerChat), ctx, createdAfter)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatID mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatID(ctx context.Context, arg database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2222,6 +2297,21 @@ func (mr *MockStoreMockRecorder) GetChatModelConfigs(ctx any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigs), ctx)
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry mocks base method.
|
||||
func (m *MockStore) GetChatModelConfigsForTelemetry(ctx context.Context) ([]database.GetChatModelConfigsForTelemetryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatModelConfigsForTelemetry", ctx)
|
||||
ret0, _ := ret[0].([]database.GetChatModelConfigsForTelemetryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatModelConfigsForTelemetry indicates an expected call of GetChatModelConfigsForTelemetry.
|
||||
func (mr *MockStoreMockRecorder) GetChatModelConfigsForTelemetry(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatModelConfigsForTelemetry", reflect.TypeOf((*MockStore)(nil).GetChatModelConfigsForTelemetry), ctx)
|
||||
}
|
||||
|
||||
// GetChatProviderByID mocks base method.
|
||||
func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2282,6 +2372,21 @@ func (mr *MockStoreMockRecorder) GetChatQueuedMessages(ctx, chatID any) *gomock.
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatQueuedMessages", reflect.TypeOf((*MockStore)(nil).GetChatQueuedMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// GetChatRetentionDays mocks base method.
|
||||
func (m *MockStore) GetChatRetentionDays(ctx context.Context) (int32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatRetentionDays", ctx)
|
||||
ret0, _ := ret[0].(int32)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatRetentionDays indicates an expected call of GetChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) GetChatRetentionDays(ctx any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatRetentionDays", reflect.TypeOf((*MockStore)(nil).GetChatRetentionDays), ctx)
|
||||
}
|
||||
|
||||
// GetChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) GetChatSystemPrompt(ctx context.Context) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2417,6 +2522,21 @@ func (mr *MockStoreMockRecorder) GetChatsByWorkspaceIDs(ctx, ids any) *gomock.Ca
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsByWorkspaceIDs", reflect.TypeOf((*MockStore)(nil).GetChatsByWorkspaceIDs), ctx, ids)
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter mocks base method.
|
||||
func (m *MockStore) GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]database.GetChatsUpdatedAfterRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatsUpdatedAfter", ctx, updatedAfter)
|
||||
ret0, _ := ret[0].([]database.GetChatsUpdatedAfterRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatsUpdatedAfter indicates an expected call of GetChatsUpdatedAfter.
|
||||
func (mr *MockStoreMockRecorder) GetChatsUpdatedAfter(ctx, updatedAfter any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatsUpdatedAfter", reflect.TypeOf((*MockStore)(nil).GetChatsUpdatedAfter), ctx, updatedAfter)
|
||||
}
|
||||
|
||||
// GetConnectionLogsOffset mocks base method.
|
||||
func (m *MockStore) GetConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams) ([]database.GetConnectionLogsOffsetRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3572,19 +3692,19 @@ func (mr *MockStoreMockRecorder) GetPRInsightsPerModel(ctx, arg any) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPerModel", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs mocks base method.
|
||||
func (m *MockStore) GetPRInsightsRecentPRs(ctx context.Context, arg database.GetPRInsightsRecentPRsParams) ([]database.GetPRInsightsRecentPRsRow, error) {
|
||||
// GetPRInsightsPullRequests mocks base method.
|
||||
func (m *MockStore) GetPRInsightsPullRequests(ctx context.Context, arg database.GetPRInsightsPullRequestsParams) ([]database.GetPRInsightsPullRequestsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsRecentPRs", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsRecentPRsRow)
|
||||
ret := m.ctrl.Call(m, "GetPRInsightsPullRequests", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetPRInsightsPullRequestsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPRInsightsRecentPRs indicates an expected call of GetPRInsightsRecentPRs.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsRecentPRs(ctx, arg any) *gomock.Call {
|
||||
// GetPRInsightsPullRequests indicates an expected call of GetPRInsightsPullRequests.
|
||||
func (mr *MockStoreMockRecorder) GetPRInsightsPullRequests(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsRecentPRs", reflect.TypeOf((*MockStore)(nil).GetPRInsightsRecentPRs), ctx, arg)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPRInsightsPullRequests", reflect.TypeOf((*MockStore)(nil).GetPRInsightsPullRequests), ctx, arg)
|
||||
}
|
||||
|
||||
// GetPRInsightsSummary mocks base method.
|
||||
@@ -7690,6 +7810,20 @@ func (mr *MockStoreMockRecorder) SoftDeleteChatMessagesAfterID(ctx, arg any) *go
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteChatMessagesAfterID", reflect.TypeOf((*MockStore)(nil).SoftDeleteChatMessagesAfterID), ctx, arg)
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages mocks base method.
|
||||
func (m *MockStore) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SoftDeleteContextFileMessages", ctx, chatID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SoftDeleteContextFileMessages indicates an expected call of SoftDeleteContextFileMessages.
|
||||
func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID)
|
||||
}
|
||||
|
||||
// TryAcquireLock mocks base method.
|
||||
func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -8986,6 +9120,20 @@ func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentConnectionByID(ctx, arg any
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentConnectionByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentConnectionByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg database.UpdateWorkspaceAgentDirectoryByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateWorkspaceAgentDirectoryByID", ctx, arg)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDirectoryByID indicates an expected call of UpdateWorkspaceAgentDirectoryByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateWorkspaceAgentDirectoryByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateWorkspaceAgentDirectoryByID", reflect.TypeOf((*MockStore)(nil).UpdateWorkspaceAgentDirectoryByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateWorkspaceAgentDisplayAppsByID mocks base method.
|
||||
func (m *MockStore) UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg database.UpdateWorkspaceAgentDisplayAppsByIDParams) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -9399,6 +9547,20 @@ func (mr *MockStoreMockRecorder) UpsertChatIncludeDefaultSystemPrompt(ctx, inclu
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatIncludeDefaultSystemPrompt", reflect.TypeOf((*MockStore)(nil).UpsertChatIncludeDefaultSystemPrompt), ctx, includeDefaultSystemPrompt)
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays mocks base method.
|
||||
func (m *MockStore) UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpsertChatRetentionDays", ctx, retentionDays)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpsertChatRetentionDays indicates an expected call of UpsertChatRetentionDays.
|
||||
func (mr *MockStoreMockRecorder) UpsertChatRetentionDays(ctx, retentionDays any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatRetentionDays", reflect.TypeOf((*MockStore)(nil).UpsertChatRetentionDays), ctx, retentionDays)
|
||||
}
|
||||
|
||||
// UpsertChatSystemPrompt mocks base method.
|
||||
func (m *MockStore) UpsertChatSystemPrompt(ctx context.Context, value string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -34,6 +34,11 @@ const (
|
||||
// long enough to cover the maximum interval of a heartbeat event (currently
|
||||
// 1 hour) plus some buffer.
|
||||
maxTelemetryHeartbeatAge = 24 * time.Hour
|
||||
// Batch sizes for chat purging. Both use 1000, which is smaller
|
||||
// than audit/connection log batches (10000), because chat_files
|
||||
// rows contain bytea blob data that make large batches heavier.
|
||||
chatsBatchSize = 1000
|
||||
chatFilesBatchSize = 1000
|
||||
)
|
||||
|
||||
// New creates a new periodically purging database instance.
|
||||
@@ -109,6 +114,17 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder
|
||||
// purgeTick performs a single purge iteration. It returns an error if the
|
||||
// purge fails.
|
||||
func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error {
|
||||
// Read chat retention config outside the transaction to
|
||||
// avoid poisoning the tx if the stored value is corrupt.
|
||||
// A SQL-level cast error (e.g. non-numeric text) puts PG
|
||||
// into error state, failing all subsequent queries in the
|
||||
// same transaction.
|
||||
chatRetentionDays, err := db.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
i.logger.Warn(ctx, "failed to read chat retention config, skipping chat purge", slog.Error(err))
|
||||
chatRetentionDays = 0
|
||||
}
|
||||
|
||||
// Start a transaction to grab advisory lock, we don't want to run
|
||||
// multiple purges at the same time (multiple replicas).
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
@@ -213,12 +229,43 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
}
|
||||
}
|
||||
|
||||
// Chat retention is configured via site_configs. When
|
||||
// enabled, old archived chats are deleted first, then
|
||||
// orphaned chat files. Deleting a chat cascades to
|
||||
// chat_file_links (removing references) but not to
|
||||
// chat_files directly, so files from deleted chats
|
||||
// become orphaned and are caught by DeleteOldChatFiles
|
||||
// in the same tick.
|
||||
var purgedChats int64
|
||||
var purgedChatFiles int64
|
||||
if chatRetentionDays > 0 {
|
||||
chatRetention := time.Duration(chatRetentionDays) * 24 * time.Hour
|
||||
deleteChatsBefore := start.Add(-chatRetention)
|
||||
|
||||
purgedChats, err = tx.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatsBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chats: %w", err)
|
||||
}
|
||||
|
||||
purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: deleteChatsBefore,
|
||||
LimitCount: chatFilesBatchSize,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to delete old chat files: %w", err)
|
||||
}
|
||||
}
|
||||
i.logger.Debug(ctx, "purged old database entries",
|
||||
slog.F("workspace_agent_logs", purgedWorkspaceAgentLogs),
|
||||
slog.F("expired_api_keys", expiredAPIKeys),
|
||||
slog.F("aibridge_records", purgedAIBridgeRecords),
|
||||
slog.F("connection_logs", purgedConnectionLogs),
|
||||
slog.F("audit_logs", purgedAuditLogs),
|
||||
slog.F("chats", purgedChats),
|
||||
slog.F("chat_files", purgedChatFiles),
|
||||
slog.F("duration", i.clk.Since(start)),
|
||||
)
|
||||
|
||||
@@ -232,6 +279,8 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.
|
||||
i.recordsPurged.WithLabelValues("aibridge_records").Add(float64(purgedAIBridgeRecords))
|
||||
i.recordsPurged.WithLabelValues("connection_logs").Add(float64(purgedConnectionLogs))
|
||||
i.recordsPurged.WithLabelValues("audit_logs").Add(float64(purgedAuditLogs))
|
||||
i.recordsPurged.WithLabelValues("chats").Add(float64(purgedChats))
|
||||
i.recordsPurged.WithLabelValues("chat_files").Add(float64(purgedChatFiles))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -53,6 +54,7 @@ func TestPurge(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
mDB := dbmock.NewMockStore(gomock.NewController(t))
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2)
|
||||
purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
<-done // wait for doTick() to run.
|
||||
@@ -125,6 +127,16 @@ func TestMetrics(t *testing.T) {
|
||||
"record_type": "audit_logs",
|
||||
})
|
||||
require.GreaterOrEqual(t, auditLogs, 0)
|
||||
|
||||
chats := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chats",
|
||||
})
|
||||
require.GreaterOrEqual(t, chats, 0)
|
||||
|
||||
chatFiles := promhelp.CounterValue(t, reg, "coderd_dbpurge_records_purged_total", prometheus.Labels{
|
||||
"record_type": "chat_files",
|
||||
})
|
||||
require.GreaterOrEqual(t, chatFiles, 0)
|
||||
})
|
||||
|
||||
t.Run("FailedIteration", func(t *testing.T) {
|
||||
@@ -138,6 +150,7 @@ func TestMetrics(t *testing.T) {
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mDB := dbmock.NewMockStore(ctrl)
|
||||
mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes()
|
||||
mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).
|
||||
Return(xerrors.New("simulated database error")).
|
||||
MinTimes(1)
|
||||
@@ -1634,3 +1647,488 @@ func TestDeleteExpiredAPIKeys(t *testing.T) {
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
//nolint:paralleltest // It uses LockIDDBPurge.
|
||||
func TestDeleteOldChatFiles(t *testing.T) {
|
||||
now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
// createChatFile inserts a chat file and backdates created_at.
|
||||
createChatFile := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, orgID uuid.UUID, createdAt time.Time) uuid.UUID {
|
||||
t.Helper()
|
||||
row, err := db.InsertChatFile(ctx, database.InsertChatFileParams{
|
||||
OwnerID: ownerID,
|
||||
OrganizationID: orgID,
|
||||
Name: "test.png",
|
||||
Mimetype: "image/png",
|
||||
Data: []byte("fake-image-data"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chat_files SET created_at = $1 WHERE id = $2", createdAt, row.ID)
|
||||
require.NoError(t, err)
|
||||
return row.ID
|
||||
}
|
||||
|
||||
// createChat inserts a chat and optionally archives it, then
|
||||
// backdates updated_at to control the "archived since" window.
|
||||
createChat := func(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, ownerID, modelConfigID uuid.UUID, archived bool, updatedAt time.Time) database.Chat {
|
||||
t.Helper()
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "test-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2", updatedAt, chat.ID)
|
||||
require.NoError(t, err)
|
||||
return chat
|
||||
}
|
||||
// setupChatDeps creates the common dependencies needed for
|
||||
// chat-related tests: user, org, org member, provider, model config.
|
||||
type chatDeps struct {
|
||||
user database.User
|
||||
org database.Organization
|
||||
modelConfig database.ChatModelConfig
|
||||
}
|
||||
setupChatDeps := func(ctx context.Context, t *testing.T, db database.Store) chatDeps {
|
||||
t.Helper()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
mc, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "test-model",
|
||||
ContextLimit: 8192,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return chatDeps{user: user, org: org, modelConfig: mc}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "ChatRetentionDisabled",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Disable retention.
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(0))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an old archived chat and an orphaned old file.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
oldFileID := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Both should still exist.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.NoError(t, err, "chat should not be deleted when retention is disabled")
|
||||
_, err = db.GetChatFileByID(ctx, oldFileID)
|
||||
require.NoError(t, err, "chat file should not be deleted when retention is disabled")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OldArchivedChatsDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old archived chat (31 days) — should be deleted.
|
||||
oldChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
// Insert a message so we can verify CASCADE.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: oldChat.ID,
|
||||
CreatedBy: []uuid.UUID{deps.user.ID},
|
||||
ModelConfigID: []uuid.UUID{deps.modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`},
|
||||
ContentVersion: []int16{0},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Recently archived chat (10 days) — should be retained.
|
||||
recentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
|
||||
// Active chat — should be retained.
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
// Old archived chat should be gone.
|
||||
_, err = db.GetChatByID(ctx, oldChat.ID)
|
||||
require.Error(t, err, "old archived chat should be deleted")
|
||||
|
||||
// Its messages should be gone too (CASCADE).
|
||||
msgs, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: oldChat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, msgs, "messages should be cascade-deleted")
|
||||
|
||||
// Recent archived and active chats should remain.
|
||||
_, err = db.GetChatByID(ctx, recentChat.ID)
|
||||
require.NoError(t, err, "recently archived chat should be retained")
|
||||
_, err = db.GetChatByID(ctx, activeChat.ID)
|
||||
require.NoError(t, err, "active chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "OrphanedOldFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File A: 31 days old, NOT in any chat -> should be deleted.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
|
||||
// File B: 31 days old, in an active chat -> should be retained.
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
activeChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileB},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// File C: 10 days old, NOT in any chat -> should be retained (too young).
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-10*24*time.Hour))
|
||||
|
||||
// File near boundary: 29d23h old — close to threshold.
|
||||
fileBoundary := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-30*24*time.Hour).Add(time.Hour))
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileA)
|
||||
require.Error(t, err, "orphaned old file A should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileB)
|
||||
require.NoError(t, err, "file B in active chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileC)
|
||||
require.NoError(t, err, "young file C should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileBoundary)
|
||||
require.NoError(t, err, "file near 30d boundary should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ArchivedChatFilesDeleted",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clk := quartz.NewMock(t)
|
||||
clk.Set(now).MustWait(ctx)
|
||||
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
err := db.UpsertChatRetentionDays(ctx, int32(30))
|
||||
require.NoError(t, err)
|
||||
|
||||
// File D: 31 days old, in a chat archived 31 days ago -> should be deleted.
|
||||
fileD := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
oldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: oldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileD},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// LinkChatFiles does not update chats.updated_at, so backdate.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), oldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File E: 31 days old, in a chat archived 10 days ago -> should be retained.
|
||||
fileE := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
recentArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-10*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: recentArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileE},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-10*24*time.Hour), recentArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// File F: 31 days old, in BOTH an active chat AND an old archived chat -> should be retained.
|
||||
fileF := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
anotherOldArchivedChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: anotherOldArchivedChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET updated_at = $1 WHERE id = $2",
|
||||
now.Add(-31*24*time.Hour), anotherOldArchivedChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeChatForF := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: activeChatForF.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileF},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
done := awaitDoTick(ctx, t, clk)
|
||||
closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, clk, prometheus.NewRegistry())
|
||||
defer closer.Close()
|
||||
testutil.TryReceive(ctx, t, done)
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileD)
|
||||
require.Error(t, err, "file D in old archived chat should be deleted")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileE)
|
||||
require.NoError(t, err, "file E in recently archived chat should be retained")
|
||||
|
||||
_, err = db.GetChatFileByID(ctx, fileF)
|
||||
require.NoError(t, err, "file F in active + old archived chat should be retained")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UnarchiveAfterFilePurge",
|
||||
run: func(t *testing.T) {
|
||||
// Validates that when dbpurge deletes chat_files rows,
|
||||
// the FK cascade on chat_file_links automatically
|
||||
// removes the stale links. Unarchiving a chat after
|
||||
// file purge should show only surviving files.
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create a chat with three attached files.
|
||||
fileA := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileB := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
fileC := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
chat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
_, err := db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: chat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{fileA, fileB, fileC},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the chat.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate dbpurge deleting files A and B. The FK
|
||||
// cascade on chat_file_links_file_id_fkey should
|
||||
// automatically remove the corresponding link rows.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)", pq.Array([]uuid.UUID{fileA, fileB}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive the chat.
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Only file C should remain linked (FK cascade
|
||||
// removed the links for deleted files A and B).
|
||||
files, err := db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1, "only surviving file should be linked")
|
||||
require.Equal(t, fileC, files[0].ID)
|
||||
|
||||
// Edge case: delete the last file too. The chat
|
||||
// should have zero linked files, not an error.
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = $1", fileC)
|
||||
require.NoError(t, err)
|
||||
_, err = db.UnarchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
files, err = db.GetChatFileMetadataByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, files, "all-files-deleted should yield empty result")
|
||||
|
||||
// Test parent+child cascade: deleting files should
|
||||
// clean up links for both parent and child chats
|
||||
// independently via FK cascade.
|
||||
parentChat := createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, false, now)
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: deps.user.ID,
|
||||
LastModelConfigID: deps.modelConfig.ID,
|
||||
Title: "child-chat",
|
||||
Status: database.ChatStatusWaiting,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// Set root_chat_id to link child to parent.
|
||||
_, err = rawDB.ExecContext(ctx, "UPDATE chats SET root_chat_id = $1 WHERE id = $2", parentChat.ID, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attach different files to parent and child.
|
||||
parentFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
parentFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileKeep := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
childFileStale := createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now)
|
||||
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: parentChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{parentFileKeep, parentFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.LinkChatFiles(ctx, database.LinkChatFilesParams{
|
||||
ChatID: childChat.ID,
|
||||
MaxFileLinks: 100,
|
||||
FileIds: []uuid.UUID{childFileKeep, childFileStale},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive via parent (cascades to child).
|
||||
_, err = db.ArchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete one file from each chat.
|
||||
_, err = rawDB.ExecContext(ctx, "DELETE FROM chat_files WHERE id = ANY($1)",
|
||||
pq.Array([]uuid.UUID{parentFileStale, childFileStale}))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive via parent.
|
||||
_, err = db.UnarchiveChatByID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
parentFiles, err := db.GetChatFileMetadataByChatID(ctx, parentChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, parentFiles, 1)
|
||||
require.Equal(t, parentFileKeep, parentFiles[0].ID,
|
||||
"parent should retain only non-stale file")
|
||||
|
||||
childFiles, err := db.GetChatFileMetadataByChatID(ctx, childChat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, childFiles, 1)
|
||||
require.Equal(t, childFileKeep, childFiles[0].ID,
|
||||
"child should retain only non-stale file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitFiles",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable orphaned files (all 31 days old).
|
||||
for range 3 {
|
||||
createChatFile(ctx, t, db, rawDB, deps.user.ID, deps.org.ID, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 files")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 file")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "BatchLimitChats",
|
||||
run: func(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure())
|
||||
deps := setupChatDeps(ctx, t, db)
|
||||
|
||||
// Create 3 deletable old archived chats.
|
||||
for range 3 {
|
||||
createChat(ctx, t, db, rawDB, deps.user.ID, deps.modelConfig.ID, true, now.Add(-31*24*time.Hour))
|
||||
}
|
||||
|
||||
// Delete with limit 2 — should delete 2, leave 1.
|
||||
deleted, err := db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted, "should delete exactly 2 chats")
|
||||
|
||||
// Delete again — should delete the remaining 1.
|
||||
deleted, err = db.DeleteOldChats(ctx, database.DeleteOldChatsParams{
|
||||
BeforeTime: now.Add(-30 * 24 * time.Hour),
|
||||
LimitCount: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), deleted, "should delete remaining 1 chat")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.run(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Generated
+18
-5
@@ -293,7 +293,8 @@ CREATE TYPE chat_status AS ENUM (
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
'error',
|
||||
'requires_action'
|
||||
);
|
||||
|
||||
CREATE TYPE connection_status AS ENUM (
|
||||
@@ -315,6 +316,11 @@ CREATE TYPE cors_behavior AS ENUM (
|
||||
'passthru'
|
||||
);
|
||||
|
||||
CREATE TYPE credential_kind AS ENUM (
|
||||
'centralized',
|
||||
'byok'
|
||||
);
|
||||
|
||||
CREATE TYPE crypto_key_feature AS ENUM (
|
||||
'workspace_apps_token',
|
||||
'workspace_apps_api_key',
|
||||
@@ -1101,7 +1107,9 @@ CREATE TABLE aibridge_interceptions (
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL,
|
||||
provider_name text DEFAULT ''::text NOT NULL
|
||||
provider_name text DEFAULT ''::text NOT NULL,
|
||||
credential_kind credential_kind DEFAULT 'centralized'::credential_kind NOT NULL,
|
||||
credential_hint character varying(15) DEFAULT ''::character varying NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1118,6 +1126,10 @@ COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related intercept
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.provider_name IS 'The provider instance name which may differ from provider when multiple instances of the same provider type exist.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -1418,7 +1430,8 @@ CREATE TABLE chats (
|
||||
agent_id uuid,
|
||||
pin_order integer DEFAULT 0 NOT NULL,
|
||||
last_read_message_id bigint,
|
||||
last_injected_context jsonb
|
||||
last_injected_context jsonb,
|
||||
dynamic_tools jsonb
|
||||
);
|
||||
|
||||
CREATE TABLE connection_logs (
|
||||
@@ -3770,14 +3783,14 @@ CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_chats_labels ON chats USING gin (labels);
|
||||
|
||||
CREATE INDEX idx_chats_last_model_config_id ON chats USING btree (last_model_config_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner ON chats USING btree (owner_id);
|
||||
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats USING btree (owner_id, updated_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_chats_parent_chat_id ON chats USING btree (parent_chat_id);
|
||||
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
-- First update any rows using the value we're about to remove.
|
||||
-- The column type is still the original chat_status at this point.
|
||||
UPDATE chats SET status = 'error' WHERE status = 'requires_action';
|
||||
|
||||
-- Drop the column (this is independent of the enum).
|
||||
ALTER TABLE chats DROP COLUMN IF EXISTS dynamic_tools;
|
||||
|
||||
-- Drop the partial index that references the chat_status enum type.
|
||||
-- It must be removed before the rename-create-cast-drop cycle
|
||||
-- because the index's WHERE clause (status = 'pending'::chat_status)
|
||||
-- would otherwise cause a cross-type comparison failure.
|
||||
DROP INDEX IF EXISTS idx_chats_pending;
|
||||
|
||||
-- Now recreate the enum without requires_action.
|
||||
-- We must use the rename-create-cast-drop pattern.
|
||||
ALTER TYPE chat_status RENAME TO chat_status_old;
|
||||
CREATE TYPE chat_status AS ENUM (
|
||||
'waiting',
|
||||
'pending',
|
||||
'running',
|
||||
'paused',
|
||||
'completed',
|
||||
'error'
|
||||
);
|
||||
ALTER TABLE chats ALTER COLUMN status DROP DEFAULT;
|
||||
ALTER TABLE chats ALTER COLUMN status TYPE chat_status USING status::text::chat_status;
|
||||
ALTER TABLE chats ALTER COLUMN status SET DEFAULT 'waiting';
|
||||
DROP TYPE chat_status_old;
|
||||
|
||||
-- Recreate the partial index.
|
||||
CREATE INDEX idx_chats_pending ON chats USING btree (status) WHERE (status = 'pending'::chat_status);
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TYPE chat_status ADD VALUE IF NOT EXISTS 'requires_action';
|
||||
|
||||
ALTER TABLE chats ADD COLUMN dynamic_tools JSONB DEFAULT NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE aibridge_interceptions
|
||||
DROP COLUMN IF EXISTS credential_kind,
|
||||
DROP COLUMN IF EXISTS credential_hint;
|
||||
|
||||
DROP TYPE IF EXISTS credential_kind;
|
||||
@@ -0,0 +1,12 @@
|
||||
CREATE TYPE credential_kind AS ENUM ('centralized', 'byok');
|
||||
|
||||
-- Records how each LLM request was authenticated and a masked credential
|
||||
-- identifier for audit purposes. Existing rows default to 'centralized'
|
||||
-- with an empty hint since we cannot retroactively determine their values.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN credential_kind credential_kind NOT NULL DEFAULT 'centralized',
|
||||
-- Length capped as a safety measure to ensure only masked values are stored.
|
||||
ADD COLUMN credential_hint CHARACTER VARYING(15) NOT NULL DEFAULT '';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_kind IS 'How the request was authenticated: centralized or byok.';
|
||||
COMMENT ON COLUMN aibridge_interceptions.credential_hint IS 'Masked credential identifier for audit (e.g. sk-a***efgh).';
|
||||
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS idx_chats_agent_id;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_agent_id ON chats(agent_id) WHERE agent_id IS NOT NULL;
|
||||
@@ -0,0 +1 @@
|
||||
CREATE INDEX idx_chats_owner_updated_id ON chats (owner_id, updated_at DESC, id DESC);
|
||||
@@ -0,0 +1,5 @@
|
||||
-- The GetChats ORDER BY changed from (updated_at, id) DESC to a 4-column
|
||||
-- expression sort (pinned-first flag, negated pin_order, updated_at, id).
|
||||
-- This index was purpose-built for the old sort and no longer provides
|
||||
-- read benefit. The simpler idx_chats_owner covers the owner_id filter.
|
||||
DROP INDEX IF EXISTS idx_chats_owner_updated_id;
|
||||
@@ -798,6 +798,7 @@ func (q *sqlQuerier) GetAuthorizedChats(ctx context.Context, arg GetChatsParams,
|
||||
&i.Chat.PinOrder,
|
||||
&i.Chat.LastReadMessageID,
|
||||
&i.Chat.LastInjectedContext,
|
||||
&i.Chat.DynamicTools,
|
||||
&i.HasUnread); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -868,6 +869,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1131,6 +1134,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, a
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.AIBridgeInterception.ProviderName,
|
||||
&i.AIBridgeInterception.CredentialKind,
|
||||
&i.AIBridgeInterception.CredentialHint,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1290,12 +1290,13 @@ func AllChatModeValues() []ChatMode {
|
||||
type ChatStatus string
|
||||
|
||||
const (
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusWaiting ChatStatus = "waiting"
|
||||
ChatStatusPending ChatStatus = "pending"
|
||||
ChatStatusRunning ChatStatus = "running"
|
||||
ChatStatusPaused ChatStatus = "paused"
|
||||
ChatStatusCompleted ChatStatus = "completed"
|
||||
ChatStatusError ChatStatus = "error"
|
||||
ChatStatusRequiresAction ChatStatus = "requires_action"
|
||||
)
|
||||
|
||||
func (e *ChatStatus) Scan(src interface{}) error {
|
||||
@@ -1340,7 +1341,8 @@ func (e ChatStatus) Valid() bool {
|
||||
ChatStatusRunning,
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError:
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -1354,6 +1356,7 @@ func AllChatStatusValues() []ChatStatus {
|
||||
ChatStatusPaused,
|
||||
ChatStatusCompleted,
|
||||
ChatStatusError,
|
||||
ChatStatusRequiresAction,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1543,6 +1546,64 @@ func AllCorsBehaviorValues() []CorsBehavior {
|
||||
}
|
||||
}
|
||||
|
||||
type CredentialKind string
|
||||
|
||||
const (
|
||||
CredentialKindCentralized CredentialKind = "centralized"
|
||||
CredentialKindByok CredentialKind = "byok"
|
||||
)
|
||||
|
||||
func (e *CredentialKind) Scan(src interface{}) error {
|
||||
switch s := src.(type) {
|
||||
case []byte:
|
||||
*e = CredentialKind(s)
|
||||
case string:
|
||||
*e = CredentialKind(s)
|
||||
default:
|
||||
return fmt.Errorf("unsupported scan type for CredentialKind: %T", src)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullCredentialKind struct {
|
||||
CredentialKind CredentialKind `json:"credential_kind"`
|
||||
Valid bool `json:"valid"` // Valid is true if CredentialKind is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (ns *NullCredentialKind) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
ns.CredentialKind, ns.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
ns.Valid = true
|
||||
return ns.CredentialKind.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullCredentialKind) Value() (driver.Value, error) {
|
||||
if !ns.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return string(ns.CredentialKind), nil
|
||||
}
|
||||
|
||||
func (e CredentialKind) Valid() bool {
|
||||
switch e {
|
||||
case CredentialKindCentralized,
|
||||
CredentialKindByok:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AllCredentialKindValues() []CredentialKind {
|
||||
return []CredentialKind{
|
||||
CredentialKindCentralized,
|
||||
CredentialKindByok,
|
||||
}
|
||||
}
|
||||
|
||||
type CryptoKeyFeature string
|
||||
|
||||
const (
|
||||
@@ -4040,6 +4101,10 @@ type AIBridgeInterception struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
// The provider instance name which may differ from provider when multiple instances of the same provider type exist.
|
||||
ProviderName string `db:"provider_name" json:"provider_name"`
|
||||
// How the request was authenticated: centralized or byok.
|
||||
CredentialKind CredentialKind `db:"credential_kind" json:"credential_kind"`
|
||||
// Masked credential identifier for audit (e.g. sk-a***efgh).
|
||||
CredentialHint string `db:"credential_hint" json:"credential_hint"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
@@ -4180,6 +4245,7 @@ type Chat struct {
|
||||
PinOrder int32 `db:"pin_order" json:"pin_order"`
|
||||
LastReadMessageID sql.NullInt64 `db:"last_read_message_id" json:"last_read_message_id"`
|
||||
LastInjectedContext pqtype.NullRawMessage `db:"last_injected_context" json:"last_injected_context"`
|
||||
DynamicTools pqtype.NullRawMessage `db:"dynamic_tools" json:"dynamic_tools"`
|
||||
}
|
||||
|
||||
type ChatDiffStatus struct {
|
||||
|
||||
@@ -81,8 +81,8 @@ func newMsgQueue(ctx context.Context, l Listener, le ListenerWithErr) *msgQueue
|
||||
}
|
||||
|
||||
func (q *msgQueue) run() {
|
||||
var batch [maxDrainBatch]msgOrErr
|
||||
for {
|
||||
// wait until there is something on the queue or we are closed
|
||||
q.cond.L.Lock()
|
||||
for q.size == 0 && !q.closed {
|
||||
q.cond.Wait()
|
||||
@@ -91,32 +91,28 @@ func (q *msgQueue) run() {
|
||||
q.cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
// Drain up to maxDrainBatch items while holding the lock.
|
||||
n := min(q.size, maxDrainBatch)
|
||||
for i := range n {
|
||||
batch[i] = q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
}
|
||||
q.size -= n
|
||||
item := q.q[q.front]
|
||||
q.front = (q.front + 1) % BufferSize
|
||||
q.size--
|
||||
q.cond.L.Unlock()
|
||||
|
||||
// Dispatch each message individually without holding the lock.
|
||||
for i := range n {
|
||||
item := batch[i]
|
||||
if item.err == nil {
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
// process item without holding lock
|
||||
if item.err == nil {
|
||||
// real message
|
||||
if q.l != nil {
|
||||
q.l(q.ctx, item.msg)
|
||||
continue
|
||||
}
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
q.le(q.ctx, item.msg, nil)
|
||||
continue
|
||||
}
|
||||
// unhittable
|
||||
continue
|
||||
}
|
||||
// if the listener wants errors, send it.
|
||||
if q.le != nil {
|
||||
q.le(q.ctx, nil, item.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -237,12 +233,6 @@ type PGPubsub struct {
|
||||
// for a subscriber before dropping messages.
|
||||
const BufferSize = 2048
|
||||
|
||||
// maxDrainBatch is the maximum number of messages to drain from the ring
|
||||
// buffer per iteration. Batching amortizes the cost of mutex
|
||||
// acquire/release and cond.Wait across many messages, improving drain
|
||||
// throughput during bursts.
|
||||
const maxDrainBatch = 256
|
||||
|
||||
// Subscribe calls the listener when an event matching the name is received.
|
||||
func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) {
|
||||
return p.subscribeQueue(event, newMsgQueue(context.Background(), listener, nil))
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetLostPeers(ctx context.Context) error
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
ClearChatMessageProviderResponseIDsByChatID(ctx context.Context, chatID uuid.UUID) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
@@ -128,6 +129,22 @@ type sqlcQuerier interface {
|
||||
// connection events (connect, disconnect, open, close) which are handled
|
||||
// separately by DeleteOldAuditLogConnectionEvents.
|
||||
DeleteOldAuditLogs(ctx context.Context, arg DeleteOldAuditLogsParams) (int64, error)
|
||||
// TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
// chat_files(created_at) for purge query performance.
|
||||
// See: https://github.com/coder/internal/issues/1438
|
||||
// Deletes chat files that are older than the given threshold and are
|
||||
// not referenced by any chat that is still active or was archived
|
||||
// within the same threshold window. This covers two cases:
|
||||
// 1. Orphaned files not linked to any chat.
|
||||
// 2. Files whose every referencing chat has been archived for longer
|
||||
// than the retention period.
|
||||
DeleteOldChatFiles(ctx context.Context, arg DeleteOldChatFilesParams) (int64, error)
|
||||
// Deletes chats that have been archived for longer than the given
|
||||
// threshold. Active (non-archived) chats are never deleted.
|
||||
// Related chat_messages, chat_diff_statuses, and
|
||||
// chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
// Parent/root references on child chats are SET NULL.
|
||||
DeleteOldChats(ctx context.Context, arg DeleteOldChatsParams) (int64, error)
|
||||
DeleteOldConnectionLogs(ctx context.Context, arg DeleteOldConnectionLogsParams) (int64, error)
|
||||
// Delete all notification messages which have not been updated for over a week.
|
||||
DeleteOldNotificationMessages(ctx context.Context) error
|
||||
@@ -152,7 +169,7 @@ type sqlcQuerier interface {
|
||||
DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error)
|
||||
DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error
|
||||
DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) error
|
||||
DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (int64, error)
|
||||
DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error
|
||||
DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error
|
||||
DeleteWorkspaceACLByID(ctx context.Context, id uuid.UUID) error
|
||||
@@ -199,6 +216,7 @@ type sqlcQuerier interface {
|
||||
GetAPIKeysByUserID(ctx context.Context, arg GetAPIKeysByUserIDParams) ([]APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveAISeatCount(ctx context.Context) (int64, error)
|
||||
GetActiveChatsByAgentID(ctx context.Context, agentID uuid.UUID) ([]Chat, error)
|
||||
GetActivePresetPrebuildSchedules(ctx context.Context) ([]TemplateVersionPresetPrebuildSchedule, error)
|
||||
GetActiveUserCount(ctx context.Context, includeSystem bool) (int64, error)
|
||||
GetActiveWorkspaceBuildsByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceBuild, error)
|
||||
@@ -255,16 +273,27 @@ type sqlcQuerier interface {
|
||||
// otherwise the setting defaults to true.
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
// Aggregates message-level metrics per chat for messages created
|
||||
// after the given timestamp. Uses message created_at so that
|
||||
// ongoing activity in long-running chats is captured each window.
|
||||
GetChatMessageSummariesPerChat(ctx context.Context, createdAfter time.Time) ([]GetChatMessageSummariesPerChatRow, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error)
|
||||
// Returns all model configurations for telemetry snapshot collection.
|
||||
GetChatModelConfigsForTelemetry(ctx context.Context) ([]GetChatModelConfigsForTelemetryRow, error)
|
||||
GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error)
|
||||
GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error)
|
||||
GetChatProviders(ctx context.Context) ([]ChatProvider, error)
|
||||
GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error)
|
||||
// Returns the chat retention period in days. Chats archived longer
|
||||
// than this and orphaned chat files older than this are purged by
|
||||
// dbpurge. Returns 30 (days) when no value has been configured.
|
||||
// A value of 0 disables chat purging entirely.
|
||||
GetChatRetentionDays(ctx context.Context) (int32, error)
|
||||
GetChatSystemPrompt(ctx context.Context) (string, error)
|
||||
// GetChatSystemPromptConfig returns both chat system prompt settings in a
|
||||
// single read to avoid torn reads between separate site-config lookups.
|
||||
@@ -283,6 +312,10 @@ type sqlcQuerier interface {
|
||||
GetChatWorkspaceTTL(ctx context.Context) (string, error)
|
||||
GetChats(ctx context.Context, arg GetChatsParams) ([]GetChatsRow, error)
|
||||
GetChatsByWorkspaceIDs(ctx context.Context, ids []uuid.UUID) ([]Chat, error)
|
||||
// Retrieves chats updated after the given timestamp for telemetry
|
||||
// snapshot collection. Uses updated_at so that long-running chats
|
||||
// still appear in each snapshot window while they are active.
|
||||
GetChatsUpdatedAfter(ctx context.Context, updatedAfter time.Time) ([]GetChatsUpdatedAfterRow, error)
|
||||
GetConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams) ([]GetConnectionLogsOffsetRow, error)
|
||||
GetCryptoKeyByFeatureAndSequence(ctx context.Context, arg GetCryptoKeyByFeatureAndSequenceParams) (CryptoKey, error)
|
||||
GetCryptoKeys(ctx context.Context) ([]CryptoKey, error)
|
||||
@@ -385,11 +418,12 @@ type sqlcQuerier interface {
|
||||
// per PR for state/additions/deletions/model (model comes from the
|
||||
// most recent chat).
|
||||
GetPRInsightsPerModel(ctx context.Context, arg GetPRInsightsPerModelParams) ([]GetPRInsightsPerModelRow, error)
|
||||
// Returns individual PR rows with cost for the recent PRs table.
|
||||
// Returns all individual PR rows with cost for the selected time range.
|
||||
// Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
// direct children (that lack their own PR), and deduped picks one row
|
||||
// per PR for metadata.
|
||||
GetPRInsightsRecentPRs(ctx context.Context, arg GetPRInsightsRecentPRsParams) ([]GetPRInsightsRecentPRsRow, error)
|
||||
// per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
// large result sets from direct API callers.
|
||||
GetPRInsightsPullRequests(ctx context.Context, arg GetPRInsightsPullRequestsParams) ([]GetPRInsightsPullRequestsRow, error)
|
||||
// PR Insights queries for the /agents analytics dashboard.
|
||||
// These aggregate data from chat_diff_statuses (PR metadata) joined
|
||||
// with chats and chat_messages (cost) to power the PR Insights view.
|
||||
@@ -478,8 +512,10 @@ type sqlcQuerier interface {
|
||||
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
|
||||
GetRunningPrebuiltWorkspaces(ctx context.Context) ([]GetRunningPrebuiltWorkspacesRow, error)
|
||||
GetRuntimeConfig(ctx context.Context, key string) (string, error)
|
||||
// Find chats that appear stuck (running but heartbeat has expired).
|
||||
// Used for recovery after coderd crashes or long hangs.
|
||||
// Find chats that appear stuck and need recovery. This covers:
|
||||
// 1. Running chats whose heartbeat has expired (worker crash).
|
||||
// 2. Chats awaiting client action (requires_action) past the
|
||||
// timeout threshold (client disappeared).
|
||||
GetStaleChats(ctx context.Context, staleThreshold time.Time) ([]Chat, error)
|
||||
GetTailnetPeers(ctx context.Context, id uuid.UUID) ([]TailnetPeer, error)
|
||||
GetTailnetTunnelPeerBindingsBatch(ctx context.Context, ids []uuid.UUID) ([]GetTailnetTunnelPeerBindingsBatchRow, error)
|
||||
@@ -860,11 +896,16 @@ type sqlcQuerier interface {
|
||||
SelectUsageEventsForPublishing(ctx context.Context, now time.Time) ([]UsageEvent, error)
|
||||
SoftDeleteChatMessageByID(ctx context.Context, id int64) error
|
||||
SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error
|
||||
SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error
|
||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||
//
|
||||
// This must be called from within a transaction. The lock will be automatically
|
||||
// released when the transaction ends.
|
||||
TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error)
|
||||
// Unarchives a chat (and its children). Stale file references are
|
||||
// handled automatically by FK cascades on chat_file_links: when
|
||||
// dbpurge deletes a chat_files row, the corresponding
|
||||
// chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
UnarchiveChatByID(ctx context.Context, id uuid.UUID) ([]Chat, error)
|
||||
// This will always work regardless of the current state of the template version.
|
||||
UnarchiveTemplateVersion(ctx context.Context, arg UnarchiveTemplateVersionParams) error
|
||||
@@ -971,6 +1012,7 @@ type sqlcQuerier interface {
|
||||
UpdateWorkspace(ctx context.Context, arg UpdateWorkspaceParams) (WorkspaceTable, error)
|
||||
UpdateWorkspaceACLByID(ctx context.Context, arg UpdateWorkspaceACLByIDParams) error
|
||||
UpdateWorkspaceAgentConnectionByID(ctx context.Context, arg UpdateWorkspaceAgentConnectionByIDParams) error
|
||||
UpdateWorkspaceAgentDirectoryByID(ctx context.Context, arg UpdateWorkspaceAgentDirectoryByIDParams) error
|
||||
UpdateWorkspaceAgentDisplayAppsByID(ctx context.Context, arg UpdateWorkspaceAgentDisplayAppsByIDParams) error
|
||||
UpdateWorkspaceAgentLifecycleStateByID(ctx context.Context, arg UpdateWorkspaceAgentLifecycleStateByIDParams) error
|
||||
UpdateWorkspaceAgentLogOverflowByID(ctx context.Context, arg UpdateWorkspaceAgentLogOverflowByIDParams) error
|
||||
@@ -1006,6 +1048,7 @@ type sqlcQuerier interface {
|
||||
UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error)
|
||||
UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error)
|
||||
UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error
|
||||
UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error
|
||||
UpsertChatSystemPrompt(ctx context.Context, value string) error
|
||||
UpsertChatTemplateAllowlist(ctx context.Context, templateAllowlist string) error
|
||||
UpsertChatUsageLimitConfig(ctx context.Context, arg UpsertChatUsageLimitConfigParams) (ChatUsageLimitConfig, error)
|
||||
|
||||
@@ -7376,7 +7376,7 @@ func TestUserSecretsCRUDOperations(t *testing.T) {
|
||||
assert.Equal(t, "WORKFLOW_ENV", updatedSecret.EnvName) // EnvName unchanged
|
||||
|
||||
// 6. DELETE
|
||||
err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
_, err = db.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: testUser.ID,
|
||||
Name: "workflow-secret",
|
||||
})
|
||||
@@ -9085,10 +9085,11 @@ func TestUpdateAIBridgeInterceptionEnded(t *testing.T) {
|
||||
|
||||
for _, uid := range []uuid.UUID{{1}, {2}, {3}} {
|
||||
insertParams := database.InsertAIBridgeInterceptionParams{
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
ID: uid,
|
||||
InitiatorID: user.ID,
|
||||
Metadata: json.RawMessage("{}"),
|
||||
Client: sql.NullString{String: "client", Valid: true},
|
||||
CredentialKind: database.CredentialKindCentralized,
|
||||
}
|
||||
|
||||
intc, err := db.InsertAIBridgeInterception(ctx, insertParams)
|
||||
@@ -10407,11 +10408,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10441,11 +10441,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
// RecentPRs ordered by created_at DESC: chatB is newer.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10490,11 +10489,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(1), summary.TotalPrsMerged)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10532,11 +10530,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(9_000_000), summary.TotalCostMicros)
|
||||
|
||||
// RecentPRs should return 1 row with the full tree cost.
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10574,11 +10571,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10620,11 +10616,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(17_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10657,11 +10652,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(2), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(10_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 2)
|
||||
@@ -10694,11 +10688,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(15_000_000), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10723,11 +10716,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(1), summary.TotalPrsCreated)
|
||||
assert.Equal(t, int64(0), summary.TotalCostMicros)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10766,11 +10758,10 @@ func TestGetPRInsights(t *testing.T) {
|
||||
require.Len(t, byModel, 1)
|
||||
assert.Equal(t, modelName, byModel[0].DisplayName)
|
||||
|
||||
recent, err := store.GetPRInsightsRecentPRs(context.Background(), database.GetPRInsightsRecentPRsParams{
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
LimitVal: 20,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, recent, 1)
|
||||
@@ -10802,6 +10793,30 @@ func TestGetPRInsights(t *testing.T) {
|
||||
assert.Equal(t, int64(8_000_000), summary.TotalCostMicros)
|
||||
assert.Equal(t, int64(5_000_000), summary.MergedCostMicros)
|
||||
})
|
||||
|
||||
t.Run("AllPRsReturnedWithSafetyCap", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, userID, mcID := setupChatInfra(t)
|
||||
|
||||
// Create 25 distinct PRs — more than the old LIMIT 20 — and
|
||||
// verify all are returned.
|
||||
const prCount = 25
|
||||
for i := range prCount {
|
||||
chat := createChat(t, store, userID, mcID, fmt.Sprintf("chat-%d", i))
|
||||
insertCostMessage(t, store, chat.ID, userID, mcID, 1_000_000)
|
||||
linkPR(t, store, chat.ID,
|
||||
fmt.Sprintf("https://github.com/org/repo/pull/%d", 100+i),
|
||||
"merged", fmt.Sprintf("fix: pr-%d", i), 10, 2, 1)
|
||||
}
|
||||
|
||||
recent, err := store.GetPRInsightsPullRequests(context.Background(), database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: noOwner,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, recent, prCount, "all PRs within the date range should be returned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatPinOrderQueries(t *testing.T) {
|
||||
|
||||
+550
-89
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
-- name: InsertAIBridgeInterception :one
|
||||
INSERT INTO aibridge_interceptions (
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id
|
||||
id, api_key_id, initiator_id, provider, provider_name, model, metadata, started_at, client, client_session_id, thread_parent_id, thread_root_id, credential_kind, credential_hint
|
||||
) VALUES (
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid
|
||||
@id, @api_key_id, @initiator_id, @provider, @provider_name, @model, COALESCE(@metadata::jsonb, '{}'::jsonb), @started_at, @client, sqlc.narg('client_session_id'), sqlc.narg('thread_parent_interception_id')::uuid, sqlc.narg('thread_root_interception_id')::uuid, @credential_kind, @credential_hint
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -18,3 +18,37 @@ FROM chat_files cf
|
||||
JOIN chat_file_links cfl ON cfl.file_id = cf.id
|
||||
WHERE cfl.chat_id = @chat_id::uuid
|
||||
ORDER BY cf.created_at ASC;
|
||||
|
||||
-- TODO(cian): Add indexes on chats(archived, updated_at) and
|
||||
-- chat_files(created_at) for purge query performance.
|
||||
-- See: https://github.com/coder/internal/issues/1438
|
||||
-- name: DeleteOldChatFiles :execrows
|
||||
-- Deletes chat files that are older than the given threshold and are
|
||||
-- not referenced by any chat that is still active or was archived
|
||||
-- within the same threshold window. This covers two cases:
|
||||
-- 1. Orphaned files not linked to any chat.
|
||||
-- 2. Files whose every referencing chat has been archived for longer
|
||||
-- than the retention period.
|
||||
WITH kept_file_ids AS (
|
||||
-- NOTE: This uses updated_at as a proxy for archive time
|
||||
-- because there is no archived_at column. Correctness
|
||||
-- requires that updated_at is never backdated on archived
|
||||
-- chats. See ArchiveChatByID.
|
||||
SELECT DISTINCT cfl.file_id
|
||||
FROM chat_file_links cfl
|
||||
JOIN chats c ON c.id = cfl.chat_id
|
||||
WHERE c.archived = false
|
||||
OR c.updated_at >= @before_time::timestamptz
|
||||
),
|
||||
deletable AS (
|
||||
SELECT cf.id
|
||||
FROM chat_files cf
|
||||
LEFT JOIN kept_file_ids k ON cf.id = k.file_id
|
||||
WHERE cf.created_at < @before_time::timestamptz
|
||||
AND k.file_id IS NULL
|
||||
ORDER BY cf.created_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chat_files
|
||||
USING deletable
|
||||
WHERE chat_files.id = deletable.id;
|
||||
|
||||
@@ -173,11 +173,12 @@ JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
GROUP BY d.model_config_id, d.display_name, d.model, d.provider
|
||||
ORDER BY total_prs DESC;
|
||||
|
||||
-- name: GetPRInsightsRecentPRs :many
|
||||
-- Returns individual PR rows with cost for the recent PRs table.
|
||||
-- name: GetPRInsightsPullRequests :many
|
||||
-- Returns all individual PR rows with cost for the selected time range.
|
||||
-- Uses two CTEs: pr_costs sums cost for the PR-linked chat and its
|
||||
-- direct children (that lack their own PR), and deduped picks one row
|
||||
-- per PR for metadata.
|
||||
-- per PR for metadata. A safety-cap LIMIT guards against unexpectedly
|
||||
-- large result sets from direct API callers.
|
||||
WITH pr_costs AS (
|
||||
SELECT
|
||||
prc.pr_key,
|
||||
@@ -264,4 +265,4 @@ SELECT * FROM (
|
||||
JOIN pr_costs pc ON pc.pr_key = d.pr_key
|
||||
) sub
|
||||
ORDER BY sub.created_at DESC
|
||||
LIMIT @limit_val::int;
|
||||
LIMIT 500;
|
||||
|
||||
@@ -10,9 +10,14 @@ FROM chats
|
||||
ORDER BY (id = @id::uuid) DESC, created_at ASC, id ASC;
|
||||
|
||||
-- name: UnarchiveChatByID :many
|
||||
-- Unarchives a chat (and its children). Stale file references are
|
||||
-- handled automatically by FK cascades on chat_file_links: when
|
||||
-- dbpurge deletes a chat_files row, the corresponding
|
||||
-- chat_file_links rows are cascade-deleted by PostgreSQL.
|
||||
WITH chats AS (
|
||||
UPDATE chats
|
||||
SET archived = false, updated_at = NOW()
|
||||
UPDATE chats SET
|
||||
archived = false,
|
||||
updated_at = NOW()
|
||||
WHERE id = @id::uuid OR root_chat_id = @id::uuid
|
||||
RETURNING *
|
||||
)
|
||||
@@ -348,20 +353,18 @@ WHERE
|
||||
ELSE chats.archived = sqlc.narg('archived') :: boolean
|
||||
END
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
-- Cursor pagination: the last element on a page acts as the cursor.
|
||||
-- The 4-tuple matches the ORDER BY below. All columns sort DESC
|
||||
-- (pin_order is negated so lower values sort first in DESC order),
|
||||
-- which lets us use a single tuple < comparison.
|
||||
WHEN @after_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the updated_at field, so select all
|
||||
-- rows before the cursor.
|
||||
(updated_at, id) < (
|
||||
(CASE WHEN pin_order > 0 THEN 1 ELSE 0 END, -pin_order, updated_at, id) < (
|
||||
SELECT
|
||||
updated_at, id
|
||||
CASE WHEN c2.pin_order > 0 THEN 1 ELSE 0 END, -c2.pin_order, c2.updated_at, c2.id
|
||||
FROM
|
||||
chats
|
||||
chats c2
|
||||
WHERE
|
||||
id = @after_id
|
||||
c2.id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
@@ -373,9 +376,15 @@ WHERE
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedChats
|
||||
-- @authorize_filter
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(updated_at, id) DESC OFFSET @offset_opt
|
||||
-- Pinned chats (pin_order > 0) sort before unpinned ones. Within
|
||||
-- pinned chats, lower pin_order values come first. The negation
|
||||
-- trick (-pin_order) keeps all sort columns DESC so the cursor
|
||||
-- tuple < comparison works with uniform direction.
|
||||
CASE WHEN pin_order > 0 THEN 1 ELSE 0 END DESC,
|
||||
-pin_order DESC,
|
||||
updated_at DESC,
|
||||
id DESC
|
||||
OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- The chat list is unbounded and expected to grow large.
|
||||
-- Default to 50 to prevent accidental excessively large queries.
|
||||
@@ -394,7 +403,8 @@ INSERT INTO chats (
|
||||
mode,
|
||||
status,
|
||||
mcp_server_ids,
|
||||
labels
|
||||
labels,
|
||||
dynamic_tools
|
||||
) VALUES (
|
||||
@owner_id::uuid,
|
||||
sqlc.narg('workspace_id')::uuid,
|
||||
@@ -407,7 +417,8 @@ INSERT INTO chats (
|
||||
sqlc.narg('mode')::chat_mode,
|
||||
@status::chat_status,
|
||||
COALESCE(@mcp_server_ids::uuid[], '{}'::uuid[]),
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb)
|
||||
COALESCE(sqlc.narg('labels')::jsonb, '{}'::jsonb),
|
||||
sqlc.narg('dynamic_tools')::jsonb
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -664,15 +675,19 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
-- Find chats that appear stuck and need recovery. This covers:
|
||||
-- 1. Running chats whose heartbeat has expired (worker crash).
|
||||
-- 2. Chats awaiting client action (requires_action) past the
|
||||
-- timeout threshold (client disappeared).
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chats
|
||||
WHERE
|
||||
status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz;
|
||||
(status = 'running'::chat_status
|
||||
AND heartbeat_at < @stale_threshold::timestamptz)
|
||||
OR (status = 'requires_action'::chat_status
|
||||
AND updated_at < @stale_threshold::timestamptz);
|
||||
|
||||
-- name: UpdateChatHeartbeats :many
|
||||
-- Bumps the heartbeat timestamp for the given set of chat IDs,
|
||||
@@ -1220,3 +1235,88 @@ LIMIT 1;
|
||||
UPDATE chats
|
||||
SET last_read_message_id = @last_read_message_id::bigint
|
||||
WHERE id = @id::uuid;
|
||||
|
||||
-- name: DeleteOldChats :execrows
|
||||
-- Deletes chats that have been archived for longer than the given
|
||||
-- threshold. Active (non-archived) chats are never deleted.
|
||||
-- Related chat_messages, chat_diff_statuses, and
|
||||
-- chat_queued_messages are removed via ON DELETE CASCADE.
|
||||
-- Parent/root references on child chats are SET NULL.
|
||||
WITH deletable AS (
|
||||
SELECT id
|
||||
FROM chats
|
||||
WHERE archived = true
|
||||
AND updated_at < @before_time::timestamptz
|
||||
ORDER BY updated_at ASC
|
||||
LIMIT @limit_count
|
||||
)
|
||||
DELETE FROM chats
|
||||
USING deletable
|
||||
WHERE chats.id = deletable.id
|
||||
AND chats.archived = true;
|
||||
|
||||
-- name: GetChatsUpdatedAfter :many
|
||||
-- Retrieves chats updated after the given timestamp for telemetry
|
||||
-- snapshot collection. Uses updated_at so that long-running chats
|
||||
-- still appear in each snapshot window while they are active.
|
||||
SELECT
|
||||
id, owner_id, created_at, updated_at, status,
|
||||
(parent_chat_id IS NOT NULL)::bool AS has_parent,
|
||||
root_chat_id, workspace_id,
|
||||
mode, archived, last_model_config_id
|
||||
FROM chats
|
||||
WHERE updated_at > @updated_after;
|
||||
|
||||
-- name: GetChatMessageSummariesPerChat :many
|
||||
-- Aggregates message-level metrics per chat for messages created
|
||||
-- after the given timestamp. Uses message created_at so that
|
||||
-- ongoing activity in long-running chats is captured each window.
|
||||
SELECT
|
||||
cm.chat_id,
|
||||
COUNT(*)::bigint AS message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'user')::bigint AS user_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'assistant')::bigint AS assistant_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'tool')::bigint AS tool_message_count,
|
||||
COUNT(*) FILTER (WHERE cm.role = 'system')::bigint AS system_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(cm.reasoning_tokens), 0)::bigint AS total_reasoning_tokens,
|
||||
COALESCE(SUM(cm.cache_creation_tokens), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(cm.cache_read_tokens), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COALESCE(SUM(cm.runtime_ms), 0)::bigint AS total_runtime_ms,
|
||||
COUNT(DISTINCT cm.model_config_id)::bigint AS distinct_model_count,
|
||||
COUNT(*) FILTER (WHERE cm.compressed)::bigint AS compressed_message_count
|
||||
FROM chat_messages cm
|
||||
WHERE cm.created_at > @created_after
|
||||
AND cm.deleted = false
|
||||
GROUP BY cm.chat_id;
|
||||
|
||||
-- name: GetChatModelConfigsForTelemetry :many
|
||||
-- Returns all model configurations for telemetry snapshot collection.
|
||||
SELECT id, provider, model, context_limit, enabled, is_default
|
||||
FROM chat_model_configs
|
||||
WHERE deleted = false;
|
||||
-- name: GetActiveChatsByAgentID :many
|
||||
SELECT *
|
||||
FROM chats
|
||||
WHERE agent_id = @agent_id::uuid
|
||||
AND archived = false
|
||||
-- Active statuses only: waiting, pending, running, paused,
|
||||
-- requires_action.
|
||||
-- Excludes completed and error (terminal states).
|
||||
AND status IN ('waiting', 'running', 'paused', 'pending', 'requires_action')
|
||||
ORDER BY updated_at DESC;
|
||||
|
||||
-- name: ClearChatMessageProviderResponseIDsByChatID :exec
|
||||
UPDATE chat_messages
|
||||
SET provider_response_id = NULL
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND provider_response_id IS NOT NULL;
|
||||
|
||||
-- name: SoftDeleteContextFileMessages :exec
|
||||
UPDATE chat_messages SET deleted = true
|
||||
WHERE chat_id = @chat_id::uuid
|
||||
AND deleted = false
|
||||
AND content::jsonb @> '[{"type": "context-file"}]';
|
||||
|
||||
@@ -195,7 +195,8 @@ SELECT
|
||||
w.id AS workspace_id,
|
||||
COALESCE(w.name, '') AS workspace_name,
|
||||
-- Include the name of the provisioner_daemon associated to the job
|
||||
COALESCE(pd.name, '') AS worker_name
|
||||
COALESCE(pd.name, '') AS worker_name,
|
||||
wb.transition as workspace_build_transition
|
||||
FROM
|
||||
provisioner_jobs pj
|
||||
LEFT JOIN
|
||||
@@ -240,7 +241,8 @@ GROUP BY
|
||||
t.icon,
|
||||
w.id,
|
||||
w.name,
|
||||
pd.name
|
||||
pd.name,
|
||||
wb.transition
|
||||
ORDER BY
|
||||
pj.created_at DESC
|
||||
LIMIT
|
||||
|
||||
@@ -236,3 +236,20 @@ VALUES ('agents_workspace_ttl', @workspace_ttl::text)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET value = @workspace_ttl::text
|
||||
WHERE site_configs.key = 'agents_workspace_ttl';
|
||||
|
||||
-- name: GetChatRetentionDays :one
|
||||
-- Returns the chat retention period in days. Chats archived longer
|
||||
-- than this and orphaned chat files older than this are purged by
|
||||
-- dbpurge. Returns 30 (days) when no value has been configured.
|
||||
-- A value of 0 disables chat purging entirely.
|
||||
SELECT COALESCE(
|
||||
(SELECT value::integer FROM site_configs
|
||||
WHERE key = 'agents_chat_retention_days'),
|
||||
30
|
||||
) :: integer AS retention_days;
|
||||
|
||||
-- name: UpsertChatRetentionDays :exec
|
||||
INSERT INTO site_configs (key, value)
|
||||
VALUES ('agents_chat_retention_days', CAST(@retention_days AS integer)::text)
|
||||
ON CONFLICT (key) DO UPDATE SET value = CAST(@retention_days AS integer)::text
|
||||
WHERE site_configs.key = 'agents_chat_retention_days';
|
||||
|
||||
@@ -56,6 +56,6 @@ SET
|
||||
WHERE user_id = @user_id AND name = @name
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteUserSecretByUserIDAndName :exec
|
||||
-- name: DeleteUserSecretByUserIDAndName :execrows
|
||||
DELETE FROM user_secrets
|
||||
WHERE user_id = @user_id AND name = @name;
|
||||
|
||||
@@ -190,6 +190,14 @@ SET
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: UpdateWorkspaceAgentDirectoryByID :exec
|
||||
UPDATE
|
||||
workspace_agents
|
||||
SET
|
||||
directory = $2, updated_at = $3
|
||||
WHERE
|
||||
id = $1;
|
||||
|
||||
-- name: GetWorkspaceAgentLogsAfter :many
|
||||
SELECT
|
||||
*
|
||||
|
||||
+257
-77
@@ -137,8 +137,9 @@ func publishChatConfigEvent(logger slog.Logger, ps dbpubsub.Pubsub, kind pubsub.
|
||||
func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
logger := api.Logger.Named("chat_watcher")
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat watch stream.",
|
||||
@@ -146,54 +147,44 @@ func (api *API) watchChats(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatEvent(
|
||||
func(ctx context.Context, payload pubsub.ChatEvent, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// The encoder is only written from the SubscribeWithErr callback,
|
||||
// which delivers serially per subscription. Do not add a second
|
||||
// write path without introducing synchronization.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
cancelSubscribe, err := api.Pubsub.SubscribeWithErr(pubsub.ChatWatchEventChannel(apiKey.UserID),
|
||||
pubsub.HandleChatWatchEvent(
|
||||
func(ctx context.Context, payload codersdk.ChatWatchEvent, err error) {
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "chat event subscription error", slog.Error(err))
|
||||
logger.Error(ctx, "chat watch event subscription error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: payload,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat event", slog.Error(err))
|
||||
if err := encoder.Encode(payload); err != nil {
|
||||
logger.Debug(ctx, "failed to send chat watch event", slog.Error(err))
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
},
|
||||
))
|
||||
if err != nil {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Internal error subscribing to chat events.",
|
||||
Detail: err.Error(),
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat subscribe error event", slog.Error(err))
|
||||
}
|
||||
logger.Error(ctx, "failed to subscribe to chat watch events", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, "Failed to subscribe to chat events.")
|
||||
return
|
||||
}
|
||||
defer cancelSubscribe()
|
||||
|
||||
// Send initial ping to signal the connection is ready.
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypePing,
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat ping event", slog.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: chatsByWorkspace returns a mapping of workspace ID to
|
||||
@@ -398,6 +389,10 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use
|
||||
// from large dynamic tool schemas.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
|
||||
var req codersdk.CreateChatRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
@@ -488,6 +483,50 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.UnsafeDynamicTools) > 250 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Too many dynamic tools.",
|
||||
Detail: "Maximum 250 dynamic tools per chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate that dynamic tool names are non-empty and unique
|
||||
// within the list. Name collision with built-in tools is
|
||||
// checked at chatloop time when the full tool set is known.
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
seenNames := make(map[string]struct{}, len(req.UnsafeDynamicTools))
|
||||
for _, dt := range req.UnsafeDynamicTools {
|
||||
if dt.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Dynamic tool name must not be empty.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if _, exists := seenNames[dt.Name]; exists {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Duplicate dynamic tool name.",
|
||||
Detail: fmt.Sprintf("Tool %q appears more than once.", dt.Name),
|
||||
})
|
||||
return
|
||||
}
|
||||
seenNames[dt.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicToolsJSON json.RawMessage
|
||||
if len(req.UnsafeDynamicTools) > 0 {
|
||||
var err error
|
||||
dynamicToolsJSON, err = json.Marshal(req.UnsafeDynamicTools)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal dynamic tools.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
chat, err := api.chatDaemon.CreateChat(ctx, chatd.CreateOptions{
|
||||
OwnerID: apiKey.UserID,
|
||||
WorkspaceID: workspaceSelection.WorkspaceID,
|
||||
@@ -497,6 +536,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
InitialUserContent: contentBlocks,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
})
|
||||
if err != nil {
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
@@ -1770,9 +1810,9 @@ func (api *API) patchChat(rw http.ResponseWriter, r *http.Request) {
|
||||
// - pinOrder > 0 && already pinned: reorder (shift
|
||||
// neighbors, clamp to [1, count]).
|
||||
// - pinOrder > 0 && not pinned: append to end. The
|
||||
// requested value is intentionally ignored because
|
||||
// PinChatByID also bumps updated_at to keep the
|
||||
// chat visible in the paginated sidebar.
|
||||
// requested value is intentionally ignored; the
|
||||
// SQL ORDER BY sorts pinned chats first so they
|
||||
// appear on page 1 of the paginated sidebar.
|
||||
var err error
|
||||
errMsg := "Failed to pin chat."
|
||||
switch {
|
||||
@@ -2127,6 +2167,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_streamer").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@@ -2149,7 +2190,22 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
sendEvent, senderClosed, err := httpapi.OneWayWebSocketEventSender(api.Logger)(rw, r)
|
||||
// Subscribe before accepting the WebSocket so that failures
|
||||
// can still be reported as normal HTTP errors.
|
||||
snapshot, events, cancelSub, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
// Subscribe only fails today when the receiver is nil, which
|
||||
// the chatDaemon == nil guard above already catches. This is
|
||||
// defensive against future Subscribe failure modes.
|
||||
if !ok {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer cancelSub()
|
||||
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to open chat stream.",
|
||||
@@ -2157,41 +2213,30 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
snapshot, events, cancel, ok := api.chatDaemon.Subscribe(ctx, chatID, r.Header, afterMessageID)
|
||||
if !ok {
|
||||
if err := sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeError,
|
||||
Data: codersdk.Response{
|
||||
Message: "Chat streaming is not available.",
|
||||
Detail: "Chat stream state is not configured.",
|
||||
},
|
||||
}); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream unavailable event", slog.Error(err))
|
||||
}
|
||||
// Ensure the WebSocket is closed so senderClosed
|
||||
// completes and the handler can return.
|
||||
<-senderClosed
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
<-senderClosed
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
_ = conn.CloseRead(context.Background())
|
||||
|
||||
ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText)
|
||||
defer wsNetConn.Close()
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, logger, cancel, conn)
|
||||
|
||||
// Mark the chat as read when the stream connects and again
|
||||
// when it disconnects so we avoid per-message API calls while
|
||||
// messages are actively streaming.
|
||||
api.markChatAsRead(ctx, chatID)
|
||||
defer api.markChatAsRead(context.WithoutCancel(ctx), chatID)
|
||||
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
|
||||
sendChatStreamBatch := func(batch []codersdk.ChatStreamEvent) error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
return sendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: batch,
|
||||
})
|
||||
return encoder.Encode(batch)
|
||||
}
|
||||
|
||||
drainChatStreamBatch := func(
|
||||
@@ -2224,7 +2269,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
end = len(snapshot)
|
||||
}
|
||||
if err := sendChatStreamBatch(snapshot[start:end]); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream snapshot", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2233,8 +2278,6 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-senderClosed:
|
||||
return
|
||||
case firstEvent, ok := <-events:
|
||||
if !ok {
|
||||
return
|
||||
@@ -2244,7 +2287,7 @@ func (api *API) streamChat(rw http.ResponseWriter, r *http.Request) {
|
||||
chatStreamBatchSize,
|
||||
)
|
||||
if err := sendChatStreamBatch(batch); err != nil {
|
||||
api.Logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
logger.Debug(ctx, "failed to send chat stream event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if streamClosed {
|
||||
@@ -2259,6 +2302,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
chatID := chat.ID
|
||||
logger := api.Logger.Named("chat_interrupt").With(slog.F("chat_id", chatID))
|
||||
|
||||
if api.chatDaemon != nil {
|
||||
chat = api.chatDaemon.InterruptChat(ctx, chat)
|
||||
@@ -2272,8 +2316,7 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
if updateErr != nil {
|
||||
api.Logger.Error(ctx, "failed to mark chat as waiting",
|
||||
slog.F("chat_id", chatID), slog.Error(updateErr))
|
||||
logger.Error(ctx, "failed to mark chat as waiting", slog.Error(updateErr))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to interrupt chat.",
|
||||
Detail: updateErr.Error(),
|
||||
@@ -3183,6 +3226,70 @@ func (api *API) putChatWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// @Summary Get chat retention days
|
||||
// @ID get-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Produce json
|
||||
// @Success 200 {object} codersdk.ChatRetentionDaysResponse
|
||||
// @Router /experimental/chats/config/retention-days [get]
|
||||
// @x-apidocgen {"skip": true}
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
func (api *API) getChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
retentionDays, err := api.Database.GetChatRetentionDays(ctx)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatRetentionDaysResponse{
|
||||
RetentionDays: retentionDays,
|
||||
})
|
||||
}
|
||||
|
||||
// Keep in sync with retentionDaysMaximum in
|
||||
// site/src/pages/AgentsPage/AgentSettingsBehaviorPageView.tsx.
|
||||
const retentionDaysMaximum = 3650 // ~10 years
|
||||
|
||||
// @Summary Update chat retention days
|
||||
// @ID update-chat-retention-days
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Chats
|
||||
// @Accept json
|
||||
// @Param request body codersdk.UpdateChatRetentionDaysRequest true "Request body"
|
||||
// @Success 204
|
||||
// @Router /experimental/chats/config/retention-days [put]
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) putChatRetentionDays(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
var req codersdk.UpdateChatRetentionDaysRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if req.RetentionDays < 0 || req.RetentionDays > retentionDaysMaximum {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Retention days must be between 0 and %d.", retentionDaysMaximum),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := api.Database.UpsertChatRetentionDays(ctx, req.RetentionDays); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to update chat retention days.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler.
|
||||
@@ -5519,7 +5626,7 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
previousSummary database.GetPRInsightsSummaryRow
|
||||
timeSeries []database.GetPRInsightsTimeSeriesRow
|
||||
byModel []database.GetPRInsightsPerModelRow
|
||||
recentPRs []database.GetPRInsightsRecentPRsRow
|
||||
recentPRs []database.GetPRInsightsPullRequestsRow
|
||||
)
|
||||
|
||||
eg, egCtx := errgroup.WithContext(ctx)
|
||||
@@ -5567,11 +5674,10 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
eg.Go(func() error {
|
||||
var err error
|
||||
recentPRs, err = api.Database.GetPRInsightsRecentPRs(egCtx, database.GetPRInsightsRecentPRsParams{
|
||||
recentPRs, err = api.Database.GetPRInsightsPullRequests(egCtx, database.GetPRInsightsPullRequestsParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
OwnerID: ownerID,
|
||||
LimitVal: 20,
|
||||
})
|
||||
return err
|
||||
})
|
||||
@@ -5681,9 +5787,83 @@ func (api *API) prInsights(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.PRInsightsResponse{
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
RecentPRs: prEntries,
|
||||
Summary: summary,
|
||||
TimeSeries: tsEntries,
|
||||
ByModel: modelEntries,
|
||||
PullRequests: prEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) postChatToolResults(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Cap the raw request body to prevent excessive memory use.
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, int64(2*maxSystemPromptLenBytes))
|
||||
var req codersdk.SubmitToolResultsRequest
|
||||
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Results) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one tool result is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fast-path check outside the transaction. The authoritative
|
||||
// check happens inside SubmitToolResults under a row lock.
|
||||
if chat.Status != database.ChatStatusRequiresAction {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: fmt.Sprintf("Chat status is %q, expected %q.", chat.Status, database.ChatStatusRequiresAction),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var dynamicTools json.RawMessage
|
||||
if chat.DynamicTools.Valid {
|
||||
dynamicTools = chat.DynamicTools.RawMessage
|
||||
}
|
||||
|
||||
err := api.chatDaemon.SubmitToolResults(ctx, chatd.SubmitToolResultsOptions{
|
||||
ChatID: chat.ID,
|
||||
UserID: apiKey.UserID,
|
||||
ModelConfigID: chat.LastModelConfigID,
|
||||
Results: req.Results,
|
||||
DynamicTools: dynamicTools,
|
||||
})
|
||||
if err != nil {
|
||||
var validationErr *chatd.ToolResultValidationError
|
||||
var conflictErr *chatd.ToolResultStatusConflictError
|
||||
switch {
|
||||
case errors.As(err, &conflictErr):
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Chat is not waiting for tool results.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
case errors.As(err, &validationErr):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: validationErr.Message,
|
||||
Detail: validationErr.Detail,
|
||||
})
|
||||
default:
|
||||
api.Logger.Error(ctx, "tool results submission failed",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error submitting tool results.",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
+630
-116
@@ -16,7 +16,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -268,17 +270,10 @@ func TestPostChats(t *testing.T) {
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Member without agents-access should be denied.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
_, err := memberClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -288,6 +283,7 @@ func TestPostChats(t *testing.T) {
|
||||
})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
})
|
||||
|
||||
t.Run("HidesSystemPromptMessages", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -756,15 +752,7 @@ func TestListChats(t *testing.T) {
|
||||
// returning empty because no chats exist.
|
||||
memberClientRaw, member := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Strip the auto-assigned agents-access role to test
|
||||
// the denied case.
|
||||
_, err := client.Client.UpdateUserRoles(ctx, member.Username, codersdk.UpdateRoles{
|
||||
Roles: []string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
_, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
@@ -888,6 +876,186 @@ func TestListChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, allChats, totalChats)
|
||||
})
|
||||
|
||||
// Test that a pinned chat with an old updated_at appears on page 1.
|
||||
t.Run("PinnedOnFirstPage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create the chat that will later be pinned. It gets the
|
||||
// earliest updated_at because it is inserted first.
|
||||
pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "pinned-chat",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fill page 1 with newer chats so the pinned chat would
|
||||
// normally be pushed off the first page (default limit 50).
|
||||
const fillerCount = 51
|
||||
fillerChats := make([]codersdk.Chat, 0, fillerCount)
|
||||
for i := range fillerCount {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("filler-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
fillerChats = append(fillerChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach a terminal status so
|
||||
// updated_at is stable before paginating. A single
|
||||
// polling loop checks every chat per tick to avoid
|
||||
// O(N) separate Eventually loops.
|
||||
allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...)
|
||||
pending := make(map[uuid.UUID]struct{}, len(allCreated))
|
||||
for _, c := range allCreated {
|
||||
pending[c.ID] = struct{}{}
|
||||
}
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: fillerCount + 10},
|
||||
})
|
||||
if listErr != nil {
|
||||
return false
|
||||
}
|
||||
for _, ch := range all {
|
||||
if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning {
|
||||
delete(pending, ch.ID)
|
||||
}
|
||||
}
|
||||
return len(pending) == 0
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the earliest chat.
|
||||
err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch page 1 with default limit (50).
|
||||
page1, err := client.ListChats(ctx, &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: 50},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// The pinned chat must appear on page 1.
|
||||
page1IDs := make(map[uuid.UUID]struct{}, len(page1))
|
||||
for _, c := range page1 {
|
||||
page1IDs[c.ID] = struct{}{}
|
||||
}
|
||||
_, found := page1IDs[pinnedChat.ID]
|
||||
require.True(t, found, "pinned chat should appear on page 1")
|
||||
|
||||
// The pinned chat should be the first item in the list.
|
||||
require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first")
|
||||
})
|
||||
|
||||
// Test cursor pagination with a mix of pinned and unpinned chats.
|
||||
t.Run("CursorWithPins", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, _ := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
// Create 5 chats: 2 will be pinned, 3 unpinned.
|
||||
const totalChats = 5
|
||||
createdChats := make([]codersdk.Chat, 0, totalChats)
|
||||
for i := range totalChats {
|
||||
c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: fmt.Sprintf("cursor-pin-chat-%d", i),
|
||||
}},
|
||||
})
|
||||
require.NoError(t, createErr)
|
||||
createdChats = append(createdChats, c)
|
||||
}
|
||||
|
||||
// Wait for all chats to reach terminal status.
|
||||
// Check each chat by ID rather than fetching the full list.
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
for _, c := range createdChats {
|
||||
ch, err := client.GetChat(ctx, c.ID)
|
||||
require.NoError(t, err, "GetChat should succeed for just-created chat %s", c.ID)
|
||||
if ch.Status == codersdk.ChatStatusPending || ch.Status == codersdk.ChatStatusRunning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, testutil.IntervalFast)
|
||||
|
||||
// Pin the first two chats (oldest updated_at).
|
||||
err := client.UpdateChat(ctx, createdChats[0].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = client.UpdateChat(ctx, createdChats[1].ID, codersdk.UpdateChatRequest{
|
||||
PinOrder: ptr.Ref(int32(1)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Paginate with limit=2 using cursor (after_id).
|
||||
const pageSize = 2
|
||||
maxPages := totalChats/pageSize + 2
|
||||
var allPaginated []codersdk.Chat
|
||||
var afterID uuid.UUID
|
||||
for range maxPages {
|
||||
opts := &codersdk.ListChatsOptions{
|
||||
Pagination: codersdk.Pagination{Limit: pageSize},
|
||||
}
|
||||
if afterID != uuid.Nil {
|
||||
opts.Pagination.AfterID = afterID
|
||||
}
|
||||
page, listErr := client.ListChats(ctx, opts)
|
||||
require.NoError(t, listErr)
|
||||
if len(page) == 0 {
|
||||
break
|
||||
}
|
||||
allPaginated = append(allPaginated, page...)
|
||||
afterID = page[len(page)-1].ID
|
||||
}
|
||||
|
||||
// All chats should appear exactly once.
|
||||
seenIDs := make(map[uuid.UUID]struct{}, len(allPaginated))
|
||||
for _, c := range allPaginated {
|
||||
_, dup := seenIDs[c.ID]
|
||||
require.False(t, dup, "chat %s appeared more than once", c.ID)
|
||||
seenIDs[c.ID] = struct{}{}
|
||||
}
|
||||
require.Len(t, seenIDs, totalChats, "all chats should appear in paginated results")
|
||||
|
||||
// Pinned chats should come before unpinned ones, and
|
||||
// within the pinned group, lower pin_order sorts first.
|
||||
pinnedSeen := false
|
||||
unpinnedSeen := false
|
||||
for _, c := range allPaginated {
|
||||
if c.PinOrder > 0 {
|
||||
require.False(t, unpinnedSeen, "pinned chat %s appeared after unpinned chat", c.ID)
|
||||
pinnedSeen = true
|
||||
} else {
|
||||
unpinnedSeen = true
|
||||
}
|
||||
}
|
||||
require.True(t, pinnedSeen, "at least one pinned chat should exist")
|
||||
|
||||
// Verify within-pinned ordering: pin_order=1 before
|
||||
// pin_order=2 (the -pin_order DESC column).
|
||||
require.Equal(t, createdChats[0].ID, allPaginated[0].ID,
|
||||
"pin_order=1 chat should be first")
|
||||
require.Equal(t, createdChats[1].ID, allPaginated[1].ID,
|
||||
"pin_order=2 chat should be second")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatModels(t *testing.T) {
|
||||
@@ -1126,17 +1294,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1148,25 +1305,16 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreatedEventIncludesAllChatFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1186,18 +1334,6 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Skip the initial ping.
|
||||
var event watchEvent
|
||||
err = wsjson.Read(ctx, conn, &event)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, event.Type)
|
||||
require.True(t, len(event.Data) == 0 || string(event.Data) == "null")
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
@@ -1210,18 +1346,11 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
var got codersdk.Chat
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool {
|
||||
var update watchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &update); readErr != nil {
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if readErr := wsjson.Read(ctx, conn, &payload); readErr != nil {
|
||||
return false
|
||||
}
|
||||
if update.Type != codersdk.ServerSentEventTypeData {
|
||||
return false
|
||||
}
|
||||
var payload coderdpubsub.ChatEvent
|
||||
if unmarshalErr := json.Unmarshal(update.Data, &payload); unmarshalErr != nil {
|
||||
return false
|
||||
}
|
||||
if payload.Kind == coderdpubsub.ChatEventKindCreated &&
|
||||
if payload.Kind == codersdk.ChatWatchEventKindCreated &&
|
||||
payload.Chat.ID == createdChat.ID {
|
||||
got = payload.Chat
|
||||
return true
|
||||
@@ -1294,25 +1423,14 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Read the initial ping.
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
// Publish a diff_status_change event via pubsub,
|
||||
// mimicking what PublishDiffStatusChange does after
|
||||
// it reads the diff status from the DB.
|
||||
dbStatus, err := db.GetChatDiffStatusByChatID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
sdkDiffStatus := db2sdk.ChatDiffStatus(chat.ID, &dbStatus)
|
||||
event := coderdpubsub.ChatEvent{
|
||||
Kind: coderdpubsub.ChatEventKindDiffStatusChange,
|
||||
event := codersdk.ChatWatchEvent{
|
||||
Kind: codersdk.ChatWatchEventKindDiffStatusChange,
|
||||
Chat: codersdk.Chat{
|
||||
ID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
@@ -1325,25 +1443,15 @@ func TestWatchChats(t *testing.T) {
|
||||
}
|
||||
payload, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatEventChannel(user.UserID), payload)
|
||||
err = api.Pubsub.Publish(coderdpubsub.ChatWatchEventChannel(user.UserID), payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read events until we find the diff_status_change.
|
||||
for {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
var received codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var received coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &received)
|
||||
require.NoError(t, err)
|
||||
|
||||
if received.Kind != coderdpubsub.ChatEventKindDiffStatusChange ||
|
||||
if received.Kind != codersdk.ChatWatchEventKindDiffStatusChange ||
|
||||
received.Chat.ID != chat.ID {
|
||||
continue
|
||||
}
|
||||
@@ -1362,7 +1470,6 @@ func TestWatchChats(t *testing.T) {
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ArchiveAndUnarchiveEmitEventsForDescendants", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -1405,31 +1512,13 @@ func TestWatchChats(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close(websocket.StatusNormalClosure, "done")
|
||||
|
||||
type watchEvent struct {
|
||||
Type codersdk.ServerSentEventType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
var ping watchEvent
|
||||
err = wsjson.Read(ctx, conn, &ping)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ServerSentEventTypePing, ping.Type)
|
||||
|
||||
collectLifecycleEvents := func(expectedKind coderdpubsub.ChatEventKind) map[uuid.UUID]coderdpubsub.ChatEvent {
|
||||
collectLifecycleEvents := func(expectedKind codersdk.ChatWatchEventKind) map[uuid.UUID]codersdk.ChatWatchEvent {
|
||||
t.Helper()
|
||||
|
||||
events := make(map[uuid.UUID]coderdpubsub.ChatEvent, 3)
|
||||
events := make(map[uuid.UUID]codersdk.ChatWatchEvent, 3)
|
||||
for len(events) < 3 {
|
||||
var update watchEvent
|
||||
err = wsjson.Read(ctx, conn, &update)
|
||||
require.NoError(t, err)
|
||||
if update.Type == codersdk.ServerSentEventTypePing {
|
||||
continue
|
||||
}
|
||||
require.Equal(t, codersdk.ServerSentEventTypeData, update.Type)
|
||||
|
||||
var payload coderdpubsub.ChatEvent
|
||||
err = json.Unmarshal(update.Data, &payload)
|
||||
var payload codersdk.ChatWatchEvent
|
||||
err = wsjson.Read(ctx, conn, &payload)
|
||||
require.NoError(t, err)
|
||||
if payload.Kind != expectedKind {
|
||||
continue
|
||||
@@ -1439,7 +1528,7 @@ func TestWatchChats(t *testing.T) {
|
||||
return events
|
||||
}
|
||||
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]coderdpubsub.ChatEvent, archived bool) {
|
||||
assertLifecycleEvents := func(events map[uuid.UUID]codersdk.ChatWatchEvent, archived bool) {
|
||||
t.Helper()
|
||||
|
||||
require.Len(t, events, 3)
|
||||
@@ -1452,12 +1541,12 @@ func TestWatchChats(t *testing.T) {
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(true)})
|
||||
require.NoError(t, err)
|
||||
deletedEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindDeleted)
|
||||
deletedEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindDeleted)
|
||||
assertLifecycleEvents(deletedEvents, true)
|
||||
|
||||
err = client.UpdateChat(ctx, parentChat.ID, codersdk.UpdateChatRequest{Archived: ptr.Ref(false)})
|
||||
require.NoError(t, err)
|
||||
createdEvents := collectLifecycleEvents(coderdpubsub.ChatEventKindCreated)
|
||||
createdEvents := collectLifecycleEvents(codersdk.ChatWatchEventKindCreated)
|
||||
assertLifecycleEvents(createdEvents, false)
|
||||
})
|
||||
|
||||
@@ -7747,6 +7836,62 @@ func TestChatWorkspaceTTL(t *testing.T) {
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestChatRetentionDays(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
// Default value is 30 (days) when nothing has been configured.
|
||||
resp, err := adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get default")
|
||||
require.Equal(t, int32(30), resp.RetentionDays, "default should be 30")
|
||||
|
||||
// Admin can set retention days to 90.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 90,
|
||||
})
|
||||
require.NoError(t, err, "admin set 90")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after set")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "should return 90")
|
||||
|
||||
// Non-admin member can read the value.
|
||||
resp, err = memberClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "member get")
|
||||
require.Equal(t, int32(90), resp.RetentionDays, "member should see same value")
|
||||
|
||||
// Non-admin member cannot write.
|
||||
err = memberClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{RetentionDays: 7})
|
||||
requireSDKError(t, err, http.StatusForbidden)
|
||||
|
||||
// Admin can disable purge by setting 0.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 0,
|
||||
})
|
||||
require.NoError(t, err, "admin set 0")
|
||||
|
||||
resp, err = adminClient.GetChatRetentionDays(ctx)
|
||||
require.NoError(t, err, "get after zero")
|
||||
require.Equal(t, int32(0), resp.RetentionDays, "should be 0 after disable")
|
||||
|
||||
// Validation: negative value is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: -1,
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
|
||||
// Validation: exceeding the 3650-day maximum is rejected.
|
||||
err = adminClient.UpdateChatRetentionDays(ctx, codersdk.UpdateChatRetentionDaysRequest{
|
||||
RetentionDays: 3651, // retentionDaysMaximum + 1; keep in sync with coderd/exp_chats.go.
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance.
|
||||
func TestUserChatCompactionThresholds(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -8153,6 +8298,375 @@ func TestGetChatsByWorkspace(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSubmitToolResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// setupRequiresAction creates a chat via the DB with dynamic tools,
|
||||
// inserts an assistant message containing tool-call parts for each
|
||||
// given toolCallID, and sets the chat status to requires_action.
|
||||
// It returns the chat row so callers can exercise the endpoint.
|
||||
setupRequiresAction := func(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
db database.Store,
|
||||
ownerID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
dynamicToolName string,
|
||||
toolCallIDs []string,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
// Marshal dynamic tools into the chat row.
|
||||
dynamicTools := []mcp.Tool{{
|
||||
Name: dynamicToolName,
|
||||
Description: "a test dynamic tool",
|
||||
InputSchema: mcp.ToolInputSchema{Type: "object"},
|
||||
}}
|
||||
dtJSON, err := json.Marshal(dynamicTools)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Title: "tool-results-test",
|
||||
DynamicTools: pqtype.NullRawMessage{RawMessage: dtJSON, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Build assistant message with tool-call parts.
|
||||
parts := make([]codersdk.ChatMessagePart, 0, len(toolCallIDs))
|
||||
for _, id := range toolCallIDs {
|
||||
parts = append(parts, codersdk.ChatMessagePart{
|
||||
Type: codersdk.ChatMessagePartTypeToolCall,
|
||||
ToolCallID: id,
|
||||
ToolName: dynamicToolName,
|
||||
Args: json.RawMessage(`{"key":"value"}`),
|
||||
})
|
||||
}
|
||||
content, err := chatprompt.MarshalParts(parts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessages(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelConfigID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Content: []string{string(content.RawMessage)},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0},
|
||||
OutputTokens: []int64{0},
|
||||
TotalTokens: []int64{0},
|
||||
ReasoningTokens: []int64{0},
|
||||
CacheCreationTokens: []int64{0},
|
||||
CacheReadTokens: []int64{0},
|
||||
ContextLimit: []int64{0},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{0},
|
||||
RuntimeMs: []int64{0},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Transition to requires_action.
|
||||
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRequiresAction,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusRequiresAction, chat.Status)
|
||||
|
||||
return chat
|
||||
}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_abc", "call_def"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_abc", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_def", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify status is no longer requires_action. The chatd
|
||||
// loop may have already picked the chat up and
|
||||
// transitioned it further (pending → running → …), so we
|
||||
// accept any non-requires_action status.
|
||||
gotChat, err := client.GetChat(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, codersdk.ChatStatusRequiresAction, gotChat.Status,
|
||||
"chat should no longer be in requires_action after submitting tool results")
|
||||
|
||||
// Verify tool-result messages were persisted.
|
||||
msgsResp, err := client.GetChatMessages(ctx, chat.ID, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var toolResultCount int
|
||||
for _, msg := range msgsResp.Messages {
|
||||
if msg.Role == codersdk.ChatMessageRoleTool {
|
||||
toolResultCount++
|
||||
}
|
||||
}
|
||||
require.Equal(t, len(toolCallIDs), toolResultCount,
|
||||
"expected one tool-result message per submitted result")
|
||||
})
|
||||
|
||||
t.Run("WrongStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
// Create a chat that is NOT in requires_action status.
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
Status: database.ChatStatusWaiting,
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "wrong-status-test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_xyz", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusConflict)
|
||||
})
|
||||
|
||||
t.Run("MissingResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_one", "call_two"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit only one of the two required results.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_one", Output: json.RawMessage(`"partial"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("UnexpectedResult", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_real"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Submit a result with a wrong tool_call_id.
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_bogus", Output: json.RawMessage(`"wrong"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("InvalidJSONOutput", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_json"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// We must bypass the SDK client because json.RawMessage
|
||||
// rejects invalid JSON during json.Marshal. A raw HTTP
|
||||
// request lets the invalid payload reach the server so we
|
||||
// can verify server-side validation.
|
||||
rawBody := `{"results":[{"tool_call_id":"call_json","output":not-json,"is_error":false}]}`
|
||||
url := client.URL.JoinPath(fmt.Sprintf("/api/experimental/chats/%s/tool-results", chat.ID)).String()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBufferString(rawBody))
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolCallID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_dup1", "call_dup2"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_a"`)},
|
||||
{ToolCallID: "call_dup1", Output: json.RawMessage(`"result_b"`)},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Contains(t, sdkErr.Message, "Duplicate tool_call_id")
|
||||
})
|
||||
|
||||
t.Run("EmptyResults", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_empty"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
err := client.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
const toolName = "my_dynamic_tool"
|
||||
toolCallIDs := []string{"call_other"}
|
||||
|
||||
chat := setupRequiresAction(ctx, t, db, user.UserID, modelConfig.ID, toolName, toolCallIDs)
|
||||
|
||||
// Create a second user and try to submit tool results
|
||||
// to user A's chat.
|
||||
otherClientRaw, _ := coderdtest.CreateAnotherUser(
|
||||
t, client.Client, user.OrganizationID,
|
||||
rbac.RoleAgentsAccess(),
|
||||
)
|
||||
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
|
||||
|
||||
err := otherClient.SubmitToolResults(ctx, chat.ID, codersdk.SubmitToolResultsRequest{
|
||||
Results: []codersdk.ToolResult{
|
||||
{ToolCallID: "call_other", Output: json.RawMessage(`"nope"`)},
|
||||
},
|
||||
})
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostChats_DynamicToolValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("TooManyTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
tools := make([]codersdk.DynamicTool, 251)
|
||||
for i := range tools {
|
||||
tools[i] = codersdk.DynamicTool{
|
||||
Name: fmt.Sprintf("tool-%d", i),
|
||||
}
|
||||
}
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: tools,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Too many dynamic tools.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("EmptyToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: ""},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Dynamic tool name must not be empty.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("DuplicateToolName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "hello",
|
||||
}},
|
||||
UnsafeDynamicTools: []codersdk.DynamicTool{
|
||||
{Name: "dup-tool"},
|
||||
{Name: "dup-tool"},
|
||||
},
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Duplicate dynamic tool name.", sdkErr.Message)
|
||||
})
|
||||
}
|
||||
|
||||
func requireSDKError(t *testing.T, err error, expectedStatus int) *codersdk.Error {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
package coderd
|
||||
|
||||
// InsertAgentChatTestModelConfig exposes insertAgentChatTestModelConfig for external tests.
|
||||
var InsertAgentChatTestModelConfig = insertAgentChatTestModelConfig
|
||||
@@ -148,7 +148,7 @@ func TestGetOrgMembersFilter(t *testing.T) {
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.OrganizationMembersPaginated(testCtx, first.OrganizationID, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Members))
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -146,12 +147,35 @@ func ShowAuthorizePage(accessURL *url.URL) http.HandlerFunc {
|
||||
cancel := params.redirectURL
|
||||
cancelQuery := params.redirectURL.Query()
|
||||
cancelQuery.Add("error", "access_denied")
|
||||
cancelQuery.Add("error_description", "The resource owner or authorization server denied the request")
|
||||
if params.state != "" {
|
||||
cancelQuery.Add("state", params.state)
|
||||
}
|
||||
cancel.RawQuery = cancelQuery.Encode()
|
||||
|
||||
cancelURI := cancel.String()
|
||||
if err := codersdk.ValidateRedirectURIScheme(cancel); err != nil {
|
||||
site.RenderStaticErrorPage(rw, r, site.ErrorPageData{
|
||||
Status: http.StatusBadRequest,
|
||||
HideStatus: false,
|
||||
Title: "Invalid Callback URL",
|
||||
Description: "The application's registered callback URL has an invalid scheme.",
|
||||
Actions: []site.Action{
|
||||
{
|
||||
URL: accessURL.String(),
|
||||
Text: "Back to site",
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
site.RenderOAuthAllowPage(rw, r, site.RenderOAuthAllowData{
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
CancelURI: cancel.String(),
|
||||
AppIcon: app.Icon,
|
||||
AppName: app.Name,
|
||||
// #nosec G203 -- The scheme is validated by
|
||||
// codersdk.ValidateRedirectURIScheme above.
|
||||
CancelURI: htmltemplate.URL(cancelURI),
|
||||
RedirectURI: r.URL.String(),
|
||||
CSRFToken: nosurf.Token(r),
|
||||
Username: ua.FriendlyName,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package oauth2provider_test
|
||||
|
||||
import (
|
||||
htmltemplate "html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -20,7 +21,7 @@ func TestOAuthConsentFormIncludesCSRFToken(t *testing.T) {
|
||||
|
||||
site.RenderOAuthAllowPage(rec, req, site.RenderOAuthAllowData{
|
||||
AppName: "Test OAuth App",
|
||||
CancelURI: "https://coder.com/cancel",
|
||||
CancelURI: htmltemplate.URL("https://coder.com/cancel"),
|
||||
RedirectURI: "https://coder.com/oauth2/authorize?client_id=test",
|
||||
CSRFToken: csrfFieldValue,
|
||||
Username: "test-user",
|
||||
|
||||
@@ -435,6 +435,9 @@ func convertProvisionerJobWithQueuePosition(pj database.GetProvisionerJobsByOrga
|
||||
if pj.WorkspaceID.Valid {
|
||||
job.Metadata.WorkspaceID = &pj.WorkspaceID.UUID
|
||||
}
|
||||
if pj.WorkspaceBuildTransition.Valid {
|
||||
job.Metadata.WorkspaceBuildTransition = codersdk.WorkspaceTransition(pj.WorkspaceBuildTransition.WorkspaceTransition)
|
||||
}
|
||||
return job
|
||||
}
|
||||
|
||||
|
||||
@@ -97,13 +97,14 @@ func TestProvisionerJobs(t *testing.T) {
|
||||
|
||||
// Verify that job metadata is correct.
|
||||
assert.Equal(t, job2.Metadata, codersdk.ProvisionerJobMetadata{
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
TemplateVersionName: version.Name,
|
||||
TemplateID: template.ID,
|
||||
TemplateName: template.Name,
|
||||
TemplateDisplayName: template.DisplayName,
|
||||
TemplateIcon: template.Icon,
|
||||
WorkspaceID: &w.ID,
|
||||
WorkspaceName: w.Name,
|
||||
WorkspaceBuildTransition: codersdk.WorkspaceTransitionStart,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
const ChatConfigEventChannel = "chat:config_change"
|
||||
|
||||
// HandleChatConfigEvent wraps a typed callback for ChatConfigEvent
|
||||
// messages, following the same pattern as HandleChatEvent.
|
||||
// messages, following the same pattern as HandleChatWatchEvent.
|
||||
func HandleChatConfigEvent(cb func(ctx context.Context, payload ChatConfigEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func ChatEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
func HandleChatEvent(cb func(ctx context.Context, payload ChatEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("chat event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload ChatEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, ChatEvent{}, xerrors.Errorf("unmarshal chat event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
|
||||
type ChatEvent struct {
|
||||
Kind ChatEventKind `json:"kind"`
|
||||
Chat codersdk.Chat `json:"chat"`
|
||||
}
|
||||
|
||||
type ChatEventKind string
|
||||
|
||||
const (
|
||||
ChatEventKindStatusChange ChatEventKind = "status_change"
|
||||
ChatEventKindTitleChange ChatEventKind = "title_change"
|
||||
ChatEventKindCreated ChatEventKind = "created"
|
||||
ChatEventKindDeleted ChatEventKind = "deleted"
|
||||
ChatEventKindDiffStatusChange ChatEventKind = "diff_status_change"
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// ChatWatchEventChannel returns the pubsub channel for chat
|
||||
// lifecycle events scoped to a single user.
|
||||
func ChatWatchEventChannel(ownerID uuid.UUID) string {
|
||||
return fmt.Sprintf("chat:owner:%s", ownerID)
|
||||
}
|
||||
|
||||
// HandleChatWatchEvent wraps a typed callback for
|
||||
// ChatWatchEvent messages delivered via pubsub.
|
||||
func HandleChatWatchEvent(cb func(ctx context.Context, payload codersdk.ChatWatchEvent, err error)) func(ctx context.Context, message []byte, err error) {
|
||||
return func(ctx context.Context, message []byte, err error) {
|
||||
if err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("chat watch event pubsub: %w", err))
|
||||
return
|
||||
}
|
||||
var payload codersdk.ChatWatchEvent
|
||||
if err := json.Unmarshal(message, &payload); err != nil {
|
||||
cb(ctx, codersdk.ChatWatchEvent{}, xerrors.Errorf("unmarshal chat watch event: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
cb(ctx, payload, err)
|
||||
}
|
||||
}
|
||||
+17
-35
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -390,7 +389,6 @@ type MultiAgentController struct {
|
||||
// connections to the destination
|
||||
tickets map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
coordination *tailnet.BasicCoordination
|
||||
sendGroup singleflight.Group
|
||||
|
||||
cancel context.CancelFunc
|
||||
expireOldAgentsDone chan struct{}
|
||||
@@ -420,44 +418,28 @@ func (m *MultiAgentController) New(client tailnet.CoordinatorClient) tailnet.Clo
|
||||
|
||||
func (m *MultiAgentController) ensureAgent(agentID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.connectionTimes[agentID]
|
||||
if ok {
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
|
||||
_, err, _ := m.sendGroup.Do(agentID.String(), func() (interface{}, error) {
|
||||
m.mu.Lock()
|
||||
coord := m.coordination
|
||||
m.mu.Unlock()
|
||||
if coord == nil {
|
||||
return nil, xerrors.New("no active coordination")
|
||||
// If we don't have the agent, subscribe.
|
||||
if !ok {
|
||||
m.logger.Debug(context.Background(),
|
||||
"subscribing to agent", slog.F("agent_id", agentID))
|
||||
if m.coordination != nil {
|
||||
err := m.coordination.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
err = xerrors.Errorf("subscribe agent: %w", err)
|
||||
m.coordination.SendErr(err)
|
||||
_ = m.coordination.Client.Close()
|
||||
m.coordination = nil
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := coord.Client.Send(&proto.CoordinateRequest{
|
||||
AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.tickets[agentID] = map[uuid.UUID]struct{}{}
|
||||
m.mu.Unlock()
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.Error(context.Background(), "ensureAgent send failed",
|
||||
slog.F("agent_id", agentID), slog.Error(err))
|
||||
return xerrors.Errorf("send AddTunnel: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.connectionTimes[agentID] = time.Now()
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -776,6 +776,40 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return nil
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
chats, err := r.options.Database.GetChatsUpdatedAfter(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chats updated after: %w", err)
|
||||
}
|
||||
snapshot.Chats = make([]Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
snapshot.Chats = append(snapshot.Chats, ConvertChat(chat))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
summaries, err := r.options.Database.GetChatMessageSummariesPerChat(ctx, createdAfter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat message summaries: %w", err)
|
||||
}
|
||||
snapshot.ChatMessageSummaries = make([]ChatMessageSummary, 0, len(summaries))
|
||||
for _, s := range summaries {
|
||||
snapshot.ChatMessageSummaries = append(snapshot.ChatMessageSummaries, ConvertChatMessageSummary(s))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
configs, err := r.options.Database.GetChatModelConfigsForTelemetry(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat model configs: %w", err)
|
||||
}
|
||||
snapshot.ChatModelConfigs = make([]ChatModelConfig, 0, len(configs))
|
||||
for _, c := range configs {
|
||||
snapshot.ChatModelConfigs = append(snapshot.ChatModelConfigs, ConvertChatModelConfig(c))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
err := eg.Wait()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1503,6 +1537,9 @@ type Snapshot struct {
|
||||
AIBridgeInterceptionsSummaries []AIBridgeInterceptionsSummary `json:"aibridge_interceptions_summaries"`
|
||||
BoundaryUsageSummary *BoundaryUsageSummary `json:"boundary_usage_summary"`
|
||||
FirstUserOnboarding *FirstUserOnboarding `json:"first_user_onboarding"`
|
||||
Chats []Chat `json:"chats"`
|
||||
ChatMessageSummaries []ChatMessageSummary `json:"chat_message_summaries"`
|
||||
ChatModelConfigs []ChatModelConfig `json:"chat_model_configs"`
|
||||
}
|
||||
|
||||
// Deployment contains information about the host running Coder.
|
||||
@@ -2113,6 +2150,66 @@ func ConvertTask(task database.Task) Task {
|
||||
return t
|
||||
}
|
||||
|
||||
// ConvertChat converts a database chat row to a telemetry Chat.
|
||||
func ConvertChat(dbChat database.GetChatsUpdatedAfterRow) Chat {
|
||||
c := Chat{
|
||||
ID: dbChat.ID,
|
||||
OwnerID: dbChat.OwnerID,
|
||||
CreatedAt: dbChat.CreatedAt,
|
||||
UpdatedAt: dbChat.UpdatedAt,
|
||||
Status: string(dbChat.Status),
|
||||
HasParent: dbChat.HasParent,
|
||||
Archived: dbChat.Archived,
|
||||
LastModelConfigID: dbChat.LastModelConfigID,
|
||||
}
|
||||
if dbChat.RootChatID.Valid {
|
||||
c.RootChatID = &dbChat.RootChatID.UUID
|
||||
}
|
||||
if dbChat.WorkspaceID.Valid {
|
||||
c.WorkspaceID = &dbChat.WorkspaceID.UUID
|
||||
}
|
||||
if dbChat.Mode.Valid {
|
||||
mode := string(dbChat.Mode.ChatMode)
|
||||
c.Mode = &mode
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// ConvertChatMessageSummary converts a database chat message
|
||||
// summary row to a telemetry ChatMessageSummary.
|
||||
func ConvertChatMessageSummary(dbRow database.GetChatMessageSummariesPerChatRow) ChatMessageSummary {
|
||||
return ChatMessageSummary{
|
||||
ChatID: dbRow.ChatID,
|
||||
MessageCount: dbRow.MessageCount,
|
||||
UserMessageCount: dbRow.UserMessageCount,
|
||||
AssistantMessageCount: dbRow.AssistantMessageCount,
|
||||
ToolMessageCount: dbRow.ToolMessageCount,
|
||||
SystemMessageCount: dbRow.SystemMessageCount,
|
||||
TotalInputTokens: dbRow.TotalInputTokens,
|
||||
TotalOutputTokens: dbRow.TotalOutputTokens,
|
||||
TotalReasoningTokens: dbRow.TotalReasoningTokens,
|
||||
TotalCacheCreationTokens: dbRow.TotalCacheCreationTokens,
|
||||
TotalCacheReadTokens: dbRow.TotalCacheReadTokens,
|
||||
TotalCostMicros: dbRow.TotalCostMicros,
|
||||
TotalRuntimeMs: dbRow.TotalRuntimeMs,
|
||||
DistinctModelCount: dbRow.DistinctModelCount,
|
||||
CompressedMessageCount: dbRow.CompressedMessageCount,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertChatModelConfig converts a database model config row to a
|
||||
// telemetry ChatModelConfig.
|
||||
func ConvertChatModelConfig(dbRow database.GetChatModelConfigsForTelemetryRow) ChatModelConfig {
|
||||
return ChatModelConfig{
|
||||
ID: dbRow.ID,
|
||||
Provider: dbRow.Provider,
|
||||
Model: dbRow.Model,
|
||||
ContextLimit: dbRow.ContextLimit,
|
||||
Enabled: dbRow.Enabled,
|
||||
IsDefault: dbRow.IsDefault,
|
||||
}
|
||||
}
|
||||
|
||||
type telemetryItemKey string
|
||||
|
||||
// The comment below gets rid of the warning that the name "TelemetryItemKey" has
|
||||
@@ -2234,6 +2331,53 @@ type BoundaryUsageSummary struct {
|
||||
PeriodDurationMilliseconds int64 `json:"period_duration_ms"`
|
||||
}
|
||||
|
||||
// Chat contains anonymized metadata about a chat for telemetry.
|
||||
// Titles and message content are excluded to avoid PII leakage.
|
||||
type Chat struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
OwnerID uuid.UUID `json:"owner_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Status string `json:"status"`
|
||||
HasParent bool `json:"has_parent"`
|
||||
RootChatID *uuid.UUID `json:"root_chat_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspace_id"`
|
||||
Mode *string `json:"mode"`
|
||||
Archived bool `json:"archived"`
|
||||
LastModelConfigID uuid.UUID `json:"last_model_config_id"`
|
||||
}
|
||||
|
||||
// ChatMessageSummary contains per-chat aggregated message metrics
|
||||
// for telemetry. Individual message content is never included.
|
||||
type ChatMessageSummary struct {
|
||||
ChatID uuid.UUID `json:"chat_id"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
UserMessageCount int64 `json:"user_message_count"`
|
||||
AssistantMessageCount int64 `json:"assistant_message_count"`
|
||||
ToolMessageCount int64 `json:"tool_message_count"`
|
||||
SystemMessageCount int64 `json:"system_message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalReasoningTokens int64 `json:"total_reasoning_tokens"`
|
||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
TotalRuntimeMs int64 `json:"total_runtime_ms"`
|
||||
DistinctModelCount int64 `json:"distinct_model_count"`
|
||||
CompressedMessageCount int64 `json:"compressed_message_count"`
|
||||
}
|
||||
|
||||
// ChatModelConfig contains model configuration metadata for
|
||||
// telemetry. Sensitive fields like API keys are excluded.
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ContextLimit int64 `json:"context_limit"`
|
||||
Enabled bool `json:"enabled"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
func ConvertAIBridgeInterceptionsSummary(endTime time.Time, provider, model, client string, summary database.CalculateAIBridgeInterceptionsTelemetrySummaryRow) AIBridgeInterceptionsSummary {
|
||||
return AIBridgeInterceptionsSummary{
|
||||
ID: uuid.New(),
|
||||
|
||||
@@ -1549,3 +1549,303 @@ func TestTelemetry_BoundaryUsageSummary(t *testing.T) {
|
||||
require.Nil(t, snapshot2.BoundaryUsageSummary)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatsTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
// Create chat providers (required FK for model configs).
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "Anthropic",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a model config.
|
||||
modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
DisplayName: "Claude Sonnet",
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 200000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a second model config to test full dump.
|
||||
modelCfg2, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o",
|
||||
DisplayName: "GPT-4o",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a soft-deleted model config — should NOT appear in telemetry.
|
||||
deletedCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-deleted",
|
||||
DisplayName: "Deleted Model",
|
||||
Enabled: true,
|
||||
IsDefault: false,
|
||||
ContextLimit: 100000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage("{}"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.DeleteChatModelConfigByID(ctx, deletedCfg.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a root chat with a workspace.
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
job := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{
|
||||
OrganizationID: org.ID,
|
||||
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
|
||||
})
|
||||
tpl := dbgen.Template(t, db, database.Template{
|
||||
OrganizationID: org.ID,
|
||||
CreatedBy: user.ID,
|
||||
})
|
||||
tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: uuid.NullUUID{UUID: tpl.ID, Valid: true},
|
||||
CreatedBy: user.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
ws := dbgen.Workspace(t, db, database.WorkspaceTable{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
TemplateID: tpl.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStart,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
WorkspaceID: ws.ID,
|
||||
TemplateVersionID: tv.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
|
||||
rootChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg.ID,
|
||||
Title: "Root Chat",
|
||||
Status: database.ChatStatusRunning,
|
||||
WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true},
|
||||
Mode: database.NullChatMode{ChatMode: database.ChatModeComputerUse, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a child chat (has parent + root).
|
||||
childChat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: modelCfg2.ID,
|
||||
Title: "Child Chat",
|
||||
Status: database.ChatStatusCompleted,
|
||||
ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for root chat: 2 user, 2 assistant, 1 tool.
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, user.ID, uuid.Nil, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID, modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
||||
Content: []string{`[{"type":"text","text":"hello"}]`, `[{"type":"text","text":"hi"}]`, `[{"type":"text","text":"help"}]`, `[{"type":"text","text":"sure"}]`, `[{"type":"text","text":"result"}]`},
|
||||
ContentVersion: []int16{1, 1, 1, 1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{100, 200, 150, 300, 0},
|
||||
OutputTokens: []int64{0, 50, 0, 100, 0},
|
||||
TotalTokens: []int64{100, 250, 150, 400, 0},
|
||||
ReasoningTokens: []int64{0, 10, 0, 20, 0},
|
||||
CacheCreationTokens: []int64{50, 0, 30, 0, 0},
|
||||
CacheReadTokens: []int64{0, 25, 0, 40, 0},
|
||||
ContextLimit: []int64{200000, 200000, 200000, 200000, 200000},
|
||||
Compressed: []bool{false, false, false, false, false},
|
||||
TotalCostMicros: []int64{1000, 2000, 1500, 3000, 0},
|
||||
RuntimeMs: []int64{0, 500, 0, 800, 100},
|
||||
ProviderResponseID: []string{"", "resp-1", "", "resp-2", ""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert messages for child chat: 1 user, 1 assistant (compressed).
|
||||
_, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: childChat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg2.ID, modelCfg2.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"q"}]`, `[{"type":"text","text":"a"}]`},
|
||||
ContentVersion: []int16{1, 1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{500, 600},
|
||||
OutputTokens: []int64{0, 200},
|
||||
TotalTokens: []int64{500, 800},
|
||||
ReasoningTokens: []int64{0, 50},
|
||||
CacheCreationTokens: []int64{100, 0},
|
||||
CacheReadTokens: []int64{0, 75},
|
||||
ContextLimit: []int64{128000, 128000},
|
||||
Compressed: []bool{false, true},
|
||||
TotalCostMicros: []int64{5000, 8000},
|
||||
RuntimeMs: []int64{0, 1200},
|
||||
ProviderResponseID: []string{"", "resp-3"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert a soft-deleted message on root chat with large token values.
|
||||
// This acts as "poison" — if the deleted filter is missing, totals
|
||||
// will be inflated and assertions below will fail.
|
||||
poisonMsgs, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: rootChat.ID,
|
||||
CreatedBy: []uuid.UUID{uuid.Nil},
|
||||
ModelConfigID: []uuid.UUID{modelCfg.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{`[{"type":"text","text":"poison"}]`},
|
||||
ContentVersion: []int16{1},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{999999},
|
||||
OutputTokens: []int64{999999},
|
||||
TotalTokens: []int64{999999},
|
||||
ReasoningTokens: []int64{999999},
|
||||
CacheCreationTokens: []int64{999999},
|
||||
CacheReadTokens: []int64{999999},
|
||||
ContextLimit: []int64{200000},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{999999},
|
||||
RuntimeMs: []int64{999999},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.SoftDeleteChatMessageByID(ctx, poisonMsgs[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, snapshot := collectSnapshot(ctx, t, db, nil)
|
||||
|
||||
// --- Assert Chats ---
|
||||
require.Len(t, snapshot.Chats, 2)
|
||||
|
||||
// Find root and child by HasParent flag.
|
||||
var foundRoot, foundChild *telemetry.Chat
|
||||
for i := range snapshot.Chats {
|
||||
if !snapshot.Chats[i].HasParent {
|
||||
foundRoot = &snapshot.Chats[i]
|
||||
} else {
|
||||
foundChild = &snapshot.Chats[i]
|
||||
}
|
||||
}
|
||||
require.NotNil(t, foundRoot, "expected root chat")
|
||||
require.NotNil(t, foundChild, "expected child chat")
|
||||
|
||||
// Root chat assertions.
|
||||
assert.Equal(t, rootChat.ID, foundRoot.ID)
|
||||
assert.Equal(t, user.ID, foundRoot.OwnerID)
|
||||
assert.Equal(t, "running", foundRoot.Status)
|
||||
assert.False(t, foundRoot.HasParent)
|
||||
assert.Nil(t, foundRoot.RootChatID)
|
||||
require.NotNil(t, foundRoot.WorkspaceID)
|
||||
assert.Equal(t, ws.ID, *foundRoot.WorkspaceID)
|
||||
assert.Equal(t, modelCfg.ID, foundRoot.LastModelConfigID)
|
||||
require.NotNil(t, foundRoot.Mode)
|
||||
assert.Equal(t, "computer_use", *foundRoot.Mode)
|
||||
assert.False(t, foundRoot.Archived)
|
||||
|
||||
// Child chat assertions.
|
||||
assert.Equal(t, childChat.ID, foundChild.ID)
|
||||
assert.Equal(t, user.ID, foundChild.OwnerID)
|
||||
assert.True(t, foundChild.HasParent)
|
||||
require.NotNil(t, foundChild.RootChatID)
|
||||
assert.Equal(t, rootChat.ID, *foundChild.RootChatID)
|
||||
assert.Nil(t, foundChild.WorkspaceID)
|
||||
assert.Equal(t, "completed", foundChild.Status)
|
||||
assert.Equal(t, modelCfg2.ID, foundChild.LastModelConfigID)
|
||||
assert.Nil(t, foundChild.Mode)
|
||||
assert.False(t, foundChild.Archived)
|
||||
|
||||
// --- Assert ChatMessageSummaries ---
|
||||
require.Len(t, snapshot.ChatMessageSummaries, 2)
|
||||
|
||||
summaryMap := make(map[uuid.UUID]telemetry.ChatMessageSummary)
|
||||
for _, s := range snapshot.ChatMessageSummaries {
|
||||
summaryMap[s.ChatID] = s
|
||||
}
|
||||
|
||||
// Root chat summary: 2 user + 2 assistant + 1 tool = 5 messages.
|
||||
rootSummary, ok := summaryMap[rootChat.ID]
|
||||
require.True(t, ok, "expected summary for root chat")
|
||||
assert.Equal(t, int64(5), rootSummary.MessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(2), rootSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1), rootSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), rootSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(750), rootSummary.TotalInputTokens) // 100+200+150+300+0
|
||||
assert.Equal(t, int64(150), rootSummary.TotalOutputTokens) // 0+50+0+100+0
|
||||
assert.Equal(t, int64(30), rootSummary.TotalReasoningTokens) // 0+10+0+20+0
|
||||
assert.Equal(t, int64(80), rootSummary.TotalCacheCreationTokens) // 50+0+30+0+0
|
||||
assert.Equal(t, int64(65), rootSummary.TotalCacheReadTokens) // 0+25+0+40+0
|
||||
assert.Equal(t, int64(7500), rootSummary.TotalCostMicros) // 1000+2000+1500+3000+0
|
||||
assert.Equal(t, int64(1400), rootSummary.TotalRuntimeMs) // 0+500+0+800+100
|
||||
assert.Equal(t, int64(1), rootSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(0), rootSummary.CompressedMessageCount)
|
||||
|
||||
// Child chat summary: 1 user + 1 assistant = 2 messages, 1 compressed.
|
||||
childSummary, ok := summaryMap[childChat.ID]
|
||||
require.True(t, ok, "expected summary for child chat")
|
||||
assert.Equal(t, int64(2), childSummary.MessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.UserMessageCount)
|
||||
assert.Equal(t, int64(1), childSummary.AssistantMessageCount)
|
||||
assert.Equal(t, int64(1100), childSummary.TotalInputTokens) // 500+600
|
||||
assert.Equal(t, int64(200), childSummary.TotalOutputTokens) // 0+200
|
||||
assert.Equal(t, int64(50), childSummary.TotalReasoningTokens) // 0+50
|
||||
assert.Equal(t, int64(0), childSummary.ToolMessageCount)
|
||||
assert.Equal(t, int64(0), childSummary.SystemMessageCount)
|
||||
assert.Equal(t, int64(100), childSummary.TotalCacheCreationTokens) // 100+0
|
||||
assert.Equal(t, int64(75), childSummary.TotalCacheReadTokens) // 0+75
|
||||
assert.Equal(t, int64(13000), childSummary.TotalCostMicros) // 5000+8000
|
||||
assert.Equal(t, int64(1200), childSummary.TotalRuntimeMs) // 0+1200
|
||||
assert.Equal(t, int64(1), childSummary.DistinctModelCount)
|
||||
assert.Equal(t, int64(1), childSummary.CompressedMessageCount)
|
||||
|
||||
// --- Assert ChatModelConfigs ---
|
||||
require.Len(t, snapshot.ChatModelConfigs, 2)
|
||||
|
||||
configMap := make(map[uuid.UUID]telemetry.ChatModelConfig)
|
||||
for _, c := range snapshot.ChatModelConfigs {
|
||||
configMap[c.ID] = c
|
||||
}
|
||||
|
||||
cfg1, ok := configMap[modelCfg.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic", cfg1.Provider)
|
||||
assert.Equal(t, "claude-sonnet-4-20250514", cfg1.Model)
|
||||
assert.Equal(t, int64(200000), cfg1.ContextLimit)
|
||||
assert.True(t, cfg1.Enabled)
|
||||
assert.True(t, cfg1.IsDefault)
|
||||
|
||||
cfg2, ok := configMap[modelCfg2.ID]
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "openai", cfg2.Provider)
|
||||
assert.Equal(t, "gpt-4o", cfg2.Model)
|
||||
assert.Equal(t, int64(128000), cfg2.ContextLimit)
|
||||
assert.True(t, cfg2.Enabled)
|
||||
assert.False(t, cfg2.IsDefault)
|
||||
}
|
||||
|
||||
+8
-12
@@ -475,6 +475,14 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req.UserLoginType = codersdk.LoginTypeNone
|
||||
|
||||
// Service accounts are a Premium feature.
|
||||
if !api.Entitlements.Enabled(codersdk.FeatureServiceAccounts) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: fmt.Sprintf("%s is a Premium feature. Contact sales!", codersdk.FeatureServiceAccounts.Humanize()),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if req.UserLoginType == "" {
|
||||
// Default to password auth
|
||||
req.UserLoginType = codersdk.LoginTypePassword
|
||||
@@ -1630,18 +1638,6 @@ func (api *API) CreateUser(ctx context.Context, store database.Store, req Create
|
||||
rbacRoles = req.RBACRoles
|
||||
}
|
||||
|
||||
// When the agents experiment is enabled, auto-assign the
|
||||
// agents-access role so new users can use Coder Agents
|
||||
// without manual admin intervention. Skip this for OIDC
|
||||
// users when site role sync is enabled, because the sync
|
||||
// will overwrite roles on every login anyway — those
|
||||
// admins should use --oidc-user-role-default instead.
|
||||
if api.Experiments.Enabled(codersdk.ExperimentAgents) &&
|
||||
!(req.LoginType == database.LoginTypeOIDC && api.IDPSync.SiteRoleSyncEnabled()) &&
|
||||
!slices.Contains(rbacRoles, codersdk.RoleAgentsAccess) {
|
||||
rbacRoles = append(rbacRoles, codersdk.RoleAgentsAccess)
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
orgRoles := make([]string, 0)
|
||||
|
||||
+5
-142
@@ -829,35 +829,6 @@ func TestPostUsers(t *testing.T) {
|
||||
assert.Equal(t, firstUser.OrganizationID, user.OrganizationIDs[0])
|
||||
})
|
||||
|
||||
// CreateWithAgentsExperiment verifies that new users
|
||||
// are auto-assigned the agents-access role when the
|
||||
// experiment is enabled. The experiment-disabled case
|
||||
// is implicitly covered by TestInitialRoles, which
|
||||
// asserts exactly [owner] with no experiment — it
|
||||
// would fail if agents-access leaked through.
|
||||
t.Run("CreateWithAgentsExperiment", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.Experiments = []string{string(codersdk.ExperimentAgents)}
|
||||
client := coderdtest.New(t, &coderdtest.Options{DeploymentValues: dv})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{firstUser.OrganizationID},
|
||||
Email: "another@user.org",
|
||||
Username: "someone-else",
|
||||
Password: "SomeSecurePassword!",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
roles, err := client.UserRoles(ctx, user.Username)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, roles.Roles, codersdk.RoleAgentsAccess,
|
||||
"new user should have agents-access role when agents experiment is enabled")
|
||||
})
|
||||
|
||||
t.Run("CreateWithStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditor := audit.NewMock()
|
||||
@@ -979,7 +950,7 @@ func TestPostUsers(t *testing.T) {
|
||||
require.Equal(t, found.LoginType, codersdk.LoginTypeOIDC)
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/OK", func(t *testing.T) {
|
||||
t.Run("ServiceAccount/Unlicensed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -987,98 +958,16 @@ func TestPostUsers(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-ok",
|
||||
UserLoginType: codersdk.LoginTypeNone,
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.LoginTypeNone, user.LoginType)
|
||||
require.Empty(t, user.Email)
|
||||
require.Equal(t, "service-acct-ok", user.Username)
|
||||
require.Equal(t, codersdk.UserStatusDormant, user.Status)
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-email",
|
||||
Email: "should-not-have@email.com",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Email cannot be set for service accounts")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithPassword", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-password",
|
||||
Password: "ShouldNotHavePassword123!",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Password cannot be set for service accounts")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/WithInvalidLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-login-type",
|
||||
UserLoginType: codersdk.LoginTypePassword,
|
||||
ServiceAccount: true,
|
||||
})
|
||||
var apiErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Service accounts must use login type 'none'")
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/DefaultLoginType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-default-login",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := client.User(ctx, user.ID.String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.LoginTypeNone, found.LoginType)
|
||||
require.Empty(t, found.Email)
|
||||
require.Equal(t, http.StatusForbidden, apiErr.StatusCode())
|
||||
require.Contains(t, apiErr.Message, "Premium feature")
|
||||
})
|
||||
|
||||
t.Run("NonServiceAccount/WithoutEmail", func(t *testing.T) {
|
||||
@@ -1098,32 +987,6 @@ func TestPostUsers(t *testing.T) {
|
||||
require.ErrorAs(t, err, &apiErr)
|
||||
require.Equal(t, http.StatusBadRequest, apiErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ServiceAccount/MultipleWithoutEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
user1, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-multi-1",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, user1.Email)
|
||||
|
||||
user2, err := client.CreateUserWithOrgs(ctx, codersdk.CreateUserRequestWithOrgs{
|
||||
OrganizationIDs: []uuid.UUID{first.OrganizationID},
|
||||
Username: "service-acct-multi-2",
|
||||
ServiceAccount: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, user2.Email)
|
||||
require.NotEqual(t, user1.ID, user2.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotifyCreatedUser(t *testing.T) {
|
||||
@@ -1832,7 +1695,7 @@ func TestGetUsersFilter(t *testing.T) {
|
||||
setupCtx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
coderdtest.UsersFilter(setupCtx, t, client, api.Database, nil, nil, func(testCtx context.Context, req codersdk.UsersRequest) []codersdk.ReducedUser {
|
||||
res, err := client.Users(testCtx, req)
|
||||
require.NoError(t, err)
|
||||
reduced := make([]codersdk.ReducedUser, len(res.Users))
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// @Summary Create a new user secret
|
||||
// @ID create-a-new-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param request body codersdk.CreateUserSecretRequest true "Create secret request"
|
||||
// @Success 201 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [post]
|
||||
func (api *API) postUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
var req codersdk.CreateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Name is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Value is required.",
|
||||
})
|
||||
return
|
||||
}
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := codersdk.UserSecretFilePathValid(req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
secret, err := api.Database.CreateUserSecret(ctx, database.CreateUserSecretParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Value: req.Value,
|
||||
ValueKeyID: sql.NullString{},
|
||||
EnvName: req.EnvName,
|
||||
FilePath: req.FilePath,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "A secret with that name, environment variable, or file path already exists.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error creating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary List user secrets
|
||||
// @ID list-user-secrets
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Success 200 {array} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets [get]
|
||||
func (api *API) getUserSecrets(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
secrets, err := api.Database.ListUserSecrets(ctx, user.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error listing secrets.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecrets(secrets))
|
||||
}
|
||||
|
||||
// @Summary Get a user secret by name
|
||||
// @ID get-a-user-secret-by-name
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [get]
|
||||
func (api *API) getUserSecret(rw http.ResponseWriter, r *http.Request) { //nolint:revive // Method name matches route.
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
secret, err := api.Database.GetUserSecretByUserIDAndName(ctx, database.GetUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Update a user secret
|
||||
// @ID update-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Param request body codersdk.UpdateUserSecretRequest true "Update secret request"
|
||||
// @Success 200 {object} codersdk.UserSecret
|
||||
// @Router /users/{user}/secrets/{name} [patch]
|
||||
func (api *API) patchUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
var req codersdk.UpdateUserSecretRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Value == nil && req.Description == nil && req.EnvName == nil && req.FilePath == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "At least one field must be provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
envOpts := codersdk.UserSecretEnvValidationOptions{
|
||||
AIGatewayEnabled: api.DeploymentValues.AI.BridgeConfig.Enabled.Value(),
|
||||
}
|
||||
if err := codersdk.UserSecretEnvNameValid(*req.EnvName, envOpts); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid environment variable name.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
if err := codersdk.UserSecretFilePathValid(*req.FilePath); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid file path.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
params := database.UpdateUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
UpdateValue: req.Value != nil,
|
||||
Value: "",
|
||||
ValueKeyID: sql.NullString{},
|
||||
UpdateDescription: req.Description != nil,
|
||||
Description: "",
|
||||
UpdateEnvName: req.EnvName != nil,
|
||||
EnvName: "",
|
||||
UpdateFilePath: req.FilePath != nil,
|
||||
FilePath: "",
|
||||
}
|
||||
if req.Value != nil {
|
||||
params.Value = *req.Value
|
||||
}
|
||||
if req.Description != nil {
|
||||
params.Description = *req.Description
|
||||
}
|
||||
if req.EnvName != nil {
|
||||
params.EnvName = *req.EnvName
|
||||
}
|
||||
if req.FilePath != nil {
|
||||
params.FilePath = *req.FilePath
|
||||
}
|
||||
|
||||
secret, err := api.Database.UpdateUserSecretByUserIDAndName(ctx, params)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if database.IsUniqueViolation(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Update would conflict with an existing secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error updating secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.UserSecretFromFull(secret))
|
||||
}
|
||||
|
||||
// @Summary Delete a user secret
|
||||
// @ID delete-a-user-secret
|
||||
// @Security CoderSessionToken
|
||||
// @Tags Secrets
|
||||
// @Param user path string true "User ID, username, or me"
|
||||
// @Param name path string true "Secret name"
|
||||
// @Success 204
|
||||
// @Router /users/{user}/secrets/{name} [delete]
|
||||
func (api *API) deleteUserSecret(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
name := chi.URLParam(r, "name")
|
||||
|
||||
rowsAffected, err := api.Database.DeleteUserSecretByUserIDAndName(ctx, database.DeleteUserSecretByUserIDAndNameParams{
|
||||
UserID: user.ID,
|
||||
Name: name,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting secret.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
@@ -0,0 +1,413 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestPostUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
secret, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "github-token",
|
||||
Value: "ghp_xxxxxxxxxxxx",
|
||||
Description: "Personal GitHub PAT",
|
||||
EnvName: "GITHUB_TOKEN",
|
||||
FilePath: "~/.github-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "github-token", secret.Name)
|
||||
assert.Equal(t, "Personal GitHub PAT", secret.Description)
|
||||
assert.Equal(t, "GITHUB_TOKEN", secret.EnvName)
|
||||
assert.Equal(t, "~/.github-token", secret.FilePath)
|
||||
assert.NotZero(t, secret.ID)
|
||||
assert.NotZero(t, secret.CreatedAt)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Value: "some-value",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Name is required")
|
||||
})
|
||||
|
||||
t.Run("MissingValue", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "missing-value-secret",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
assert.Contains(t, sdkErr.Message, "Value is required")
|
||||
})
|
||||
|
||||
t.Run("DuplicateName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "dup-secret",
|
||||
Value: "value2",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-1",
|
||||
Value: "value1",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "env-dup-2",
|
||||
Value: "value2",
|
||||
EnvName: "DUPLICATE_ENV",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("DuplicateFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "fp-dup-2",
|
||||
Value: "value2",
|
||||
FilePath: "/tmp/dup-file",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "invalid-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "1INVALID",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ReservedEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "reserved-env-secret",
|
||||
Value: "value",
|
||||
EnvName: "PATH",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("CoderPrefixEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "coder-prefix-secret",
|
||||
Value: "value",
|
||||
EnvName: "CODER_AGENT_TOKEN",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("InvalidFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "bad-path-secret",
|
||||
Value: "value",
|
||||
FilePath: "relative/path",
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecrets(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Verify no secrets exist on a fresh user.
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, secrets)
|
||||
|
||||
t.Run("WithSecrets", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-a",
|
||||
Value: "value-a",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "list-secret-b",
|
||||
Value: "value-b",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secrets, err := client.UserSecrets(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secrets, 2)
|
||||
// Sorted by name.
|
||||
assert.Equal(t, "list-secret-a", secrets[0].Name)
|
||||
assert.Equal(t, "list-secret-b", secrets[1].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
created, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "get-found-secret",
|
||||
Value: "my-value",
|
||||
EnvName: "GET_FOUND_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := client.UserSecretByName(ctx, codersdk.Me, "get-found-secret")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "get-found-secret", got.Name)
|
||||
assert.Equal(t, "GET_FOUND_SECRET", got.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.UserSecretByName(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPatchUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("UpdateDescription", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-desc-secret",
|
||||
Value: "my-value",
|
||||
Description: "original",
|
||||
EnvName: "PATCH_DESC_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
newDesc := "updated"
|
||||
updated, err := client.UpdateUserSecret(ctx, codersdk.Me, "patch-desc-secret", codersdk.UpdateUserSecretRequest{
|
||||
Description: &newDesc,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated", updated.Description)
|
||||
// Other fields unchanged.
|
||||
assert.Equal(t, "PATCH_DESC_ENV", updated.EnvName)
|
||||
})
|
||||
|
||||
t.Run("NoFields", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "patch-nofields-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "patch-nofields-secret", codersdk.UpdateUserSecretRequest{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
newVal := "new-value"
|
||||
_, err := client.UpdateUserSecret(ctx, codersdk.Me, "nonexistent", codersdk.UpdateUserSecretRequest{
|
||||
Value: &newVal,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictEnvName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-1",
|
||||
Value: "value1",
|
||||
EnvName: "CONFLICT_TAKEN_ENV",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-env-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "CONFLICT_TAKEN_ENV"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-env-2", codersdk.UpdateUserSecretRequest{
|
||||
EnvName: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("ConflictFilePath", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-1",
|
||||
Value: "value1",
|
||||
FilePath: "/tmp/conflict-taken",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "conflict-fp-2",
|
||||
Value: "value2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
taken := "/tmp/conflict-taken"
|
||||
_, err = client.UpdateUserSecret(ctx, codersdk.Me, "conflict-fp-2", codersdk.UpdateUserSecretRequest{
|
||||
FilePath: &taken,
|
||||
})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusConflict, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteUserSecret(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
_, err := client.CreateUserSecret(ctx, codersdk.Me, codersdk.CreateUserSecretRequest{
|
||||
Name: "delete-me-secret",
|
||||
Value: "my-value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.DeleteUserSecret(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.UserSecretByName(ctx, codersdk.Me, "delete-me-secret")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
err := client.DeleteUserSecret(ctx, codersdk.Me, "nonexistent")
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
assert.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
+600
-2
@@ -42,6 +42,8 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/telemetry"
|
||||
maputil "github.com/coder/coder/v2/coderd/util/maps"
|
||||
"github.com/coder/coder/v2/coderd/wspubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/x/gitsync"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
@@ -181,8 +183,9 @@ func (api *API) patchWorkspaceAgentLogs(rw http.ResponseWriter, r *http.Request)
|
||||
level := make([]database.LogLevel, 0)
|
||||
outputLength := 0
|
||||
for _, logEntry := range req.Logs {
|
||||
output = append(output, logEntry.Output)
|
||||
outputLength += len(logEntry.Output)
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(logEntry.Output)
|
||||
output = append(output, sanitizedOutput)
|
||||
outputLength += len(sanitizedOutput)
|
||||
if logEntry.Level == "" {
|
||||
// Default to "info" to support older agents that didn't have the level field.
|
||||
logEntry.Level = codersdk.LogLevelInfo
|
||||
@@ -2392,3 +2395,598 @@ func convertWorkspaceAgentLogs(logs []database.WorkspaceAgentLog) []codersdk.Wor
|
||||
}
|
||||
return sdk
|
||||
}
|
||||
|
||||
// maxChatContextParts caps the number of parts per request to
|
||||
// prevent unbounded message payloads.
|
||||
const maxChatContextParts = 100
|
||||
|
||||
// maxChatContextFileBytes caps each context-file part to the same
|
||||
// 64KiB budget used when the agent reads instruction files from disk.
|
||||
const maxChatContextFileBytes = 64 * 1024
|
||||
|
||||
// maxChatContextRequestBodyBytes caps the JSON request body size for
|
||||
// agent-added context to roughly the same per-part budget used when
|
||||
// reading instruction files from disk.
|
||||
const maxChatContextRequestBodyBytes int64 = maxChatContextParts * maxChatContextFileBytes
|
||||
|
||||
// sanitizeWorkspaceAgentContextFileContent applies prompt
|
||||
// sanitization, then enforces the 64KiB per-file budget. The
|
||||
// truncated flag is preserved when the caller already capped the
|
||||
// file before sending it.
|
||||
func sanitizeWorkspaceAgentContextFileContent(
|
||||
content string,
|
||||
truncated bool,
|
||||
) (string, bool) {
|
||||
content = chatd.SanitizePromptText(content)
|
||||
if len(content) > maxChatContextFileBytes {
|
||||
content = content[:maxChatContextFileBytes]
|
||||
truncated = true
|
||||
}
|
||||
return content, truncated
|
||||
}
|
||||
|
||||
// readChatContextBody reads and validates the request body for chat
|
||||
// context endpoints. It handles MaxBytesReader wrapping, error
|
||||
// responses, and body rewind. If the body is empty or whitespace-only
|
||||
// and allowEmpty is true, it returns false without writing an error.
|
||||
//
|
||||
//nolint:revive // Add and clear endpoints only differ by empty-body handling.
|
||||
func readChatContextBody(ctx context.Context, rw http.ResponseWriter, r *http.Request, dst any, allowEmpty bool) bool {
|
||||
r.Body = http.MaxBytesReader(rw, r.Body, maxChatContextRequestBodyBytes)
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusRequestEntityTooLarge, codersdk.Response{
|
||||
Message: "Request body too large.",
|
||||
Detail: fmt.Sprintf("Maximum request body size is %d bytes.", maxChatContextRequestBodyBytes),
|
||||
})
|
||||
return false
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read request body.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return false
|
||||
}
|
||||
if allowEmpty && len(bytes.TrimSpace(body)) == 0 {
|
||||
r.Body = http.NoBody
|
||||
return false
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return httpapi.Read(ctx, rw, r, dst)
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentAddChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.AddChatContextRequest
|
||||
if !readChatContextBody(ctx, rw, r, &req, false) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Parts) > maxChatContextParts {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Too many context parts (%d). Maximum is %d.", len(req.Parts), maxChatContextParts),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Filter to only non-empty context-file and skill parts.
|
||||
filtered := chatd.FilterContextParts(req.Parts, false)
|
||||
if len(filtered) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
req.Parts = filtered
|
||||
responsePartCount := 0
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
// We verify agent-to-chat ownership explicitly below.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stamp each persisted part with the agent identity. Context-file
|
||||
// parts also get server-authoritative workspace metadata.
|
||||
directory := workspaceAgent.ExpandedDirectory
|
||||
if directory == "" {
|
||||
directory = workspaceAgent.Directory
|
||||
}
|
||||
for i := range req.Parts {
|
||||
req.Parts[i].ContextFileAgentID = uuid.NullUUID{
|
||||
UUID: workspaceAgent.ID,
|
||||
Valid: true,
|
||||
}
|
||||
if req.Parts[i].Type != codersdk.ChatMessagePartTypeContextFile {
|
||||
continue
|
||||
}
|
||||
req.Parts[i].ContextFileContent, req.Parts[i].ContextFileTruncated = sanitizeWorkspaceAgentContextFileContent(
|
||||
req.Parts[i].ContextFileContent,
|
||||
req.Parts[i].ContextFileTruncated,
|
||||
)
|
||||
req.Parts[i].ContextFileOS = workspaceAgent.OperatingSystem
|
||||
req.Parts[i].ContextFileDirectory = directory
|
||||
}
|
||||
req.Parts = chatd.FilterContextParts(req.Parts, false)
|
||||
if len(req.Parts) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "No context-file or skill parts provided.",
|
||||
})
|
||||
return
|
||||
}
|
||||
responsePartCount = len(req.Parts)
|
||||
|
||||
// Skill-only messages need a sentinel context-file part so the turn
|
||||
// pipeline trusts the associated skill metadata.
|
||||
req.Parts = prependAgentChatContextSentinelIfNeeded(
|
||||
req.Parts,
|
||||
workspaceAgent.ID,
|
||||
workspaceAgent.OperatingSystem,
|
||||
directory,
|
||||
)
|
||||
|
||||
content, err := chatprompt.MarshalParts(req.Parts)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to marshal context parts.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(sysCtx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != workspaceAgent.ID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspace.OwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if _, err := tx.InsertChatMessages(sysCtx, chatd.BuildSingleChatMessageInsertParams(
|
||||
chat.ID,
|
||||
database.ChatMessageRoleUser,
|
||||
content,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
locked.LastModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
uuid.Nil,
|
||||
)); err != nil {
|
||||
return xerrors.Errorf("insert context message: %w", err)
|
||||
}
|
||||
if err := updateAgentChatLastInjectedContextFromMessages(sysCtx, api.Logger, tx, chat.ID); err != nil {
|
||||
return xerrors.Errorf("rebuild injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to persist context message.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.AddChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
Count: responsePartCount,
|
||||
})
|
||||
}
|
||||
|
||||
// @x-apidocgen {"skip": true}
|
||||
func (api *API) workspaceAgentClearChatContext(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||
|
||||
var req agentsdk.ClearChatContextRequest
|
||||
populated := readChatContextBody(ctx, rw, r, &req, true)
|
||||
if !populated && r.Body != http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
// Use system context for chat operations since the
|
||||
// workspace agent scope does not include chat resources.
|
||||
//nolint:gocritic // Agent needs system access to read/write chat resources.
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
workspace, err := api.Database.GetWorkspaceByAgentID(sysCtx, workspaceAgent.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to determine workspace from agent token.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chat, err := resolveAgentChat(sysCtx, api.Database, workspaceAgent.ID, workspace.OwnerID, req.ChatID)
|
||||
if err != nil {
|
||||
// Zero active chats is not an error for clear.
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{})
|
||||
return
|
||||
}
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = clearAgentChatContext(sysCtx, api.Database, chat.ID, workspaceAgent.ID, workspace.OwnerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, errChatNotActive) || errors.Is(err, errChatDoesNotBelongToAgent) || errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
writeAgentChatError(ctx, rw, err)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to clear context from chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, agentsdk.ClearChatContextResponse{
|
||||
ChatID: chat.ID,
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
errNoActiveChats = xerrors.New("no active chats found")
|
||||
errChatNotFound = xerrors.New("chat not found")
|
||||
errChatNotActive = xerrors.New("chat is not active")
|
||||
errChatDoesNotBelongToAgent = xerrors.New("chat does not belong to this agent")
|
||||
errChatDoesNotBelongToWorkspaceOwner = xerrors.New("chat does not belong to this workspace owner")
|
||||
)
|
||||
|
||||
type multipleActiveChatsError struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (e *multipleActiveChatsError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"multiple active chats (%d) found for this agent, specify a chat ID",
|
||||
e.count,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveDefaultAgentChat(chats []database.Chat) (database.Chat, error) {
|
||||
switch len(chats) {
|
||||
case 0:
|
||||
return database.Chat{}, errNoActiveChats
|
||||
case 1:
|
||||
return chats[0], nil
|
||||
}
|
||||
|
||||
var rootChat *database.Chat
|
||||
for i := range chats {
|
||||
chat := &chats[i]
|
||||
if chat.ParentChatID.Valid {
|
||||
continue
|
||||
}
|
||||
if rootChat != nil {
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
rootChat = chat
|
||||
}
|
||||
if rootChat != nil {
|
||||
return *rootChat, nil
|
||||
}
|
||||
return database.Chat{}, &multipleActiveChatsError{count: len(chats)}
|
||||
}
|
||||
|
||||
// resolveAgentChat finds the target chat from either an explicit ID
|
||||
// or auto-detection via the agent's active chats.
|
||||
func resolveAgentChat(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
explicitChatID uuid.UUID,
|
||||
) (database.Chat, error) {
|
||||
if explicitChatID == uuid.Nil {
|
||||
chats, err := db.GetActiveChatsByAgentID(ctx, agentID)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("list active chats: %w", err)
|
||||
}
|
||||
ownerChats := make([]database.Chat, 0, len(chats))
|
||||
for _, chat := range chats {
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
continue
|
||||
}
|
||||
ownerChats = append(ownerChats, chat)
|
||||
}
|
||||
return resolveDefaultAgentChat(ownerChats)
|
||||
}
|
||||
|
||||
chat, err := db.GetChatByID(ctx, explicitChatID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return database.Chat{}, errChatNotFound
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("get chat by id: %w", err)
|
||||
}
|
||||
if !chat.AgentID.Valid || chat.AgentID.UUID != agentID {
|
||||
return database.Chat{}, errChatDoesNotBelongToAgent
|
||||
}
|
||||
if chat.OwnerID != workspaceOwnerID {
|
||||
return database.Chat{}, errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
if !isActiveAgentChat(chat) {
|
||||
return database.Chat{}, errChatNotActive
|
||||
}
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
func isActiveAgentChat(chat database.Chat) bool {
|
||||
if chat.Archived {
|
||||
return false
|
||||
}
|
||||
|
||||
switch chat.Status {
|
||||
case database.ChatStatusWaiting,
|
||||
database.ChatStatusPending,
|
||||
database.ChatStatusRunning,
|
||||
database.ChatStatusPaused,
|
||||
database.ChatStatusRequiresAction:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func clearAgentChatContext(
|
||||
ctx context.Context,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
agentID uuid.UUID,
|
||||
workspaceOwnerID uuid.UUID,
|
||||
) error {
|
||||
return db.InTx(func(tx database.Store) error {
|
||||
locked, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat: %w", err)
|
||||
}
|
||||
if !isActiveAgentChat(locked) {
|
||||
return errChatNotActive
|
||||
}
|
||||
if !locked.AgentID.Valid || locked.AgentID.UUID != agentID {
|
||||
return errChatDoesNotBelongToAgent
|
||||
}
|
||||
if locked.OwnerID != workspaceOwnerID {
|
||||
return errChatDoesNotBelongToWorkspaceOwner
|
||||
}
|
||||
messages, err := tx.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
hadInjectedContext := locked.LastInjectedContext.Valid
|
||||
var skillOnlyMessageIDs []int64
|
||||
for _, msg := range messages {
|
||||
if !msg.Content.Valid {
|
||||
continue
|
||||
}
|
||||
hasContextFile := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeContextFile)
|
||||
hasSkill := messageHasPartTypes(msg.Content.RawMessage, codersdk.ChatMessagePartTypeSkill)
|
||||
if hasContextFile || hasSkill {
|
||||
hadInjectedContext = true
|
||||
}
|
||||
if hasSkill && !hasContextFile {
|
||||
skillOnlyMessageIDs = append(skillOnlyMessageIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
if !hadInjectedContext {
|
||||
return nil
|
||||
}
|
||||
if err := tx.SoftDeleteContextFileMessages(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("soft delete context-file messages: %w", err)
|
||||
}
|
||||
for _, messageID := range skillOnlyMessageIDs {
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messageID); err != nil {
|
||||
return xerrors.Errorf("soft delete context message %d: %w", messageID, err)
|
||||
}
|
||||
}
|
||||
// Reset provider-side Responses chaining so the next turn replays
|
||||
// the post-clear history instead of inheriting cleared context.
|
||||
if err := tx.ClearChatMessageProviderResponseIDsByChatID(ctx, chatID); err != nil {
|
||||
return xerrors.Errorf("clear provider response chain: %w", err)
|
||||
}
|
||||
// Clear the injected-context cache inside the transaction so it is
|
||||
// atomic with the soft-deletes.
|
||||
param, err := chatd.BuildLastInjectedContext(nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
if _, err := tx.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("clear injected context cache: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
}
|
||||
|
||||
// prependAgentChatContextSentinelIfNeeded adds an empty context-file
|
||||
// part when the request only carries skills. The turn pipeline uses
|
||||
// the sentinel's agent metadata to trust the skill parts.
|
||||
func prependAgentChatContextSentinelIfNeeded(
|
||||
parts []codersdk.ChatMessagePart,
|
||||
agentID uuid.UUID,
|
||||
operatingSystem string,
|
||||
directory string,
|
||||
) []codersdk.ChatMessagePart {
|
||||
hasContextFile := false
|
||||
hasSkill := false
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case codersdk.ChatMessagePartTypeContextFile:
|
||||
hasContextFile = true
|
||||
case codersdk.ChatMessagePartTypeSkill:
|
||||
hasSkill = true
|
||||
}
|
||||
if hasContextFile && hasSkill {
|
||||
return parts
|
||||
}
|
||||
}
|
||||
if !hasSkill || hasContextFile {
|
||||
return parts
|
||||
}
|
||||
return append([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: chatd.AgentChatContextSentinelPath,
|
||||
ContextFileAgentID: uuid.NullUUID{
|
||||
UUID: agentID,
|
||||
Valid: true,
|
||||
},
|
||||
ContextFileOS: operatingSystem,
|
||||
ContextFileDirectory: directory,
|
||||
}}, parts...)
|
||||
}
|
||||
|
||||
func sortChatMessagesByCreatedAtAndID(messages []database.ChatMessage) {
|
||||
sort.SliceStable(messages, func(i, j int) bool {
|
||||
if messages[i].CreatedAt.Equal(messages[j].CreatedAt) {
|
||||
return messages[i].ID < messages[j].ID
|
||||
}
|
||||
return messages[i].CreatedAt.Before(messages[j].CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
// updateAgentChatLastInjectedContextFromMessages rebuilds the
|
||||
// injected-context cache from all persisted context-file and skill parts.
|
||||
func updateAgentChatLastInjectedContextFromMessages(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
chatID uuid.UUID,
|
||||
) error {
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("load context messages for injected context: %w", err)
|
||||
}
|
||||
|
||||
sortChatMessagesByCreatedAtAndID(messages)
|
||||
|
||||
parts, err := chatd.CollectContextPartsFromMessages(ctx, logger, messages, true)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("collect injected context parts: %w", err)
|
||||
}
|
||||
parts = chatd.FilterContextPartsToLatestAgent(parts)
|
||||
|
||||
param, err := chatd.BuildLastInjectedContext(parts)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
if _, err := db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
||||
ID: chatID,
|
||||
LastInjectedContext: param,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update injected context: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func messageHasPartTypes(raw []byte, types ...codersdk.ChatMessagePartType) bool {
|
||||
var parts []codersdk.ChatMessagePart
|
||||
if err := json.Unmarshal(raw, &parts); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range parts {
|
||||
for _, typ := range types {
|
||||
if part.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAgentChatError translates resolveAgentChat errors to HTTP
|
||||
// responses.
|
||||
func writeAgentChatError(
|
||||
ctx context.Context,
|
||||
rw http.ResponseWriter,
|
||||
err error,
|
||||
) {
|
||||
if errors.Is(err, errNoActiveChats) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "No active chats found for this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotFound) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Chat not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToAgent) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this agent.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatDoesNotBelongToWorkspaceOwner) {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
||||
Message: "Chat does not belong to this workspace owner.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errChatNotActive) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Cannot modify context: this chat is no longer active.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var multipleErr *multipleActiveChatsError
|
||||
if errors.As(err, &multipleErr) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to resolve chat.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestActiveAgentChatDefinitionsAgree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
|
||||
org, err := db.GetDefaultOrganization(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
owner := dbgen.User(t, db, database.User{})
|
||||
workspace := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: owner.ID,
|
||||
}).WithAgent().Do()
|
||||
modelConfig := insertAgentChatTestModelConfig(ctx, t, db, owner.ID)
|
||||
|
||||
insertedChats := make([]database.Chat, 0, len(database.AllChatStatusValues())*2)
|
||||
for _, archived := range []bool{false, true} {
|
||||
for _, status := range database.AllChatStatusValues() {
|
||||
chat, err := db.InsertChat(ctx, database.InsertChatParams{
|
||||
Status: status,
|
||||
OwnerID: owner.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: fmt.Sprintf("%s-archived-%t", status, archived),
|
||||
AgentID: uuid.NullUUID{UUID: workspace.Agents[0].ID, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
if archived {
|
||||
_, err = db.ArchiveChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.GetChatByID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
insertedChats = append(insertedChats, chat)
|
||||
}
|
||||
}
|
||||
|
||||
activeChats, err := db.GetActiveChatsByAgentID(ctx, workspace.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
activeByID := make(map[uuid.UUID]bool, len(activeChats))
|
||||
for _, chat := range activeChats {
|
||||
activeByID[chat.ID] = true
|
||||
}
|
||||
|
||||
for _, chat := range insertedChats {
|
||||
require.Equalf(
|
||||
t,
|
||||
isActiveAgentChat(chat),
|
||||
activeByID[chat.ID],
|
||||
"status=%s archived=%t",
|
||||
chat.Status,
|
||||
chat.Archived,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestUpdateAgentChatLastInjectedContextFromMessagesUsesMessageIDTieBreaker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
chatID := uuid.New()
|
||||
createdAt := time.Date(2026, time.April, 9, 13, 0, 0, 0, time.UTC)
|
||||
oldAgentID := uuid.New()
|
||||
newAgentID := uuid.New()
|
||||
|
||||
oldContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/old/AGENTS.md",
|
||||
ContextFileContent: "old instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: oldAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
newContent, err := json.Marshal([]codersdk.ChatMessagePart{{
|
||||
Type: codersdk.ChatMessagePartTypeContextFile,
|
||||
ContextFilePath: "/new/AGENTS.md",
|
||||
ContextFileContent: "new instructions",
|
||||
ContextFileAgentID: uuid.NullUUID{UUID: newAgentID, Valid: true},
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
|
||||
db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
}).Return([]database.ChatMessage{
|
||||
{
|
||||
ID: 2,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: newContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 1,
|
||||
CreatedAt: createdAt,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: oldContent,
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
db.EXPECT().UpdateChatLastInjectedContext(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatLastInjectedContextParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.True(t, arg.LastInjectedContext.Valid)
|
||||
var cached []codersdk.ChatMessagePart
|
||||
require.NoError(t, json.Unmarshal(arg.LastInjectedContext.RawMessage, &cached))
|
||||
require.Len(t, cached, 1)
|
||||
require.Equal(t, "/new/AGENTS.md", cached[0].ContextFilePath)
|
||||
require.Equal(t, uuid.NullUUID{UUID: newAgentID, Valid: true}, cached[0].ContextFileAgentID)
|
||||
return database.Chat{}, nil
|
||||
},
|
||||
)
|
||||
|
||||
err = updateAgentChatLastInjectedContextFromMessages(
|
||||
context.Background(),
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
db,
|
||||
chatID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func insertAgentChatTestModelConfig(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
db database.Store,
|
||||
userID uuid.UUID,
|
||||
) database.ChatModelConfig {
|
||||
t.Helper()
|
||||
|
||||
sysCtx := dbauthz.AsSystemRestricted(ctx)
|
||||
createdBy := uuid.NullUUID{UUID: userID, Valid: true}
|
||||
|
||||
_, err := db.InsertChatProvider(sysCtx, database.InsertChatProviderParams{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-api-key",
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedBy: createdBy,
|
||||
Enabled: true,
|
||||
CentralApiKeyEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(sysCtx, database.InsertChatModelConfigParams{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: createdBy,
|
||||
UpdatedBy: createdBy,
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 128000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return model
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,7 +91,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
require.Equal(t, tmpDir, workspace.LatestBuild.Resources[0].Agents[0].Directory)
|
||||
_, err = anotherClient.WorkspaceAgent(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
require.False(t, workspace.LatestBuild.Resources[0].Agents[0].Health.Healthy)
|
||||
})
|
||||
t.Run("HasFallbackTroubleshootingURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -260,6 +260,50 @@ func TestWorkspaceAgentLogs(t *testing.T) {
|
||||
require.Equal(t, "testing", logChunk[0].Output)
|
||||
require.Equal(t, "testing2", logChunk[1].Output)
|
||||
})
|
||||
t.Run("SanitizesNulBytesAndTracksSanitizedLength", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
client, db := coderdtest.NewWithDatabase(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
|
||||
OrganizationID: user.OrganizationID,
|
||||
OwnerID: user.UserID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
rawOutput := "before\x00after"
|
||||
sanitizedOutput := agentsdk.SanitizeLogOutput(rawOutput)
|
||||
agentClient := agentsdk.New(client.URL, agentsdk.WithFixedToken(r.AgentToken))
|
||||
err := agentClient.PatchLogs(ctx, agentsdk.PatchLogs{
|
||||
Logs: []agentsdk.Log{
|
||||
{
|
||||
CreatedAt: dbtime.Now(),
|
||||
Output: rawOutput,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
agent, err := db.GetWorkspaceAgentByID(dbauthz.AsSystemRestricted(ctx), r.Agents[0].ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(sanitizedOutput), agent.LogsLength)
|
||||
|
||||
workspace, err := client.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
logs, closer, err := client.WorkspaceAgentLogsAfter(ctx, workspace.LatestBuild.Resources[0].Agents[0].ID, 0, true)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = closer.Close()
|
||||
}()
|
||||
|
||||
var logChunk []codersdk.WorkspaceAgentLog
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case logChunk = <-logs:
|
||||
}
|
||||
require.NoError(t, ctx.Err())
|
||||
require.Len(t, logChunk, 1)
|
||||
require.Equal(t, sanitizedOutput, logChunk[0].Output)
|
||||
})
|
||||
t.Run("Close logs on outdated build", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
@@ -730,10 +730,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log := s.Logger.With(
|
||||
slog.F("agent_id", appToken.AgentID),
|
||||
slog.F("workspace_id", appToken.WorkspaceID),
|
||||
)
|
||||
log := s.Logger.With(slog.F("agent_id", appToken.AgentID))
|
||||
log.Debug(ctx, "resolved PTY request")
|
||||
|
||||
values := r.URL.Query()
|
||||
@@ -768,21 +765,19 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
go httpapi.HeartbeatClose(ctx, s.Logger, cancel, conn)
|
||||
|
||||
ctx, wsNetConn := WebsocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
go httpapi.HeartbeatClose(ctx, log, cancel, conn)
|
||||
|
||||
dialStart := time.Now()
|
||||
|
||||
agentConn, release, err := s.AgentProvider.AgentConn(ctx, appToken.AgentID)
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dial workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
|
||||
return
|
||||
}
|
||||
defer release()
|
||||
log.Debug(ctx, "dialed workspace agent", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dialed workspace agent")
|
||||
// #nosec G115 - Safe conversion for terminal height/width which are expected to be within uint16 range (0-65535)
|
||||
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
|
||||
arp.Container = container
|
||||
@@ -790,12 +785,12 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
arp.BackendType = backendType
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err), slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
|
||||
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
|
||||
return
|
||||
}
|
||||
defer ptNetConn.Close()
|
||||
log.Debug(ctx, "obtained PTY", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "obtained PTY")
|
||||
|
||||
report := newStatsReportFromSignedToken(*appToken)
|
||||
s.collectStats(report)
|
||||
@@ -805,7 +800,7 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
agentssh.Bicopy(ctx, wsNetConn, ptNetConn)
|
||||
log.Debug(ctx, "pty Bicopy finished", slog.F("elapsed_ms", time.Since(dialStart).Milliseconds()))
|
||||
log.Debug(ctx, "pty Bicopy finished")
|
||||
}
|
||||
|
||||
func (s *Server) collectStats(stats StatsReport) {
|
||||
|
||||
+69
-12
@@ -213,6 +213,39 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: echo.ProvisionGraphWithAgent(authToken),
|
||||
})
|
||||
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Connecting", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
@@ -247,10 +280,10 @@ func TestWorkspace(t *testing.T) {
|
||||
|
||||
agent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
|
||||
assert.True(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{}, workspace.Health.FailingAgents)
|
||||
assert.True(t, agent.Health.Healthy)
|
||||
assert.Empty(t, agent.Health.Reason)
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
assert.Equal(t, []uuid.UUID{agent.ID}, workspace.Health.FailingAgents)
|
||||
assert.False(t, agent.Health.Healthy)
|
||||
assert.Equal(t, "agent has not yet connected", agent.Health.Reason)
|
||||
})
|
||||
|
||||
t.Run("Unhealthy", func(t *testing.T) {
|
||||
@@ -302,6 +335,7 @@ func TestWorkspace(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
a1AuthToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -313,7 +347,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "a1",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: a1AuthToken,
|
||||
},
|
||||
}, {
|
||||
Id: uuid.NewString(),
|
||||
Name: "a2",
|
||||
@@ -330,13 +366,21 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, a1AuthToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
return assert.NoError(t, err) && !workspace.Health.Healthy
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Wait for the mixed state: a1 connected (healthy)
|
||||
// and workspace unhealthy (because a2 timed out).
|
||||
agent1 := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return agent1.Health.Healthy && !workspace.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
|
||||
assert.False(t, workspace.Health.Healthy)
|
||||
@@ -360,6 +404,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// disconnected, but this should not make the workspace unhealthy.
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
authToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionGraph: []*proto.Response{{
|
||||
@@ -371,7 +416,9 @@ func TestWorkspace(t *testing.T) {
|
||||
Agents: []*proto.Agent{{
|
||||
Id: uuid.NewString(),
|
||||
Name: "parent",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: authToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
@@ -383,14 +430,23 @@ func TestWorkspace(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
_ = agenttest.New(t, client.URL, authToken)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// Get the workspace and parent agent.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
parentAgent := workspace.LatestBuild.Resources[0].Agents[0]
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy initially")
|
||||
// Wait for the parent agent to connect and be healthy.
|
||||
var parentAgent codersdk.WorkspaceAgent
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
parentAgent = workspace.LatestBuild.Resources[0].Agents[0]
|
||||
return parentAgent.Health.Healthy
|
||||
}, testutil.IntervalMedium)
|
||||
require.True(t, parentAgent.Health.Healthy, "parent agent should be healthy")
|
||||
|
||||
// Create a sub-agent with a short connection timeout so it becomes
|
||||
// unhealthy quickly (simulating a devcontainer rebuild scenario).
|
||||
@@ -404,6 +460,7 @@ func TestWorkspace(t *testing.T) {
|
||||
// Wait for the sub-agent to become unhealthy due to timeout.
|
||||
var subAgentUnhealthy bool
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
+711
-75
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user